# Part 7.2 Transformer and LLM Principles
"""
Part 7.2 Transformer와 LLM 원리 실습
간단한 Transformer 아키텍처 구현 및 Attention 메커니즘 시각화

이 파일은 다음 내용을 포함합니다:
1. Multi-Head Attention 구현
2. Position Encoding 구현
3. Transformer Encoder 구현
4. Transformer Decoder 구현
5. 간단한 언어 모델 학습
6. Attention 메커니즘 시각화
"""

import math
import warnings

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings("ignore")

# 시드 설정
torch.manual_seed(42)
np.random.seed(42)


class PositionalEncoding(nn.Module):
    """위치 인코딩 구현"""

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[: x.size(0), :]


class MultiHeadAttention(nn.Module):
    """Multi-Head Attention 구현"""

    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention_weights = None  # 시각화용

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)
        self.attention_weights = attention_weights.detach()  # 시각화용 저장

        output = torch.matmul(attention_weights, V)
        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear transformation and split into heads
        Q = (
            self.W_q(query)
            .view(batch_size, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = (
            self.W_v(value)
            .view(batch_size, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )

        # Apply attention
        attention_output, attention_weights = self.scaled_dot_product_attention(
            Q, K, V, mask
        )

        # Concatenate heads
        attention_output = (
            attention_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.d_model)
        )

        # Final linear layer
        output = self.W_o(attention_output)

        return output


class FeedForward(nn.Module):
    """Feed Forward Network 구현"""

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


class TransformerEncoderLayer(nn.Module):
    """Transformer Encoder Layer 구현"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x


class SimpleTransformer(nn.Module):
    """간단한 Transformer 모델"""

    def __init__(
        self,
        vocab_size,
        d_model=512,
        num_heads=8,
        num_layers=6,
        d_ff=2048,
        max_len=5000,
        dropout=0.1,
    ):
        super(SimpleTransformer, self).__init__()

        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)

        self.encoder_layers = nn.ModuleList(
            [
                TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
                for _ in range(num_layers)
            ]
        )

        self.output_layer = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Embedding + Position Encoding
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)

        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, mask)

        # Output projection
        output = self.output_layer(x)

        return output


class SimpleTextDataset(Dataset):
    """간단한 텍스트 데이터셋"""

    def __init__(self, text, seq_length=50):
        self.seq_length = seq_length

        # 간단한 문자 단위 토크나이저
        self.chars = sorted(list(set(text)))
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}

        self.vocab_size = len(self.chars)
        self.data = [self.char_to_idx[ch] for ch in text]

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        input_seq = torch.tensor(self.data[idx : idx + self.seq_length])
        target_seq = torch.tensor(self.data[idx + 1 : idx + self.seq_length + 1])
        return input_seq, target_seq


class TransformerTrainer:
    """Transformer 모델 학습 클래스"""

    def __init__(self, model, vocab_size, learning_rate=0.0001):
        self.model = model
        self.vocab_size = vocab_size
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        self.train_losses = []

    def train(self, dataloader, epochs=10):
        """모델 학습"""
        self.model.train()

        for epoch in range(epochs):
            epoch_loss = 0
            for batch_idx, (input_seq, target_seq) in enumerate(dataloader):
                self.optimizer.zero_grad()

                output = self.model(input_seq)
                loss = self.criterion(
                    output.view(-1, self.vocab_size), target_seq.view(-1)
                )

                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            avg_loss = epoch_loss / len(dataloader)
            self.train_losses.append(avg_loss)

            if (epoch + 1) % 2 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

    def generate_text(self, dataset, start_str="Hello", length=100):
        """텍스트 생성"""
        self.model.eval()

        input_seq = [dataset.char_to_idx[ch] for ch in start_str]
        generated = start_str

        for _ in range(length):
            input_tensor = torch.tensor(input_seq[-dataset.seq_length :]).unsqueeze(0)

            with torch.no_grad():
                output = self.model(input_tensor)
                # 마지막 토큰의 확률 분포에서 샘플링
                probabilities = F.softmax(output[0, -1], dim=-1)
                next_char_idx = torch.multinomial(probabilities, 1).item()

            next_char = dataset.idx_to_char[next_char_idx]
            generated += next_char
            input_seq.append(next_char_idx)

        return generated


def visualize_attention(model, input_text, dataset, layer_idx=0):
    """Attention 가중치 시각화"""
    model.eval()

    # 입력 텍스트를 토큰으로 변환
    input_tokens = [dataset.char_to_idx[ch] for ch in input_text]
    input_tensor = torch.tensor(input_tokens).unsqueeze(0)

    # Forward pass
    with torch.no_grad():
        _ = model(input_tensor)

    # 특정 레이어의 attention 가중치 가져오기
    attention_weights = model.encoder_layers[layer_idx].self_attention.attention_weights

    if attention_weights is not None:
        # 첫 번째 헤드의 attention 가중치 시각화
        attn_matrix = attention_weights[0, 0].cpu().numpy()

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            attn_matrix,
            xticklabels=list(input_text),
            yticklabels=list(input_text),
            cmap="Blues",
            cbar=True,
        )
        plt.title(f"Attention Weights - Layer {layer_idx}, Head 0")
        plt.xlabel("Key Position")
        plt.ylabel("Query Position")
        plt.show()

        return attn_matrix
    else:
        print("Attention weights not available")
        return None


def plot_positional_encoding(d_model=512, max_len=100):
    """위치 인코딩 시각화"""
    pos_encoding = PositionalEncoding(d_model, max_len)

    # 더미 입력으로 위치 인코딩 계산
    dummy_input = torch.zeros(max_len, 1, d_model)
    encoded = pos_encoding(dummy_input)

    plt.figure(figsize=(15, 5))

    # 일부 차원만 시각화
    plt.subplot(1, 2, 1)
    plt.plot(encoded[:, 0, 0:8].detach().numpy())
    plt.title("Positional Encoding (first 8 dimensions)")
    plt.xlabel("Position")
    plt.ylabel("Encoding Value")
    plt.legend([f"dim {i}" for i in range(8)])

    # 히트맵으로 전체 위치 인코딩 시각화
    plt.subplot(1, 2, 2)
    plt.imshow(encoded[:50, 0, :50].detach().numpy().T, cmap="RdYlBu", aspect="auto")
    plt.title("Positional Encoding Heatmap")
    plt.xlabel("Position")
    plt.ylabel("Encoding Dimension")
    plt.colorbar()

    plt.tight_layout()
    plt.show()


def main():
    """메인 실행 함수"""
    print("=== Transformer 및 LLM 원리 실습 시작 ===\n")

    # 1. 위치 인코딩 시각화
    print("1. 위치 인코딩 시각화...")
    plot_positional_encoding()

    # 2. 샘플 텍스트 데이터 준비
    print("2. 데이터 준비 중...")
    sample_text = """
    Hello world! This is a simple example of transformer model training.
    The transformer architecture uses self-attention mechanisms to process sequences.
    It has revolutionized natural language processing and deep learning.
    """

    dataset = SimpleTextDataset(sample_text, seq_length=30)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    print(f"   - 어휘 크기: {dataset.vocab_size}")
    print(f"   - 데이터 길이: {len(dataset)}")

    # 3. 모델  생성
    print("\n3. Transformer 모델 생성 중...")
    model = SimpleTransformer(
        vocab_size=dataset.vocab_size,
        d_model=128,
        num_heads=4,
        num_layers=2,
        d_ff=256,
        dropout=0.1,
    )

    print(f"   - 모델 파라미터 수: {sum(p.numel() for p in model.parameters()):,}")

    # 4. 모델 학습
    print("\n4. 모델 학습 중...")
    trainer = TransformerTrainer(model, dataset.vocab_size)
    trainer.train(dataloader, epochs=10)

    # 5. 텍스트 생성
    print("\n5. 텍스트 생성 중...")
    generated_text = trainer.generate_text(dataset, start_str="Hello", length=100)
    print(f"생성된 텍스트:\n{generated_text}")

    # 6. Attention 시각화
    print("\n6. Attention 메커니즘 시각화...")
    test_text = "Hello world transformer"
    if len(test_text) <= 30:  # 시퀀스 길이 제한
        attention_matrix = visualize_attention(model, test_text, dataset)

    # 7. 학습 곡선 시각화
    print("\n7. 학습 곡선 시각화...")
    plt.figure(figsize=(10, 5))
    plt.plot(trainer.train_losses)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()

    print("\n=== 실습 완료 ===")


if __name__ == "__main__":
    main()
