"""
평가 지표 테스트 모듈
각 평가 지표의 정확성을 검증합니다.
"""

import unittest
import numpy as np
import sys
import os
from pathlib import Path

# 프로젝트 루트를 Python 경로에 추가
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

from src.metrics.map_2d import Map2DCalculator
from src.metrics.map_3d import Map3DCalculator
from src.metrics.chamfer_distance import ChamferDistanceCalculator
from src.metrics.emd import EMCalculator
from src.metrics.class_accuracy import ClassAccuracyCalculator


class TestMap2DCalculator(unittest.TestCase):
    """2D mAP 계산기 테스트"""
    
    def setUp(self):
        """테스트 설정"""
        self.calculator = Map2DCalculator()
        
        # 테스트용 더미 이미지 생성
        self.test_image = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8)
        self.rendered_images = [
            np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8)
            for _ in range(4)
        ]
    
    def test_compute_iou(self):
        """IoU 계산 테스트"""
        box1 = np.array([10, 10, 50, 50])
        box2 = np.array([30, 30, 70, 70])
        
        iou = self.calculator.compute_iou(box1, box2)
        
        # IoU는 0과 1 사이의 값이어야 함
        self.assertGreaterEqual(iou, 0.0)
        self.assertLessEqual(iou, 1.0)
    
    def test_compute_iou_no_overlap(self):
        """겹치지 않는 박스의 IoU 테스트"""
        box1 = np.array([0, 0, 10, 10])
        box2 = np.array([20, 20, 30, 30])
        
        iou = self.calculator.compute_iou(box1, box2)
        
        # 겹치지 않는 경우 IoU는 0이어야 함
        self.assertEqual(iou, 0.0)
    
    def test_compute_iou_identical(self):
        """동일한 박스의 IoU 테스트"""
        box1 = np.array([10, 10, 50, 50])
        box2 = np.array([10, 10, 50, 50])
        
        iou = self.calculator.compute_iou(box1, box2)
        
        # 동일한 박스의 IoU는 1이어야 함
        self.assertEqual(iou, 1.0)
    
    def test_calculate_2d_map(self):
        """2D mAP 계산 테스트"""
        map_score = self.calculator.calculate_2d_map(
            self.rendered_images, self.test_image
        )
        
        # mAP는 0과 1 사이의 값이어야 함
        self.assertGreaterEqual(map_score, 0.0)
        self.assertLessEqual(map_score, 1.0)


class TestMap3DCalculator(unittest.TestCase):
    """3D mAP 계산기 테스트"""
    
    def setUp(self):
        """테스트 설정"""
        self.calculator = Map3DCalculator()
        
        # 테스트용 더미 3D 모델 생성
        self.test_model = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3)),
            'center': np.array([0, 0, 0]),
            'scale': np.array([1, 1, 1]),
            'bounding_box': (np.array([-1, -1, -1]), np.array([1, 1, 1]))
        }
        
        self.ground_truth_3d = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3)),
            'classes': ['chair'],
            'bounding_boxes': [{
                'center': np.array([0, 0, 0]),
                'size': np.array([1, 1, 1]),
                'min_coords': np.array([-0.5, -0.5, -0.5]),
                'max_coords': np.array([0.5, 0.5, 0.5]),
                'volume': 1.0,
                'confidence': 1.0
            }]
        }
    
    def test_compute_3d_iou(self):
        """3D IoU 계산 테스트"""
        box1 = {
            'min_coords': np.array([0, 0, 0]),
            'max_coords': np.array([2, 2, 2])
        }
        box2 = {
            'min_coords': np.array([1, 1, 1]),
            'max_coords': np.array([3, 3, 3])
        }
        
        iou = self.calculator.compute_3d_iou(box1, box2)
        
        # 3D IoU는 0과 1 사이의 값이어야 함
        self.assertGreaterEqual(iou, 0.0)
        self.assertLessEqual(iou, 1.0)
    
    def test_calculate_3d_map(self):
        """3D mAP 계산 테스트"""
        map_score = self.calculator.calculate_3d_map(
            self.test_model, self.ground_truth_3d
        )
        
        # mAP는 0과 1 사이의 값이어야 함
        self.assertGreaterEqual(map_score, 0.0)
        self.assertLessEqual(map_score, 1.0)


class TestChamferDistanceCalculator(unittest.TestCase):
    """Chamfer Distance 계산기 테스트"""
    
    def setUp(self):
        """테스트 설정"""
        self.calculator = ChamferDistanceCalculator()
        
        # 테스트용 더미 3D 모델 생성
        self.model_3d = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3))
        }
        
        self.reference_3d = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3))
        }
    
    def test_chamfer_distance_naive(self):
        """Chamfer Distance 계산 테스트"""
        pc1 = np.random.rand(50, 3)
        pc2 = np.random.rand(50, 3)
        
        distance = self.calculator.chamfer_distance_naive(pc1, pc2)
        
        # Chamfer Distance는 0 이상의 값이어야 함
        self.assertGreaterEqual(distance, 0.0)
    
    def test_chamfer_distance_identical(self):
        """동일한 점군의 Chamfer Distance 테스트"""
        pc1 = np.random.rand(50, 3)
        pc2 = pc1.copy()
        
        distance = self.calculator.chamfer_distance_naive(pc1, pc2)
        
        # 동일한 점군의 Chamfer Distance는 0에 가까워야 함
        self.assertLess(distance, 1e-6)
    
    def test_calculate_chamfer_distance(self):
        """Chamfer Distance 계산 테스트"""
        distance = self.calculator.calculate_chamfer_distance(
            self.model_3d, self.reference_3d
        )
        
        # Chamfer Distance는 0 이상의 값이어야 함
        self.assertGreaterEqual(distance, 0.0)


class TestEMCalculator(unittest.TestCase):
    """EMD 계산기 테스트"""
    
    def setUp(self):
        """테스트 설정"""
        self.calculator = EMCalculator()
        
        # 테스트용 더미 3D 모델 생성
        self.model_3d = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3))
        }
        
        self.reference_3d = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3))
        }
    
    def test_earth_movers_distance(self):
        """EMD 계산 테스트"""
        pc1 = np.random.rand(50, 3)
        pc2 = np.random.rand(50, 3)
        
        emd = self.calculator.earth_movers_distance(pc1, pc2)
        
        # EMD는 0 이상의 값이어야 함
        self.assertGreaterEqual(emd, 0.0)
    
    def test_earth_movers_distance_identical(self):
        """동일한 점군의 EMD 테스트"""
        pc1 = np.random.rand(50, 3)
        pc2 = pc1.copy()
        
        emd = self.calculator.earth_movers_distance(pc1, pc2)
        
        # 동일한 점군의 EMD는 0에 가까워야 함
        self.assertLess(emd, 1e-6)
    
    def test_calculate_emd(self):
        """EMD 계산 테스트"""
        emd = self.calculator.calculate_emd(self.model_3d, self.reference_3d)
        
        # EMD는 0 이상의 값이어야 함
        self.assertGreaterEqual(emd, 0.0)


class TestClassAccuracyCalculator(unittest.TestCase):
    """클래스 정확도 계산기 테스트"""
    
    def setUp(self):
        """테스트 설정"""
        self.calculator = ClassAccuracyCalculator()
        
        # 테스트용 더미 3D 모델 생성
        self.model_3d = {
            'vertices': np.random.rand(100, 3),
            'faces': np.random.randint(0, 100, (50, 3))
        }
        
        self.ground_truth_labels = {
            'classes': ['chair', 'table'],
            'labels': ['chair', 'table']
        }
    
    def test_classify_objects(self):
        """객체 분류 테스트"""
        classes = self.calculator.classify_objects(self.model_3d)
        
        # 분류 결과는 리스트여야 함
        self.assertIsInstance(classes, list)
        self.assertGreater(len(classes), 0)
    
    def test_calculate_class_accuracy(self):
        """클래스 정확도 계산 테스트"""
        accuracy = self.calculator.calculate_class_accuracy(
            self.model_3d, self.ground_truth_labels
        )
        
        # 정확도는 0과 1 사이의 값이어야 함
        self.assertGreaterEqual(accuracy, 0.0)
        self.assertLessEqual(accuracy, 1.0)
    
    def test_compute_class_wise_metrics(self):
        """클래스별 메트릭 계산 테스트"""
        predictions = ['chair', 'table', 'chair']
        ground_truth = ['chair', 'table', 'sofa']
        
        metrics = self.calculator.compute_class_wise_metrics(predictions, ground_truth)
        
        # 메트릭은 딕셔너리여야 함
        self.assertIsInstance(metrics, dict)
        self.assertIn('overall_accuracy', metrics)


if __name__ == '__main__':
    # 테스트 실행
    unittest.main(verbosity=2)
