from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
import re
import os
from typing import List, Dict
from pathlib import Path
from sqlalchemy import Column, Integer, String, Text, Boolean, Float, DateTime, create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql import func


DATABASE_URL = "postgresql+asyncpg://postgres:p4python@postgres.cqf3qroocfwm.us-east-1.rds.amazonaws.com:5432/email_filter"

Base = declarative_base()

engine = create_engine(
    DATABASE_URL.replace("+asyncpg", ""),
    echo=True
)

SessionLocal = sessionmaker(bind=engine)

def get_all_emails_from_db():
    db = SessionLocal()
    try:
        emails = (
            db.query(EmailSpam)
            .order_by(EmailSpam.received_at.desc())
            .all()
        )
        return emails or []   # ✅ always returns a list
    except Exception as e:
        print("DB error while fetching emails:", e)
        return []             # ✅ fail-safe
    finally:
        db.close()



app = FastAPI()

# Get the directory where this file is located
BASE_DIR = Path(__file__).parent
TEMPLATES_DIR = BASE_DIR / "templates"
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))


class EmailSpam(Base):
    __tablename__ = "email_filter"
    
    id = Column(Integer, primary_key=True, index=True)
    sender_email = Column(String, nullable=False)
    subject = Column(String, nullable=True)
    body = Column(Text, nullable=True)

    spam_probability = Column(Float, nullable=True)
    predicted_label = Column(String, nullable=True)
    feedback_label = Column(String, nullable=True)
    is_reviewed = Column(Boolean, default=False)

    email_summary = Column(Text, nullable=True)      
    model_accuracy = Column(Float, nullable=True)  
    received_at = Column(DateTime(timezone=True), server_default=func.now())

# ----------------------------
# Automated Analysis Functions
# ----------------------------
def analyze_email_subject(subject: str) -> List[str]:
    """Check email subject for spam-like patterns"""
    spam_patterns = []
    subject_lower = subject.lower()
    
    # Common spam indicators
    spam_keywords = [
        r'\b(win|winner|won|prize|free|urgent|limited time|act now|click here)\b',
        r'\$[\d,]+',
        r'!!!+',
        r'\b(viagra|cialis|pharmacy|pills)\b',
        r'\b(guaranteed|risk-free|no obligation)\b',
        r'\b(click|download|claim|verify)\b.*\b(now|immediately|today)\b'
    ]
    
    for pattern in spam_keywords:
        if re.search(pattern, subject_lower, re.IGNORECASE):
            spam_patterns.append(f"Subject contains suspicious pattern: '{pattern}'")
    
    return spam_patterns

# ----------------------------
# Root → Redirect to listing
# ----------------------------
@app.get("/")
async def root():
    return RedirectResponse(url="/dashboard/spam")

# ----------------------------
# LISTING PAGE
# ----------------------------

@app.get("/dashboard/spam", response_class=HTMLResponse)
async def spam_dashboard(request: Request, page: int = 1):
    PER_PAGE = 20
    offset = (page - 1) * PER_PAGE

    def normalize_label(label):
        return label.strip().lower() if label else None

    # 🔹 1. FETCH ALL EMAILS ONCE
    all_emails = get_all_emails_from_db()
    total_emails = len(all_emails)

    # 🔹 2. CALCULATE STATS FROM ALL EMAILS (NOT PAGINATED)
    all_labels = [normalize_label(e.predicted_label) for e in all_emails]

    spam_count = sum(1 for l in all_labels if l == "spam")
    ham_count = sum(1 for l in all_labels if l == "ham")
    undecided_count = sum(1 for l in all_labels if l == "undecided")
    not_classified_count = sum(1 for l in all_labels if l is None)
    classified_count = total_emails - not_classified_count

    db = SessionLocal()

    latest_accuracy_row = (
    db.query(EmailSpam.model_accuracy)
    .filter(EmailSpam.model_accuracy.isnot(None))
    .order_by(EmailSpam.id.desc())
    .first()
    )

    if latest_accuracy_row:
        model_accuracy_percent = round(latest_accuracy_row.model_accuracy * 100, 1)
    else:
        model_accuracy_percent = 0

    # 🔹 3. PAGINATE EMAILS FOR DISPLAY
    db_emails = all_emails[offset:offset + PER_PAGE]

    enhanced_emails = []
    for email in db_emails:
        display_label = normalize_label(email.predicted_label)

        spam_prob = email.spam_probability  # None or float (0–1)

        ml_score = round(spam_prob * 100, 1) if spam_prob is not None else None

        enhanced_emails.append({
            "email_id": email.id,
            "sender": email.sender_email,
            "subject": email.subject,
            "summary": email.body,
            "label": display_label if display_label else "undecided",
            "ml_score": ml_score
        })

    # 🔹 4. TOTAL PAGES
    total_pages = (total_emails + PER_PAGE - 1) // PER_PAGE

    stats = {
        "total": total_emails,
        "spam": spam_count,
        "ham": ham_count,
        "undecided": undecided_count,
        "classified": classified_count
    }

    return templates.TemplateResponse(
        "spam_dashboard.html",
        {
            "request": request,
            "emails": enhanced_emails,
            "stats": stats,
            "page": page,
            "total_pages": total_pages,
            "model_accuracy": model_accuracy_percent
        }
    )


# ----------------------------
# ADD FORM PAGE
# ----------------------------
@app.get("/dashboard/spam/add", response_class=HTMLResponse)
async def spam_add_page(request: Request):
    return templates.TemplateResponse(
        "spam_add.html",
        {"request": request}
    )

# ----------------------------
# FORM SUBMIT
# ----------------------------

# @app.post("/dashboard/spam/add")
# async def spam_add_submit(
#     request: Request,
#     subject: str = Form(...),
#     sender: str = Form(...),
#     summary: str = Form(...)
# ):
#     analysis = generate_automated_analysis(subject, sender, summary)

#     db = SessionLocal()
#     try:
#         new_email = EmailSpam(
#             sender_email=sender,
#             subject=subject,
#             body=summary,
#             spam_probability=None,
#             predicted_label="Not Classified",
#             feedback_label=None,
#             is_reviewed=False
#         )
#         db.add(new_email)
#         db.commit()
#     finally:
#         db.close()

#     return RedirectResponse(url="/dashboard/spam", status_code=302)


# ----------------------------
# CLASSIFICATION UPDATE
# ----------------------------
# allowed values: "spam", "ham", "undecided"

@app.post("/dashboard/spam/classify")
async def classify_email(
    email_id: int = Form(...),
    classification: str = Form(...)
):
    classification = classification.lower()  # normalize

    if classification not in ["spam", "ham", "undecided"]:
        # stay on the same page even if invalid
        return RedirectResponse(f"/dashboard/spam/email/{email_id}", status_code=303)

    db = SessionLocal()
    try:
        email_obj = db.query(EmailSpam).filter(EmailSpam.id == email_id).first()
        if email_obj:
            email_obj.predicted_label = classification
            # email_obj.is_reviewed = True
            db.commit()
    finally:
        db.close()

    # redirect back to the same email detail page
    return RedirectResponse(f"/dashboard/spam/email/{email_id}", status_code=303)



@app.get("/dashboard/spam/email/{email_id}", response_class=HTMLResponse)
async def email_details(request: Request, email_id: int):
    db = SessionLocal()
    try:
        email = db.query(EmailSpam).filter(EmailSpam.id == email_id).first()
    finally:
        db.close()

    if not email:
        return RedirectResponse("/dashboard/spam", status_code=303)

    # Detect if content is HTML
    body = email.body or ""
    is_html = bool(re.search(r'<[a-z][\s\S]*>', body, re.IGNORECASE))
    
    # Calculate ML score from spam_probability
    spam_prob = email.spam_probability
    ml_score = round(spam_prob * 100, 1) if spam_prob is not None else None
    
    # Normalize label
    predicted_label = (email.predicted_label or "").strip().lower()
    
    enhanced_email = {
        "id": email.id,
        "subject": email.subject or "No Subject",
        "sender": email.sender_email,
        "body": body,
        "summary": body,  # For template compatibility
        "is_html": is_html,
        "ml_score": ml_score,
        "label": predicted_label if predicted_label else "undecided",
        "spam_probability": spam_prob,
        "received_at": email.received_at
    }

    return templates.TemplateResponse(
        "email_details.html",
        {
            "request": request,
            "email": enhanced_email
        }
    )


Base.metadata.create_all(bind=engine)

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "main:app",          
        host="0.0.0.0",
        port=9001,
        reload=True,
        reload_dirs=[str(BASE_DIR)], 
        reload_excludes=["venv"]     
    )