"""
Part 12: 모델 최적화 테스트
"""

import unittest
import sys
import os
import tempfile
import shutil
from unittest.mock import patch, MagicMock, mock_open
import numpy as np
import torch

# 상위 디렉토리를 Python 경로에 추가
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Part 12 모듈 import
try:
    # 숫자로 시작하는 모듈명은 직접 import할 수 없으므로 importlib를 사용
    import importlib.util
    base_dir = os.path.dirname(os.path.abspath(__file__))
    spec = importlib.util.spec_from_file_location(
        "bert_quantization",
        os.path.abspath(os.path.join(base_dir, "..", "bert_quantization.py"))
    )
    bert_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(bert_module)
    get_model_size = bert_module.get_model_size
    benchmark_performance = bert_module.benchmark_performance
    
except ImportError as e:
    print(f"Warning: Part 12 모듈을 import할 수 없습니다: {e}")
    # 대안 경로 시도
    try:
        sys.path.append(os.path.dirname(os.path.abspath(__file__)))
        from bert_quantization import get_model_size, benchmark_performance
    except ImportError as e2:
        print(f"Warning: 대안 경로도 실패했습니다: {e2}")
        get_model_size = None
        benchmark_performance = None


class TestPart12ModelSizeCalculation(unittest.TestCase):
    """Part 12 모델 크기 계산 테스트 클래스"""

    def test_get_model_size(self):
        """모델 크기 계산 함수 테스트"""
        if get_model_size is None:
            self.skipTest("get_model_size 함수를 import할 수 없습니다.")

        # 간단한 모델 생성
        model = torch.nn.Sequential(
            torch.nn.Linear(10, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 1)
        )
        
        # 모델 크기 계산
        size = get_model_size(model)
        
        # 결과 검증
        self.assertIsInstance(size, float)
        self.assertGreater(size, 0)
        self.assertLess(size, 100)  # 작은 모델이므로 100MB 미만이어야 함

    def test_get_model_size_zero_parameters(self):
        """파라미터가 없는 모델의 크기 계산 테스트"""
        if get_model_size is None:
            self.skipTest("get_model_size 함수를 import할 수 없습니다.")

        # 파라미터가 없는 모델 생성
        model = torch.nn.Sequential()
        
        # 모델 크기 계산
        size = get_model_size(model)
        
        # 결과 검증
        self.assertIsInstance(size, float)
        self.assertGreaterEqual(size, 0)


class TestPart12PerformanceBenchmark(unittest.TestCase):
    """Part 12 성능 벤치마크 테스트 클래스"""

    @patch('bert_quantization.time.time')
    def test_benchmark_performance(self, mock_time):
        """성능 벤치마크 함수 테스트"""
        if benchmark_performance is None:
            self.skipTest("benchmark_performance 함수를 import할 수 없습니다.")

        # Mock 파이프라인 생성
        mock_pipe = MagicMock()
        mock_pipe.return_value = {"label": "positive", "score": 0.9}
        
        # 시간 Mock 설정 (첫 번째 호출: 0.1초, 두 번째 호출: 0.2초)
        mock_time.side_effect = [0.0, 0.1, 0.0, 0.2]
        
        # 성능 벤치마크 실행
        latency, throughput = benchmark_performance(mock_pipe, "test text", num_runs=2)
        
        # 결과 검증
        self.assertIsInstance(latency, float)
        self.assertIsInstance(throughput, float)
        self.assertGreater(latency, 0)
        self.assertGreater(throughput, 0)
        
        # Mock이 올바른 횟수로 호출되었는지 확인
        self.assertEqual(mock_pipe.call_count, 12)  # 예열 10회 + 실제 테스트 2회


class TestPart12QuantizationEffects(unittest.TestCase):
    """Part 12 양자화 효과 테스트 클래스"""

    def test_quantization_size_reduction(self):
        """양자화를 통한 모델 크기 감소 테스트"""
        if get_model_size is None:
            self.skipTest("get_model_size 함수를 import할 수 없습니다.")

        # 원본 모델 크기 (가상)
        original_size = 100.0  # MB
        
        # 양자화된 모델 크기 (가상, 일반적으로 25-50% 감소)
        quantized_size = 60.0  # MB
        
        # 크기 감소율 계산
        reduction_percentage = 100 * (1 - quantized_size / original_size)
        
        # 검증
        self.assertGreater(reduction_percentage, 20)  # 최소 20% 감소
        self.assertLess(reduction_percentage, 80)     # 최대 80% 감소

    def test_quantization_performance_improvement(self):
        """양자화를 통한 성능 개선 테스트"""
        # 원본 모델 지연 시간 (가상)
        original_latency = 100.0  # ms
        
        # 양자화된 모델 지연 시간 (가상, 일반적으로 10-30% 개선)
        quantized_latency = 80.0  # ms
        
        # 성능 개선율 계산
        improvement_percentage = 100 * (original_latency - quantized_latency) / original_latency
        
        # 검증
        self.assertGreater(improvement_percentage, 0)   # 개선이 있어야 함
        self.assertLess(improvement_percentage, 50)     # 최대 50% 개선


class TestPart12Integration(unittest.TestCase):
    """Part 12 통합 테스트 클래스"""

    @patch('bert_quantization.AutoModelForSequenceClassification')
    @patch('bert_quantization.AutoTokenizer')
    @patch('bert_quantization.pipeline')
    def test_optimization_workflow(self, mock_pipeline, mock_tokenizer, mock_model):
        """모델 최적화 워크플로우 테스트"""
        if get_model_size is None or benchmark_performance is None:
            self.skipTest("필요한 함수들을 import할 수 없습니다.")

        # Mock 모델 설정
        mock_model_instance = MagicMock()
        mock_model_instance.parameters.return_value = [
            torch.randn(10, 5),  # 가상 파라미터
            torch.randn(5, 1)
        ]
        mock_model_instance.buffers.return_value = []
        mock_model.from_pretrained.return_value = mock_model_instance
        
        # Mock 토크나이저 설정
        mock_tokenizer_instance = MagicMock()
        mock_tokenizer.from_pretrained.return_value = mock_tokenizer_instance
        
        # Mock 파이프라인 설정
        mock_pipe_instance = MagicMock()
        mock_pipe_instance.return_value = {"label": "positive", "score": 0.9}
        mock_pipeline.return_value = mock_pipe_instance
        
        # 워크플로우 테스트
        try:
            # 모델 크기 계산
            size = get_model_size(mock_model_instance)
            self.assertIsInstance(size, float)
            self.assertGreater(size, 0)
            
            # 성능 벤치마크
            with patch('bert_quantization.time.time') as mock_time:
                mock_time.side_effect = [0.0, 0.1]
                latency, throughput = benchmark_performance(mock_pipe_instance, "test", num_runs=1)
                self.assertIsInstance(latency, float)
                self.assertIsInstance(throughput, float)
                
        except Exception as e:
            self.fail(f"워크플로우 테스트 중 오류 발생: {e}")


class TestPart12ErrorHandling(unittest.TestCase):
    """Part 12 에러 처리 테스트 클래스"""

    def test_get_model_size_with_none_model(self):
        """None 모델로 크기 계산 시 에러 처리 테스트"""
        if get_model_size is None:
            self.skipTest("get_model_size 함수를 import할 수 없습니다.")

        with self.assertRaises(Exception):
            get_model_size(None)

    def test_benchmark_performance_with_invalid_pipeline(self):
        """잘못된 파이프라인으로 성능 벤치마크 시 에러 처리 테스트"""
        if benchmark_performance is None:
            self.skipTest("benchmark_performance 함수를 import할 수 없습니다.")

        # None 파이프라인
        with self.assertRaises(Exception):
            benchmark_performance(None, "test text")

        # 호출할 수 없는 파이프라인
        invalid_pipe = MagicMock()
        invalid_pipe.side_effect = Exception("Pipeline error")
        
        with self.assertRaises(Exception):
            benchmark_performance(invalid_pipe, "test text")


if __name__ == "__main__":
    unittest.main() 