"Get items from dynamodb"
import datetime
import socket
import time
from decimal import Decimal

import backoff
import botocore.exceptions
from boto3.dynamodb.conditions import Attr

from utils.dynamo_connector import DynamoConnector


class DynamoDB:
    "Get items from dynamodb"

    slides_research_table = "slides_research"

    def __init__(self):
        db_connector = DynamoConnector()
        self.dynamodb = db_connector.dynamodb
        self.user = self.dynamodb.Table("user")
        self.team = self.dynamodb.Table("Team")
        self.company_info_table = self.dynamodb.Table("company_information")
        self.llm_table = self.dynamodb.Table("LLM_results")
        self.company_list_table = self.dynamodb.Table("company_list")
        self.projects = self.dynamodb.Table("projects")
        self.projectsV2 = self.dynamodb.Table("projectsV2")
        self.document_tables = self.dynamodb.Table("document_tables")
        self.fund_info_table = self.dynamodb.Table("fund_info")
        self.press_releases = self.dynamodb.Table("press_releases")
        self.web_pages = self.dynamodb.Table("web_pages")
        self.project_docs = self.dynamodb.Table("project_docs")
        self.project_structure = self.dynamodb.Table("project_structure")

    def get_table_name(self, table):
        "Get the table name"
        return table.name

    def get_table(self, table_name):
        "Get the table name"
        return self.dynamodb.Table(table_name)

    def get_item(self, table, key_id, sort_key_id=None):
        "Get items from dynamodb"
        root_key_name = table.key_schema[0]["AttributeName"]
        sort_key_name = None

        if len(table.key_schema) == 2 and sort_key_id is not None:
            sort_key_name = table.key_schema[1]["AttributeName"]
            response = table.get_item(
                Key={root_key_name: key_id, sort_key_name: sort_key_id}
            )

        elif len(table.key_schema) == 2 and sort_key_id is None:
            response = table.scan(FilterExpression=Attr(root_key_name).eq(key_id))
            if "Items" not in response or len(response["Items"]) == 0:
                return None
            return response["Items"][0]
        else:
            response = table.get_item(Key={root_key_name: key_id})

        if "Item" not in response:
            return None
        item = response["Item"]
        return item

    def scan_table(self, table, filter_expression=None, expression_values=None):
        "Get all items from dynamodb with an optional filter expression"
        items = []
        if filter_expression is None:
            response = table.scan()
        if expression_values is None:
            response = table.scan(FilterExpression=filter_expression)
        else:
            response = table.scan(
                FilterExpression=filter_expression,
                ExpressionAttributeValues=expression_values,
            )

        items.extend(response.get("Items", []))

        while "LastEvaluatedKey" in response:
            response = self.scan_with_backoff(
                response, table, filter_expression, expression_values
            )
            items.extend(response.get("Items", []))

        return items

    def query_items(self, table, look_up_value: str):
        "Query items from dynamodb"
        ## get the partition key and sort key
        table_partition_key = table.key_schema[0]["AttributeName"]
        from boto3.dynamodb.conditions import Key

        response = table.query(
            KeyConditionExpression=Key(table_partition_key).eq(
                look_up_value
            )  # Query by partition key
        )

        documents = response.get("Items", [])  # Extract items

        # Handle pagination in case of large datasets
        while "LastEvaluatedKey" in response:
            response = table.query(
                KeyConditionExpression=Key(table_partition_key).eq(look_up_value),
                ExclusiveStartKey=response["LastEvaluatedKey"],  # Fetch next batch
            )
            documents.extend(response.get("Items", []))

        return documents

    @backoff.on_exception(backoff.expo, botocore.exceptions.ClientError, max_tries=5)
    def scan_with_backoff(
        self, response, table, filter_expression=None, expression_values=None
    ):
        if filter_expression is None:
            return table.scan(ExclusiveStartKey=response["LastEvaluatedKey"])
        if expression_values is None:
            return table.scan(
                ExclusiveStartKey=response["LastEvaluatedKey"],
                FilterExpression=filter_expression,
            )
        return table.scan(
            ExclusiveStartKey=response["LastEvaluatedKey"],
            FilterExpression=filter_expression,
            ExpressionAttributeValues=expression_values,
        )

    def batch_get_item(self, table, partition_keys: dict):
        "get all of the items in the table with the given partition keys"
        keys = {table.name: partition_keys}
        response = self.dynamodb.batch_get_item(RequestItems=keys)
        items = response.get("Responses", {}).get(table.name, [])
        return items

    def convert_for_recommender(self, company_info, llm_result):
        "Convert the dynamodb items to the format outputted by the recommendor"
        overview = None
        hq_n_founded = None
        for results in llm_result["LLM_results"]:
            if results["category"] == "Overview":
                overview = results["response"]
            if results["category"] == "HQ_Founded":
                hq_n_founded = results["response"]

        company_name = company_info["company_name"]
        if "name" in company_info and (company_name == "" or company_name is None):
            company_name = company_info["name"]
        if company_name == "" or company_name is None:
            company_name = (
                company_info["root_url"].split("www.")[-1].split(".")[0].capitalize()
            )

        return {
            "company_name": company_name,
            "company_url": company_info["root_url"],
            "logo": (
                company_info["logo"]["logo_url"]
                if isinstance(company_info["logo"], dict)
                else None
            ),
            "overview": overview,
            "hq_n_founded": hq_n_founded,
        }

    def get_item_for_recommender(self, root_url):
        "Get items from dynamodb"
        company_info = self.get_item(self.company_info_table, root_url)
        llm_result = self.get_item(self.llm_table, root_url)
        return self.convert_for_recommender(company_info, llm_result)

    def get_item_for_fund_recommender(self, root_url):
        "Get items from dynamodb"
        fund_info = self.get_item(self.fund_info_table, root_url)
        return fund_info

    def convert_floats_to_decimals(self, item):
        if isinstance(item, dict):
            return {k: self.convert_floats_to_decimals(v) for k, v in item.items()}
        elif isinstance(item, list):
            return [self.convert_floats_to_decimals(element) for element in item]
        elif isinstance(item, float):
            return Decimal(str(item))
        else:
            return item

    def upload_to_dynamodb(self, table, data):
        "Put item into collection"
        ## convert float to decimal
        data = self.convert_floats_to_decimals(data)
        try:
            response = table.put_item(Item=data)
        except Exception:
            time.sleep(10)
            response = table.put_item(Item=data)
        return response

    @staticmethod
    def _sanitize_text(text):
        """Remove non-printable characters from the text fields."""
        return "".join(c for c in text if c.isprintable())

    def batch_upload_to_dynamodb(
        self, table_name, data_list, max_retries=3, retry_delay=5
    ):
        """
        Batch upload items into the DynamoDB table with retry logic.

        :param table_name: Name of the DynamoDB table.
        :param data_list: A list of dictionaries, each representing an item to upload.
        :param max_retries: Number of retries for unprocessed items.
        :param retry_delay: Time in seconds to wait before retrying.
        """
        table = self.dynamodb.Table(table_name)

        # Convert floats to decimals and sanitize text fields
        cleaned_data_list = []
        for item in data_list:
            sanitized_item = {
                k: self._sanitize_text(v) if isinstance(v, str) else v
                for k, v in item.items()
            }
            cleaned_data_list.append(sanitized_item)

        # Process batch writes in chunks of 25
        for i in range(0, len(cleaned_data_list), 25):
            batch_chunk = cleaned_data_list[i : i + 25]

            retries = 0
            while retries <= max_retries:
                try:
                    with table.batch_writer() as batch:  # ✅ Open once per batch
                        for item in batch_chunk:
                            batch.put_item(Item=item)

                    break  # ✅ Exit retry loop if successful

                except botocore.exceptions.ClientError as e:
                    print(
                        f"DynamoDB error: {e}, retrying {retries + 1}/{max_retries}..."
                    )
                    retries += 1
                    time.sleep(retry_delay)

                except Exception as e:
                    print(f"Unexpected error: {e}")
                    break  # ✅ Stop retrying on unknown errors

    def create_or_update_company_list(self, root_url, scrapped=False):
        "Get or update item in collection"
        response = self.get_item(self.company_list_table, root_url)

        if response is not None:
            # update the item in the collection
            self.company_list_table.update_item(
                Key={"root_url": root_url},
                UpdateExpression="set scrapped = :s, modified_by = :m, modified_date = :d",
                ExpressionAttributeValues={
                    ":s": scrapped,
                    ":m": socket.gethostname(),
                    ":d": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                },
            )
            return response
        else:
            # create
            self.company_list_table.put_item(
                Item={
                    "root_url": root_url,
                    "scrapped": scrapped,
                    "size": 0,
                    "created_by": socket.gethostname(),
                    "created_date": datetime.datetime.now().strftime(
                        "%Y-%m-%d %H:%M:%S"
                    ),
                    "modified_by": socket.gethostname(),
                    "modified_date": datetime.datetime.now().strftime(
                        "%Y-%m-%d %H:%M:%S"
                    ),
                }
            )
            return None

    def get_all_table_items(self, table, filter_expression=None):
        "For a given table, get all items in the table"
        items = []
        if filter_expression is None:
            response = table.scan()
        else:
            response = table.scan(FilterExpression=filter_expression)

        items = response["Items"]

        while "LastEvaluatedKey" in response:
            try:
                if filter_expression is None:
                    response = table.scan(
                        ExclusiveStartKey=response["LastEvaluatedKey"]
                    )
                else:
                    response = table.scan(
                        ExclusiveStartKey=response["LastEvaluatedKey"],
                        FilterExpression=filter_expression,
                    )
                items += response["Items"]
            except Exception:
                time.sleep(10)
                response = table.scan(
                    ExclusiveStartKey=response["LastEvaluatedKey"],
                )
                items += response["Items"]

        return items

    def query_table(
        self,
        table,
        partition_key_name,
        partition_key_value,
        sort_key_name=None,
        sort_key_value=None,
        filter_expression=None,
        expression_values=None,
    ):
        """
        Query items from DynamoDB based on partition key, optionally sort key, and with an optional filter expression.
        Args:
            table: DynamoDB table object.
            partition_key_name: The name of the partition key.
            partition_key_value: The value of the partition key.
            sort_key_name: (Optional) The name of the sort key.
            sort_key_value: (Optional) The value of the sort key.
            filter_expression: (Optional) Additional filter expression.
            expression_values: (Optional) Expression attribute values for the filter expression.

        Returns:
            List of items matching the query.
        """
        items = []
        key_condition = f"{partition_key_name} = :partition_key"
        expression_values = expression_values or {}
        expression_values[":partition_key"] = partition_key_value

        # Add sort key condition if provided
        if sort_key_name and sort_key_value:
            key_condition += f" AND {sort_key_name} = :sort_key"
            expression_values[":sort_key"] = sort_key_value

        # Construct query parameters
        query_params = {
            "KeyConditionExpression": key_condition,
            "ExpressionAttributeValues": expression_values,
        }
        if filter_expression:
            query_params["FilterExpression"] = filter_expression

        # Execute the query
        response = table.query(**query_params)
        items.extend(response.get("Items", []))

        # Handle pagination
        while "LastEvaluatedKey" in response:
            # time.sleep(1)
            query_params["ExclusiveStartKey"] = response["LastEvaluatedKey"]
            response = table.query(**query_params)
            items.extend(response.get("Items", []))

        return items

    def query_index(
        self,
        table,
        index_name: str,
        sort_key_value: str,
        expression_values=None,
        sort_key="project_id",
    ) -> list:
        """
        Query items from DynamoDB based on index name, key condition, and optional expression values.
        """
        items = []
        response = table.query(
            IndexName=index_name,  # Name of your GSI
            KeyConditionExpression=f"{sort_key} = :{sort_key}",
            ExpressionAttributeValues={f":{sort_key}": sort_key_value},
        )
        items.extend(response.get("Items", []))
        while "LastEvaluatedKey" in response:
            response = table.query(
                IndexName=index_name,
                KeyConditionExpression=f"{sort_key} = :{sort_key}",
                ExpressionAttributeValues={f":{sort_key}": sort_key_value},
                ExclusiveStartKey=response["LastEvaluatedKey"],
            )
            items.extend(response.get("Items", []))

        return items

    def update_item(
        self, table, key, update_expression, expression_values, return_values=None
    ):
        "Update item in collection"
        return table.update_item(
            Key=key,
            UpdateExpression=update_expression,
            ExpressionAttributeValues=expression_values,
            ReturnValues=return_values,
        )

    def batch_get_items(self, table, keys, chunk_size=100):
        """
        Batch get items from DynamoDB table using a list of keys.

        Args:
            table: DynamoDB table object
            keys: List of key dictionaries
            chunk_size: Size of each batch request (max 100)

        Returns:
            List of items retrieved from DynamoDB
        """
        items = []

        for i in range(0, len(keys), chunk_size):
            chunk = keys[i : i + chunk_size]
            response = self.dynamodb.batch_get_item(
                RequestItems={table.name: {"Keys": chunk}}
            )

            if "Responses" in response and table.name in response["Responses"]:
                items.extend(response["Responses"][table.name])

        return items


if __name__ == "__main__":

    # get_item = DynamoDB()
    # ci = get_item.get_item(get_item.company_info_table, "https://www.unity.com")
    # llmr = get_item.get_item(get_item.llm_table, "https://www.unity.com")

    # from services.company_profile.data_classes.company_info import CompanyInfo
    # from services.company_profile.data_classes.llm_results import LLMResults

    # new_ci = CompanyInfo(**ci)
    # new_llmr = LLMResults(**llmr)

    # # upload to dynamodb
    # get_item.upload_to_dynamodb(get_item.company_info_table, new_ci.model_dump())
    # get_item.upload_to_dynamodb(get_item.llm_table, new_llmr.model_dump())

    # # create or update company list
    # get_item.create_or_update_company_list("https://www.unity.com")

    db = DynamoDB()
    table = db.dynamodb.Table(db.web_pages)
    items = db.query_index(table, "project_id", "greenjunkremoval")
    print(items)
