import stripe

from configs.firebase import SERVER_TIMESTAMP, db
from functions.Invoices import (createSetupIntent, getCurrency, getCustomer,
                                getStripeCreds, saveInvoice)
from V2.functions.Users.main import User, getUser
from V2.middlewares.auth import API_Error
from V2.Params import Params


def getPaymentMethods(params: Params):
    currentUser = params.currentUser
    isProUser, isEnterpriseAdmin = currentUser.isProUser, currentUser.isEnterpriseAdmin
    uid, enterpriseId = currentUser.uid, currentUser.enterpriseId
    if isProUser or isEnterpriseAdmin:
        enterpriseId = "cfbZh6XFBG3usCcUwpRE"
    stripeCreds = getStripeCreds(enterpriseId)
    if stripeCreds:
        stripe.api_key = stripeCreds.get("apiSecret")
        customer = getCustomer(uid, enterpriseId=enterpriseId)
        if not customer:
            customer = createCustomer(currentUser, stripeCreds, enterpriseId)
        try:
            setupIntent = createSetupIntent(uid, enterpriseId, stripeCreds)
            defaultPaymentMethod = None
            try:
                defaultPaymentMethod = (
                    stripe.Customer.retrieve(customer.get("customerId"))
                    .get("invoice_settings")
                    .get("default_payment_method")
                )
            except Exception as e:
                print(e)
            paymentMethods = stripe.PaymentMethod.list(
                customer=customer.get("customerId"), type="card"
            ).get("data", [])
            paymentMethods.extend(
                stripe.PaymentMethod.list(
                    customer=customer.get("customerId"), type="us_bank_account"
                ).get("data", [])
            )
            if defaultPaymentMethod:
                for paymentMethod in paymentMethods:
                    paymentMethod["default"] = (
                        paymentMethod.get("id") == defaultPaymentMethod
                    )
            return dict(
                paymentMethods=paymentMethods,
                stripeCreds=stripeCreds.get("apiKey"),
                setupIntent=setupIntent,
            )
        except Exception as e:
            raise API_Error(str(e), 400)
    raise API_Error("Can't get payment methods, contact admin.", 400)


def updateCustomerBalance(enterpriseId, uid, amount):
    try:
        customer = updateCustomer(uid, enterpriseId, balance=amount)
        return customer
    except Exception as e:
        print(e)
        return False
    return True


def createPaymentIntent(
    api_key, customerId, paymentMethodId, amount, currency, **kwargs
):
    try:
        stripe.api_key = api_key
        intent = stripe.PaymentIntent.create(
            amount=int(amount),
            customer=customerId,
            confirm=True,
            currency=currency,
            payment_method_types=["card"],
            payment_method=paymentMethodId,
            metadata=kwargs,
        )
        return intent
    except Exception as e:
        print(e)
        # raise API_Error(str(e), 500)
        return None


def updateThreshold(params: Params):
    enterpriseId = params.currentUser.enterpriseId
    uid = params.currentUser.uid
    amount = round(params.args.get("amount"), 2)
    try:
        customer = updateCustomer(uid, enterpriseId, rechargeThreshold=amount)
        return customer
    except Exception as e:
        raise API_Error(str(e), 500)


def updateRechargeAmount(params: Params):
    enterpriseId = params.currentUser.enterpriseId
    uid = params.currentUser.uid
    amount = round(params.args.get("amount"), 2)
    try:
        customer = updateCustomer(uid, enterpriseId, rechargeAmount=amount)
        return customer
    except Exception as e:
        raise API_Error(str(e), 500)


def getMinimumThreshold(enterpriseId, customerId):
    try:
        ref = (
            db.collection("paymentPlatformCredentials")
            .document(f"STRIPE{enterpriseId}")
            .collection("customers")
            .document(customerId)
            .get()
        )
        return ref.to_dict().get("rechargeThreshold", 0)
    except Exception as e:
        print(e)
        return 0


def addBalance(params: Params):
    enterpriseId = params.args.get("enterpriseId")
    uid = params.args.get("uid")
    if params.currentUser:
        enterpriseId = params.currentUser.enterpriseId
        uid = params.currentUser.uid
    customer = params.args.get("customer")
    if not customer:
        customer = getCustomer(uid, enterpriseId)
    amount = round(float(params.args.get("amount")), 2)
    stripeCreds = getStripeCreds(enterpriseId)
    currency = getCurrency(enterpriseId)
    paymentMethodId = params.args.get("paymentMethodId")
    if not paymentMethodId:
        stripe.api_key = stripeCreds.get("apiSecret")
        paymentMethods = stripe.PaymentMethod.list(
            customer=customer.get("customerId"), type="card"
        ).get("data", [])
        if paymentMethods:
            # find default payment method
            paymentMethodId = next(
                (pm for pm in paymentMethods if pm.get("default")), None
            )
            if not paymentMethodId:
                paymentMethodId = paymentMethods[0].get("id")
    try:
        intent = createPaymentIntent(
            stripeCreds.get("apiSecret"),
            customer.get("customerId"),
            paymentMethodId,
            amount * 100,
            currency,
        )
        if intent:
            if intent.get("status") == "succeeded":
                saveTransaction(
                    id=intent.get("id"),
                    customerId=customer.get("customerId"),
                    amount=amount,
                    uid=uid,
                    enterpriseId=enterpriseId,
                    status=intent.get("status"),
                    type="credit",
                    description=params.args.get("description", "Added by user/system."),
                    batchId=params.args.get("batchId"),
                    invoiceId=params.args.get("invoiceId"),
                )
                currentBalance = round(customer.get("balance", 0), 2)
                if currentBalance < 0: currentBalance = 0
                currentBalance += amount
                updateCustomer(uid, enterpriseId, balance=currentBalance)
                # updateThreshold(params)
                return True
            else:
                print(intent)
            return False
    except Exception as e:
        print(uid, customer)
        print(e)
        return False
        # raise API_Error(str(e), 500)


def saveTransaction(
    id, customerId, amount, uid, enterpriseId, status, type, orderId=None, **kwargs
):
    customer = getCustomer(uid, enterpriseId)
    balance = customer.get("balance", 0)
    if not balance: balance = 0
    newBalance = balance
    if status == "succeeded":
        if type == "credit":  newBalance = balance + amount
        elif type == "debit": newBalance = balance - amount
    transaction = dict(
        id=id,
        amount=amount,
        uid=uid,
        enterpriseId=enterpriseId,
        status=status,
        customerId=customerId,
        type=type,
        orderId=orderId,
        createdAt=SERVER_TIMESTAMP,
        prevBalance=customer.get("balance", 0),
        newBalance=round(newBalance, 2),
        **kwargs,
    )
    db.collection(
        f"paymentPlatformsCredentials/STRIPE{enterpriseId}/transactions"
    ).document(id).set(transaction)
    return transaction


def getCustomerTransactions(enterpriseId, uid):
    try:
        ref = (
            db.collection(
                f"paymentPlatformsCredentials/STRIPE{enterpriseId}/transactions"
            )
            .where("uid", "==", uid)
            .order_by("createdAt", "DESCENDING")
            .get()
        )
        return [doc.to_dict() for doc in ref]
    except Exception as e:
        print(e)
        return None


def getWallet(params: Params):
    enterpriseId = params.currentUser.enterpriseId
    uid = params.currentUser.uid
    isProUser = params.currentUser.isProUser
    if isProUser:
        enterpriseId = "cfbZh6XFBG3usCcUwpRE"
    customer = getCustomer(uid, enterpriseId)
    if not customer:
        customer = createCustomer(
            params.currentUser, getStripeCreds(enterpriseId), enterpriseId
        )
    balance = customer.get("balance", 0)
    if not balance:
        balance = 0
    transactions = getCustomerTransactions(enterpriseId, uid)
    if not transactions:
        transactions = []
    rechargeAmount = customer.get("rechargeAmount", 0)
    return dict(
        rechargeAmount=rechargeAmount,
        balance=round(balance, 2),
        transactions=transactions,
        currency=getCurrency(enterpriseId),
        rechargeThreshold=customer.get("rechargeThreshold", 0),
    )


def createBatchInvoiceFromWallet(params: Params):
    user = params.currentUser
    enterpriseId = user.enterpriseId
    batchId = params.args.get("batchId")
    batchRef = db.document(f"enterprises/{enterpriseId}/batches/{batchId}").get()
    if not batchRef.exists: return
    batch = batchRef.to_dict()
    orderIds = batch.get("orderIds")
    groupedOrders = {}
    currency = getCurrency(enterpriseId)
    stripeCreds = getStripeCreds(enterpriseId)
    stripe.api_key = stripeCreds.get("apiSecret")
    stripe.api_version = "2017-08-15"
    for orderId in orderIds:
        orderInvoiceRef = db.collection("orderInvoices").document(orderId).get()
        if not orderInvoiceRef.exists:
            print(orderInvoiceRef.id, "Not invoiced")
            continue
        orderInvoice = orderInvoiceRef.to_dict()
        orderUserId = orderInvoice.get("userId", orderInvoice.get("uid"))
        currentOrders = groupedOrders.get(orderUserId, [])
        currentOrders.append(orderInvoice)
        groupedOrders[orderUserId] = currentOrders
    for userId, orderInvoices in groupedOrders.items():
        if userId == user.uid: continue
        if userId in [
            "S6cjn7gKnDh2O6OGNaZRrZPdYlh2",
            "YpgkOp5oD1bwPEg6UhCB78kwv053",
            "gG895WoY1TWSCwGgNCGupORs7Xg1",
            "ad0WwLLxhAODJzg58kmBqA9ys4j1",
            "XiZwBRHZtmXeJREtBhLVQVdJm6f1",
            # "p510LF4wLkhM88ybNKaH1LdxoCu2",  # rumreefapparel@gmail.com
            # "lL5gKEFEWOO3FATxSg5YHhLEKdC3",  # mark@coedsportswear.com
            "cVlvEpN6XYe0YTBQJNYzDKrvBFg2",  # seller1@candlebuilders.com
            # "Vv40VRMA7ZfKUcUbgGkwy9K83M92", # tpvia01@gmail.com
            "PhDBWLh23xM4DKiMXmOs2mSf5L32",  # collision@85supply.com
        ]:
            continue
        # if userId not in ["zhsR0CgE36WqxhQVuwyEaTJZLel2", "dlOrPiJMCaVGjz1yFUFjzdmJOVr2"]:continue
        invoiceId = str(userId) + str(batchId)
        shippingCost = sum(
            order.get("shippingCost")
            for order in orderInvoices
            if not order.get("invoiced")
        )
        printingCost = sum(
            order.get("printingCost")
            for order in orderInvoices
            if not order.get("invoiced")
        )
        totalCost = sum(
            order.get("totalCost")
            for order in orderInvoices
            if not order.get("invoiced")
        )
        poCost = sum(
            order.get("poCost") for order in orderInvoices if not order.get("invoiced")
        )
        discount = sum(
            order.get("discount")
            for order in orderInvoices
            if not order.get("invoiced")
        )
        orderIds = [
            order.get("id") for order in orderInvoices if not order.get("invoiced")
        ]
        platformOrderIds = [
            order.get("platformOrderId")
            for order in orderInvoices
            if not order.get("invoiced")
        ]
        customer = getCustomer(userId, enterpriseId)
        if not customer:
            continue
        paymentMethod = "card/bank_account"
        paid = False
        transactionId = None
        # if enterpriseId in ["gv51IUBOdmyCPcW1NfDz", "4ARNc3YSVqBDEcHbXa0f", "e9g3DJWshxOh7SAV50pV08UK8QI2"]:
        balance = customer.get("balance", 0)
        if balance < 0: balance = 0
        paymentMethod = "wallet"
        customerId = customer.get("customerId")
        if totalCost == 0: continue
        failed = False
        if balance < totalCost:
            minimumThreshold = customer.get("rechargeThreshold", 0)
            amountToAdd = totalCost - balance + minimumThreshold
            added = addBalance(
                Params(
                    currentUser=None,
                    hostname=params.hostname,
                    args=dict(
                        amount=amountToAdd,
                        uid=userId,
                        enterpriseId=enterpriseId,
                        description=f"Added for charging batch {batchId}",
                        invoiceId=invoiceId,
                        batchId=batchId,
                    ),
                    id=None,
                )
            )
            if added: balance += amountToAdd
            else:
                saveTransaction(
                    id=userId + batchId,
                    customerId=customerId,
                    amount=round(totalCost, 2),
                    uid=userId,
                    enterpriseId=enterpriseId,
                    status="failed",
                    batchId=batchId,
                    type="debit",
                    description="Payment Failed.",
                    invoiceId=invoiceId,
                )
                failed = True
                # continue
        transaction = saveTransaction(
            id=userId + batchId,
            customerId=customerId,
            amount=round(totalCost, 2),
            uid=userId,
            enterpriseId=enterpriseId,
            status="succeeded" if not failed else "failed",
            batchId=batchId,
            type="debit",
            description="Charged for batch.",
            invoiceId=invoiceId,
        )
        transactionId = transaction.get("id")
        if not failed: balance -= totalCost
        updateCustomer(userId, enterpriseId, balance=balance)
        paid = not failed
        # else:
        #     paymentMethodId = None
        #     paymentMethods = stripe.PaymentMethod.list(customer=customer.get('customerId'), type="card").get("data", [])
        #     if paymentMethods:
        #         # find default payment method
        #         for paymentMethod in paymentMethods:
        #             if paymentMethod.get("default"):
        #                 paymentMethodId = paymentMethod.get("id")
        #                 break
        #         # if no default payment method, use first payment method
        #         if not paymentMethodId: paymentMethodId = paymentMethods[0].get("id")
        #     if not paymentMethodId: continue
        #     intent = createPaymentIntent(stripeCreds.get("apiSecret"),customer.get("customerId"),paymentMethodId, totalCost*100, currency, uid=userId, enterpriseId=enterpriseId, batchId=batchId)
        #     paid=intent.get("status") == "succeeded"

        # stripe.InvoiceItem.create(customer=customerId, amount=int(totalCost*100), currency=currency, metadata =metadata)
        # stripeInvoice = stripe.Invoice.create(customer = customerId, collection_method='charge_automatically',metadata =metadata , currency =currency)
        # stripeInvoice = dict(stripe.Invoice.finalize_invoice(stripeInvoice.get('id')))
        # else:
        #     transaction = saveTransaction(
        #         id=userId+batchId,
        #         customerId=customerId,
        #         amount=round(totalCost,2),
        #         uid=userId,
        #         enterpriseId=enterpriseId,
        #         status="succeeded",
        #         batchId=batchId,
        #         type="debit"
        #     )
        #     transactionId = transaction.get("id")
        #     updateCustomer(userId, enterpriseId,balance= balance - orderInvoice.get("totalCost"))
        if failed:
            for orderId in orderIds:
                batchItems = (
                    batchRef.reference.collection("batchItems")
                    .where("orderId", "==", orderId)
                    .get()
                )
                for batchItem in batchItems: batchItem.reference.update(dict(paymentStatus="failed"))
        saveInvoice(
            id=invoiceId,
            uid=userId,
            enterpriseId=enterpriseId,
            shippingCost=shippingCost,
            printingCost=printingCost,
            totalCost=totalCost,
            poCost=poCost,
            discount=discount,
            orderIds=orderIds,
            platformOrderIds=platformOrderIds,
            sent=True,
            paid=paid,
            sentAt=SERVER_TIMESTAMP,
            paidAt=SERVER_TIMESTAMP,
            userId=userId,
            platformCustomerId=customer.get("customerId"),
            paymentMethod=paymentMethod,
            transactionId=transactionId,
            batchId=batchId,
        )
        for orderInvoice in orderInvoices:
            db.document(f"orderInvoices/{orderInvoice.get('id')}").update(
                dict(
                    platformInvoiceItemId=orderInvoice.get("id"),
                    invoicedAt=SERVER_TIMESTAMP,
                    invoiced=True,
                    invoiceId=invoiceId,
                )
            )
    batchRef.reference.update(dict(invoiced=True,paymentStatus="succeeded" if not failed else "failed", invoiceId=invoiceId, invoicedAt=SERVER_TIMESTAMP))


def createOrderInvoiceFromWallet(params: Params):
    user = params.currentUser
    isProUser = user.isProUser
    uid = user.uid
    orderId = params.args.get("orderId")
    orderInvoiceRef = db.collection("orderInvoices").document(orderId).get()
    orderInvoice = orderInvoiceRef.to_dict()
    enterpriseId = orderInvoice.get("enterpriseId")
    if isProUser:
        enterpriseId = "cfbZh6XFBG3usCcUwpRE"
    customer = getCustomer(uid, enterpriseId)
    user = getUser(orderInvoice.get("userId"))
    # if orderInvoice.get("userId") == "R7UPQvH1cOUPPHcgOVtcaqRycAA3" and enterpriseId == "3pTMy91nsUIWWIsGF2eW": return
    if orderInvoice.get("userId") in [
        "H3VPShrDPIXOqRQdep6EnBp7zXZ2",
        "XiZwBRHZtmXeJREtBhLVQVdJm6f1",
        "tL328gQi1HMyTQ3Y0QuhvALmKGE2",
    ]:
        return
    customerId = customer.get("customerId")
    stripeInvoice = None
    invoiceItem = None
    paymentMethod = "wallet"
    try:
        balance = customer.get("balance", 0)
        # if balance < orderInvoice.get("totalCost"):
        #     minimumThreshold = getMinimumThreshold(enterpriseId, customerId)
        #     added = addBalance(Params(
        #         currentUser = user,
        #         hostname = params.hostname,
        #         args=dict(
        #             amount = orderInvoice.get("totalCost") - balance + minimumThreshold
        #         )
        #     ))
        #     if added: balance+=orderInvoice.get("totalCost") - balance + minimumThreshold
        if balance < orderInvoice.get("totalCost"):
            # raise API_Error("Insufficient balance.", 400)
            currency = getCurrency(enterpriseId)
            stripeCreds = getStripeCreds(enterpriseId)
            metadata = dict(orderId=orderId, uid=uid, enterpriseId=enterpriseId)
            description = f"Invoice for order {orderInvoice.get('platformOrderId')}"
            stripe.api_key = stripeCreds.get("apiSecret")
            stripe.api_version = "2017-08-15"
            invoiceItem = stripe.InvoiceItem.create(
                customer=customerId,
                amount=int(orderInvoice.get("totalCost") * 100),
                currency=currency,
                description=description,
                metadata=metadata,
            )
            stripeInvoice = stripe.Invoice.create(
                customer=customerId,
                collection_method="charge_automatically",
                description=description,
                metadata=metadata,
                currency=currency,
            )
            stripeInvoice = dict(
                stripe.Invoice.finalize_invoice(stripeInvoice.get("id"))
            )
            paymentMethod = "card/bank_account"
        if paymentMethod == "wallet":
            transaction = saveTransaction(
                id=orderId,
                customerId=customerId,
                amount=round(orderInvoice.get("totalCost"), 2),
                uid=uid,
                enterpriseId=enterpriseId,
                status="succeeded",
                orderId=orderId,
                type="debit",
            )
        invoiceId = saveInvoice(
            uid=uid,
            enterpriseId=enterpriseId,
            shippingCost=orderInvoice.get("shippingCost"),
            printingCost=orderInvoice.get("printingCost"),
            totalCost=orderInvoice.get("totalCost"),
            poCost=orderInvoice.get("totalCost"),
            sent=True,
            paid=True,
            paidAt=SERVER_TIMESTAMP,
            userId=uid,
            discount=orderInvoice.get("discount"),
            platformInvoiceId=(
                stripeInvoice.get("id")
                if paymentMethod == "card/bank_account"
                else transaction.get("id")
            ),
            orderIds=[orderId],
            platformOrderIds=[orderInvoice.get("platformOrderId")],
            platformCustomerId=customerId,
            sentAt=SERVER_TIMESTAMP,
        )
        orderInvoiceRef.reference.update(
            dict(
                sent=True,
                sentAt=SERVER_TIMESTAMP,
                invoiced=True,
                paymentMethod=paymentMethod,
                transactionId=(
                    transaction.get("id") if paymentMethod == "wallet" else None
                ),
                invoiceId=invoiceId,
                platformInvoiceItemId=(
                    invoiceItem.get("id") if invoiceItem else orderInvoice.get("id")
                ),
            )
        )
        if paymentMethod == "wallet":
            updateCustomer(
                uid, enterpriseId, balance=balance - orderInvoice.get("totalCost")
            )
    except Exception as e:
        print(e)
        raise API_Error(str(e), 400)
    return dict(invoice=orderInvoiceRef.reference.get().to_dict())


def createCustomer(user: User, stripeCreds: dict, enterpriseId=None):
    uid, email, name, enterpriseId = (
        user.uid,
        user.email,
        user.displayName,
        user.enterpriseId if not enterpriseId else enterpriseId,
    )
    isProUser = user.isProUser
    isEnterpriseAdmin = user.isEnterpriseAdmin
    if isProUser or isEnterpriseAdmin:
        "cfbZh6XFBG3usCcUwpRE"
    stripe.api_key = stripeCreds.get("apiSecret")
    customer = dict(
        stripe.Customer.create(
            description="",
            name=name,
            email=email,
            metadata=dict(uid=uid, displayName=name),
        )
    )
    customer = dict(
        id=uid,
        customerId=customer.get("id"),
        name=customer.get("name"),
        email=customer.get("email"),
        metadata=customer.get("metadata"),
        createdAt=SERVER_TIMESTAMP,
        updatedAt=SERVER_TIMESTAMP,
        enterpriseId=user.get("enterpriseId"),
        uid=uid,
    )
    db.collection("paymentPlatformsCredentials").document(
        "STRIPE" + enterpriseId
    ).collection("customers").document(uid).set(customer, merge=True)
    return customer


def updateCustomer(userId, enterpriseId, **kwargs):
    ref = (
        db.collection("paymentPlatformsCredentials")
        .document("STRIPE" + enterpriseId)
        .collection("customers")
        .document(userId)
        .get()
    )
    if ref.exists:
        ref.reference.update(kwargs)
        return ref.reference.get().to_dict()
    return None
