import os
import shutil
import threading
import time

import numpy as np
import requests
import tritonclient.http as httpclient
import uvicorn
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier

# --- 1. 기본 설정 ---
MODEL_DIR = "triton_model_repo"
MODEL_NAME = "sklearn_rf"
MODEL_VERSION = "1"
FASTAPI_PORT = 8001
TRITON_PORT = 8000


def prepare_model_and_repo():
    """
    Scikit-learn 모델을 학습하고, ONNX로 변환한 후,
    Triton 서버가 사용할 수 있는 모델 리포지토리 구조를 생성합니다.
    """
    print("--- 1. 모델 학습 및 ONNX 변환 시작 ---")
    # 간단한 분류 모델 학습
    X, y = make_classification(
        n_samples=1000, n_features=10, n_classes=2, random_state=42
    )
    model = RandomForestClassifier(random_state=42)
    model.fit(X, y)
    print("RandomForestClassifier 모델 학습 완료.")

    # ONNX 변환을 위한 초기 타입 정의
    initial_type = [("float_input", FloatTensorType([None, X.shape[1]]))]
    onnx_model = convert_sklearn(model, initial_types=initial_type)
    print("모델을 ONNX 형식으로 변환 완료.")

    # Triton 모델 리포지토리 생성
    model_repo_path = os.path.join(MODEL_DIR, MODEL_NAME, MODEL_VERSION)
    if os.path.exists(MODEL_DIR):
        shutil.rmtree(MODEL_DIR)
    os.makedirs(model_repo_path)

    # ONNX 모델 저장
    onnx_model_path = os.path.join(model_repo_path, "model.onnx")
    with open(onnx_model_path, "wb") as f:
        f.write(onnx_model.SerializeToString())
    print(f"ONNX 모델을 '{onnx_model_path}'에 저장했습니다.")

    # Triton 설정 파일(config.pbtxt) 생성
    config_pbtxt = f"""
name: "{MODEL_NAME}"
platform: "onnxruntime_onnx"
input [
  {{
    name: "float_input"
    data_type: TYPE_FP32
    dims: [ -1, {X.shape[1]} ]
  }}
]
output [
  {{
    name: "output_label"
    data_type: TYPE_INT64
    dims: [ -1 ]
  }},
  {{
    name: "output_probability"
    data_type: TYPE_FP32
    dims: [ -1, 2 ]
  }}
]
dynamic_batching {{
  preferred_batch_size: [4, 8]
  max_queue_delay_microseconds: 100
}}
"""
    config_path = os.path.join(MODEL_DIR, MODEL_NAME, "config.pbtxt")
    with open(config_path, "w") as f:
        f.write(config_pbtxt)
    print(f"Triton 설정 파일을 '{config_path}'에 생성했습니다.")
    return X[0].astype(np.float32)  # 벤치마킹에 사용할 샘플 데이터 반환


def run_fastapi_server(stop_event):
    """
    간단한 FastAPI 서버를 백그라운드 스레드에서 실행합니다.
    Triton과 동일한 ONNX 모델을 서빙합니다. (실제로는 onnxruntime 사용 필요)
    여기서는 간단히 scikit-learn 모델을 직접 사용합니다.
    """
    from fastapi import FastAPI
    from pydantic import BaseModel

    # 모델 재학습 (간단한 예제를 위해)
    X, y = make_classification(
        n_samples=1000, n_features=10, n_classes=2, random_state=42
    )
    model = RandomForestClassifier(random_state=42)
    model.fit(X, y)

    app = FastAPI()

    class Item(BaseModel):
        data: list

    @app.post("/predict")
    def predict(item: Item):
        data = np.array(item.data).reshape(1, -1)
        prediction = model.predict(data)
        return {"prediction": prediction.tolist()}

    config = uvicorn.Config(app, host="0.0.0.0", port=FASTAPI_PORT, log_level="info")
    server = uvicorn.Server(config)

    # stop_event가 설정될 때까지 서버 실행
    server.run()
    # while not stop_event.is_set():
    #     # uvicorn.run()은 블로킹 함수이므로, 실제 백그라운드 실행을 위해서는
    #     # 별도 프로세스나 더 정교한 스레딩 제어가 필요합니다.
    #     # 이 예제에서는 시연을 위해 메인 스레드에서 제어합니다.
    #     pass


def benchmark_server(name, url, sample_data, n_requests=100):
    """서버 성능을 벤치마킹합니다."""
    print(f"\n--- {name} 벤치마킹 시작 (요청 {n_requests}회) ---")

    # Latency 측정 (순차 요청)
    latencies = []
    start_time = time.time()
    for _ in range(n_requests):
        req_start_time = time.time()
        if "Triton" in name:
            triton_client = httpclient.InferenceServerClient(url=url)
            inputs = [httpclient.InferInput("float_input", [1, 10], "FP32")]
            inputs[0].set_data_from_numpy(sample_data.reshape(1, -1), binary_data=True)
            triton_client.infer(model_name=MODEL_NAME, inputs=inputs)
        else:
            requests.post(url, json={"data": sample_data.tolist()})
        latencies.append(time.time() - req_start_time)
    total_time = time.time() - start_time

    avg_latency = np.mean(latencies) * 1000  # ms
    throughput = n_requests / total_time  # rps

    print(f"평균 지연 시간 (Avg. Latency): {avg_latency:.2f} ms")
    print(f"처리량 (Throughput): {throughput:.2f} req/s")
    return avg_latency, throughput


def main():
    """메인 실행 함수"""
    sample_data = prepare_model_and_repo()

    print("\n--- 2. 서버 실행 준비 ---")
    print(f"FastAPI 서버가 포트 {FASTAPI_PORT}에서 실행될 예정입니다.")
    print("Triton 서버를 별도의 터미널에서 실행해주세요.")
    print("\n" + "=" * 50)
    print(" 다음 Docker 명령어를 사용하여 Triton 서버를 실행하세요:")
    print(f"  docker run --rm -p {TRITON_PORT}:8000 -p 8001:8001 -p 8002:8002 \\")
    print(
        f"           -v {os.getcwd()}/{MODEL_DIR}:/models nvcr.io/nvidia/tritonserver:23.10-py3 tritonserver --model-repository=/models"
    )
    print("=" * 50 + "\n")

    input("Triton 서버가 실행되면 Enter 키를 눌러 벤치마킹을 시작하세요...")

    # FastAPI 서버를 백그라운드 스레드에서 실행
    # stop_event = threading.Event()
    # fastapi_thread = threading.Thread(target=run_fastapi_server, args=(stop_event,))
    # fastapi_thread.start()
    # time.sleep(5)  # 서버가 시작될 때까지 대기
    print("\n[알림] 이 스크립트는 FastAPI 서버를 직접 실행하지 않습니다.")
    print(
        "별도 터미널에서 'python -m uvicorn fastapi_server:app --port 8001'과 같이 실행하거나,"
    )
    print("아래 코드를 활성화하여 스레드 기반으로 실행할 수 있습니다.")
    print("이 데모에서는 Triton과의 비교에 집중합니다.")

    # 벤치마킹 실행
    triton_latency, triton_throughput = benchmark_server(
        "NVIDIA Triton Server", f"localhost:{TRITON_PORT}", sample_data
    )
    # fastapi_latency, fastapi_throughput = benchmark_server(
    #     "FastAPI Server", f"http://localhost:{FASTAPI_PORT}/predict", sample_data
    # )

    print("\n\n--- 벤치마킹 결과 요약 ---")
    print(
        f"| {'Server':<20} | {'Avg. Latency (ms)':<20} | {'Throughput (req/s)':<20} |"
    )
    print(f"| {'-'*20} | {'-'*20} | {'-'*20} |")
    # print(f"| {'FastAPI Server':<20} | {fastapi_latency:<20.2f} | {fastapi_throughput:<20.2f} |")
    print(
        f"| {'NVIDIA Triton Server':<20} | {triton_latency:<20.2f} | {triton_throughput:<20.2f} |"
    )
    print("\n[참고] FastAPI 벤치마크를 위해서는 위 코드의 주석을 해제하고,")
    print("`run_fastapi_server` 함수 및 관련 스레드 코드를 활성화해야 합니다.")

    # FastAPI 서버 종료
    # stop_event.set()
    # fastapi_thread.join()
    print("\n실습 완료!")


if __name__ == "__main__":
    main()
