From d0be1a6aa1a6b80119ee1e69883fdd72403a77dc Mon Sep 17 00:00:00 2001 From: Travis Fischer Date: Sat, 17 May 2025 20:07:24 +0700 Subject: [PATCH] feat: WIP stripe billing refactor update for 2025 --- apps/api/src/db/schema/deployment.ts | 15 +- apps/api/src/db/schema/project.ts | 21 +- apps/api/src/db/schema/types.ts | 122 ++++++++--- apps/api/src/db/schema/utils.ts | 58 +++-- .../upsert-stripe-products-and-pricing.ts | 201 ++++++++++++------ apps/api/src/lib/utils.test.ts | 19 ++ 6 files changed, 292 insertions(+), 144 deletions(-) create mode 100644 apps/api/src/lib/utils.test.ts diff --git a/apps/api/src/db/schema/deployment.ts b/apps/api/src/db/schema/deployment.ts index c9886b9e..10678c4e 100644 --- a/apps/api/src/db/schema/deployment.ts +++ b/apps/api/src/db/schema/deployment.ts @@ -14,8 +14,8 @@ import { teams, teamSelectSchema } from './team' import { // type Coupon, // couponSchema, - type PricingPlanMapByInterval, - pricingPlanMapByIntervalSchema + type PricingPlanMap, + pricingPlanMapSchema } from './types' import { users, userSelectSchema } from './user' import { @@ -64,11 +64,8 @@ export const deployments = pgTable( // Backend API URL _url: text().notNull(), - // NOTE: this does not have a default value and must be given a value at creation. - // Record> - pricingPlanMapByInterval: jsonb() - .$type() - .notNull() + // Record + pricingPlanMap: jsonb().$type().notNull() // coupons: jsonb().$type().default([]).notNull() }, @@ -110,7 +107,7 @@ export const deploymentSelectSchema = createSelectSchema(deployments, { // build: z.object({}), // env: z.object({}), - pricingPlanMapByInterval: pricingPlanMapByIntervalSchema + pricingPlanMap: pricingPlanMapSchema // coupons: z.array(couponSchema) }) .omit({ @@ -145,7 +142,7 @@ export const deploymentInsertSchema = createInsertSchema(deployments, { // TODO: should this public resource be decoupled from the internal pricing // plan structure? - pricingPlanMapByInterval: pricingPlanMapByIntervalSchema + pricingPlanMap: pricingPlanMapSchema // TODO // coupons: z.array(couponSchema).optional() diff --git a/apps/api/src/db/schema/project.ts b/apps/api/src/db/schema/project.ts index fd3df4c9..fbde2f6f 100644 --- a/apps/api/src/db/schema/project.ts +++ b/apps/api/src/db/schema/project.ts @@ -85,8 +85,6 @@ export const projects = pgTable( _webhooks: jsonb().$type().default([]).notNull(), - // TODO: currency? - // Stripe coupons associated with this project, mapping from unique coupon // object hash to stripe coupon id. // `[hash: string]: string` @@ -180,8 +178,8 @@ export const projectSelectSchema = createSelectSchema(projects, { _stripePriceIdMap: stripePriceIdMapSchema, _stripeMeterIdMap: stripeMeterIdMapSchema, - pricingIntervals: z.array(pricingIntervalSchema).optional(), - defaultPricingInterval: pricingIntervalSchema.optional() + pricingIntervals: z.array(pricingIntervalSchema).nonempty(), + defaultPricingInterval: pricingIntervalSchema }) .omit({ _secret: true, @@ -241,20 +239,5 @@ export const projectUpdateSchema = createUpdateSchema(projects) }) .strict() -export const projectDebugSelectSchema = createSelectSchema(projects).pick({ - id: true, - name: true, - alias: true, - userId: true, - teamId: true, - createdAt: true, - updatedAt: true, - isStripeConnectEnabled: true, - lastPublishedDeploymentId: true, - lastDeploymentId: true, - pricingIntervals: true, - defaultPricingInterval: true -}) - // TODO: virtual saasUrl // TODO: virtual aliasUrl diff --git a/apps/api/src/db/schema/types.ts b/apps/api/src/db/schema/types.ts index 3ef6cc01..299ac86b 100644 --- a/apps/api/src/db/schema/types.ts +++ b/apps/api/src/db/schema/types.ts @@ -59,13 +59,8 @@ export type Webhook = z.infer export const rateLimitSchema = z .object({ - enabled: z.boolean(), - interval: z.number(), // seconds - maxPerInterval: z.number(), // unitless - - // informal description that overrides any other properties - desc: z.string().optional() + maxPerInterval: z.number() // unitless }) .openapi('RateLimit') export type RateLimit = z.infer @@ -88,6 +83,7 @@ export type PricingPlanTier = z.infer export const pricingIntervalSchema = z .enum(['day', 'week', 'month', 'year']) + .describe('The frequency at which a subscription is billed.') .openapi('PricingInterval') export type PricingInterval = z.infer @@ -117,15 +113,21 @@ const commonPricingPlanMetricSchema = z.object({ /** * Slugs act as the primary key for metrics. They should be lower and * kebab-cased ("base", "requests", "image-transformations"). + * + * TODO: ensure user-provided custom metrics don't use reserved 'base' + * and 'requests' slugs. */ slug: z.union([z.string(), z.literal('base'), z.literal('requests')]), - interval: pricingIntervalSchema, + /** + * The frequency at which a subscription is billed. + * + * Only optional when `PricingPlan.slug` is `free`. + */ + interval: pricingIntervalSchema.optional(), label: z.string().optional().openapi('label', { example: 'API calls' }), - rateLimit: rateLimitSchema.optional(), - stripePriceId: z.string().optional() }) @@ -134,7 +136,7 @@ export const pricingPlanMetricSchema = z commonPricingPlanMetricSchema.merge( z.object({ usageType: z.literal('licensed'), - amount: z.number() + amount: z.number().nonnegative() }) ), @@ -143,10 +145,29 @@ export const pricingPlanMetricSchema = z usageType: z.literal('metered'), unitLabel: z.string().optional(), + /** + * Optional rate limit to enforce for this metric. + * + * You can use this, for example, to limit the number of API calls that + * can be made during a given interval. + */ + rateLimit: rateLimitSchema.optional(), + + /** + * Describes how to compute the price per period. Either `per_unit` or + * `tiered`. + * + * `per_unit` indicates that the fixed amount (specified in + * `unitAmount`) will be charged per unit of total usage. + * + * `tiered` indicates that the unit pricing will be computed using a + * tiering strategy as defined using the `tiers` and `tiersMode` + * attributes. + */ billingScheme: z.enum(['per_unit', 'tiered']), // Only applicable for `per_unit` billing schemes - amount: z.number().optional(), + unitAmount: z.number().nonnegative().optional(), // Only applicable for `tiered` billing schemes tiersMode: z.enum(['graduated', 'volume']).optional(), @@ -154,14 +175,48 @@ export const pricingPlanMetricSchema = z // TODO: add support for tiered rate limits? + /** + * The default settings to aggregate the Stripe Meter's events with. + * + * Deafults to `{ formula: 'sum' }`. + */ defaultAggregation: z .object({ - formula: z.enum(['sum', 'count', 'last']) + /** + * Specifies how events are aggregated for a Stripe Metric. + * Allowed values are `count` to count the number of events, `sum` + * to sum each event's value and `last` to take the last event's + * value in the window. + * + * Defaults to `sum`. + */ + formula: z.enum(['sum', 'count', 'last']).default('sum') }) .optional(), - // Stripe metric id, which is created lazily upon first use. - stripeMetricId: z.string().optional() + /** + * Optionally apply a transformation to the reported usage or set + * quantity before computing the amount billed. Cannot be combined + * with `tiers`. + */ + transformQuantity: z + .object({ + /** + * Divide usage by this number. + */ + divideBy: z.number().positive(), + + /** + * After division, either round the result `up` or `down`. + */ + round: z.enum(['down', 'up']) + }) + .optional(), + + /** + * Stripe Meter id, which is created lazily upon first use. + */ + stripeMeterId: z.string().optional() }) ) ]) @@ -171,15 +226,19 @@ export type PricingPlanMetric = z.infer export const pricingPlanSchema = z .object({ - name: z.string().openapi('name', { example: 'Starter Monthly' }), - slug: z.string().openapi('slug', { example: 'starter-monthly' }), + name: z.string().nonempty().openapi('name', { example: 'Starter Monthly' }), + slug: z.string().nonempty().openapi('slug', { example: 'starter-monthly' }), + + /** + * The frequency at which a subscription is billed. + */ + interval: pricingIntervalSchema.optional(), desc: z.string().optional(), features: z.array(z.string()), - interval: pricingIntervalSchema, - - trialPeriodDays: z.number().optional(), + // TODO? + trialPeriodDays: z.number().nonnegative().optional(), metricsMap: z .record(pricingPlanMetricSlugSchema, pricingPlanMetricSchema) @@ -190,6 +249,15 @@ export const pricingPlanSchema = z }) .default({}) }) + .refine((data) => { + if (data.interval === undefined && data.slug !== 'free') { + throw new Error( + `Invalid PricingPlan "${data.slug}": non-free pricing plans must have an interval` + ) + } + + return data + }) .openapi('PricingPlan') export type PricingPlan = z.infer @@ -199,19 +267,13 @@ export const stripeProductIdMapSchema = z .openapi('StripeProductIdMap') export type StripeProductIdMap = z.infer -export const pricingPlanMapBySlugSchema = z +export const pricingPlanMapSchema = z .record(z.string().describe('PricingPlan slug'), pricingPlanSchema) + .refine((data) => Object.keys(data).length > 0, { + message: 'Must contain at least one PricingPlan' + }) .describe('Map from PricingPlan slug to PricingPlan') -export type PricingPlanMapBySlug = z.infer - -export const pricingPlanMapByIntervalSchema = z - .record(pricingIntervalSchema, pricingPlanMapBySlugSchema) - .describe( - 'Map from PricingInterval to a map from PricingPlan slug to PricingPlan' - ) -export type PricingPlanMapByInterval = z.infer< - typeof pricingPlanMapByIntervalSchema -> +export type PricingPlanMap = z.infer // export const couponSchema = z // .object({ diff --git a/apps/api/src/db/schema/utils.ts b/apps/api/src/db/schema/utils.ts index 2ead3cde..607de3e7 100644 --- a/apps/api/src/db/schema/utils.ts +++ b/apps/api/src/db/schema/utils.ts @@ -16,7 +16,12 @@ import { createId } from '@paralleldrive/cuid2' import { hashObject, omit } from '@/lib/utils' import type { RawProject } from '../types' -import type { PricingPlanMetric } from './types' +import type { + PricingInterval, + PricingPlan, + PricingPlanMap, + PricingPlanMetric +} from './types' const usernameAndTeamSlugLength = 64 as const @@ -135,33 +140,42 @@ export function getPricingPlanMetricHashForStripePrice({ pricingPlanMetric: PricingPlanMetric project: RawProject }) { + // TODO: use pricingPlan.slug as well here? + // 'price:free:base:' + // 'price:basic-monthly:base:' + // 'price:basic-monthly:requests:' + const hash = hashObject({ - ...omit(pricingPlanMetric, 'stripePriceId', 'stripeMetricId'), + ...omit(pricingPlanMetric, 'stripePriceId', 'stripeMeterId'), projectId: project.id, - stripeAccountId: project._stripeAccountId + stripeAccountId: project._stripeAccountId, + currency: project.pricingCurrency }) return `price:${pricingPlanMetric.slug}:${hash}` } -/** - * Gets the hash used to uniquely map a PricingPlanMetric to its corresponding - * Stripe Meter in a stable way across deployments within a project. - * - * This hash is used as the key for the `Project._stripePriceIdMap`. - */ -export function getPricingPlanMetricHashForStripeMeter({ - pricingPlanMetric, - project +export function getPricingPlansByInterval({ + pricingInterval, + pricingPlanMap }: { - pricingPlanMetric: PricingPlanMetric - project: RawProject -}) { - const hash = hashObject({ - ...omit(pricingPlanMetric, 'stripePriceId', 'stripeMetricId'), - projectId: project.id, - stripeAccountId: project._stripeAccountId - }) - - return `price:${pricingPlanMetric.slug}:${hash}` + pricingInterval: PricingInterval + pricingPlanMap: PricingPlanMap +}): PricingPlan[] { + return Object.values(pricingPlanMap).filter( + (pricingPlan) => pricingPlan.interval === pricingInterval + ) +} + +const pricingIntervalToLabelMap: Record = { + day: 'daily', + week: 'weekly', + month: 'monthly', + year: 'yearly' +} + +export function getLabelForPricingInterval( + pricingInterval: PricingInterval +): string { + return pricingIntervalToLabelMap[pricingInterval] } diff --git a/apps/api/src/lib/billing/upsert-stripe-products-and-pricing.ts b/apps/api/src/lib/billing/upsert-stripe-products-and-pricing.ts index d5dde312..534c32d2 100644 --- a/apps/api/src/lib/billing/upsert-stripe-products-and-pricing.ts +++ b/apps/api/src/lib/billing/upsert-stripe-products-and-pricing.ts @@ -3,7 +3,9 @@ import pAll from 'p-all' import { db, eq, type RawDeployment, type RawProject, schema } from '@/db' import { - getPricingPlanMetricHash, + getLabelForPricingInterval, + getPricingPlanMetricHashForStripePrice, + getPricingPlansByInterval, type PricingPlan, type PricingPlanMetric } from '@/db/schema' @@ -33,13 +35,14 @@ export async function upsertStripeProductsAndPricing({ pricingPlan: PricingPlan pricingPlanMetric: PricingPlanMetric }) { - const { slug: pricingPlanSlug } = pricingPlan // TODO + const { slug: pricingPlanSlug } = pricingPlan const { slug: pricingPlanMetricSlug } = pricingPlanMetric - const pricingPlanMetricHash = getPricingPlanMetricHash({ - pricingPlanMetric, - project - }) + const pricingPlanMetricHashForStripePrice = + getPricingPlanMetricHashForStripePrice({ + pricingPlanMetric, + project + }) // Upsert the Stripe Product if (!project._stripeProductIdMap[pricingPlanMetricSlug]) { @@ -69,51 +72,82 @@ export async function upsertStripeProductsAndPricing({ assert(project._stripeProductIdMap[pricingPlanMetricSlug]) - // Upsert the Stripe Meter - if ( - pricingPlanMetric.usageType === 'metered' && - !project._stripeMeterIdMap[pricingPlanMetricSlug] - ) { - const meter = await stripe.billing.meters.create( - { - display_name: `${project.id} ${pricingPlanMetric.label || pricingPlanMetricSlug}`, - event_name: `meter-${project.id}-${pricingPlanMetricSlug}`, - default_aggregation: { - formula: 'sum' + if (pricingPlanMetric.usageType === 'metered') { + // Upsert the Stripe Meter + if (!project._stripeMeterIdMap[pricingPlanMetricSlug]) { + const stripeMeter = await stripe.billing.meters.create( + { + display_name: `${project.id} ${pricingPlanMetric.label || pricingPlanMetricSlug}`, + event_name: `meter-${project.id}-${pricingPlanMetricSlug}`, + // TODO: This currently isn't taken into account for the slug, so if it + // changes across deployments, the meter will not be updated. + default_aggregation: { + formula: pricingPlanMetric.defaultAggregation?.formula ?? 'sum' + }, + customer_mapping: { + event_payload_key: 'stripe_customer_id', + type: 'by_id' + }, + value_settings: { + event_payload_key: 'value' + } }, - customer_mapping: { - event_payload_key: 'stripe_customer_id', - type: 'by_id' - }, - value_settings: { - event_payload_key: 'value' - } - }, - ...stripeConnectParams - ) + ...stripeConnectParams + ) - project._stripeMeterIdMap[pricingPlanMetricSlug] = meter.id - dirty = true + project._stripeMeterIdMap[pricingPlanMetricSlug] = stripeMeter.id + dirty = true + } + + assert(project._stripeMeterIdMap[pricingPlanMetricSlug]) + + if (!pricingPlanMetric.stripeMeterId) { + pricingPlanMetric.stripeMeterId = + project._stripeMeterIdMap[pricingPlanMetricSlug] + dirty = true + + assert(pricingPlanMetric.stripeMeterId) + } + } else { + assert(pricingPlanMetric.usageType === 'licensed') + + assert( + !project._stripeMeterIdMap[pricingPlanMetricSlug], + `Invalid pricing plan metric "${pricingPlanMetricSlug}" for pricing plan "${pricingPlanSlug}": licensed pricing plan metrics cannot replace a previous metered pricing plan metric. Use a different pricing plan metric slug for the new licensed plan.` + ) } - assert( - pricingPlanMetric.usageType === 'licensed' || - project._stripeMeterIdMap[pricingPlanMetricSlug] - ) - // Upsert the Stripe Price - if (!project._stripePriceIdMap[pricingPlanMetricHash]) { + if (!project._stripePriceIdMap[pricingPlanMetricHashForStripePrice]) { + const interval = + pricingPlanMetric.interval ?? project.defaultPricingInterval + + const nickname = [ + 'price', + project.id, + pricingPlanMetricSlug, + getLabelForPricingInterval(interval) + ] + .filter(Boolean) + .join('-') + const priceParams: Stripe.PriceCreateParams = { - nickname: `price-${project.id}-${pricingPlan.slug}-${pricingPlanMetricSlug}`, + nickname, product: project._stripeProductIdMap[pricingPlanMetricSlug], currency: project.pricingCurrency, recurring: { - interval: pricingPlanMetric.interval, + interval, // TODO: support this interval_count: 1, - usage_type: pricingPlanMetric.usageType + usage_type: pricingPlanMetric.usageType, + + meter: project._stripeMeterIdMap[pricingPlanMetricSlug] + }, + metadata: { + projectId: project.id, + pricingPlanMetricSlug } } @@ -124,26 +158,53 @@ export async function upsertStripeProductsAndPricing({ if (pricingPlanMetric.billingScheme === 'tiered') { assert( - pricingPlanMetric.tiers, - `Invalid pricing plan metric: ${pricingPlanMetricSlug}` + pricingPlanMetric.tiers?.length, + `Invalid pricing plan metric "${pricingPlanMetricSlug}" for pricing plan "${pricingPlanSlug}": tiered billing schemes must have at least one tier.` + ) + assert( + !pricingPlanMetric.transformQuantity, + `Invalid pricing plan metric "${pricingPlanMetricSlug}" for pricing plan "${pricingPlanSlug}": tiered billing schemes cannot have transformQuantity.` ) priceParams.tiers_mode = pricingPlanMetric.tiersMode - priceParams.tiers = pricingPlanMetric.tiers!.map((tier) => { - const result: Stripe.PriceCreateParams.Tier = { - up_to: tier.upTo + priceParams.tiers = pricingPlanMetric.tiers!.map((tierData) => { + const tier: Stripe.PriceCreateParams.Tier = { + up_to: tierData.upTo } - if (tier.unitAmount !== undefined) { - result.unit_amount_decimal = tier.unitAmount.toFixed(12) + if (tierData.unitAmount !== undefined) { + tier.unit_amount_decimal = tierData.unitAmount.toFixed(12) } - if (tier.flatAmount !== undefined) { - result.flat_amount_decimal = tier.flatAmount.toFixed(12) + if (tierData.flatAmount !== undefined) { + tier.flat_amount_decimal = tierData.flatAmount.toFixed(12) } - return result + return tier }) + } else { + assert( + pricingPlanMetric.billingScheme === 'per_unit', + `Invalid pricing plan metric "${pricingPlanMetricSlug}" for pricing plan "${pricingPlanSlug}": invalid billing scheme.` + ) + assert( + pricingPlanMetric.unitAmount !== undefined, + `Invalid pricing plan metric "${pricingPlanMetricSlug}" for pricing plan "${pricingPlanSlug}": unitAmount is required for per_unit billing schemes.` + ) + assert( + !pricingPlanMetric.tiers, + `Invalid pricing plan metric "${pricingPlanMetricSlug}" for pricing plan "${pricingPlanSlug}": per_unit billing schemes cannot have tiers.` + ) + + priceParams.unit_amount_decimal = + pricingPlanMetric.unitAmount.toFixed(12) + + if (pricingPlanMetric.transformQuantity) { + priceParams.transform_quantity = { + divide_by: pricingPlanMetric.transformQuantity.divideBy, + round: pricingPlanMetric.transformQuantity.round + } + } } } @@ -152,31 +213,43 @@ export async function upsertStripeProductsAndPricing({ ...stripeConnectParams ) - project._stripePriceIdMap[pricingPlanMetricHash] = stripePrice.id + project._stripePriceIdMap[pricingPlanMetricHashForStripePrice] = + stripePrice.id dirty = true } - assert(project._stripePriceIdMap[pricingPlanMetricHash]) + assert(project._stripePriceIdMap[pricingPlanMetricHashForStripePrice]) + + if (!pricingPlanMetric.stripePriceId) { + pricingPlanMetric.stripePriceId = + project._stripePriceIdMap[pricingPlanMetricHashForStripePrice] + } + + assert(pricingPlanMetric.stripePriceId) } const upserts: Array<() => Promise> = [] for (const pricingInterval of project.pricingIntervals) { - const pricingPlanMap = deployment.pricingPlanMapByInterval[pricingInterval] - assert( - pricingPlanMap, - `Invalid pricing config for deployment "${deployment.id}": missing pricing plan map for interval "${pricingInterval}"` - ) + const pricingPlans = getPricingPlansByInterval({ + pricingInterval, + pricingPlanMap: deployment.pricingPlanMap + }) - for (const pricingPlan of Object.values(pricingPlanMap)) { - for (const pricingPlanMetric of Object.values(pricingPlan.metricsMap)) { - upserts.push(() => - upsertStripeProductAndPricingForMetric({ - pricingPlan, - pricingPlanMetric - }) - ) - } + assert( + pricingPlans.length > 0, + `Invalid pricing config for deployment "${deployment.id}": no pricing plans for interval "${pricingInterval}"` + ) + } + + for (const pricingPlan of Object.values(deployment.pricingPlanMap)) { + for (const pricingPlanMetric of Object.values(pricingPlan.metricsMap)) { + upserts.push(() => + upsertStripeProductAndPricingForMetric({ + pricingPlan, + pricingPlanMetric + }) + ) } } diff --git a/apps/api/src/lib/utils.test.ts b/apps/api/src/lib/utils.test.ts new file mode 100644 index 00000000..150a691e --- /dev/null +++ b/apps/api/src/lib/utils.test.ts @@ -0,0 +1,19 @@ +import { expect, test } from 'vitest' + +import { omit, pick } from './utils' + +test('pick', () => { + expect(pick({ a: 1, b: 2, c: 3 }, 'a', 'c')).toEqual({ a: 1, c: 3 }) + expect( + pick({ a: { b: 'foo' }, d: -1, foo: null } as any, 'b', 'foo') + ).toEqual({ foo: null }) +}) + +test('omit', () => { + expect(omit({ a: 1, b: 2, c: 3 }, 'a', 'c')).toEqual({ b: 2 }) + expect(omit({ a: { b: 'foo' }, d: -1, foo: null }, 'b', 'foo')).toEqual({ + a: { b: 'foo' }, + d: -1 + }) + expect(omit({ a: 1, b: 2, c: 3 }, 'foo', 'bar', 'c')).toEqual({ a: 1, b: 2 }) +})