"Primary research script for the project."

import os
import sys
from datetime import datetime

sys.path.append(r".")


from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from configs.config import EMBEDDING_MODEL, OPENAI_EF
from services.ppt_generator.data_classes.project import Project
from utils.chroma_db import ChromaDB
from utils.client_check import ClientConfig
from utils.dynamo_db import DynamoDB
from utils.researcher.dataclass.research_llm_classes import Questions


class ResearchChroma:
    "Class for the primary research for the project."

    llm35 = ChatOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"), model="gpt-3.5-turbo", temperature=0.1
    )

    llm4_sub_query = ChatOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4-turbo", temperature=0.3
    )

    max_number_of_questions = 5

    COT_sub_instructions = """The questions should be chain of thought questions that can help provide context to a large-language model.\
        Such as:
        Parent question: What are the top 5 clients for the company?
        Sub-questions: What type of customers does the company serve? Where are the clients located?\

        Parent question: What trends and risks are faced by companies in the automotive industry (specific industry)?
        Sub-questions: Who are the public companies in the space? For each of those companies, what are the trends and risks they are facing?\
        """

    def __init__(self, project: Project, client: ClientConfig = None, query=""):

        self.project = project
        self.client_config = client if client else None
        self.client_chroma_collection = self.client_config.chroma_db if client else ""

        self.question = query
        self.sub_questions = []  # set in the research function

    def research(self, verbose=False):
        "Research"

        # get the data from chroma for the intial query
        docs = self.retrieve_from_chroma(self.question)
        if len(docs) == 0:
            return []
        summary = self.summarize_document_for_query(self.question, verbose, docs)
        # get sub_questions
        self.sub_questions.append({"question": self.question, "summary": summary})
        sub_questions = self._get_parent_sub_questions(self.question, summary)
        for sub_question in sub_questions["questions"]:

            docs = self.retrieve_from_chroma(sub_question)
            summary = self.summarize_document_for_query(sub_question, verbose, docs)
            self.sub_questions.append({"question": sub_question, "summary": summary})

        return self.sub_questions

    def retrieve_from_chroma(self, query: str):
        "Retrieve data from the chroma database"

        customer_chroma_collection = self.client_chroma_collection
        chroma_db = ChromaDB()
        try:
            collection = chroma_db.chroma_client.get_collection(
                customer_chroma_collection, embedding_function=OPENAI_EF
            )
        except Exception as e:  #
            collection = chroma_db.chroma_client.get_collection(
                customer_chroma_collection, embedding_function=OPENAI_EF
            )

        item = collection.query(
            where={
                "project_id": self.project.project_id,
            },
            query_texts=[query],
        )
        return item["documents"]

    def _get_parent_sub_questions(self, question: str, summary: str):
        "get questions to ask for the parent question"

        chat_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    f"""\
                    You are an expert researcher skilled in finding information and reasoning.\
                    You strive to provide data-points and quantitative answers. For any market or industry questions, you need to go higher level and then dig deep.
                    You are not providing an opinion but rather a factual answer and current analysis.\
                    {self.COT_sub_instructions}
                    """,
                ),
                ("human", "The tasks you are getting an answer for is: {question}"),
                ("human", "You already know this about the company: {summary_context}"),
                (
                    "system",
                    """\
                    Give back a list of {number_of_questions} sub-questions which can help provide context to a large-language model. Each question should only cover one topic.
                    Use today's date if needed {today_date}
                    Provide your answer in a list format in such as JSON: {json_output}
                    Only ask additional questions if you need to. If you don't need to ask any questions, just say return an empty list.
                    """,
                ),
            ]
        )

        parser = JsonOutputParser(pydantic_object=Questions)

        chain = chat_prompt | self.llm4_sub_query | parser

        chain_output = chain.invoke(
            {
                "question": question,
                "summary_context": summary,
                "questions": self.max_number_of_questions,
                "today_date": datetime.now().strftime("%Y-%m-%d"),
                "number_of_questions": self.max_number_of_questions,
                "json_output": parser.get_format_instructions(),
            }
        )

        return chain_output

    def summarize_document_for_query(
        self, query, verbose=False, content: list = [""], past_summary=""
    ):
        "Summarize the document for the query."
        if isinstance(content, str):
            content = [content]

        if verbose:
            verbose_prompt = "Provide a verbose answer with multiple bullets or longer text each roughly 40 tokens long"
        else:
            verbose_prompt = ""
        summary = past_summary
        for doc in content:
            summary = self.summarize_content(
                query, verbose_prompt, doc, past_summary=summary
            )

        return summary

    def summarize_content(self, query, verbose_prompt, doc, past_summary=""):
        "Summarize the content for the query."
        chat_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """\
                You are an Investment Banking managing director, creating a detail overview of a company for a buyer. \
                You strive to provide data-points when answering the question.
                You are not providing an opinion but rather a factual answer and current analysis.\
                You want to make your answer specific to the company and quantitative in nature and comparing company specifics to the trends in the industry. If you don't have enough information use [X] as placeholder \
                If you don't have enough information, do not try to answer the question and only return 'not enough information'.\
                """,
                ),
                (
                    "system",
                    """Answer the question based on the information provided.
                    Only return the answer and add to the answer if relevant information is provided.
                    Do not add any additional information.
                 """,
                ),
                (
                    "system",
                    "If you need to ask the internet for information, you can do so by providing the questions as a list in the answer.",
                ),
                ("human", "The parent question we are answering is: {parent_query}"),
                ("human", "The question we are answering is: {query}"),
                ("human", "Add to the summary here: {summary}"),
                ("human", "New information: {doc}"),
                ("system", verbose_prompt),
            ]
        )

        chain = chat_prompt | self.llm35
        chain_output = chain.invoke(
            {
                "parent_query": self.question,
                "query": query,
                "summary": past_summary,
                "doc": doc,
            }
        )
        summary = chain_output.content
        return summary


if __name__ == "__main__":

    db = DynamoDB()
    project = db.get_item(db.projects, "sunbelt_demo_3")
    project = Project(**project)
    slide_working_with = project.sections[0].slides[0]
    query = slide_working_with.content + "\n".join(slide_working_with.questions)
    print(query)
    research_chroma = ResearchChroma(
        project, ClientConfig("sunbelt").get_client_config(), query
    )
    summary = research_chroma.research(verbose=True)
    print(summary)
