# this script will use langchain libraries to upload a PDF, Word doc or text file to a chroma DB
# we will then query against that file to see if it contains any relevant information

import sys

sys.path.append(r".")

import io
import os
import time

import requests
from docx import Document
from docx.oxml.exceptions import InvalidXmlError
from docx.table import Table
from docx.text.paragraph import Paragraph
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from configs.config import EMBEDDING_MODEL, OPENAI_MODEL_35
# from utils.selenium_driver_chrome import SeleniumDriverChrome
from services.ppt_generator.data_classes.project import Project
from utils.dynamo_db import DynamoDB
from utils.s3_storage import S3BucketStorage

# we will loop through the tables and paragraphs to extract the text
openai_ef = OpenAIEmbeddings(
    model=EMBEDDING_MODEL,
    api_key=os.environ["OPENAI_API_KEY"],
)


class ConvertTable:
    "Convert a table to text."

    col_delimiter_text = "\t|"
    row_delimiter_text = "\n\n\n\n"

    file_delimiter_text = "_+_"

    llm35 = ChatOpenAI(
        api_key=os.getenv("OPENAI_API_KEY"), model=OPENAI_MODEL_35, temperature=0
    )
    summary_token_limit = 500

    def __init__(self, table: Table, document_name: str, project: Project):
        self.table = table
        self.document_name = document_name
        self.project = project

    def convert_table_to_text(self, prior_str: str = "") -> str:
        "Converts a table to text."

        rows_text = []
        for row in self.table.rows:
            try:
                row_text = self.col_delimiter_text.join(
                    [cell.text for cell in row.cells]
                )
                rows_text.append(row_text)
            except InvalidXmlError:
                pass

        table_str = self.row_delimiter_text.join(rows_text)
        table_summary = self.get_table_summary(table_str, prior_str)

        table_name = self.upload_table_to_dynamodb(table_str, table_summary)

        return table_summary, table_name

    @classmethod
    def get_table_from_dynamodb(cls, table_name: str):
        "Get the table from dynamodb"
        db = DynamoDB()
        table = db.get_item(db.document_tables, table_name)
        # table_markdown = cls.decode_table_text(table["table_str"])
        return table

    @classmethod
    def decode_table_text(cls, table_text: str):
        "Decode the table text to markdown"

        table_rows = table_text.split(cls.row_delimiter_text)

        # add a | at the start and end of each row
        table_markdown = ""
        for row in table_rows:
            table_markdown += "|" + row.replace(cls.col_delimiter_text, "|") + "|"
        return table_markdown

    def get_table_summary(self, table_str: str, prior_str: str = ""):
        "Use openAI to get the summary of the table"
        chat_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """You are a financial analyst and you have been given a table with company data.
              You need to summarize the table in 3-4 sentences.""",
                ),
                (
                    "human",
                    f"The document name is {self.document_name} where the table originates.",
                ),
                ("human", "The information before the table is {prior_str}"),
                ("human", "The rows of the table are: {table_str}"),
                ("human", "The column delimiter is: {delimiter_text}"),
                ("human", "The row delimiter is: {row_delimiter_text}"),
                (
                    "system",
                    "Please summarize the table in less than {summary_token_limit} tokens.",
                ),
                ("system", "Only return the summary, nothing else."),
            ]
        )

        chain = chat_prompt | self.llm35
        chain_output = chain.invoke(
            {
                "prior_str": prior_str,
                "table_str": table_str,
                "delimiter_text": self.col_delimiter_text,
                "row_delimiter_text": self.row_delimiter_text,
                "summary_token_limit": self.summary_token_limit,
            }
        )

        return chain_output.content

    def upload_table_to_dynamodb(self, table_str: str, table_summary: str):
        "Upload the table to dynamodb"

        _unique_id = self.create_table_name(self.project.project_id, self.document_name)
        _item = {
            "unique_id": _unique_id,
            "table_str": table_str,
            "table_summary": table_summary,
        }
        db = DynamoDB()
        db.upload_to_dynamodb(table=db.document_tables, data=_item)

        return _unique_id

    def create_table_name(self, project_id: str, document_name: str):
        "get table name"
        table_name = self.file_delimiter_text.join(
            [project_id, document_name, str(time.time())]
        )
        return table_name


class LoadWordDoc:
    "Load a word doc and extract the text."

    def __init__(self, doc_path: str, project: Project):
        self.doc_path = doc_path
        self.project = project
        self.file_name = os.path.basename(doc_path)
        self.document_text_chunked = self.load_doc()

    def load_doc(self):
        "Load the word doc"
        if self.doc_path.startswith("s3://"):
            # download using boto3
            s3b = S3BucketStorage()
            content = s3b.get_file_from_s3_bucket(self.doc_path)
            doc = Document(io.BytesIO(content))

        elif self.doc_path.startswith("https://"):
            r = requests.get(self.doc_path, timeout=10)
            doc = Document(io.BytesIO(r.content))

        else:
            doc = Document(self.doc_path)

        last_str = ""
        chunk_str = ""

        document_text_chunked = []

        for _item in doc.iter_inner_content():

            if isinstance(_item, Paragraph):

                chunk_str += "\n" + _item.text

                if _item.text not in ["", " ", "\n"]:
                    last_str = _item.text

            elif isinstance(_item, Table):
                if chunk_str != "":
                    document_text_chunked.append({"text": chunk_str, "type": "text"})

                table_prior_chunk = "\n".join(
                    [i_str for i_str in chunk_str.split("\n") if i_str != ""][-2:]
                )
                chunk_str = last_str + "\n"

                conv_table = ConvertTable(
                    table=_item,
                    document_name=self.file_name,
                    project=self.project,
                )
                table_str, table_name = conv_table.convert_table_to_text(
                    prior_str=table_prior_chunk
                )

                chunk_str += table_str
                document_text_chunked.append(
                    {"text": chunk_str, "type": "table", "table_name": table_name}
                )
                chunk_str = ""

        if chunk_str != "":
            document_text_chunked.append({"text": chunk_str, "type": "text"})

        return document_text_chunked


if __name__ == "__main__":

    db = DynamoDB()
    project = db.get_item(db.projects, "navis_pack_ship")
    project = Project(**project)

    doc_path = r"s3://xcap-storage-dev/Sunbelt/navis_pack_ship/Navis - Questionnaire w Larry Feedback.docx"

    load_doc = LoadWordDoc(doc_path, project)
