"Module to call perplexity api"

import os
import re

import requests
from dotenv import load_dotenv

load_dotenv()


class PerplexityResearch:
    "class to call perplexity api"
    api_key = os.getenv("PERPLEXITY_KEY")
    url = "https://api.perplexity.ai/chat/completions"
    citation_pattern = r"\[(\d+)\]"

    def ask_perplexity(
        self,
        messages: list,
        temperature: float = 0.3,
        return_citations: bool = True,
        model: str = "sonar",
    ) -> dict:
        """
        Asks Perplexity AI a question with a given set of messages and returns a response.

        Args:
            messages (list): A list of strings with the messages to send to the AI.
            temperature (float, optional): The temperature of the response. Defaults to 0.3.
            return_citations (bool, optional): Whether to return the citations. Defaults to True.
            model (str, optional): The model to use. Defaults to "sonar".

        Returns:
            dict: A dictionary with the answer and optionally the citations.
        """

        payload = {
            "return_citations": return_citations,
            "model": model,
            "messages": self.convert_messages(messages),
            "temperature": temperature,
        }
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }

        response = requests.post(self.url, json=payload, headers=headers, timeout=60)

        return self.extract_answer_and_citations(response)

    def convert_messages(self, messages):
        "convert the messages to json"
        perplexity_messages = []
        for message in messages:
            if message[0] == "system":
                perplexity_messages.append({"role": "system", "content": message[1]})
            elif message[0] == "human":
                perplexity_messages.append({"role": "user", "content": message[1]})
            else:
                perplexity_messages.append({"role": "assistant", "content": message[1]})

        return perplexity_messages

    def extract_answer_and_citations(self, response):
        "from the response of the request, extract the answer and citations"

        # if the repsonse is an error, return None
        if response.status_code != 200:
            return None

        answer = response.json()["choices"][0]["message"]["content"]

        matches = re.findall(self.citation_pattern, answer)
        answer = re.sub(self.citation_pattern, "", answer)

        citations = response.json()["citations"]
        citations_to_return = []

        for match in matches:
            number = int(match)
            citations_to_return.append(citations[number - 1])

        return {"answer": answer, "citations": citations_to_return}

    @staticmethod
    def pretty_print(response):
        "pretty print the response"
        return f"Answer: {response['answer']}, Citations: {response['citations']}"


if __name__ == "__main__":
    messages_perplexity = [
        ("system", "You are a fan of F1."),
    ]

    perplexity = PerplexityResearch()
    per_response = perplexity.ask_perplexity(messages_perplexity)

    print(per_response)

    print(perplexity.pretty_print(per_response))
