"""
Part 6 머신러닝 기초 테스트 코드
"""

import os
import sys
import unittest

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier

# 상위 디렉토리의 모듈을 import하기 위한 경로 추가
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class TestMachineLearningBasics(unittest.TestCase):
    """머신러닝 기초 테스트 클래스"""

    def setUp(self):
        """테스트 설정"""
        # 붓꽃 데이터셋 로드
        self.iris = load_iris()
        self.X = self.iris.data
        self.y = self.iris.target

        # 훈련/테스트 분할
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            self.X, self.y, test_size=0.2, random_state=42, stratify=self.y
        )

        # 스케일러
        self.scaler = StandardScaler()
        self.X_train_scaled = self.scaler.fit_transform(self.X_train)
        self.X_test_scaled = self.scaler.transform(self.X_test)

    def test_data_loading(self):
        """데이터 로딩 테스트"""
        # 데이터셋 크기 확인
        self.assertEqual(self.X.shape, (150, 4))
        self.assertEqual(self.y.shape, (150,))

        # 클래스 개수 확인
        self.assertEqual(len(np.unique(self.y)), 3)

        # 특성 이름 확인
        self.assertEqual(len(self.iris.feature_names), 4)
        self.assertEqual(len(self.iris.target_names), 3)

    def test_data_preprocessing(self):
        """데이터 전처리 테스트"""
        # 훈련/테스트 분할 확인
        self.assertEqual(self.X_train.shape[0], 120)
        self.assertEqual(self.X_test.shape[0], 30)

        # 스케일링 확인 (평균이 0에 가깝고 표준편차가 1에 가까운지)
        train_mean = np.mean(self.X_train_scaled, axis=0)
        train_std = np.std(self.X_train_scaled, axis=0)

        np.testing.assert_allclose(train_mean, 0, atol=1e-10)
        np.testing.assert_allclose(train_std, 1, atol=1e-10)

    def test_logistic_regression(self):
        """로지스틱 회귀 모델 테스트"""
        # 모델 훈련
        log_reg = LogisticRegression(random_state=42)
        log_reg.fit(self.X_train_scaled, self.y_train)

        # 예측 수행
        y_pred = log_reg.predict(self.X_test_scaled)

        # 정확도 확인 (최소 85% 이상)
        accuracy = accuracy_score(self.y_test, y_pred)
        self.assertGreaterEqual(accuracy, 0.85)

        # 예측 결과 형태 확인
        self.assertEqual(len(y_pred), len(self.y_test))
        self.assertTrue(all(pred in [0, 1, 2] for pred in y_pred))

    def test_decision_tree(self):
        """결정 트리 모델 테스트"""
        # 모델 훈련
        tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
        tree_clf.fit(self.X_train_scaled, self.y_train)

        # 예측 수행
        y_pred = tree_clf.predict(self.X_test_scaled)

        # 정확도 확인 (최소 80% 이상)
        accuracy = accuracy_score(self.y_test, y_pred)
        self.assertGreaterEqual(accuracy, 0.80)

        # 트리 깊이 확인
        self.assertLessEqual(tree_clf.get_depth(), 3)

    def test_model_comparison(self):
        """두 모델 성능 비교 테스트"""
        # 로지스틱 회귀
        log_reg = LogisticRegression(random_state=42)
        log_reg.fit(self.X_train_scaled, self.y_train)
        y_pred_log = log_reg.predict(self.X_test_scaled)
        accuracy_log = accuracy_score(self.y_test, y_pred_log)

        # 결정 트리
        tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
        tree_clf.fit(self.X_train_scaled, self.y_train)
        y_pred_tree = tree_clf.predict(self.X_test_scaled)
        accuracy_tree = accuracy_score(self.y_test, y_pred_tree)

        # 두 모델 모두 합리적인 성능을 보이는지 확인
        self.assertGreaterEqual(accuracy_log, 0.75)
        self.assertGreaterEqual(accuracy_tree, 0.75)

    def test_confusion_matrix(self):
        """혼동 행렬 테스트"""
        # 모델 훈련 및 예측
        log_reg = LogisticRegression(random_state=42)
        log_reg.fit(self.X_train_scaled, self.y_train)
        y_pred = log_reg.predict(self.X_test_scaled)

        # 혼동 행렬 생성
        cm = confusion_matrix(self.y_test, y_pred)

        # 혼동 행렬 크기 확인
        self.assertEqual(cm.shape, (3, 3))

        # 대각선 합이 전체 정확한 예측 수와 같은지 확인
        correct_predictions = np.trace(cm)
        total_predictions = np.sum(cm)
        accuracy = correct_predictions / total_predictions

        expected_accuracy = accuracy_score(self.y_test, y_pred)
        self.assertAlmostEqual(accuracy, expected_accuracy, places=10)


class TestModelEvaluation(unittest.TestCase):
    """모델 평가 테스트 클래스"""

    def setUp(self):
        """테스트 설정"""
        # 간단한 테스트 데이터 생성
        np.random.seed(42)
        self.y_true = np.array([0, 0, 1, 1, 2, 2])
        self.y_pred = np.array([0, 1, 1, 1, 2, 2])

    def test_accuracy_calculation(self):
        """정확도 계산 테스트"""
        accuracy = accuracy_score(self.y_true, self.y_pred)
        # 실제 계산: [0,0,1,1,2,2] vs [0,1,1,1,2,2]
        # 정확한 예측: 0번째(0), 2번째(1), 3번째(1), 4번째(2), 5번째(2) = 5개
        # 전체: 6개
        # 정확도: 5/6 = 0.8333...
        expected_accuracy = 5 / 6  # 6개 중 5개 정확
        self.assertAlmostEqual(accuracy, expected_accuracy, places=10)

    def test_classification_report(self):
        """분류 보고서 테스트"""
        try:
            report = classification_report(self.y_true, self.y_pred, output_dict=True)

            # 보고서 구조 확인
            self.assertIn("accuracy", report)
            self.assertIn("macro avg", report)
            self.assertIn("weighted avg", report)

            # 클래스별 메트릭 확인
            for class_label in [0, 1, 2]:
                self.assertIn(str(class_label), report)
                class_metrics = report[str(class_label)]
                self.assertIn("precision", class_metrics)
                self.assertIn("recall", class_metrics)
                self.assertIn("f1-score", class_metrics)

        except Exception as e:
            self.skipTest(f"Classification report test failed: {e}")


def run_tests():
    """테스트 실행 함수"""
    # 테스트 스위트 생성
    test_suite = unittest.TestSuite()

    # 테스트 클래스 추가
    test_suite.addTest(unittest.makeSuite(TestMachineLearningBasics))
    test_suite.addTest(unittest.makeSuite(TestModelEvaluation))

    # 테스트 실행
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(test_suite)

    # 결과 출력
    print(f"\n=== 테스트 결과 ===")
    print(f"실행된 테스트: {result.testsRun}")
    print(f"실패한 테스트: {len(result.failures)}")
    print(f"오류가 발생한 테스트: {len(result.errors)}")

    if result.failures:
        print("\n실패한 테스트:")
        for test, error in result.failures:
            print(f"- {test}: {error}")

    if result.errors:
        print("\n오류가 발생한 테스트:")
        for test, error in result.errors:
            print(f"- {test}: {error}")

    return result.wasSuccessful()


if __name__ == "__main__":
    # 직접 실행 시 테스트 수행
    success = run_tests()

    if success:
        print("\n모든 테스트가 성공적으로 완료되었습니다! ✅")
    else:
        print("\n일부 테스트가 실패했습니다. ❌")
        sys.exit(1)
