from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, List, Dict, Any

from django.conf import settings
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import os


_embedder = None
_faiss_index = None
_metadata = []


def get_embedder() -> SentenceTransformer:
    global _embedder
    if _embedder is None:
        _embedder = SentenceTransformer(settings.EMBEDDING_MODEL)
    return _embedder


def _index_paths():
    os.makedirs(os.path.dirname(settings.FAISS_INDEX_PATH), exist_ok=True)
    return settings.FAISS_INDEX_PATH, settings.FAISS_META_PATH


def _load_index():
    global _faiss_index, _metadata
    index_path, meta_path = _index_paths()
    if os.path.exists(index_path) and os.path.exists(meta_path):
        _faiss_index = faiss.read_index(index_path)
        with open(meta_path, 'r', encoding='utf-8') as f:
            _metadata = json.load(f)
    else:
        dim = get_embedder().get_sentence_embedding_dimension()
        _faiss_index = faiss.IndexFlatIP(dim)
        _metadata = []


@dataclass
class IngestItem:
    id: str
    title: str
    url: str
    content: str
    snippet: str | None = None


def chunk_text(text: str, chunk_size: int | None = None, overlap: int | None = None) -> List[str]:
    chunk_size = chunk_size or getattr(settings, 'RAG_CHUNK_SIZE', 800)
    overlap = overlap or getattr(settings, 'RAG_CHUNK_OVERLAP', 100)
    words = text.split()
    chunks: List[str] = []
    start = 0
    while start < len(words):
        end = min(len(words), start + chunk_size)
        chunks.append(" ".join(words[start:end]))
        if end == len(words):
            break
        start = max(0, end - overlap)
    return chunks


def _normalize(vecs: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
    return vecs / norms


def upsert_items(items: Iterable[IngestItem]) -> int:
    global _faiss_index, _metadata
    if _faiss_index is None:
        _load_index()
    embedder = get_embedder()

    vectors: List[np.ndarray] = []
    metas: List[Dict[str, Any]] = []
    for item in items:
        chunks = chunk_text(item.content)
        if not chunks:
            continue
        embeddings = embedder.encode(chunks)
        for idx, (chunk, vec) in enumerate(zip(chunks, embeddings)):
            metas.append({
                "id": f"{item.id}-{idx}",
                "doc_id": item.id,
                "title": item.title,
                "url": item.url,
                "snippet": item.snippet or chunk[:300],
                "chunk_index": idx,
            })
            vectors.append(vec)

    if not vectors:
        return 0

    mat = np.vstack(vectors).astype('float32')
    mat = _normalize(mat)
    _faiss_index.add(mat)
    _metadata.extend(metas)

    index_path, meta_path = _index_paths()
    faiss.write_index(_faiss_index, index_path)
    with open(meta_path, 'w', encoding='utf-8') as f:
        json.dump(_metadata, f)
    return len(vectors)


def search(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
    if _faiss_index is None:
        _load_index()
    if _faiss_index.ntotal == 0:
        return []
    vec = get_embedder().encode([query]).astype('float32')
    vec = _normalize(vec)
    scores, idxs = _faiss_index.search(vec, top_k)
    results: List[Dict[str, Any]] = []
    for rank, (score, idx) in enumerate(zip(scores[0], idxs[0])):
        if idx < 0 or idx >= len(_metadata):
            continue
        meta = _metadata[idx]
        results.append({
            "score": float(score),
            "title": meta.get("title"),
            "url": meta.get("url"),
            "snippet": meta.get("snippet"),
        })
    return results


