" This module is responsible for compressing the documents and returning the relevant context. "
import os

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import (
    DocumentCompressorPipeline, EmbeddingsFilter)
from langchain.text_splitter import RecursiveCharacterTextSplitter
# use the openAI GPT-3 embeddings
from langchain_openai import OpenAIEmbeddings

from configs.config import EMBEDDING_MODEL, OPENAI_EF

from .retriever import SearchAPIRetriever

openai_ef = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.environ["OPENAI_API_KEY"],
)


class ContextCompressor:
    "This class is responsible for compressing the documents and returning the relevant context."

    def __init__(self, documents, max_results=7, **kwargs):
        "Initializes the ContextCompressor class."
        self.max_results = max_results
        self.documents = documents
        self.kwargs = kwargs
        self.embeddings = openai_ef
        self.similarity_threshold = 0.38

    def _get_contextual_retriever(self):
        "Returns the contextual retriever."
        splitter = RecursiveCharacterTextSplitter(chunk_size=2500, chunk_overlap=250)
        relevance_filter = EmbeddingsFilter(
            embeddings=self.embeddings, similarity_threshold=self.similarity_threshold
        )
        pipeline_compressor = DocumentCompressorPipeline(
            transformers=[splitter, relevance_filter]
        )
        base_retriever = SearchAPIRetriever(pages=self.documents)
        contextual_retriever = ContextualCompressionRetriever(
            base_compressor=pipeline_compressor, base_retriever=base_retriever
        )
        return contextual_retriever

    def _pretty_print_docs(self, docs, top_n):
        "Returns the pretty printed documents."
        return f"\n\n".join(
            f"Source: {d.metadata.get('source')}\n"
            f"Title: {d.metadata.get('title')}\n"
            f"Content: {d.page_content}\n"
            for i, d in enumerate(docs)
            if i < top_n
        )

    def get_context(self, query, max_results=7):
        "Returns the context for the given query."
        compressed_docs = self._get_contextual_retriever()
        relevant_docs = compressed_docs.get_relevant_documents(query)
        return self._pretty_print_docs(relevant_docs, max_results)
