"""
정확도 계산기
기하학적 정확성을 평가하여 정확도 점수를 계산합니다.
"""

import numpy as np
import logging
from typing import Dict, List, Tuple, Optional
from ..utils.exception_handler import EvaluationExceptionHandler

logger = logging.getLogger(__name__)

class AccuracyCalculator:
    """정확도 점수 계산기"""
    
    def __init__(self, config: Dict):
        """
        정확도 계산기 초기화
        
        Args:
            config (Dict): 정확도 계산 설정
        """
        self.config = config
        self.geometric_accuracy_weight = config.get('geometric_accuracy_weight', 0.5)
        self.shape_similarity_weight = config.get('shape_similarity_weight', 0.3)
        self.proportional_accuracy_weight = config.get('proportional_accuracy_weight', 0.2)
        self.tolerance_threshold = config.get('tolerance_threshold', 0.1)
        
        logger.info("정확도 계산기가 초기화되었습니다.")
    
    @EvaluationExceptionHandler.handle_metric_calculation_error("Accuracy", {'score': 0.0})
    def calculate_accuracy(self, model_3d: Dict, ground_truth_3d: Dict) -> Dict:
        """
        정확도 점수 계산
        
        Args:
            model_3d (Dict): 3D 모델 데이터
            ground_truth_3d (Dict): Ground Truth 데이터
            
        Returns:
            Dict: 정확도 점수 및 상세 정보
        """
        try:
            logger.info("정확도 점수 계산 시작...")
            
            # 기하학적 정확도 점수 계산
            geometric_accuracy_score = self._calculate_geometric_accuracy(model_3d, ground_truth_3d)
            
            # 형태 유사도 점수 계산
            shape_similarity_score = self._calculate_shape_similarity(model_3d, ground_truth_3d)
            
            # 비례 정확도 점수 계산
            proportional_accuracy_score = self._calculate_proportional_accuracy(model_3d, ground_truth_3d)
            
            # 종합 정확도 점수 계산
            accuracy_score = (
                geometric_accuracy_score * self.geometric_accuracy_weight +
                shape_similarity_score * self.shape_similarity_weight +
                proportional_accuracy_score * self.proportional_accuracy_weight
            )
            
            # 결과 정리
            result = {
                'score': accuracy_score,
                'details': {
                    'geometric_accuracy': geometric_accuracy_score,
                    'shape_similarity': shape_similarity_score,
                    'proportional_accuracy': proportional_accuracy_score,
                    'weights': {
                        'geometric_accuracy': self.geometric_accuracy_weight,
                        'shape_similarity': self.shape_similarity_weight,
                        'proportional_accuracy': self.proportional_accuracy_weight
                    }
                }
            }
            
            logger.info(f"정확도 점수 계산 완료: {accuracy_score:.2f}")
            return result
            
        except Exception as e:
            logger.error(f"정확도 점수 계산 중 오류: {str(e)}")
            return {'score': 0.0, 'details': {'error': str(e)}}
    
    def _calculate_geometric_accuracy(self, model_3d: Dict, ground_truth_3d: Dict) -> float:
        """기하학적 정확도 점수 계산"""
        try:
            model_vertices = model_3d.get('vertices', [])
            gt_vertices = ground_truth_3d.get('vertices', [])
            
            if not model_vertices or not gt_vertices:
                return 0.0
            
            # 정점 수가 다를 경우 샘플링
            if len(model_vertices) != len(gt_vertices):
                model_vertices, gt_vertices = self._align_vertices(model_vertices, gt_vertices)
            
            if not model_vertices or not gt_vertices:
                return 0.0
            
            # 거리 기반 정확도 계산
            distances = []
            for model_vertex, gt_vertex in zip(model_vertices, gt_vertices):
                distance = np.linalg.norm(np.array(model_vertex) - np.array(gt_vertex))
                distances.append(distance)
            
            # 평균 거리 기반 정확도 점수
            mean_distance = np.mean(distances)
            accuracy_score = max(0.0, 1.0 - mean_distance / self.tolerance_threshold)
            
            return min(1.0, accuracy_score)
            
        except Exception as e:
            logger.error(f"기하학적 정확도 계산 중 오류: {str(e)}")
            return 0.0
    
    def _calculate_shape_similarity(self, model_3d: Dict, ground_truth_3d: Dict) -> float:
        """형태 유사도 점수 계산"""
        try:
            model_vertices = model_3d.get('vertices', [])
            gt_vertices = ground_truth_3d.get('vertices', [])
            
            if not model_vertices or not gt_vertices:
                return 0.0
            
            # 형태 특성 추출
            model_features = self._extract_shape_features(model_vertices)
            gt_features = self._extract_shape_features(gt_vertices)
            
            if not model_features or not gt_features:
                return 0.0
            
            # 특성 유사도 계산
            similarity_score = self._calculate_feature_similarity(model_features, gt_features)
            
            return min(1.0, max(0.0, similarity_score))
            
        except Exception as e:
            logger.error(f"형태 유사도 계산 중 오류: {str(e)}")
            return 0.0
    
    def _calculate_proportional_accuracy(self, model_3d: Dict, ground_truth_3d: Dict) -> float:
        """비례 정확도 점수 계산"""
        try:
            model_vertices = model_3d.get('vertices', [])
            gt_vertices = ground_truth_3d.get('vertices', [])
            
            if not model_vertices or not gt_vertices:
                return 0.0
            
            # 모델의 비례 특성 계산
            model_proportions = self._calculate_proportions(model_vertices)
            gt_proportions = self._calculate_proportions(gt_vertices)
            
            if not model_proportions or not gt_proportions:
                return 0.0
            
            # 비례 정확도 계산
            proportion_accuracy = self._calculate_proportion_accuracy(model_proportions, gt_proportions)
            
            return min(1.0, max(0.0, proportion_accuracy))
            
        except Exception as e:
            logger.error(f"비례 정확도 계산 중 오류: {str(e)}")
            return 0.0
    
    def _align_vertices(self, model_vertices: List, gt_vertices: List) -> Tuple[List, List]:
        """정점 수를 맞추기 위한 정렬"""
        try:
            if len(model_vertices) == len(gt_vertices):
                return model_vertices, gt_vertices
            
            # 더 적은 수의 정점을 기준으로 샘플링
            min_count = min(len(model_vertices), len(gt_vertices))
            
            if len(model_vertices) > min_count:
                # 모델 정점 샘플링
                indices = np.random.choice(len(model_vertices), min_count, replace=False)
                model_vertices = [model_vertices[i] for i in indices]
            
            if len(gt_vertices) > min_count:
                # GT 정점 샘플링
                indices = np.random.choice(len(gt_vertices), min_count, replace=False)
                gt_vertices = [gt_vertices[i] for i in indices]
            
            return model_vertices, gt_vertices
            
        except Exception as e:
            logger.error(f"정점 정렬 중 오류: {str(e)}")
            return [], []
    
    def _extract_shape_features(self, vertices: List) -> Dict:
        """형태 특성 추출"""
        try:
            if not vertices:
                return {}
            
            vertices_array = np.array(vertices)
            
            # 중심점 계산
            center = np.mean(vertices_array, axis=0)
            
            # 중심에서의 거리 분포
            distances = np.linalg.norm(vertices_array - center, axis=1)
            
            # 형태 특성
            features = {
                'center': center.tolist(),
                'mean_distance': np.mean(distances),
                'std_distance': np.std(distances),
                'max_distance': np.max(distances),
                'min_distance': np.min(distances),
                'volume_estimate': self._estimate_volume(vertices_array),
                'surface_area_estimate': self._estimate_surface_area(vertices_array)
            }
            
            return features
            
        except Exception as e:
            logger.error(f"형태 특성 추출 중 오류: {str(e)}")
            return {}
    
    def _calculate_feature_similarity(self, model_features: Dict, gt_features: Dict) -> float:
        """특성 유사도 계산"""
        try:
            if not model_features or not gt_features:
                return 0.0
            
            similarities = []
            
            # 거리 기반 특성 유사도
            distance_features = ['mean_distance', 'std_distance', 'max_distance', 'min_distance']
            for feature in distance_features:
                if feature in model_features and feature in gt_features:
                    model_val = model_features[feature]
                    gt_val = gt_features[feature]
                    
                    if gt_val > 0:
                        similarity = 1.0 - abs(model_val - gt_val) / gt_val
                        similarities.append(max(0.0, similarity))
            
            # 부피 및 표면적 유사도
            volume_features = ['volume_estimate', 'surface_area_estimate']
            for feature in volume_features:
                if feature in model_features and feature in gt_features:
                    model_val = model_features[feature]
                    gt_val = gt_features[feature]
                    
                    if gt_val > 0:
                        similarity = 1.0 - abs(model_val - gt_val) / gt_val
                        similarities.append(max(0.0, similarity))
            
            if not similarities:
                return 0.0
            
            return np.mean(similarities)
            
        except Exception as e:
            logger.error(f"특성 유사도 계산 중 오류: {str(e)}")
            return 0.0
    
    def _calculate_proportions(self, vertices: List) -> Dict:
        """비례 특성 계산"""
        try:
            if not vertices:
                return {}
            
            vertices_array = np.array(vertices)
            
            # 각 축별 범위 계산
            x_range = np.max(vertices_array[:, 0]) - np.min(vertices_array[:, 0])
            y_range = np.max(vertices_array[:, 1]) - np.min(vertices_array[:, 1])
            z_range = np.max(vertices_array[:, 2]) - np.min(vertices_array[:, 2])
            
            # 비례 계산
            proportions = {
                'x_y_ratio': x_range / y_range if y_range > 0 else 1.0,
                'x_z_ratio': x_range / z_range if z_range > 0 else 1.0,
                'y_z_ratio': y_range / z_range if z_range > 0 else 1.0,
                'aspect_ratios': [x_range, y_range, z_range]
            }
            
            return proportions
            
        except Exception as e:
            logger.error(f"비례 특성 계산 중 오류: {str(e)}")
            return {}
    
    def _calculate_proportion_accuracy(self, model_proportions: Dict, gt_proportions: Dict) -> float:
        """비례 정확도 계산"""
        try:
            if not model_proportions or not gt_proportions:
                return 0.0
            
            accuracies = []
            
            # 비율 정확도 계산
            ratio_features = ['x_y_ratio', 'x_z_ratio', 'y_z_ratio']
            for feature in ratio_features:
                if feature in model_proportions and feature in gt_proportions:
                    model_val = model_proportions[feature]
                    gt_val = gt_proportions[feature]
                    
                    if gt_val > 0:
                        accuracy = 1.0 - abs(model_val - gt_val) / gt_val
                        accuracies.append(max(0.0, accuracy))
            
            # 종횡비 정확도 계산
            if 'aspect_ratios' in model_proportions and 'aspect_ratios' in gt_proportions:
                model_ratios = model_proportions['aspect_ratios']
                gt_ratios = gt_proportions['aspect_ratios']
                
                if len(model_ratios) == len(gt_ratios):
                    for model_ratio, gt_ratio in zip(model_ratios, gt_ratios):
                        if gt_ratio > 0:
                            accuracy = 1.0 - abs(model_ratio - gt_ratio) / gt_ratio
                            accuracies.append(max(0.0, accuracy))
            
            if not accuracies:
                return 0.0
            
            return np.mean(accuracies)
            
        except Exception as e:
            logger.error(f"비례 정확도 계산 중 오류: {str(e)}")
            return 0.0
    
    def _estimate_volume(self, vertices: np.ndarray) -> float:
        """부피 추정"""
        try:
            if len(vertices) < 4:
                return 0.0
            
            # Convex Hull을 사용한 부피 추정
            from scipy.spatial import ConvexHull
            hull = ConvexHull(vertices)
            return hull.volume
            
        except Exception as e:
            logger.error(f"부피 추정 중 오류: {str(e)}")
            # 간단한 바운딩 박스 부피로 대체
            try:
                min_coords = np.min(vertices, axis=0)
                max_coords = np.max(vertices, axis=0)
                return np.prod(max_coords - min_coords)
            except:
                return 0.0
    
    def _estimate_surface_area(self, vertices: np.ndarray) -> float:
        """표면적 추정"""
        try:
            if len(vertices) < 3:
                return 0.0
            
            # 간단한 표면적 추정 (정점 분포 기반)
            min_coords = np.min(vertices, axis=0)
            max_coords = np.max(vertices, axis=0)
            dimensions = max_coords - min_coords
            
            # 바운딩 박스 표면적
            surface_area = 2 * (dimensions[0] * dimensions[1] + 
                               dimensions[1] * dimensions[2] + 
                               dimensions[0] * dimensions[2])
            
            return surface_area
            
        except Exception as e:
            logger.error(f"표면적 추정 중 오류: {str(e)}")
            return 0.0
    
    def get_details(self) -> Dict:
        """정확도 계산기 상세 정보 반환"""
        return {
            'name': 'AccuracyCalculator',
            'config': self.config,
            'weights': {
                'geometric_accuracy': self.geometric_accuracy_weight,
                'shape_similarity': self.shape_similarity_weight,
                'proportional_accuracy': self.proportional_accuracy_weight
            },
            'parameters': {
                'tolerance_threshold': self.tolerance_threshold
            }
        }
