# built-in sitemap parser:

import asyncio
import logging
import sys
from urllib.parse import urlparse

import urllib3

sys.path.append(".")


# from utils.logger import ServiceLogger
import logging
import time
from urllib.robotparser import RobotFileParser

import requests
from anytree import Node, PreOrderIter
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_text_splitters import CharacterTextSplitter
from parsel import Selector

from configs.config import (
    OPENAI_API_KEY,
    OPENAI_MODEL_35,
    OPENAI_MODEL_MINI,
    OPENAI_TOKEN_LIMIT,
)
from utils.selenium_driver_chrome import SeleniumDriverChrome
from utils.url_parser import parsed_url
from utils.webscrape.crawl_pages import crawl_pages

XCM_logger = logging.getLogger()


class SitemapScrape:

    robots_locations = ["/robots.txt"]

    max_depth_to_scrape = 3

    duplicate_page_text_dict = {}

    def __init__(
        self,
        url: str,
        scrape_sitemap: bool = True,
        url_scrape: bool = True,
        summarize: bool = True,
        exclude_blogs: bool = False,
    ):
        ## set the configs
        self.sitemap_scrape = scrape_sitemap
        self.url_scrape = url_scrape
        self.summarize = summarize
        self.exclude_blogs = exclude_blogs

        self.url = url
        self.domain_sitemap = self.sitemaps_from_robots(parsed_url(url).url)
        self.consolidate_sitemaps = self.remove_duplicate_sitemaps_urls(
            self.domain_sitemap
        )

        if self.sitemap_scrape:
            self.site_urls = self.scrape_sitemap()
            self.site_urls = list(set(self.site_urls))
            # self.convert_to_tree()
        if self.url_scrape:
            self.urls_scraped = asyncio.run(self.scrape_urls(self.site_urls))
            self.image_urls = self.consolidate_images()

    def sitemaps_from_robots(self, url: str) -> list[str]:
        "Get the sitemap for a url"
        sitemap_urls = [url + "/sitemap.xml"]
        for robots_location in self.robots_locations:
            robot_url = url + robots_location
            rp = RobotFileParser(robot_url)
            r = self.get_url(robot_url)
            if r.status_code in [404, 403]:
                continue

            lines = r.text.split("\n")
            rp.parse(lines)
            sitemaps = rp.site_maps()
            if sitemaps:
                sitemap_urls.extend(sitemaps)

        return list(set(sitemap_urls))

    def get_url(self, url: str) -> requests.Response:
        "Get the url"
        try:
            r = requests.get(
                url,
                headers={"User-Agent": "Mozilla/5.0"},
                timeout=10,
                allow_redirects=True,
            )
        except requests.exceptions.ConnectionError as e:
            XCM_logger.error("ConnectionError with %s: %s", url, e, exc_info=True)
            time.sleep(1)
            r = requests.get(
                url.replace("//www.", "//"),  # remove www
                headers={"User-Agent": "Mozilla/5.0"},
                timeout=10,
                allow_redirects=True,
                # verify=False,
            )
        except urllib3.exceptions.MaxRetryError as e:
            XCM_logger.error("MaxRetryError with %s: %s", url, e, exc_info=True)
            time.sleep(0.5)
            raise e

        return r

    def remove_duplicate_sitemaps_urls(self, domain_sitemap: list[str]) -> set[str]:
        "Remove duplicate sitemap urls"
        consolidate_sitemaps = set()
        for sitemap in domain_sitemap:
            r = self.get_url(sitemap)
            # get the redirect url
            consolidate_sitemaps.add(r.url)

        return consolidate_sitemaps

    def scrape_sitemap(self):
        "Scrape the sitemap"
        XCM_logger.info("Scrape sitemap for %s", self.url, exc_info=True)

        urls = []

        while self.consolidate_sitemaps:
            time.sleep(0.5)
            sitemap_url = self.consolidate_sitemaps.pop()
            response = self.get_url(sitemap_url)
            response_text = self._clean_cdata(response.text)
            selector = Selector(response_text)
            ## find sitemaps and add to the consolidate_sitemaps
            for loc in selector.xpath("//sitemap/loc/text()").getall():
                self.consolidate_sitemaps.add(self.clean_loc_text(loc))
            ## find urls and add to the urls
            for loc in selector.xpath("//url/loc/text()").getall():
                urls.append(self.clean_loc_text(loc))

        XCM_logger.info("Scraped %s urls from sitemap for %s", len(urls), self.url)
        return urls

    def _clean_cdata(self, text: str):
        import re

        return re.sub(r"<!\[CDATA\[(.*?)\]\]>", r"\1", text)

    def clean_loc_text(self, loc_text: str) -> str:
        "Clean the loc url text"
        # remove the trailing slash and any regex that would be in the url
        loc_text = loc_text.strip().rstrip("/").split(" ")[0]
        return loc_text.strip()

    async def scrape_urls(self, urls: list):
        """Scrape the URLs asynchronously"""

        urls_scraped = {}

        # Filter URLs based on conditions
        urls_to_scrape = [
            url
            for url in urls
            if len(urlparse(url).path.split("/")) <= self.max_depth_to_scrape
            and not ("blog" in url.lower() and self.exclude_blogs)
        ]

        # Async fetch pages
        urls_scraped = await crawl_pages(urls_to_scrape, with_images=True)

        # If summarization is enabled, process all summaries concurrently
        tasks = []
        for url, page_data in urls_scraped.items():
            page_text = page_data.get("markdown", "")
            if page_text and self.summarize:
                tasks.append(self.summarize_website(url, page_text))

        summaries = await asyncio.gather(*tasks)

        # Update results with summaries
        for i, url in enumerate(urls_scraped):
            urls_scraped[url] = {
                "text": urls_scraped[url]["markdown"],
                "images": urls_scraped[url]["images"],
                "summary": summaries[i] if self.summarize else "",
            }

        return urls_scraped

    def remove_duplicate_text(self, url_node, texts: list[str]) -> list[str]:
        "Remove duplicate text"
        # self.page_text_dict = {}
        unique_page_text = []
        for text in texts:
            # text_in_parent = False

            # url_node_parent = url_node.parent
            # while url_node_parent is not None:
            #     if text in url_node_parent.page_text_dict:
            #         text_in_parent = True
            #         break
            #     url_node_parent = url_node_parent.parent

            # if text_in_parent == False:
            #     url_node.page_text_dict[text] = True
            if text not in self.duplicate_page_text_dict:
                self.duplicate_page_text_dict[text] = True
                unique_page_text.append(text)

        return unique_page_text

    async def summarize_website_text_chunk(self, url: str, page_text_list: list[str]):
        "Summarize the website text chunk"
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """You are a research analyst with 20 years of experience and you have been asked to summarize website pages. \
                You will first understand the website text in detail and then summarize the relevant information. \
                Only return the summary text of the website text. Make your summary verbose but factual. Extra information is better than truncating information that could be useful.\
                You will be provided prior summary of the website if the page had too much text. You want to only append to the summary.""",
                ),
                ("system" "Bring the relevant information from the text"),
                ("human", "The URL you are summarizing is: {url}"),
                ("human", "The website text is: {page_text}"),
            ]
        )

        llm = ChatOpenAI(
            model_name=OPENAI_MODEL_MINI,
            api_key=OPENAI_API_KEY,
            temperature=0.1,
            max_tokens=500,
        )

        chain = prompt | llm

        output_results = await chain.abatch(
            [{"url": url, "page_text": text} for text in page_text_list]
        )

        output = "\n".join([result.content for result in output_results])
        return output

    async def summarize_website(self, url: str, page_text: str):
        "Summarize the website"

        text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
            encoding_name="cl100k_base",
            chunk_size=int(OPENAI_TOKEN_LIMIT / 2),
            chunk_overlap=int(0.05 * OPENAI_TOKEN_LIMIT),
        )
        texts = text_splitter.split_text(page_text)

        summaries = await self.summarize_website_text_chunk(url, texts)
        # summarization_tasks = [self.summarize_website_text_chunk(url, text) for text in texts]
        # summaries = await asyncio.gather(*summarization_tasks)

        return summaries

    def consolidate_images(self):
        "Consolidate the images"
        XCM_logger.info("Consolidating images for %s", self.url)

        image_urls = {}
        for site in self.urls_scraped:
            for image in self.urls_scraped[site]["images"]:
                image_url = image["src"]
                if image_url not in image_urls:
                    image_urls[image_url] = {"url": image_url, "count": 1}
                else:
                    image_urls[image_url]["count"] += 1

        # sort the images by count
        image_urls = {
            k: v
            for k, v in sorted(
                image_urls.items(), key=lambda item: item[1]["count"], reverse=True
            )
            if k is not None
        }

        return image_urls

    # def convert_to_tree(self):
    #     "Convert the sitemap to a tree"
    #     # sort the urls by the length
    #     self.site_urls_sorted = sorted(self.site_urls, key=lambda x: len(x))

    #     root_node = None

    #     for i, url in enumerate(self.site_urls_sorted):
    #         new_node = Node(url)
    #         new_node.url = url
    #         new_node.is_blog = True if "blog" in url.lower() else False

    #         if i == 0:
    #             root_node = new_node
    #         else:
    #             for j in range(i):
    #                 parent_node_set = False
    #                 if url.startswith(self.site_urls_sorted[j]):
    #                     new_node.parent = prior_node
    #         self.site_tree = new_node.root


if __name__ == "__main__":

    URL = "https://munichmotorsport.com/"

    sitemap_scrape = SitemapScrape(
        URL, scrape_sitemap=True, url_scrape=True, summarize=False, exclude_blogs=True
    )
    print(sitemap_scrape.site_urls)
