# combine analysis
import sys

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

sys.path.append(".")
# from utils.logger import ServiceLogger
import logging

from services.ppt_training.training_data_models import (
    ClientTraining, PPTData, PresentationInstructions, SlideInstructions)
from utils.dynamo_db import DynamoDB

XCM_logger = logging.getLogger()


class CombinePresentationsAnalysis:

    db = DynamoDB()
    db_table = "Client_Pitch_Instructions"

    def __init__(self, client: str, pitch_type: str, new_ppt_data: PPTData) -> None:
        XCM_logger.info(
            "Initialized CombinePresentationsAnalysis for client: %s and pitch type: %s",
            client,
            pitch_type,
        )
        self.client = client
        self.pitch_type = pitch_type
        self.new_ppt_data = new_ppt_data

        self.client_training_data = self._get_client()

    def _get_client(self):
        table = self.db.dynamodb.Table(self.db_table)
        response = self.db.get_item(table, self.client, self.pitch_type)

        if not response:
            client_training = ClientTraining(
                client=self.client, pitch_type=self.pitch_type
            )
            client_training.save_to_db()
            return client_training

        return ClientTraining(**response)

    def consolidate_analysis(self):
        "consolidate analysis of the presentation"
        XCM_logger.info(
            "Consolidating analysis for client: %s and pitch type: %s",
            self.client,
            self.pitch_type,
        )
        try:
            self.consolidate_presentation_analysis()
            self.consolidate_slides_analysis()
            return self.client_training_data.save_to_db()
        except Exception as e:  # pylint: disable=broad-except
            XCM_logger.error(
                "Error consolidating analysis for client: %s and pitch type: %s",
                self.client,
                self.pitch_type,
                exc_info=True,
            )
            return False

    def consolidate_presentation_analysis(self):
        "consolidate analysis at the presentation level"

        # pull out the presentation analysis from client
        training_presentation_analysis = (
            self.client_training_data.presentation_instructions
        )
        current_presentation_overview = (
            self.new_ppt_data.PPT_analysis.presentation_overview
        )
        current_presentation_enhancements = (
            self.new_ppt_data.PPT_analysis.presentation_improvements
        )

        # if the training presentation analysis is None, then create it
        if training_presentation_analysis is None:
            self.client_training_data.presentation_instructions = (
                PresentationInstructions(
                    presentation_overview=[current_presentation_overview],
                    presentation_enhancements=[current_presentation_enhancements],
                )
            )
            return

        # if there is not a presentation in the client training data for this pitch type, then create it
        # check to see if any of the items are novel
        novel_overview = self.return_disimalar(
            training_presentation_analysis.presentation_overview,
            [current_presentation_overview],
        )
        novel_enhancements = self.return_disimalar(
            training_presentation_analysis.presentation_enhancements,
            [current_presentation_enhancements],
        )

        # update the client training data

        self.client_training_data.presentation_instructions.presentation_overview.extend(
            novel_overview
        )
        self.client_training_data.presentation_instructions.presentation_enhancements.extend(
            novel_enhancements
        )

    def consolidate_slides_analysis(self):
        "consolidate analysis at the slide level"
        for current_slide in self.new_ppt_data.slides_analysis:
            for slide_type in current_slide.slide_categories:
                training_slide = self.client_training_data.slides_instructions.get(
                    slide_type, None
                )

                # if there is not a slide type in the client training data for this pitch type, then create it
                if training_slide is None:
                    self.client_training_data.slides_instructions[slide_type] = (
                        SlideInstructions(
                            slide_type=slide_type,
                            questions_to_answer=current_slide.questions_answered_by_slide
                            + current_slide.additional_information,
                            slide_overview=[current_slide.slide_overview],
                            design_instructions=[current_slide.slide_direction],
                        )
                    )
                else:
                    self.consolidate_slide_analysis(current_slide, training_slide)

    def consolidate_slide_analysis(self, current_slide, training_slide):
        "consolidate analysis at the slide level"
        current_slide_questions = (
            current_slide.questions_answered_by_slide
            + current_slide.additional_information
        )
        current_slide_overview = [current_slide.slide_overview]
        current_slide_design_instructions = [current_slide.slide_direction]

        # check to see if any of the items are novel
        novel_questions = self.return_disimalar(
            training_slide.questions_to_answer, current_slide_questions
        )
        novel_overview = self.return_disimalar(
            training_slide.slide_overview, current_slide_overview
        )
        novel_design_instructions = self.return_disimalar(
            training_slide.design_instructions, current_slide_design_instructions
        )

        # update the client training data
        training_slide.questions_to_answer.extend(novel_questions)
        training_slide.slide_overview.extend(novel_overview)
        training_slide.design_instructions.extend(novel_design_instructions)

    def check_similarity(self, list1, list2):
        "check similarity between two lists"
        vectorizer = TfidfVectorizer().fit_transform(list1 + list2)
        vectors = vectorizer.toarray()
        similarity_matrix = cosine_similarity(
            vectors[: len(list1)], vectors[len(list1) :]
        )
        return similarity_matrix

    def return_disimalar(self, list1, list2, threshold=0.85):
        "return disimilar items between two lists from list2"

        if len(list1) == 0:
            list1 = ["foo"]
        if len(list2) == 0:
            return []

        similarity_matrix = self.check_similarity(list2, list1)
        disimilar = []
        for i in range(len(similarity_matrix)):
            if (similarity_matrix[i] < threshold).all():
                disimilar.append(list2[i])
        return disimilar


if __name__ == "__main__":

    ppt_data = PPTData.get_ppt_data(project_id="EdgarReeves")

    combined_analysis = CombinePresentationsAnalysis(
        client="sunbelt", pitch_type="CIM", new_ppt_data=ppt_data
    )

    combined_analysis_result = combined_analysis.consolidate_analysis()

    print(combined_analysis_result)
