"Class to contain the slide creation process"

import asyncio
import json
import math
import os

import requests
from PIL import Image
from pptx.dml.color import RGBColor
from pptx.enum.shapes import (  # pylint: disable=import-error
    MSO_AUTO_SHAPE_TYPE, MSO_SHAPE)
from pptx.enum.text import MSO_ANCHOR, PP_ALIGN  # pylint: disable=import-error
from pptx.util import Inches, Pt

from services.ppt_generator.data_classes.slide_layout_models import *
from services.ppt_generator.slide_research import SlideResearch
from utils.client_check import ClientConfig
from utils.download_image import DownloadImage
from utils.search_google import GoogleSearch


class SlideCreator:
    "Class to contain the slide creation process"

    slide_max_width = Inches(13.5)
    icon_circle_size = Inches(0.35)

    default_font_size = 18
    default_font_size_body = 12

    takeaway_properties = {
        "height": Inches(0.5),
        "font_size": 24,
        "alignment": PP_ALIGN.CENTER,
        "font_color": RGBColor(255, 255, 255),
    }

    summary_text_properties = {
        "font_size": 38,
        "font_color": RGBColor(0, 0, 0),
        "alignment": PP_ALIGN.CENTER,
    }

    flowchart_properties = {
        "min_chevron_length": Inches(0.75),
        "chevron_height": Inches(0.5),
        "chevron_spacing": -Inches(0.2),
        "text_spacing_from_chevron": Inches(0.1),
        "font_size_header": 18,
        "font_size_text": 14,
        "font_color_header": RGBColor(255, 255, 255),
        "font_color_text": RGBColor(0, 0, 0),
    }

    landscape_properties = {
        "max_rows": 3,
        "max_columns": 4,
        "min_spacing": Inches(0.25),
    }

    images = {
        "height": Inches(0.5),
        "spacing": Inches(0.1),
        "max_height": Inches(1.5),
    }

    research = ""
    secondary_research = None
    secondary_research_text = ""

    subtitle_properties = {
        "font_size": 24,
    }

    graph_properties = {"font_size": 18, "font_color": RGBColor(255, 255, 255)}

    process_flow_properties = {
        "font_size_header": 18,
        "font_size_text": 14,
        "font_color_header": RGBColor(255, 255, 255),
        "font_color_text": RGBColor(0, 0, 0),
    }

    def __init__(
        self,
        project,
        slide,
        conduct_research_flag=True,
        client=ClientConfig("xcm").get_client_config(),
        prs_slide_obj=None,
        research: str = None,
    ):
        "initialize the class"

        self.project = project
        self.slide = slide
        self.prs_slide_object = prs_slide_obj

        self.client_config = client

        self.research = research

        self.item_to_function_mapping = {
            "summary_text": self.summary_text,
            # "summary_text_icons": self.summary_text_icons,
            "bullet_text": self.bullet_text,  # in
            "icon_text": self.icon_text,  # in
            "takeaway": self.takeaway,
            "icon_circle": self.icon_circle,
            "solution_diagram": self.iconography,
            "time_line_graph": self.graph,  # TODO change to a line graph with a time x-axis
            "20_token_text": self.summary_text,
            "graph": self.graph,  # in
            "graph_with_text": self.graph_with_text,  # in
            "iconography": self.iconography,  # in
            "full_page_iconography": self.full_page_iconography,
            "infographic": self.full_page_iconography,
            "flowchart": self.flowchart,  # in
            "circular_process_flow": self.circular_process_flow,  # in
            "timeline": self.timeline,  # in
            "centered_timeline": self.centered_timeline,  # in
            "top_timeline": self.top_timeline,  # in
            "thirds_matrix": self.thirds_matrix,  # in
            "landscape": self.landscape,  # in
            "grid": self.grid,  # in
            "subtitle_content": self.subtitle_content,
            "gantt_chart": self.gantt_chart,
            "venn_diagram": self.venn_diagram,
            "overlapping_elements": self.overlapping_elements,
        }

    ## Text based only
    def summary_text(self, ppt_slide_object, top, left, width, height, _text=None):
        """
        Add a summary text to the slide
        ppt_slide_object: pptx slide object
        top: int in em for PPT
        left: int in em for PPT
        width: int in em for PPT
        height: int in em for PPT
        text: str
        """

        if _text is None:
            slide_research = SlideResearch(
                slide=self.slide,
                primary_data=self.research,
                # secondary_data=self.secondary_research_text,,
                parsing_class=token20_text,
                prompt_addition="Keep the summary text to 20 tokens or less.",
            )
            output_data = slide_research._format()
            summary_text = output_data["content"]
        else:
            summary_text = _text

        # split the work area into 3rds and this will be in the middle 3rd
        left = left + width / 3
        width = width / 3
        top = top + height / 3
        height = height / 3

        # add a rectangle to the slide
        summary_text_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, height
        )

        # set rectangle to no fill, no line, and black font
        self.set_object_fill_and_line(summary_text_obj)

        # get the text frame
        text_frame = summary_text_obj.text_frame
        self.add_text(
            text_frame.paragraphs[0],
            summary_text,
            font_size=self.summary_text_properties["font_size"],
            font_color=self.summary_text_properties["font_color"],
            alignment=self.summary_text_properties["alignment"],
        )

        return ppt_slide_object

    @classmethod
    def unpack_bullet_text(cls, bullet_text):
        "Unpack the bullet text"
        return_text = []
        if isinstance(bullet_text, str):
            return bullet_text.split("\n")

        if isinstance(bullet_text, list):
            for text in bullet_text:
                return_text.extend(cls.unpack_bullet_text(text))

        return return_text

    def bullet_text(self, ppt_slide_object, top, left, width, height):
        """
        Add a summary text to the slide
        ppt_slide_object: pptx slide object
        top: int in em for PPT
        left: int in em for PPT
        width: int in em for PPT
        height: int in em for PPT
        text: str
        bullets: list of str
        """

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=bullet_text,
            prompt_addition="Keep the summary text to 20 tokens or less.",
        )
        output_data = slide_research._format()
        bullet_texts = output_data["content"]

        # bullet_text = self.format_input_to_list(bullet_text)
        if isinstance(bullet_texts, str):
            bullet_texts = bullet_texts.split("\n")
        if isinstance(bullet_texts, list):
            if isinstance(bullet_texts[0], dict):
                bullet_texts = [i["text"] for i in bullet_texts]
        if "text" in bullet_texts:
            bullet_texts = bullet_texts["text"]

        bullet_texts = self.unpack_bullet_text(bullet_texts)

        bullet_text_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, height
        )

        # set rectangle to no fill, no line, and black font
        self.set_object_fill_and_line(bullet_text_obj)

        bullet_text_frame = bullet_text_obj.text_frame
        bullet_text_frame.vertical_anchor = MSO_ANCHOR.TOP
        bullet_text_frame.alignment = PP_ALIGN.LEFT
        bullet_text_frame.word_wrap = True

        for i, bullet in enumerate(bullet_texts):

            # all textframes have a paragraph
            if i == 0:
                bullet_p = bullet_text_frame.paragraphs[0]
            else:
                bullet_p = bullet_text_frame.add_paragraph()

            self.add_text(bullet_p, bullet, font_size=18)

        return ppt_slide_object

    def takeaway(self, ppt_slide_object, top, left, width, height):
        "add takeaway to the slide"

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # # secondary_data=self.secondary_research_text,,
            parsing_class=token20_text,
            prompt_addition="Keep the summary text to 20 tokens or less.",
        )
        output_data = slide_research._format()
        takeaway = output_data["content"]

        top = top + height - self.takeaway_properties["height"]

        takeaway_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, self.takeaway_properties["height"]
        )

        self.set_object_fill_and_line(
            takeaway_obj, fill_color=RGBColor(0, 0, 0), line_color=RGBColor(0, 0, 0)
        )
        self.add_text(
            takeaway_obj.text_frame.paragraphs[0],
            takeaway,
            font_color=self.takeaway_properties["font_color"],
            font_size=self.takeaway_properties["font_size"],
            alignment=self.takeaway_properties["alignment"],
        )

        return ppt_slide_object

    def subtitle_content(self, ppt_slide_object, top, left, width, height):
        "add subtitle to the slide"

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=token20_text,
            prompt_addition="Keep the summary text to 20 tokens or less.",
        )
        output_data = slide_research._format()
        subtitle_content = output_data["content"]

        takeaway_obj_height = Inches(0.5)

        takeaway_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, takeaway_obj_height
        )
        self.set_object_fill_and_line(takeaway_obj)
        self.add_text(
            takeaway_obj.text_frame.paragraphs[0],
            subtitle_content,
            font_size=self.subtitle_properties["font_size"],
            alignment=PP_ALIGN.LEFT,
        )

        # return the object and the new top and height
        return (
            ppt_slide_object,
            top + takeaway_obj_height + Inches(0.125),
            height - takeaway_obj_height - Inches(0.125),
        )

    ## Text and Icons

    def icon_circle(self, ppt_slide_object, top, left, width, height):
        """
        Add a summary text to the slide
        ppt_slide_object: pptx slide object
        top: int in em for PPT
        left: int in em for PPT
        width: int in em for PPT
        height: int in em for PPT
        text: str
        summary_text_icons: dict of {'text': str, 'icons': list of str}
        """

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=icon_circle,
            prompt_addition="The icons amount should be between 4 to 10.",
        )
        output_data = slide_research._format()

        # add the summary text to the slide
        self.summary_text(
            ppt_slide_object, top, left, width, height, output_data["content"]
        )
        if width < height:
            height = width

        # set everything to be centered with the top and left
        # circle radius will be the height /2
        height -= self.icon_circle_size * 3
        width = height

        rads = 2 * math.pi / len(output_data["icons"])

        top = top + height / 2
        left = left - width / 2

        for i, icon in enumerate(output_data["icons"]):
            # get the x, y for the circle and icon
            x = left + width / 2 * math.sin(i * rads)
            y = top - height / 2 * math.cos(i * rads)

            # add circle
            icon_obj = ppt_slide_object.shapes.add_shape(
                MSO_SHAPE.OVAL,
                x,
                y,
                self.icon_circle_size * 1.5,
                self.icon_circle_size * 1.5,
            )

            self.set_object_fill_and_line(
                icon_obj,
                fill_color=RGBColor(0, 0, 0),
                line_color=RGBColor(0, 0, 0),
                line_width=2,
            )

            # add an image to the obj
            # for the time being, we will add a placeholder image
            icon_loc = self.search_and_download_image(
                search_query=icon["icon_search_query"], square_ratio=True
            )[1]
            # TODO add icon a better version of the icon search

            pic_obj = ppt_slide_object.shapes.add_picture(
                icon_loc,
                x,
                y,
                width=self.icon_circle_size,
            )

        return ppt_slide_object

    def iconography(self, ppt_slide_object, top, left, width, height):
        """add iconography to the slide
        iconograpth is a list [search_query, design_instruction]
        """
        # iconography = self.format_input_to_list(iconography)

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=iconography,
            prompt_addition="This is going to be a single image",
        )
        output_data = slide_research._format()

        search_query = output_data["search_query"]
        design_instruction = output_data["design_instructions"]
        # TODO use the design insructions for llm generation

        # add design_instruction to the notes section for the slide
        ppt_slide_object.notes_slide.notes_text_frame.text += (
            "\n" + f"Design instructions for image: {design_instruction}"
        )

        # search for the image
        google_search = GoogleSearch()
        images_url = asyncio.run(google_search.search_google(search_query, "images"))

        for image_url in images_url:

            # download the image
            download_image = DownloadImage()
            try:
                img, img_loc = asyncio.run(
                    download_image.get_image(
                        image_url["imageUrl"], "images/iconography.png"
                    )
                )
            except requests.exceptions.InvalidSchema as e:
                print(e)
                continue

            # add the image to the slide
            img = ppt_slide_object.shapes.add_picture(
                img_loc, left, top, width=width if width < height else height
            )

            # for the rest of the images, put them in the margin
            top += img.height
            left = self.slide_max_width
            width = Inches(1)

        return ppt_slide_object

    def full_page_iconography(self, ppt_slide_object, top, left, width, height):
        "add full page iconography to the slide"
        standard_margin = Inches(0.5)
        top = top + standard_margin
        left = left + standard_margin
        width = width - standard_margin
        height = height - standard_margin

        self.iconography(ppt_slide_object, top, left, width, height)

        return ppt_slide_object

    ## Graph based
    def graph(self, ppt_slide_object, top, left, width, height, parser_class=None):
        """add graph to the slide
        graph is a list [graph_type, data, data_source, search_query, takeaway_text]
        """

        if parser_class is None:
            parser_class = graph

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=parser_class,
        )
        graph_data = slide_research._format()

        # add a rectangle header for the graph
        graph_header_height = Inches(0.35)
        graph_header = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, graph_header_height
        )
        self.set_object_fill_and_line(
            graph_header, fill_color=RGBColor(0, 0, 0), line_color=RGBColor(0, 0, 0)
        )

        graph_header_text = graph_data.get("takeaway_text", graph_data["title"])

        self.add_text(
            graph_header.text_frame.paragraphs[0],
            graph_header_text,
            font_color=self.graph_properties["font_color"],
            font_size=self.graph_properties["font_size"],
            alignment=PP_ALIGN.LEFT,
        )

        # add a rectangle for the graph
        graph_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE,
            left,
            top + graph_header_height * 2,
            width,
            height - graph_header_height * 2,
        )
        self.set_object_fill_and_line(graph_obj, fill_color=RGBColor(200, 0, 0))
        self.add_text(
            graph_obj.text_frame.paragraphs[0],
            graph_data["graph_type"]
            + "\nData: "
            + json.dumps(graph_data["data"]),  # TODO change to adjustments
            font_color=RGBColor(255, 255, 255),
            font_size=18,
            alignment=PP_ALIGN.LEFT,
        )

        # FIXME for the time being, we will add a placeholder image
        google_search = GoogleSearch()
        images_url = asyncio.run(google_search.search_google("bar chart", "images"))
        top = 0
        left = self.slide_max_width
        img_height = Inches(2)
        for image_url in images_url:

            # download the image
            download_image = DownloadImage()
            img_loc = asyncio.run(
                download_image.get_image(image_url["imageUrl"], "images/graph.png")
            )[1]

            # add the image to the slide
            ppt_slide_object.shapes.add_picture(img_loc, left, top, height=img_height)
            # for the rest of the images, put them in the margin
            top += img_height
            break

        return ppt_slide_object

    def graph_with_text(
        self,
        ppt_slide_object,
        top,
        left,
        width,
        height,
    ):
        """add graph with text to the slide
        graph_with_text is a list [graph_type, data, data_source, search_query, text]
        """

        text_height = Inches(0.5)
        self.graph(
            ppt_slide_object,
            top,
            left,
            width,
            height - text_height,
            graph_with_text,
        )

        return ppt_slide_object

    ## Circular Based
    def circular_process_flow(self, ppt_slide_object, top, left, width, height):
        """
        add circular process flow to the slide
        circular_process_flow: list of [icon, header, text]
        """

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=flowchart,
            prompt_addition="The icons amount should be between 5 to 10.",
        )
        output_data = slide_research._format()
        circular_process_flow = output_data["data"]
        # circular_process_flow = self.format_input_to_list(circular_process_flow)
        circular_process_flow = [
            [i["icon_content"]["icon_search_query"], i["title"], i["text"]]
            for i in circular_process_flow
        ]

        rads = 2 * math.pi / len(circular_process_flow)

        # determine the size of the circle
        circle_size = min(width, height) - Inches(0.5)

        # determine the top and left of the circle
        top = top + height / 2 - circle_size / 2
        left = left + width / 2 - circle_size / 2

        # add the circle to the slide
        circle_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.OVAL, left, top, circle_size, circle_size
        )
        self.set_object_fill_and_line(circle_obj, line_color=RGBColor(0, 0, 0))

        # set left and top to the center of the circle
        left = left + circle_size / 2
        top = top + circle_size / 2

        for i, process in enumerate(circular_process_flow):
            icon_size = circle_size / 5
            # get the x, y for the circle and icon
            x = left + circle_size / 2 * math.sin(i * rads) - icon_size / 2
            y = top - circle_size / 2 * math.cos(i * rads) - icon_size / 2

            # add an image of the icon to the circle in the shape of a circle
            # for the time being, we will add a placeholder image

            icon_image, icon_loc = self.search_and_download_image(
                search_query=process[0] + " icon", square_ratio=True
            )
            icon_image = ppt_slide_object.shapes.add_picture(
                icon_loc, x, y, width=icon_size
            )
            icon_image.auto_shape_type = MSO_AUTO_SHAPE_TYPE.OVAL

            # add the header to the slide
            # offset the header from the circle
            left_offset = (
                icon_size * 1.25
                if i * rads < math.pi
                else -icon_size * 0.25 - Inches(2)
            )

            text_top = y
            text_left = x + left_offset
            text_width = Inches(2)
            text_height = Inches(0.5)
            text_obj = ppt_slide_object.shapes.add_shape(
                MSO_SHAPE.RECTANGLE, text_left, text_top, text_width, text_height
            )
            self.set_object_fill_and_line(text_obj, fill_color=RGBColor(0, 0, 0))
            self.add_text(
                text_obj.text_frame.paragraphs[0],
                process[1],
                font_color=self.process_flow_properties["font_color_header"],
                font_size=self.process_flow_properties["font_size_header"],
                alignment=PP_ALIGN.CENTER,
            )

            # add the text to the slide
            # offset the text from the header
            text_top = y + Inches(0.5)
            text_left = x + left_offset

            text_width = Inches(2)
            text_height = Inches(1)
            text_obj = ppt_slide_object.shapes.add_shape(
                MSO_SHAPE.RECTANGLE, text_left, text_top, text_width, text_height
            )
            self.set_object_fill_and_line(
                text_obj, line_color=RGBColor(0, 0, 0), line_width=2
            )
            self.add_text(
                text_obj.text_frame.paragraphs[0],
                process[2],
                font_color=self.process_flow_properties["font_color_text"],
                font_size=self.process_flow_properties["font_size_text"],
                alignment=PP_ALIGN.LEFT,
            )

        return ppt_slide_object

    def flowchart(self, ppt_slide_object, top, left, width, height, parsing_class=None):
        """add flowchart to the slide
        flowchart: list of [icon, header limited to 5 tokens, text]
        """
        if parsing_class is None:
            parsing_class = flowchart

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=parsing_class,
            prompt_addition="The icons amount should be between 5 to 10.",
        )
        output_data = slide_research._format()

        # TODO add the title to the slide

        flowchart_data = output_data["data"]
        flowchart_data = [
            [i["icon_content"]["icon_search_query"], i["title"], i["text"]]
            for i in flowchart_data
        ]

        flowchart_orientation = "horizontal"

        min_chevron_length = self.flowchart_properties["min_chevron_length"]
        chevron_height = self.flowchart_properties["chevron_height"]
        chevron_spacing = self.flowchart_properties["chevron_spacing"]
        text_spacing_from_chevron = self.flowchart_properties[
            "text_spacing_from_chevron"
        ]

        # of chevrons that can fit in the slide
        vertical_chevron_length = (
            height - (len(flowchart_data) - 1) * chevron_spacing
        ) / len(flowchart_data)
        horizontal_chevron_length = (
            width - (len(flowchart_data) - 1) * chevron_spacing
        ) / len(flowchart_data)

        # set the chevron length
        chevron_width = horizontal_chevron_length

        if horizontal_chevron_length < vertical_chevron_length:
            flowchart_orientation = "vertical"
            chevron_width = vertical_chevron_length

            # top_i = top
            left_i = left

            top = top - (chevron_height - chevron_width) / 2
            left = left - (chevron_width - chevron_height) / 2

        if chevron_width < min_chevron_length:
            raise ValueError("The flowchart is too large for the slide")

        # add the chevrons to the slide
        for flow in flowchart_data:

            chevron = ppt_slide_object.shapes.add_shape(
                MSO_SHAPE.CHEVRON, left, top, chevron_width, chevron_height
            )
            chevron.rotation = (
                90 if flowchart_orientation == "vertical" else 0
            )  # in degrees

            self.set_object_fill_and_line(
                chevron, fill_color=RGBColor(0, 0, 0), line_color=RGBColor(0, 0, 0)
            )
            self.add_text(
                chevron.text_frame.paragraphs[0],
                flow[1],
                font_color=self.flowchart_properties["font_color_header"],
                font_size=self.flowchart_properties["font_size_header"],
                alignment=PP_ALIGN.CENTER,
            )

            # add the text below the chevron
            if flowchart_orientation == "vertical":
                text_left = left_i + chevron_height + text_spacing_from_chevron
                text_top = top
                text_width = width - (chevron_width + text_spacing_from_chevron)
                text_height = chevron_height / 1.17

            else:
                text_left = left
                text_top = top + chevron_height + text_spacing_from_chevron
                text_width = chevron_width / 1.17
                text_height = height - (chevron_height + text_spacing_from_chevron)

            text_obj = ppt_slide_object.shapes.add_shape(
                MSO_SHAPE.RECTANGLE, text_left, text_top, text_width, text_height
            )
            self.set_object_fill_and_line(text_obj)

            text_obj.text_frame.vertical_anchor = MSO_ANCHOR.TOP
            self.add_text(
                text_obj.text_frame.paragraphs[0],
                flow[2],
                font_color=self.flowchart_properties["font_color_text"],
                font_size=self.flowchart_properties["font_size_text"],
                alignment=PP_ALIGN.LEFT,
            )

            # if the flowchart is vertical, then the top will be the top will move,
            # if horizontal, then the left will move
            if flowchart_orientation == "vertical":
                top += chevron_width + chevron_spacing
                # left = left
            else:
                # top = top
                left += chevron_width + chevron_spacing

        return ppt_slide_object

    def timeline(self, ppt_slide_object, top, left, width, height):
        """
        add timeline to the slide
        timeline list of [icon, timeline_process_name, text]
        """

        self.flowchart(
            ppt_slide_object, top, left, width, height, parsing_class=timeline
        )

        return ppt_slide_object

    def centered_timeline(self, ppt_slide_object, top, left, width, height):
        """
        add centered timeline to the slide
        centered_timeline list of [icon, timeline_process_name, text]
        """
        top = top + height / 2 - self.flowchart_properties["chevron_height"] / 2
        # add the timeline to the slide
        self.flowchart(
            ppt_slide_object, top, left, width, height, parsing_class=timeline
        )
        return ppt_slide_object

    def top_timeline(self, ppt_slide_object, top, left, width, height):
        "add top timeline to the slide"
        self.flowchart(
            ppt_slide_object, top, left, width, height, parsing_class=timeline
        )
        return ppt_slide_object

    ## Matrix based
    def landscape(self, ppt_slide_object, top, left, width, height):
        """
        add landscape to the slide
        landscape list of [icon, header, text]
        cell_structures list of [height, item, type of item]
        """
        # landscape = self.format_input_to_list(landscape)
        # determine the number of rows and columns

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=landscape,
            prompt_addition="The icons amount should be between 3 to 12.",
        )
        output_data = slide_research._format()
        output_data = output_data["content"]

        landscape_len = len(output_data)
        # if the length of landscape is more than max rows * max columns, then raise an error
        matrix_def = self.define_matrix(width, height, landscape_len)

        cols = matrix_def[1]
        matrix_height = matrix_def[2]
        matrix_width = matrix_def[3]

        # add the landscape to the slide
        cell_structures = []
        for cell in output_data:
            cell_structures.append(
                [
                    [
                        self.icon_circle_size,
                        cell["icon_content"]["icon_search_query"],
                        "icon",
                    ],  # TODO change to a better icon search
                    [
                        (matrix_height - self.icon_circle_size) / 3,
                        cell["title"],
                        "title",
                    ],
                    [
                        (matrix_height - self.icon_circle_size) * 2 / 3,
                        cell["text"],
                        "text",
                        self.default_font_size_body,
                    ],
                ]
            )

        self.create_matrix(
            ppt_slide_object,
            top,
            left,
            cell_structures,
            cols,
            matrix_height,
            matrix_width,
        )

        return ppt_slide_object

    def thirds_matrix(self, ppt_slide_object, top, left, width, height):
        """
        add thirds matrix to the slide
        thirds_matrix list of [icon, header, text]
        """
        # thirds_matrix = self.format_input_to_list(thirds_matrix)

        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=thirds_matrix,
            prompt_addition="The icons amount should be between 3 to 9.",
        )
        output_data = slide_research._format()
        output_data = output_data["content"]

        matrix_len = len(output_data)
        matrix_def = self.define_matrix(width, height, matrix_len)

        cols = matrix_def[1]
        matrix_height = matrix_def[2]
        matrix_width = matrix_def[3]

        cell_structures = []
        for cell in output_data:
            cell_structures.append(
                [
                    [
                        matrix_height / 3,
                        cell["icon_content"]["icon_search_query"],
                        "icon",
                    ],
                    [matrix_height / 3, cell["title"], "title"],
                    [matrix_height / 3, cell["text"], "text"],
                ]
            )

        self.create_matrix(
            ppt_slide_object,
            top,
            left,
            cell_structures,
            cols,
            matrix_height,
            matrix_width,
        )

        return ppt_slide_object

    def grid(self, ppt_slide_object, top, left, width, height, grid_details=None):
        """
        add grid to the slide
        grid list of [icon, text, search_query]
        """
        # grid = self.format_input_to_list(grid)
        if grid_details is None:
            slide_research = SlideResearch(
                slide=self.slide,
                primary_data=self.research,
                # secondary_data=self.secondary_research_text,,
                parsing_class=grid,
                prompt_addition="The grid amount should be between 4 to 12.",
            )

            output_data = slide_research._format()
            output_data = output_data["content"]

        else:
            output_data = grid_details

        grid_len = len(output_data)
        matrix_def = self.define_matrix(width, height, grid_len)

        cols = matrix_def[1]
        matrix_height = matrix_def[2]
        matrix_width = matrix_def[3]

        cell_structures = []
        for cell in output_data:
            cell_structures.append(
                [
                    [
                        matrix_height / 4,
                        cell["title"],
                        "title",
                    ],
                    [matrix_height / 4, cell["text"], "text"],
                    [
                        matrix_height / 2,
                        cell["images"],
                        "image",
                    ],  # TODO change a better approach to images making them in a landscape / grid manner
                ]
            )

        self.create_matrix(
            ppt_slide_object,
            top,
            left,
            cell_structures,
            cols,
            matrix_height,
            matrix_width,
        )

        return ppt_slide_object

    def icon_text(self, ppt_slide_object, top, left, width, height):
        """
        add icon and text
        icon_text: list of [icon, text]
        """
        # icon_text = self.format_input_to_list(icon_text)
        slide_research = SlideResearch(
            slide=self.slide,
            primary_data=self.research,
            # secondary_data=self.secondary_research_text,,
            parsing_class=icon_text,
            prompt_addition="Keep the response between 6 to 10 bullets",
        )
        output_data = slide_research._format()
        output_data = output_data["content"]

        icons_len = len(output_data)

        height_spacing = Inches(0.05)
        height_for_each_bullet = int(
            (height - height_spacing * (icons_len - 1)) / icons_len
        )
        icon_height = min(self.icon_circle_size, int(height_for_each_bullet / 4 * 3))
        width_spacing = Inches(0.1)

        for bullet in enumerate(output_data):
            # make the icon and then add the text
            # download the image for the icon
            icon_image, icon_loc = self.search_and_download_image(
                search_query=bullet[1]["icon_content"]["icon_search_query"],
                square_ratio=True,
            )

            # add the icon to the slide
            icon_obj = ppt_slide_object.shapes.add_picture(
                icon_loc, left, top, width=icon_height
            )

            # add the text to the slide
            text_obj = ppt_slide_object.shapes.add_shape(
                MSO_SHAPE.RECTANGLE,
                left + icon_height + width_spacing,
                top,
                width - icon_height - width_spacing,
                height_for_each_bullet,
            )

            self.set_object_fill_and_line(text_obj)
            self.add_text(
                text_obj.text_frame.paragraphs[0],
                bullet[1]["bullet_text"],
                font_size=self.default_font_size,
                alignment=PP_ALIGN.LEFT,
            )

            top += height_for_each_bullet + height_spacing

        return ppt_slide_object

    ## TODO not implemented
    def gantt_chart(self, ppt_slide_object, top, left, width, height, gantt_chart):
        "add gantt chart to the slide"
        gantt_chart = self.format_input_to_list(gantt_chart)
        # add a rectangle to the slide
        # add a text to the slide for the gantt chart
        gantt_chart_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, height
        )
        self.set_object_fill_and_line(gantt_chart_obj)
        self.add_text(
            gantt_chart_obj.text_frame.paragraphs[0],
            json.dumps(gantt_chart),
            font_size=self.default_font_size,
            alignment=PP_ALIGN.CENTER,
        )

        return ppt_slide_object

    def venn_diagram(self, ppt_slide_object, top, left, width, height, venn_diagram):
        "add venn diagram to the slide"
        venn_diagram = self.format_input_to_list(venn_diagram)
        # add a rectangle to the slide
        # add a text to the slide for the gantt chart
        venn_diagram_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, height
        )
        self.set_object_fill_and_line(venn_diagram_obj)
        self.add_text(
            venn_diagram_obj.text_frame.paragraphs[0],
            json.dumps(venn_diagram),
            font_size=self.default_font_size,
            alignment=PP_ALIGN.CENTER,
        )

        return ppt_slide_object

    def overlapping_elements(
        self, ppt_slide_object, top, left, width, height, overlapping_elements
    ):
        "add overlapping elements to the slide"
        overlapping_elements = self.format_input_to_list(overlapping_elements)
        # add a rectangle to the slide
        # add a text to the slide for the gantt chart
        overlapping_elements_obj = ppt_slide_object.shapes.add_shape(
            MSO_SHAPE.RECTANGLE, left, top, width, height
        )

        self.set_object_fill_and_line(overlapping_elements_obj)
        self.add_text(
            overlapping_elements_obj.text_frame.paragraphs[0],
            json.dumps(overlapping_elements),
            font_size=self.default_font_size,
            alignment=PP_ALIGN.CENTER,
        )

        return ppt_slide_object

    @staticmethod
    def set_object_fill_and_line(
        _object, fill_color=None, line_color=None, line_width=None
    ):
        "Set the fill and line color of the object. If none, set to none"
        if fill_color:
            _object.fill.solid()
            _object.fill.fore_color.rgb = fill_color
        else:
            _object.fill.background()

        if line_color:
            _object.line.color.rgb = line_color
            _object.line.width = line_width
        else:
            _object.line.fill.background()

        # return _object

    @staticmethod
    def add_text(
        _paragraph,
        text,
        font_color=RGBColor(0, 0, 0),
        font_size=12,
        level=0,
        space_after=None,
        alignment=PP_ALIGN.LEFT,
    ):
        "add text to the paragraph"
        if space_after is None:
            space_after = font_size

        _paragraph.text = text
        _paragraph.level = level
        _paragraph.font.size = Pt(font_size)
        _paragraph.font.color.rgb = font_color
        _paragraph.space_after = Pt(space_after)
        _paragraph.alignment = alignment

    @classmethod
    def add_single_cell(cls, ppt_slide_object, top, left, width, height, cell_contents):
        """
        add a single cell to the slide
        top = top of the cell
        left = left of the cell
        width = width of the cell
        height = height of the cell
        cell_contents = list of [height, content, content_type (text, icon, image, etc)]
        """
        for cell_content in cell_contents:
            cell_height = cell_content[0]
            # add a rectangle to the slide

            if cell_content[2] == "text":
                # set rectangle to no fill, no line, and black font
                cell_obj = ppt_slide_object.shapes.add_shape(
                    MSO_SHAPE.RECTANGLE, left, top, width, cell_height
                )
                cls.set_object_fill_and_line(cell_obj)
                cls.add_text(
                    cell_obj.text_frame.paragraphs[0],
                    cell_content[1],
                    font_size=12,
                    alignment=PP_ALIGN.CENTER,
                )

            elif cell_content[2] == "title":
                # set rectangle to no fill, no line, and black font
                cell_obj = ppt_slide_object.shapes.add_shape(
                    MSO_SHAPE.RECTANGLE, left, top, width, cell_height
                )
                cls.set_object_fill_and_line(cell_obj)
                cls.add_text(
                    cell_obj.text_frame.paragraphs[0],
                    cell_content[1],
                    font_size=18,
                    alignment=PP_ALIGN.CENTER,
                )

            elif cell_content[2] == "icon":
                # add an image to the obj
                # for the time being, we will add a placeholder image
                icon_loc = cls.search_and_download_image(
                    search_query=cell_content[1] + " icon", square_ratio=True
                )[1]
                ppt_slide_object.shapes.add_picture(
                    icon_loc, left + width / 2 - cell_height / 2, top, width=cell_height
                )

            elif cell_content[2] == "image":
                if type(cell_content[1]) is not list:
                    cell_content[1] = [cell_content[1]]

                cls.multi_image(
                    ppt_slide_object, top, left, width, height, cell_content[1]
                )

            top += cell_height

        return ppt_slide_object

    @classmethod
    def multi_image(
        cls, ppt_slide_object, top, left, width, height, multi_image: list[str]
    ):
        "add multiple images to the slide"
        top_i = top

        if len(multi_image) == 0:
            return ppt_slide_object

        # if number of images is 1, then set the image to the full width
        if (height / (cls.images["height"] + cls.images["spacing"])) >= len(
            multi_image
        ):
            image_height = min(height / len(multi_image), cls.images["max_height"])
            image_width = width
            add_column = False
            # center the image in the box
            left = left + width / 2 - image_width / 2
        else:
            image_height = min(cls.images["height"], cls.images["max_height"])
            image_width = width / 2 - cls.images["spacing"]
            add_column = False

        # add a rectangle to the slide
        # add a text to the slide for the gantt chart
        for image_str in multi_image:
            # if the image is a url, then download the image
            if image_str.startswith("http"):
                image = cls.download_image(image_str)[1]
            else:
                image = cls.search_and_download_image(image_str)[1]

            image_obj = ppt_slide_object.shapes.add_picture(
                image,
                left,
                top,
                width=image_width if image_width < image_height else image_height,
            )

            top += image_obj.height + (cls.images["spacing"])

            if add_column and top > top_i:
                left += image_width + cls.images["spacing"]
                top = top_i

        return ppt_slide_object

    @classmethod
    def create_matrix(
        cls,
        ppt_slide_object,
        top,
        left,
        cell_structures,
        cols,
        matrix_height,
        matrix_width,
    ):
        """
        Create the matrix
        Input:
        ppt_slide_object: pptx slide object
        top: int in em for PPT
        left: int in em for PPT
        cell_structures: list of [height, item, type of item]
        cols: int
        matrix_height: int in em for PPT
        matrix_width: int in em for PPT
        """

        for i, cell_structure in enumerate(cell_structures):
            # determine the top and left
            row = i // cols
            col = i % cols
            top_i = top + row * (
                matrix_height + cls.landscape_properties["min_spacing"]
            )
            left_i = left + col * (
                matrix_width + cls.landscape_properties["min_spacing"]
            )

            if cell_structure is None:
                raise ValueError("cell_structures is required")

            # add the entry to as a cell
            cls.add_single_cell(
                ppt_slide_object,
                top_i,
                left_i,
                matrix_width,
                matrix_height,
                cell_structure,
            )

    @classmethod
    def define_matrix(cls, width, height, landscape_len):
        """
        Define the matrix for the landscape
        Input: work area width, height, and the length of the landscape
        Output: number of rows, columns, matrix height, matrix width
        """
        if (
            landscape_len
            > cls.landscape_properties["max_rows"]
            * cls.landscape_properties["max_columns"]
        ):
            raise ValueError(
                f"Landscape of length: {landscape_len} is too large for the slide"
            )

        # check to see if landscape_len is a square number
        square_root = math.sqrt(landscape_len)
        if square_root.is_integer():
            rows = int(square_root)
            cols = int(square_root)
        elif landscape_len < 4:
            rows = 1
            cols = landscape_len
        else:
            rows = math.floor(square_root)
            cols = math.ceil(landscape_len / rows)

        # determine the size of one matrix
        matrix_height = (
            height - (rows - 1) * cls.landscape_properties["min_spacing"]
        ) / rows
        matrix_width = (
            width - (cols - 1) * cls.landscape_properties["min_spacing"]
        ) / cols

        return rows, cols, matrix_height, matrix_width

    @staticmethod
    def format_input_to_list(input_value):
        "format the input value to a list"

        # if the input value is a string, return the input value
        if isinstance(input_value, str):
            return input_value

        # if the input value is a dict, return the input value
        elif isinstance(input_value, dict):
            return list(input_value.values())

        # if the input value is a list, return the input value
        elif isinstance(input_value, list):
            input_value_unpacked = []
            for item in input_value:
                if isinstance(item, dict):
                    input_value_unpacked.append(list(item.values()))
                elif isinstance(item, list):
                    input_value_unpacked.append(item)

                elif isinstance(item, str):
                    input_value_unpacked.append(item)

                else:
                    raise ValueError(f"Item: {item} is not a list or a dict")

            return input_value_unpacked

        # if neither a list or a dict, then raise an error
        raise ValueError("Input value is not a list or a dict")

    @classmethod
    def search_and_download_image(cls, search_query: str, square_ratio: bool = False):
        """
        search and download the image
        Input:
        search_query: str
        square_ratio: bool to only accept square images

        """
        google_search = GoogleSearch()
        images_url = asyncio.run(google_search.search_google(search_query, "images"))

        first_image = []
        for image_url in images_url:

            img, img_loc = cls.download_image(image_url["imageUrl"])

            if not first_image:
                first_image = [img, img_loc]

            # check if the image is a square
            if square_ratio:
                img = Image.open(img_loc)
                width, height = img.size
                if 0.9 < width / height < 1.1:
                    return img, img_loc

            if not square_ratio:
                return img, img_loc

        return first_image[0], first_image[1]

    @staticmethod
    def download_image(url):
        "download the image"
        download_image = DownloadImage()
        try:
            img, img_loc = asyncio.run(
                download_image.get_image(url, "images/graph.png")
            )
            return img, img_loc
        except requests.exceptions.InvalidSchema as e:
            print(e)
            # logging.error(e) # TODO add logging
            pass
