diff --git a/src/quota/tokenCount.ts b/src/quota/tokenCount.ts index 70ab318..47e9e99 100644 --- a/src/quota/tokenCount.ts +++ b/src/quota/tokenCount.ts @@ -12,19 +12,16 @@ import { Usage } from "@prisma/client"; export default class tokenCount implements IQuota { defaultQuota: number; lookback: number; - requestTokenMultiplier: number; - responseTokenMultiplier: number; - + considerInputTokensAsHalf: boolean; + constructor( defaultQuota: number = 512 * 25, lookback: number = 1000 * 60 * 60 * 24, - requestTokenMultiplier: number = 1, - responseTokenMultiplier: number = 1, + considerInputTokensAsHalf: boolean = true, ) { this.defaultQuota = defaultQuota; this.lookback = lookback; - this.requestTokenMultiplier = requestTokenMultiplier; - this.responseTokenMultiplier = responseTokenMultiplier; + this.considerInputTokensAsHalf = considerInputTokensAsHalf; } private getUserQuota(id: string) { @@ -58,7 +55,9 @@ export default class tokenCount implements IQuota { const usageResponse = usedTokens.usageResponse === null ? 0 : usedTokens.usageResponse; const usedUnits = (() => { - return usageRequest * this.requestTokenMultiplier + usageResponse * this.responseTokenMultiplier; + if (this.considerInputTokensAsHalf) + return usageResponse + usageRequest / 2; + return usageResponse + usageRequest; })(); if (userQuota?.vip) return this.createUserQuotaData(Infinity, usedUnits); @@ -75,13 +74,30 @@ export default class tokenCount implements IQuota { * @returns promise of giving out the record */ findNthUsage(user: string, requestTimestamp: number, unitCount: number) { - return database.$queryRaw>` + if (this.considerInputTokensAsHalf) + return database.$queryRaw>` + SELECT t1.*, ( + SELECT + SUM(usageResponse + usageRequest/2) AS usage + FROM \`usage\` + WHERE + user = ${user} AND + timestamp >= ${requestTimestamp - this.lookback} AND + timestamp <= t1.timestamp + ) as usage + FROM + \`usage\` AS t1 + WHERE + user = ${user} AND + timestamp >= ${requestTimestamp - this.lookback} AND + usage >= ${unitCount} + ORDER BY timestamp ASC + LIMIT 1 + `; + return database.$queryRaw>` SELECT t1.*, ( SELECT - SUM( - usageRequest * ${this.requestTokenMultiplier} + - usageResponse * ${this.responseTokenMultiplier} - ) AS usage + SUM(usageResponse + usageRequest) AS usage FROM \`usage\` WHERE user = ${user} AND @@ -105,14 +121,14 @@ export default class tokenCount implements IQuota { ): Promise { const userId = typeof user ==="string" ? user : user.id; - const [userQuota, overUnitCountRecord] = await Promise.all([ + const [userQuota, renameMebecause] = await Promise.all([ this.checkUser(userId, request), this.findNthUsage(userId, request.createdTimestamp, unitCount) ]); return { ...userQuota, - recoveryTimestamp: (overUnitCountRecord.at(0)?.timestamp.valueOf() ?? Infinity) + this.lookback, + recoveryTimestamp: (renameMebecause.at(0)?.timestamp.valueOf() ?? Infinity) + this.lookback, }; }