improvement(billing): migrate to decimaljs from number.parseFloat (#2588)

* improvement(billing): migrate to decimaljs from number.parseFloat

* ack PR comments

* ack pr comment

* consistency
This commit is contained in:
Waleed
2025-12-26 12:35:49 -08:00
committed by GitHub
parent d707d18ee6
commit 88cda3a9ce
6 changed files with 135 additions and 88 deletions

View File

@@ -5,6 +5,7 @@ import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getUserUsageData } from '@/lib/billing/core/usage' import { getUserUsageData } from '@/lib/billing/core/usage'
import { getCreditBalance } from '@/lib/billing/credits/balance' import { getCreditBalance } from '@/lib/billing/credits/balance'
import { getFreeTierLimit, getPlanPricing } from '@/lib/billing/subscriptions/utils' import { getFreeTierLimit, getPlanPricing } from '@/lib/billing/subscriptions/utils'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
export { getPlanPricing } export { getPlanPricing }
@@ -99,7 +100,7 @@ export async function calculateSubscriptionOverage(sub: {
return 0 return 0
} }
let totalOverage = 0 let totalOverageDecimal = new Decimal(0)
if (sub.plan === 'team') { if (sub.plan === 'team') {
const members = await db const members = await db
@@ -107,10 +108,10 @@ export async function calculateSubscriptionOverage(sub: {
.from(member) .from(member)
.where(eq(member.organizationId, sub.referenceId)) .where(eq(member.organizationId, sub.referenceId))
let totalTeamUsage = 0 let totalTeamUsageDecimal = new Decimal(0)
for (const m of members) { for (const m of members) {
const usage = await getUserUsageData(m.userId) const usage = await getUserUsageData(m.userId)
totalTeamUsage += usage.currentUsage totalTeamUsageDecimal = totalTeamUsageDecimal.plus(toDecimal(usage.currentUsage))
} }
const orgData = await db const orgData = await db
@@ -119,28 +120,29 @@ export async function calculateSubscriptionOverage(sub: {
.where(eq(organization.id, sub.referenceId)) .where(eq(organization.id, sub.referenceId))
.limit(1) .limit(1)
const departedUsage = const departedUsageDecimal =
orgData.length > 0 && orgData[0].departedMemberUsage orgData.length > 0 ? toDecimal(orgData[0].departedMemberUsage) : new Decimal(0)
? Number.parseFloat(orgData[0].departedMemberUsage)
: 0
const totalUsageWithDeparted = totalTeamUsage + departedUsage const totalUsageWithDepartedDecimal = totalTeamUsageDecimal.plus(departedUsageDecimal)
const { basePrice } = getPlanPricing(sub.plan) const { basePrice } = getPlanPricing(sub.plan)
const baseSubscriptionAmount = (sub.seats ?? 0) * basePrice const baseSubscriptionAmount = (sub.seats ?? 0) * basePrice
totalOverage = Math.max(0, totalUsageWithDeparted - baseSubscriptionAmount) totalOverageDecimal = Decimal.max(
0,
totalUsageWithDepartedDecimal.minus(baseSubscriptionAmount)
)
logger.info('Calculated team overage', { logger.info('Calculated team overage', {
subscriptionId: sub.id, subscriptionId: sub.id,
currentMemberUsage: totalTeamUsage, currentMemberUsage: toNumber(totalTeamUsageDecimal),
departedMemberUsage: departedUsage, departedMemberUsage: toNumber(departedUsageDecimal),
totalUsage: totalUsageWithDeparted, totalUsage: toNumber(totalUsageWithDepartedDecimal),
baseSubscriptionAmount, baseSubscriptionAmount,
totalOverage, totalOverage: toNumber(totalOverageDecimal),
}) })
} else if (sub.plan === 'pro') { } else if (sub.plan === 'pro') {
// Pro plan: include snapshot if user joined a team // Pro plan: include snapshot if user joined a team
const usage = await getUserUsageData(sub.referenceId) const usage = await getUserUsageData(sub.referenceId)
let totalProUsage = usage.currentUsage let totalProUsageDecimal = toDecimal(usage.currentUsage)
// Add any snapshotted Pro usage (from when they joined a team) // Add any snapshotted Pro usage (from when they joined a team)
const userStatsRows = await db const userStatsRows = await db
@@ -150,41 +152,41 @@ export async function calculateSubscriptionOverage(sub: {
.limit(1) .limit(1)
if (userStatsRows.length > 0 && userStatsRows[0].proPeriodCostSnapshot) { if (userStatsRows.length > 0 && userStatsRows[0].proPeriodCostSnapshot) {
const snapshotUsage = Number.parseFloat(userStatsRows[0].proPeriodCostSnapshot.toString()) const snapshotUsageDecimal = toDecimal(userStatsRows[0].proPeriodCostSnapshot)
totalProUsage += snapshotUsage totalProUsageDecimal = totalProUsageDecimal.plus(snapshotUsageDecimal)
logger.info('Including snapshotted Pro usage in overage calculation', { logger.info('Including snapshotted Pro usage in overage calculation', {
userId: sub.referenceId, userId: sub.referenceId,
currentUsage: usage.currentUsage, currentUsage: usage.currentUsage,
snapshotUsage, snapshotUsage: toNumber(snapshotUsageDecimal),
totalProUsage, totalProUsage: toNumber(totalProUsageDecimal),
}) })
} }
const { basePrice } = getPlanPricing(sub.plan) const { basePrice } = getPlanPricing(sub.plan)
totalOverage = Math.max(0, totalProUsage - basePrice) totalOverageDecimal = Decimal.max(0, totalProUsageDecimal.minus(basePrice))
logger.info('Calculated pro overage', { logger.info('Calculated pro overage', {
subscriptionId: sub.id, subscriptionId: sub.id,
totalProUsage, totalProUsage: toNumber(totalProUsageDecimal),
basePrice, basePrice,
totalOverage, totalOverage: toNumber(totalOverageDecimal),
}) })
} else { } else {
// Free plan or unknown plan type // Free plan or unknown plan type
const usage = await getUserUsageData(sub.referenceId) const usage = await getUserUsageData(sub.referenceId)
const { basePrice } = getPlanPricing(sub.plan || 'free') const { basePrice } = getPlanPricing(sub.plan || 'free')
totalOverage = Math.max(0, usage.currentUsage - basePrice) totalOverageDecimal = Decimal.max(0, toDecimal(usage.currentUsage).minus(basePrice))
logger.info('Calculated overage for plan', { logger.info('Calculated overage for plan', {
subscriptionId: sub.id, subscriptionId: sub.id,
plan: sub.plan || 'free', plan: sub.plan || 'free',
usage: usage.currentUsage, usage: usage.currentUsage,
basePrice, basePrice,
totalOverage, totalOverage: toNumber(totalOverageDecimal),
}) })
} }
return totalOverage return toNumber(totalOverageDecimal)
} }
/** /**
@@ -272,14 +274,16 @@ export async function getSimplifiedBillingSummary(
const licensedSeats = subscription.seats ?? 0 const licensedSeats = subscription.seats ?? 0
const totalBasePrice = basePricePerSeat * licensedSeats // Based on Stripe subscription const totalBasePrice = basePricePerSeat * licensedSeats // Based on Stripe subscription
let totalCurrentUsage = 0 let totalCurrentUsageDecimal = new Decimal(0)
let totalCopilotCost = 0 let totalCopilotCostDecimal = new Decimal(0)
let totalLastPeriodCopilotCost = 0 let totalLastPeriodCopilotCostDecimal = new Decimal(0)
// Calculate total team usage across all members // Calculate total team usage across all members
for (const memberInfo of members) { for (const memberInfo of members) {
const memberUsageData = await getUserUsageData(memberInfo.userId) const memberUsageData = await getUserUsageData(memberInfo.userId)
totalCurrentUsage += memberUsageData.currentUsage totalCurrentUsageDecimal = totalCurrentUsageDecimal.plus(
toDecimal(memberUsageData.currentUsage)
)
// Fetch copilot cost for this member // Fetch copilot cost for this member
const memberStats = await db const memberStats = await db
@@ -292,17 +296,21 @@ export async function getSimplifiedBillingSummary(
.limit(1) .limit(1)
if (memberStats.length > 0) { if (memberStats.length > 0) {
totalCopilotCost += Number.parseFloat( totalCopilotCostDecimal = totalCopilotCostDecimal.plus(
memberStats[0].currentPeriodCopilotCost?.toString() || '0' toDecimal(memberStats[0].currentPeriodCopilotCost)
) )
totalLastPeriodCopilotCost += Number.parseFloat( totalLastPeriodCopilotCostDecimal = totalLastPeriodCopilotCostDecimal.plus(
memberStats[0].lastPeriodCopilotCost?.toString() || '0' toDecimal(memberStats[0].lastPeriodCopilotCost)
) )
} }
} }
const totalCurrentUsage = toNumber(totalCurrentUsageDecimal)
const totalCopilotCost = toNumber(totalCopilotCostDecimal)
const totalLastPeriodCopilotCost = toNumber(totalLastPeriodCopilotCostDecimal)
// Calculate team-level overage: total usage beyond what was already paid to Stripe // Calculate team-level overage: total usage beyond what was already paid to Stripe
const totalOverage = Math.max(0, totalCurrentUsage - totalBasePrice) const totalOverage = toNumber(Decimal.max(0, totalCurrentUsageDecimal.minus(totalBasePrice)))
// Get user's personal limits for warnings // Get user's personal limits for warnings
const percentUsed = const percentUsed =
@@ -380,14 +388,10 @@ export async function getSimplifiedBillingSummary(
.limit(1) .limit(1)
const copilotCost = const copilotCost =
userStatsRows.length > 0 userStatsRows.length > 0 ? toNumber(toDecimal(userStatsRows[0].currentPeriodCopilotCost)) : 0
? Number.parseFloat(userStatsRows[0].currentPeriodCopilotCost?.toString() || '0')
: 0
const lastPeriodCopilotCost = const lastPeriodCopilotCost =
userStatsRows.length > 0 userStatsRows.length > 0 ? toNumber(toDecimal(userStatsRows[0].lastPeriodCopilotCost)) : 0
? Number.parseFloat(userStatsRows[0].lastPeriodCopilotCost?.toString() || '0')
: 0
// For team and enterprise plans, calculate total team usage instead of individual usage // For team and enterprise plans, calculate total team usage instead of individual usage
let currentUsage = usageData.currentUsage let currentUsage = usageData.currentUsage
@@ -400,12 +404,12 @@ export async function getSimplifiedBillingSummary(
.from(member) .from(member)
.where(eq(member.organizationId, subscription.referenceId)) .where(eq(member.organizationId, subscription.referenceId))
let totalTeamUsage = 0 let totalTeamUsageDecimal = new Decimal(0)
let totalTeamCopilotCost = 0 let totalTeamCopilotCostDecimal = new Decimal(0)
let totalTeamLastPeriodCopilotCost = 0 let totalTeamLastPeriodCopilotCostDecimal = new Decimal(0)
for (const teamMember of teamMembers) { for (const teamMember of teamMembers) {
const memberUsageData = await getUserUsageData(teamMember.userId) const memberUsageData = await getUserUsageData(teamMember.userId)
totalTeamUsage += memberUsageData.currentUsage totalTeamUsageDecimal = totalTeamUsageDecimal.plus(toDecimal(memberUsageData.currentUsage))
// Fetch copilot cost for this team member // Fetch copilot cost for this team member
const memberStats = await db const memberStats = await db
@@ -418,20 +422,20 @@ export async function getSimplifiedBillingSummary(
.limit(1) .limit(1)
if (memberStats.length > 0) { if (memberStats.length > 0) {
totalTeamCopilotCost += Number.parseFloat( totalTeamCopilotCostDecimal = totalTeamCopilotCostDecimal.plus(
memberStats[0].currentPeriodCopilotCost?.toString() || '0' toDecimal(memberStats[0].currentPeriodCopilotCost)
) )
totalTeamLastPeriodCopilotCost += Number.parseFloat( totalTeamLastPeriodCopilotCostDecimal = totalTeamLastPeriodCopilotCostDecimal.plus(
memberStats[0].lastPeriodCopilotCost?.toString() || '0' toDecimal(memberStats[0].lastPeriodCopilotCost)
) )
} }
} }
currentUsage = totalTeamUsage currentUsage = toNumber(totalTeamUsageDecimal)
totalCopilotCost = totalTeamCopilotCost totalCopilotCost = toNumber(totalTeamCopilotCostDecimal)
totalLastPeriodCopilotCost = totalTeamLastPeriodCopilotCost totalLastPeriodCopilotCost = toNumber(totalTeamLastPeriodCopilotCostDecimal)
} }
const overageAmount = Math.max(0, currentUsage - basePrice) const overageAmount = toNumber(Decimal.max(0, toDecimal(currentUsage).minus(basePrice)))
const percentUsed = usageData.limit > 0 ? (currentUsage / usageData.limit) * 100 : 0 const percentUsed = usageData.limit > 0 ? (currentUsage / usageData.limit) * 100 : 0
// Calculate days remaining in billing period // Calculate days remaining in billing period

View File

@@ -15,6 +15,7 @@ import {
getPlanPricing, getPlanPricing,
} from '@/lib/billing/subscriptions/utils' } from '@/lib/billing/subscriptions/utils'
import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types' import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
import { isBillingEnabled } from '@/lib/core/config/feature-flags' import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls' import { getBaseUrl } from '@/lib/core/utils/urls'
import { sendEmail } from '@/lib/messaging/email/mailer' import { sendEmail } from '@/lib/messaging/email/mailer'
@@ -45,7 +46,7 @@ export async function getOrgUsageLimit(
const configured = const configured =
orgData.length > 0 && orgData[0].orgUsageLimit orgData.length > 0 && orgData[0].orgUsageLimit
? Number.parseFloat(orgData[0].orgUsageLimit) ? toNumber(toDecimal(orgData[0].orgUsageLimit))
: null : null
if (plan === 'enterprise') { if (plan === 'enterprise') {
@@ -111,22 +112,23 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
} }
const stats = userStatsData[0] const stats = userStatsData[0]
let currentUsage = Number.parseFloat(stats.currentPeriodCost?.toString() ?? '0') let currentUsageDecimal = toDecimal(stats.currentPeriodCost)
// For Pro users, include any snapshotted usage (from when they joined a team) // For Pro users, include any snapshotted usage (from when they joined a team)
// This ensures they see their total Pro usage in the UI // This ensures they see their total Pro usage in the UI
if (subscription && subscription.plan === 'pro' && subscription.referenceId === userId) { if (subscription && subscription.plan === 'pro' && subscription.referenceId === userId) {
const snapshotUsage = Number.parseFloat(stats.proPeriodCostSnapshot?.toString() ?? '0') const snapshotUsageDecimal = toDecimal(stats.proPeriodCostSnapshot)
if (snapshotUsage > 0) { if (snapshotUsageDecimal.greaterThan(0)) {
currentUsage += snapshotUsage currentUsageDecimal = currentUsageDecimal.plus(snapshotUsageDecimal)
logger.info('Including Pro snapshot in usage display', { logger.info('Including Pro snapshot in usage display', {
userId, userId,
currentPeriodCost: stats.currentPeriodCost, currentPeriodCost: stats.currentPeriodCost,
proPeriodCostSnapshot: snapshotUsage, proPeriodCostSnapshot: toNumber(snapshotUsageDecimal),
totalUsage: currentUsage, totalUsage: toNumber(currentUsageDecimal),
}) })
} }
} }
const currentUsage = toNumber(currentUsageDecimal)
// Determine usage limit based on plan type // Determine usage limit based on plan type
let limit: number let limit: number
@@ -134,7 +136,7 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') { if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') {
// Free/Pro: Use individual user limit from userStats // Free/Pro: Use individual user limit from userStats
limit = stats.currentUsageLimit limit = stats.currentUsageLimit
? Number.parseFloat(stats.currentUsageLimit) ? toNumber(toDecimal(stats.currentUsageLimit))
: getFreeTierLimit() : getFreeTierLimit()
} else { } else {
// Team/Enterprise: Use organization limit // Team/Enterprise: Use organization limit
@@ -163,7 +165,7 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
isExceeded, isExceeded,
billingPeriodStart, billingPeriodStart,
billingPeriodEnd, billingPeriodEnd,
lastPeriodCost: Number.parseFloat(stats.lastPeriodCost?.toString() || '0'), lastPeriodCost: toNumber(toDecimal(stats.lastPeriodCost)),
} }
} catch (error) { } catch (error) {
logger.error('Failed to get user usage data', { userId, error }) logger.error('Failed to get user usage data', { userId, error })
@@ -195,7 +197,7 @@ export async function getUserUsageLimitInfo(userId: string): Promise<UsageLimitI
if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') { if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') {
// Free/Pro: Use individual limits // Free/Pro: Use individual limits
currentLimit = stats.currentUsageLimit currentLimit = stats.currentUsageLimit
? Number.parseFloat(stats.currentUsageLimit) ? toNumber(toDecimal(stats.currentUsageLimit))
: getFreeTierLimit() : getFreeTierLimit()
minimumLimit = getPerUserMinimumLimit(subscription) minimumLimit = getPerUserMinimumLimit(subscription)
canEdit = canEditUsageLimit(subscription) canEdit = canEditUsageLimit(subscription)
@@ -353,7 +355,7 @@ export async function getUserUsageLimit(userId: string): Promise<number> {
) )
} }
return Number.parseFloat(userStatsQuery[0].currentUsageLimit) return toNumber(toDecimal(userStatsQuery[0].currentUsageLimit))
} }
// Team/Enterprise: Verify org exists then use organization limit // Team/Enterprise: Verify org exists then use organization limit
const orgExists = await db const orgExists = await db
@@ -438,7 +440,7 @@ export async function syncUsageLimitsFromSubscription(userId: string): Promise<v
// Free/Pro: Handle individual limits // Free/Pro: Handle individual limits
const defaultLimit = getPerUserMinimumLimit(subscription) const defaultLimit = getPerUserMinimumLimit(subscription)
const currentLimit = currentStats.currentUsageLimit const currentLimit = currentStats.currentUsageLimit
? Number.parseFloat(currentStats.currentUsageLimit) ? toNumber(toDecimal(currentStats.currentUsageLimit))
: 0 : 0
if (!subscription || subscription.status !== 'active') { if (!subscription || subscription.status !== 'active') {
@@ -503,9 +505,9 @@ export async function getTeamUsageLimits(organizationId: string): Promise<
userId: memberData.userId, userId: memberData.userId,
userName: memberData.userName, userName: memberData.userName,
userEmail: memberData.userEmail, userEmail: memberData.userEmail,
currentLimit: Number.parseFloat(memberData.currentLimit || getFreeTierLimit().toString()), currentLimit: toNumber(toDecimal(memberData.currentLimit || getFreeTierLimit().toString())),
currentUsage: Number.parseFloat(memberData.currentPeriodCost || '0'), currentUsage: toNumber(toDecimal(memberData.currentPeriodCost)),
totalCost: Number.parseFloat(memberData.totalCost || '0'), totalCost: toNumber(toDecimal(memberData.totalCost)),
lastActive: memberData.lastActive, lastActive: memberData.lastActive,
})) }))
} catch (error) { } catch (error) {
@@ -531,7 +533,7 @@ export async function getEffectiveCurrentPeriodCost(userId: string): Promise<num
.limit(1) .limit(1)
if (rows.length === 0) return 0 if (rows.length === 0) return 0
return rows[0].current ? Number.parseFloat(rows[0].current.toString()) : 0 return toNumber(toDecimal(rows[0].current))
} }
// Team/Enterprise: pooled usage across org members // Team/Enterprise: pooled usage across org members
@@ -548,11 +550,11 @@ export async function getEffectiveCurrentPeriodCost(userId: string): Promise<num
.from(userStats) .from(userStats)
.where(inArray(userStats.userId, memberIds)) .where(inArray(userStats.userId, memberIds))
let pooled = 0 let pooled = new Decimal(0)
for (const r of rows) { for (const r of rows) {
pooled += r.current ? Number.parseFloat(r.current.toString()) : 0 pooled = pooled.plus(toDecimal(r.current))
} }
return pooled return toNumber(pooled)
} }
/** /**

View File

@@ -3,6 +3,7 @@ import { member, organization, userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger' import { createLogger } from '@sim/logger'
import { and, eq, sql } from 'drizzle-orm' import { and, eq, sql } from 'drizzle-orm'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription' import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { Decimal, toDecimal, toFixedString, toNumber } from '@/lib/billing/utils/decimal'
const logger = createLogger('CreditBalance') const logger = createLogger('CreditBalance')
@@ -23,7 +24,7 @@ export async function getCreditBalance(userId: string): Promise<CreditBalanceInf
.limit(1) .limit(1)
return { return {
balance: orgRows.length > 0 ? Number.parseFloat(orgRows[0].creditBalance || '0') : 0, balance: orgRows.length > 0 ? toNumber(toDecimal(orgRows[0].creditBalance)) : 0,
entityType: 'organization', entityType: 'organization',
entityId: subscription.referenceId, entityId: subscription.referenceId,
} }
@@ -36,7 +37,7 @@ export async function getCreditBalance(userId: string): Promise<CreditBalanceInf
.limit(1) .limit(1)
return { return {
balance: userRows.length > 0 ? Number.parseFloat(userRows[0].creditBalance || '0') : 0, balance: userRows.length > 0 ? toNumber(toDecimal(userRows[0].creditBalance)) : 0,
entityType: 'user', entityType: 'user',
entityId: userId, entityId: userId,
} }
@@ -92,20 +93,21 @@ export interface DeductResult {
} }
async function atomicDeductUserCredits(userId: string, cost: number): Promise<number> { async function atomicDeductUserCredits(userId: string, cost: number): Promise<number> {
const costStr = cost.toFixed(6) const costDecimal = toDecimal(cost)
const costStr = toFixedString(costDecimal)
// Use raw SQL with CTE to capture old balance before update // Use raw SQL with CTE to capture old balance before update
const result = await db.execute<{ old_balance: string; new_balance: string }>(sql` const result = await db.execute<{ old_balance: string; new_balance: string }>(sql`
WITH old_balance AS ( WITH old_balance AS (
SELECT credit_balance FROM user_stats WHERE user_id = ${userId} SELECT credit_balance FROM user_stats WHERE user_id = ${userId}
) )
UPDATE user_stats UPDATE user_stats
SET credit_balance = CASE SET credit_balance = CASE
WHEN credit_balance >= ${costStr}::decimal THEN credit_balance - ${costStr}::decimal WHEN credit_balance >= ${costStr}::decimal THEN credit_balance - ${costStr}::decimal
ELSE 0 ELSE 0
END END
WHERE user_id = ${userId} AND credit_balance >= 0 WHERE user_id = ${userId} AND credit_balance >= 0
RETURNING RETURNING
(SELECT credit_balance FROM old_balance) as old_balance, (SELECT credit_balance FROM old_balance) as old_balance,
credit_balance as new_balance credit_balance as new_balance
`) `)
@@ -113,25 +115,26 @@ async function atomicDeductUserCredits(userId: string, cost: number): Promise<nu
const rows = Array.from(result) const rows = Array.from(result)
if (rows.length === 0) return 0 if (rows.length === 0) return 0
const oldBalance = Number.parseFloat(rows[0].old_balance || '0') const oldBalance = toDecimal(rows[0].old_balance)
return Math.min(oldBalance, cost) return toNumber(oldBalance.lessThan(costDecimal) ? oldBalance : costDecimal)
} }
async function atomicDeductOrgCredits(orgId: string, cost: number): Promise<number> { async function atomicDeductOrgCredits(orgId: string, cost: number): Promise<number> {
const costStr = cost.toFixed(6) const costDecimal = toDecimal(cost)
const costStr = toFixedString(costDecimal)
// Use raw SQL with CTE to capture old balance before update // Use raw SQL with CTE to capture old balance before update
const result = await db.execute<{ old_balance: string; new_balance: string }>(sql` const result = await db.execute<{ old_balance: string; new_balance: string }>(sql`
WITH old_balance AS ( WITH old_balance AS (
SELECT credit_balance FROM organization WHERE id = ${orgId} SELECT credit_balance FROM organization WHERE id = ${orgId}
) )
UPDATE organization UPDATE organization
SET credit_balance = CASE SET credit_balance = CASE
WHEN credit_balance >= ${costStr}::decimal THEN credit_balance - ${costStr}::decimal WHEN credit_balance >= ${costStr}::decimal THEN credit_balance - ${costStr}::decimal
ELSE 0 ELSE 0
END END
WHERE id = ${orgId} AND credit_balance >= 0 WHERE id = ${orgId} AND credit_balance >= 0
RETURNING RETURNING
(SELECT credit_balance FROM old_balance) as old_balance, (SELECT credit_balance FROM old_balance) as old_balance,
credit_balance as new_balance credit_balance as new_balance
`) `)
@@ -139,8 +142,8 @@ async function atomicDeductOrgCredits(orgId: string, cost: number): Promise<numb
const rows = Array.from(result) const rows = Array.from(result)
if (rows.length === 0) return 0 if (rows.length === 0) return 0
const oldBalance = Number.parseFloat(rows[0].old_balance || '0') const oldBalance = toDecimal(rows[0].old_balance)
return Math.min(oldBalance, cost) return toNumber(oldBalance.lessThan(costDecimal) ? oldBalance : costDecimal)
} }
export async function deductFromCredits(userId: string, cost: number): Promise<DeductResult> { export async function deductFromCredits(userId: string, cost: number): Promise<DeductResult> {
@@ -159,7 +162,7 @@ export async function deductFromCredits(userId: string, cost: number): Promise<D
creditsUsed = await atomicDeductUserCredits(userId, cost) creditsUsed = await atomicDeductUserCredits(userId, cost)
} }
const overflow = Math.max(0, cost - creditsUsed) const overflow = toNumber(Decimal.max(0, toDecimal(cost).minus(creditsUsed)))
if (creditsUsed > 0) { if (creditsUsed > 0) {
logger.info('Deducted credits atomically', { logger.info('Deducted credits atomically', {

View File

@@ -0,0 +1,36 @@
import Decimal from 'decimal.js'
/**
* Configure Decimal.js for billing precision.
* 20 significant digits is more than enough for currency calculations.
*/
Decimal.set({ precision: 20, rounding: Decimal.ROUND_HALF_UP })
/**
* Parse a value to Decimal for precise billing calculations.
* Handles null, undefined, empty strings, and number/string inputs.
*/
export function toDecimal(value: string | number | null | undefined): Decimal {
if (value === null || value === undefined || value === '') {
return new Decimal(0)
}
return new Decimal(value)
}
/**
* Convert Decimal back to number for storage/API responses.
* Use this at the final step when returning values.
*/
export function toNumber(value: Decimal): number {
return value.toNumber()
}
/**
* Format a Decimal to a fixed string for database storage.
* Uses 6 decimal places which matches current DB precision.
*/
export function toFixedString(value: Decimal, decimalPlaces = 6): string {
return value.toFixed(decimalPlaces)
}
export { Decimal }

View File

@@ -12,6 +12,7 @@
"@tanstack/react-query-devtools": "5.90.2", "@tanstack/react-query-devtools": "5.90.2",
"@types/fluent-ffmpeg": "2.1.28", "@types/fluent-ffmpeg": "2.1.28",
"cronstrue": "3.3.0", "cronstrue": "3.3.0",
"decimal.js": "10.6.0",
"drizzle-orm": "^0.44.5", "drizzle-orm": "^0.44.5",
"ffmpeg-static": "5.3.0", "ffmpeg-static": "5.3.0",
"fluent-ffmpeg": "2.1.3", "fluent-ffmpeg": "2.1.3",

View File

@@ -35,27 +35,28 @@
}, },
"dependencies": { "dependencies": {
"@linear/sdk": "40.0.0", "@linear/sdk": "40.0.0",
"next-runtime-env": "3.3.0",
"@modelcontextprotocol/sdk": "1.20.2", "@modelcontextprotocol/sdk": "1.20.2",
"@t3-oss/env-nextjs": "0.13.4", "@t3-oss/env-nextjs": "0.13.4",
"zod": "^3.24.2",
"@tanstack/react-query": "5.90.8", "@tanstack/react-query": "5.90.8",
"@tanstack/react-query-devtools": "5.90.2", "@tanstack/react-query-devtools": "5.90.2",
"@types/fluent-ffmpeg": "2.1.28", "@types/fluent-ffmpeg": "2.1.28",
"cronstrue": "3.3.0", "cronstrue": "3.3.0",
"decimal.js": "10.6.0",
"drizzle-orm": "^0.44.5", "drizzle-orm": "^0.44.5",
"ffmpeg-static": "5.3.0", "ffmpeg-static": "5.3.0",
"fluent-ffmpeg": "2.1.3", "fluent-ffmpeg": "2.1.3",
"isolated-vm": "6.0.2", "isolated-vm": "6.0.2",
"mongodb": "6.19.0", "mongodb": "6.19.0",
"neo4j-driver": "6.0.1", "neo4j-driver": "6.0.1",
"next-runtime-env": "3.3.0",
"nodemailer": "7.0.11", "nodemailer": "7.0.11",
"onedollarstats": "0.0.10", "onedollarstats": "0.0.10",
"postgres": "^3.4.5", "postgres": "^3.4.5",
"remark-gfm": "4.0.1", "remark-gfm": "4.0.1",
"rss-parser": "3.13.0", "rss-parser": "3.13.0",
"socket.io-client": "4.8.1", "socket.io-client": "4.8.1",
"twilio": "5.9.0" "twilio": "5.9.0",
"zod": "^3.24.2"
}, },
"devDependencies": { "devDependencies": {
"@biomejs/biome": "2.0.0-beta.5", "@biomejs/biome": "2.0.0-beta.5",