"Public comparable data for stocks"

# this module uses Alpha Vantage API to get the public comparable data for the stock tickers

import sys
from typing import Literal

sys.path.append(r".")

import os

import numpy as np
import pandas as pd
import requests
from dotenv import load_dotenv

load_dotenv()


class PublicComparables:
    "public comparable data for stock tickers"

    alpha_vantage_api_key = os.getenv("Alpha_Vantage")

    def __init__(self, public_comparables: list[str] = None) -> None:

        if public_comparables is None:
            public_comparables = []

        self.data = {}
        # check if the tickers are valid
        for ticker in public_comparables:
            data = self.get_ticker_data(ticker)
            if not data:
                raise ValueError(f"Invalid ticker: {ticker}")
            self.data[ticker] = {
                "overview_data": data,
            }
            self.data[ticker]["is"] = self.get_income_statement(ticker)
            self.data[ticker]["bs"] = self.get_balance_sheet(ticker)
            self.data[ticker]["comp_table"] = self.create_comp_table(
                data, self.data[ticker]["is"], ticker
            )

    def get_ticker_data(self, ticker: str):
        "Gets the ticker data if the ticker is valid, else returns False"
        # using the alpha vantage api, check if the ticker is valid

        url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={ticker}&apikey={self.alpha_vantage_api_key}"

        response = requests.get(url)
        data = response.json()

        if len(data) == 0 or "Error Message" in data:
            return False

        return data

    def create_comp_table(self, d, income_statement, ticker):
        "Create the necessary items for the comparable table"
        # get the current year
        year = pd.Timestamp.now().year

        # get the income statement for the ticker

        comp_table = {
            "Name": d["Name"],
            "stock_price": self.get_stock_price(ticker),
            "Mkt Cap": self.normalize_orders_of_mag(
                int(d["MarketCapitalization"]), "billion"
            ),
            "EV": self.normalize_orders_of_mag(
                self.calculate_enterprise_value(
                    self.data[ticker]["bs"], d["MarketCapitalization"]
                ),
                "billion",
            ),
            "Revenue": {
                "LTM": self.normalize_orders_of_mag(int(d["RevenueTTM"]), "million"),
                year: self.normalize_orders_of_mag(
                    int(income_statement["annualReports"][0]["totalRevenue"]), "million"
                ),
                year + 1: None,
                year + 2: None,
            },
            "EBITDA": {
                "LTM": self.normalize_orders_of_mag(
                    int(
                        self.calculate_ttm_data(
                            income_statement["quarterlyReports"], "ebitda"
                        )
                    ),
                    "million",
                ),
                year: self.normalize_orders_of_mag(int(d["EBITDA"]), "million"),
                year + 1: None,
                year + 2: None,
            },
            "EV/Rev": {
                "LTM": None,
                year: float(d["EVToRevenue"]),
                year + 1: None,
                year + 2: None,
            },
            "EV/EBITDA": {
                "LTM": None,
                year: float(d["EVToEBITDA"]),
                year + 1: None,
                year + 2: None,
            },
            "P/E": {
                "LTM": None,
                year: float(d["PERatio"]),
                year + 1: None,
                year + 2: None,
            },
        }
        return comp_table

    def get_stock_price(self, ticker):
        "get the stock price for the ticker"
        url = f"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={ticker}&apikey={self.alpha_vantage_api_key}"
        response = requests.get(url)
        data = response.json()
        close_price = list(data["Time Series (Daily)"].values())[0]["4. close"]
        return close_price

    def normalize_orders_of_mag(
        self, value, order_of_mag: Literal["billion", "million", "thousand"]
    ):
        """Normalize the orders of magnitude
        return a dict with normalized value as a float and the unit
        """
        if order_of_mag == "billion":
            return {"value": value / 1_000_000_000, "unit": "B"}
        elif order_of_mag == "million":
            return {"value": value / 1_000_000, "unit": "M"}
        elif order_of_mag == "thousand":
            return {"value": value / 1_000, "unit": "K"}

    def get_income_statement(self, ticker):
        "get the income statement for the ticker"
        url = f"https://www.alphavantage.co/query?function=INCOME_STATEMENT&symbol={ticker}&apikey={self.alpha_vantage_api_key}"
        response = requests.get(url)
        data = response.json()
        return data

    def get_balance_sheet(self, ticker):
        "get the balance sheet for the ticker"
        url = f"https://www.alphavantage.co/query?function=BALANCE_SHEET&symbol={ticker}&apikey={self.alpha_vantage_api_key}"
        response = requests.get(url)
        data = response.json()
        return data

    def calculate_enterprise_value(self, data, mkt_cap):
        "calculate the enterprise value"
        if (
            data["annualReports"][0]["fiscalDateEnding"]
            > data["quarterlyReports"][0]["fiscalDateEnding"]
        ):
            bs_data = data["annualReports"][0]

        else:
            bs_data = data["quarterlyReports"][0]

        total_debt = int(bs_data["shortLongTermDebtTotal"]) + int(
            bs_data["longTermDebt"]
        )
        total_cash = int(bs_data["cashAndShortTermInvestments"]) + int(
            bs_data["longTermInvestments"]
            if bs_data["longTermInvestments"] != "None"
            else 0
        )
        enterprise_value = int(mkt_cap) + total_debt - total_cash

        return enterprise_value

    def calculate_ttm_data(self, data, key):
        "calculate the trailing twelve months data"
        ttm_data = (
            float(data[0][key])
            + float(data[1][key])
            + float(data[2][key])
            + float(data[3][key])
        )
        return ttm_data


if __name__ == "__main__":
    public_comparables = PublicComparables(["AAPL", "GOOGL", "MSFT", "TSLA"])
    import json

    for ticker, data in public_comparables.data.items():
        print(ticker)
        print(json.dumps(data["comp_table"], indent=2))

    # validate a ticker
    # print(PublicComparables().get_ticker_data("aapl"))
