

import inspect
from dataclasses import dataclass, field
from datetime import datetime

import pytz

from configs.firebase import (SERVER_TIMESTAMP, ArrayUnion, DocumentReference,
                              db)
from functions.Bins import clearBin
from V2.functions.Address.main import Address
from V2.functions.main import DocumentData, validate
from V2.functions.Orders.OrderItem import OrderItem
from V2.functions.Printify.Auth import printifyTime
from V2.functions.Shipments.main import Shipment


@dataclass(kw_only=True)
class Order(DocumentData):
    platformOrderId:str
    platformId:str
    shopId:str
    shipped:bool
    cancelled:bool
    draft:bool
    totalDiscount:int|float
    totalPrice:int|float
    totalTax:int|float
    shippingCost:int|float
    grandTotal:int|float
    shippingAddress:Address
    orderNumber:str|None = None
    shopName:str=""
    currencyCode:str="USD"
    routed:bool|None=False
    shippingSpeedId:str|None=None
    itemIds:list[str] = field(default_factory=list)
    mappedItems:list[str] = field(default_factory=list)
    productIds:list[str] = field(default_factory=list)
    routedOrderIds:list[str] = field(default_factory=list)
    metadata:dict|None = field(default_factory=dict)
    routedFrom:dict|None = field(default_factory=dict)
    labels:list[dict] = field(default_factory=list)
    printFilePending:bool|None = False
    statusTags:list[str] = field(default_factory=list)

    @classmethod
    def from_dict(cls, data):      
        return cls(**{
            k: data.get(k) for k in inspect.signature(cls).parameters
        })

    @classmethod
    def get(cls, orderId:str):
        snap = db.document(f"orders/{orderId}").get()
        if snap.exists: 
            if not snap.to_dict().get("id"): snap.reference.update(dict(id=snap.id))
            return cls.from_dict({"id": snap.id, **snap.to_dict()})
    
    def getOrderItems(self):
        docs = db.collection("orderItems").where("orderId", "==", self.id).get()
        return [OrderItem.from_dict(doc.to_dict()) for doc in docs]
    
    def __post_init__(self):
        validate(self)

    def save(self,orderItems:list[OrderItem],rewrite=False, readd=False):
        return saveOrder(self, orderItems=orderItems, rewrite=rewrite, readd=readd)
    
    def update(self, **kwargs):
        ref = db.document(f"orders/{self.id}").get()
        ref.reference.set(kwargs, merge=True)
        return kwargs

    def markAsShipped(self, shipment:Shipment):
        ref = db.document(f"orders/{self.id}").get()
        items = self.getOrderItems()
        batch = db.batch()
        timezone = pytz.utc
        location:dict = self.metadata.get("location") if self.metadata else None
        if location: timezone = pytz.timezone(location.get("timezone", "UTC"))
        update = dict(
            shipped=True,
            shippedAt=datetime.now(tz=timezone),
            labels = ArrayUnion([dict(
                id=shipment.id,
                type="shipment",
                url=shipment.image if shipment.image else shipment.pdf
            )]),
            maskedOrderId = self.id.lower()+str(int(datetime.now(tz=timezone).timestamp())),
            metadata=dict(
                status = "shipped",    
                events=ArrayUnion([dict(
                    time=printifyTime(timezone.zone),
                    action="shipped",
                    affected_items=[item.platformOrderItemId for item in items],
                    tracking_number=shipment.trackingCode,
                    carrier=shipment.carrierName,
                    tracking_url=shipment.trackingUrl
                ), dict(
                    time=printifyTime(timezone.zone),
                    action="packaged",
                    affected_items=[item.platformOrderItemId for item in items],
                )])),
            statusTags = ["shipped"]
        )
        batch.set(ref.reference,update, merge=True)
        for item in items: batch.update(db.document(f"orderItems/{item.id}"), dict(
            shipped=True,
            shippedAt=datetime.now(tz=timezone),
        ))
        if self.routedFrom:
            routedRef = db.collection('orders').document(self.routedFrom.get("orderId")).get()
            if routedRef.exists:
                orderItems = db.collection('orderItems').where('orderId', '==', self.routedFrom.get("orderId")).get()
                for orderItem in orderItems:
                    batch.update(orderItem.reference,dict(
                        shipped=True,
                        shippedAt=datetime.now(tz=timezone),
                    ))
                batch.update(routedRef.reference,dict(
                    shipped=True,
                    shippedAt=datetime.now(tz=timezone),
                    statusTags = ["shipped"]
                ))
        batch.commit()
        clearBin(self.enterpriseId, self.id)
        fulfillBatchItems(self.enterpriseId, self.id)
        return update

    def cancel(self, orderItemIds:list[str]=[]):
        ref = db.document(f"orders/{self.id}").get()
        timezone = pytz.utc
        location:dict = self.metadata.get("location")
        if location: timezone = pytz.timezone(location.get("timezone", "UTC"))
        if ref.exists: 
            orderItems = self.getOrderItems()
            orderItemsToCancel = [orderItem for orderItem in orderItems if orderItem.id in orderItemIds] if len(orderItemIds) > 0 else [orderItem for orderItem in orderItems]
            if len(orderItemIds) in [0, len(orderItems)]:
                ref.reference.update(dict(
                    cancelled=len(orderItems) == len(orderItemsToCancel),
                    cancelledAt=datetime.now(timezone) if len(orderItems) == len(orderItemsToCancel) else None
                ))
                self.cancelled = len(orderItems) == len(orderItemsToCancel)
            metadata = dict(
                events=ArrayUnion([
                                {
                                    "time": printifyTime(timezone.zone),
                                    "action": "canceled",
                                    "affected_items":  [i.platformOrderItemId for i in orderItemsToCancel],
                                }
                        ])
            )
            if len(orderItems) == len(orderItemsToCancel): metadata['status'] = "canceled"
            ref.reference.set(dict(
                metadata=metadata,
                statusTags=["cancelled"]
            ), merge=True)
            for orderItem in orderItemsToCancel: orderItem.cancel()
            return self.cancelled
        return False

def getOrderRef(shopId:str,platformOrderId:str,platformId:str) -> DocumentReference | None:
    ref = db.collection('orders').where('platformOrderId', '==', platformOrderId).where('platformId', '==', platformId).where('shopId', '==', shopId).get()
    if ref: return ref[0]
    return None

def saveOrder(order:Order, orderItems: list[OrderItem], rewrite=False, readd=False) -> str:
    platformId, shopId, platformOrderId= order.platformId,order.shopId,order.platformOrderId
    ref:DocumentReference = getOrderRef(shopId, platformOrderId, platformId)
    batch = db.batch()
    order.itemIds = [orderItem.id for orderItem in orderItems]
    order.productIds = [orderItem.productId for orderItem in orderItems]
    order.draft = any(item.draft for item in orderItems)
    statusTags = []
    if order.cancelled: statusTags = ["cancelled"]
    if order.shipped: statusTags = ["shipped"]
    if order.draft: statusTags = ["draft"]
    if ref and not rewrite:
        orderId = ref.id
        oldOrder = Order.from_dict(ref.to_dict())
        wasShipped, wasCancelled = oldOrder.shipped, oldOrder.cancelled
        if wasShipped or wasCancelled: return order.id
        update = dict(
            shipped=order.shipped,
            cancelled=order.cancelled,
            updatedAt=order.updatedAt,
            statusTags=statusTags,
        )
        batch.update(ref.reference,update)
        orderItemRefs = db.collection("orderItems").where("orderId", "==", orderId).get()
        for item in orderItemRefs: batch.update(item.reference,update)
    elif ref and rewrite:
        oldOrder =  Order.from_dict(ref.to_dict())
        wasShipped, wasCancelled = oldOrder.shipped, oldOrder.cancelled
        if wasShipped or wasCancelled: return order.id
        orderId = ref.id
        order.id = orderId
        for orderItem in orderItems:
            orderItem.orderId = orderId
            orderItem.id = f"{orderId}{orderItem.platformOrderItemId}"
            orderItem.index = orderItems.index(orderItem)+1
            batch.set(db.document(f"orderItems/{orderItem.id}"),orderItem.to_dict(), merge=not readd)
        order.itemIds = [orderItem.id for orderItem in orderItems]
        order.productIds = [orderItem.productId for orderItem in orderItems]
        order.draft = any(item.draft for item in orderItems)
        if order.cancelled: statusTags = ["cancelled"]
        if order.shipped: statusTags = ["shipped"]
        if order.draft: statusTags = ["draft"]
        order.statusTags = statusTags
        if readd: order.mappedItems = order.itemIds
        batch.set(ref.reference,order.to_dict(), merge=not readd)
    else:
        _,ref = db.collection('orders').add(order.to_dict())
        order.id = ref.id
        for orderItem in orderItems:
            orderItem.orderId = order.id
            orderItem.index = orderItems.index(orderItem)+1
            orderItemId = f"{order.id}{orderItem.platformOrderItemId}"
            orderItem.id = orderItemId
            batch.set(db.document(f"orderItems/{orderItemId}"),orderItem.to_dict(), merge=True)
        order.itemIds = [orderItem.id for orderItem in orderItems]
        order.productIds = [orderItem.productId for orderItem in orderItems]
        order.draft = any(item.draft for item in orderItems)
        order.printFilePending = any(item.printFilePending for item in orderItems)
        if order.cancelled: statusTags = ["cancelled"]
        if order.shipped: statusTags = ["shipped"]
        if order.draft: statusTags = ["draft"]
        order.statusTags = statusTags
        batch.set(ref,order.to_dict(), merge=True)
    batch.commit()
    return ref.id


def fulfillBatchItems(enterpriseId:str,orderId:str):
    batches = db.collection(f"enterprises/{enterpriseId}/batches").where("orderIds", "array_contains", orderId).get()
    batchWrite = db.batch()
    for batch in batches:
        batchItems = db.collection(f"enterprises/{enterpriseId}/batches/{batch.id}/batchItems").where("orderId", "==", orderId).get()
        for batchItem in batchItems:
            batchWrite.update(batchItem.reference, dict(
                fulfilled=True,
                fulfilledAt=SERVER_TIMESTAMP
            ))
        batchWrite.update(batch.reference, dict(
            fulfilledItems=ArrayUnion([item.id for item in batchItems]),
        ))
    batchWrite.commit()

def findOrderByPlatformOrderId(shopId:str, platformOrderId:str) -> Order | None:
    ref = db.collection('orders').where('platformOrderId', '==', platformOrderId).where('shopId', '==', shopId).get()
    if ref: return Order.from_dict(ref[0].to_dict())
    return None