diff --git a/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml b/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml index be7b97146..cbaf1cef6 100644 --- a/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml +++ b/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml @@ -50,9 +50,10 @@ modelprop: JBROWSE_SERVICE_URL: "/{{ webapp_ctx }}/service/jbrowse" AI_EXPRESSION_CACHE_DIR: "/var/www/Common/ai-expr-cache" AI_EXPRESSION_QUALTRICS_ID: SV_38C4ZX1JxLi2SEe - OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST: 33 - OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS: 2.5 - OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS: 10 + MAX_DAILY_AI_EXPRESSION_DOLLAR_COST: 33 + # Claude Sonnet 4.5 costs from here: https://platform.claude.com/docs/en/about-claude/pricing + DOLLAR_COST_PER_1M_AI_INPUT_TOKENS: 3 + DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS: 15 user_datasets_uploadTypes_env_map: w: "genelist" diff --git a/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml b/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml index 82596e15e..81c090100 100644 --- a/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml +++ b/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml @@ -46,6 +46,7 @@ modelprop: GOOGLE_MAPS_API_KEY: "{{ lookup('euparc', 'attr=api_key xpath=sites/google_maps default=NOKEY') }}" COMMUNITY_SITE: "//{{ community_env_map[prefix]|default(community_env_map['default']) }}" OPENAI_API_KEY: "{{ lookup('euparc', 'attr=api_key xpath=sites/openai default=NOKEY') }}" + CLAUDE_API_KEY: "{{ lookup('euparc', 'attr=api_key xpath=sites/claude default=NOKEY') }}" # the below extends the w_ q_ prefix pattern used for workspace_env_map, which diff --git a/Model/pom.xml b/Model/pom.xml index 39a98f933..41c4cfbba 100644 --- a/Model/pom.xml +++ b/Model/pom.xml @@ -135,6 +135,11 @@ openai-java + + com.anthropic + anthropic-java + + diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java index 9b61a9f54..cf2f0d1ad 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java @@ -13,6 +13,8 @@ import org.apidb.apicommon.model.report.ai.expression.DailyCostMonitor; import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor; import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.GeneSummaryInputs; +import org.apidb.apicommon.model.report.ai.expression.ClaudeSummarizer; +//import org.apidb.apicommon.model.report.ai.expression.OpenAISummarizer; import org.apidb.apicommon.model.report.ai.expression.Summarizer; import org.gusdb.wdk.model.WdkModelException; import org.gusdb.wdk.model.WdkServiceTemporarilyUnavailableException; @@ -27,13 +29,44 @@ import org.json.JSONException; import org.json.JSONObject; +/** + * Reporter that generates AI-powered gene expression summaries using LLM models. + * + *

This reporter analyzes expression data across multiple experiments for a single gene + * and generates natural language summaries of expression patterns and biological significance. + * Results are cached to minimize API costs and response times.

+ * + *

Configuration (JSON request payload)

+ *
+ * {
+ *   "populateIfNotPresent": true|false,  // If true, generate summary if not cached (default: false)
+ *   "makeTopicEmbeddings": true|false    // If true, generate embedding vectors for topics (default: false)
+ * }
+ * 
+ * + *

Cache Invalidation Warning

+ *

IMPORTANT: Changing the {@code makeTopicEmbeddings} setting will invalidate + * the entire cache for all genes, as this value is included in the cache digest. To avoid costly + * cache regeneration, choose a setting and stick with it across requests. Only change this value + * when you intentionally want to regenerate all summaries with or without embeddings.

+ * + *

Model Configuration

+ *

The AI model and embedding model are hardcoded in the summarizer implementations + * ({@link ClaudeSummarizer}, {@link org.apidb.apicommon.model.report.ai.expression.OpenAISummarizer}). + * Changing models will also invalidate the cache.

+ */ public class SingleGeneAiExpressionReporter extends AbstractReporter { private static final int MAX_RESULT_SIZE = 1; // one gene at a time for now private static final String POPULATION_MODE_PROP_KEY = "populateIfNotPresent"; + private static final String AI_MAX_CONCURRENT_REQUESTS_PROP_KEY = "AI_MAX_CONCURRENT_REQUESTS"; + private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10; + private static final String MAKE_TOPIC_EMBEDDINGS_PROP_KEY = "makeTopicEmbeddings"; private boolean _populateIfNotPresent; + private int _maxConcurrentRequests; + private boolean _makeTopicEmbeddings; private DailyCostMonitor _costMonitor; @Override @@ -42,6 +75,15 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk // assign cache mode _populateIfNotPresent = config.optBoolean(POPULATION_MODE_PROP_KEY, false); + // assign topic embeddings flag + _makeTopicEmbeddings = config.optBoolean(MAKE_TOPIC_EMBEDDINGS_PROP_KEY, false); + + // read max concurrent requests from model properties or use default + String maxConcurrentRequestsStr = _wdkModel.getProperties().get(AI_MAX_CONCURRENT_REQUESTS_PROP_KEY); + _maxConcurrentRequests = maxConcurrentRequestsStr != null + ? Integer.parseInt(maxConcurrentRequestsStr) + : DEFAULT_MAX_CONCURRENT_REQUESTS; + // instantiate cost monitor _costMonitor = new DailyCostMonitor(_wdkModel); @@ -52,7 +94,7 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk " should only be assigned to " + geneRecordClass.getFullName()); } - // check result size; limit to small results due to OpenAI cost + // check result size; limit to small results due to AI API cost if (_baseAnswer.getResultSizeFactory().getResultSize() > MAX_RESULT_SIZE) { throw new ReporterConfigException("This reporter cannot be called with results of size greater than " + MAX_RESULT_SIZE); } @@ -79,9 +121,11 @@ protected void write(OutputStream out) throws IOException, WdkModelException { // open summary cache (manages persistence of expression data) AiExpressionCache cache = AiExpressionCache.getInstance(_wdkModel); - // create summarizer (interacts with OpenAI) - Summarizer summarizer = new Summarizer(_wdkModel, _costMonitor); - + // create summarizer (interacts with Claude) + ClaudeSummarizer summarizer = new ClaudeSummarizer(_wdkModel, _costMonitor, _makeTopicEmbeddings); + // or alternatively use OpenAI (with the appropriate import) + // OpenAISummarizer summarizer = new OpenAISummarizer(_wdkModel, _costMonitor, _makeTopicEmbeddings); + // open record and output streams try (RecordStream recordStream = RecordStreamFactory.getRecordStream(_baseAnswer, List.of(), tables); BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out))) { @@ -93,12 +137,13 @@ protected void write(OutputStream out) throws IOException, WdkModelException { // create summary inputs GeneSummaryInputs summaryInputs = - GeneRecordProcessor.getSummaryInputsFromRecord(record, Summarizer.OPENAI_CHAT_MODEL.toString(), + GeneRecordProcessor.getSummaryInputsFromRecord(record, ClaudeSummarizer.CLAUDE_MODEL.toString(), + Summarizer.EMBEDDING_MODEL.asString(), _makeTopicEmbeddings, ClaudeSummarizer.USE_EXTENDED_THINKING, Summarizer::getExperimentMessage, Summarizer::getFinalSummaryMessage); // fetch summary, producing if necessary and requested JSONObject expressionSummary = _populateIfNotPresent - ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments) + ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments, _maxConcurrentRequests) : cache.readSummary(summaryInputs); // join entries with commas diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java index 3bc5768b7..4054eb1e6 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java @@ -73,7 +73,6 @@ public class AiExpressionCache { private static Logger LOG = Logger.getLogger(AiExpressionCache.class); // parallel processing - private static final int MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST = 10; private static final long VISIT_ENTRY_LOCK_MAX_WAIT_MILLIS = 50; // cache location @@ -317,18 +316,20 @@ private static Optional readCachedData(Path entryDir) { * @param summaryInputs gene summary inputs * @param experimentDescriber function to describe an experiment * @param experimentSummarizer function to summarize experiments into an expression summary + * @param maxConcurrentRequests maximum number of concurrent experiment lookups * @return expression summary (will always be a cache hit) */ public JSONObject populateSummary(GeneSummaryInputs summaryInputs, FunctionWithException> experimentDescriber, - BiFunctionWithException, JSONObject> experimentSummarizer) { + BiFunctionWithException, JSONObject> experimentSummarizer, + int maxConcurrentRequests) { try { return _cache.populateAndProcessContent(summaryInputs.getGeneId(), // populator entryDir -> { // first populate each dataset entry as needed and collect experiment descriptors - List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber); + List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber, maxConcurrentRequests); // sort them most-interesting first so that the "Other" section will be filled // in that order (and also to give the AI the data in a sensible order) @@ -362,14 +363,16 @@ public JSONObject populateSummary(GeneSummaryInputs summaryInputs, * * @param experimentData experiment inputs * @param experimentDescriber function to describe an experiment + * @param maxConcurrentRequests maximum number of concurrent experiment lookups * @return list of cached experiment descriptions * @throws Exception if unable to generate descriptions or store */ private List populateExperiments(List experimentData, - FunctionWithException> experimentDescriber) throws Exception { + FunctionWithException> experimentDescriber, + int maxConcurrentRequests) throws Exception { // use a thread for each experiment, up to a reasonable max - int threadPoolSize = Math.min(MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST, experimentData.size()); + int threadPoolSize = Math.min(maxConcurrentRequests, experimentData.size()); ExecutorService exec = Executors.newFixedThreadPool(threadPoolSize); try { diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java new file mode 100644 index 000000000..2e77db18e --- /dev/null +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java @@ -0,0 +1,134 @@ +package org.apidb.apicommon.model.report.ai.expression; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +import org.gusdb.wdk.model.WdkModel; +import org.gusdb.wdk.model.WdkModelException; + +import com.anthropic.client.AnthropicClientAsync; +import com.anthropic.client.okhttp.AnthropicOkHttpClientAsync; +import com.anthropic.models.messages.MessageCreateParams; +import com.anthropic.models.messages.Model; +import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema; + +public class ClaudeSummarizer extends Summarizer { + + public static final Model CLAUDE_MODEL = Model.CLAUDE_SONNET_4_5_20250929; + public static final boolean USE_EXTENDED_THINKING = false; + + private static final String CLAUDE_API_KEY_PROP_NAME = "CLAUDE_API_KEY"; + + private final AnthropicClientAsync _claudeClient; + + public ClaudeSummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor, boolean makeTopicEmbeddings) throws WdkModelException { + super(wdkModel, costMonitor, makeTopicEmbeddings); + + String apiKey = wdkModel.getProperties().get(CLAUDE_API_KEY_PROP_NAME); + if (apiKey == null) { + throw new WdkModelException("WDK property '" + CLAUDE_API_KEY_PROP_NAME + "' has not been set."); + } + + _claudeClient = AnthropicOkHttpClientAsync.builder() + .apiKey(apiKey) + .maxRetries(32) // Handle 429 errors + .checkJacksonVersionCompatibility(false) + .build(); + } + + @Override + protected CompletableFuture callApiForJson(String prompt, Schema schema) { + // Convert JSON schema to natural language description for Claude + String jsonFormatInstructions = convertSchemaToPromptInstructions(schema); + + String enhancedPrompt = prompt + "\n\n" + jsonFormatInstructions; + + MessageCreateParams.Builder requestBuilder = MessageCreateParams.builder() + .model(CLAUDE_MODEL) + .maxTokens((long) MAX_RESPONSE_TOKENS) + .system(SYSTEM_MESSAGE) + .addUserMessage(enhancedPrompt); + + if (USE_EXTENDED_THINKING) { + requestBuilder.enabledThinking(1024); + } + + MessageCreateParams request = requestBuilder.build(); + + return retryOnOverload( + () -> _claudeClient.messages().create(request), + e -> e instanceof com.anthropic.errors.InternalServerException, + "Claude API call" + ).thenApply(response -> { + // Convert Claude usage to TokenUsage for cost monitoring + com.anthropic.models.messages.Usage claudeUsage = response.usage(); + TokenUsage tokenUsage = TokenUsage.builder() + .promptTokens(claudeUsage.inputTokens()) + .completionTokens(claudeUsage.outputTokens()) + .build(); + + _costMonitor.updateCost(tokenUsage); + + // Extract text from content blocks using stream API + String rawText = response.content().stream() + .flatMap(contentBlock -> contentBlock.text().stream()) + .map(textBlock -> textBlock.text()) + .findFirst() + .orElseThrow(() -> new RuntimeException("No text content found in Claude response")); + + // Strip JSON markdown formatting if present + return stripJsonMarkdown(rawText); + }); + } + + @Override + protected void updateCostMonitor(Object apiResponse) { + // Claude response handling is done in callApiForJson + } + + private String stripJsonMarkdown(String text) { + String trimmed = text.trim(); + + // Remove ```json and ``` markdown formatting + if (trimmed.startsWith("```json")) { + trimmed = trimmed.substring(7); // Remove "```json" + } else if (trimmed.startsWith("```")) { + trimmed = trimmed.substring(3); // Remove "```" + } + + if (trimmed.endsWith("```")) { + trimmed = trimmed.substring(0, trimmed.length() - 3); // Remove trailing "```" + } + + return trimmed.trim(); + } + + private String convertSchemaToPromptInstructions(Schema schema) { + // Convert OpenAI JSON schema to Claude-friendly format instructions + if (schema == experimentResponseSchema) { + return "Respond in valid JSON format matching this exact structure:\n" + + "{\n" + + " \"one_sentence_summary\": \"string describing gene expression\",\n" + + " \"biological_importance\": \"integer 0-5\",\n" + + " \"confidence\": \"integer 0-5\",\n" + + " \"experiment_keywords\": [\"array\", \"of\", \"strings\"],\n" + + " \"notes\": \"string with additional context\"\n" + + "}"; + } else if (schema == finalResponseSchema) { + return "Respond in valid JSON format matching this exact structure:\n" + + "{\n" + + " \"headline\": \"string summarizing key results\",\n" + + " \"one_paragraph_summary\": \"string with ~100 words\",\n" + + " \"topics\": [\n" + + " {\n" + + " \"headline\": \"string summarizing topic\",\n" + + " \"one_sentence_summary\": \"string describing topic results\",\n" + + " \"dataset_ids\": [\"array\", \"of\", \"dataset_id\", \"strings\"]\n" + + " }\n" + + " ]\n" + + "}"; + } else { + return "Respond in valid JSON format."; + } + } +} diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java index fe7459174..277853496 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java @@ -21,8 +21,6 @@ import org.json.JSONException; import org.json.JSONObject; -import com.openai.models.completions.CompletionUsage; - public class DailyCostMonitor { private static final Logger LOG = Logger.getLogger(DailyCostMonitor.class); @@ -31,10 +29,18 @@ public class DailyCostMonitor { private static final String DAILY_COST_ACCUMULATION_FILE_DIR = "dailyCost"; private static final String DAILY_COST_ACCUMULATION_FILE = "daily_cost_accumulation.txt"; - // model prop keys - private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST"; - private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS"; - private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS"; + // model prop keys (new names without OPENAI_ prefix) + private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "MAX_DAILY_AI_EXPRESSION_DOLLAR_COST"; + private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_INPUT_TOKENS"; + private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS"; + + // hardcoded embedding token cost + private static final double DOLLAR_COST_PER_1M_EMBEDDING_TOKENS = 0.02; + + // deprecated model prop keys (with OPENAI_ prefix) + private static final String DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST"; + private static final String DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS"; + private static final String DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS"; // lock characteristics private static final long DEFAULT_TIMEOUT_MILLIS = 1000; @@ -44,11 +50,11 @@ public class DailyCostMonitor { private static final String JSON_DATE_PROP = "currentDate"; private static final String JSON_COST_PROP = "accumulatedCost"; - // completion usage object representing 0 cost - private static final CompletionUsage EMPTY_COST = CompletionUsage.builder() + // token usage object representing 0 cost + private static final TokenUsage EMPTY_COST = TokenUsage.builder() .promptTokens(0) .completionTokens(0) - .totalTokens(0) + .embeddingTokens(0) .build(); private final Path _costMonitoringDir; @@ -57,6 +63,7 @@ public class DailyCostMonitor { private final double _maxDailyDollarCost; private final double _costPerInputToken; private final double _costPerOutputToken; + private final double _costPerEmbeddingToken; public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException { _costMonitoringDir = AiExpressionCache.getAiExpressionCacheParentDir(wdkModel).resolve(DAILY_COST_ACCUMULATION_FILE_DIR).toAbsolutePath(); @@ -68,9 +75,26 @@ public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException { _costMonitoringFile = _costMonitoringDir.resolve(DAILY_COST_ACCUMULATION_FILE); - _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME); - _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000; - _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000; + _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME, DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME); + _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000; + _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000; + _costPerEmbeddingToken = DOLLAR_COST_PER_1M_EMBEDDING_TOKENS / 1000000; + } + + private double getNumberProp(WdkModel wdkModel, String propName, String deprecatedPropName) throws WdkModelException { + // First try the new property name + if (wdkModel.getProperties().get(propName) != null) { + return getNumberProp(wdkModel, propName); + } + + // Fall back to deprecated property name with warning + if (wdkModel.getProperties().get(deprecatedPropName) != null) { + LOG.warn("WDK property '" + deprecatedPropName + "' is deprecated. Please use '" + propName + "' instead."); + return getNumberProp(wdkModel, deprecatedPropName); + } + + // Neither property is set + throw new WdkModelException("WDK property '" + propName + "' (or deprecated '" + deprecatedPropName + "') has not been set."); } private double getNumberProp(WdkModel wdkModel, String propName) throws WdkModelException { @@ -88,11 +112,11 @@ public boolean isCostExceeded() { return updateAndGetCost(EMPTY_COST) > _maxDailyDollarCost; } - public void updateCost(Optional usage) { - updateAndGetCost(usage.orElse(EMPTY_COST)); + public void updateCost(TokenUsage usage) { + updateAndGetCost(usage); } - public double updateAndGetCost(CompletionUsage usageCost) { + public double updateAndGetCost(TokenUsage usageCost) { try (DirectoryLock lock = new DirectoryLock(_costMonitoringDir, DEFAULT_TIMEOUT_MILLIS, DEFAULT_POLL_FREQUENCY_MILLIS)) { // read current values from file @@ -103,7 +127,8 @@ public double updateAndGetCost(CompletionUsage usageCost) { // calculate cost of the current usage double additionalCost = (usageCost.promptTokens() * _costPerInputToken) + - (usageCost.completionTokens() * _costPerOutputToken); + (usageCost.completionTokens() * _costPerOutputToken) + + (usageCost.embeddingTokens() * _costPerEmbeddingToken); // reset cost to zero if date has rolled over to the next day String newDate = getCurrentDateString(); diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java index 26fa39bab..0b6e00653 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java @@ -6,6 +6,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.log4j.Logger; import org.gusdb.fgputil.EncryptionUtil; import org.gusdb.wdk.model.WdkModelException; import org.gusdb.wdk.model.WdkUserException; @@ -21,9 +22,12 @@ */ public class GeneRecordProcessor { + private static final Logger LOG = Logger.getLogger(GeneRecordProcessor.class); + private static final Set KEYS_TO_KEEP = Set.of("y_axis", "description", "genus_species", "project_id", "summary", "dataset_id", "assay_type", "x_axis", "module", "dataset_name", "display_name", - "short_attribution", "paralog_number"); + "short_attribution"); + // TODO: restore "paralog_number" to KEYS_TO_KEEP once it is reliably present in gene records again private static final String EXPRESSION_GRAPH_TABLE = "ExpressionGraphs"; private static final String EXPRESSION_GRAPH_DATA_TABLE = "ExpressionGraphsDataTable"; @@ -32,7 +36,7 @@ public class GeneRecordProcessor { // Increment this to invalidate all previous cache entries: // (for example if changing first level model outputs rather than inputs which are already digestified) - private static final String DATA_MODEL_VERSION = "v3b"; + private static final String DATA_MODEL_VERSION = "v4"; public interface ExperimentInputs { @@ -62,7 +66,7 @@ private static String getGeneId(RecordInstance record) { return record.getPrimaryKey().getValues().get("source_id"); } - public static GeneSummaryInputs getSummaryInputsFromRecord(RecordInstance record, String aiChatModel, Function getExperimentPrompt, Function, String> getFinalSummaryPrompt) throws WdkModelException { + public static GeneSummaryInputs getSummaryInputsFromRecord(RecordInstance record, String aiChatModel, String embeddingModel, boolean makeTopicEmbeddings, boolean useExtendedThinking, Function getExperimentPrompt, Function, String> getFinalSummaryPrompt) throws WdkModelException { String geneId = getGeneId(record); @@ -90,7 +94,7 @@ public String getDigest() { List digests = experimentsWithData.stream() .map(exp -> new JSONObject().put("digest", exp.getDigest())) .collect(Collectors.toList()); - return EncryptionUtil.md5(aiChatModel + ":" + DATA_MODEL_VERSION + ":" + getFinalSummaryPrompt.apply(digests)); + return EncryptionUtil.md5(aiChatModel + ":" + embeddingModel + ":" + makeTopicEmbeddings + ":" + useExtendedThinking + ":" + DATA_MODEL_VERSION + ":" + getFinalSummaryPrompt.apply(digests)); } }; @@ -114,6 +118,16 @@ private static List processExpressionData(RecordInstance recor experimentInfo.put(key, experimentRow.getAttributeValue(key).getValue()); } + // TODO: remove this fallback once paralog_number is reliably present in gene records again; + // restore "paralog_number" to KEYS_TO_KEEP above and delete this block. + try { + experimentInfo.put("paralog_number", experimentRow.getAttributeValue("paralog_number").getValue()); + } catch (WdkModelException e) { + LOG.warn("paralog_number attribute is missing from gene record; defaulting to 0. " + + "This is a temporary workaround - restore it to KEYS_TO_KEEP once the field is available again."); + experimentInfo.put("paralog_number", "0"); + } + String datasetId = experimentRow.getAttributeValue("dataset_id").getValue(); String assayType = experimentRow.getAttributeValue("assay_type").getValue(); String experimentName = experimentRow.getAttributeValue("display_name").getValue(); diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java new file mode 100644 index 000000000..bcbd4cbab --- /dev/null +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java @@ -0,0 +1,77 @@ +package org.apidb.apicommon.model.report.ai.expression; + +import java.util.concurrent.CompletableFuture; + +import org.gusdb.wdk.model.WdkModel; +import org.gusdb.wdk.model.WdkModelException; + +import com.openai.client.OpenAIClientAsync; +import com.openai.client.okhttp.OpenAIOkHttpClientAsync; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.ChatModel; +import com.openai.models.ResponseFormatJsonSchema; +import com.openai.models.ResponseFormatJsonSchema.JsonSchema; + +public class OpenAISummarizer extends Summarizer { + + // provide exact model number for semi-reproducibility + public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06; + + private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY"; + + private final OpenAIClientAsync _openAIClient; + + public OpenAISummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor, boolean makeTopicEmbeddings) throws WdkModelException { + super(wdkModel, costMonitor, makeTopicEmbeddings); + + String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME); + if (apiKey == null) { + throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set."); + } + + _openAIClient = OpenAIOkHttpClientAsync.builder() + .apiKey(apiKey) + .maxRetries(32) // Handle 429 errors + .build(); + } + + @Override + protected CompletableFuture callApiForJson(String prompt, com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema schema) { + ChatCompletionCreateParams request = ChatCompletionCreateParams.builder() + .model(OPENAI_CHAT_MODEL) + .maxCompletionTokens(MAX_RESPONSE_TOKENS) + .responseFormat(ResponseFormatJsonSchema.builder() + .jsonSchema(JsonSchema.builder() + .name("structured-response") + .schema(schema) + .strict(true) + .build()) + .build()) + .addSystemMessage(SYSTEM_MESSAGE) + .addUserMessage(prompt) + .build(); + + return retryOnOverload( + () -> _openAIClient.chat().completions().create(request), + e -> e instanceof com.openai.errors.InternalServerException, + "OpenAI API call" + ).thenApply(completion -> { + // update cost accumulator - convert to TokenUsage + completion.usage().ifPresent(usage -> { + TokenUsage tokenUsage = TokenUsage.builder() + .promptTokens(usage.promptTokens()) + .completionTokens(usage.completionTokens()) + .build(); + _costMonitor.updateCost(tokenUsage); + }); + + // return JSON string + return completion.choices().get(0).message().content().get(); + }); + } + + @Override + protected void updateCostMonitor(Object apiResponse) { + // OpenAI response handling is done in callApiForJson + } +} diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java index 4b119fe3c..e64d8e27c 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java @@ -7,6 +7,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import java.util.stream.Collectors; import org.apache.log4j.Logger; import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.ExperimentInputs; @@ -20,25 +21,29 @@ import com.openai.client.OpenAIClientAsync; import com.openai.client.okhttp.OpenAIOkHttpClientAsync; import com.openai.core.JsonValue; +import com.openai.models.embeddings.EmbeddingCreateParams; +import com.openai.models.embeddings.EmbeddingModel; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.ChatModel; import com.openai.models.ResponseFormatJsonSchema; import com.openai.models.ResponseFormatJsonSchema.JsonSchema; import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema; -public class Summarizer { +public abstract class Summarizer { - // provide exact model number for semi-reproducibility - public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06; + protected static final int MAX_RESPONSE_TOKENS = 10000; - private static final int MAX_RESPONSE_TOKENS = 10000; + public static final EmbeddingModel EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_3_SMALL; + private static final int EMBEDDING_DIMENSIONS = 512; + private static final int EMBEDDING_DECIMAL_PLACES = 4; private static final int MAX_MALFORMED_RESPONSE_RETRIES = 3; + private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY"; - private static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data"; + protected static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data"; // Prepare JSON schemas for structured responses - private static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder() + protected static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder() .putAdditionalProperty("type", JsonValue.from("object")) .putAdditionalProperty("properties", JsonValue.from(Map.of( "one_sentence_summary", Map.of("type", "string"), @@ -57,7 +62,7 @@ public class Summarizer { .putAdditionalProperty("additionalProperties", JsonValue.from(false)) .build(); - private static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder() + protected static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder() .putAdditionalProperty("type", JsonValue.from("object")) .putAdditionalProperty("properties", JsonValue.from(Map.of( "headline", Map.of("type", "string"), @@ -81,26 +86,139 @@ public class Summarizer { .putAdditionalProperty("additionalProperties", JsonValue.from(false)) .build(); - private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY"; - - private final OpenAIClientAsync _openAIClient; - private final DailyCostMonitor _costMonitor; + protected final DailyCostMonitor _costMonitor; + private final OpenAIClientAsync _embeddingClient; + protected final boolean _makeTopicEmbeddings; private static final Logger LOG = Logger.getLogger(Summarizer.class); - public Summarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException { + public Summarizer(WdkModel wdkModel, DailyCostMonitor costMonitor, boolean makeTopicEmbeddings) throws WdkModelException { + _costMonitor = costMonitor; + _makeTopicEmbeddings = makeTopicEmbeddings; - String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME); - if (apiKey == null) { - throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set."); + // Only create embedding client if we need to make topic embeddings + if (makeTopicEmbeddings) { + String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME); + if (apiKey == null) { + throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set."); + } + + _embeddingClient = OpenAIOkHttpClientAsync.builder() + .apiKey(apiKey) + .maxRetries(32) // Handle 429 errors + .build(); + } else { + _embeddingClient = null; } + } - _openAIClient = OpenAIOkHttpClientAsync.builder() - .apiKey(apiKey) - .maxRetries(32) // Handle 429 errors + private CompletableFuture> getEmbedding(String text) { + // Safety check: ensure embedding client was initialized + if (_embeddingClient == null) { + LOG.error("Attempted to generate embedding but embedding client was not initialized (makeTopicEmbeddings=false)"); + return CompletableFuture.completedFuture(List.of()); + } + + EmbeddingCreateParams request = EmbeddingCreateParams.builder() + .model(EMBEDDING_MODEL) + .input(text) + .dimensions(EMBEDDING_DIMENSIONS) .build(); - _costMonitor = costMonitor; + return _embeddingClient.embeddings().create(request).thenApply(response -> { + // Update cost monitor - convert embedding usage to TokenUsage + com.openai.models.embeddings.CreateEmbeddingResponse.Usage embeddingUsage = response.usage(); + TokenUsage tokenUsage = TokenUsage.builder() + .embeddingTokens(embeddingUsage.totalTokens()) + .build(); + _costMonitor.updateCost(tokenUsage); + + // Extract embedding vector from first result (convert Float to Double) + List rawEmbedding = response.data().get(0).embedding(); + + // Round to specified decimal places + double scale = Math.pow(10, EMBEDDING_DECIMAL_PLACES); + return rawEmbedding.stream() + .map(val -> Math.round(val.doubleValue() * scale) / scale) + .collect(Collectors.toList()); + }).exceptionally(e -> { + LOG.error("Failed to generate embedding: " + e.getMessage(), e); + return List.of(); // Return empty list on error + }); + } + + /** + * Retries an operation with exponential backoff if it fails with a retriable error. + * + * @param the return type of the operation + * @param operation supplier that produces the CompletableFuture to execute + * @param shouldRetry predicate to determine if an exception should trigger a retry + * @param operationDescription description for logging purposes + * @return CompletableFuture with the result of the operation + */ + protected CompletableFuture retryOnOverload( + java.util.function.Supplier> operation, + java.util.function.Predicate shouldRetry, + String operationDescription) { + + final int maxRetries = 3; + final long[] backoffDelaysMs = {1000, 2000, 4000}; // 1s, 2s, 4s + + return retryWithBackoff(operation, shouldRetry, operationDescription, 0, maxRetries, backoffDelaysMs); + } + + private CompletableFuture retryWithBackoff( + java.util.function.Supplier> operation, + java.util.function.Predicate shouldRetry, + String operationDescription, + int attemptNumber, + int maxRetries, + long[] backoffDelaysMs) { + + CompletableFuture result = new CompletableFuture<>(); + + operation.get().whenComplete((value, throwable) -> { + if (throwable == null) { + // Success case + result.complete(value); + } else { + // Error case - unwrap CompletionException to get the actual cause + Throwable actualCause = throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null + ? throwable.getCause() + : throwable; + + // Check if we should retry this exception and haven't exceeded max retries + if (shouldRetry.test(actualCause) && attemptNumber < maxRetries) { + long delayMs = backoffDelaysMs[attemptNumber]; + LOG.warn(String.format( + "Retrying %s after error (attempt %d/%d, waiting %dms): %s", + operationDescription, attemptNumber + 1, maxRetries, delayMs, actualCause.getMessage())); + + // Schedule retry after delay + new java.util.Timer().schedule(new java.util.TimerTask() { + @Override + public void run() { + retryWithBackoff(operation, shouldRetry, operationDescription, attemptNumber + 1, maxRetries, backoffDelaysMs) + .whenComplete((retryValue, retryError) -> { + if (retryError != null) { + result.completeExceptionally(retryError); + } else { + result.complete(retryValue); + } + }); + } + }, delayMs); + } else { + // No more retries or non-retriable exception + if (attemptNumber >= maxRetries) { + LOG.error(String.format("Failed %s after %d retries: %s", operationDescription, maxRetries, actualCause.getMessage())); + } + result.completeExceptionally(throwable); + } + } + }); + + return result; } public static String getExperimentMessage(JSONObject experiment) { @@ -133,12 +251,9 @@ public static String getExperimentMessage(JSONObject experiment) { public CompletableFuture describeExperiment(ExperimentInputs experimentInputs) { - ChatCompletionCreateParams request = buildAiRequest( - "experiment-summary", - experimentResponseSchema, - getExperimentMessage(experimentInputs.getExperimentData())); + String prompt = getExperimentMessage(experimentInputs.getExperimentData()); - return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), request, json -> { + return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), prompt, experimentResponseSchema, json -> { // add some fields to the result to aid the final summarization return json .put("dataset_id", experimentInputs.getDatasetId()) @@ -151,6 +266,7 @@ public static String getFinalSummaryMessage(List experiments) { return "Below are AI-generated summaries of one gene's behavior in all the transcriptomics experiments available in VEuPathDB, provided in JSON format:\n\n" + String.format("```json\n%s\n```\n\n", new JSONArray(experiments).toString(2)) + "Generate a one-paragraph summary (~100 words) describing the gene's expression. Structure it using ,
    , and
  • tags with no attributes. If relevant, briefly speculate on the gene's potential function, but only if justified by the data. Also, generate a short, specific headline for the summary. The headline must reflect this gene's expression and **must not** include generic phrases like \"comprehensive insights into\" or the word \"gene\".\n\n" + + "Use sentence case for all headlines: capitalize only the first word and proper nouns, not every word.\n\n" + "Additionally, group the per-experiment summaries (identified by `dataset_id`) with `biological_importance > 3` and `confidence > 3` into sections by topic. For each topic, provide:\n" + "- A headline summarizing the key experimental results within the topic\n" + "- A concise one-sentence summary of the topic's experimental results\n\n" + @@ -159,18 +275,17 @@ public static String getFinalSummaryMessage(List experiments) { public JSONObject summarizeExperiments(String geneId, List experiments) { - ChatCompletionCreateParams request = buildAiRequest( - "expression-summary", - finalResponseSchema, - getFinalSummaryMessage(experiments)); + String prompt = getFinalSummaryMessage(experiments); - return getValidatedAiResponse("summary for gene " + geneId, request, json -> + return getValidatedAiResponse("summary for gene " + geneId, prompt, finalResponseSchema, json -> + json // Return json as-is; consolidateSummary will be called separately + ).thenCompose(json -> // quality control (remove bad `dataset_id`s) and add 'Others' section for any experiments not listed by AI consolidateSummary(json, experiments) ).join(); } - private static JSONObject consolidateSummary(JSONObject summaryResponse, + private CompletableFuture consolidateSummary(JSONObject summaryResponse, List individualResults) { // Gather all dataset IDs from individualResults and map them to summaries. // Preserving the order of individualResults. @@ -180,7 +295,8 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse, } Set seenDatasetIds = new LinkedHashSet<>(); - JSONArray deduplicatedTopics = new JSONArray(); + List deduplicatedTopicsList = new java.util.ArrayList<>(); + List> embeddingFutures = new java.util.ArrayList<>(); JSONArray topics = summaryResponse.getJSONArray("topics"); for (int i = 0; i < topics.length(); i++) { @@ -210,7 +326,21 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse, if (summaries.length() > 0) { topic.put("summaries", summaries); topic.remove("dataset_ids"); - deduplicatedTopics.put(topic); + deduplicatedTopicsList.add(topic); + + // Generate embedding for non-"Other" topics (if enabled) + if (_makeTopicEmbeddings) { + String headline = topic.optString("headline", ""); + if (!headline.equals("Other")) { + String embeddingText = headline + "\n\n" + topic.optString("one_sentence_summary", ""); + CompletableFuture embeddingFuture = getEmbedding(embeddingText).thenAccept(embedding -> { + if (!embedding.isEmpty()) { + topic.put("embedding_vector", embedding); + } + }); + embeddingFutures.add(embeddingFuture); + } + } } } @@ -230,44 +360,38 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse, otherTopic.put("one_sentence_summary", "The AI ordered these experiments by biological importance but did not group them into topics."); otherTopic.put("summaries", otherSummaries); - deduplicatedTopics.put(otherTopic); + deduplicatedTopicsList.add(otherTopic); + // Note: no embedding for "Other" topic } - // Create final deduplicated summary - JSONObject finalSummary = new JSONObject(summaryResponse.toString()); - finalSummary.put("topics", deduplicatedTopics); - return finalSummary; + // Wait for all embeddings to complete, then create final summary + return CompletableFuture.allOf(embeddingFutures.toArray(new CompletableFuture[0])) + .thenApply(v -> { + // Convert deduplicated topics list back to JSONArray + JSONArray deduplicatedTopics = new JSONArray(); + for (JSONObject topic : deduplicatedTopicsList) { + deduplicatedTopics.put(topic); + } + + // Create final deduplicated summary + JSONObject finalSummary = new JSONObject(summaryResponse.toString()); + finalSummary.put("topics", deduplicatedTopics); + return finalSummary; + }); } - private static ChatCompletionCreateParams buildAiRequest(String name, Schema schema, String userMessage) { - return ChatCompletionCreateParams.builder() - .model(OPENAI_CHAT_MODEL) - .maxCompletionTokens(MAX_RESPONSE_TOKENS) - .responseFormat(ResponseFormatJsonSchema.builder() - .jsonSchema(JsonSchema.builder() - .name(name) - .schema(schema) - .strict(true) - .build()) - .build()) - .addSystemMessage(SYSTEM_MESSAGE) - .addUserMessage(userMessage) - .build(); - } + protected abstract CompletableFuture callApiForJson(String prompt, Schema schema); + + protected abstract void updateCostMonitor(Object apiResponse); private CompletableFuture getValidatedAiResponse( String operationDescription, - ChatCompletionCreateParams request, + String prompt, + Schema schema, Function createFinalJson ) { - return _openAIClient.chat().completions().create(request).thenApply(completion -> { - - // update cost accumulator - _costMonitor.updateCost(completion.usage()); - - // expect response to be a JSON string - String jsonString = completion.choices().get(0).message().content().get(); + return callApiForJson(prompt, schema).thenApply(jsonString -> { int attempts = 1; Exception mostRecentError; @@ -281,12 +405,10 @@ private CompletableFuture getValidatedAiResponse( } catch (JSONException e) { mostRecentError = e; - LOG.warn("Malformed JSON from OpenAI (attempt " + attempts + ") for " + operationDescription + ". Retrying..."); + LOG.warn("Malformed JSON from AI (attempt " + attempts + ") for " + operationDescription + ". Retrying..."); - // Re-request from OpenAI - completion = _openAIClient.chat().completions().create(request).join(); - _costMonitor.updateCost(completion.usage()); - jsonString = completion.choices().get(0).message().content().get(); + // Re-request from AI + jsonString = callApiForJson(prompt, schema).join(); attempts++; } } diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/TokenUsage.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/TokenUsage.java new file mode 100644 index 000000000..ea76ca3f8 --- /dev/null +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/TokenUsage.java @@ -0,0 +1,59 @@ +package org.apidb.apicommon.model.report.ai.expression; + +/** + * Represents token usage for AI API calls, including chat completions and embeddings. + * Immutable value object with builder pattern. + */ +public class TokenUsage { + + private final long promptTokens; + private final long completionTokens; + private final long embeddingTokens; + + private TokenUsage(Builder builder) { + this.promptTokens = builder.promptTokens; + this.completionTokens = builder.completionTokens; + this.embeddingTokens = builder.embeddingTokens; + } + + public long promptTokens() { + return promptTokens; + } + + public long completionTokens() { + return completionTokens; + } + + public long embeddingTokens() { + return embeddingTokens; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private long promptTokens = 0; + private long completionTokens = 0; + private long embeddingTokens = 0; + + public Builder promptTokens(long promptTokens) { + this.promptTokens = promptTokens; + return this; + } + + public Builder completionTokens(long completionTokens) { + this.completionTokens = completionTokens; + return this; + } + + public Builder embeddingTokens(long embeddingTokens) { + this.embeddingTokens = embeddingTokens; + return this; + } + + public TokenUsage build() { + return new TokenUsage(this); + } + } +}