"""
클래스별 정확도 계산 모듈
3D 모델의 객체 클래스 분류 정확도를 평가합니다.
"""

import numpy as np
import trimesh
from typing import Dict, List, Tuple, Optional
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.cluster import KMeans
import cv2


class ClassAccuracyCalculator:
    """클래스별 정확도 계산을 담당하는 클래스"""
    
    def __init__(self, num_classes: int = 10):
        """
        클래스 정확도 계산기 초기화
        
        Args:
            num_classes (int): 예상 클래스 개수
        """
        self.num_classes = num_classes
        self.class_names = [
            'chair', 'table', 'sofa', 'bed', 'desk', 
            'bookshelf', 'lamp', 'cabinet', 'door', 'window'
        ]
    
    def calculate_class_accuracy(self, model_3d: Dict, ground_truth_labels: Dict) -> float:
        """
        클래스별 정확도를 계산합니다.
        
        Args:
            model_3d (Dict): 변환된 3D 모델 정보
            ground_truth_labels (Dict): Ground Truth 클래스 레이블
            
        Returns:
            float: 클래스 정확도 점수 (0-1)
        """
        # 3D 모델에서 객체 클래스 분류
        predicted_classes = self.classify_objects(model_3d)
        
        # Ground Truth 클래스 레이블 추출
        gt_classes = self._extract_ground_truth_classes(ground_truth_labels)
        
        if not predicted_classes or not gt_classes:
            return 0.0
        
        # 클래스별 정확도 계산
        accuracy = self._compute_class_wise_accuracy(predicted_classes, gt_classes)
        
        return accuracy
    
    def classify_objects(self, mesh: Dict) -> List[str]:
        """
        3D 메시에서 객체 클래스를 분류합니다.
        
        Args:
            mesh (Dict): 3D 메시 정보
            
        Returns:
            List[str]: 예측된 클래스 레이블 리스트
        """
        vertices = mesh['vertices']
        faces = mesh['faces']
        
        if len(vertices) == 0:
            return []
        
        # 기하학적 특성 추출
        geometric_features = self._extract_geometric_features(vertices, faces)
        
        # 클래스 분류
        predicted_classes = self._classify_by_geometric_features(geometric_features)
        
        return predicted_classes
    
    def compute_class_wise_metrics(self, predictions: List, ground_truth: List) -> Dict:
        """
        클래스별 정밀도, 재현율, F1 점수를 계산합니다.
        
        Args:
            predictions (List): 예측 결과 리스트
            ground_truth (List): Ground Truth 리스트
            
        Returns:
            Dict: 클래스별 메트릭
        """
        if not predictions or not ground_truth:
            return {}
        
        # 정확도 계산
        accuracy = accuracy_score(ground_truth, predictions)
        
        # 정밀도, 재현율, F1 점수 계산
        precision, recall, f1, support = precision_recall_fscore_support(
            ground_truth, predictions, average='weighted', zero_division=0
        )
        
        # 혼동 행렬 계산
        cm = confusion_matrix(ground_truth, predictions)
        
        # 클래스별 상세 메트릭
        class_metrics = {}
        unique_classes = list(set(ground_truth + predictions))
        
        for class_name in unique_classes:
            class_precision, class_recall, class_f1, class_support = precision_recall_fscore_support(
                ground_truth, predictions, labels=[class_name], average='weighted', zero_division=0
            )
            
            class_metrics[class_name] = {
                'precision': class_precision,
                'recall': class_recall,
                'f1': class_f1,
                'support': class_support
            }
        
        return {
            'overall_accuracy': accuracy,
            'weighted_precision': precision,
            'weighted_recall': recall,
            'weighted_f1': f1,
            'confusion_matrix': cm,
            'class_metrics': class_metrics
        }
    
    def _extract_geometric_features(self, vertices: np.ndarray, faces: np.ndarray) -> Dict:
        """
        3D 메시에서 기하학적 특성을 추출합니다.
        
        Args:
            vertices (np.ndarray): 메시 정점
            faces (np.ndarray): 메시 면
            
        Returns:
            Dict: 기하학적 특성
        """
        features = {}
        
        # 기본 통계
        features['num_vertices'] = len(vertices)
        features['num_faces'] = len(faces)
        
        # 바운딩 박스 특성
        min_coords = np.min(vertices, axis=0)
        max_coords = np.max(vertices, axis=0)
        bbox_size = max_coords - min_coords
        
        features['bbox_width'] = bbox_size[0]
        features['bbox_height'] = bbox_size[1]
        features['bbox_depth'] = bbox_size[2]
        features['bbox_volume'] = np.prod(bbox_size)
        
        # 종횡비
        features['aspect_ratio_xy'] = bbox_size[0] / (bbox_size[1] + 1e-8)
        features['aspect_ratio_xz'] = bbox_size[0] / (bbox_size[2] + 1e-8)
        features['aspect_ratio_yz'] = bbox_size[1] / (bbox_size[2] + 1e-8)
        
        # 표면적과 부피
        try:
            trimesh_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
            features['surface_area'] = trimesh_mesh.surface_area
            features['volume'] = trimesh_mesh.volume
            features['compactness'] = features['volume'] / (features['surface_area'] + 1e-8)
        except:
            features['surface_area'] = 0.0
            features['volume'] = features['bbox_volume']
            features['compactness'] = 0.0
        
        # 복잡도 (면 대 정점 비율)
        features['complexity'] = len(faces) / (len(vertices) + 1e-8)
        
        # 중심점
        features['center'] = np.mean(vertices, axis=0)
        
        # 주성분 분석
        centered_vertices = vertices - features['center']
        cov_matrix = np.cov(centered_vertices.T)
        eigenvalues = np.linalg.eigvals(cov_matrix)
        eigenvalues = np.sort(eigenvalues)[::-1]
        
        features['eigenvalue_1'] = eigenvalues[0] if len(eigenvalues) > 0 else 0
        features['eigenvalue_2'] = eigenvalues[1] if len(eigenvalues) > 1 else 0
        features['eigenvalue_3'] = eigenvalues[2] if len(eigenvalues) > 2 else 0
        
        # 형태 특성
        if len(eigenvalues) >= 3:
            features['linearity'] = (eigenvalues[0] - eigenvalues[1]) / (eigenvalues[0] + 1e-8)
            features['planarity'] = (eigenvalues[1] - eigenvalues[2]) / (eigenvalues[0] + 1e-8)
            features['sphericity'] = eigenvalues[2] / (eigenvalues[0] + 1e-8)
        else:
            features['linearity'] = 0.0
            features['planarity'] = 0.0
            features['sphericity'] = 0.0
        
        return features
    
    def _classify_by_geometric_features(self, features: Dict) -> List[str]:
        """
        기하학적 특성을 기반으로 객체 클래스를 분류합니다.
        
        Args:
            features (Dict): 기하학적 특성
            
        Returns:
            List[str]: 예측된 클래스 레이블
        """
        predicted_classes = []
        
        # 규칙 기반 분류
        if self._is_chair_like(features):
            predicted_classes.append('chair')
        elif self._is_table_like(features):
            predicted_classes.append('table')
        elif self._is_sofa_like(features):
            predicted_classes.append('sofa')
        elif self._is_bed_like(features):
            predicted_classes.append('bed')
        elif self._is_desk_like(features):
            predicted_classes.append('desk')
        elif self._is_bookshelf_like(features):
            predicted_classes.append('bookshelf')
        elif self._is_lamp_like(features):
            predicted_classes.append('lamp')
        elif self._is_cabinet_like(features):
            predicted_classes.append('cabinet')
        elif self._is_door_like(features):
            predicted_classes.append('door')
        elif self._is_window_like(features):
            predicted_classes.append('window')
        else:
            predicted_classes.append('unknown')
        
        return predicted_classes
    
    def _is_chair_like(self, features: Dict) -> bool:
        """의자와 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_yz'] > 1.5 and  # 높이가 깊이보다 큼
                features['aspect_ratio_xy'] < 2.0 and  # 너비가 높이의 2배 이하
                features['compactness'] > 0.1)
    
    def _is_table_like(self, features: Dict) -> bool:
        """테이블과 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_xy'] > 1.5 and  # 너비가 높이보다 큼
                features['aspect_ratio_yz'] < 0.5 and  # 높이가 깊이보다 작음
                features['linearity'] < 0.3)
    
    def _is_sofa_like(self, features: Dict) -> bool:
        """소파와 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_xy'] > 2.0 and  # 너비가 높이의 2배 이상
                features['aspect_ratio_xz'] > 1.5 and  # 너비가 깊이의 1.5배 이상
                features['compactness'] > 0.05)
    
    def _is_bed_like(self, features: Dict) -> bool:
        """침대와 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_xy'] > 2.5 and  # 너비가 높이의 2.5배 이상
                features['aspect_ratio_xz'] > 1.8 and  # 너비가 깊이의 1.8배 이상
                features['aspect_ratio_yz'] < 0.3)     # 높이가 깊이보다 작음
    
    def _is_desk_like(self, features: Dict) -> bool:
        """책상과 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_xy'] > 1.2 and  # 너비가 높이보다 큼
                features['aspect_ratio_xz'] > 1.0 and  # 너비가 깊이보다 큼
                features['aspect_ratio_yz'] < 0.4)     # 높이가 깊이보다 작음
    
    def _is_bookshelf_like(self, features: Dict) -> bool:
        """책장과 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_yz'] > 2.0 and  # 높이가 깊이의 2배 이상
                features['aspect_ratio_xy'] > 1.5 and  # 너비가 높이의 1.5배 이상
                features['linearity'] > 0.4)
    
    def _is_lamp_like(self, features: Dict) -> bool:
        """램프와 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_yz'] > 3.0 and  # 높이가 깊이의 3배 이상
                features['aspect_ratio_xy'] < 1.5 and  # 너비가 높이의 1.5배 이하
                features['compactness'] < 0.1)
    
    def _is_cabinet_like(self, features: Dict) -> bool:
        """캐비닛과 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_xy'] > 1.0 and  # 너비가 높이보다 큼
                features['aspect_ratio_xz'] > 1.0 and  # 너비가 깊이보다 큼
                features['aspect_ratio_yz'] > 1.0 and  # 높이가 깊이보다 큼
                features['compactness'] > 0.2)
    
    def _is_door_like(self, features: Dict) -> bool:
        """문과 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_yz'] > 2.5 and  # 높이가 깊이의 2.5배 이상
                features['aspect_ratio_xy'] < 0.5 and  # 너비가 높이의 절반 이하
                features['planarity'] > 0.7)
    
    def _is_window_like(self, features: Dict) -> bool:
        """창문과 유사한 특성을 가졌는지 확인"""
        return (features['aspect_ratio_xy'] > 1.5 and  # 너비가 높이의 1.5배 이상
                features['aspect_ratio_yz'] > 1.5 and  # 높이가 깊이의 1.5배 이상
                features['aspect_ratio_xz'] < 0.1)     # 너비가 깊이의 0.1배 이하
    
    def _extract_ground_truth_classes(self, ground_truth_labels: Dict) -> List[str]:
        """
        Ground Truth에서 클래스 레이블을 추출합니다.
        
        Args:
            ground_truth_labels (Dict): Ground Truth 레이블 정보
            
        Returns:
            List[str]: Ground Truth 클래스 레이블 리스트
        """
        if 'classes' in ground_truth_labels:
            return ground_truth_labels['classes']
        elif 'labels' in ground_truth_labels:
            return ground_truth_labels['labels']
        else:
            # 기본 클래스 반환 (실제 구현에서는 사용자가 제공해야 함)
            return ['unknown']
    
    def _compute_class_wise_accuracy(self, predictions: List[str], ground_truth: List[str]) -> float:
        """
        클래스별 정확도를 계산합니다.
        
        Args:
            predictions (List[str]): 예측된 클래스 레이블
            ground_truth (List[str]): Ground Truth 클래스 레이블
            
        Returns:
            float: 클래스 정확도 (0-1)
        """
        if not predictions or not ground_truth:
            return 0.0
        
        # 정확도 계산
        correct = 0
        total = min(len(predictions), len(ground_truth))
        
        for i in range(total):
            if predictions[i] == ground_truth[i]:
                correct += 1
        
        accuracy = correct / total if total > 0 else 0.0
        
        return accuracy
    
    def classify_objects_clustering(self, mesh: Dict, num_clusters: int = 5) -> List[str]:
        """
        클러스터링을 사용하여 객체를 분류합니다.
        
        Args:
            mesh (Dict): 3D 메시 정보
            num_clusters (int): 클러스터 개수
            
        Returns:
            List[str]: 예측된 클래스 레이블
        """
        vertices = mesh['vertices']
        
        if len(vertices) < num_clusters:
            return ['unknown']
        
        try:
            # K-means 클러스터링
            kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
            cluster_labels = kmeans.fit_predict(vertices)
            
            # 클러스터별 특성 분석
            predicted_classes = []
            
            for cluster_id in range(num_clusters):
                cluster_vertices = vertices[cluster_labels == cluster_id]
                
                if len(cluster_vertices) < 10:
                    continue
                
                # 클러스터의 기하학적 특성 계산
                cluster_features = self._extract_geometric_features(cluster_vertices, np.array([]))
                
                # 클래스 분류
                cluster_classes = self._classify_by_geometric_features(cluster_features)
                predicted_classes.extend(cluster_classes)
            
            return predicted_classes if predicted_classes else ['unknown']
            
        except Exception:
            return ['unknown']
    
    def calculate_class_distribution_similarity(self, model_3d: Dict, reference_3d: Dict) -> float:
        """
        두 모델 간의 클래스 분포 유사성을 계산합니다.
        
        Args:
            model_3d (Dict): 변환된 3D 모델 정보
            reference_3d (Dict): 참조 3D 모델 정보
            
        Returns:
            float: 클래스 분포 유사성 (0-1)
        """
        # 각 모델의 클래스 분류
        model_classes = self.classify_objects(model_3d)
        reference_classes = self.classify_objects(reference_3d)
        
        if not model_classes or not reference_classes:
            return 0.0
        
        # 클래스 분포 계산
        model_dist = self._calculate_class_distribution(model_classes)
        reference_dist = self._calculate_class_distribution(reference_classes)
        
        # 분포 유사성 계산 (코사인 유사도)
        similarity = self._calculate_distribution_similarity(model_dist, reference_dist)
        
        return similarity
    
    def _calculate_class_distribution(self, classes: List[str]) -> Dict[str, float]:
        """
        클래스 분포를 계산합니다.
        
        Args:
            classes (List[str]): 클래스 레이블 리스트
            
        Returns:
            Dict[str, float]: 클래스별 분포
        """
        total = len(classes)
        if total == 0:
            return {}
        
        distribution = {}
        for class_name in set(classes):
            distribution[class_name] = classes.count(class_name) / total
        
        return distribution
    
    def _calculate_distribution_similarity(self, dist1: Dict[str, float], 
                                         dist2: Dict[str, float]) -> float:
        """
        두 분포 간의 유사성을 계산합니다.
        
        Args:
            dist1 (Dict[str, float]): 첫 번째 분포
            dist2 (Dict[str, float]): 두 번째 분포
            
        Returns:
            float: 분포 유사성 (0-1)
        """
        # 모든 클래스 수집
        all_classes = set(dist1.keys()) | set(dist2.keys())
        
        if not all_classes:
            return 1.0
        
        # 벡터 생성
        vec1 = np.array([dist1.get(cls, 0.0) for cls in all_classes])
        vec2 = np.array([dist2.get(cls, 0.0) for cls in all_classes])
        
        # 코사인 유사도 계산
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        similarity = dot_product / (norm1 * norm2)
        
        return similarity
