import os
import socket
import sys
import unicodedata
from datetime import datetime

from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pptx import Presentation
from pptx.exc import PackageNotFoundError
# from pptx.enum.shapes import MSO_SHAPE_TYPE
# from pptx.util import Inches
from pydantic import BaseModel

sys.path.append(".")
# from utils.logger import ServiceLogger
import logging

from configs.config import OPENAI_API_KEY, OPENAI_MODEL_MINI
from services.ppt_training.training_data_models import *

XCM_logger = logging.getLogger()


def get_temp_dir(path: str = None):
    """Get the temp directory."""
    # Use AWS Lambda's writable /tmp/ directory
    default = "/tmp/ppt_analysis"
    if path is None:
        path = default

    os.makedirs(path, exist_ok=True)  # Ensure directory exists safely

    return path


class SlideProcessor:
    """
    Class for processing slides in a PowerPoint presentation.
    """

    def __init__(self, pptx_path: str, project_id: str = None):
        XCM_logger.info(
            "Initialized SlideProcessor for project: %s and ppt: %s",
            project_id,
            pptx_path,
        )
        self.pptx_path = pptx_path
        self.temp_dir = get_temp_dir()
        # if the path is an s3 path, download it
        if "s3://" in pptx_path:
            import boto3

            s3 = boto3.client("s3")
            # download the file to a local temp directory
            pptx_path = pptx_path.replace("s3://", "")
            bucket = pptx_path.split("/")[0]
            pptx_path = "/".join(pptx_path.split("/")[1:])
            temp_file = os.path.join(self.temp_dir, os.path.basename(pptx_path))
            s3.download_file(bucket, pptx_path, temp_file)
            pptx_path = temp_file

        # check if the file exists
        if not os.path.exists(pptx_path):
            raise ValueError(f"File {pptx_path} does not exist")

        # check if the file is a PowerPoint presentation
        if not pptx_path.endswith(".pptx"):
            raise ValueError(f"File {pptx_path} is not a PowerPoint presentation")

        # check if the file is a valid PowerPoint presentation
        try:
            prs = Presentation(pptx_path)
        except PackageNotFoundError as exc:
            raise ValueError(
                f"File {pptx_path} is not a valid PowerPoint presentation"
            ) from exc

        self.prs = prs
        # self.slide_data =
        self.ppt_object = PPTData(
            project_id=project_id,
            slide_data=None,
            pptx_path=self.pptx_path,
        )

    def process_slides(self):
        "process the slides in the presentation"

        XCM_logger.info("Processing slides in presentation: %s", self.pptx_path)
        slide_data = []
        if self.prs.slides is None:
            raise ValueError(
                "Error: self.prs.slides is None, cannot proceed with analysis."
            )
        for slide in self.prs.slides:
            title = self._get_title(slide)
            text = self._get_text(slide, title)

            ## if there is no title, use the text as the title
            if not title:
                title = text[0]
                text = text[1:]

            image_path = self._save_slide_image(slide)
            is_section_divider = False
            tables = self._get_tables(slide)
            slide_data.append(
                SlideData(
                    title=title,
                    text=text,
                    tables=tables,
                    image_path=image_path,
                    is_section_divider=is_section_divider,
                )
            )

        # loop through the slies_data and run the is_section_divider logic
        for i, slide in enumerate(slide_data):
            try:
                if i == 0 or i == len(slide_data) - 1:
                    # if the slide is the first or last slide, then it is not a section divider
                    slide.is_section_divider = False
                else:
                    # pass in the slide before and after the current slide
                    slide.is_section_divider = self._is_section_divider(
                        slide_data[i - 1 : i + 2]
                    )
            except UnicodeEncodeError:
                continue

        return slide_data

    def _get_title(self, slide):
        "Get the title of the slide"
        if slide.shapes.title is None:
            return None
        return self._clean_text(slide.shapes.title.text)

    def _get_text(self, slide, title: str = None):
        text_content = []
        for shape in slide.shapes:
            text_content.extend(
                self._clean_text(self._extract_text_from_shape(shape, title))
            )
        return text_content

    @staticmethod
    def _clean_text(text: str | list[str]):
        "clean the text"
        if isinstance(text, list):
            return [SlideProcessor._clean_text(t) for t in text]
        normalized_text = unicodedata.normalize("NFKD", text)
        cleaned_text = normalized_text.encode("ascii", "ignore").decode("utf-8")
        return cleaned_text

    def _extract_text_from_shape(self, shape, title_text: str = None):
        """Recursively extract text from shapes, excluding the title text."""
        text_content = []

        # If shape has a text frame, extract the text
        if shape.has_text_frame:
            shape_text = shape.text.strip()
            if shape_text and shape_text != title_text:
                text_content.append(shape_text)

        # If shape is a group, recursively process each sub-shape
        elif shape.shape_type == 6:  # Group shape type is 6
            for sub_shape in shape.shapes:
                text_content.extend(
                    self._extract_text_from_shape(sub_shape, title_text)
                )

        return text_content

    def _get_tables(self, slide):
        """Extract table data from a shape."""
        tables = []

        for shape in slide.shapes:
            if shape.has_table:
                table_data = []
                table = shape.table
                for row in table.rows:
                    row_data = [cell.text.strip() for cell in row.cells]
                    table_data.append(row_data)
                tables.append(table_data)

        return tables

    def _save_slide_image(self, slide):
        # Placeholder for saving slide as image
        return ""

    def _is_section_divider(self, slides: list[SlideData]):
        "using an LLM, we are going to determine if the slide is a section divider"
        # Placeholder logic for determining section divider
        if slides[0].is_section_divider == True or slides[2].is_section_divider == True:
            return False

        if (
            len(slides[1].text) > 3
        ):  # if the text is more than 3 sections , then it is not a section divider
            return False

        if slides[1].title == None:
            return False

        llm = ChatOpenAI(
            api_key=OPENAI_API_KEY,
            model=OPENAI_MODEL_MINI,
            temperature=0.5,
        )

        if isinstance(slides[1].text, list):
            slide_text = "\n".join(slides[1].text)
        else:
            slide_text = slides[1].text
        slide_text = slide_text.replace("\u2010", "-")
        slide_text_safe = slide_text.encode("utf-8", errors="replace").decode("utf-8")

        chat_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "You are a skilled slide deck analyzer. You are an expert in determining if a slide is a section divider. You are not providing an opinion but rather a factual answer and current analysis.",
                ),
                (
                    "human",
                    f"The slide we are trying to determine if it is a section divider is titled: {slides[1].title}",
                ),
                ("human", f"The content of the slide is: {slide_text_safe}"),
                (
                    "human",
                    f"The previous slide is: {slides[0].title} and the next slide is: {slides[2].title}",
                ),
                (
                    "human",
                    "Is this slide a section divider? True or False. Only return the answer",
                ),
            ]
        )

        chain = chat_prompt | llm
        chain_output = chain.invoke({})
        if chain_output.content.lower() == "true":
            return True

        return False

    def __del__(self):
        "Delete the files in the temp directory"
        import shutil

        try:
            if hasattr(self, "temp_dir") and self.temp_dir:
                shutil.rmtree(self.temp_dir)
                print(f"Temporary directory {self.temp_dir} deleted.")
        except Exception as e:
            print(f"Error deleting temporary directory {self.temp_dir}: {e}")

    # def upload_to_database(self):
    #     "upload the slide data to the database"
    #     from utils.dynamo_db import DynamoDB

    #     db = DynamoDB()
    #     table =
    #     response = db.upload_to_dynamodb(db.training_ppts, self.ppt_object.model_dump())

    #     return response


if __name__ == "__main__":
    processor = SlideProcessor(
        r"C:\Users\sagar\OneDrive\Desktop\X Cap Market\Engineering\training_data\sunbelt\4.19.2024_Stairworx_CIM_v7.pptx"
    )
    slides_info = processor.process_slides()
    processor.ppt_object.slide_data = slides_info
    # processor.upload_to_database()
    print(slides_info)
