from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from utils.chroma_db import ChromaDB

from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.output_parsers import JsonOutputParser


class ChromaCustomRetriever(BaseRetriever):
    def __init__(
        self,
        chroma_collection,
        embedding_function,
        selected_files=None,
        top_k=5,
    ):
        super().__init__()
        object.__setattr__(self, "chroma_collection", chroma_collection)
        object.__setattr__(self, "embedding_function", embedding_function)
        object.__setattr__(self, "top_k", top_k)
        object.__setattr__(self, "filter", self.build_filter(selected_files))
        object.__setattr__(self, "openai_ef", OpenAIEmbeddings())

    def build_filter(self, files):
        filters = []
        if files:
            filters.append({"file_name": {"$in": files}})
        if not filters:
            return {}
        if len(filters) == 1:
            return filters[0]

        return {"$and": filters}

    def _get_relevant_documents(self, query: str):
        embedding = self.openai_ef.embed_query(query)
        results = self.chroma_collection.query(
            query_embeddings=[embedding], n_results=self.top_k, where=self.filter
        )
        return [
            Document(page_content=doc, metadata=meta)
            for doc, meta in zip(results["documents"][0], results["metadatas"][0])
        ]

    async def _aget_relevant_documents(self, query: str):
        return self._get_relevant_documents(query)


class ConversationalRAG:
    def __init__(self, selected_files, project_id, history_provider, ai_model):
        self.selected_files = selected_files
        self.llm = ChatOpenAI(
            model=ai_model, temperature=0.3, streaming=True  # Enable streaming
        )
        self.chroma_db = ChromaDB()
        self.project_id = project_id
        self.collection_name = "chat_" + project_id
        self.collection = self.chroma_db.chroma_client.get_collection(
            self.collection_name, embedding_function=self.chroma_db.openai_ef
        )
        self.history_aware_retriever = self.create_history_aware_retriever()
        self.question_answer_chain = self.create_qa_chain()
        self.conversational_rag_chain = self.create_conversational_rag_chain()
        self.history_provider = history_provider

    def create_history_aware_retriever(self):
        contextualize_q_system_prompt = (
            "Given a chat history and the latest user question which might reference context in the chat history, "
            "formulate a standalone question which can be understood without the chat history. "
            "Do NOT answer the question, just reformulate it if needed and otherwise return it as is."
        )
        contextualize_q_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", contextualize_q_system_prompt),
                MessagesPlaceholder("chat_history"),
                ("human", "{input}"),
            ]
        )

        retriever = ChromaCustomRetriever(
            chroma_collection=self.collection,
            embedding_function=self.chroma_db.openai_ef,
            selected_files=self.selected_files,
            top_k=5,
        )

        return create_history_aware_retriever(
            self.llm, retriever, contextualize_q_prompt
        )

    def create_qa_chain(self):
        system_prompt = (
"""
You are an AI-powered Investment Banking Analyst. Your role is to support deal teams with research, analysis, 
and presentation material preparation across M&A, capital markets, and strategic advisory mandates. Maintain a 
highly professional, data-driven, and concise tone. Use financial terminology appropriately and ensure all insights 
are substantiated by logic, precedent, or relevant data.

## Core Functions:
- Conduct industry and company research using credible sources and databases
- Build and analyze financial models including DCF, comparable company, and precedent transaction analyses
- Create polished, professional PowerPoint slides for pitch books, CIMs, management presentations, and market updates
- Synthesize complex financial and strategic information into concise insights for internal and client-facing use
- Monitor M&A, equity, and debt capital markets to identify relevant transactions and trends
- Assist with due diligence by summarizing key findings from data rooms and public filings
- Maintain accuracy, clarity, and consistency across deliverables

## Communication Style:
- Use a professional and analytical tone
- Be concise and structured in all written responses
- Support assertions with data or citations when possible
- Prioritize clarity and eliminate fluff
- Never speculate—only provide informed, defensible perspectives

## Formatting Guidelines:
- Use bullet points and short paragraphs for readability
- Clearly label financial metrics, dates, and sources
- Avoid jargon unless appropriate for an investment banking audience
- Keep bullet points clean and without terminal punctuation

{context}
"""

        )
        qa_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_prompt),
                MessagesPlaceholder("chat_history"),
                ("human", "{input}"),
            ]
        )
        return create_stuff_documents_chain(self.llm, qa_prompt)

    def create_conversational_rag_chain(self):
        return RunnableWithMessageHistory(
            create_retrieval_chain(
                self.history_aware_retriever, self.question_answer_chain
            ),
            self.get_session_history,
            input_messages_key="input",
            history_messages_key="chat_history",
            output_messages_key="answer",
        )

    def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
        return self.history_provider

    def get_summary_of_combined_answer(self, combined_answer: str) -> str:
        """
        Generate a summary of the combined answer using the language model.

        :param combined_answer: The full answer to be summarized.
        :return: A summarized version of the answer.
        """
        llm_message = [
            (
                "system",
                """
                    Summarize the following answer in a concise manner in less than 5 words:\n\n
                """,
            ),
            (
                "human",
                """Here is the answer to summarize: {combined_answer}""",
            ),
        ]

        chat_prompt = ChatPromptTemplate.from_messages(llm_message)

        chain = chat_prompt | self.llm

        response = chain.invoke(
            {
                "combined_answer": combined_answer,
            }
        )

        return response
