# Part 7.4 Graph Neural Networks
"""
Part 7.4 그래프 신경망 (GNN) 실습
PyTorch Geometric을 사용한 GNN 기초 실습

이 파일은 다음 내용을 포함합니다:
1. 그래프 데이터 생성 및 시각화
2. Graph Convolutional Network (GCN) 구현
3. Graph Attention Network (GAT) 구현
4. 노드 분류 실습
5. 그래프 분류 실습
6. 결과 분석 및 시각화
"""

import warnings

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, classification_report

warnings.filterwarnings("ignore")

# PyTorch Geometric이 없는 경우를 위한 대체 구현
try:
    from torch_geometric.data import Data, DataLoader
    from torch_geometric.datasets import Planetoid, TUDataset
    from torch_geometric.nn import GATConv, GCNConv, global_mean_pool

    TORCH_GEOMETRIC_AVAILABLE = True
except ImportError:
    print("PyTorch Geometric이 설치되지 않았습니다. 기본 구현을 사용합니다.")
    TORCH_GEOMETRIC_AVAILABLE = False

# 시드 설정
torch.manual_seed(42)
np.random.seed(42)


class SimpleGCNLayer(nn.Module):
    """간단한 GCN 레이어 구현 (PyTorch Geometric 대체)"""

    def __init__(self, in_features, out_features):
        super(SimpleGCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, adj_matrix):
        # A * X * W (단순화된 GCN 연산)
        support = self.linear(x)
        output = torch.mm(adj_matrix, support)
        return output


class GCNModel(nn.Module):
    """Graph Convolutional Network 모델"""

    def __init__(self, input_dim, hidden_dim, output_dim, use_pytorch_geometric=True):
        super(GCNModel, self).__init__()
        self.use_pytorch_geometric = use_pytorch_geometric and TORCH_GEOMETRIC_AVAILABLE

        if self.use_pytorch_geometric:
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, output_dim)
        else:
            self.conv1 = SimpleGCNLayer(input_dim, hidden_dim)
            self.conv2 = SimpleGCNLayer(hidden_dim, output_dim)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index_or_adj):
        if self.use_pytorch_geometric:
            x = self.conv1(x, edge_index_or_adj)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.conv2(x, edge_index_or_adj)
        else:
            x = self.conv1(x, edge_index_or_adj)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.conv2(x, edge_index_or_adj)

        return F.log_softmax(x, dim=1)


class GATModel(nn.Module):
    """Graph Attention Network 모델"""

    def __init__(self, input_dim, hidden_dim, output_dim, heads=8):
        super(GATModel, self).__init__()
        self.use_pytorch_geometric = TORCH_GEOMETRIC_AVAILABLE

        if self.use_pytorch_geometric:
            self.conv1 = GATConv(input_dim, hidden_dim, heads=heads, dropout=0.6)
            self.conv2 = GATConv(
                hidden_dim * heads, output_dim, heads=1, concat=False, dropout=0.6
            )
        else:
            # 간단한 attention 메커니즘 구현
            self.attention = nn.MultiheadAttention(input_dim, heads)
            self.linear1 = nn.Linear(input_dim, hidden_dim)
            self.linear2 = nn.Linear(hidden_dim, output_dim)

        self.dropout = nn.Dropout(0.6)

    def forward(self, x, edge_index_or_adj):
        if self.use_pytorch_geometric:
            x = F.dropout(x, p=0.6, training=self.training)
            x = self.conv1(x, edge_index_or_adj)
            x = F.elu(x)
            x = F.dropout(x, p=0.6, training=self.training)
            x = self.conv2(x, edge_index_or_adj)
        else:
            # 간단한 구현
            x = self.linear1(x)
            x = F.relu(x)
            x = self.dropout(x)
            x = self.linear2(x)

        return F.log_softmax(x, dim=1)


class GraphDataGenerator:
    """그래프 데이터 생성 클래스"""

    def __init__(self):
        pass

    def create_karate_club_graph(self):
        """Zachary's Karate Club 그래프 생성"""
        G = nx.karate_club_graph()

        # 노드 특성 생성 (단순히 degree를 특성으로 사용)
        degrees = dict(G.degree())
        features = []
        labels = []

        for node in G.nodes():
            # 특성: [degree, clustering_coefficient, betweenness_centrality]
            degree = degrees[node]
            clustering = nx.clustering(G, node)

            features.append([degree, clustering, degree * clustering])
            labels.append(G.nodes[node]["club"] == "Mr. Hi")

        # 인접 행렬 생성
        adj_matrix = nx.adjacency_matrix(G).todense()

        # 정규화된 인접 행렬 (self-loop 추가)
        adj_matrix = adj_matrix + np.eye(len(G.nodes()))
        degree_matrix = np.diag(np.sum(adj_matrix, axis=1))
        adj_matrix = np.linalg.inv(degree_matrix) @ adj_matrix

        return (
            torch.FloatTensor(features),
            torch.LongTensor(labels),
            torch.FloatTensor(adj_matrix),
            G,
        )

    def create_synthetic_graphs(self, num_graphs=100):
        """합성 그래프 데이터셋 생성 (그래프 분류용)"""
        graphs = []
        labels = []

        for i in range(num_graphs):
            if i % 2 == 0:
                # 클래스 0: tree-like 구조
                num_nodes = np.random.randint(10, 20)
                G = nx.random_tree(num_nodes)
                label = 0
            else:
                # 클래스 1: cycle-like 구조
                num_nodes = np.random.randint(10, 20)
                G = nx.cycle_graph(num_nodes)
                # 랜덤 엣지 추가
                for _ in range(np.random.randint(0, 3)):
                    nodes = list(G.nodes())
                    u, v = np.random.choice(nodes, 2, replace=False)
                    G.add_edge(u, v)
                label = 1

            # 노드 특성 생성
            features = []
            for node in G.nodes():
                degree = G.degree(node)
                features.append([degree, degree**2])

            graphs.append((torch.FloatTensor(features), G))
            labels.append(label)

        return graphs, torch.LongTensor(labels)


class GNNTrainer:
    """GNN 모델 학습 클래스"""

    def __init__(self, model, learning_rate=0.01):
        self.model = model
        self.optimizer = optim.Adam(
            model.parameters(), lr=learning_rate, weight_decay=5e-4
        )
        self.train_losses = []
        self.train_accuracies = []

    def train_node_classification(
        self, features, labels, adj_matrix, train_mask, val_mask, epochs=200
    ):
        """노드 분류 학습"""
        self.model.train()

        for epoch in range(epochs):
            self.optimizer.zero_grad()

            # Forward pass
            output = self.model(features, adj_matrix)

            # 훈련 데이터에 대한 손실 계산
            loss = F.nll_loss(output[train_mask], labels[train_mask])

            # Backward pass
            loss.backward()
            self.optimizer.step()

            # 평가
            with torch.no_grad():
                self.model.eval()
                pred = self.model(features, adj_matrix).argmax(dim=1)
                train_acc = accuracy_score(
                    labels[train_mask].cpu(), pred[train_mask].cpu()
                )
                val_acc = accuracy_score(labels[val_mask].cpu(), pred[val_mask].cpu())
                self.model.train()

            self.train_losses.append(loss.item())
            self.train_accuracies.append(train_acc)

            if (epoch + 1) % 50 == 0:
                print(
                    f"Epoch [{epoch+1}/{epochs}], "
                    f"Loss: {loss.item():.4f}, "
                    f"Train Acc: {train_acc:.4f}, "
                    f"Val Acc: {val_acc:.4f}"
                )

    def evaluate(self, features, labels, adj_matrix, test_mask):
        """모델 평가"""
        self.model.eval()
        with torch.no_grad():
            output = self.model(features, adj_matrix)
            pred = output.argmax(dim=1)

            test_acc = accuracy_score(labels[test_mask].cpu(), pred[test_mask].cpu())

            print(f"테스트 정확도: {test_acc:.4f}")
            print("\n분류 리포트:")
            print(classification_report(labels[test_mask].cpu(), pred[test_mask].cpu()))

        return pred


def visualize_graph(G, labels=None, node_embeddings=None, title="Graph Visualization"):
    """그래프 시각화"""
    plt.figure(figsize=(12, 8))

    pos = nx.spring_layout(G, seed=42)

    if labels is not None:
        colors = ["red" if label == 1 else "blue" for label in labels]
    else:
        colors = "lightblue"

    nx.draw(
        G,
        pos,
        node_color=colors,
        with_labels=True,
        node_size=300,
        font_size=8,
        font_weight="bold",
    )

    plt.title(title)
    plt.axis("off")
    plt.show()


def visualize_embeddings(embeddings, labels, title="Node Embeddings"):
    """노드 임베딩 시각화"""
    # t-SNE를 사용한 차원 축소
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings.detach().cpu().numpy())

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        embeddings_2d[:, 0],
        embeddings_2d[:, 1],
        c=labels.cpu().numpy(),
        cmap="viridis",
        alpha=0.7,
    )
    plt.colorbar(scatter)
    plt.title(title)
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.show()


def main():
    """메인 실행 함수"""
    print("=== 그래프 신경망 (GNN) 실습 시작 ===\n")

    # 1. 데이터 생성
    print("1. 그래프 데이터 생성 중...")
    data_gen = GraphDataGenerator()
    features, labels, adj_matrix, G = data_gen.create_karate_club_graph()

    print(f"   - 노드 수: {features.shape[0]}")
    print(f"   - 특성 차원: {features.shape[1]}")
    print(f"   - 클래스 수: {len(torch.unique(labels))}")

    # 2. 그래프 시각화
    print("\n2. 원본 그래프 시각화...")
    visualize_graph(G, labels.numpy(), title="Karate Club Graph")

    # 3. 훈련/검증/테스트 분할
    num_nodes = features.shape[0]
    indices = torch.randperm(num_nodes)

    train_size = int(0.6 * num_nodes)
    val_size = int(0.2 * num_nodes)

    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size : train_size + val_size]] = True
    test_mask[indices[train_size + val_size :]] = True

    # 4. GCN 모델 학습
    print("\n3. GCN 모델 학습 중...")
    gcn_model = GCNModel(
        input_dim=features.shape[1],
        hidden_dim=16,
        output_dim=len(torch.unique(labels)),
        use_pytorch_geometric=False,
    )

    gcn_trainer = GNNTrainer(gcn_model)
    gcn_trainer.train_node_classification(
        features, labels, adj_matrix, train_mask, val_mask, epochs=200
    )

    # 5. GCN 모델 평가
    print("\n4. GCN 모델 평가...")
    gcn_pred = gcn_trainer.evaluate(features, labels, adj_matrix, test_mask)

    # 6. GAT 모델 학습 (PyTorch Geometric 사용 가능한 경우)
    if TORCH_GEOMETRIC_AVAILABLE:
        print("\n5. GAT 모델 학습 중...")
        gat_model = GATModel(
            input_dim=features.shape[1],
            hidden_dim=8,
            output_dim=len(torch.unique(labels)),
        )

        gat_trainer = GNNTrainer(gat_model)
        # GAT는 edge_index 형태가 필요하므로 변환
        edge_index = torch.nonzero(adj_matrix > 0).t().contiguous()
        # 간단히 adj_matrix 사용
        gat_trainer.train_node_classification(
            features, labels, adj_matrix, train_mask, val_mask, epochs=100
        )

        print("\n6. GAT 모델 평가...")
        gat_pred = gat_trainer.evaluate(features, labels, adj_matrix, test_mask)

    # 7. 학습 곡선 시각화
    print("\n7. 학습 곡선 시각화...")
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 2, 1)
    plt.plot(gcn_trainer.train_losses)
    plt.title("GCN Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(gcn_trainer.train_accuracies)
    plt.title("GCN Training Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    # 8. 노드 임베딩 시각화
    print("\n8. 노드 임베딩 시각화...")
    gcn_model.eval()
    with torch.no_grad():
        # 마지막 레이어 전 특성 추출
        embeddings = gcn_model.conv1(features, adj_matrix)
        visualize_embeddings(embeddings, labels, "GCN Node Embeddings")

    print("\n=== 실습 완료 ===")


if __name__ == "__main__":
    main()
