import pynvml
from loguru import logger
from typing import Dict, List
import time
from datetime import datetime, timedelta
from src.utils.influxdb_client import InfluxDBManager
import subprocess
import json
from influxdb_client import InfluxDBClient, Point
from influxdb_client.client.write_api import SYNCHRONOUS
from config.config import settings

class GPUMonitor:
    def __init__(self):
        try:
            pynvml.nvmlInit()
            self.influxdb = InfluxDBManager()
            self.influx_client = InfluxDBClient(
                url=settings.INFLUXDB_URL,
                token=settings.INFLUXDB_TOKEN,
                org=settings.INFLUXDB_ORG
            )
            self.write_api = self.influx_client.write_api(write_options=SYNCHRONOUS)
            logger.info("GPU 모니터링 초기화 완료")
        except Exception as e:
            logger.error(f"GPU 모니터링 초기화 실패: {str(e)}")

    def get_gpu_info(self) -> List[Dict]:
        try:
            device_count = pynvml.nvmlDeviceGetCount()
            gpus = []
            
            for i in range(device_count):
                handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                info = pynvml.nvmlDeviceGetMemoryInfo(handle)
                util = pynvml.nvmlDeviceGetUtilizationRates(handle)
                power = pynvml.nvmlDeviceGetPowerUsage(handle)
                power_limit = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle)
                
                # nvidia-smi를 사용하여 온도 정보 가져오기
                try:
                    result = subprocess.run(
                        ['nvidia-smi', '--query-gpu=temperature.gpu,fan.speed,power.draw,memory.used,memory.total', 
                         '--format=csv,noheader,nounits'],
                        capture_output=True, text=True
                    )
                    if result.returncode == 0:
                        temp, fan_speed, power, mem_used, mem_total = map(float, result.stdout.strip().split(','))
                    else:
                        temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
                        fan_speed = pynvml.nvmlDeviceGetFanSpeed(handle)
                        power = pynvml.nvmlDeviceGetPowerUsage(handle)
                        mem_used = info.used / 1024**2
                        mem_total = info.total / 1024**2
                except Exception as e:
                    logger.error(f"GPU {i} 온도 정보 가져오기 실패: {str(e)}")
                    temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
                    fan_speed = pynvml.nvmlDeviceGetFanSpeed(handle)
                    power = pynvml.nvmlDeviceGetPowerUsage(handle)
                    mem_used = info.used / 1024**2
                    mem_total = info.total / 1024**2
                
                # 온도가 비정상적으로 높은 경우 로깅
                if temp > 100:
                    logger.warning(f"GPU {i}의 온도가 비정상적으로 높습니다: {temp}°C")
                
                gpu_data = {
                    'id': i,
                    'name': pynvml.nvmlDeviceGetName(handle).decode('utf-8'),
                    'temperature': float(temp),
                    'fan_speed': float(fan_speed),
                    'memory_used': float(mem_used),
                    'memory_total': float(mem_total),
                    'utilization': util.gpu,
                    'power_draw': float(power) / 1000.0,  # mW를 W로 변환
                    'power_limit': power_limit / 1000.0  # mW를 W로 변환
                }
                gpus.append(gpu_data)
            
            return gpus
        except Exception as e:
            logger.error(f"GPU 정보 수집 중 오류 발생: {str(e)}")
            return []

    def get_metrics_history(self, start_time: datetime, end_time: datetime) -> List[Dict]:
        """지정된 기간 동안의 메트릭 히스토리를 반환합니다."""
        try:
            # InfluxDB에서 GPU 메트릭 데이터 조회
            query = f'''
                from(bucket: "{self.influxdb.bucket}")
                    |> range(start: {start_time.isoformat()}, stop: {end_time.isoformat()})
                    |> filter(fn: (r) => r["_measurement"] == "gpu_metrics")
                    |> filter(fn: (r) => r["_field"] == "temperature")
            '''
            
            result = self.influxdb.query_api.query(query)
            
            metrics = []
            for table in result:
                for record in table.records:
                    metrics.append({
                        '_measurement': record.get_measurement(),
                        '_time': record.get_time().isoformat(),
                        '_value': record.get_value(),
                        'gpu_id': record.values.get('gpu_id', '')
                    })
            
            return metrics
        except Exception as e:
            logger.error(f"메트릭 히스토리 조회 중 오류 발생: {str(e)}")
            return []

    def save_to_influxdb(self, gpu_data):
        try:
            # GPU 데이터 파싱
            temp, fan_speed, power, mem_used, mem_total = map(float, gpu_data.split(','))
            
            # InfluxDB에 데이터 포인트 생성
            point = Point("gpu_metrics") \
                .field("temperature", temp) \
                .field("fan_speed", fan_speed) \
                .field("power_draw", power) \
                .field("memory_used", mem_used) \
                .field("memory_total", mem_total) \
                .time(datetime.utcnow())
            
            # 데이터 저장
            self.write_api.write(bucket=settings.INFLUXDB_BUCKET, record=point)
            
        except Exception as e:
            print(f"Error saving to InfluxDB: {str(e)}")

    def check_gpu_status(self) -> List[Dict]:
        alerts = []
        gpus = self.get_gpu_info()
        
        for gpu in gpus:
            if gpu['temperature'] > 85.0:  # 임계값 설정 필요
                alerts.append({
                    'type': 'temperature',
                    'gpu_id': gpu['id'],
                    'message': f"GPU {gpu['id']} 온도가 높습니다: {gpu['temperature']}°C"
                })
            
            if gpu['memory_used'] / gpu['memory_total'] > 0.9:
                alerts.append({
                    'type': 'memory',
                    'gpu_id': gpu['id'],
                    'message': f"GPU {gpu['id']} 메모리 사용량이 높습니다: {gpu['memory_used']:.1f}MB / {gpu['memory_total']:.1f}MB"
                })
            
            if gpu['power_draw'] / gpu['power_limit'] > 0.9:
                alerts.append({
                    'type': 'power',
                    'gpu_id': gpu['id'],
                    'message': f"GPU {gpu['id']} 전력 사용량이 높습니다: {gpu['power_draw']:.1f}W / {gpu['power_limit']:.1f}W"
                })
        
        if alerts:
            # InfluxDB에 데이터 저장
            self.save_to_influxdb(f"{gpu['temperature']},{gpu['fan_speed']},{gpu['power_draw']},{gpu['memory_used']},{gpu['memory_total']}")
        
        return alerts

    def __del__(self):
        try:
            pynvml.nvmlShutdown()
            if hasattr(self, 'influx_client'):
                self.influx_client.close()
        except:
            pass 