import json
import os
import sys
import unittest
from unittest.mock import MagicMock, patch

import numpy as np
from fastapi.testclient import TestClient
from sklearn.datasets import load_iris

# 상위 디렉토리를 Python 경로에 추가
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# FastAPI 앱 임포트
from main import app


class TestFastAPIModelServing(unittest.TestCase):
    """FastAPI 모델 서빙 테스트 클래스"""

    def setUp(self):
        """테스트 전 설정"""
        self.client = TestClient(app)
        self.iris = load_iris()

        # 테스트용 샘플 데이터
        self.test_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2,
        }

    def test_root_endpoint(self):
        """루트 엔드포인트 테스트"""
        response = self.client.get("/")

        # 상태 코드 확인
        self.assertEqual(response.status_code, 200)

        # 응답 내용 확인
        data = response.json()
        self.assertIn("message", data)
        self.assertIn("AI 모델 서빙 API", data["message"])

    def test_model_info_endpoint(self):
        """모델 정보 엔드포인트 테스트"""
        response = self.client.get("/api/v1/model")

        # 상태 코드 확인
        self.assertEqual(response.status_code, 200)

        # 응답 내용 확인
        data = response.json()
        self.assertIn("name", data)
        self.assertIn("version", data)
        self.assertIn("description", data)

        # 모델 정보 검증
        self.assertEqual(data["name"], "Iris Species Predictor")
        self.assertEqual(data["version"], "1.0")
        self.assertIn("Decision Tree", data["description"])

    def test_predict_endpoint_valid_input(self):
        """예측 엔드포인트 - 유효한 입력 테스트"""
        response = self.client.post("/api/v1/predict", json=self.test_input)

        # 상태 코드 확인
        self.assertEqual(response.status_code, 200)

        # 응답 내용 확인
        data = response.json()
        self.assertIn("species_name", data)
        self.assertIn("prediction", data)

        # 예측 결과 검증
        self.assertIsInstance(data["species_name"], str)
        self.assertIsInstance(data["prediction"], int)
        self.assertIn(data["species_name"], self.iris.target_names)
        self.assertGreaterEqual(data["prediction"], 0)
        self.assertLess(data["prediction"], len(self.iris.target_names))

    def test_predict_endpoint_invalid_input(self):
        """예측 엔드포인트 - 잘못된 입력 테스트"""
        # 잘못된 데이터 타입
        invalid_input = {
            "sepal_length": "invalid",
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2,
        }

        response = self.client.post("/api/v1/predict", json=invalid_input)

        # 상태 코드 확인 (422: Validation Error)
        self.assertEqual(response.status_code, 422)

    def test_predict_endpoint_missing_fields(self):
        """예측 엔드포인트 - 필드 누락 테스트"""
        # 필수 필드 누락
        incomplete_input = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            # petal_length, petal_width 누락
        }

        response = self.client.post("/api/v1/predict", json=incomplete_input)

        # 상태 코드 확인 (422: Validation Error)
        self.assertEqual(response.status_code, 422)

    def test_predict_endpoint_edge_cases(self):
        """예측 엔드포인트 - 경계값 테스트"""
        # 매우 작은 값
        small_input = {
            "sepal_length": 0.1,
            "sepal_width": 0.1,
            "petal_length": 0.1,
            "petal_width": 0.1,
        }

        response = self.client.post("/api/v1/predict", json=small_input)
        self.assertEqual(response.status_code, 200)

        # 매우 큰 값
        large_input = {
            "sepal_length": 100.0,
            "sepal_width": 100.0,
            "petal_length": 100.0,
            "petal_width": 100.0,
        }

        response = self.client.post("/api/v1/predict", json=large_input)
        self.assertEqual(response.status_code, 200)

    def test_predict_endpoint_multiple_samples(self):
        """예측 엔드포인트 - 여러 샘플 테스트"""
        # 여러 테스트 케이스
        test_cases = [
            {
                "sepal_length": 5.1,
                "sepal_width": 3.5,
                "petal_length": 1.4,
                "petal_width": 0.2,
            },
            {
                "sepal_length": 6.3,
                "sepal_width": 3.3,
                "petal_length": 4.7,
                "petal_width": 1.6,
            },
            {
                "sepal_length": 7.2,
                "sepal_width": 3.2,
                "petal_length": 6.0,
                "petal_width": 1.8,
            },
        ]

        for test_case in test_cases:
            response = self.client.post("/api/v1/predict", json=test_case)
            self.assertEqual(response.status_code, 200)

            data = response.json()
            self.assertIn("species_name", data)
            self.assertIn("prediction", data)

    def test_api_documentation_endpoints(self):
        """API 문서 엔드포인트 테스트"""
        # OpenAPI 스키마
        response = self.client.get("/openapi.json")
        self.assertEqual(response.status_code, 200)

        # API 문서
        response = self.client.get("/docs")
        self.assertEqual(response.status_code, 200)

        # ReDoc 문서
        response = self.client.get("/redoc")
        self.assertEqual(response.status_code, 200)


class TestSchemas(unittest.TestCase):
    """Pydantic 스키마 테스트 클래스"""

    def setUp(self):
        """테스트 전 설정"""
        from schemas import IrisInput, IrisPrediction, ModelInfo

        self.IrisInput = IrisInput
        self.IrisPrediction = IrisPrediction
        self.ModelInfo = ModelInfo

    def test_iris_input_schema(self):
        """IrisInput 스키마 테스트"""
        # 유효한 입력
        valid_data = {
            "sepal_length": 5.1,
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2,
        }

        iris_input = self.IrisInput(**valid_data)
        self.assertEqual(iris_input.sepal_length, 5.1)
        self.assertEqual(iris_input.sepal_width, 3.5)
        self.assertEqual(iris_input.petal_length, 1.4)
        self.assertEqual(iris_input.petal_width, 0.2)

        # 잘못된 데이터 타입
        invalid_data = {
            "sepal_length": "invalid",
            "sepal_width": 3.5,
            "petal_length": 1.4,
            "petal_width": 0.2,
        }

        with self.assertRaises(Exception):
            self.IrisInput(**invalid_data)

    def test_iris_prediction_schema(self):
        """IrisPrediction 스키마 테스트"""
        valid_data = {"species_name": "setosa", "prediction": 0}

        prediction = self.IrisPrediction(**valid_data)
        self.assertEqual(prediction.species_name, "setosa")
        self.assertEqual(prediction.prediction, 0)

    def test_model_info_schema(self):
        """ModelInfo 스키마 테스트"""
        valid_data = {
            "name": "Test Model",
            "version": "1.0",
            "description": "Test Description",
        }

        model_info = self.ModelInfo(**valid_data)
        self.assertEqual(model_info.name, "Test Model")
        self.assertEqual(model_info.version, "1.0")
        self.assertEqual(model_info.description, "Test Description")


class TestModelLogic(unittest.TestCase):
    """모델 로직 테스트 클래스"""

    def setUp(self):
        """테스트 전 설정"""
        from sklearn.datasets import load_iris
        from sklearn.tree import DecisionTreeClassifier

        self.iris = load_iris()
        self.X = self.iris.data
        self.y = self.iris.target

        # 모델 훈련
        self.model = DecisionTreeClassifier(max_depth=3, random_state=42)
        self.model.fit(self.X, self.y)

    def test_model_prediction_consistency(self):
        """모델 예측 일관성 테스트"""
        # 동일한 입력에 대해 항상 같은 결과가 나오는지 확인
        test_input = np.array([[5.1, 3.5, 1.4, 0.2]])

        prediction1 = self.model.predict(test_input)
        prediction2 = self.model.predict(test_input)

        self.assertEqual(prediction1[0], prediction2[0])

    def test_model_prediction_range(self):
        """모델 예측 범위 테스트"""
        # 예측 결과가 유효한 범위 내에 있는지 확인
        test_inputs = [
            [5.1, 3.5, 1.4, 0.2],  # setosa
            [6.3, 3.3, 4.7, 1.6],  # versicolor
            [7.2, 3.2, 6.0, 1.8],  # virginica
        ]

        for test_input in test_inputs:
            prediction = self.model.predict(np.array([test_input]))
            predicted_class = prediction[0]

            # 예측 클래스가 유효한 범위 내에 있는지 확인
            self.assertGreaterEqual(predicted_class, 0)
            self.assertLess(predicted_class, len(self.iris.target_names))

            # 예측된 클래스에 해당하는 품종 이름이 존재하는지 확인
            species_name = self.iris.target_names[predicted_class]
            self.assertIn(species_name, self.iris.target_names)

    def test_model_accuracy_on_training_data(self):
        """훈련 데이터에 대한 모델 정확도 테스트"""
        # 훈련 데이터에 대한 예측
        predictions = self.model.predict(self.X)

        # 정확도 계산
        accuracy = np.mean(predictions == self.y)

        # 결정 트리 모델이므로 훈련 데이터에 대해 높은 정확도를 가져야 함
        self.assertGreater(accuracy, 0.8)  # 80% 이상의 정확도


if __name__ == "__main__":
    # 테스트 스위트 생성
    test_suite = unittest.TestSuite()

    # 테스트 클래스 추가
    test_suite.addTest(unittest.makeSuite(TestFastAPIModelServing))
    test_suite.addTest(unittest.makeSuite(TestSchemas))
    test_suite.addTest(unittest.makeSuite(TestModelLogic))

    # 테스트 실행
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(test_suite)

    # 테스트 결과 출력
    print(f"\n테스트 결과: {len(result.failures)} 실패, {len(result.errors)} 오류")
    if result.failures:
        print("\n실패한 테스트:")
        for test, traceback in result.failures:
            print(f"- {test}: {traceback}")
    if result.errors:
        print("\n오류가 발생한 테스트:")
        for test, traceback in result.errors:
            print(f"- {test}: {traceback}")
