improvement(billing): migrate to decimaljs from number.parseFloat (#2588)

* improvement(billing): migrate to decimaljs from number.parseFloat

* ack PR comments

* ack pr comment

* consistency
This commit is contained in:
Waleed
2025-12-26 12:35:49 -08:00
committed by GitHub
parent d707d18ee6
commit 88cda3a9ce
6 changed files with 135 additions and 88 deletions

View File

@@ -5,6 +5,7 @@ import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getUserUsageData } from '@/lib/billing/core/usage'
import { getCreditBalance } from '@/lib/billing/credits/balance'
import { getFreeTierLimit, getPlanPricing } from '@/lib/billing/subscriptions/utils'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
export { getPlanPricing }
@@ -99,7 +100,7 @@ export async function calculateSubscriptionOverage(sub: {
return 0
}
let totalOverage = 0
let totalOverageDecimal = new Decimal(0)
if (sub.plan === 'team') {
const members = await db
@@ -107,10 +108,10 @@ export async function calculateSubscriptionOverage(sub: {
.from(member)
.where(eq(member.organizationId, sub.referenceId))
let totalTeamUsage = 0
let totalTeamUsageDecimal = new Decimal(0)
for (const m of members) {
const usage = await getUserUsageData(m.userId)
totalTeamUsage += usage.currentUsage
totalTeamUsageDecimal = totalTeamUsageDecimal.plus(toDecimal(usage.currentUsage))
}
const orgData = await db
@@ -119,28 +120,29 @@ export async function calculateSubscriptionOverage(sub: {
.where(eq(organization.id, sub.referenceId))
.limit(1)
const departedUsage =
orgData.length > 0 && orgData[0].departedMemberUsage
? Number.parseFloat(orgData[0].departedMemberUsage)
: 0
const departedUsageDecimal =
orgData.length > 0 ? toDecimal(orgData[0].departedMemberUsage) : new Decimal(0)
const totalUsageWithDeparted = totalTeamUsage + departedUsage
const totalUsageWithDepartedDecimal = totalTeamUsageDecimal.plus(departedUsageDecimal)
const { basePrice } = getPlanPricing(sub.plan)
const baseSubscriptionAmount = (sub.seats ?? 0) * basePrice
totalOverage = Math.max(0, totalUsageWithDeparted - baseSubscriptionAmount)
totalOverageDecimal = Decimal.max(
0,
totalUsageWithDepartedDecimal.minus(baseSubscriptionAmount)
)
logger.info('Calculated team overage', {
subscriptionId: sub.id,
currentMemberUsage: totalTeamUsage,
departedMemberUsage: departedUsage,
totalUsage: totalUsageWithDeparted,
currentMemberUsage: toNumber(totalTeamUsageDecimal),
departedMemberUsage: toNumber(departedUsageDecimal),
totalUsage: toNumber(totalUsageWithDepartedDecimal),
baseSubscriptionAmount,
totalOverage,
totalOverage: toNumber(totalOverageDecimal),
})
} else if (sub.plan === 'pro') {
// Pro plan: include snapshot if user joined a team
const usage = await getUserUsageData(sub.referenceId)
let totalProUsage = usage.currentUsage
let totalProUsageDecimal = toDecimal(usage.currentUsage)
// Add any snapshotted Pro usage (from when they joined a team)
const userStatsRows = await db
@@ -150,41 +152,41 @@ export async function calculateSubscriptionOverage(sub: {
.limit(1)
if (userStatsRows.length > 0 && userStatsRows[0].proPeriodCostSnapshot) {
const snapshotUsage = Number.parseFloat(userStatsRows[0].proPeriodCostSnapshot.toString())
totalProUsage += snapshotUsage
const snapshotUsageDecimal = toDecimal(userStatsRows[0].proPeriodCostSnapshot)
totalProUsageDecimal = totalProUsageDecimal.plus(snapshotUsageDecimal)
logger.info('Including snapshotted Pro usage in overage calculation', {
userId: sub.referenceId,
currentUsage: usage.currentUsage,
snapshotUsage,
totalProUsage,
snapshotUsage: toNumber(snapshotUsageDecimal),
totalProUsage: toNumber(totalProUsageDecimal),
})
}
const { basePrice } = getPlanPricing(sub.plan)
totalOverage = Math.max(0, totalProUsage - basePrice)
totalOverageDecimal = Decimal.max(0, totalProUsageDecimal.minus(basePrice))
logger.info('Calculated pro overage', {
subscriptionId: sub.id,
totalProUsage,
totalProUsage: toNumber(totalProUsageDecimal),
basePrice,
totalOverage,
totalOverage: toNumber(totalOverageDecimal),
})
} else {
// Free plan or unknown plan type
const usage = await getUserUsageData(sub.referenceId)
const { basePrice } = getPlanPricing(sub.plan || 'free')
totalOverage = Math.max(0, usage.currentUsage - basePrice)
totalOverageDecimal = Decimal.max(0, toDecimal(usage.currentUsage).minus(basePrice))
logger.info('Calculated overage for plan', {
subscriptionId: sub.id,
plan: sub.plan || 'free',
usage: usage.currentUsage,
basePrice,
totalOverage,
totalOverage: toNumber(totalOverageDecimal),
})
}
return totalOverage
return toNumber(totalOverageDecimal)
}
/**
@@ -272,14 +274,16 @@ export async function getSimplifiedBillingSummary(
const licensedSeats = subscription.seats ?? 0
const totalBasePrice = basePricePerSeat * licensedSeats // Based on Stripe subscription
let totalCurrentUsage = 0
let totalCopilotCost = 0
let totalLastPeriodCopilotCost = 0
let totalCurrentUsageDecimal = new Decimal(0)
let totalCopilotCostDecimal = new Decimal(0)
let totalLastPeriodCopilotCostDecimal = new Decimal(0)
// Calculate total team usage across all members
for (const memberInfo of members) {
const memberUsageData = await getUserUsageData(memberInfo.userId)
totalCurrentUsage += memberUsageData.currentUsage
totalCurrentUsageDecimal = totalCurrentUsageDecimal.plus(
toDecimal(memberUsageData.currentUsage)
)
// Fetch copilot cost for this member
const memberStats = await db
@@ -292,17 +296,21 @@ export async function getSimplifiedBillingSummary(
.limit(1)
if (memberStats.length > 0) {
totalCopilotCost += Number.parseFloat(
memberStats[0].currentPeriodCopilotCost?.toString() || '0'
totalCopilotCostDecimal = totalCopilotCostDecimal.plus(
toDecimal(memberStats[0].currentPeriodCopilotCost)
)
totalLastPeriodCopilotCost += Number.parseFloat(
memberStats[0].lastPeriodCopilotCost?.toString() || '0'
totalLastPeriodCopilotCostDecimal = totalLastPeriodCopilotCostDecimal.plus(
toDecimal(memberStats[0].lastPeriodCopilotCost)
)
}
}
const totalCurrentUsage = toNumber(totalCurrentUsageDecimal)
const totalCopilotCost = toNumber(totalCopilotCostDecimal)
const totalLastPeriodCopilotCost = toNumber(totalLastPeriodCopilotCostDecimal)
// Calculate team-level overage: total usage beyond what was already paid to Stripe
const totalOverage = Math.max(0, totalCurrentUsage - totalBasePrice)
const totalOverage = toNumber(Decimal.max(0, totalCurrentUsageDecimal.minus(totalBasePrice)))
// Get user's personal limits for warnings
const percentUsed =
@@ -380,14 +388,10 @@ export async function getSimplifiedBillingSummary(
.limit(1)
const copilotCost =
userStatsRows.length > 0
? Number.parseFloat(userStatsRows[0].currentPeriodCopilotCost?.toString() || '0')
: 0
userStatsRows.length > 0 ? toNumber(toDecimal(userStatsRows[0].currentPeriodCopilotCost)) : 0
const lastPeriodCopilotCost =
userStatsRows.length > 0
? Number.parseFloat(userStatsRows[0].lastPeriodCopilotCost?.toString() || '0')
: 0
userStatsRows.length > 0 ? toNumber(toDecimal(userStatsRows[0].lastPeriodCopilotCost)) : 0
// For team and enterprise plans, calculate total team usage instead of individual usage
let currentUsage = usageData.currentUsage
@@ -400,12 +404,12 @@ export async function getSimplifiedBillingSummary(
.from(member)
.where(eq(member.organizationId, subscription.referenceId))
let totalTeamUsage = 0
let totalTeamCopilotCost = 0
let totalTeamLastPeriodCopilotCost = 0
let totalTeamUsageDecimal = new Decimal(0)
let totalTeamCopilotCostDecimal = new Decimal(0)
let totalTeamLastPeriodCopilotCostDecimal = new Decimal(0)
for (const teamMember of teamMembers) {
const memberUsageData = await getUserUsageData(teamMember.userId)
totalTeamUsage += memberUsageData.currentUsage
totalTeamUsageDecimal = totalTeamUsageDecimal.plus(toDecimal(memberUsageData.currentUsage))
// Fetch copilot cost for this team member
const memberStats = await db
@@ -418,20 +422,20 @@ export async function getSimplifiedBillingSummary(
.limit(1)
if (memberStats.length > 0) {
totalTeamCopilotCost += Number.parseFloat(
memberStats[0].currentPeriodCopilotCost?.toString() || '0'
totalTeamCopilotCostDecimal = totalTeamCopilotCostDecimal.plus(
toDecimal(memberStats[0].currentPeriodCopilotCost)
)
totalTeamLastPeriodCopilotCost += Number.parseFloat(
memberStats[0].lastPeriodCopilotCost?.toString() || '0'
totalTeamLastPeriodCopilotCostDecimal = totalTeamLastPeriodCopilotCostDecimal.plus(
toDecimal(memberStats[0].lastPeriodCopilotCost)
)
}
}
currentUsage = totalTeamUsage
totalCopilotCost = totalTeamCopilotCost
totalLastPeriodCopilotCost = totalTeamLastPeriodCopilotCost
currentUsage = toNumber(totalTeamUsageDecimal)
totalCopilotCost = toNumber(totalTeamCopilotCostDecimal)
totalLastPeriodCopilotCost = toNumber(totalTeamLastPeriodCopilotCostDecimal)
}
const overageAmount = Math.max(0, currentUsage - basePrice)
const overageAmount = toNumber(Decimal.max(0, toDecimal(currentUsage).minus(basePrice)))
const percentUsed = usageData.limit > 0 ? (currentUsage / usageData.limit) * 100 : 0
// Calculate days remaining in billing period

View File

@@ -15,6 +15,7 @@ import {
getPlanPricing,
} from '@/lib/billing/subscriptions/utils'
import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { sendEmail } from '@/lib/messaging/email/mailer'
@@ -45,7 +46,7 @@ export async function getOrgUsageLimit(
const configured =
orgData.length > 0 && orgData[0].orgUsageLimit
? Number.parseFloat(orgData[0].orgUsageLimit)
? toNumber(toDecimal(orgData[0].orgUsageLimit))
: null
if (plan === 'enterprise') {
@@ -111,22 +112,23 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
}
const stats = userStatsData[0]
let currentUsage = Number.parseFloat(stats.currentPeriodCost?.toString() ?? '0')
let currentUsageDecimal = toDecimal(stats.currentPeriodCost)
// For Pro users, include any snapshotted usage (from when they joined a team)
// This ensures they see their total Pro usage in the UI
if (subscription && subscription.plan === 'pro' && subscription.referenceId === userId) {
const snapshotUsage = Number.parseFloat(stats.proPeriodCostSnapshot?.toString() ?? '0')
if (snapshotUsage > 0) {
currentUsage += snapshotUsage
const snapshotUsageDecimal = toDecimal(stats.proPeriodCostSnapshot)
if (snapshotUsageDecimal.greaterThan(0)) {
currentUsageDecimal = currentUsageDecimal.plus(snapshotUsageDecimal)
logger.info('Including Pro snapshot in usage display', {
userId,
currentPeriodCost: stats.currentPeriodCost,
proPeriodCostSnapshot: snapshotUsage,
totalUsage: currentUsage,
proPeriodCostSnapshot: toNumber(snapshotUsageDecimal),
totalUsage: toNumber(currentUsageDecimal),
})
}
}
const currentUsage = toNumber(currentUsageDecimal)
// Determine usage limit based on plan type
let limit: number
@@ -134,7 +136,7 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') {
// Free/Pro: Use individual user limit from userStats
limit = stats.currentUsageLimit
? Number.parseFloat(stats.currentUsageLimit)
? toNumber(toDecimal(stats.currentUsageLimit))
: getFreeTierLimit()
} else {
// Team/Enterprise: Use organization limit
@@ -163,7 +165,7 @@ export async function getUserUsageData(userId: string): Promise<UsageData> {
isExceeded,
billingPeriodStart,
billingPeriodEnd,
lastPeriodCost: Number.parseFloat(stats.lastPeriodCost?.toString() || '0'),
lastPeriodCost: toNumber(toDecimal(stats.lastPeriodCost)),
}
} catch (error) {
logger.error('Failed to get user usage data', { userId, error })
@@ -195,7 +197,7 @@ export async function getUserUsageLimitInfo(userId: string): Promise<UsageLimitI
if (!subscription || subscription.plan === 'free' || subscription.plan === 'pro') {
// Free/Pro: Use individual limits
currentLimit = stats.currentUsageLimit
? Number.parseFloat(stats.currentUsageLimit)
? toNumber(toDecimal(stats.currentUsageLimit))
: getFreeTierLimit()
minimumLimit = getPerUserMinimumLimit(subscription)
canEdit = canEditUsageLimit(subscription)
@@ -353,7 +355,7 @@ export async function getUserUsageLimit(userId: string): Promise<number> {
)
}
return Number.parseFloat(userStatsQuery[0].currentUsageLimit)
return toNumber(toDecimal(userStatsQuery[0].currentUsageLimit))
}
// Team/Enterprise: Verify org exists then use organization limit
const orgExists = await db
@@ -438,7 +440,7 @@ export async function syncUsageLimitsFromSubscription(userId: string): Promise<v
// Free/Pro: Handle individual limits
const defaultLimit = getPerUserMinimumLimit(subscription)
const currentLimit = currentStats.currentUsageLimit
? Number.parseFloat(currentStats.currentUsageLimit)
? toNumber(toDecimal(currentStats.currentUsageLimit))
: 0
if (!subscription || subscription.status !== 'active') {
@@ -503,9 +505,9 @@ export async function getTeamUsageLimits(organizationId: string): Promise<
userId: memberData.userId,
userName: memberData.userName,
userEmail: memberData.userEmail,
currentLimit: Number.parseFloat(memberData.currentLimit || getFreeTierLimit().toString()),
currentUsage: Number.parseFloat(memberData.currentPeriodCost || '0'),
totalCost: Number.parseFloat(memberData.totalCost || '0'),
currentLimit: toNumber(toDecimal(memberData.currentLimit || getFreeTierLimit().toString())),
currentUsage: toNumber(toDecimal(memberData.currentPeriodCost)),
totalCost: toNumber(toDecimal(memberData.totalCost)),
lastActive: memberData.lastActive,
}))
} catch (error) {
@@ -531,7 +533,7 @@ export async function getEffectiveCurrentPeriodCost(userId: string): Promise<num
.limit(1)
if (rows.length === 0) return 0
return rows[0].current ? Number.parseFloat(rows[0].current.toString()) : 0
return toNumber(toDecimal(rows[0].current))
}
// Team/Enterprise: pooled usage across org members
@@ -548,11 +550,11 @@ export async function getEffectiveCurrentPeriodCost(userId: string): Promise<num
.from(userStats)
.where(inArray(userStats.userId, memberIds))
let pooled = 0
let pooled = new Decimal(0)
for (const r of rows) {
pooled += r.current ? Number.parseFloat(r.current.toString()) : 0
pooled = pooled.plus(toDecimal(r.current))
}
return pooled
return toNumber(pooled)
}
/**

View File

@@ -3,6 +3,7 @@ import { member, organization, userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, sql } from 'drizzle-orm'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { Decimal, toDecimal, toFixedString, toNumber } from '@/lib/billing/utils/decimal'
const logger = createLogger('CreditBalance')
@@ -23,7 +24,7 @@ export async function getCreditBalance(userId: string): Promise<CreditBalanceInf
.limit(1)
return {
balance: orgRows.length > 0 ? Number.parseFloat(orgRows[0].creditBalance || '0') : 0,
balance: orgRows.length > 0 ? toNumber(toDecimal(orgRows[0].creditBalance)) : 0,
entityType: 'organization',
entityId: subscription.referenceId,
}
@@ -36,7 +37,7 @@ export async function getCreditBalance(userId: string): Promise<CreditBalanceInf
.limit(1)
return {
balance: userRows.length > 0 ? Number.parseFloat(userRows[0].creditBalance || '0') : 0,
balance: userRows.length > 0 ? toNumber(toDecimal(userRows[0].creditBalance)) : 0,
entityType: 'user',
entityId: userId,
}
@@ -92,7 +93,8 @@ export interface DeductResult {
}
async function atomicDeductUserCredits(userId: string, cost: number): Promise<number> {
const costStr = cost.toFixed(6)
const costDecimal = toDecimal(cost)
const costStr = toFixedString(costDecimal)
// Use raw SQL with CTE to capture old balance before update
const result = await db.execute<{ old_balance: string; new_balance: string }>(sql`
@@ -113,12 +115,13 @@ async function atomicDeductUserCredits(userId: string, cost: number): Promise<nu
const rows = Array.from(result)
if (rows.length === 0) return 0
const oldBalance = Number.parseFloat(rows[0].old_balance || '0')
return Math.min(oldBalance, cost)
const oldBalance = toDecimal(rows[0].old_balance)
return toNumber(oldBalance.lessThan(costDecimal) ? oldBalance : costDecimal)
}
async function atomicDeductOrgCredits(orgId: string, cost: number): Promise<number> {
const costStr = cost.toFixed(6)
const costDecimal = toDecimal(cost)
const costStr = toFixedString(costDecimal)
// Use raw SQL with CTE to capture old balance before update
const result = await db.execute<{ old_balance: string; new_balance: string }>(sql`
@@ -139,8 +142,8 @@ async function atomicDeductOrgCredits(orgId: string, cost: number): Promise<numb
const rows = Array.from(result)
if (rows.length === 0) return 0
const oldBalance = Number.parseFloat(rows[0].old_balance || '0')
return Math.min(oldBalance, cost)
const oldBalance = toDecimal(rows[0].old_balance)
return toNumber(oldBalance.lessThan(costDecimal) ? oldBalance : costDecimal)
}
export async function deductFromCredits(userId: string, cost: number): Promise<DeductResult> {
@@ -159,7 +162,7 @@ export async function deductFromCredits(userId: string, cost: number): Promise<D
creditsUsed = await atomicDeductUserCredits(userId, cost)
}
const overflow = Math.max(0, cost - creditsUsed)
const overflow = toNumber(Decimal.max(0, toDecimal(cost).minus(creditsUsed)))
if (creditsUsed > 0) {
logger.info('Deducted credits atomically', {

View File

@@ -0,0 +1,36 @@
import Decimal from 'decimal.js'
/**
* Configure Decimal.js for billing precision.
* 20 significant digits is more than enough for currency calculations.
*/
Decimal.set({ precision: 20, rounding: Decimal.ROUND_HALF_UP })
/**
* Parse a value to Decimal for precise billing calculations.
* Handles null, undefined, empty strings, and number/string inputs.
*/
export function toDecimal(value: string | number | null | undefined): Decimal {
if (value === null || value === undefined || value === '') {
return new Decimal(0)
}
return new Decimal(value)
}
/**
* Convert Decimal back to number for storage/API responses.
* Use this at the final step when returning values.
*/
export function toNumber(value: Decimal): number {
return value.toNumber()
}
/**
* Format a Decimal to a fixed string for database storage.
* Uses 6 decimal places which matches current DB precision.
*/
export function toFixedString(value: Decimal, decimalPlaces = 6): string {
return value.toFixed(decimalPlaces)
}
export { Decimal }

View File

@@ -12,6 +12,7 @@
"@tanstack/react-query-devtools": "5.90.2",
"@types/fluent-ffmpeg": "2.1.28",
"cronstrue": "3.3.0",
"decimal.js": "10.6.0",
"drizzle-orm": "^0.44.5",
"ffmpeg-static": "5.3.0",
"fluent-ffmpeg": "2.1.3",

View File

@@ -35,27 +35,28 @@
},
"dependencies": {
"@linear/sdk": "40.0.0",
"next-runtime-env": "3.3.0",
"@modelcontextprotocol/sdk": "1.20.2",
"@t3-oss/env-nextjs": "0.13.4",
"zod": "^3.24.2",
"@tanstack/react-query": "5.90.8",
"@tanstack/react-query-devtools": "5.90.2",
"@types/fluent-ffmpeg": "2.1.28",
"cronstrue": "3.3.0",
"decimal.js": "10.6.0",
"drizzle-orm": "^0.44.5",
"ffmpeg-static": "5.3.0",
"fluent-ffmpeg": "2.1.3",
"isolated-vm": "6.0.2",
"mongodb": "6.19.0",
"neo4j-driver": "6.0.1",
"next-runtime-env": "3.3.0",
"nodemailer": "7.0.11",
"onedollarstats": "0.0.10",
"postgres": "^3.4.5",
"remark-gfm": "4.0.1",
"rss-parser": "3.13.0",
"socket.io-client": "4.8.1",
"twilio": "5.9.0"
"twilio": "5.9.0",
"zod": "^3.24.2"
},
"devDependencies": {
"@biomejs/biome": "2.0.0-beta.5",