# app/api.py
import numpy as np
from fastapi import APIRouter, HTTPException
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

from . import schemas

# Model Preparation
iris = load_iris()
X = iris.data
y = iris.target
model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X, y)
model_info = {
    "name": "Iris Species Predictor",
    "version": "1.0",
    "description": "A simple Decision Tree classifier for Iris species.",
}
print(f"'{model_info['name']}' model is ready.")

# API Router
router = APIRouter()


@router.get("/model", response_model=schemas.ModelInfo)
async def get_model_info():
    """Returns information about the loaded machine learning model."""
    return model_info


@router.post("/predict", response_model=schemas.IrisPrediction)
async def predict_species(iris_input: schemas.IrisInput):
    """
    Predicts the species of an Iris flower based on its features.

    - **Input**: Four features of the Iris flower (JSON).
    - **Output**: Predicted species name and class index (JSON).
    """
    try:
        input_data = np.array(
            [
                [
                    iris_input.sepal_length,
                    iris_input.sepal_width,
                    iris_input.petal_length,
                    iris_input.petal_width,
                ]
            ]
        )

        prediction_value = model.predict(input_data)
        predicted_class_index = int(prediction_value[0])

        predicted_species_name = iris.target_names[predicted_class_index]

        return schemas.IrisPrediction(
            species_name=predicted_species_name, prediction=predicted_class_index
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
