"this is a compression and retrieval module"

import asyncio
import os

from langchain.retrievers import EnsembleRetriever
from langchain.text_splitter import TokenTextSplitter
from langchain_chroma.vectorstores import Chroma
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from langchain_openai.embeddings import OpenAIEmbeddings


class Compressor:
    "Compressor class"

    def __init__(self, collection_name: str = "default", persist_directory: str = None):
        """
        Class to compress documents to a contextual retriever format.
        """

        self.persist_directory = persist_directory
        self.embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
        self.bm25_retriever = BM25Retriever
        if self.persist_directory:
            self.vectorstore = Chroma(
                embedding_function=self.embeddings,
                collection_name=collection_name,
                persist_directory=self.persist_directory,
            )
        else:
            self.vectorstore = Chroma(
                embedding_function=self.embeddings, collection_name=collection_name
            )

    def compress(
        self,
        documents: list[Document, str],
        chunk_size: int = 1000,
        chunk_overlap: int = None,
        **kwargs,
    ):
        """
        Compress a list of documents into a contextual retriever format.

        Args:
        - documents (list[Document]): List of documents to compress.
        - chunk_size (int): The size of each chunk to split the document into.
        - chunk_overlap (int): The amount of overlap between each chunk.
        - **kwargs: Additional keyword arguments to pass to the text_splitter.

        Returns:
        - A contextual retriever.
        """

        # convert any strings to Document objects

        documents = self.validate_docs(documents)
        # Add the documents to the vector store

        # Split the documents into smaller chunks
        if chunk_overlap is None:
            chunk_overlap = int(chunk_size * 0.1)

        if kwargs.get("token_splitter", False):
            text_splitter = kwargs["token_splitter"](
                chunk_size=chunk_size, chunk_overlap=chunk_overlap
            )
        else:
            text_splitter = TokenTextSplitter(
                chunk_size=chunk_size, chunk_overlap=chunk_overlap
            )

        texts = text_splitter.split_documents(documents)

        # Store the chunks in the vector store
        self.vectorstore.add_documents(texts)
        self.bm25_retriever = self.bm25_retriever.from_documents(texts)

    def validate_docs(self, documents):
        """
        Validate that the input documents are Document objects, and convert any
        strings to Document objects.

        Args:
        - documents (list[Document]): The list of documents to validate.

        Returns:
        - The validated list of documents.
        """
        for i, doc in enumerate(documents):
            if isinstance(doc, str):
                documents[i] = Document(page_content=doc)

        return documents

    def retrieve(self, query: str, k: int = 5, bm25_weight: float = 0.5):
        """
        Retrieve a list of documents that match a query.

        Args:
        - query (str): The query to search for.
        - k (int): The number of results to return.
        - bm25_weight (float): The weight to give the BM25 retriever when combining
            with the vectorstore retriever.

        Returns:
        - A list of documents.
        """

        # Create a retriever
        bm25_retriever = self.bm25_retriever
        bm25_retriever.k = k  # Set the number of results to return

        vectorstore_retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})

        ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, vectorstore_retriever],
            weights=[bm25_weight, (1 - bm25_weight)],
        )

        return ensemble_retriever.invoke(query)[0:k]

    async def aretrieve(self, query: str, k: int = 5, bm25_weight: float = 0.5):
        """
        Retrieve a list of documents that match a query.

        Args:
        - query (str): The query to search for.
        - k (int): The number of results to return.
        - bm25_weight (float): The weight to give the BM25 retriever when combining
            with the vectorstore retriever.

        Returns:
        - A list of documents.
        """

        # Create a retriever
        loop = asyncio.get_event_loop()
        results = await loop.run_in_executor(None, self.retrieve, query, k, bm25_weight)
        return results


if __name__ == "__main__":
    import sys

    sys.path.append(".")

    import boto3

    from services.ppt_generator.data_classes.project import Project
    from utils.client_check import ClientConfig
    from utils.document_loader.DocUploader import DocUploader
    from utils.dynamo_db import DynamoDB

    def list_files_in_folder(bucket_name: str, folder_name: str):
        "List all the files in a folder"
        s3 = boto3.client("s3")
        response = s3.list_objects_v2(Bucket=bucket_name, Prefix=folder_name)
        files = []
        for obj in response.get("Contents", []):
            if obj["Key"] == folder_name:
                continue
            files.append(obj["Key"])
        return files

    project_id = "tequity_1"
    client = "tequity"
    client_config = ClientConfig(client).client_config
    bucket_name_loc = os.environ["xcap_s3_storage"]
    project_folder_name = f"{client_config.client}/{project_id}/"  # Ensure the folder name ends with a '/'

    db = DynamoDB()
    project = db.get_item(db.projects, project_id)
    project = Project(**project)

    project_files = list_files_in_folder(bucket_name_loc, project_folder_name)

    text_chunks = []
    for doc_path in project_files[0:1]:
        print(f"Uploading {doc_path} to ChromaDB")
        doc_path = f"s3://xcap-storage-dev/{doc_path}"
        doc_uploader = DocUploader(project, client_config, doc_path)
        print(doc_uploader.doc_classification)

        text_chunks.append(
            Document(
                page_content="\n\n\n".join(
                    [page["text"] for page in doc_uploader.document_text_chunked]
                ),
                metadata={
                    "source": doc_path,
                    "project_id": project_id,
                    "client": client,
                    "doc_classification": doc_uploader.doc_classification,
                },
            )
        )

    compressor = Compressor(persist_directory="path/to/persist/directory")
    compressor.compress(text_chunks, chunk_size=250)

    question = "When was the company found?"

    results = compressor.retrieve(question)
    print([doc.page_content for doc in results])
