"Service to generate PPT structure."
import asyncio
import json
import sys

import faiss
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.output_parsers import JsonOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

sys.path.append(r".")
# from utils.logger import ServiceLogger
import logging

from configs.config import (
    EMBEDDING_MODEL,
    OPENAI_API_KEY,
    OPENAI_EF,
    OPENAI_MODEL_4O,
    OPENAI_MODEL_MINI,
)
from services.company_profile.data_classes.llm_results import LLMResults
from services.ppt_generator.create_ppt import CreatePPT
from services.ppt_generator.data_classes.presentation_outline import (
    PresentationOutline,
    Section_LLM,
    SectionLLMWithSlides,
    SlideLLM,
)
from services.ppt_generator.data_classes.project import Project, Section, Slide
from services.ppt_generator.presentation_types import presentation_types
from utils.client_check import ClientConfig
from utils.dynamo_db import DynamoDB
from utils.redis_cache import redis_cache
from utils.url_parser import parsed_url

load_dotenv()
PUBLIC_STATUS = ["public", "private"]
embeddings = OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY)
index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world")))

XCM_logger = logging.getLogger()


class PPTStructure:
    "Service to generate PPT structure."

    llm_model4 = ChatOpenAI(
        model=OPENAI_MODEL_4O, temperature=0.3, api_key=OPENAI_API_KEY
    )

    llm_model3 = ChatOpenAI(
        model=OPENAI_MODEL_MINI, temperature=0.3, api_key=OPENAI_API_KEY
    )

    LAYOUT_CHOICES = "services/ppt_generator/content_slide_layout_choices_v5.txt"

    additional_instructions = presentation_types
    additional_instructions = presentation_types

    def __init__(self, project: Project = None, client: ClientConfig = None):
        "Initialize the service."
        # run through checks to make sure the project is supported
        if project.public_status not in PUBLIC_STATUS:
            raise ValueError("public_status must be either public or private")

        if project.pitch_type not in self.additional_instructions:
            raise ValueError(
                f"pitch_type must be {', '.join(self.additional_instructions.keys())}"
            )

        # defining the structure and slides as empty list
        self.project = project
        self.main_company = self.get_main_company()  # sets the main company
        self.client_config = client

        self.LAYOUT_CHOICES = (
            self.client_config.layout_choices
            if self.client_config.layout_choices is not None
            else self.LAYOUT_CHOICES
        )

        self.structure_recommendation = (
            self.client_config.structure_recommendation.get(
                self.project.pitch_type, None
            )
            if self.client_config
            else None
        )

        # open the content_slide_layout_choices.txt file and store the text into a variable
        with open(self.LAYOUT_CHOICES, "r", encoding="utf-8") as file:
            self.content_slide_layout_choices = file.read()

        self.get_project_docs()

    def get_project_docs(self):
        "Get the documents for the project."
        self.project_doc_summaries, docs = self.get_uploaded_documents_summaries()
        self.project_website_summaries, websites = self.get_website_summaries()

        ## in memory vector store
        self.doc_store = FAISS(
            embedding_function=embeddings,
            index=index,
            docstore=InMemoryDocstore(),
            index_to_docstore_id={},
        )
        self._create_vector_store(docs, websites)

    def add_or_update_project_in_db(self):
        "Add or update the structure in the DB."
        return self.project.update_project_in_db()

    def get_structure(self) -> list[Section]:
        "Get the structure of the PPT."
        ## consolidate summaries of the documents and website

        ## client specific instructions

        llm_message = [
            (
                "system",
                """
                    You are an Investment Banking Managing Director creating a slide deck outline.\n
                    The structure of the deck should be broken into sections which will contain different slides.\n
                    Don't include an introduction section. Always include a company and market overview section.\n
                    Each section should have multiple slides and a slide should only cover one topic.\n
                    You want to ensure you have an executive overivew section to start and a financial overview at the end.\n
                    {additional_instructions}
                """,
            ),
            (
                "system",
                "You will answer in json format which will be provided below. You are only being tasked to provide the structure\
                    (sections) of the PPT with the theme or content that should be included in each section.",
            ),
            (
                "human",
                """You are pitching {company}, more information on the company is provided below, which is {public_status} in the {industry} and sector {sector} options for {pitch_type}.""",
            ),
            (
                "human",
                "Company overview is below \n {company_overview} \n\n Company product and services overview is below \n {company_products}",
            ),
        ]

        if self.project_doc_summaries:
            project_doc_summaries = "\n\n".join(self.project_doc_summaries)
            llm_message.append(
                (
                    "human",
                    f"Here are the summaries of the uploaded documents: {project_doc_summaries}",
                )
            )

        if self.project_website_summaries:
            project_website_summaries = "\n\n".join(self.project_website_summaries)
            llm_message.append(
                (
                    "human",
                    f"Here are the summaries of the website: {project_website_summaries}",
                )
            )

        if self.structure_recommendation:
            llm_message.append(
                (
                    "human",
                    f"Ensure the sections could cover the slides from this list: {self.structure_recommendation}",
                )
            )

        if self.client_config.client_type == "ib":
            llm_message.append(("human", "You are going to limit your sections to 8."))
        else:
            llm_message.append(("human", "You are going to limit your sections to 5."))

        llm_message.append(("system", "You will answer in json format {json_format}"))

        chat_prompt = ChatPromptTemplate.from_messages(llm_message)

        parser = JsonOutputParser(pydantic_object=PresentationOutline)

        chain = chat_prompt | self.llm_model4 | parser

        response = chain.invoke(
            {
                "company": self.project.company_name,
                "public_status": self.project.public_status,
                "industry": (
                    self.project.industry
                    if self.project.industry
                    else self.main_company["industry"]
                ),
                "sector": (
                    self.project.sector
                    if self.project.sector
                    else self.main_company["sector"]
                ),
                "pitch_type": self.project.pitch_type,
                "company_overview": self.main_company["overview"],
                "company_products": self.main_company["products"],
                "additional_instructions": (
                    self.additional_instructions[self.project.pitch_type]
                    if self.project.pitch_type in self.additional_instructions
                    else ""
                ),
                "json_format": parser.get_format_instructions(),
            }
        )

        try:
            self.project.sections = [
                Section_LLM(**section) for section in response["sections"]
            ]
        except (KeyError, ValueError):
            self.project.sections = response["properties"]["sections"]

        self.add_or_update_project_in_db()

    async def flesh_out_sections(self):
        "Flesh out the structure of the PPT."
        tasks = []
        for i, section in enumerate(self.project.sections):

            async def flesh_out_section(
                section: Section,
                previous_sections: list[Section],
                next_sections: list[Section],
            ):
                "Flesh out the section."
                try:
                    print("fleshing section", section.title)
                    await self.flesh_out_section(
                        section, previous_sections, next_sections
                    )
                    print("fleshed out section", section.title)
                    # self.project.update_project_in_db()
                except AttributeError as e:
                    print(f"Error in fleshing out section: {e}")
                    XCM_logger.error(
                        "Error in fleshing out section: %s", e, exc_info=True
                    )
                    if getattr(section, "slides", None) is not None:
                        section.slides = None

            previous_sections = self.project.sections[: i - 1] if i > 0 else []
            next_sections = (
                self.project.sections[i + 1 :]
                if i < len(self.project.sections) - 1
                else []
            )
            tasks.append(
                asyncio.create_task(
                    flesh_out_section(section, previous_sections, next_sections)
                )
            )

        await asyncio.gather(*tasks)
        self.project.update_project_in_db()

    async def flesh_out_section(
        self,
        section: Section,
        previous_sections: list[Section],
        next_sections: list[Section],
    ) -> Section:
        "Flesh out this section of the PPT."

        prompt_messages = [
            (
                "system",
                "You are an Investment Banking Managing Director who is an expert in the {industry}. Specifically, you are an expert in the {sector} sector.",
            ),
            (
                "system",
                "You are going to be given a section of a {pitch_type} presentation and you are going to help flesh it out.",
            ),
            ("system", "One slide can only have one topic. "),
            ("system", "You will return a json object as {json_output}"),
            (
                "human",
                "Company you are working with is {company} with the products / services of {products}",
            ),
            (
                "human",
                "The section you are working on is: {section_title} and the content of the section should be: {section_content}. Don't overlap with other sections.",
            ),
            (
                "human",
                "In addition to your knowledge, you have the following documents to reference: {documents}",
            ),
        ]

        if previous_sections:
            previous_sections_str = ", ".join(
                [section.title for section in previous_sections]
            )
            prompt_messages.append(
                (
                    "human",
                    f"The previous section titles are: {previous_sections_str}. Don't overlap slides with other sections.",
                )
            )

        if next_sections:
            next_sections_str = ", ".join([section.title for section in next_sections])
            prompt_messages.append(
                (
                    "human",
                    f"The next section titles are: {next_sections_str}. Don't overlap slides with other sections.",
                )
            )

        if self.client_config.client_type == "ib":
            prompt_messages.append(
                (
                    "human",
                    "You are going to limit your slides to 10 per section but you don't need all of them.",
                )
            )
        else:
            prompt_messages.append(
                (
                    "human",
                    "You are going to limit your slides to 10 per section but you don't need to all of them.",
                )
            )

        prompt = ChatPromptTemplate.from_messages(prompt_messages)

        parser = JsonOutputParser(pydantic_object=SectionLLMWithSlides)

        llm_ = ChatOpenAI(
            model=OPENAI_MODEL_MINI, temperature=1, api_key=OPENAI_API_KEY
        )
        chain = prompt | llm_ | parser

        response = await chain.ainvoke(
            {
                "industry": (
                    self.project.industry
                    if self.project.industry
                    else self.main_company["industry"]
                ),
                "sector": (
                    self.project.sector
                    if self.project.sector
                    else self.main_company["sector"]
                ),
                "pitch_type": self.project.pitch_type,
                "company": self.project.company_name,
                "products": self.main_company["products"],
                "section_title": section.title,
                "section_content": section.content,
                "documents": self.find_relevant_summary(section),
                "json_output": parser.get_format_instructions(),
            }
        )

        try:
            section.slides = [SlideLLM(**slide) for slide in response["slides"]]
            return section
            # update the section in the db
        except (KeyError, ValueError) as exc:
            raise AttributeError(
                f"Error in fleshing out section for project {self.project.project_id}, section {section.title}"
            ) from exc

    async def flesh_out_new_section(self, sectionsLLM: list[Section]):
        "Flesh out new or updated sections and slides. Adds modification flags to track changes."
        updatedSections = []

        for i, section in enumerate(sectionsLLM):
            # Check if section is new or updated
            if self._is_item_modified(section):
                section = await self._process_modified_section(section, i)
            else:  # check if slides need to be fleshed out
                section = await self._process_existing_section(section)
                # If any slides were modified, mark the section
                if any(getattr(slide, "isModified", False) for slide in section.slides):
                    section.isModified = True

            updatedSections.append(section)

        return updatedSections

    async def _process_modified_section(self, section: Section, index: int):
        "Process a new or updated section."
        try:
            # Get context sections
            previous_sections = self.project.sections[:index] if index > 0 else []
            next_sections = (
                self.project.sections[index + 1 :]
                if index < len(self.project.sections) - 1
                else []
            )

            section = await self.flesh_out_section(
                section=section,
                previous_sections=previous_sections,
                next_sections=next_sections,
            )

            # Clean up modification flags
            self._clear_modification_flags(section)
            section.isModified = True

        except Exception as e:
            print("Error in fleshing out section:", e)
            XCM_logger.error("Error in fleshing out section: %s", e, exc_info=True)

            ## since the section was asked to be reset, the slides need to be reset as well
            if getattr(section, "slides", None) is not None:
                section.slides = None

        return section

    async def _process_existing_section(self, section: Section):
        "Process existing section with potentially modified slides."
        if not getattr(section, "slides", None):
            return section

        if any(self._is_item_modified(slide) for slide in section.slides) is False:
            return section

        slide_tasks = []

        for index, slide in enumerate(section.slides):
            if self._is_item_modified(slide):
                slide_tasks.append({"index": index, "slide": slide})

        slide_tasks = await self.flesh_out_slides(slide_tasks, section)

        for slide_task in slide_tasks:
            if slide_task["output"] is None:
                section.slides[slide_task["index"]] = None
                continue
            if slide_task["slide"].dict() != slide_task["output"].dict():
                slide_task["output"].isModified = (
                    True  # Changed from modified to isModified
                )

            section.slides[slide_task["index"]] = self._clear_modification_flags(
                slide_task["output"]
            )

        return section

    def _is_item_modified(self, item):
        "Check if slide is new or updated."
        return getattr(item, "isNew", False) or getattr(item, "isUpdated", False)

    def _clear_modification_flags(self, item):
        "Clear new/updated flags from an item but preserve modification status."
        for flag in ["isNew", "isUpdated"]:
            if hasattr(item, flag):
                delattr(item, flag)

        return item

    async def flesh_out_slides(self, slide_tasks: list[dict], section: Section):
        "Flesh out a batch of slides slides."
        prompt_messages = self._build_slide_prompt_messages()

        chat_prompt = ChatPromptTemplate.from_messages(prompt_messages)
        parser = JsonOutputParser(pydantic_object=SlideLLM)
        llm = ChatOpenAI(
            openai_api_key=OPENAI_API_KEY,
            temperature=0.5,
            model_name=OPENAI_MODEL_MINI,
        )

        chain = chat_prompt | llm | parser

        response = await chain.abatch(
            [
                self._get_slide_chain_inputs(
                    slide_task["slide"], section, slide_task["index"], parser
                )
                for slide_task in slide_tasks
            ]
        )

        for i, slide_task in enumerate(slide_tasks):
            try:
                slide_task["output"] = SlideLLM(**response[i])
            except Exception as e:
                XCM_logger.error(
                    "Error in fleshing out slide: %s in section %s in Project %s. \n Error was %s",
                    slide_task["slide"].title,
                    section.title,
                    self.project.project_id,
                    e,
                    exc_info=True,
                )
                slide_task["output"] = None

        return slide_tasks

    # def _build_slide_prompt_messages(self, slide: Slide):
    def _build_slide_prompt_messages(self):
        "Build prompt messages for slide generation."
        messages = [
            (
                "system",
                f"You are an investment banking expert in the {self.project.industry} industry and the {self.project.sector} sector.",
            ),
            (
                "system",
                "You are going to be provided with guidance on how to create a new slide in a presentation.",
            ),
            (
                "human",
                """The section you are working on is: {section_title} and the content of the section is: {section_content}. \
                Ensure this slide doesn't overlap with other slides in the section.""",
            ),
        ]

        # TODO: fix if this needs to be a a slide
        # if self._is_slide_modified(slide):
        messages.extend(
            [
                (
                    "human",
                    "You are going to be adding a brand new slide which is titled: {slide_title} and should contain the following content: {slide_content}",
                ),
                ("human", "The slides before this one are titled: {slides_before}"),
                ("human", "The slides after this one are titled: {slides_after}"),
                (
                    "human",
                    "Make sure this slide doesn't overlap with other slides in the section.",
                ),
            ]
        )

        messages.append(
            ("system", "You will respond in a json object as {json_output}")
        )
        return messages

    def _get_slide_chain_inputs(
        self, slide: Slide, section: Section, index: int, parser: JsonOutputParser
    ):
        "Get inputs for the slide generation chain."
        previous_slides = [s.title for s in section.slides[:index]] if index > 0 else []
        next_slides = (
            [s.title for s in section.slides[index + 1 :]]
            if index < len(section.slides) - 1
            else []
        )

        return {
            "section_title": section.title,
            "section_content": section.content,
            "slide_title": slide.title,
            "slide_content": slide.content,
            "slides_before": "; ".join(previous_slides),
            "slides_after": "; ".join(next_slides),
            "json_output": parser.get_format_instructions(),
        }

    def find_relevant_summary(self, section: Section):
        "Find the relevant summary."

        results = self.doc_store.search(
            query=f"Section title: {section.title} and Section content: {section.content}",
            search_type="mmr",
        )

        return "\n\n".join([result.page_content for result in results])

    def create_PPT(self):
        "Create the PPT from the structure."
        ppt_creator = CreatePPT(self.project, self.main_company, self.client_config)
        deck_location = ppt_creator.convert_structure_to_ppt()
        return deck_location

    def get_main_company(self):
        "Get the main company."

        db = DynamoDB()
        company_url = parsed_url(self.project.company_url).url
        main_company = db.get_item(db.company_info_table, company_url)
        llm_info = db.get_item(db.llm_table, company_url)

        llm_info = LLMResults(**llm_info)

        main_company["overview"] = ""
        main_company["products"] = ""
        for _item in llm_info.LLM_results:
            if _item.category == "Overview_Stats":
                main_company["overview"] = _item.response
            if _item.category == "Products":
                main_company["products"] = _item.response

        return main_company

    @redis_cache(ttl=3600)  # caches for an hour
    def get_uploaded_documents_summaries(self):
        "Get the uploaded document summaries."
        from boto3.dynamodb.conditions import Attr

        db = DynamoDB()
        documents = db.query_items(db.project_docs, self.project.project_id)

        doc_summaries = []
        for doc in documents:
            doc_summaries.append(doc["doc_summary"])

        return doc_summaries, documents

    @redis_cache(ttl=3600)  # caches for an hour
    def get_website_summaries(self):
        "Get the website summaries."
        from boto3.dynamodb.conditions import Attr

        db = DynamoDB()
        websites = db.query_index(
            table=db.web_pages,
            index_name="project_id-index",
            sort_key_value=self.project.project_id,
            expression_values=Attr("project_id").eq(self.project.project_id),
        )
        return [website.get("summary", "") for website in websites], websites

    def _create_vector_store(self, docs=None, websites=None):
        "Create the vector store."
        if docs:
            self.doc_store.add_documents(
                documents=[
                    Document(
                        page_content=doc["doc_summary"],
                        metadata={
                            "file_name": doc["file_name"],
                            "doc_path": doc["doc_path"],
                        },
                    )
                    for doc in docs
                ]
            )

        if websites:
            self.doc_store.add_documents(
                documents=[
                    Document(
                        page_content=site.get("summary", ""),
                        metadata={"file_name": site["url"], "doc_path": site["url"]},
                    )
                    for site in websites
                ]
            )

    def research_presentation(self):
        "Conduct research for the presentation."
        XCM_logger.info(
            "Conducting research for presentation %s", self.project.project_id
        )
        ppt_creator = CreatePPT(self.project, self.main_company, self.client_config)
        ppt_creator.research_presentation()
        self._update_project()
        XCM_logger.info(
            "Research for presentation %s completed", self.project.project_id
        )
        return self.project

    def research_slide(self, section_number: int, slide_number: int):
        "Research the slide."
        section = self.project.sections[section_number - 1]
        slide = section.slides[slide_number - 1]
        ppt_creator = CreatePPT(self.project, self.main_company, self.client_config)
        slide.research = ppt_creator.conduct_research(slide)
        self._update_project()
        return slide

    def _update_project(self):
        "Update the project into the project class"
        self.project = Project(**self.project.dict())
        return self.project

    ## TODO
    @staticmethod
    def _check_new_updates(self, project_id: str, new_sections: list[dict]):
        "Check if there are any new updates"
        project = Project.check_project_in_db(project_id=project_id)
        if project is None:
            raise ValueError("Project does not exist in the database")

        # if there is an update on the section: call flesh_out_section


if __name__ == "__main__":

    # import time
    # time.sleep(120)

    ##"Test the service"
    projects = [
        # {
        #     "project_id": "Northbound",
        #     "client": "tequity"
        # },
        {
            "project_id": "iti-digital-1b7107ab-9e73-4590-90ce-1ba6c6b6a791",
            "client": "tequity",
        }
    ]
    for project_details in projects:
        import time

        project = Project.check_project_in_db(
            project_id=project_details["project_id"],
        )

        # #%% get the structure of the deck
        start = time.time()
        ppt_structure = PPTStructure(
            project, client=ClientConfig(project_details["client"]).get_client_config()
        )
        # end = time.time()
        # print(end - start)
        # # # ## get the structure of the deck
        # ppt_structure.get_structure()

        # # %% flesh out the struture of the sections
        # ppt_structure._update_project()
        # start = time.time()
        # asyncio.run(ppt_structure.flesh_out_sections())
        # end = time.time()
        # print(end - start)

        # ppt_structure._update_project()

        # %% research the presentation
        start = time.time()
        ppt_structure.research_presentation()
        end = time.time()
        print(end - start)

        # %% create the ppt
        # create the ppt
        # start = time.time()
        # ppt_structure.create_PPT()
        # end = time.time()
        # print(end - start)
