import math

import torch
import torch.nn as nn

# Live Coding 시나리오에 따라 Transformer의 핵심 구성 요소를 밑바닥부터 구현합니다.


def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Scaled Dot-Product Attention을 계산합니다.
    Q, K, V는 (batch_size, num_heads, seq_len, d_k) 형태의 텐서입니다.
    """
    # 1. Q와 K의 전치 행렬을 곱합니다. (Attention Score 계산)
    # (batch_size, num_heads, seq_len, d_k) @ (batch_size, num_heads, d_k, seq_len)
    # -> (batch_size, num_heads, seq_len, seq_len)
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    # 2. (선택적) 마스킹을 적용합니다.
    if mask is not None:
        # 마스크에서 0인 위치를 매우 작은 값으로 채워 Softmax 이후 해당 위치의 확률이 0에 가깝도록 만듭니다.
        scores = scores.masked_fill(mask == 0, -1e9)

    # 3. Softmax를 적용하여 Attention Weight를 계산합니다.
    attn_weights = torch.softmax(scores, dim=-1)

    # 4. 계산된 Attention Weight와 V를 곱합니다.
    # (batch_size, num_heads, seq_len, seq_len) @ (batch_size, num_heads, seq_len, d_v)
    # -> (batch_size, num_heads, seq_len, d_v)
    output = torch.matmul(attn_weights, v)
    return output, attn_weights


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention 모듈 구현
    """

    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model  # 모델의 전체 차원
        self.num_heads = num_heads  # 헤드의 수
        self.d_k = d_model // num_heads  # 각 헤드의 차원

        # Q, K, V 및 최종 출력을 위한 Linear 레이어
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k)
        batch_size, seq_len, _ = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # (batch_size, num_heads, seq_len, d_k) -> (batch_size, seq_len, d_model)
        batch_size, _, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    def forward(self, q, k, v, mask=None):
        # 1. 입력 Q, K, V에 각각 Linear 레이어를 통과시킵니다.
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        # 2. 헤드를 여러 개로 나눕니다.
        q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)

        # 3. Scaled Dot-Product Attention을 수행합니다.
        attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)

        # 4. 나누었던 헤드를 다시 합칩니다.
        output = self.combine_heads(attn_output)

        # 5. 최종 Linear 레이어를 통과시켜 결과를 반환합니다.
        output = self.w_o(output)
        return output


class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network 구현
    """

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(self.relu(self.linear1(x))))


class EncoderBlock(nn.Module):
    """
    Transformer Encoder Block 하나를 구현합니다.
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderBlock, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 1. Multi-Head Self-Attention (첫 번째 서브레이어)
        # 잔차 연결(Residual Connection) 및 Layer Normalization
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))

        # 2. Position-wise Feed-Forward Network (두 번째 서브레이어)
        # 잔차 연결(Residual Connection) 및 Layer Normalization
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))

        return x


if __name__ == "__main__":
    # --- 파라미터 설정 ---
    batch_size = 4
    seq_len = 60  # 문장의 최대 길이
    d_model = 512  # 모델의 임베딩 차원
    num_heads = 8  # Multi-Head Attention의 헤드 수
    d_ff = 2048  # Feed-Forward 네트워크의 내부 차원

    # --- 더미 데이터 생성 ---
    # 실제로는 Tokenizer를 통해 얻은 토큰 ID를 Embedding 레이어에 통과시킨 결과입니다.
    dummy_input = torch.rand(batch_size, seq_len, d_model)
    print(f"입력 텐서 형태: {dummy_input.shape}")

    # --- Encoder Block 인스턴스화 및 실행 ---
    encoder_block = EncoderBlock(d_model, num_heads, d_ff)
    output = encoder_block(dummy_input)

    print(f"Encoder Block 출력 텐서 형태: {output.shape}")
    print("\n성공: Encoder Block을 통과한 후 입력과 출력의 텐서 형태가 동일합니다.")
    print(
        "이는 여러 개의 Encoder Block을 쌓아 전체 Encoder를 구성할 수 있음을 의미합니다."
    )
