import json
import shutil
import sys
from collections import Counter
from typing import Dict, List

import requests

sys.path.append(".")

import os

from cairosvg import svg2png
from PIL import Image
from rembg import remove

from configs.config import pic_file_extension
from services.company_profile.data_classes.company_info import Logo

BRAND_FETCH_API = os.environ["BRAND_FETCH_API_KEY"]


def get_brand_from_url(url):
    """
    Fetches brand information from a given URL using the brandfetch API

    Args:
        url (str): The URL to fetch the text from.

    Returns:
        brand (dict):
            company_name (str): The name of the company
            logos (dict):
                logo_url (str): The url of the logo
                logo_extension (int): The file extension of the logo
                logo_dark (bool): If the logo is dark or not
            linkedin (str): The linkedin url of the company
    """

    request_header = {
        "accept": "application/json",
        "Authorization": "Bearer " + BRAND_FETCH_API,
    }

    request_url = "https://api.brandfetch.io/v2/brands/" + url.split("//")[1]
    try:
        req = requests.get(request_url, headers=request_header, timeout=20)
        if req.status_code != 200:
            return None
    except requests.exceptions.ReadTimeout:
        return None

    req = req.json()

    logo = find_logo_from_logos(req["logos"]) if "logos" in req else None

    return logo


def find_logo_from_logos(logos):
    "Find the logo from the logos returned"
    if len(logos) == 0:
        print("NO LOGO FOUND")
        return False

    final_logo_url = ""
    final_logo_extension = 0
    final_logo_dark = False
    final_logo_format = ""
    # go throught array of logos where each entry is a dictionary. Get the link
    for logo in logos:
        for f in logo["formats"]:
            logo_url = f["src"]
            logo_extension = (
                pic_file_extension[f["format"]] + 0 if logo["theme"] == "dark" else 10
            )
            logo_format = f["format"]
            logo_dark = logo["theme"] == "dark"

            if (
                len(final_logo_url) == 0 or (logo_extension < final_logo_extension)
            ) and logo["type"] == "logo":
                final_logo_url = logo_url
                final_logo_extension = logo_extension
                final_logo_dark = logo_dark
                final_logo_format = logo_format

    return Logo(
        logo_url=final_logo_url,
        logo_extension=final_logo_extension,
        logo_dark=final_logo_dark,
        logo_format=final_logo_format,
    )


# save a logo to a file
def save_logo(logo, filename):
    "Save the logo to a file"
    with open(filename, "wb") as f:
        f.write(requests.get(logo.logo_url, timeout=10).content)

    if (
        logo.logo_url.endswith(".png")
        or logo.logo_url.endswith(".jpg")
        or logo.logo_url.endswith(".jpeg")
    ):
        # convert the svg to a png
        remove_background(filename, save_image=True)


def load_json_data(file_path: str) -> List[Dict]:
    with open(file_path, "r") as f:
        data = json.load(f)
    return data


def save_json_data(data: List[Dict], file_path: str):
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)


def remove_background(
    image_path: str, output_path: str = None, save_image: bool = False
):
    "remove the background of an image"

    input_path = image_path
    output_path = output_path

    _input = Image.open(input_path).convert("RGBA")

    _input = remove_background_color(_input)

    _input_gray_scale = _input.convert("L")
    _input_gray_scale = _input_gray_scale.getbbox()
    _input = _input.crop(_input_gray_scale)

    output = _input

    if save_image:
        if output_path is None:
            output_path = "".join(image_path.split(".")[:-1]) + "_no_bg.png"
        output.save(output_path)

    return output


def remove_background_color(image: Image.Image):

    # Get the width and height of the image
    width, height = image.size

    # Extract the border pixels (top, bottom, left, and right edges)
    border_pixels = []

    # Top edge
    for x in range(width):
        border_pixels.append(image.getpixel((x, 0)))

    # Bottom edge
    for x in range(width):
        border_pixels.append(image.getpixel((x, height - 1)))

    # Left edge
    for y in range(height):
        border_pixels.append(image.getpixel((0, y)))

    # Right edge
    for y in range(height):
        border_pixels.append(image.getpixel((width - 1, y)))

    # Count the most common color in the border pixels
    most_common_color = Counter(border_pixels).most_common(1)[0][0]

    # Create a new image with the same size and RGBA mode
    new_image = Image.new("RGBA", (width, height))

    # Process each pixel in the original image
    for y in range(height):
        for x in range(width):
            current_color = image.getpixel((x, y))
            # If the pixel matches the most common border color, make it transparent
            if is_color_within_tolerance(
                current_color, most_common_color, tolerance=0.1
            ):
                new_image.putpixel((x, y), (0, 0, 0, 0))  # Transparent
            else:
                new_image.putpixel((x, y), current_color)  # Keep the original color

    return new_image


def is_color_within_tolerance(color, target_color, tolerance=0.05):
    return all(abs(c - t) <= tolerance * 255 for c, t in zip(color, target_color))


# find the domain from the comapny name and industry
def get_url_from_name(company_name: str, industry: str) -> str:
    "Get the url from the company name and industry"
    from collections import Counter

    from utils.search_google import GoogleSearch

    # check if the company name is a url
    if company_name.startswith("http://") or company_name.startswith("https://"):
        return company_name

    # check if the company name is a company
    search_google = GoogleSearch()
    results = search_google._search_google(f"{company_name} {industry} logo", "search")
    if len(results) == 0:
        return False

    # for all of the results, return the domain that repeats the most
    domain_count = Counter(
        [url_parser.parsed_url(result["link"]).url for result in results]
    )
    if len(domain_count) == 0:
        return False
    return max(domain_count, key=domain_count.get)


if __name__ == "__main__":

    from utils import url_parser

    # get all of the PNG files in the folder
    # folder_loc = r"C:\Users\sagar\OneDrive\Desktop\X Cap Market\Customers\Kenyon\projects\Altlaw\logos\splash"
    # for file in os.listdir(folder_loc):
    #     if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg"):
    #         remove_background(folder_loc + "\\" + file, save_image=True)

    companies = {
        "Law firms": [
            "Mintz",
            "Morgan Lewis",
            "Cooley",
            "Kirkland & Ellis",
            "Jones Day",
            "Covington & Burling",
            "Williams and Connely",
            "Hogan Lovells",
        ],
    }

    location = r"C:\Users\sagar\X Cap Market\X Cap Market - Customers\Kenyon\projects\Barrister Digital\logos"

    for sector, companies_list in companies.items():
        industry = f"AI {sector}"
        sector_location = location + f"\{sector}"
        os.makedirs(sector_location, exist_ok=True)
        os.chdir(sector_location)
        for company in companies_list:
            try:
                url = get_url_from_name(company, industry)
                url = url_parser.parsed_url(url)

                logos = get_brand_from_url(url.url)

                # check a folder for the bucket exists, then add the logo to that folder
                if not os.path.exists(sector_location + "\\"):
                    os.makedirs(location)

                location_new = (
                    sector_location + "\\" + url.domain + "." + logos.logo_format
                )

                save_logo(logos, location_new)
                # remove_background(location + "\\" + location_new, save_image=True)
            except Exception as e:
                print("Error with: ", url.url)
                continue

        os.chdir(location)
