import json
import os
import sys

sys.path.append(r".")

import re

from anytree import Node, PostOrderIter, PreOrderIter, RenderTree
# open the HTML report
from bs4 import BeautifulSoup, element
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from configs.config import OPENAI_MODEL_35

load_dotenv()


class TableJson(BaseModel):
    table: dict = Field(..., description="The table in JSON format")


class ParseSECReport:

    sec_parts = re.compile(r"part\s+(\d+|[IVX]+)", re.IGNORECASE)
    sec_items = re.compile(r"Item \d+[a-zA-Z]?", re.IGNORECASE)
    sec_notes = re.compile(r"Note \d+[a-zA-Z]?", re.IGNORECASE)

    tables_reviewed = []

    def __init__(self, doc_text: str):
        "Parse the SEC report and create a tree structure of the report"
        self.soup = BeautifulSoup(doc_text, "html.parser")
        self.body = self.soup.find("body")
        self._items_children = []
        self.loop_children(self.body)

        self.repetitive_strings = self.get_repetitive_strings(self._items_children)
        self.styles_in_report = self.create_styles_count(self._items_children)
        self.report_tree = self.create_tree(self._items_children)
        self.clean_up_nodes(self.report_tree)
        self.clean_up_tree(self.report_tree)

    def loop_children(self, node, parent=None, id=0, level=0):

        if parent is None:
            parent = {}

        parent["node"] = node
        parent["el_id"] = node["id"] if "id" in node.attrs else None
        parent["continuedat"] = (
            node["continuedat"] if "continuedat" in node.attrs else None
        )
        parent["children"] = {}
        parent["text"] = node.text.strip() if node.name == "p" else None

        children = self.get_children(node)
        check_item = True

        # if style is in the node, check the if it is a page number
        if "style" in node.attrs:
            if "bottom:0;position:absolute;width:100%" in node["style"]:
                return None

        if (
            sum([1 if "ix" in child.name else 0 for child in children]) == len(children)
            and len(children) > 0
        ):
            self._items_children.append([level, node, parent["text"]])
            check_item = True

        elif len(children) == 0:
            self._items_children.append([level, node, parent["text"]])
            check_item = False

        elif len(children) == 1:
            if children[0].name == "ix:nonnumeric":
                self._items_children.append([level, node, parent["text"]])
                check_item = False

        # if all children are spans, then just get the text
        elif sum([1 if child.name == "span" else 0 for child in children]) == len(
            children
        ):
            self._items_children.append([level, node, parent["text"]])
            check_item = False

        if check_item:
            if node.name == "table":
                self._items_children.append([level, node, "Table"])
                pass
            else:
                for child in children:
                    if child.text.strip() == "":
                        continue
                    id += 1
                    parent["children"][id] = {}
                    self.loop_children(
                        child, parent=parent["children"][id], id=id, level=level + 1
                    )

        return parent

    def get_children(self, node):
        "Get the children of the node that are not empty text nodes"
        children = []
        for child in node.children:
            if isinstance(child, element.Tag):
                if len(child.text.strip()) == 0:
                    continue
                children.append(child)

        return children

    def get_repetitive_strings(self, _items_children):
        "Get the repetitive strings in the report that are not CIK numbers or dates"
        item_refs = {}
        for _item2 in _items_children:
            _item = _item2[1]
            if _item.name == "table":
                continue
            item_text = _item.text
            _item_attrs = _item.attrs
            _item_styles = _item_attrs.get("style", None)

            if _item_styles is not None:
                if item_text not in item_refs:
                    item_refs[item_text] = 1
                else:
                    item_refs[item_text] += 1

        # sort by the number of times the item appears
        # from the top 10, get the repetitions
        item_refs = {
            k: v
            for k, v in sorted(
                item_refs.items(), key=lambda item: item[1], reverse=True
            )
        }
        items_repetitions = []
        # items to check against
        for k, v in {k: item_refs[k] for k in list(item_refs.keys())[:10]}.items():
            if v > 10:
                # check if k is a CIK number
                add_to_list = True
                if re.match(r"\d{10}", k):
                    add_to_list = False
                elif re.match(r"\d{4}-\d{2}-\d{2}", k):
                    add_to_list = False
                if add_to_list:
                    items_repetitions.append(k)

        return items_repetitions

    def create_styles_count(self, _items_children):

        styles_in_report = {}
        style_count = 1

        for _item2 in _items_children:
            _item = _item2[1]

            if _item.name == "table":
                continue

            _item_attrs = _item.attrs
            _item_styles = _item_attrs.get("style", None)

            if _item_styles is not None:
                if _item_styles in styles_in_report:
                    styles_in_report[_item_styles] += 1
                else:
                    styles_in_report[_item_styles] = 1

        # sort by the number of times the item appears
        styles_in_report = {
            k: v for k, v in sorted(styles_in_report.items(), key=lambda item: item[1])
        }

        new_styles = {}
        # give each style a unique number
        for k in styles_in_report.keys():
            new_styles[k] = style_count
            style_count += 1

        return new_styles

    def create_tree(self, _items_children):
        "Create a tree structure of the report"

        root_node = Node(name="root")
        root_node.style = None
        root_node.text = ""
        root_node = self.set_attr_flags(root_node)

        last_sec_part_node = None
        last_sec_item_node = None
        last_node_note = None

        first_node = None

        for i, _item2 in enumerate(_items_children):
            # get all of the styles as a set from the report
            _item = _item2[1]
            n_node = Node(i)

            # set the is_ flags
            n_node = self.set_attr_flags(n_node)

            # set node style
            n_node.style = self.get_node_style(_item)

            is_table = self.check_if_table(_item)
            if is_table:
                if is_table not in self.tables_reviewed:
                    n_node.text = self.clean_table_for_storage(is_table)
                    n_node.style = "Table"
                    n_node.parent = node_old
                    n_node.is_table = True
                    self.tables_reviewed.append(is_table)
                continue

            # get the attributes of the item
            item_text = _item.text.replace("\n", " ").replace("\xa0", " ")
            _item_attrs = _item.attrs

            # create a node
            if n_node.style is not None and item_text not in self.repetitive_strings:
                n_node.text = item_text
                n_node.parent = root_node

                if first_node is None:
                    n_node.is_first_node = True
                    first_node = n_node

                if self.sec_parts.match(item_text) and len(item_text) < 100:
                    n_node.is_sec_part = True
                    last_sec_part_node = n_node
                    last_sec_item_node = None
                    last_node_note = None

                if self.sec_items.match(item_text):
                    n_node.is_sec_item = True
                    last_sec_item_node = n_node
                    last_node_note = None

                if self.sec_notes.match(item_text):
                    n_node.is_sec_note = True
                    last_node_note = n_node

                if i == 0 or getattr(n_node, "is_sec_part", False):
                    n_node.parent = root_node
                elif getattr(n_node, "is_sec_item", False):
                    n_node.parent = last_sec_part_node
                elif getattr(n_node, "is_sec_note", False):
                    n_node.parent = last_sec_item_node
                elif last_node_note:
                    n_node.parent = last_node_note
                elif last_sec_item_node:
                    n_node.parent = last_sec_item_node
                else:
                    n_node.parent = first_node if first_node else root_node

            if n_node.parent:
                node_old = n_node

        return root_node

    def get_node_style(self, node) -> str:
        """get the style of the node
        if all of the children are spans, then get a style that appears the most based on character count
        if the children are not spans, then get the style of the node"""

        if sum([1 if child.name == "span" else 0 for child in node.children]) == sum(
            [True for i in node.children]
        ):
            child_styles = {}
            for child in node.children:
                child_style = child.attrs.get("style", None)
                if child_style in child_styles:
                    child_styles[child_style] += len(child.text)
                else:
                    child_styles[child_style] = len(child.text)

            return max(child_styles, key=child_styles.get)

        else:
            return node.attrs.get("style", None)

    def set_attr_flags(self, n_node):
        "Set the flags for the node"
        n_node.is_first_node = False
        n_node.is_sec_part = False
        n_node.is_sec_item = False
        n_node.is_sec_note = False
        n_node.is_table = False
        return n_node

    def get_parent_node(self, root_node, node_old, n_node, style_enums_dict):
        "Get the parent node of the current node"
        # check to see if the parent node's style is higher than the current node
        if self.check_sec_sections(node_old):
            return
        while style_enums_dict[n_node.style] <= style_enums_dict[node_old.style]:
            node_old = node_old.parent
            if node_old == root_node or self.check_sec_sections(node_old):
                break
        n_node.parent = node_old

    def clean_up_nodes(self, root):
        "Create a tree structure with levels"

        for node in PreOrderIter(root):
            # for the tables, the parent is defined during creation
            if node.name == root:
                continue
            if node.is_table:
                continue
            if self.check_sec_sections(node, check_table=False):
                # get the style enums of the children
                style_enum_dict = self.style_enums(node)

                node_old = node
                for child in node.children:
                    # set the node parent based on the style enums
                    if self.check_sec_sections(child):
                        continue
                    self.get_parent_node(node, node_old, child, style_enum_dict)
                    node_old = child

    def check_sec_sections(self, child, check_table=True):
        "Check if the child is a section part, item or note"
        if check_table:
            return (
                child.is_table
                or child.is_sec_part
                or child.is_sec_item
                or child.is_sec_note
                or child.is_first_node
            )
        else:
            return (
                child.is_sec_part
                or child.is_sec_item
                or child.is_sec_note
                or child.is_first_node
            )

    def style_enums(self, node) -> dict:
        "Get the style enums of the children"
        # loop through the children and store the count of each style
        # then sort the styles by the count
        # return the sorted styles
        style_enums = {}
        for child in node.children:
            # skip of any tables, parts, items or notes
            if self.check_sec_sections(child):
                continue
            if child.style in style_enums:
                style_enums[child.style] += 1
            else:
                style_enums[child.style] = 1

        # sort the styles by the count in descending order
        style_enums = {
            k: v for k, v in sorted(style_enums.items(), key=lambda item: item[1])
        }

        # give each style a unique number
        style_enums = {k: i for i, k in enumerate(style_enums.keys())}

        return style_enums

    def clean_up_tree(self, node):
        "clean up the tree"
        for child in node.children:
            self.clean_up_tree(child)
            if len(child.descendants) == 1:
                # if the children of the child are not tables
                if sum([1 if child.is_table else 0 for child in child.children]) == 0:
                    child.text += "  " + " ".join([i.text for i in child.children])
                    child.children[0].parent = None
        # self.clean_up_node(node)

        return node

    def clean_table_for_storage(self, table_item):
        "Clean the table for storage"
        table_text = str(table_item)
        # clean the table text
        cleaned_text = re.sub(r'style="[^"]+"', "", table_text)
        cleaned_text = re.sub(r"\n", "", table_text)

        chain_output = self.convert_html_table_to_json(cleaned_text)
        # chain_output = self.convert_table_to_json(cleaned_text)

        return chain_output
        # return cleaned_text

    def check_if_table(self, _item):
        "Check if the item is a table"
        if _item.name == "table":
            return _item
        if getattr(_item, "children", None):
            if len([a for a in _item.children]) == 1:
                for child in _item.children:
                    return self.check_if_table(child)
        return False

    def convert_html_table_to_json(self, table_text):
        "Convert the table to JSON"
        table_text = table_text.replace("\n", " ")
        soup = BeautifulSoup(table_text, "html.parser")
        table = soup.find("table")

        if not table:
            return ""

        if not table:
            print(None)

        rows = table.find_all("tr")
        table_data = []
        for row in rows:
            cells = row.find_all(["th", "td"])
            row_data = []
            empty_row = True
            for idx, cell in enumerate(cells):
                cell_text = cell.get_text(strip=True)
                cell_text = (
                    cell_text.replace("\n", " ")
                    .replace("\xa0", "")
                    .replace("&nbsp;", " ")
                )
                ix_elements = cell.find_all(re.compile("^ix:"))

                if cell_text != "" and empty_row is True:
                    empty_row = False

                if ix_elements:
                    ix_id = ix_elements[0].attrs["id"]
                    ix_name = ix_elements[0].attrs["name"]
                    ix_scale = (
                        ix_elements[0].attrs["scale"]
                        if "scale" in ix_elements[0].attrs
                        else None
                    )
                    row_data.append(
                        {
                            "text": cell_text,
                            "ix_id": ix_id if ix_elements else None,
                            "ix_name": ix_name if ix_elements else None,
                            "ix_scale": ix_scale if ix_elements else None,
                        }
                    )
                else:
                    row_data.append({"text": cell_text})
            if not empty_row:
                table_data.append(row_data)

        return json.dumps(table_data)

    def convert_table_to_json(self, table_text):
        "Convert the table to JSON"

        parser = JsonOutputParser(pydantic_object=TableJson)

        # get the table headers
        chat_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "human",
                    """Convert the below from HTML to JSON table in a way that makes sense. \
            Take the table below and only give me the relevant info with keys that make sense. \
            If there are dates associated, you need have the dates. Make the date the highest level when dealing with financial data. \
            You need to retain all of the information in the table.""",
                ),
                ("human", "{table_text}"),
                # ("human", "Respond as a JSON object"),
            ]
        )

        llm = ChatOpenAI(
            model=OPENAI_MODEL_35,
            temperature=0.0,
        )

        chain = chat_prompt | llm | parser

        chain_output = chain.invoke({"table_text": table_text})

        return chain_output


if __name__ == "__main__":

    from utils.researcher.SEC.SECReports import SECReports

    # ticker = "MRAI"

    for ticker in [
        "AAPL",
        "TSLA",
        "MRAI",
        "FDX",
        "WDFC",
        "UAL",
    ]:
        for report_type in ["10-K", "10-Q"]:
            # for ticker in ["WDFC"]:
            #     for report_type in ["10-Q"]:
            sec_report = SECReports(
                ticker=ticker, filing_type=report_type, last_n_reports=1
            )

            report = ParseSECReport(sec_report.reports[0]["report_main_doc"])

            # write the tree to a text file
            with open(f"{ticker}_{report_type}_tree.txt", "w", encoding="utf-8") as f:
                for pre, fill, node in RenderTree(report.report_tree):
                    f.write(f"{fill}{node.name}\t{node.text}\n")
