"Graph research agent"

import sys

sys.path.append(r".")

import os

from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_openai import ChatOpenAI

from services.ppt_generator.data_classes.graph_classes import (Bar, Line, Pie,
                                                               Scatter)
from utils.researcher.researcher import Researcher

llm35 = ChatOpenAI(
    api_key=os.getenv("OPENAI_API_KEY"), model="gpt-3.5-turbo", temperature=0
)


def research_data(parent_question: str = None):

    researcher = Researcher(parent_question=parent_question)
    answer_ = researcher.researcher()

    print(answer_)

    return answer_


def format_for_chart(research_context, graph_type):
    "Use the LLM to format the data for the chart"

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a mathematician who has been asked to format text into a structured data format for a chart
                You are skilled at logical reasoning. If given a few data points you are able to interpolate and extrapolate the data to fill in the gaps.""",
            ),
            (
                "system",
                "Your task is to re-format the text into the data structure requested",
            ),
            (
                "human",
                "The text you should reference is ** \n {research_context} \n **",
            ),
            ("system", "The data is formatted as follows:{graph_format}"),
        ]
    )

    parser = JsonOutputParser(pydantic_object=graph_type)

    chain = prompt | llm35 | parser

    chain_output = chain.invoke(
        {
            "research_context": research_context,
            "graph_format": parser.get_format_instructions(),
        }
    )

    return chain_output


def get_graph_type(question: str):

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "What is the best graph type for this question?"),
            ("human", "The question: ** \n {question} \n **"),
            (
                "human",
                "Your response should be in JSON as the following example: 'graph_type': '', 'x_axis': ''",
            ),
            ("system", "Your graph type choices are: 'Line', 'Bar', 'Pie', 'Scatter'"),
            (
                "system",
                "Only return the the graph type and what the x-axis should be: 'Category', 'Date', 'Number'",
            ),
        ]
    )

    chain = prompt | llm35

    chain_output = chain.invoke({"question": question})

    return chain_output.content


if __name__ == "__main__":

    QUESTION = "What is global market size for a electric vehicles?"

    QUESTION += " If it is a timeline graph, provide at minimum 3 historical and 5 future data points"

    graph_type = get_graph_type(QUESTION)
    print(graph_type)

    # answer = research_data(parent_question=QUESTION)

    # if graph_type == "Line":
    #     formatted_data = format_for_chart("\n".join([a['summary'] for a in answer]), Line)
    # elif graph_type == "Bar":
    #     formatted_data = format_for_chart("\n".join([a['summary'] for a in answer]), Bar)
    # elif graph_type == "Pie":
    #     formatted_data = format_for_chart("\n".join([a['summary'] for a in answer]), Pie)
    # elif graph_type == "Scatter":
    #     formatted_data = format_for_chart("\n".join([a['summary'] for a in answer]), Scatter)

    # print(formatted_data)
