import os

import faiss
import numpy as np
from dotenv import load_dotenv
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.docstore.document import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from sentence_transformers import CrossEncoder


def setup_environment():
    """환경 변수를 로드하고 OpenAI API 키를 설정합니다."""
    load_dotenv()
    if not os.getenv("OPENAI_API_KEY"):
        raise ValueError(
            "OPENAI_API_KEY가 설정되지 않았습니다. .env 파일을 확인하세요."
        )
    print("API 키 로드 완료.")


def prepare_documents():
    """지식 베이스로 사용할 샘플 문서를 준비합니다."""
    return [
        Document(
            page_content="LangChain은 LLM을 활용한 애플리케이션 개발을 돕는 프레임워크입니다. 다양한 모듈을 제공합니다.",
            metadata={"doc_id": 1},
        ),
        Document(
            page_content="RAG는 '검색 증강 생성'의 약자로, LLM이 외부 지식 베이스를 참조하여 답변을 생성하는 기술입니다.",
            metadata={"doc_id": 2},
        ),
        Document(
            page_content="고급 RAG 기법에는 쿼리를 여러 개로 변환하는 Multi-Query와 검색된 문서 순위를 재조정하는 Re-ranking이 있습니다.",
            metadata={"doc_id": 3},
        ),
        Document(
            page_content="Re-ranking에는 Bi-Encoder로 1차 검색 후, Cross-Encoder로 정교하게 순위를 다시 매기는 방식이 효과적입니다.",
            metadata={"doc_id": 4},
        ),
        Document(
            page_content="쿼리 변환(Query Transformation)은 사용자의 모호한 질문을 명확하게 만들어 검색 성능을 높이는 것을 목표로 합니다.",
            metadata={"doc_id": 5},
        ),
        Document(
            page_content="BM25는 키워드 기반 검색 알고리즘으로, 벡터 검색과 함께 사용하면(하이브리드 검색) 성능을 보완할 수 있습니다.",
            metadata={"doc_id": 6},
        ),
    ]


def build_hybrid_retriever(docs, embeddings_model):
    """키워드 검색(BM25)과 벡터 검색(FAISS)을 결합한 하이브리드 리트리버를 구축합니다."""
    # 1. FAISS 벡터 스토어 및 리트리버
    faiss_vectorstore = FAISS.from_documents(docs, embeddings_model)
    faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": 5})

    # 2. BM25 리트리버
    bm25_retriever = BM25Retriever.from_documents(docs)
    bm25_retriever.k = 5

    # 3. 앙상블 리트리버 (두 리트리버 결합)
    ensemble_retriever = EnsembleRetriever(
        retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
    )
    return ensemble_retriever


def get_query_transformation_chain(llm):
    """사용자의 질문을 여러 관점의 질문 3개로 변환하는 체인을 생성합니다."""
    prompt = PromptTemplate(
        input_variables=["question"],
        template="""당신은 AI 언어 모델입니다. 사용자의 질문을 검색에 더 효과적인 3개의 다른 버전으로 다시 작성해주세요.
        다양한 관점에서 질문을 재구성해야 합니다. 각 질문은 한 줄로 구분해주세요.
        원본 질문: {question}""",
    )
    return LLMChain(llm=llm, prompt=prompt, output_key="queries")


def rerank_documents(question, retrieved_docs):
    """CrossEncoder를 사용하여 검색된 문서들의 순위를 재조정합니다."""
    model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512)
    pairs = [(question, doc.page_content) for doc in retrieved_docs]
    scores = model.predict(pairs)

    # 점수가 높은 순으로 정렬
    sorted_docs = sorted(zip(scores, retrieved_docs), key=lambda x: x[0], reverse=True)
    return [doc for score, doc in sorted_docs]


def format_docs(docs):
    """검색된 문서들을 LLM 프롬프트에 넣기 좋은 형태로 포맷팅합니다."""
    return "\n\n".join(doc.page_content for doc in docs)


def main():
    """고급 RAG 파이프라인을 시연하는 메인 함수"""
    setup_environment()
    docs = prepare_documents()

    llm = ChatOpenAI(temperature=0)
    embeddings = HuggingFaceEmbeddings(model_name="jhgan/ko-sbert-nli")

    # 1. 하이브리드 리트리버 구축
    hybrid_retriever = build_hybrid_retriever(docs, embeddings)

    # 2. 쿼리 변환 체인 준비
    query_transformer_chain = get_query_transformation_chain(llm)

    # 3. 최종 답변 생성을 위한 프롬프트 및 체인
    template = """당신은 주어진 컨텍스트를 바탕으로 질문에 답변하는 AI 어시스턴트입니다.
    컨텍스트를 벗어난 답변은 하지 마세요.
    
    컨텍스트:
    {context}
    
    질문:
    {question}
    
    답변:
    """
    prompt = PromptTemplate.from_template(template)

    final_rag_chain = (
        RunnablePassthrough.assign(context=(lambda x: format_docs(x["documents"])))
        | prompt
        | llm
        | StrOutputParser()
    )

    # --- 고급 RAG 파이프라인 실행 ---
    question = "RAG의 검색 성능을 어떻게 높일 수 있을까?"
    print(f"\n원본 질문: {question}")

    # 쿼리 변환
    transformed_queries_result = query_transformer_chain.invoke({"question": question})
    transformed_queries = transformed_queries_result["queries"].strip().split("\n")
    print(f"\n변환된 쿼리:\n{transformed_queries_result['queries']}")

    # 변환된 각 쿼리로 문서 검색 및 종합
    all_retrieved_docs = []
    for q in transformed_queries:
        all_retrieved_docs.extend(hybrid_retriever.invoke(q))

    # 중복 제거
    unique_docs = {doc.metadata["doc_id"]: doc for doc in all_retrieved_docs}.values()
    print(
        f"\n1차 검색된 문서 (중복 제거 전 {len(all_retrieved_docs)}개, 후 {len(unique_docs)}개)"
    )

    # 재순위화
    reranked_docs = rerank_documents(question, list(unique_docs))
    top_k_reranked = reranked_docs[:3]  # 상위 3개만 사용

    print("\n재순위화 후 선택된 최종 문서:")
    for doc in top_k_reranked:
        print(f"- {doc.page_content} (Doc ID: {doc.metadata['doc_id']})")

    # 최종 답변 생성
    print("\n--- 최종 답변 생성 ---")
    answer = final_rag_chain.invoke({"documents": top_k_reranked, "question": question})
    print(answer)

    print("\n\n--- 비교: 기본 RAG 파이프라인 ---")
    basic_retrieved_docs = hybrid_retriever.invoke(question)
    basic_answer = final_rag_chain.invoke(
        {"documents": basic_retrieved_docs, "question": question}
    )
    print("기본 RAG로 검색된 문서:")
    for doc in basic_retrieved_docs:
        print(f"- {doc.page_content} (Doc ID: {doc.metadata['doc_id']})")
    print("\n기본 RAG 답변:")
    print(basic_answer)
    print("\n실습 완료!")


if __name__ == "__main__":
    main()
