# Part 7.1 Recurrent Neural Networks
"""
Part 7.1 순환 신경망 (RNN) 실습
LSTM과 GRU를 활용한 시계열 데이터 예측

이 파일은 다음 내용을 포함합니다:
1. 시계열 데이터 생성 및 전처리
2. LSTM 모델 구현
3. GRU 모델 구현
4. 모델 학습 및 평가
5. 결과 비교 및 시각화
"""

import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.preprocessing import MinMaxScaler

warnings.filterwarnings("ignore")

# 시드 설정
torch.manual_seed(42)
np.random.seed(42)


class TimeSeriesData:
    """시계열 데이터 생성 및 전처리 클래스"""

    def __init__(self, seq_length=50):
        self.seq_length = seq_length
        self.scaler = MinMaxScaler()

    def generate_sine_wave_data(self, n_samples=1000):
        """사인파 기반 시계열 데이터 생성"""
        t = np.linspace(0, 100, n_samples)
        # 복합 사인파 + 노이즈
        data = (
            np.sin(t)
            + 0.5 * np.sin(3 * t)
            + 0.3 * np.sin(5 * t)
            + 0.1 * np.random.randn(n_samples)
        )
        return data

    def create_sequences(self, data):
        """시계열 데이터를 시퀀스로 변환"""
        X, y = [], []
        for i in range(len(data) - self.seq_length):
            X.append(data[i : (i + self.seq_length)])
            y.append(data[i + self.seq_length])
        return np.array(X), np.array(y)

    def prepare_data(self, train_ratio=0.8):
        """데이터 생성 및 전처리"""
        # 데이터 생성
        raw_data = self.generate_sine_wave_data()

        # 정규화
        scaled_data = self.scaler.fit_transform(raw_data.reshape(-1, 1)).flatten()

        # 시퀀스 생성
        X, y = self.create_sequences(scaled_data)

        # 훈련/테스트 분할
        train_size = int(len(X) * train_ratio)
        X_train, X_test = X[:train_size], X[train_size:]
        y_train, y_test = y[:train_size], y[train_size:]

        # 텐서 변환
        X_train = torch.FloatTensor(X_train).unsqueeze(-1)  # (batch, seq_len, features)
        X_test = torch.FloatTensor(X_test).unsqueeze(-1)
        y_train = torch.FloatTensor(y_train).unsqueeze(-1)
        y_test = torch.FloatTensor(y_test).unsqueeze(-1)

        return X_train, X_test, y_train, y_test, raw_data


class LSTMModel(nn.Module):
    """LSTM 기반 시계열 예측 모델"""

    def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # LSTM 레이어
        lstm_out, _ = self.lstm(x)
        # 마지막 타임스텝의 출력 사용
        output = self.fc(lstm_out[:, -1, :])
        return output


class GRUModel(nn.Module):
    """GRU 기반 시계열 예측 모델"""

    def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # GRU 레이어
        gru_out, _ = self.gru(x)
        # 마지막 타임스텝의 출력 사용
        output = self.fc(gru_out[:, -1, :])
        return output


class ModelTrainer:
    """모델 학습 및 평가 클래스"""

    def __init__(self, model, learning_rate=0.001):
        self.model = model
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        self.train_losses = []

    def train(self, X_train, y_train, epochs=100, batch_size=32):
        """모델 학습"""
        self.model.train()

        for epoch in range(epochs):
            epoch_loss = 0
            for i in range(0, len(X_train), batch_size):
                batch_X = X_train[i : i + batch_size]
                batch_y = y_train[i : i + batch_size]

                self.optimizer.zero_grad()
                outputs = self.model(batch_X)
                loss = self.criterion(outputs, batch_y)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            avg_loss = epoch_loss / (len(X_train) // batch_size + 1)
            self.train_losses.append(avg_loss)

            if (epoch + 1) % 20 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.6f}")

    def evaluate(self, X_test, y_test, scaler):
        """모델 평가"""
        self.model.eval()
        with torch.no_grad():
            predictions = self.model(X_test)

        # 역정규화
        predictions_np = scaler.inverse_transform(predictions.cpu().numpy())
        y_test_np = scaler.inverse_transform(y_test.cpu().numpy())

        # 평가 지표 계산
        mse = mean_squared_error(y_test_np, predictions_np)
        mae = mean_absolute_error(y_test_np, predictions_np)
        rmse = np.sqrt(mse)

        return predictions_np, y_test_np, {"MSE": mse, "MAE": mae, "RMSE": rmse}


def visualize_results(data_gen, lstm_pred, gru_pred, y_true, lstm_metrics, gru_metrics):
    """결과 시각화"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # 원본 데이터
    axes[0, 0].plot(data_gen.generate_sine_wave_data()[:200])
    axes[0, 0].set_title("Original Time Series Data")
    axes[0, 0].set_xlabel("Time")
    axes[0, 0].set_ylabel("Value")

    # LSTM vs 실제값
    axes[0, 1].plot(y_true[:100], label="Actual", alpha=0.7)
    axes[0, 1].plot(lstm_pred[:100], label="LSTM Prediction", alpha=0.7)
    axes[0, 1].set_title(f'LSTM Predictions (RMSE: {lstm_metrics["RMSE"]:.4f})')
    axes[0, 1].set_xlabel("Time")
    axes[0, 1].set_ylabel("Value")
    axes[0, 1].legend()

    # GRU vs 실제값
    axes[1, 0].plot(y_true[:100], label="Actual", alpha=0.7)
    axes[1, 0].plot(gru_pred[:100], label="GRU Prediction", alpha=0.7)
    axes[1, 0].set_title(f'GRU Predictions (RMSE: {gru_metrics["RMSE"]:.4f})')
    axes[1, 0].set_xlabel("Time")
    axes[1, 0].set_ylabel("Value")
    axes[1, 0].legend()

    # 성능 비교
    metrics = ["MSE", "MAE", "RMSE"]
    lstm_values = [lstm_metrics[m] for m in metrics]
    gru_values = [gru_metrics[m] for m in metrics]

    x = range(len(metrics))
    width = 0.35
    axes[1, 1].bar(
        [i - width / 2 for i in x], lstm_values, width, label="LSTM", alpha=0.7
    )
    axes[1, 1].bar(
        [i + width / 2 for i in x], gru_values, width, label="GRU", alpha=0.7
    )
    axes[1, 1].set_title("Model Performance Comparison")
    axes[1, 1].set_xlabel("Metrics")
    axes[1, 1].set_ylabel("Error Value")
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(metrics)
    axes[1, 1].legend()

    plt.tight_layout()
    plt.show()


def main():
    """메인 실행 함수"""
    print("=== RNN 시계열 예측 실습 시작 ===\n")

    # 1. 데이터 준비
    print("1. 데이터 준비 중...")
    data_gen = TimeSeriesData(seq_length=50)
    X_train, X_test, y_train, y_test, raw_data = data_gen.prepare_data()
    print(f"   - 훈련 데이터: {X_train.shape}")
    print(f"   - 테스트 데이터: {X_test.shape}\n")

    # 2. 모델 생성
    print("2. 모델 생성 중...")
    lstm_model = LSTMModel(input_size=1, hidden_size=50, num_layers=2)
    gru_model = GRUModel(input_size=1, hidden_size=50, num_layers=2)
    print("   - LSTM 모델 생성 완료")
    print("   - GRU 모델 생성 완료\n")

    # 3. 모델 학습
    print("3. LSTM 모델 학습 중...")
    lstm_trainer = ModelTrainer(lstm_model)
    lstm_trainer.train(X_train, y_train, epochs=100)

    print("\n4. GRU 모델 학습 중...")
    gru_trainer = ModelTrainer(gru_model)
    gru_trainer.train(X_train, y_train, epochs=100)

    # 4. 모델 평가
    print("\n5. 모델 평가 중...")
    lstm_pred, y_true, lstm_metrics = lstm_trainer.evaluate(
        X_test, y_test, data_gen.scaler
    )
    gru_pred, _, gru_metrics = gru_trainer.evaluate(X_test, y_test, data_gen.scaler)

    print("\n=== 평가 결과 ===")
    print("LSTM 성능:")
    for metric, value in lstm_metrics.items():
        print(f"  {metric}: {value:.6f}")

    print("\nGRU 성능:")
    for metric, value in gru_metrics.items():
        print(f"  {metric}: {value:.6f}")

    # 5. 결과 시각화
    print("\n6. 결과 시각화...")
    visualize_results(data_gen, lstm_pred, gru_pred, y_true, lstm_metrics, gru_metrics)

    print("\n=== 실습 완료 ===")


if __name__ == "__main__":
    main()
