import numpy as np
import matplotlib.pyplot as plt
from matplotlib.image import imread
from mpl_toolkits.mplot3d import Axes3D

def visualize_vectors():
    """1. 벡터를 화살표로 시각화"""
    print("=== 1. 벡터: 방향과 크기를 가진 화살표 ===")
    
    def plot_vector(ax, start, vector, color='blue', label=''):
        """벡터를 화살표로 그리는 함수"""
        ax.quiver(start[0], start[1], vector[0], vector[1], 
                  angles='xy', scale_units='xy', scale=1, color=color, label=label)

    # 그래프 설정
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')

    # 원점
    origin = np.array([0, 0])

    # 다양한 벡터들
    vectors = {
        'v1': np.array([3, 2]),    # 오른쪽 위로
        'v2': np.array([-2, 1]),   # 왼쪽 위로  
        'v3': np.array([0, -3]),   # 아래로
        'v4': np.array([4, 0])     # 오른쪽으로
    }

    colors = ['red', 'blue', 'green', 'purple']

    # 벡터들을 화살표로 그리기
    for i, (name, vector) in enumerate(vectors.items()):
        plot_vector(ax, origin, vector, colors[i], name)
        
        # 벡터의 크기 계산
        magnitude = np.linalg.norm(vector)
        print(f"{name}: 크기 = {magnitude:.2f}, 방향 = {vector}")

    ax.set_xlabel('X축')
    ax.set_ylabel('Y축')
    ax.set_title('벡터 = 방향과 크기를 가진 화살표')
    ax.legend()
    plt.show()

def vector_operations():
    """2. 벡터 연산의 직관적 의미"""
    print("\n=== 2. 벡터 연산: 이동 경로 합치기 ===")
    
    def plot_vector(ax, start, vector, color='blue', label=''):
        ax.quiver(start[0], start[1], vector[0], vector[1], 
                  angles='xy', scale_units='xy', scale=1, color=color, label=label)

    # 벡터 덧셈 시각화
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # 첫 번째 그래프: 개별 벡터들
    ax1.set_xlim(-5, 5)
    ax1.set_ylim(-5, 5)
    ax1.grid(True, alpha=0.3)
    ax1.set_aspect('equal')

    v1 = np.array([2, 1])
    v2 = np.array([1, 3])

    plot_vector(ax1, [0, 0], v1, 'red', 'v1')
    plot_vector(ax1, [0, 0], v2, 'blue', 'v2')
    ax1.set_title('개별 벡터들')
    ax1.legend()

    # 두 번째 그래프: 벡터 덧셈
    ax2.set_xlim(-5, 5)
    ax2.set_ylim(-5, 5)
    ax2.grid(True, alpha=0.3)
    ax2.set_aspect('equal')

    # v1을 먼저 그리고, 그 끝점에서 v2를 그리기
    plot_vector(ax2, [0, 0], v1, 'red', 'v1')
    plot_vector(ax2, v1, v2, 'blue', 'v2')  # v1의 끝점에서 v2 시작

    # 합 벡터 (원점에서 최종 위치까지)
    v_sum = v1 + v2
    plot_vector(ax2, [0, 0], v_sum, 'green', 'v1 + v2')

    ax2.set_title('벡터 덧셈: 이동 경로 합치기')
    ax2.legend()

    plt.tight_layout()
    plt.show()

    print(f"v1 = {v1}")
    print(f"v2 = {v2}")
    print(f"v1 + v2 = {v_sum}")

def vector_similarity():
    """3. 벡터 내적: 서로 얼마나 같은 방향인지"""
    print("\n=== 3. 벡터 내적: 유사도 측정 ===")
    
    def cosine_similarity(v1, v2):
        """두 벡터의 코사인 유사도 계산"""
        dot_product = np.dot(v1, v2)
        norm_v1 = np.linalg.norm(v1)
        norm_v2 = np.linalg.norm(v2)
        return dot_product / (norm_v1 * norm_v2)

    # 다양한 각도의 벡터들
    angles = [0, 45, 90, 135, 180]  # 도
    vectors = []

    for angle in angles:
        # 각도를 라디안으로 변환
        rad = np.radians(angle)
        # 단위 벡터 생성
        v = np.array([np.cos(rad), np.sin(rad)])
        vectors.append(v)

    # 기준 벡터 (오른쪽 방향)
    base_vector = np.array([1, 0])

    print("벡터 간 유사도 (내적의 직관적 의미):")
    print("-" * 50)
    for i, angle in enumerate(angles):
        similarity = cosine_similarity(base_vector, vectors[i])
        direction = '같은 방향' if similarity > 0.9 else '수직' if abs(similarity) < 0.1 else '반대 방향' if similarity < -0.9 else '대각선'
        print(f"각도 {angle}°: 유사도 = {similarity:.3f} ({direction})")

def matrix_transformations():
    """4. 행렬: 데이터를 변형하는 마법 상자"""
    print("\n=== 4. 행렬 변환 시각화 ===")
    
    def plot_transformation(original_points, transformed_points, title):
        """점들의 변환을 시각화"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # 원본 점들
        ax1.scatter(original_points[:, 0], original_points[:, 1], c='blue', alpha=0.6)
        ax1.set_xlim(-3, 3)
        ax1.set_ylim(-3, 3)
        ax1.grid(True, alpha=0.3)
        ax1.set_aspect('equal')
        ax1.set_title('변환 전')
        
        # 변환된 점들
        ax2.scatter(transformed_points[:, 0], transformed_points[:, 1], c='red', alpha=0.6)
        ax2.set_xlim(-3, 3)
        ax2.set_ylim(-3, 3)
        ax2.grid(True, alpha=0.3)
        ax2.set_aspect('equal')
        ax2.set_title('변환 후')
        
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()

    # 원형 점들 생성
    theta = np.linspace(0, 2*np.pi, 50)
    circle_points = np.column_stack([np.cos(theta), np.sin(theta)])

    # 1. 회전 변환 (90도 회전)
    rotation_matrix = np.array([[0, -1], [1, 0]])  # 90도 회전
    rotated_points = circle_points @ rotation_matrix.T
    plot_transformation(circle_points, rotated_points, "회전 변환: 원형 → 회전된 원형")

    # 2. 크기 조절 변환 (X축으로 2배 늘리기)
    scaling_matrix = np.array([[2, 0], [0, 1]])  # X축 2배
    scaled_points = circle_points @ scaling_matrix.T
    plot_transformation(circle_points, scaled_points, "크기 조절: 원형 → 타원형")

    # 3. 전단 변환 (X축 방향으로 기울이기)
    shear_matrix = np.array([[1, 0.5], [0, 1]])  # X축 방향 전단
    sheared_points = circle_points @ shear_matrix.T
    plot_transformation(circle_points, sheared_points, "전단 변환: 원형 → 기울어진 원형")

def matrix_multiplication():
    """5. 행렬곱의 직관적 의미"""
    print("\n=== 5. 행렬곱: 복합 변환 ===")
    
    # 원본 점들
    theta = np.linspace(0, 2*np.pi, 100)
    original = np.column_stack([np.cos(theta), np.sin(theta)])

    # 변환 행렬들
    A = np.array([[0.8, -0.6], [0.6, 0.8]])  # 회전 + 축소
    B = np.array([[1.5, 0], [0, 0.8]])       # X축 확대, Y축 축소

    # 단계별 변환
    step1 = original @ A.T
    step2 = step1 @ B.T

    # 복합 변환 (A × B)
    combined = original @ (A @ B).T

    # 시각화
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    titles = ['원본', 'A 변환 후', 'B 변환 후', 'A×B 복합 변환']
    points_list = [original, step1, step2, combined]

    for i, (ax, title, points) in enumerate(zip(axes.flat, titles, points_list)):
        ax.scatter(points[:, 0], points[:, 1], c='blue', alpha=0.6)
        ax.set_xlim(-2, 2)
        ax.set_ylim(-2, 2)
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')
        ax.set_title(title)

    plt.tight_layout()
    plt.show()

    print("행렬곱의 의미: 변환을 순서대로 적용")
    print(f"A = {A}")
    print(f"B = {B}")
    print(f"A × B = {A @ B}")

def eigenvectors_visualization():
    """6. 고유벡터: 변하지 않는 특별한 방향"""
    print("\n=== 6. 고유벡터 시각화 ===")
    
    def visualize_eigenvectors(matrix, title):
        """고유벡터를 시각화하는 함수"""
        # 고유값과 고유벡터 계산
        eigenvalues, eigenvectors = np.linalg.eig(matrix)
        
        # 원형 점들
        theta = np.linspace(0, 2*np.pi, 100)
        circle = np.column_stack([np.cos(theta), np.sin(theta)])
        
        # 변환 전후
        transformed = circle @ matrix.T
        
        # 고유벡터들
        eigenvector1 = eigenvectors[:, 0]
        eigenvector2 = eigenvectors[:, 1]
        
        # 시각화
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # 변환 전
        ax1.scatter(circle[:, 0], circle[:, 1], c='lightblue', alpha=0.6, s=20)
        ax1.quiver(0, 0, eigenvector1[0], eigenvector1[1], color='red', scale=3, 
                   label=f'고유벡터1 (λ={eigenvalues[0]:.2f})')
        ax1.quiver(0, 0, eigenvector2[0], eigenvector2[1], color='green', scale=3, 
                   label=f'고유벡터2 (λ={eigenvalues[1]:.2f})')
        ax1.set_xlim(-2, 2)
        ax1.set_ylim(-2, 2)
        ax1.grid(True, alpha=0.3)
        ax1.set_aspect('equal')
        ax1.set_title('변환 전: 고유벡터 방향')
        ax1.legend()
        
        # 변환 후
        ax2.scatter(transformed[:, 0], transformed[:, 1], c='lightcoral', alpha=0.6, s=20)
        # 변환된 고유벡터들
        transformed_eigenvector1 = eigenvector1 @ matrix.T
        transformed_eigenvector2 = eigenvector2 @ matrix.T
        ax2.quiver(0, 0, transformed_eigenvector1[0], transformed_eigenvector1[1], 
                   color='red', scale=3, label=f'변환된 고유벡터1')
        ax2.quiver(0, 0, transformed_eigenvector2[0], transformed_eigenvector2[1], 
                   color='green', scale=3, label=f'변환된 고유벡터2')
        ax2.set_xlim(-2, 2)
        ax2.set_ylim(-2, 2)
        ax2.grid(True, alpha=0.3)
        ax2.set_aspect('equal')
        ax2.set_title('변환 후: 고유벡터는 방향이 유지됨')
        ax2.legend()
        
        plt.suptitle(title)
        plt.tight_layout()
        plt.show()
        
        print(f"고유값: {eigenvalues}")
        print(f"고유벡터:\n{eigenvectors}")

    # 예제 행렬들
    matrix1 = np.array([[2, 0], [0, 1]])  # 대각 행렬
    matrix2 = np.array([[1, 0.5], [0.5, 1]])  # 대칭 행렬

    visualize_eigenvectors(matrix1, "대각 행렬의 고유벡터")
    visualize_eigenvectors(matrix2, "대칭 행렬의 고유벡터")

def pca_visualization():
    """7. PCA: 데이터의 핵심 방향 찾기"""
    print("\n=== 7. PCA 시각화 ===")
    
    # 2D 데이터 생성 (상관관계가 있는 데이터)
    np.random.seed(42)
    mean = [0, 0]
    cov = [[3, 2], [2, 3]]  # 상관관계가 있는 공분산 행렬
    data = np.random.multivariate_normal(mean, cov, 200)

    # 데이터 중앙 정렬
    X = data - data.mean(axis=0)

    # 공분산 행렬 계산
    cov_matrix = np.cov(X.T)

    # 고유값과 고유벡터 계산 (PCA)
    eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)

    # 고유값 순으로 정렬
    sorted_indices = np.argsort(eigenvalues)[::-1]
    sorted_eigenvalues = eigenvalues[sorted_indices]
    sorted_eigenvectors = eigenvectors[:, sorted_indices]

    # 시각화
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # 원본 데이터와 주성분
    ax1.scatter(X[:, 0], X[:, 1], alpha=0.6, c='lightblue')
    # 주성분 벡터 그리기
    pc1 = sorted_eigenvectors[:, 0]
    pc2 = sorted_eigenvectors[:, 1]
    ax1.quiver(0, 0, pc1[0]*np.sqrt(sorted_eigenvalues[0]), pc1[1]*np.sqrt(sorted_eigenvalues[0]), 
               color='red', scale=10, label=f'PC1 (분산: {sorted_eigenvalues[0]:.2f})')
    ax1.quiver(0, 0, pc2[0]*np.sqrt(sorted_eigenvalues[1]), pc2[1]*np.sqrt(sorted_eigenvalues[1]), 
               color='green', scale=10, label=f'PC2 (분산: {sorted_eigenvalues[1]:.2f})')
    ax1.set_xlim(-4, 4)
    ax1.set_ylim(-4, 4)
    ax1.grid(True, alpha=0.3)
    ax1.set_aspect('equal')
    ax1.set_title('원본 데이터와 주성분 (PC1, PC2)')
    ax1.legend()

    # PCA 변환된 데이터
    transformed_data = X @ sorted_eigenvectors
    ax2.scatter(transformed_data[:, 0], transformed_data[:, 1], alpha=0.6, c='lightcoral')
    ax2.set_xlim(-4, 4)
    ax2.set_ylim(-4, 4)
    ax2.grid(True, alpha=0.3)
    ax2.set_aspect('equal')
    ax2.set_title('PCA 변환 후: PC1이 가장 큰 분산을 가짐')
    ax2.set_xlabel('PC1')
    ax2.set_ylabel('PC2')

    plt.tight_layout()
    plt.show()

    print("PCA의 직관적 의미:")
    print(f"PC1 (첫 번째 주성분): {pc1}")
    print(f"PC2 (두 번째 주성분): {pc2}")
    print(f"PC1 분산: {sorted_eigenvalues[0]:.2f} ({sorted_eigenvalues[0]/sum(sorted_eigenvalues)*100:.1f}%)")
    print(f"PC2 분산: {sorted_eigenvalues[1]:.2f} ({sorted_eigenvalues[1]/sum(sorted_eigenvalues)*100:.1f}%)")

def image_filtering():
    """8. 이미지 필터링으로 이해하는 행렬"""
    print("\n=== 8. 이미지 필터링 예제 ===")
    
    # 간단한 이미지 생성 (체크무늬 패턴)
    def create_checkerboard(size=50):
        """체크무늬 이미지 생성"""
        img = np.zeros((size, size))
        for i in range(size):
            for j in range(size):
                if (i // 10 + j // 10) % 2 == 0:
                    img[i, j] = 255
        return img

    # 필터 행렬들 (커널)
    filters = {
        '블러': np.array([[1/9, 1/9, 1/9], [1/9, 1/9, 1/9], [1/9, 1/9, 1/9]]),
        '엣지 검출': np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]),
        '수평 엣지': np.array([[-1, -1, -1], [0, 0, 0], [1, 1, 1]]),
        '수직 엣지': np.array([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]])
    }

    # 이미지 생성
    original = create_checkerboard()

    # 필터 적용 함수
    def apply_filter(image, filter_kernel):
        """이미지에 필터 적용"""
        from scipy import ndimage
        return ndimage.convolve(image, filter_kernel, mode='constant', cval=0)

    # 시각화
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    # 원본 이미지
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title('원본 이미지')
    axes[0].axis('off')

    # 필터 적용 결과
    for i, (filter_name, filter_kernel) in enumerate(filters.items()):
        filtered = apply_filter(original, filter_kernel)
        axes[i+1].imshow(filtered, cmap='gray')
        axes[i+1].set_title(f'{filter_name} 필터')
        axes[i+1].axis('off')

    plt.tight_layout()
    plt.show()

    print("이미지 필터링 = 행렬 곱셈의 실제 활용")
    print("각 픽셀 주변의 값들을 행렬(커널)로 가중 평균하여 새로운 픽셀 값 계산")

def recommendation_system():
    """9. 추천 시스템으로 이해하는 행렬 분해"""
    print("\n=== 9. 추천 시스템 시뮬레이션 ===")
    
    # 간단한 추천 시스템 시뮬레이션
    np.random.seed(42)

    # 사용자-아이템 평점 행렬 (5명 사용자, 4개 아이템)
    # 0은 평점이 없는 경우
    ratings = np.array([
        [5, 3, 0, 1],  # 사용자 1
        [4, 0, 0, 1],  # 사용자 2
        [1, 1, 0, 5],  # 사용자 3
        [1, 0, 0, 4],  # 사용자 4
        [0, 1, 5, 4]   # 사용자 5
    ])

    print("사용자-아이템 평점 행렬:")
    print(ratings)
    print("\n행렬 분해의 직관적 의미:")
    print("- 사용자 행렬: 각 사용자의 취향 (예: 액션 선호도, 로맨스 선호도)")
    print("- 아이템 행렬: 각 아이템의 특성 (예: 액션 요소, 로맨스 요소)")
    print("- 두 행렬의 곱으로 누락된 평점을 예측")

    # 간단한 행렬 분해 (SVD의 직관적 이해)
    # 실제로는 더 복잡한 알고리즘 사용
    U = np.random.rand(5, 2)  # 사용자 특성 행렬
    V = np.random.rand(2, 4)  # 아이템 특성 행렬

    # 예측 평점
    predicted = U @ V
    print(f"\n예측된 평점 행렬:\n{predicted.round(2)}")

    # 원본과 예측 비교
    print(f"\n원본 평점 (0은 누락):\n{ratings}")
    print(f"예측 평점:\n{predicted.round(2)}")

def my_vector_class():
    """10. 나만의 벡터 클래스 만들기"""
    print("\n=== 10. 나만의 벡터 클래스 ===")
    
    class MyVector:
        """직관적인 벡터 클래스"""
        
        def __init__(self, x, y):
            self.x = x
            self.y = y
        
        def __add__(self, other):
            """벡터 덧셈: 이동 경로 합치기"""
            return MyVector(self.x + other.x, self.y + other.y)
        
        def __mul__(self, scalar):
            """스칼라 곱: 크기 조절"""
            return MyVector(self.x * scalar, self.y * scalar)
        
        def magnitude(self):
            """벡터의 크기"""
            return np.sqrt(self.x**2 + self.y**2)
        
        def dot(self, other):
            """내적: 서로 얼마나 같은 방향인지"""
            return self.x * other.x + self.y * other.y
        
        def plot(self, ax, color='blue', label=''):
            """벡터를 화살표로 그리기"""
            ax.quiver(0, 0, self.x, self.y, color=color, scale=10, label=label)
        
        def __str__(self):
            return f"Vector({self.x}, {self.y})"

    # 사용 예제
    v1 = MyVector(3, 2)
    v2 = MyVector(1, 4)

    print(f"v1 = {v1}")
    print(f"v2 = {v2}")
    print(f"v1 + v2 = {v1 + v2}")
    print(f"v1의 크기 = {v1.magnitude():.2f}")
    print(f"v1 · v2 = {v1.dot(v2)}")

    # 시각화
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')

    v1.plot(ax, 'red', 'v1')
    v2.plot(ax, 'blue', 'v2')
    (v1 + v2).plot(ax, 'green', 'v1 + v2')

    ax.legend()
    ax.set_title('나만의 벡터 클래스')
    plt.show()

def neural_network_weights():
    """11. 간단한 신경망 가중치 시각화"""
    print("\n=== 11. 신경망 가중치 시각화 ===")
    
    # 간단한 신경망 가중치를 행렬로 이해하기
    def create_simple_network():
        """간단한 신경망 가중치 행렬 생성"""
        # 입력층(2) -> 은닉층(3) -> 출력층(1)
        W1 = np.random.randn(2, 3) * 0.1  # 입력층 -> 은닉층 가중치
        W2 = np.random.randn(3, 1) * 0.1  # 은닉층 -> 출력층 가중치
        
        return W1, W2

    def visualize_weights(W1, W2):
        """신경망 가중치 시각화"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # W1 시각화 (2x3 행렬)
        im1 = ax1.imshow(W1, cmap='RdBu', aspect='auto')
        ax1.set_title('입력층 → 은닉층 가중치 (W1)')
        ax1.set_xlabel('은닉층 뉴런')
        ax1.set_ylabel('입력층 뉴런')
        plt.colorbar(im1, ax=ax1)
        
        # W2 시각화 (3x1 행렬)
        im2 = ax2.imshow(W2, cmap='RdBu', aspect='auto')
        ax2.set_title('은닉층 → 출력층 가중치 (W2)')
        ax2.set_xlabel('출력층 뉴런')
        ax2.set_ylabel('은닉층 뉴런')
        plt.colorbar(im2, ax=ax2)
        
        plt.tight_layout()
        plt.show()
        
        print("신경망 가중치 = 행렬의 실제 활용")
        print(f"W1 모양: {W1.shape} (입력 2개 → 은닉층 3개)")
        print(f"W2 모양: {W2.shape} (은닉층 3개 → 출력 1개)")

    # 신경망 생성 및 시각화
    W1, W2 = create_simple_network()
    visualize_weights(W1, W2)

def main():
    """메인 함수: 모든 시각화 실행"""
    print("🚀 선형대수 직관적 이해 - 시각화 중심 학습")
    print("=" * 60)
    
    # 1. 벡터 시각화
    visualize_vectors()
    
    # 2. 벡터 연산
    vector_operations()
    
    # 3. 벡터 유사도
    vector_similarity()
    
    # 4. 행렬 변환
    matrix_transformations()
    
    # 5. 행렬곱
    matrix_multiplication()
    
    # 6. 고유벡터
    eigenvectors_visualization()
    
    # 7. PCA
    pca_visualization()
    
    # 8. 이미지 필터링
    image_filtering()
    
    # 9. 추천 시스템
    recommendation_system()
    
    # 10. 벡터 클래스
    my_vector_class()
    
    # 11. 신경망 가중치
    neural_network_weights()
    
    print("\n🎉 선형대수 학습 완료!")
    print("핵심 메시지: 복잡한 수식보다는 직관적 이해와 실제 활용에 집중하세요!")

if __name__ == "__main__":
    main() 