import sys

sys.path.append(".")


import asyncio

# from utils.logger import ServiceLogger
import logging
import os
from concurrent.futures import ThreadPoolExecutor

import boto3
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
from langchain_core.output_parsers import JsonOutputParser
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from configs.config import OPENAI_EF, OPENAI_MODEL_MINI
from services.ppt_generator.data_classes.project import Project
from utils.chroma_db import ChromaDB
from utils.client_check import ClientConfig
from utils.document_loader.pdf_loader import LoadPDF
from utils.document_loader.word_loader import LoadWordDoc
from utils.dynamo_db import DynamoDB
from utils.webscrape.sitemap_scrape import SitemapScrape

XCM_logger = logging.getLogger()

XCM_logger = logging.getLogger()


class FinancialInfo(BaseModel):
    document_type: str = Field(..., title="Document Type")
    start_date: str = Field(..., title="Start Date YYYY-MM-DD")
    end_date: str = Field(..., title="End Date YYYY-MM-DD")


follow_up_questions_to_ask = {
    "comapny_info": [
        {
            "question": "What is the legal name of the business?",
            "attribute": "company_alt_names",
        },
    ],
    "company_financial": [
        {
            "question": """What type of financial document is this: tax, income_statement, balance_sheet, cash_flow, Other? \
                    If other, explain it. It can be multiple typesWhat is the date range of the document? Start date and end date \
                    return a JSON format: document_type, start_date: YYYY-MM-DD, end_date: YYYY-MM-DD""",
            "attribute": "company_financials",
            "output_parser": JsonOutputParser(pydantic_object=FinancialInfo),
        },
        {
            "question": "What is the legal name of the business?",
            "attribute": "company_alt_names",
        },
    ],
    "industry": None,
    "other": None,
    "website": None,
}


class DocClassification(BaseModel):
    classification: str = Field(..., title="Classification of the document")


class DocUploader:
    "Upload a document to Database"
    llm35 = ChatOpenAI(
        model=OPENAI_MODEL_MINI, api_key=os.environ["OPENAI_API_KEY"], temperature=0.1
    )
    chunk_size = 2000
    chunk_overlap = chunk_size // 10

    delimiter_text = "_+_"

    classification = [
        "comapny_info",
        "company_financial",
        "industry",
        "other",
        "website",
    ]

    db = DynamoDB()

    def __init__(
        self, project: Project, client: ClientConfig = None, doc_path: str = None
    ):
        self.supported = True
        self.project = project
        self.client = client
        self.customer_chroma_collection = None
        self.chunks = []
        self.doc_summary = ""
        self.chunks_summary = []

        self.file_path = doc_path

        if client:
            self.customer_chroma_collection = self.client.chroma_db

        if doc_path:
            self.file_name = os.path.basename(doc_path)
            self.doc_classification = self.classify_document(self.file_name)
            self.document_text_chunked = self.get_document_text(doc_path)

    def get_document_text(self, doc_path: str):
        "Get the document text in chunks"
        lower_doc_path = doc_path.lower()
        if lower_doc_path.endswith(".pdf"):
            doc_loader = LoadPDF(doc_path)
            self.doc_type = "pdf_doc"
        elif lower_doc_path.endswith(".docx") or lower_doc_path.endswith(".doc"):
            doc_loader = LoadWordDoc(doc_path, self.project)
            self.doc_type = "word_doc"

        else:
            self.supported = False
            # raise ValueError("The document type is not supported")
            return None

        return doc_loader.document_text_chunked

    def classify_document(self, filename: str):
        "Classify the document"
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", "You are a financial analyst able to classify items"),
                (
                    "system",
                    "You will choose between the following classes: {classification_choices}",
                ),
                (
                    "system",
                    """
            Here are some examples of the classification:
            "TTM through March 2024" -> "company_financial"
            "questionnaire" -> "company_info"
            "23611D Remodeling in the US Industry Report -IBIS" -> "industry"
            "2023 P and L-Bal Sheet" -> "company_financial"
            "2024 YTD comparison"" -> "company_financial"
            """,
                ),
                ("human", "The document file name is: {filename}"),
                (
                    "human",
                    "Return the answer in JSON format as only the classification: {JSON_format}",
                ),
            ]
        )

        parser = JsonOutputParser(pydantic_object=DocClassification)
        chain = prompt | self.llm35 | parser

        chain_output = chain.invoke(
            {
                "classification_choices": "\n".join(self.classification),
                "filename": filename,
                "JSON_format": parser.get_format_instructions(),
            }
        )

        return chain_output["classification"]

    def upload_to_chromadb(self):
        "Upload the document to ChromaDB"
        # TODO: Convert this to be a generic function
        chromadb = ChromaDB()
        collection = chromadb.chroma_client.get_or_create_collection(
            self.customer_chroma_collection, embedding_function=OPENAI_EF
        )

        # return the chunks
        chunks = []

        chunk_number = 1
        if self.document_text_chunked is None:
            return []

        unique_ids = []
        documents_chunks = []
        metadatas = []
        for chunk in self.document_text_chunked:
            splitter = SentenceTransformersTokenTextSplitter(
                chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
            )

            text_chunks = splitter.split_text(chunk["text"])
            for i, text_chunk in enumerate(text_chunks):
                unique_id = self.unique_id(chunk_number, i)
                metadata = self.standardize_metadata(chunk)
                if "table_name" in chunk:
                    metadata["table_name"] = chunk["table_name"]

                unique_ids.append(unique_id)
                metadatas.append(metadata)
                documents_chunks.append(text_chunk)
                chunks.append(text_chunk)

            logging.info(
                "Uploaded chunk %s of %s \n File: %s to ChromaDB",
                chunk_number,
                len(self.document_text_chunked),
                self.file_name,
            )
            chunk_number += 1

        collection.upsert(
            ids=unique_ids,
            documents=documents_chunks,
            metadatas=metadatas,
        )
        return chunks

    def standardize_metadata(
        self,
        chunk,
        file_name: str = None,
        doc_type: str = None,
        doc_classification: str = None,
        chunk_type: str = None,
    ):
        "Standardize the metadata"
        file_name = file_name if file_name else self.file_name
        doc_type = doc_type if doc_type else self.doc_type
        doc_classification = (
            doc_classification if doc_classification else self.doc_classification
        )
        chunk_type = chunk_type if chunk_type else chunk["type"]

        return {
            "project_id": self.project.project_id,
            "file_name": file_name,
            "document_type": doc_type,
            "doc_classification": doc_classification,
            "chunk_type": chunk_type,
        }

    def unique_id(self, chunk_number: int, i: int = 0):
        "Create a unique ID for the document."
        return self.delimiter_text.join(
            [self.project.project_id, self.file_name, str(chunk_number), str(i)]
        )

    @staticmethod
    def _sanitize_text(text):
        "Remove non-printable characters from the text fields."
        return "".join(c for c in text if c.isprintable())

    def vectorize_website(self):
        "Upload the website to ChromaDB"
        XCM_logger.info("Uploading website %s to ChromaDB", self.project.company_url)
        scrape = SitemapScrape(
            self.project.company_url, summarize=True, exclude_blogs=True
        )
        # upload the website to dynamoDB

        items_to_upload = [
            {
                "project_id": self.project.project_id,
                "url": url,
                "text": self._sanitize_text("\n".join(item["text"])),
                "summary": item["summary"],
            }
            for url, item in scrape.urls_scraped.items()
        ]
        # self.db.batch_upload_to_dynamodb(self.db.web_pages, items_to_upload)
        db = DynamoDB()
        for item in items_to_upload:
            import time

            time.sleep(0.25)
            item_clean = {
                k: self._sanitize_text(v) if isinstance(v, str) else v
                for k, v in item.items()
            }
            db.upload_to_dynamodb(self.db.web_pages, item_clean)
            logging.info("Uploaded website %s to DynamoDB" % item["url"])

        self.db.upload_to_dynamodb(
            self.db.web_pages,
            {
                "project_id": self.project.project_id,
                "url": "images",
                "text": "images",
                "images": scrape.image_urls,
            },
        )
        self.upload_website_to_vectorDB(scrape.urls_scraped)

    def upload_website_to_vectorDB(self, website_scraped: dict):
        "Upload the website to VectorDB"
        chromadb = ChromaDB()
        collection = chromadb.chroma_client.get_or_create_collection(
            self.customer_chroma_collection, embedding_function=OPENAI_EF
        )

        self.file_name = "Website"

        unique_ids = []
        metadatas = []
        document_text = []
        for url, item in website_scraped.items():
            if item["summary"] is None or item["summary"] == "":
                continue
            unique_id = self.unique_id(url)
            metadata = self.standardize_metadata(
                item,
                url,
                doc_type="website",
                doc_classification="website",
                chunk_type="text",
            )
            unique_ids.append(unique_id)
            metadatas.append(metadata)
            document_text.append(item["summary"])

        collection.upsert(
            ids=unique_ids,
            documents=document_text,
            metadatas=metadatas,
        )

        collection.upsert(
            ids=unique_ids,
            documents=document_text,
            metadatas=metadatas,
        )

    async def summarize_chunks(self, chunks):
        "Summarize the chunks"

        # Use asyncio.gather to run summarize_chunk calls concurrently
        chunks_summary = await self.summarize_chunk(chunks)

        if len(chunks_summary) == 1:
            doc_summary = chunks_summary[0]
        else:
            doc_summary = "\n".join(chunks_summary)
            doc_summary = await self.summarize_chunk(doc_summary)

        self.doc_summary = doc_summary
        self.chunks_summary = chunks_summary

        return chunks_summary, doc_summary

    async def summarize_chunk(self, chunks):
        "Summarize the chunk"

        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """ You are skilled investment banker and expert in summarizing information as a financial analyst. """,
                ),
                ("system", "Summarize the following text:"),
                ("human", "The following text:"),
                ("human", "{chunk}"),
            ]
        )

        chain = prompt | self.llm35

        if isinstance(chunks, list):
            chain_output = await chain.abatch([{"chunk": chunk} for chunk in chunks])
            output = [chunk_output.content for chunk_output in chain_output]

        else:
            chain_output = await chain.ainvoke({"chunk": chunks})
            output = chain_output.content

        return output

    def upload_to_dynamodb(self):
        "Upload the document to dynamoDB"

        db = DynamoDB()

        db.upload_to_dynamodb(
            db.project_docs,
            {
                "project_id": self.project.project_id,
                "doc_path": self.file_path,
                "doc_classification": self.doc_classification,
                "file_name": self.file_name,
                "doc_type": self.doc_type,
                "doc_summary": self.doc_summary,
                "chunks_summary": self.chunks_summary,
                "chunks": self.chunks,
            },
        )

    def follow_up_questions(self):
        "Follow up questions"
        if self.doc_classification is None:
            return None

        if self.doc_classification not in self.classification:
            return None

        if follow_up_questions_to_ask.get(self.doc_classification) is None:
            return None

        # follow up questions
        for question in follow_up_questions_to_ask.get(self.doc_classification):
            prompt_messages = [
                (
                    "system",
                    "You are a skilled knowledge worker. Answer the questions asked and return only the answer",
                ),
                ("human", "The questions is {question}"),
                (
                    "human",
                    "You have access to the following chunk of the document: {document}",
                ),
                (
                    "human",
                    "You have access to the previous answer to the question as well: {previous_answer}. Update it.",
                ),
            ]

            previous_answer = ""
            if question.get("output_parser") is not None:
                prompt_messages.append(
                    (
                        "human",
                        "Your output should be in the following format: {output_format}",
                    )
                )
                prompt = ChatPromptTemplate.from_messages(prompt_messages)
                chain = prompt | self.llm35 | question.get("output_parser")
            else:
                prompt = ChatPromptTemplate.from_messages(prompt_messages)
                chain = prompt | self.llm35

            for i in range(0, len(self.chunks), 2):
                chunk = self.chunks[
                    i : i + 2
                ]  # This will automatically handle the end of the list
                chunk = "\n\n".join(chunk)
                if question.get("output_parser") is not None:
                    chain_output = chain.invoke(
                        {
                            "question": question["question"],
                            "document": chunk,
                            "previous_answer": previous_answer,
                            "output_format": question.get(
                                "output_parser"
                            ).get_format_instructions(),
                        }
                    )
                else:
                    chain_output = chain.invoke(
                        {
                            "question": question["question"],
                            "document": chunk,
                            "previous_answer": previous_answer,
                        }
                    )

                if getattr(chain_output, "content", None) is None:
                    previous_answer = chain_output
                else:
                    previous_answer = chain_output.content

            # update the project in the DB
            attribute_to_update = getattr(self.project, question["attribute"])
            if attribute_to_update is None:
                attribute_to_update = []
            attribute_to_update.append(previous_answer)
            setattr(self.project, question["attribute"], attribute_to_update)

        self.project.update_project_in_db()


def list_files_in_folder(bucket_name: str, folder_name: str):
    "List all the files in a folder"
    s3 = boto3.client("s3")
    response = s3.list_objects_v2(Bucket=bucket_name, Prefix=folder_name)
    files = []
    for obj in response.get("Contents", []):
        if obj["Key"] == folder_name:
            continue
        files.append(obj["Key"])
    return files


def doc_uploader(
    bucket_name: str, doc_path: str, project: Project, client_config: ClientConfig
):

    try:
        "Parse and Upload the document to ChromaDB"
        print(f"Uploading {doc_path} to ChromaDB")
        doc_path = f"s3://{bucket_name}/{doc_path}"
        doc_uploader = DocUploader(project, client_config, doc_path)
        print(doc_uploader.doc_classification)
        doc_uploader.chunks = doc_uploader.upload_to_chromadb()
        # ask the follow up questions
        doc_uploader.follow_up_questions()
        # summarize the chunks
        asyncio.run(doc_uploader.summarize_chunks(doc_uploader.chunks))
        # upload to dynamoDB
        doc_uploader.upload_to_dynamodb()
        print(f"Uploaded {doc_path} to DynamoDB")
    except Exception as e:
        logging.error("Error processing document: %s", doc_path, exc_info=True)


def doc_uploader_call(project_id, client: ClientConfig):
    "Upload the document to ChromaDB"
    client_config = client
    bucket_name = os.environ["xcap_s3_storage"]
    folder_name = f"{client_config.client}/{project_id}/"  # Ensure the folder name ends with a '/'

    db = DynamoDB()
    project = db.get_item(db.projects, project_id)
    project = Project(**project)

    files = list_files_in_folder(bucket_name, folder_name)

    # Use ThreadPoolExecutor for concurrent uploads
    with ThreadPoolExecutor(
        max_workers=min(1, os.cpu_count() // 2)
    ) as executor:  # Adjust `max_workers` as needed
        futures = [
            executor.submit(doc_uploader, bucket_name, doc_path, project, client_config)
            for doc_path in files
        ]
        for future in futures:
            future.result()  # Wait for all uploads to finish

    # Final vectorization step (runs after all document uploads are complete)
    doc_uploader_instance = DocUploader(project, client_config)
    doc_uploader_instance.vectorize_website()

    return {"message": "Files uploaded successfully."}


if __name__ == "__main__":

    # project_id = "icebox"
    # client = ClientConfig("sunbelt").get_client_config()

    # doc_uploader_call(project_id, client)

    project_id = "iti-digital-1b7107ab-9e73-4590-90ce-1ba6c6b6a791"
    client = ClientConfig("tequity").get_client_config()

    client_config = client
    doc_uploader_call(project_id, client)
    bucket_name = os.environ["xcap_s3_storage"]
    folder_name = f"{client_config.client}/{project_id}/"  # Ensure the folder name ends with a '/'

    db = DynamoDB()
    project = db.get_item(db.projects, project_id)
    project = Project(**project)
    doc_uploader_instance = DocUploader(project, client_config)
    doc_uploader_instance.vectorize_website()
