"convert svg to png"
import io
import logging
import re
from datetime import datetime

import defusedxml
import requests
from cairosvg import svg2png
from PIL import Image, UnidentifiedImageError

request_headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0",
    "Accept-Encoding": "*",
    "Connection": "keep-alive",
}

from configs.config import STOCK_IMAGE_FILENAME


class DownloadImage:
    """
    Download image from url and save it to file_name
    """

    STOCK_IMAGE_FILENAME = STOCK_IMAGE_FILENAME

    def __init__(self):
        pass

    # convert image to pgn
    async def get_image(self, url: str, file_name: str, picture_background_dark=True):
        """
        Get image from url and save it to file_name
        Input: url, file_name to save image as, picture_background_dark if is dark or not (true, false)
        OUtput: image file, file_name
        """
        if url == self.STOCK_IMAGE_FILENAME:
            stock_image_filename = self.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

        try:
            # remove special characters from file_name except for the last 4 characters
            file_name = re.sub(r"\s", "_", file_name[:-4]) + file_name[-4:]

            file_name = re.sub(r"[^a-zA-Z0-9_]", "", file_name[:-4]) + file_name[-4:]

            file_name = re.sub(r"__", "_", file_name[:-4]) + file_name[-4:]

            if url.endswith(".svg"):
                return self.svg_img(url, file_name, picture_background_dark)

            try:
                response = requests.get(url, timeout=100, headers=request_headers)

            except requests.exceptions.ReadTimeout:
                logging.error(
                    "%s: Download Images: Read timeout for %s",
                    datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    url,
                )
            except requests.exceptions.ConnectionError:
                logging.error(
                    "%s: Download Images: Connection error for %s",
                    datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    url,
                )
                stock_image_filename = self.STOCK_IMAGE_FILENAME
                img = Image.open(stock_image_filename)
                img.close()
                return img, stock_image_filename

            # if it is a PNG image
            file_name = r"images/%s" % file_name + ".jpg"
            img = Image.open(io.BytesIO(response.content))
            img = img.convert("RGBA")

            background = Image.new("RGBA", img.size, (255, 255, 255))

            alpha_composite = Image.alpha_composite(background, img)
            alpha_composite = alpha_composite.convert("RGB")
            alpha_composite.save(file_name, quality=50)
            alpha_composite.close()
            # img.save(file_name, optimize=True, dpi=(320,320))
            img.close()

            return alpha_composite, file_name

        except UnidentifiedImageError:
            stock_image_filename = self.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

        except requests.exceptions.MissingSchema:
            stock_image_filename = self.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

        except Exception as e:
            logging.error(
                "%s: Download Images: Error downloading image %s",
                datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                e,
            )
            stock_image_filename = self.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

    @classmethod
    def svg_img(cls, url, file_name, picture_background_dark=True):
        "For SVG images, convert to PNG and save to file_name"
        background_color = None
        if picture_background_dark is False:
            background_color = "black"
        try:
            img = svg2png(
                url=url,
                write_to=file_name,
                background_color=background_color,
                scale=1,
                dpi=96,
                output_height=0.65 * 96,
            )

            return img, file_name

        except defusedxml.common.EntitiesForbidden:
            stock_image_filename = cls.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

    def _get_image(self, url: str, file_name: str, picture_background_dark=True):
        """
        Get image from url and save it to file_name
        Input: url, file_name to save image as, picture_background_dark if is dark or not (true, false)
        OUtput: image file, file_name
        """
        try:
            file_name = r"images/%s" % re.sub(" ", "_", file_name)

            if url.endswith(".svg"):
                background_color = None
                if picture_background_dark is False:
                    background_color = "black"
                try:
                    img = svg2png(
                        url=url,
                        write_to=file_name,
                        background_color=background_color,
                        scale=1,
                        dpi=96,
                        output_height=0.65 * 96,
                    )

                    return img, file_name
                except defusedxml.common.EntitiesForbidden:
                    stock_image_filename = self.STOCK_IMAGE_FILENAME
                    img = Image.open(stock_image_filename)
                    img.close()
                    return img, stock_image_filename

            try:
                response = requests.get(url, timeout=100, headers=request_headers)
            except requests.exceptions.ReadTimeout:
                logging.error(
                    "%s: Download Images: Read timeout for %s",
                    datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    url,
                )
                stock_image_filename = self.STOCK_IMAGE_FILENAME
                img = Image.open(stock_image_filename)
                img.close()
                return img, stock_image_filename

            img = Image.open(io.BytesIO(response.content))
            img = img.convert("RGBA")
            background = Image.new("RGBA", img.size, (255, 255, 255))

            alpha_composite = Image.alpha_composite(background, img)
            alpha_composite = alpha_composite.convert("RGB")
            alpha_composite.save(file_name, quality=50)
            alpha_composite.close()
            # img.save(file_name, optimize=True, dpi=(320,320))
            img.close()
            # img.save(file_name, optimize=True, dpi=(320,320))
            # img.close()

            return alpha_composite, file_name

        except UnidentifiedImageError:
            stock_image_filename = self.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

        except requests.exceptions.MissingSchema:
            stock_image_filename = self.STOCK_IMAGE_FILENAME
            img = Image.open(stock_image_filename)
            img.close()
            return img, stock_image_filename

    def resize_image(self, image_path, max_width=1, max_height=1):
        """resize image to max_width and max_height in inches"""
        max_width = max_width * 96
        max_height = max_height * 96
        # scale the image to either a max width of 1.23 inches or a max height of .65 inches

        img = Image.open(image_path)
        rescale = 1
        if img.width > max_width:
            rescale = max_width / img.width
        elif img.height > max_height:
            rescale = max_height / img.height

        img = img.resize((int(img.width * rescale), int(img.height * rescale)))
        img.close()

        return img, image_path


if __name__ == "__main__":
    download_image = DownloadImage()
    download_image.get_image(
        "https://asset.brandfetch.io/idLdViRnHy/idXqSF1uBx.svg",
        "scale_logo.png",
        picture_background_dark=True,
    )
    download_image.resize_image("scale_logo.png", max_width=1.23, max_height=0.65)
