from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer
from starlette.middleware.base import BaseHTTPMiddleware

from dependencies.get_current_user import get_current_user

security = HTTPBearer()

# Define paths that should be public (no auth required)
PUBLIC_PATHS = {
    "/auth/login",
    "/auth/callback",
    "/docs",
    "/openapi.json",
    "/redoc",
    "/utilities/sign_blob",
}


class AuthMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        try:
            # Skip authentication for public paths
            if request.url.path in PUBLIC_PATHS:
                # Initialize empty state for public routes
                setattr(request.state, "current_user", None)
                return await call_next(request)

            # Check if headers exist and contain Authorization
            auth_header = getattr(request.headers, "get", lambda x: None)(
                "Authorization"
            )
            if not auth_header or not auth_header.startswith("Bearer "):
                raise HTTPException(
                    status_code=401, detail="Invalid authorization header"
                )

            token = auth_header.split(" ")[1]
            current_user = await get_current_user(token)
            # Set the current_user in request state
            setattr(request.state, "current_user", current_user)

            if current_user.user_id:
                setattr(request.state, "user_id", current_user.user_id)

            if not current_user:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
                )

            response = await call_next(request)
            return response

        except HTTPException as http_exception:
            # Return a proper JSON response for HTTP exceptions
            return JSONResponse(
                status_code=http_exception.status_code,
                content={"detail": http_exception.detail},
            )
        except Exception as e:
            # Handle unexpected errors
            print("Error in the AuthMiddleware")
            import traceback

            trace = traceback.format_exc()
            print(trace)
            print("Error in AuthMiddleware", e)
            return JSONResponse(status_code=500, content={"detail": str(e)})
