"""
Part 9: 프로덕션 준비 API 테스트
"""

import unittest
import sys
import os
import json
from unittest.mock import patch, MagicMock
import numpy as np
from fastapi.testclient import TestClient
from fastapi import HTTPException

# 상위 디렉토리를 Python 경로에 추가
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

# Part 9 모듈 import
try:
    # 숫자로 시작하는 모듈명은 직접 import할 수 없으므로 sys.path를 통해 접근
    import importlib.util
    base_dir = os.path.dirname(os.path.abspath(__file__))
    # main.py
    spec = importlib.util.spec_from_file_location(
        "app",
        os.path.abspath(os.path.join(base_dir, "..", "main.py"))
    )
    app_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(app_module)
    app = app_module.app
    # api.py
    spec = importlib.util.spec_from_file_location(
        "api",
        os.path.abspath(os.path.join(base_dir, "..", "api.py"))
    )
    api_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(api_module)
    api = api_module
    # schemas.py
    spec = importlib.util.spec_from_file_location(
        "schemas",
        os.path.abspath(os.path.join(base_dir, "..", "schemas.py"))
    )
    schemas_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(schemas_module)
    schemas = schemas_module
    
except ImportError as e:
    print(f"Warning: Part 9 모듈을 import할 수 없습니다: {e}")
    # 대안 경로 시도
    try:
        sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        from app.main import app
        from app import api, schemas
    except ImportError as e2:
        print(f"Warning: 대안 경로도 실패했습니다: {e2}")
        app = None
        api = None
        schemas = None


class TestPart9Schemas(unittest.TestCase):
    """Part 9 스키마 테스트 클래스"""

    def test_iris_input_schema(self):
        """IrisInput 스키마 테스트"""
        if schemas is None:
            self.skipTest("schemas 모듈을 import할 수 없습니다.")

        # 유효한 입력 데이터
        valid_data = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        
        iris_input = schemas.IrisInput(**valid_data)
        self.assertEqual(iris_input.sepal_length, 5.1)
        self.assertEqual(iris_input.sepal_width, 3.5)
        self.assertEqual(iris_input.petal_length, 1.4)
        self.assertEqual(iris_input.petal_width, 0.2)

        # 잘못된 데이터 타입 테스트
        with self.assertRaises(Exception):
            invalid_data = {
                "sepal_length": "invalid",
                "sepal_width": 3.5,
                "petal_length": 1.4,
                "petal_width": 0.2
            }
            schemas.IrisInput(**invalid_data)

    def test_iris_prediction_schema(self):
        """IrisPrediction 스키마 테스트"""
        if schemas is None:
            self.skipTest("schemas 모듈을 import할 수 없습니다.")

        prediction_data = {
            "species_name": "setosa",
            "prediction": 0
        }
        
        prediction = schemas.IrisPrediction(**prediction_data)
        self.assertEqual(prediction.species_name, "setosa")
        self.assertEqual(prediction.prediction, 0)

    def test_model_info_schema(self):
        """ModelInfo 스키마 테스트"""
        if schemas is None:
            self.skipTest("schemas 모듈을 import할 수 없습니다.")

        model_info_data = {
            "name": "Test Model",
            "version": "1.0",
            "description": "Test description"
        }
        
        model_info = schemas.ModelInfo(**model_info_data)
        self.assertEqual(model_info.name, "Test Model")
        self.assertEqual(model_info.version, "1.0")
        self.assertEqual(model_info.description, "Test description")


class TestPart9API(unittest.TestCase):
    """Part 9 API 테스트 클래스"""

    def setUp(self):
        """테스트 설정"""
        if app is None:
            self.skipTest("app 모듈을 import할 수 없습니다.")
        
        self.client = TestClient(app)

    def test_get_model_info(self):
        """모델 정보 조회 API 테스트"""
        response = self.client.get("/api/v1/model")
        self.assertEqual(response.status_code, 200)
        
        data = response.json()
        self.assertIn("name", data)
        self.assertIn("version", data)
        self.assertIn("description", data)
        self.assertEqual(data["name"], "Iris Species Predictor")
        self.assertEqual(data["version"], "1.0")

    def test_predict_species_valid_input(self):
        """유효한 입력으로 종 예측 API 테스트"""
        test_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        
        response = self.client.post("/api/v1/predict", json=test_input)
        self.assertEqual(response.status_code, 200)
        
        data = response.json()
        self.assertIn("species_name", data)
        self.assertIn("prediction", data)
        self.assertIsInstance(data["species_name"], str)
        self.assertIsInstance(data["prediction"], int)

    def test_predict_species_invalid_input(self):
        """잘못된 입력으로 종 예측 API 테스트"""
        # 잘못된 데이터 타입
        invalid_input = {
            "sepal_length": "invalid",
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        
        response = self.client.post("/api/v1/predict", json=invalid_input)
        self.assertEqual(response.status_code, 422)  # Validation error

    def test_predict_species_missing_fields(self):
        """필수 필드 누락으로 종 예측 API 테스트"""
        # 필수 필드 누락
        incomplete_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5
            # petal_length, petal_width 누락
        }
        
        response = self.client.post("/api/v1/predict", json=incomplete_input)
        self.assertEqual(response.status_code, 422)  # Validation error

    def test_predict_species_edge_cases(self):
        """경계값으로 종 예측 API 테스트"""
        # 매우 작은 값
        small_input = {
            "sepal_length": 0.1,
            "sepal_width": 0.1,
            "petal_length": 0.1,
            "petal_width": 0.1
        }
        
        response = self.client.post("/api/v1/predict", json=small_input)
        self.assertEqual(response.status_code, 200)
        
        # 매우 큰 값
        large_input = {
            "sepal_length": 100.0,
            "sepal_width": 100.0,
            "petal_length": 100.0,
            "petal_width": 100.0
        }
        
        response = self.client.post("/api/v1/predict", json=large_input)
        self.assertEqual(response.status_code, 200)

    def test_api_endpoints_exist(self):
        """API 엔드포인트 존재 확인"""
        # 모델 정보 엔드포인트
        response = self.client.get("/api/v1/model")
        self.assertNotEqual(response.status_code, 404)
        
        # 예측 엔드포인트
        test_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        response = self.client.post("/api/v1/predict", json=test_input)
        self.assertNotEqual(response.status_code, 404)

    def test_response_format(self):
        """응답 형식 테스트"""
        # 모델 정보 응답 형식
        response = self.client.get("/api/v1/model")
        data = response.json()
        self.assertIsInstance(data, dict)
        self.assertIn("name", data)
        self.assertIn("version", data)
        self.assertIn("description", data)
        
        # 예측 응답 형식
        test_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        response = self.client.post("/api/v1/predict", json=test_input)
        data = response.json()
        self.assertIsInstance(data, dict)
        self.assertIn("species_name", data)
        self.assertIn("prediction", data)


class TestPart9ModelLogic(unittest.TestCase):
    """Part 9 모델 로직 테스트 클래스"""

    def setUp(self):
        if app is None:
            self.skipTest("app 모듈을 import할 수 없습니다.")
        self.client = TestClient(app)

    def test_model_prediction_logic(self):
        """모델 예측 로직 테스트 (TestClient 사용)"""
        test_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        response = self.client.post("/api/v1/predict", json=test_input)
        self.assertEqual(response.status_code, 200)
        data = response.json()
        self.assertEqual(data["species_name"], "setosa")
        self.assertEqual(data["prediction"], 0)

    def test_model_info_structure(self):
        """모델 정보 구조 테스트 (TestClient 사용)"""
        response = self.client.get("/api/v1/model")
        self.assertEqual(response.status_code, 200)
        data = response.json()
        self.assertIsInstance(data, dict)
        self.assertIn("name", data)
        self.assertIn("version", data)
        self.assertIn("description", data)
        self.assertEqual(data["name"], "Iris Species Predictor")


class TestPart9Integration(unittest.TestCase):
    """Part 9 통합 테스트 클래스"""

    def setUp(self):
        """테스트 설정"""
        if app is None:
            self.skipTest("app 모듈을 import할 수 없습니다.")
        
        self.client = TestClient(app)

    def test_full_prediction_workflow(self):
        """전체 예측 워크플로우 테스트"""
        # 1. 모델 정보 조회
        model_response = self.client.get("/api/v1/model")
        self.assertEqual(model_response.status_code, 200)
        
        # 2. 예측 수행
        test_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2
        }
        predict_response = self.client.post("/api/v1/predict", json=test_input)
        self.assertEqual(predict_response.status_code, 200)
        
        # 3. 결과 검증
        data = predict_response.json()
        self.assertIn("species_name", data)
        self.assertIn("prediction", data)

    def test_multiple_predictions_consistency(self):
        """여러 예측의 일관성 테스트"""
        test_inputs = [
            {"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2},
            {"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}
        ]
        
        predictions = []
        for input_data in test_inputs:
            response = self.client.post("/api/v1/predict", json=input_data)
            self.assertEqual(response.status_code, 200)
            predictions.append(response.json()["prediction"])
        
        # 동일한 입력에 대해 동일한 예측이 나와야 함
        self.assertEqual(predictions[0], predictions[1])

    def test_different_inputs_produce_different_predictions(self):
        """다른 입력에 대한 다른 예측 테스트"""
        # setosa와 virginica의 대표적인 값들
        setosa_input = {"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}
        virginica_input = {"sepal_length": 6.3, "sepal_width": 3.3, "petal_length": 6.0, "petal_width": 2.5}
        
        setosa_response = self.client.post("/api/v1/predict", json=setosa_input)
        virginica_response = self.client.post("/api/v1/predict", json=virginica_input)
        
        self.assertEqual(setosa_response.status_code, 200)
        self.assertEqual(virginica_response.status_code, 200)
        
        setosa_prediction = setosa_response.json()["prediction"]
        virginica_prediction = virginica_response.json()["prediction"]
        
        # 다른 종으로 예측되어야 함 (항상은 아니지만 대부분의 경우)
        # 실제로는 모델에 따라 다를 수 있으므로 예측이 유효한 범위인지만 확인
        self.assertIn(setosa_prediction, [0, 1, 2])
        self.assertIn(virginica_prediction, [0, 1, 2])


class TestPart9ErrorHandling(unittest.TestCase):
    """Part 9 에러 처리 테스트 클래스"""

    def setUp(self):
        """테스트 설정"""
        if app is None:
            self.skipTest("app 모듈을 import할 수 없습니다.")
        
        self.client = TestClient(app)

    def test_invalid_json_input(self):
        """잘못된 JSON 입력 테스트"""
        response = self.client.post("/api/v1/predict", data="invalid json")
        self.assertEqual(response.status_code, 422)

    def test_empty_request_body(self):
        """빈 요청 본문 테스트"""
        response = self.client.post("/api/v1/predict", json={})
        self.assertEqual(response.status_code, 422)

    def test_negative_values(self):
        """음수 값 테스트"""
        negative_input = {
            "sepal_length": -1.0,
            "sepal_width": -1.0,
            "petal_length": -1.0,
            "petal_width": -1.0
        }
        
        response = self.client.post("/api/v1/predict", json=negative_input)
        # 음수 값도 유효한 입력으로 처리되어야 함 (모델이 처리할 수 있는지 확인)
        self.assertEqual(response.status_code, 200)


if __name__ == "__main__":
    unittest.main() 