# api.py # API 엔드포인트의 실제 로직을 정의합니다. from fastapi import APIRouter, HTTPException from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier import numpy as np from . import schemas # 현재 패키지(디렉터리)의 schemas 모듈을 임포트 # --- 모델 준비 --- # 실제 프로덕션 환경에서는 미리 학습되고 저장된 모델 파일(e.g., .joblib, .pkl)을 로드해야 합니다. # 여기서는 실습 편의를 위해 서버 시작 시 간단한 모델을 학습시켜 메모리에 올립니다. iris = load_iris() X = iris.data y = iris.target # 간단한 결정 트리 모델 훈련 model = DecisionTreeClassifier(max_depth=3, random_state=42) model.fit(X, y) model_info = { "name": "Iris Species Predictor", "version": "1.0", "description": "A simple Decision Tree classifier for Iris species." } print(f"'{model_info['name']}' 모델이 준비되었습니다.") # --- API 라우터 설정 --- # APIRouter를 사용하면 엔드포인트를 모듈화하여 관리할 수 있습니다. router = APIRouter() @router.get("/model", response_model=schemas.ModelInfo) async def get_model_info(): """로드된 머신러닝 모델의 정보를 반환합니다.""" return model_info @router.post("/predict", response_model=schemas.IrisPrediction) async def predict_species(iris_input: schemas.IrisInput): """ 붓꽃의 특성(sepal/petal length/width)을 입력받아 품종을 예측합니다. - **입력**: 붓꽃의 4가지 특성 (JSON) - **출력**: 예측된 품종 이름과 클래스 번호 (JSON) """ try: # Pydantic 모델에서 받은 입력을 numpy 배열로 변환 input_data = np.array([[ iris_input.sepal_length, iris_input.sepal_width, iris_input.petal_length, iris_input.petal_width ]]) # 모델 예측 prediction_value = model.predict(input_data) predicted_class_index = int(prediction_value[0]) # 예측된 클래스 인덱스에 해당하는 품종 이름 찾기 predicted_species_name = iris.target_names[predicted_class_index] return schemas.IrisPrediction( species_name=predicted_species_name, prediction=predicted_class_index ) except Exception as e: # 예외 발생 시, 서버 에러(500)와 함께 상세 내용을 반환 raise HTTPException(status_code=500, detail=str(e))