"Get stock data from the OpenBB library"

import logging
from datetime import datetime, timedelta

from openbb import obb
from openbb_core.app.model.abstract.error import OpenBBError

PROVIDERS = ["yfinance"]


class OpenBBStockData:
    "Get Stock data from the OpenBB library"

    def __init__(self, interval="1d", start_date=None, provider="yfinance"):
        """
        Initialize the class with the interval and start date
        interval: str: The interval of the stock data (default: 1d, options: 1m, 1h, 1d, 1W, 1M)
            minute, hour, day, week, month
        start_date: str: The start date of the stock data (default: 5 years from today)
        provider: str: The provider of the stock data (default: yfinance)
        """

        self.provider_list = PROVIDERS

        if provider not in self.provider_list:
            raise ValueError(
                f"Invalid provider: {provider}, options: {', '.join(self.provider_list)}"
            )
        self.provider = provider

        self.interval = interval

        if not start_date:
            self.start_date = self._get_start_date(duration=5)
        else:
            self.start_date = start_date

    def get_stock_data(self, ticker=None, start_date=None, duration: int = 5):
        """
        Get stock data from the OpenBB library
        ticker: str: The ticker of the stock
        start_date: str: The start date of the stock data (default: 5 years from today)
        duration: int: The duration of the stock data (default: 5 years)
        """
        try:
            if not ticker:
                raise ValueError("Ticker is required to fetch stock data")

            self.start_date = self._get_start_date(duration)

            if not start_date:
                start_date = self.start_date

            stock_data = obb.equity.price.historical(
                symbol=ticker,
                start_date=self.start_date,
                interval=self.interval,
                provider=self.provider,
            ).to_df()

            dates = [date.strftime("%Y-%m-%d") for date in stock_data.index]
            close_prices = stock_data["close"].tolist()

            stock_dict = {"date": dates, "close_price": close_prices}

            return stock_dict

        except ValueError as e:
            logging.error(
                "%s: Error fetching stock prices for %s from %s. \nError: ",
                datetime.now().strftime("%Y/%m/%d %H:%M:%S"),
                ticker,
                self.provider,
            )
            raise ValueError(
                f"Error fetching {ticker} stock prices from {self.provider}: {e}"
            ) from e

        except OpenBBError as e:
            logging.error(
                "%s: Error fetching stock prices for %s from %s. \nError: ",
                datetime.now().strftime("%Y/%m/%d %H:%M:%S"),
                ticker,
                self.provider,
            )
            raise ValueError(
                f"Error fetching {ticker} stock prices from {self.provider}: {e}"
            ) from e

    @staticmethod
    def _get_start_date(duration: int = 5):
        """
        Get the start date of the stock data
        duration: int: The duration of the stock data in years (default: 5 years)
        """
        today = datetime.now()
        start_date = today - timedelta(days=365 * duration + 1)

        if start_date.weekday() in [5, 6]:  # Saturday or Sunday
            start_date_offset = start_date.weekday() - 4  # will be 1 or 2
            start_date -= timedelta(days=start_date_offset)

        start_date = start_date.strftime("%Y-%m-%d")

        return start_date


if __name__ == "__main__":

    openbb_stock = OpenBBStockData(
        provider="yfinance",
    )
    print(openbb_stock.get_stock_data(duration=1, ticker="U"))
