from datetime import datetime

import pytz
from zeep import helpers

from configs.firebase import SERVER_TIMESTAMP, db
from functions.Suppliers.Auth import CREDS, getClient, getSupplierCredentials
from functions.Suppliers.BlankProducts import (getBlankProduct,
                                               getBlanksUsedByEnterprises)
from functions.Suppliers.Locations import getLocations


def updateBlankPricing(params):
    id = params.get('id')
    blankProduct = getBlankProduct(id)
    enterpriseId=None
    currentUser = params.get("currentUser")
    if currentUser: enterpriseId = currentUser.get('enterpriseId')
    supplierId = blankProduct.get('supplierId')
    client, version = getClient(supplierId, 'pricing')
    blankProductId = blankProduct.get("blankProductId")
    locations = getLocations(supplierId)
    creds = getSupplierCredentials(supplierId, enterpriseId)
    partsPricing = {}
    for location in locations:
        pricingResponse = helpers.serialize_object(
            client.service.getConfigurationAndPricing(
                wsVersion= version,
                id = creds.get('username'),
                password = creds.get('password'),
                productId= blankProductId,
                currency = 'USD',
                fobId = location.get('id'),
                priceType = "Customer",
                localizationCountry = "US",
                localizationLanguage = "en",
                configurationType = "Blank"
            ), target_cls=dict
        )
        Parts = pricingResponse.get("Configuration", {})
        if Parts:
            try:
                Parts = Parts.get("PartArray", {})
                Parts = Parts.get('Part', []) if Parts else []
                for Part in Parts:
                    partId = Part.get('partId')
                    partsPricing[partId] = partsPricing.get(partId, {})
                    PartPriceArray = Part.get('PartPriceArray', {}).get('PartPrice', [])
                    for ParPrice in PartPriceArray:
                        price = float(ParPrice.get('price', 0))
                        if price!=0:
                            partsPricing.get(partId, {}).update({location.get('id'): round(price , 2)})
            except Exception as e:
                print("Pricing updates", e)
        else: print(pricingResponse,client, version, creds)
    savePricing(id, partsPricing, enterpriseId)
    return True

# def savePricing(blankProductId, pricings = {}):
#     ref = db.collection('blankProducts').document(blankProductId).collection('inventoryPricings')
#     for blankVariantId in pricings:
#         prices = pricings.get(blankVariantId).values()
#         price = round(sum(prices)/len(prices),2)
#         pricing = dict(id=blankVariantId, price = price, updatedAt = SERVER_TIMESTAMP)
#         ref.document(blankVariantId).set(pricing,merge=True)
#     print(f'Pricing Updated => {blankProductId}', len(pricings.keys()))
#     return blankProductId

def savePricing(blankProductId, pricings = {}, enterpriseId=None):
    ref = db.collection(f'blankProducts/{blankProductId}/inventoryPricings')
    if enterpriseId:
       ref = db.collection(f"enterprises/{enterpriseId}/blankProducts/{blankProductId}/inventoryPricings")
    # Fetch existing variant IDs
    existing_variants = {doc.id: doc.to_dict() for doc in ref.stream()}
    for blankVariantId in pricings:
        prices = pricings.get(blankVariantId).values()
        price = round(max(prices), 2)
        pricing = dict(id=blankVariantId, price = price, updatedAt = SERVER_TIMESTAMP)
        ref.document(blankVariantId).set(pricing,merge=True)
        # Remove this variant from the list of existing variants
        if blankVariantId in existing_variants: del existing_variants[blankVariantId]
    # Delete remaining variants that are not in the new data
    for variant_id in existing_variants: ref.document(variant_id).delete()
    return blankProductId

def getBlankVariantPrice(blankProductId, blankVariantId, enterpriseId=None):
    if enterpriseId: ref = db.document(f"enterprises/{enterpriseId}/blankProducts/{blankProductId}/inventoryPricings/{blankVariantId}")
    ref = ref.get()
    if not ref.exists: ref = db.document(f'blankProducts/{blankProductId}/inventoryPricings/{blankVariantId}').get()
    price = 0
    if ref.exists: price = ref.to_dict().get("price", 0)
    return price

import googlemaps


def calculate_distance(api_key, origin_zip, destination_zip):
  """
  Calculates the driving distance between two zip codes using the Google Maps Distance Matrix API.

  Args:
    api_key: Your Google Maps API key.
    origin_zip: The origin zip code.
    destination_zip: The destination zip code.

  Returns:
    The distance in meters, or None if there was an error.
  """
  gmaps = googlemaps.Client(key=api_key)

  try:
    distance_matrix = gmaps.distance_matrix(
        origins=origin_zip,
        destinations=destination_zip,
        mode="driving",  # You can change this to 'walking', 'bicycling', or 'transit'
        units="metric"  # Use 'imperial' for miles
    )
    # Extract the distance from the response
    distance_meters =   print(distance_matrix)['rows'][0]['elements'][0]['distance']['value']
    return distance_meters
  except Exception as e:
    print(f"Error: {e}")
    return None


def find_closest_zip_code(api_key, origin_zip, zip_code_list):
  """
  Finds the closest zip code from a list to a given origin zip code.

  Args:
    api_key: Your Google Maps API key.
    origin_zip: The origin zip code.
    zip_code_list: A list of zip codes to compare.

  Returns:
    The closest zip code and the distance to it in meters.
  """
  closest_zip = None
  closest_distance = float('inf')

  for zip_code in zip_code_list:
    distance = calculate_distance(api_key, origin_zip, zip_code)
    if distance and distance < closest_distance:
      closest_zip = zip_code
      closest_distance = distance

  return closest_zip, closest_distance


# # Example usage
# api_key = "AIzaSyD9msPdVf8rSpoLgJGCbJrAgej2GqLFDvY"  # Replace with your actual API key
# origin_zip = "10001"
# zip_code_list = ["90210", "60601", "77001"]

# # Calculate distance between two zip codes
# distance = calculate_distance(api_key, origin_zip, "90210")
# if distance:
#   print(f"The distance between {origin_zip} and 90210 is {distance} meters.")

# # Find the closest zip code
# closest_zip, closest_distance = find_closest_zip_code(api_key, origin_zip, zip_code_list)
# if closest_zip:
#   print(f"The closest zip code to {origin_zip} is {closest_zip} ({closest_distance} meters).")

import requests
from requests.auth import HTTPBasicAuth


def getBlankVariantsPricesFromSS(api_username, api_password, blankVariantIds=[]):
    url = f"https://api.ssactivewear.com/v2/products/{','.join(blankVariantIds)}"
    try:
        response = requests.get(url, auth=HTTPBasicAuth(api_username, api_password))
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"Error: {e}")
        return []
    
def updateSSPricing(blankProductId):
    blankVariants = db.collection(f"blankProducts/{blankProductId}/blankVariants").get()
    blankVariantIds = [blankVariant.id for blankVariant in blankVariants]

    all_pricings = []
    # Make API calls in batches of 100
    for i in range(0, len(blankVariantIds), 100):
        batch_ids = blankVariantIds[i:i+100]
        pricings = getBlankVariantsPricesFromSS(
            CREDS.get("SS").get("username"),
            CREDS.get("SS").get("password"),
            batch_ids
        )
        if pricings:
            all_pricings.extend(pricings)


    # Write in batches of 500
    writer = db.batch()
    op_count = 0  # Track operations in current batch

    for idx, pricing in enumerate(all_pricings):
        blankVariantId = pricing.get("sku")
        price = float(pricing.get("salePrice", 0))
        writer.set(
            db.document(f"blankProducts/{blankProductId}/inventoryPricings/{blankVariantId}"),
            dict(id=blankVariantId, price=price, updatedAt=SERVER_TIMESTAMP),
            merge=True
        )

        op_count += 1

        # Commit every 500 writes
        if op_count == 500:
            writer.commit()
            print(f"Committed 500 records at index {idx}")
            writer = db.batch()
            op_count = 0

    # Commit remaining writes
    if op_count > 0:
        writer.commit()
        print(f"Committed final {op_count} records")

    print(f"Pricing Updated => {blankProductId}, Total: {len(all_pricings)}")

def updateBlankProductsPricings(params):
  blanksInUse = getBlanksUsedByEnterprises()
  for blankProductId in blanksInUse:
    blankProduct = db.document(f"blankProducts").where("id", "==", blankProductId).get()
    if blankProduct.exists:
      blankProduct = blankProduct.to_dict()
      pricingUpdate = blankProduct.get("updates", {})
      if pricingUpdate:
        pricingUpdate = dict(pricingUpdate)
        updatedAt = pricingUpdate.get("updatedAt")
        if updatedAt:
          if updatedAt < datetime.now(pytz.utc).replace(tzinfo=None):
            blankProductId = updateBlankPricing(dict(id=blankProductId))
            db.document(f"blankProducts/{blankProductId}").update({"updates": None})