"""
문맥 압축(Contextual Compression) RAG 예제

이 예제는 검색된 문서에서 사용자 질문과 관련된 핵심 내용만 추출하고, 
불필요한 내용을 제거하여 LLM에 전달함으로써 RAG 성능을 향상시키는 기법을 구현합니다.

요구사항:
- 필요한 패키지: langchain, langchain-openai, dotenv
- OpenAI API 키 필요 (.env 파일에 설정)

실행 방법:
$ python contextual_compression_rag.py
"""

import os
from typing import List

import numpy as np
from dotenv import load_dotenv
from langchain.chains import create_extraction_chain
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import (
    DocumentCompressorPipeline,
    EmbeddingsFilter,
    LLMChainExtractor,
)
from langchain.schema import Document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI


def setup_environment():
    """환경 설정 및 API 키 로드"""
    load_dotenv()
    if not os.getenv("OPENAI_API_KEY"):
        raise ValueError("OPENAI_API_KEY가 설정되지 않았습니다. .env 파일을 확인하세요.")
    print("API 키 로드 완료.")


def prepare_documents() -> List[Document]:
    """샘플 문서 준비"""
    # 실제 애플리케이션에서는 파일에서 문서를 로드할 수 있음
    documents = [
        Document(
            page_content="""
            # 파이썬 머신러닝 개요
            
            머신러닝(Machine Learning)은 컴퓨터가 명시적인 프로그래밍 없이 데이터로부터 학습하여 패턴을 찾고 예측하는 인공지능의 한 분야입니다.
            파이썬은 머신러닝 프로젝트에 가장 많이 사용되는 프로그래밍 언어로, 다양한 라이브러리를 통해 복잡한 알고리즘을 쉽게 구현할 수 있습니다.
            
            ## 주요 머신러닝 라이브러리
            
            1. **scikit-learn**: 지도학습, 비지도학습 알고리즘 및 모델 평가 도구를 제공하는 가장 기본적인 머신러닝 라이브러리입니다.
            2. **TensorFlow**: 구글이 개발한 딥러닝 프레임워크로, 대규모 신경망을 구축하는 데 적합합니다.
            3. **PyTorch**: 페이스북(Meta)이 개발한 유연하고 직관적인 딥러닝 프레임워크입니다.
            4. **XGBoost**: 그래디언트 부스팅 기반의 고성능 머신러닝 라이브러리입니다.
            
            ## 머신러닝 프로젝트 워크플로우
            
            1. 데이터 수집 및 전처리
            2. 탐색적 데이터 분석
            3. 특성 공학 및 선택
            4. 모델 선택 및 학습
            5. 모델 평가 및 튜닝
            6. 모델 배포 및 모니터링
            """,
            metadata={"source": "ml_overview.md", "page": 1},
        ),
        Document(
            page_content="""
            # 자연어 처리(NLP) 기술 동향
            
            자연어 처리는 컴퓨터가 인간의 언어를 이해하고 생성하는 기술 분야입니다.
            최근 트랜스포머(Transformer) 기반 모델의 등장으로 NLP 분야는 혁신적인 발전을 이루었습니다.
            
            ## 주요 NLP 모델 및 라이브러리
            
            1. **BERT**: 구글이 개발한 양방향 인코더 기반 모델로, 문맥을 고려한 단어 표현이 가능합니다.
            2. **GPT**: OpenAI가 개발한 자기회귀 모델로, 텍스트 생성에 탁월한 성능을 보입니다.
            3. **Hugging Face Transformers**: 다양한 사전 학습 모델을 쉽게 사용할 수 있는 라이브러리입니다.
            4. **spaCy**: 효율적인 자연어 처리를 위한 파이썬 라이브러리입니다.
            
            ## NLP 응용 분야
            
            - 감성 분석(Sentiment Analysis)
            - 개체명 인식(Named Entity Recognition)
            - 기계 번역(Machine Translation)
            - 질의응답 시스템(Question Answering)
            - 텍스트 요약(Text Summarization)
            """,
            metadata={"source": "nlp_trends.md", "page": 1},
        ),
        Document(
            page_content="""
            # 컴퓨터 비전 기술의 발전
            
            컴퓨터 비전은 컴퓨터가 디지털 이미지나 비디오로부터 정보를 추출하고 이해하는 기술입니다.
            딥러닝의 발전으로 이미지 인식, 객체 탐지, 이미지 생성 등 다양한 분야에서 혁신적인 성과를 보이고 있습니다.
            
            ## 주요 컴퓨터 비전 라이브러리 및 프레임워크
            
            1. **OpenCV**: 다양한 이미지 처리 기능을 제공하는 오픈소스 라이브러리입니다.
            2. **TensorFlow Object Detection API**: 객체 탐지 모델을 쉽게 구축할 수 있는 프레임워크입니다.
            3. **PyTorch Vision**: 이미지 분류, 세그먼테이션 등을 위한 모델과 도구를 제공합니다.
            4. **YOLO(You Only Look Once)**: 실시간 객체 탐지를 위한 알고리즘입니다.
            
            ## 컴퓨터 비전 응용 분야
            
            - 자율주행 차량
            - 의료 영상 분석
            - 안면 인식
            - 로봇 비전
            - 증강 현실
            """,
            metadata={"source": "computer_vision.md", "page": 1},
        ),
        Document(
            page_content="""
            # 강화학습의 기초와 응용
            
            강화학습은 에이전트가 환경과 상호작용하며 보상을 최대화하는 방향으로 학습하는 머신러닝의 한 분야입니다.
            게임, 로봇 제어, 추천 시스템 등 다양한 분야에 적용되고 있습니다.
            
            ## 주요 강화학습 알고리즘
            
            1. **Q-Learning**: 가치 기반 강화학습 알고리즘으로, 각 상태-행동 쌍의 Q값을 학습합니다.
            2. **Policy Gradient**: 정책 기반 알고리즘으로, 보상을 최대화하는 정책을 직접 학습합니다.
            3. **DQN(Deep Q-Network)**: 딥러닝을 Q-Learning에 결합한 알고리즘입니다.
            4. **PPO(Proximal Policy Optimization)**: 안정적인 정책 최적화를 위한 알고리즘입니다.
            
            ## 강화학습 프레임워크 및 라이브러리
            
            - **OpenAI Gym**: 강화학습 환경을 제공하는 표준 인터페이스입니다.
            - **Stable Baselines**: 강화학습 알고리즘 구현체를 제공하는 라이브러리입니다.
            - **RLlib**: 분산 강화학습을 위한 라이브러리입니다.
            """,
            metadata={"source": "reinforcement_learning.md", "page": 1},
        ),
    ]
    return documents


def create_document_index(documents, embeddings_model):
    """문서를 벡터 저장소에 인덱싱"""
    # 문서 분할
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", " ", ""]
    )
    splits = text_splitter.split_documents(documents)
    
    print(f"총 {len(documents)}개 문서를 {len(splits)}개 청크로 분할했습니다.")
    
    # 벡터 저장소 생성
    vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings_model)
    return vectorstore


def create_compression_retriever(vectorstore, llm, embeddings_model):
    """문맥 압축 리트리버 생성"""
    # 1. LLM 기반 관련 내용 추출기 생성 (관련 없는 텍스트 제거)
    llm_extractor = LLMChainExtractor.from_llm(llm)
    
    # 2. 임베딩 기반 관련성 필터링 (질문과 관련성이 낮은 문서 제거)
    embeddings_filter = EmbeddingsFilter(
        embeddings=embeddings_model,
        similarity_threshold=0.76,  # 0-1 사이 값, 높을수록 엄격한 필터링
    )
    
    # 3. 중복 내용 제거 필터 생성
    redundant_filter = EmbeddingsRedundantFilter(
        embeddings=embeddings_model,
        similarity_threshold=0.95,  # 높을수록 유사한 내용을 더 많이 유지
    )
    
    # 4. 문서 압축 파이프라인 구성
    doc_compressor = DocumentCompressorPipeline(
        transformers=[redundant_filter, embeddings_filter, llm_extractor]
    )
    
    # 5. 문맥 압축 리트리버 생성
    compression_retriever = ContextualCompressionRetriever(
        base_retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
        base_compressor=doc_compressor,
    )
    
    return compression_retriever


def create_rag_chain(retriever, llm):
    """RAG 체인 생성"""
    # 검색된 문서를 형식화하는 함수
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    # 프롬프트 템플릿
    template = """
    당신은 AI 기술에 대한 전문 지식을 갖춘 도우미입니다.
    다음 문맥 정보를 바탕으로 사용자의 질문에 답변해주세요.
    문맥에 없는 정보는 '제공된 문맥에서 해당 정보를 찾을 수 없습니다'라고 답변하세요.
    
    문맥:
    {context}
    
    질문: {question}
    
    답변:
    """
    
    # RAG 체인 구성
    from langchain.prompts import ChatPromptTemplate
    
    prompt = ChatPromptTemplate.from_template(template)
    
    chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    
    return chain


def extract_key_information(document, llm):
    """문서에서 핵심 정보 추출 (추가 데모용)"""
    # 추출할 정보 스키마 정의
    schema = {
        "properties": {
            "main_topic": {"type": "string"},
            "key_points": {"type": "array", "items": {"type": "string"}},
            "technologies": {"type": "array", "items": {"type": "string"}},
        },
        "required": ["main_topic", "key_points", "technologies"],
    }
    
    # 추출 체인 생성 및 실행
    chain = create_extraction_chain(schema, llm)
    result = chain.run(document.page_content)
    
    return result[0] if result else None


def main():
    """메인 실행 함수"""
    setup_environment()
    
    # 모델 초기화
    llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
    embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    
    # 문서 준비 및 인덱싱
    print("문서 준비 및 인덱싱 중...")
    documents = prepare_documents()
    vectorstore = create_document_index(documents, embeddings_model)
    
    # 기본 리트리버와 압축 리트리버 생성
    basic_retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
    compression_retriever = create_compression_retriever(vectorstore, llm, embeddings_model)
    
    # RAG 체인 생성
    basic_rag_chain = create_rag_chain(basic_retriever, llm)
    compressed_rag_chain = create_rag_chain(compression_retriever, llm)
    
    # 사용자 질의 설정
    questions = [
        "파이썬 머신러닝에서 가장 많이 사용되는 라이브러리는 무엇인가요?",
        "자연어 처리에서 트랜스포머의 역할은 무엇인가요?",
        "강화학습에 사용되는 알고리즘의 종류에는 어떤 것들이 있나요?"
    ]
    
    # 결과 비교
    for i, question in enumerate(questions):
        print(f"\n\n{'=' * 80}")
        print(f"질문 {i+1}: {question}")
        print(f"{'-' * 80}")
        
        # 기본 RAG로 검색된 문서
        basic_docs = basic_retriever.get_relevant_documents(question)
        print(f"\n1. 기본 리트리버로 검색된 문서 ({len(basic_docs)}개):")
        for j, doc in enumerate(basic_docs):
            source = doc.metadata.get('source', 'unknown')
            print(f"문서 {j+1} (출처: {source}):")
            print(f"{doc.page_content[:150]}...\n")
        
        # 압축 RAG로 검색된 문서
        compressed_docs = compression_retriever.get_relevant_documents(question)
        print(f"\n2. 압축 리트리버로 검색된 문서 ({len(compressed_docs)}개):")
        for j, doc in enumerate(compressed_docs):
            source = doc.metadata.get('source', 'unknown')
            print(f"문서 {j+1} (출처: {source}):")
            print(f"{doc.page_content}")
        
        # 기본 RAG 답변
        print(f"\n3. 기본 RAG 답변:")
        basic_answer = basic_rag_chain.invoke(question)
        print(basic_answer)
        
        # 압축 RAG 답변
        print(f"\n4. 압축 RAG 답변:")
        compressed_answer = compressed_rag_chain.invoke(question)
        print(compressed_answer)
        
        # 핵심 정보 추출 데모 (선택한 문서에 대해)
        if i == 0 and basic_docs:
            print(f"\n5. 핵심 정보 추출 데모:")
            extracted_info = extract_key_information(basic_docs[0], llm)
            if extracted_info:
                print(f"주제: {extracted_info.get('main_topic')}")
                print(f"핵심 요점:")
                for point in extracted_info.get('key_points', []):
                    print(f"- {point}")
                print(f"기술 목록:")
                for tech in extracted_info.get('technologies', []):
                    print(f"- {tech}")
    
    print("\n\n압축 RAG 예제가 완료되었습니다.")


if __name__ == "__main__":
    main() 