"Create a ppt with the given information"
import asyncio
import json
import math
import os
import time

import pptx
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_openai import ChatOpenAI
from PIL import ImageFont
from pptx.enum.shapes import MSO_SHAPE  # pylint: disable=no-name-in-module
from pptx.util import Inches, Pt
from pydantic import BaseModel

from configs.config import OPENAI_MODEL_35
# load services
from services.company_profile.company_profile import (CompanyProfileMaker,
                                                      find_llm_response)
# load in data models
from services.company_profile.data_classes.company_info import CompanyInfo
from services.company_profile.data_classes.llm_results import LLMResults
from services.ppt_generator.chart_creator import ChartCreator
from services.ppt_generator.data_classes.company_overview import \
    CompanyOverview
# load database functions
from services.ppt_generator.layouts.xcm.layouts import (DefaultSettings,
                                                        SlideWorkArea)
# load slide layouts
from services.ppt_generator.slide_layouts_choices import (
    slide_layouts_choices, slide_layouts_choices_keys)
# load other utilities
from utils.download_image import DownloadImage
from utils.search_google import GoogleSearch
from utils.url_parser import parsed_url

# Load the environment variables
load_dotenv()
GOOGLE_SEARCH_API = os.environ["GOOGLE_SEARCH_API"]
PRES_TEMPATE = "ppt_templates/ppt_template.pptx"


class NewArticles(BaseModel):
    "New articles"
    articles: list[str]
    articles_summary: list[str]
    articles_links: list[str]


class CompanyProfileSlide:
    "Create a slide for the company profile"

    def __init__(
        self,
        company_url: str,
        ppt_object: pptx.Presentation = None,
        work_area: SlideWorkArea = SlideWorkArea(),
        default_settings: DefaultSettings = DefaultSettings(),
        comparable_tickers: list = None,
        new_deck: bool = False,
    ):
        """
        Create a ppt with the given information
        ppt_object: the ppt slide object
        company_url: the company url
        work_area: the work area for the slide
        default_settings: the default settings for the slide
        comparable_tickers: the comparable tickers for the company
        """
        self.prs = ppt_object
        if self.prs is None and new_deck:
            self.prs = pptx.Presentation(PRES_TEMPATE)

        self.work_area = work_area

        self.comparable_tickers = comparable_tickers
        if self.comparable_tickers is None:
            self.comparable_tickers = ["^DJI", "^SPX", "^IXIC"]

        self.default_font = default_settings.font
        self.default_font_size = default_settings.font_size
        self.default_font_color = default_settings.font_color

        # parse the URL
        self.parsed_url = parsed_url(company_url)
        self.company_url = self.parsed_url.url

        # get items from dynamodb
        self.get_data()

        self.company_overview = CompanyOverview(
            company_name=self.company_info.company_name,
            logo=self.company_info.logo,
            stock_ticker=self.company_info.stock_ticker,
            customers=find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "Customers"
            ),
            products=find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "Products"
            ),
            management=self.llm_results.leadership["leadership"],
            hq=find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "HQ_Founded"
            ),
            founded=find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "HQ_Founded"
            ),
            overview=find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "Overview"
            )
            + "\n"
            + find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "Overview_Stats"
            ),
        )

        # self.relevant_news = self.get_company_news()
        self.relevant_news = ""

    def get_data(self):
        "Get the data from the database"
        self.company_info = CompanyInfo.get_company_info(self.company_url)
        self.llm_results = LLMResults.get_LLM_results(self.company_url)

        ### if there is no logo, pull in the logo
        if self.company_info.logo is None:
            self.company_overview = CompanyProfileMaker(self.company_url)
            self.company_info = asyncio.run(
                self.company_overview.load_brand(
                    supplied_company_info=self.company_info
                )
            )
        if self.company_info.logo:
            setattr(
                self.company_info.logo,
                "img_loc",
                asyncio.run(
                    self.download_image(
                        self.company_info.logo.logo_url,
                        "logo",
                        self.company_info.logo.logo_dark,
                    )
                ),
            )

        ### if there is no leadership, get leadership
        if self.llm_results.leadership is None:
            self.company_overview = CompanyProfileMaker(self.company_url)
            company_leadership_text = find_llm_response(
                self.llm_results.model_dump()["LLM_results"], "Leadership"
            )
            self.llm_results.leadership = asyncio.run(
                self.company_overview.map_leadership(company_leadership_text)
            )

        asyncio.run(self.search_management_pictures())

    async def download_image(
        self, image_url: str, file_name: str, dark_background=True
    ):
        "Download the logo from the company url"

        di = DownloadImage()
        if image_url == "":
            image_url = di.STOCK_IMAGE_FILENAME

        _, image_loc = await di.get_image(
            url=image_url,
            file_name=f"{self.company_info.company_name}_{file_name}.png",
            picture_background_dark=dark_background,
        )

        return image_loc

    async def search_management_pictures(self):
        "Search for the image of the management person on google"

        search_tasks = []
        for leader in self.llm_results.leadership["leadership"]:
            google_search = GoogleSearch()
            search_tasks.append(
                asyncio.create_task(
                    google_search.search_google(
                        f'"{self.company_info.company_name}" "{leader.name}" headshot square',
                        "images",
                        domain_check=True,
                        domain_name_to_check="linkedin.com",
                        num_of_results=100,
                    )
                )
            )

        results_of_results = await asyncio.gather(*search_tasks)

        image_results = []

        for results in results_of_results:
            items_to_remove = []
            for result in results:
                if (
                    "https://static.licdn.com/aero-v1/sc/h/1c5u578iilxfi4m4dvc4q810q"
                    in result["imageUrl"]
                ):
                    items_to_remove.append(result)
            for item in items_to_remove:
                results.remove(item)

        for i, leader in enumerate(self.llm_results.leadership["leadership"]):

            # find the image closest to 1 aspect ratio
            if len(results_of_results[i]) == 0:
                image_url = ""
            else:
                image_url = results_of_results[i][0]["imageUrl"]

            for result in results_of_results[i]:

                image_aspect_ratio = result["imageWidth"] / result["imageHeight"]
                if 0.9 < image_aspect_ratio < 1.2:
                    image_url = result["imageUrl"]
                    break

            task = asyncio.create_task(
                self.download_image(image_url, leader.name, False)
            )

            image_results.append(task)

        image_results = await asyncio.gather(*image_results)

        for i, leader in enumerate(self.llm_results.leadership["leadership"]):
            setattr(
                leader,
                "img_loc",
                image_results[i],
            )

        return self.llm_results.leadership

    @staticmethod
    def convert_llm_to_ci(llm_response_category: str):
        "Convert the llm response category to the company info category"

        llm_to_ci_mapping = {
            "Products": "products",
            "Leadership": "management",
            "Customers": "customers",
            "Overview_Stats": "overview",
            "Overview": "overview",
            "HQ_Founded": "hq",
        }

        if llm_response_category not in llm_to_ci_mapping:
            return None

        return llm_to_ci_mapping[llm_response_category]

    def get_slide_layout(self, layout_name):
        """Get the slide layout by name
        input:
        - layout_name: the name of the layout to get
        output:
        - layout: the layout
        - slide_structure: the structure of the slide
        """
        for layout in self.prs.slide_layouts:
            if layout.name == layout_name:
                if layout_name in slide_layouts_choices:
                    slide_structure = slide_layouts_choices[layout_name]
                else:
                    slide_structure = None

                return layout, slide_structure

        return None, None

    def create_slides(self, slide_to_create: slide_layouts_choices_keys):
        "Create the slides for the ppt"
        # for slide_to_create in self.slides_to_create:
        if slide_to_create not in slide_layouts_choices:
            raise ValueError(f"{slide_to_create} not in slide_layouts_choices")

        layout, slide_structure = self.get_slide_layout(slide_to_create)
        try:
            slide = self.prs.slides.add_slide(layout)

            if slide_to_create == "private_company_profile":
                self.create_private_company_profile(slide, slide_structure)

            elif slide_to_create == "stock_chart_output":
                # chart_class = ChartCreator()
                self.add_stock_chart_slide(slide, slide_structure, years_to_show=5)

            elif slide_to_create == "extra_content":
                self.extra_content_slide(
                    slide_to_create,
                    self.company_overview.company_name,
                    self.company_overview.logo,
                    slide_structure["header"],
                    slide_structure["content"],
                )

        except Exception as e:  # pylint: disable=broad-except
            print(e)
            return None

        return slide

    def create_private_company_profile(self, slide, slide_structure):
        "Create the slide"
        for place_holder in slide.placeholders:
            item_name = place_holder.name
            if item_name in slide_structure:
                # if the placeholder is a content type, then replace the text with the data from the company info
                if slide_structure[item_name]["field"] == "content":
                    if slide_structure[item_name]["data"] == "news":
                        continue
                        # TODO implement news
                        # if self.relevant_news is None or self.relevant_news == "":
                        #     continue
                        # self.write_content(
                        #     place_holder,
                        #     '\n'.join([f"{title}: {self.relevant_news.articles_summary[i]}" for i, title in enumerate(self.relevant_news.articles)]),
                        #     slide_structure[item_name],
                        # )
                        # slide.notes_slide.notes_text_frame.text = "News Links:\n" + '\n\n'.join(
                        #         [
                        #             f"{self.relevant_news.articles[i]}\n{link}"
                        #             for i, link in enumerate(self.relevant_news.articles_links)
                        #         ]
                        #     )
                        if self.relevant_news is None or self.relevant_news == "":
                            continue
                        self.write_content(
                            place_holder,
                            "\n".join(
                                [
                                    f"{title}: {self.relevant_news.articles_summary[i]}"
                                    for i, title in enumerate(
                                        self.relevant_news.articles
                                    )
                                ]
                            ),
                            slide_structure[item_name],
                        )
                        slide.notes_slide.notes_text_frame.text = (
                            "News Links:\n"
                            + "\n\n".join(
                                [
                                    f"{self.relevant_news.articles[i]}\n{link}"
                                    for i, link in enumerate(
                                        self.relevant_news.articles_links
                                    )
                                ]
                            )
                        )
                    else:
                        self.write_content(
                            place_holder,
                            getattr(
                                self.company_overview,
                                slide_structure[item_name]["data"],
                            ),
                            slide_structure[item_name],
                        )

                # if the placeholder is a header type, then replace the text with the data from the structure
                elif slide_structure[item_name]["field"] == "header":
                    place_holder.text = slide_structure[item_name]["data"]

                # if the placeholder is a logo_image type, then replace the image with the data from the company info
                elif slide_structure[item_name]["field"] == "logo_image":
                    self.add_logo(place_holder, self.company_overview.logo, slide)

                elif slide_structure[item_name]["field"] == "leader_picture":

                    if (
                        len(self.company_overview.management)
                        > slide_structure[item_name]["data"]
                    ):
                        leader = self.company_overview.management[
                            slide_structure[item_name]["data"]
                        ]
                        if leader.img_loc is not None:
                            place_holder.insert_picture(leader.img_loc)

                elif slide_structure[item_name]["field"] == "leader_name":

                    if (
                        len(self.company_overview.management)
                        > slide_structure[item_name]["data"]
                    ):
                        leader = self.company_overview.management[
                            slide_structure[item_name]["data"]
                        ]

                        # add the leader name and title to the placeholder
                        # make the leadername bold
                        run = place_holder.text_frame.paragraphs[0].add_run()
                        run.text = leader.name
                        run.font.bold = True

                        # add a new line and add the title
                        run = place_holder.text_frame.paragraphs[0].add_run()
                        run.text = f"\n{leader.title}"

        if len(self.company_overview.management) > 4:
            self.add_leadership(
                "extra_content",
                self.company_overview.company_name,
                self.company_overview.logo,
                "Leadership",
                self.company_overview.management,
            )

    def add_stock_chart_slide(self, slide, slide_structure, years_to_show=5):
        "add the stock chart slides"
        for place_holder in slide.placeholders:
            item_name = place_holder.name
            if item_name in slide_structure:
                # if the placeholder is a content type, then replace the text with the data from the company info
                if slide_structure[item_name]["field"] == "content":
                    self.write_content(
                        place_holder,
                        getattr(
                            self.company_overview, slide_structure[item_name]["data"]
                        ),
                        slide_structure[item_name],
                    )

                # if the placeholder is a header type, then replace the text with the data from the structure
                elif slide_structure[item_name]["field"] == "header":
                    place_holder.text = slide_structure[item_name]["data"]

                # if the placeholder is a logo_image type, then replace the image with the data from the company info
                elif slide_structure[item_name]["field"] == "logo_image":
                    self.add_logo(place_holder, self.company_overview.logo, slide)

                elif slide_structure[item_name]["field"] == "stock_chart":
                    # get the data
                    chart_creator = ChartCreator()
                    stock_data_df = chart_creator.get_stock_data(
                        self.company_info.stock_ticker,
                        years_to_show,
                        other_tickers=self.comparable_tickers,
                    )
                    chart_creator.add_line_chart(
                        slide,
                        place_holder.top,
                        place_holder.left,
                        place_holder.width,
                        place_holder.height,
                        stock_data_df,
                        x_axis="date",
                    )
                    # delete the placeholder
                    place_holder.element.getparent().remove(place_holder.element)

    def add_logo(self, place_holder, logo_item, slide_obj):
        "Adds the logo in place of the placeholder and resizes it to fit the placeholder"
        if logo_item is None:
            # add text for no logo found
            place_holder.text = "No Logo Found"
            return

        new_pic = slide_obj.shapes.add_picture(
            logo_item.img_loc,
            top=place_holder.top,
            left=place_holder.left,
            height=place_holder.height,
        )

        place_holder.element.getparent().remove(place_holder.element)

        if new_pic.width > place_holder.width:
            rescale = place_holder.width / new_pic.width
            new_pic.width = int(new_pic.width * rescale)
            new_pic.height = int(new_pic.height * rescale)

    def write_content(self, place_holder, content, content_attributes):
        "Write the content to the placeholder"
        content = content.replace("- ", "")

        text_to_write, text_over_height = self.estimate_text_size(
            content,
            max_width_px=(
                (place_holder.width - Pt(0.25) - Pt(0.2)) / Pt(1)
            ),  # subtract the bullet spacing and the padding
            max_height_px=(place_holder.height / Pt(1))
            * content_attributes["height_multiplier"],
            font_size=content_attributes["font_size"],
            font=content_attributes["font_type"],
        )

        place_holder.text = text_to_write

        if text_over_height:
            self.extra_content_slide(
                "extra_content",
                self.company_overview.company_name,
                self.company_overview.logo,
                content_attributes["data"],
                content,
            )

    def extra_content_slide(self, slide_name, company_name, logo, header, content=""):
        "Create an extra content slide"
        layout, slide_structure = self.get_slide_layout(slide_name)
        slide = self.prs.slides.add_slide(layout)
        slide.name = f"{company_name}_{header}_extra_content"

        header = header.capitalize()

        for place_holder in slide.placeholders:
            item_name = place_holder.name

            if slide_structure.get(item_name, None) is not None:

                if slide_structure[item_name]["field"] == "content":
                    place_holder.text = locals()[slide_structure[item_name]["data"]]

                elif slide_structure[item_name]["field"] == "image":
                    if logo is not None:
                        self.add_logo(place_holder, logo, slide)

        return slide

    def add_leadership(self, slide_name, company_name, logo, header, leadership: list):
        "Add the leadership to the slide"
        slide = self.extra_content_slide(slide_name, company_name, logo, header)
        for place_holder in slide.placeholders:
            if place_holder.name == "Content Placeholder 1":
                content_left = place_holder.left
                content_top = place_holder.top

                for leader in leadership:
                    try:
                        # picture
                        new_pic = slide.shapes.add_picture(
                            leader.img_loc,
                            # auto_shape_type = MSO_SHAPE.OVAL,
                            top=content_top,
                            left=content_left,
                            height=Inches(0.56),
                            width=Inches(0.56),
                        )

                        new_pic.auto_shape_type = MSO_SHAPE.OVAL

                        # add a text box to the right of the picture
                        left = content_left + new_pic.width + Inches(0.1)
                        top = content_top
                        width = Inches(2.36)
                        height = Inches(0.56)

                        txBox = slide.shapes.add_textbox(left, top, width, height)
                        txBox.text_frame.word_wrap = True

                        run = txBox.text_frame.paragraphs[0].add_run()
                        run.text = leader.name
                        run.font.size = Pt(12)
                        run.font.bold = True
                        # add a new line and add the title
                        run = txBox.text_frame.paragraphs[0].add_run()
                        run.text = f"\n{leader.title}"
                        run.font.size = Pt(12)

                        # move the content top down if content top is more 5 inches from the top, then move the content left to the left
                        if content_top > Inches(5):
                            content_left += Inches(3.5)
                            content_top = place_holder.top

                        else:
                            content_top += Inches(0.56) + Inches(0.25)

                    except Exception as e:  # pylint: disable=broad-except
                        print(e)
                        continue

                place_holder.element.getparent().remove(place_holder.element)

        return slide

    def estimate_text_size(
        self,
        text: str,
        max_width_px: int,
        max_height_px: int,
        font=None,
        font_size=None,
    ):
        """
        estimate the size of a text body
        input:
        - text: the text to estimate
        - max_width_px: the max width of the text in pixels
        - max_height_px: the max height of the text in pixels
        - font: the font to use
        - font_size: the font size to use

        output:
        - text_to_draw: the text to draw
        - text_over_height: if the text is over the height
        """
        if font is None:
            font = self.default_font
        if font_size is None:
            font_size = self.default_font_size

        # get the font_type
        fnt = ImageFont.truetype(font, font_size)

        # split the text into lines
        text_lines = text.split("\n")

        # loop through the text lines and see what the max width and height is
        text_to_draw = []
        text_over_height = False
        for line in text_lines:
            if line == "" or line is None:
                continue

            # estimate the line height
            line_height = (
                math.ceil(fnt.getmask(line).getbbox()[2] / max_width_px)
                * font_size
                * 1.5
            )
            if max_height_px > line_height:
                text_to_draw.append(line)

            max_height_px -= line_height

        if max_height_px < 0:
            text_over_height = True

        return "\n".join(text_to_draw), text_over_height

    def get_company_news(self):
        "Get the company news"
        search_google = GoogleSearch()
        results = asyncio.run(
            search_google.search_google(f"{self.company_info.company_name}", "news")
        )
        if len(results) == 0:
            return None
        # run through an LLM model to filter the news
        return self.filter_news(results)

    def filter_news(self, news: list):
        "Filter the news"
        # run through an LLM model to filter the news
        llm = ChatOpenAI(
            model=OPENAI_MODEL_35,
            api_key=os.environ["OPENAI_API_KEY"],
            temperature=0.2,
        )
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """
            You are going to be given news articles with the title and snippet. You need to determine if the news is relevant to your client.
            You need to return the articles that are relevant, with a summary and the link to the article.
            Make sure the article is relevant to the company and industry you are working with.
             """,
                ),
                (
                    "human",
                    "The company you are working with is {company_name} in the {industry} industry with an overview of this {overview}",
                ),
                ("human", "{articles}"),
                (
                    "human",
                    "Output format should be: {json_structure} and give me atleast 5 articles if not more",
                ),
            ]
        )

        parser = JsonOutputParser(pydantic_object=NewArticles)

        chain = prompt | llm | parser
        output = chain.invoke(
            {
                "company_name": self.company_info.company_name,
                "industry": self.company_info.industry,
                "overview": self.company_overview.overview,
                "articles": json.dumps(news),
                "json_structure": parser.get_format_instructions(),
            }
        )

        try:
            output = NewArticles(**output)
        except Exception:  # pylint: disable=broad-exceptfas
            time.sleep(5)
            output = chain.invoke(
                {
                    "company_name": self.company_info.company_name,
                    "industry": (
                        self.company_info.industry if self.company_info.industry else ""
                    ),
                    "overview": self.company_overview.overview,
                    "articles": json.dumps(
                        [
                            {
                                "title": article["title"],
                                "snippet": article["snippet"],
                                "link": article["link"],
                            }
                            for article in news
                        ]
                    ),
                    "json_structure": parser.get_format_instructions(),
                }
            )
            output = NewArticles(**output)

        return output


if __name__ == "__main__":

    urls = [
        "unity.com",
    ]

    for url in urls:

        # prs_object = pptx.Presentation("ppt_templates/ppt_template.pptx")

        # new_slide = prs_object.slides.add_slide(prs_object.slide_layouts[1])

        pptg = CompanyProfileSlide(
            company_url=url,
            new_deck=True,
        )
        pptg.create_slides("private_company_profile")
        if pptg.company_info.stock_ticker is not None:
            pptg.create_slides("stock_chart_output")

        pptg.prs.save(f"{pptg.company_info.company_name}_company_profile.pptx")
        print(f"{pptg.company_info.company_name} PPT created")
