import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_derivatives():
    """1. 미분의 직관적 이해: 기울기 시각화"""
    print("=== 1. 미분: 순간 변화율과 기울기 ===")
    
    def plot_function_with_tangent(x, y, tangent_x, tangent_y, slope, title):
        """함수와 접선을 시각화"""
        plt.figure(figsize=(10, 6))
        
        # 함수 그리기
        plt.plot(x, y, 'b-', linewidth=2, label='함수 f(x)')
        
        # 접선 그리기
        plt.plot(tangent_x, tangent_y, 'r--', linewidth=2, label=f'접선 (기울기: {slope:.2f})')
        
        # 접점 표시
        plt.scatter(tangent_x[len(tangent_x)//2], tangent_y[len(tangent_y)//2], 
                   color='red', s=100, zorder=5, label='접점')
        
        plt.grid(True, alpha=0.3)
        plt.xlabel('x')
        plt.ylabel('y')
        plt.title(title)
        plt.legend()
        plt.show()

    # 예제: f(x) = x² 함수의 미분
    x = np.linspace(-3, 3, 100)
    y = x**2

    # x = 1에서의 접선
    x0 = 1
    y0 = x0**2
    slope = 2 * x0  # f'(x) = 2x

    # 접선 그리기
    tangent_x = np.linspace(x0 - 1, x0 + 1, 50)
    tangent_y = slope * (tangent_x - x0) + y0

    plot_function_with_tangent(x, y, tangent_x, tangent_y, slope, 
                              f"f(x) = x²의 미분: x = {x0}에서 기울기 = {slope}")

    print(f"f(x) = x²")
    print(f"f'({x0}) = {slope}")
    print(f"x = {x0}에서 함수가 {slope}만큼 빠르게 증가하고 있습니다!")

def visualize_multiple_derivatives():
    """2. 여러 점에서의 기울기 시각화"""
    print("\n=== 2. 여러 점에서의 미분값 ===")
    
    x = np.linspace(-3, 3, 100)
    y = x**2
    
    # 여러 점에서의 접선
    points = [-2, -1, 0, 1, 2]
    
    plt.figure(figsize=(12, 8))
    plt.plot(x, y, 'b-', linewidth=3, label='f(x) = x²')
    
    colors = ['red', 'orange', 'green', 'blue', 'purple']
    
    for i, x0 in enumerate(points):
        y0 = x0**2
        slope = 2 * x0
        
        # 접선 그리기
        tangent_x = np.linspace(x0 - 0.5, x0 + 0.5, 20)
        tangent_y = slope * (tangent_x - x0) + y0
        
        plt.plot(tangent_x, tangent_y, '--', color=colors[i], linewidth=2,
                label=f'x = {x0}, 기울기 = {slope}')
        plt.scatter(x0, y0, color=colors[i], s=100, zorder=5)
    
    plt.grid(True, alpha=0.3)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('f(x) = x²의 여러 점에서의 기울기')
    plt.legend()
    plt.show()
    
    print("미분의 직관적 의미:")
    print("- 음수 기울기: 함수가 감소하고 있음")
    print("- 0 기울기: 함수가 변하지 않음 (극점)")
    print("- 양수 기울기: 함수가 증가하고 있음")

def numerical_vs_analytical_derivatives():
    """3. 수치 미분 vs 해석적 미분"""
    print("\n=== 3. 수치 미분 vs 해석적 미분 ===")
    
    def numerical_derivative(func, x, h=1e-4):
        """수치 미분: 미분의 정의를 이용한 근사"""
        return (func(x + h) - func(x - h)) / (2 * h)

    # 테스트 함수들
    functions = {
        'f(x) = x²': (lambda x: x**2, lambda x: 2*x),
        'f(x) = x³': (lambda x: x**3, lambda x: 3*x**2),
        'f(x) = sin(x)': (lambda x: np.sin(x), lambda x: np.cos(x))
    }
    
    x_test = 1.0
    
    print("수치 미분 vs 해석적 미분 비교:")
    print("-" * 50)
    
    for name, (func, derivative) in functions.items():
        numerical = numerical_derivative(func, x_test)
        analytical = derivative(x_test)
        
        print(f"{name}:")
        print(f"  수치 미분: f'({x_test}) ≈ {numerical:.6f}")
        print(f"  해석적 미분: f'({x_test}) = {analytical:.6f}")
        print(f"  오차: {abs(numerical - analytical):.2e}")
        print()

def gradient_descent_visualization():
    """4. 기울기 하강법: 산에서 내려오는 사람"""
    print("\n=== 4. 기울기 하강법 시각화 ===")
    
    # 목표 함수: f(x) = x² + 2
    def f(x):
        return x**2 + 2
    
    def df(x):
        return 2*x
    
    # 기울기 하강법
    x = 3.0  # 시작점
    learning_rate = 0.1
    steps = 20
    
    x_history = [x]
    y_history = [f(x)]
    
    for i in range(steps):
        gradient = df(x)
        x = x - learning_rate * gradient
        x_history.append(x)
        y_history.append(f(x))
    
    # 시각화
    x_plot = np.linspace(-4, 4, 100)
    y_plot = f(x_plot)
    
    plt.figure(figsize=(12, 8))
    
    # 함수 그리기
    plt.plot(x_plot, y_plot, 'b-', linewidth=2, label='f(x) = x² + 2')
    
    # 기울기 하강 과정
    plt.plot(x_history, y_history, 'ro-', linewidth=2, markersize=8, label='기울기 하강 경로')
    
    # 화살표로 이동 방향 표시
    for i in range(len(x_history) - 1):
        dx = x_history[i+1] - x_history[i]
        dy = y_history[i+1] - y_history[i]
        plt.arrow(x_history[i], y_history[i], dx, dy, 
                 head_width=0.05, head_length=0.1, fc='red', ec='red', alpha=0.7)
    
    plt.grid(True, alpha=0.3)
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title('기울기 하강법: 산에서 내려오는 과정')
    plt.legend()
    plt.show()
    
    print("기울기 하강법 과정:")
    for i, (x_val, y_val) in enumerate(zip(x_history, y_history)):
        print(f"단계 {i}: x = {x_val:.4f}, f(x) = {y_val:.4f}")

def compare_learning_rates():
    """5. 학습률에 따른 기울기 하강법 비교"""
    print("\n=== 5. 학습률의 중요성 ===")
    
    def f(x):
        return x**2 + 2
    
    def df(x):
        return 2*x
    
    learning_rates = [0.01, 0.1, 0.5, 1.0]
    colors = ['blue', 'green', 'orange', 'red']
    
    plt.figure(figsize=(15, 10))
    
    for i, lr in enumerate(learning_rates):
        # 기울기 하강법
        x = 3.0
        steps = 20
        x_history = [x]
        y_history = [f(x)]
        
        for step in range(steps):
            gradient = df(x)
            x = x - lr * gradient
            x_history.append(x)
            y_history.append(f(x))
        
        # 경로 그리기
        plt.subplot(2, 2, i+1)
        x_plot = np.linspace(-4, 4, 100)
        y_plot = f(x_plot)
        
        plt.plot(x_plot, y_plot, 'b-', linewidth=2, label='f(x) = x² + 2')
        plt.plot(x_history, y_history, 'o-', color=colors[i], linewidth=2, 
                markersize=6, label=f'학습률 = {lr}')
        
        plt.grid(True, alpha=0.3)
        plt.xlabel('x')
        plt.ylabel('f(x)')
        plt.title(f'학습률 = {lr}')
        plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    print("학습률의 영향:")
    print("- 너무 작음 (0.01): 천천히 수렴")
    print("- 적당함 (0.1): 빠르게 수렴")
    print("- 너무 큼 (0.5, 1.0): 진동하거나 발산")

def partial_derivatives_visualization():
    """6. 편미분: 여러 변수 중 하나만 바꾸기"""
    print("\n=== 6. 편미분 시각화 ===")
    
    # 예제 함수: f(x, y) = x² + y²
    def f(x, y):
        return x**2 + y**2
    
    def df_dx(x, y):
        return 2*x
    
    def df_dy(x, y):
        return 2*y
    
    # 3D 시각화
    x = np.linspace(-3, 3, 50)
    y = np.linspace(-3, 3, 50)
    X, Y = np.meshgrid(x, y)
    Z = f(X, Y)
    
    fig = plt.figure(figsize=(15, 5))
    
    # 3D 표면
    ax1 = fig.add_subplot(131, projection='3d')
    surf = ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
    ax1.set_title('f(x, y) = x² + y²')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('f(x, y)')
    
    # x에 대한 편미분 (y = 1에서의 단면)
    ax2 = fig.add_subplot(132)
    y_fixed = 1
    x_slice = np.linspace(-3, 3, 100)
    z_slice = f(x_slice, y_fixed)
    ax2.plot(x_slice, z_slice, 'b-', linewidth=2, label=f'f(x, {y_fixed})')
    
    # 접선 그리기 (x = 1에서)
    x0 = 1
    z0 = f(x0, y_fixed)
    slope = df_dx(x0, y_fixed)
    tangent_x = np.linspace(x0 - 1, x0 + 1, 20)
    tangent_z = slope * (tangent_x - x0) + z0
    ax2.plot(tangent_x, tangent_z, 'r--', linewidth=2, label=f'∂f/∂x = {slope}')
    
    ax2.scatter(x0, z0, color='red', s=100)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlabel('x')
    ax2.set_ylabel('f(x, y)')
    ax2.set_title(f'∂f/∂x (y = {y_fixed})')
    ax2.legend()
    
    # y에 대한 편미분 (x = 1에서의 단면)
    ax3 = fig.add_subplot(133)
    x_fixed = 1
    y_slice = np.linspace(-3, 3, 100)
    z_slice = f(x_fixed, y_slice)
    ax3.plot(y_slice, z_slice, 'g-', linewidth=2, label=f'f({x_fixed}, y)')
    
    # 접선 그리기 (y = 1에서)
    y0 = 1
    z0 = f(x_fixed, y0)
    slope = df_dy(x_fixed, y0)
    tangent_y = np.linspace(y0 - 1, y0 + 1, 20)
    tangent_z = slope * (tangent_y - y0) + z0
    ax3.plot(tangent_y, tangent_z, 'r--', linewidth=2, label=f'∂f/∂y = {slope}')
    
    ax3.scatter(y0, z0, color='red', s=100)
    ax3.grid(True, alpha=0.3)
    ax3.set_xlabel('y')
    ax3.set_ylabel('f(x, y)')
    ax3.set_title(f'∂f/∂y (x = {x_fixed})')
    ax3.legend()
    
    plt.tight_layout()
    plt.show()
    
    print("편미분의 직관적 의미:")
    print(f"∂f/∂x = {df_dx(1, 1)}: x를 조금 바꾸면 함수가 {df_dx(1, 1)}배만큼 변함")
    print(f"∂f/∂y = {df_dy(1, 1)}: y를 조금 바꾸면 함수가 {df_dy(1, 1)}배만큼 변함")

def gradient_vector_visualization():
    """7. 그래디언트 벡터 시각화"""
    print("\n=== 7. 그래디언트 벡터 ===")
    
    def f(x, y):
        return x**2 + y**2
    
    def gradient(x, y):
        return np.array([2*x, 2*y])
    
    # 등고선 그리기
    x = np.linspace(-3, 3, 50)
    y = np.linspace(-3, 3, 50)
    X, Y = np.meshgrid(x, y)
    Z = f(X, Y)
    
    plt.figure(figsize=(10, 8))
    
    # 등고선
    contours = plt.contour(X, Y, Z, levels=10, colors='blue', alpha=0.6)
    plt.clabel(contours, inline=True, fontsize=8)
    
    # 그래디언트 벡터들
    x_points = np.linspace(-2, 2, 8)
    y_points = np.linspace(-2, 2, 8)
    
    for i, x_val in enumerate(x_points):
        for j, y_val in enumerate(y_points):
            grad = gradient(x_val, y_val)
            # 벡터 크기 정규화
            grad_norm = grad / np.linalg.norm(grad) * 0.3
            plt.quiver(x_val, y_val, grad_norm[0], grad_norm[1], 
                      color='red', alpha=0.7, scale=20)
    
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('그래디언트 벡터: 가장 가파르게 증가하는 방향')
    plt.grid(True, alpha=0.3)
    plt.axis('equal')
    plt.show()
    
    print("그래디언트의 의미:")
    print("- 방향: 함수가 가장 가파르게 증가하는 방향")
    print("- 크기: 그 방향으로의 증가율")
    print("- 기울기 하강법: 그래디언트의 반대 방향으로 이동")

def chain_rule_visualization():
    """8. 연쇄법칙: 파이프라인을 통해 변화 전파"""
    print("\n=== 8. 연쇄법칙 시각화 ===")
    
    # 예제: f(x) = sin(x²)
    def f(x):
        return np.sin(x**2)
    
    def df_dx(x):
        # 연쇄법칙: d/dx[sin(x²)] = cos(x²) * 2x
        return np.cos(x**2) * 2*x
    
    x = np.linspace(-2, 2, 100)
    y = f(x)
    
    # 수치 미분으로 검증
    def numerical_derivative(func, x, h=1e-4):
        return (func(x + h) - func(x - h)) / (2 * h)
    
    x_test = 1.0
    analytical = df_dx(x_test)
    numerical = numerical_derivative(f, x_test)
    
    plt.figure(figsize=(12, 8))
    
    # 함수와 미분
    plt.subplot(2, 1, 1)
    plt.plot(x, y, 'b-', linewidth=2, label='f(x) = sin(x²)')
    plt.grid(True, alpha=0.3)
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title('합성 함수: f(x) = sin(x²)')
    plt.legend()
    
    # 미분
    plt.subplot(2, 1, 2)
    dy_dx = df_dx(x)
    plt.plot(x, dy_dx, 'r-', linewidth=2, label="f'(x) = cos(x²) * 2x")
    plt.grid(True, alpha=0.3)
    plt.xlabel('x')
    plt.ylabel("f'(x)")
    plt.title('연쇄법칙을 이용한 미분')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    print("연쇄법칙의 직관적 의미:")
    print(f"f(x) = sin(x²)에서 x = {x_test}일 때:")
    print(f"  해석적 미분: f'({x_test}) = {analytical:.6f}")
    print(f"  수치 미분: f'({x_test}) ≈ {numerical:.6f}")
    print(f"  오차: {abs(analytical - numerical):.2e}")
    print()
    print("연쇄법칙: d/dx[sin(x²)] = cos(x²) * d/dx[x²] = cos(x²) * 2x")

def simple_neural_network():
    """9. 간단한 신경망의 역전파"""
    print("\n=== 9. 간단한 신경망 역전파 ===")
    
    # 간단한 신경망: 입력 -> 은닉층 -> 출력
    # f(x) = w2 * sigmoid(w1 * x + b1) + b2
    
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))
    
    def sigmoid_derivative(x):
        return sigmoid(x) * (1 - sigmoid(x))
    
    # 가중치 초기화
    w1, b1 = 2.0, 1.0
    w2, b2 = 3.0, 0.5
    
    def forward(x):
        """순전파"""
        z1 = w1 * x + b1
        a1 = sigmoid(z1)
        z2 = w2 * a1 + b2
        return z2, a1, z1
    
    def backward(x, target, learning_rate=0.1):
        """역전파 (연쇄법칙)"""
        # 순전파
        output, a1, z1 = forward(x)
        
        # 오차
        error = output - target
        
        # 역전파 (연쇄법칙)
        # ∂E/∂w2 = ∂E/∂output * ∂output/∂w2
        dE_dw2 = error * a1
        
        # ∂E/∂w1 = ∂E/∂output * ∂output/∂a1 * ∂a1/∂z1 * ∂z1/∂w1
        dE_dw1 = error * w2 * sigmoid_derivative(z1) * x
        
        # ∂E/∂b1 = ∂E/∂output * ∂output/∂a1 * ∂a1/∂z1 * ∂z1/∂b1
        dE_db1 = error * w2 * sigmoid_derivative(z1)
        
        # ∂E/∂b2 = ∂E/∂output * ∂output/∂b2
        dE_db2 = error
        
        return dE_dw1, dE_db1, dE_dw2, dE_db2
    
    # 학습 과정 시각화
    x_train = np.array([0.5, 1.0, 1.5])
    y_train = np.array([2.0, 3.0, 4.0])
    
    print("간단한 신경망 학습 과정:")
    print("-" * 50)
    
    for epoch in range(5):
        total_error = 0
        for x, target in zip(x_train, y_train):
            output, _, _ = forward(x)
            error = output - target
            total_error += error**2
            
            # 그래디언트 계산
            dw1, db1, dw2, db2 = backward(x, target)
            
            print(f"입력: {x}, 목표: {target}, 출력: {output:.4f}, 오차: {error:.4f}")
            print(f"  그래디언트: dw1={dw1:.4f}, db1={db1:.4f}, dw2={dw2:.4f}, db2={db2:.4f}")
        
        print(f"에포크 {epoch+1}: 평균 오차 = {total_error/len(x_train):.4f}")
        print()

def linear_regression_gradient_descent():
    """10. 선형 회귀의 경사 하강법"""
    print("\n=== 10. 선형 회귀 경사 하강법 ===")
    
    # 데이터 생성
    np.random.seed(42)
    X = np.random.rand(100, 1) * 10
    y = 3 * X + 2 + np.random.randn(100, 1) * 0.5
    
    # 모델: y = w * x + b
    w, b = 0.0, 0.0
    learning_rate = 0.01
    epochs = 100
    
    # 학습 과정 기록
    w_history = [w]
    b_history = [b]
    loss_history = []
    
    for epoch in range(epochs):
        # 예측
        y_pred = w * X + b
        
        # 손실 함수 (MSE)
        loss = np.mean((y_pred - y)**2)
        loss_history.append(loss)
        
        # 그래디언트 계산
        dw = np.mean(2 * (y_pred - y) * X)
        db = np.mean(2 * (y_pred - y))
        
        # 파라미터 업데이트
        w = w - learning_rate * dw
        b = b - learning_rate * db
        
        w_history.append(w)
        b_history.append(b)
    
    # 결과 시각화
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # 데이터와 학습된 선
    ax1.scatter(X, y, alpha=0.6, label='데이터')
    X_line = np.linspace(0, 10, 100).reshape(-1, 1)
    y_line = w * X_line + b
    ax1.plot(X_line, y_line, 'r-', linewidth=2, label=f'학습된 선: y = {w:.2f}x + {b:.2f}')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_title('선형 회귀 결과')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 손실 함수 변화
    ax2.plot(loss_history, 'b-', linewidth=2)
    ax2.set_xlabel('에포크')
    ax2.set_ylabel('손실 (MSE)')
    ax2.set_title('손실 함수 수렴')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("선형 회귀 학습 결과:")
    print(f"학습된 가중치: w = {w:.4f}")
    print(f"학습된 편향: b = {b:.4f}")
    print(f"실제 값: w = 3.0, b = 2.0")
    print(f"최종 손실: {loss_history[-1]:.6f}")

def compare_optimizers():
    """11. 다양한 최적화 알고리즘 비교"""
    print("\n=== 11. 최적화 알고리즘 비교 ===")
    
    def objective_function(x):
        return x**2 + 2*x + 1
    
    def gradient(x):
        return 2*x + 2
    
    # 최적화 알고리즘들
    def gradient_descent(x0, learning_rate=0.1, epochs=50):
        x = x0
        history = [x]
        for _ in range(epochs):
            x = x - learning_rate * gradient(x)
            history.append(x)
        return history
    
    def momentum_gradient_descent(x0, learning_rate=0.1, momentum=0.9, epochs=50):
        x = x0
        velocity = 0
        history = [x]
        for _ in range(epochs):
            velocity = momentum * velocity - learning_rate * gradient(x)
            x = x + velocity
            history.append(x)
        return history
    
    # 최적화 실행
    x0 = 5.0
    gd_history = gradient_descent(x0)
    mgd_history = momentum_gradient_descent(x0)
    
    # 시각화
    x_plot = np.linspace(-1, 6, 100)
    y_plot = objective_function(x_plot)
    
    plt.figure(figsize=(12, 8))
    
    # 목적 함수
    plt.subplot(2, 1, 1)
    plt.plot(x_plot, y_plot, 'b-', linewidth=2, label='목적 함수: f(x) = x² + 2x + 1')
    plt.plot(gd_history, [objective_function(x) for x in gd_history], 'ro-', 
             markersize=6, label='기본 경사 하강법')
    plt.plot(mgd_history, [objective_function(x) for x in mgd_history], 'go-', 
             markersize=6, label='모멘텀 경사 하강법')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title('최적화 과정 비교')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 수렴 과정
    plt.subplot(2, 1, 2)
    plt.plot([objective_function(x) for x in gd_history], 'r-', linewidth=2, label='기본 경사 하강법')
    plt.plot([objective_function(x) for x in mgd_history], 'g-', linewidth=2, label='모멘텀 경사 하강법')
    plt.xlabel('에포크')
    plt.ylabel('f(x)')
    plt.title('손실 함수 수렴 비교')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("최적화 알고리즘 비교:")
    print(f"기본 경사 하강법 최종값: {gd_history[-1]:.6f}")
    print(f"모멘텀 경사 하강법 최종값: {mgd_history[-1]:.6f}")
    print(f"실제 최솟값: -1.0")

def main():
    """메인 함수: 모든 시각화 실행"""
    print("🚀 미적분 직관적 이해 - 시각화 중심 학습")
    print("=" * 60)
    
    # 1. 미분 시각화
    visualize_derivatives()
    
    # 2. 여러 점에서의 미분
    visualize_multiple_derivatives()
    
    # 3. 수치 vs 해석적 미분
    numerical_vs_analytical_derivatives()
    
    # 4. 기울기 하강법
    gradient_descent_visualization()
    
    # 5. 학습률 비교
    compare_learning_rates()
    
    # 6. 편미분
    partial_derivatives_visualization()
    
    # 7. 그래디언트 벡터
    gradient_vector_visualization()
    
    # 8. 연쇄법칙
    chain_rule_visualization()
    
    # 9. 신경망 역전파
    simple_neural_network()
    
    # 10. 선형 회귀
    linear_regression_gradient_descent()
    
    # 11. 최적화 알고리즘 비교
    compare_optimizers()
    
    print("\n🎉 미적분 학습 완료!")
    print("핵심 메시지: 복잡한 수식보다는 직관적 이해와 실제 구현에 집중하세요!")

if __name__ == "__main__":
    main() 