# Part 7.5 Reinforcement Learning
"""
Part 7.5 강화학습 (Reinforcement Learning) 실습
Q-Learning과 Deep Q-Network (DQN) 기초 실습

이 파일은 다음 내용을 포함합니다:
1. Q-Learning 알고리즘 구현
2. Deep Q-Network (DQN) 구현
3. FrozenLake 환경에서 Q-Learning 실습
4. CartPole 환경에서 DQN 실습
5. 학습 과정 시각화
6. 성능 비교 및 분석
"""

import random
import warnings
from collections import deque

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

warnings.filterwarnings("ignore")

# OpenAI Gym이 없는 경우를 위한 대체 구현
try:
    import gym

    GYM_AVAILABLE = True
except ImportError:
    print("OpenAI Gym이 설치되지 않았습니다. 간단한 환경을 구현합니다.")
    GYM_AVAILABLE = False

# 시드 설정
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


class SimpleFrozenLake:
    """간단한 FrozenLake 환경 구현 (Gym 대체용)"""

    def __init__(self, size=4):
        self.size = size
        self.n_states = size * size
        self.n_actions = 4  # up, down, left, right
        self.current_state = 0
        self.goal_state = self.n_states - 1
        self.holes = [5, 7, 11, 12]  # 구멍 위치

    def reset(self):
        self.current_state = 0
        return self.current_state

    def step(self, action):
        # 행동에 따른 상태 변화
        row, col = divmod(self.current_state, self.size)

        if action == 0 and row > 0:  # up
            row -= 1
        elif action == 1 and row < self.size - 1:  # down
            row += 1
        elif action == 2 and col > 0:  # left
            col -= 1
        elif action == 3 and col < self.size - 1:  # right
            col += 1

        new_state = row * self.size + col

        # 보상 계산
        if new_state == self.goal_state:
            reward = 1.0
            done = True
        elif new_state in self.holes:
            reward = 0.0
            done = True
        else:
            reward = 0.0
            done = False

        self.current_state = new_state
        return new_state, reward, done, {}


class QLearningAgent:
    """Q-Learning 에이전트"""

    def __init__(
        self,
        n_states,
        n_actions,
        learning_rate=0.1,
        discount_factor=0.99,
        epsilon=1.0,
        epsilon_decay=0.995,
    ):
        self.n_states = n_states
        self.n_actions = n_actions
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = 0.01

        # Q-table 초기화
        self.q_table = np.zeros((n_states, n_actions))

    def choose_action(self, state):
        """엡실론-그리디 정책을 사용한 행동 선택"""
        if np.random.random() < self.epsilon:
            return np.random.choice(self.n_actions)
        else:
            return np.argmax(self.q_table[state])

    def learn(self, state, action, reward, next_state, done):
        """Q-Learning 업데이트"""
        current_q = self.q_table[state, action]

        if done:
            target_q = reward
        else:
            target_q = reward + self.discount_factor * np.max(self.q_table[next_state])

        # Q-value 업데이트
        self.q_table[state, action] += self.learning_rate * (target_q - current_q)

        # 엡실론 감소
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


class DQNNetwork(nn.Module):
    """Deep Q-Network 신경망"""

    def __init__(self, input_size, hidden_size, output_size):
        super(DQNNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class DQNAgent:
    """Deep Q-Network 에이전트"""

    def __init__(self, state_size, action_size, learning_rate=0.001):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=10000)
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = learning_rate
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # 신경망 초기화
        self.q_network = DQNNetwork(state_size, 64, action_size).to(self.device)
        self.target_network = DQNNetwork(state_size, 64, action_size).to(self.device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)

        # 타겟 네트워크 가중치 복사
        self.update_target_network()

    def update_target_network(self):
        """타겟 네트워크 가중치 업데이트"""
        self.target_network.load_state_dict(self.q_network.state_dict())

    def remember(self, state, action, reward, next_state, done):
        """경험 저장"""
        self.memory.append((state, action, reward, next_state, done))

    def choose_action(self, state):
        """엡실론-그리디 정책을 사용한 행동 선택"""
        if np.random.random() <= self.epsilon:
            return random.randrange(self.action_size)

        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        q_values = self.q_network(state_tensor)
        return np.argmax(q_values.cpu().data.numpy())

    def replay(self, batch_size=32):
        """경험 재생을 통한 학습"""
        if len(self.memory) < batch_size:
            return

        batch = random.sample(self.memory, batch_size)
        states = torch.FloatTensor([e[0] for e in batch]).to(self.device)
        actions = torch.LongTensor([e[1] for e in batch]).to(self.device)
        rewards = torch.FloatTensor([e[2] for e in batch]).to(self.device)
        next_states = torch.FloatTensor([e[3] for e in batch]).to(self.device)
        dones = torch.BoolTensor([e[4] for e in batch]).to(self.device)

        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        next_q_values = self.target_network(next_states).max(1)[0].detach()
        target_q_values = rewards + (0.99 * next_q_values * ~dones)

        loss = F.mse_loss(current_q_values.squeeze(), target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


def train_q_learning(env, agent, episodes=1000):
    """Q-Learning 학습"""
    scores = []

    for episode in range(episodes):
        state = env.reset()
        total_reward = 0

        for step in range(100):  # 최대 100 스텝
            action = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.learn(state, action, reward, next_state, done)

            state = next_state
            total_reward += reward

            if done:
                break

        scores.append(total_reward)

        if episode % 100 == 0:
            avg_score = np.mean(scores[-100:])
            print(
                f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {agent.epsilon:.2f}"
            )

    return scores


def train_dqn(env, agent, episodes=1000):
    """DQN 학습"""
    scores = []

    for episode in range(episodes):
        state = env.reset()
        if hasattr(env, "observation_space"):
            state = np.reshape(state, [1, env.observation_space.shape[0]])
        else:
            # 간단한 환경의 경우 상태를 원-핫 인코딩
            state_vec = np.zeros(agent.state_size)
            state_vec[state] = 1
            state = state_vec

        total_reward = 0

        for step in range(500):  # 최대 500 스텝
            action = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)

            if hasattr(env, "observation_space"):
                next_state = np.reshape(next_state, [1, env.observation_space.shape[0]])
            else:
                next_state_vec = np.zeros(agent.state_size)
                next_state_vec[next_state] = 1
                next_state = next_state_vec

            agent.remember(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

            if done:
                agent.update_target_network()
                break

            if len(agent.memory) > 32:
                agent.replay(32)

        scores.append(total_reward)

        if episode % 100 == 0:
            avg_score = np.mean(scores[-100:])
            print(
                f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {agent.epsilon:.2f}"
            )

    return scores


def visualize_q_table(q_table, title="Q-Table Visualization"):
    """Q-table 시각화"""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    actions = ["Up", "Down", "Left", "Right"]

    for i in range(4):
        ax = axes[i // 2, i % 2]
        q_values = q_table[:, i].reshape(4, 4)
        im = ax.imshow(q_values, cmap="viridis")
        ax.set_title(f"{actions[i]} Action Q-Values")
        plt.colorbar(im, ax=ax)

        # 값 표시
        for row in range(4):
            for col in range(4):
                ax.text(
                    col,
                    row,
                    f"{q_values[row, col]:.2f}",
                    ha="center",
                    va="center",
                    color="white",
                    fontsize=8,
                )

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


def plot_training_results(q_scores, dqn_scores):
    """학습 결과 시각화"""
    plt.figure(figsize=(15, 5))

    # Q-Learning 결과
    plt.subplot(1, 3, 1)
    plt.plot(q_scores)
    plt.title("Q-Learning Training Scores")
    plt.xlabel("Episode")
    plt.ylabel("Score")
    plt.grid(True)

    # DQN 결과
    plt.subplot(1, 3, 2)
    plt.plot(dqn_scores)
    plt.title("DQN Training Scores")
    plt.xlabel("Episode")
    plt.ylabel("Score")
    plt.grid(True)

    # 이동 평균 비교
    plt.subplot(1, 3, 3)
    window = 100
    q_avg = [
        np.mean(q_scores[i : i + window]) for i in range(len(q_scores) - window + 1)
    ]
    dqn_avg = [
        np.mean(dqn_scores[i : i + window]) for i in range(len(dqn_scores) - window + 1)
    ]

    plt.plot(q_avg, label="Q-Learning", alpha=0.7)
    plt.plot(dqn_avg, label="DQN", alpha=0.7)
    plt.title(f"Training Comparison (Moving Average, window={window})")
    plt.xlabel("Episode")
    plt.ylabel("Average Score")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()


def main():
    """메인 실행 함수"""
    print("=== 강화학습 (RL) 실습 시작 ===\n")

    # 1. Q-Learning 실습 (FrozenLake)
    print("1. Q-Learning 실습 (FrozenLake 환경)...")

    if GYM_AVAILABLE:
        try:
            env_q = gym.make("FrozenLake-v1", is_slippery=False)
            n_states = env_q.observation_space.n
            n_actions = env_q.action_space.n
        except:
            print("   FrozenLake 환경을 로드할 수 없습니다. 간단한 환경을 사용합니다.")
            env_q = SimpleFrozenLake()
            n_states = env_q.n_states
            n_actions = env_q.n_actions
    else:
        env_q = SimpleFrozenLake()
        n_states = env_q.n_states
        n_actions = env_q.n_actions

    q_agent = QLearningAgent(n_states, n_actions)
    print(f"   - 상태 공간: {n_states}")
    print(f"   - 행동 공간: {n_actions}")

    # Q-Learning 학습
    print("\n2. Q-Learning 학습 중...")
    q_scores = train_q_learning(env_q, q_agent, episodes=1000)

    # Q-table 시각화
    print("\n3. Q-table 시각화...")
    visualize_q_table(q_agent.q_table, "Learned Q-Table")

    # 4. DQN 실습 (CartPole 또는 간단한 환경)
    print("\n4. DQN 실습...")

    if GYM_AVAILABLE:
        try:
            env_dqn = gym.make("CartPole-v1")
            state_size = env_dqn.observation_space.shape[0]
            action_size = env_dqn.action_space.n
        except:
            print("   CartPole 환경을 로드할 수 없습니다. 간단한 환경을 사용합니다.")
            env_dqn = SimpleFrozenLake()
            state_size = env_dqn.n_states
            action_size = env_dqn.n_actions
    else:
        env_dqn = SimpleFrozenLake()
        state_size = env_dqn.n_states
        action_size = env_dqn.n_actions

    dqn_agent = DQNAgent(state_size, action_size)
    print(f"   - 상태 크기: {state_size}")
    print(f"   - 행동 크기: {action_size}")

    # DQN 학습
    print("\n5. DQN 학습 중...")
    dqn_scores = train_dqn(env_dqn, dqn_agent, episodes=500)

    # 6. 결과 비교 및 시각화
    print("\n6. 학습 결과 시각화...")
    plot_training_results(q_scores, dqn_scores)

    # 7. 성능 분석
    print("\n=== 성능 분석 ===")
    print(f"Q-Learning 평균 점수 (마지막 100 에피소드): {np.mean(q_scores[-100:]):.2f}")
    print(f"DQN 평균 점수 (마지막 100 에피소드): {np.mean(dqn_scores[-100:]):.2f}")

    print("\n=== 실습 완료 ===")


if __name__ == "__main__":
    main()
