import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from fairlearn.datasets import fetch_adult
from fairlearn.metrics import (
    MetricFrame,
    demographic_parity_difference,
    equalized_odds_difference,
)
from fairlearn.postprocessing import ThresholdOptimizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler


def run():
    """
    Fairlearn을 사용하여 성인 소득 데이터셋의 성별 편향성을 탐지하고 완화하는 예제입니다.
    1. 데이터 로드 및 전처리
    2. 기준 모델 학습 및 성능/편향성 평가
    3. Fairlearn의 ThresholdOptimizer를 사용한 편향성 완화
    4. 완화 전후의 성능 및 편향성 지표 비교
    """
    # 1. 데이터 로드 및 전처리
    data = fetch_adult(as_frame=True)
    X_raw = data.data
    y = (data.target == ">50K").astype(int)

    # 민감한 특성(성별) 분리
    sensitive_features = X_raw["sex"]

    # 범주형 변수 인코딩 및 데이터 스케일링
    X = X_raw.drop(columns=["sex"])
    X = pd.get_dummies(X, drop_first=True)

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    X_scaled = pd.DataFrame(X_scaled, columns=X.columns)

    X_train, X_test, y_train, y_test, sf_train, sf_test = train_test_split(
        X_scaled, y, sensitive_features, test_size=0.2, random_state=42, stratify=y
    )

    # 2. 기준 모델 학습 및 평가
    baseline_model = LogisticRegression(solver="liblinear", random_state=42)
    baseline_model.fit(X_train, y_train)
    y_pred_baseline = baseline_model.predict(X_test)

    # MetricFrame을 사용한 평가
    metrics = {
        "accuracy": accuracy_score,
        "demographic_parity_difference": demographic_parity_difference,
        "equalized_odds_difference": equalized_odds_difference,
    }

    grouped_on_sex = MetricFrame(
        metrics=metrics,
        y_true=y_test,
        y_pred=y_pred_baseline,
        sensitive_features=sf_test,
    )

    print("--- 기준 모델 평가 결과 ---")
    print(grouped_on_sex.overall)
    print("\n--- 성별에 따른 그룹별 평가 결과 ---")
    print(grouped_on_sex.by_group)

    # 3. Fairlearn ThresholdOptimizer를 사용한 편향성 완화
    # `constraint="demographic_parity"`는 그룹 간 예측 결과의 비율을 유사하게 만듭니다.
    postprocess_model = ThresholdOptimizer(
        estimator=baseline_model,
        constraints="demographic_parity",  # or "equalized_odds"
        objective="accuracy_score",
        prefit=True,
    )

    # 최적의 임계값을 찾기 위해 postprocess_model을 학습시킵니다.
    postprocess_model.fit(X_train, y_train, sensitive_features=sf_train)
    y_pred_postprocess = postprocess_model.predict(X_test, sensitive_features=sf_test)

    # 완화 후 모델 평가
    grouped_on_sex_postprocess = MetricFrame(
        metrics=metrics,
        y_true=y_test,
        y_pred=y_pred_postprocess,
        sensitive_features=sf_test,
    )

    print("\n\n--- 편향성 완화 후 모델 평가 결과 ---")
    print(grouped_on_sex_postprocess.overall)
    print("\n--- 성별에 따른 그룹별 평가 결과 (완화 후) ---")
    print(grouped_on_sex_postprocess.by_group)

    # 4. 결과 비교 및 시각화
    summary = pd.DataFrame(
        {
            "Baseline": grouped_on_sex.overall,
            "ThresholdOptimizer": grouped_on_sex_postprocess.overall,
        }
    )

    print("\n\n--- 최종 비교 ---")
    print(summary)

    # 시각화
    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

    summary.loc[["accuracy"]].T.plot(
        kind="bar", ax=axes[0], title="Accuracy Comparison", legend=False
    )
    axes[0].set_ylabel("Score")
    axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=0)

    summary.loc[["demographic_parity_difference", "equalized_odds_difference"]].T.plot(
        kind="bar", ax=axes[1], title="Fairness Metrics Comparison"
    )
    axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=0)

    plt.suptitle("Model Performance and Fairness Comparison")
    plt.tight_layout(rect=[0, 0, 1, 0.96])

    # 결과 이미지 파일로 저장
    plt.savefig("ai lecture/source_code/part_14_ai_ethics/fairness_comparison.png")
    print("\n결과 비교 그래프를 'fairness_comparison.png' 파일로 저장했습니다.")
    plt.show()


if __name__ == "__main__":
    run()
