diff --git a/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml b/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml
index be7b97146..cbaf1cef6 100644
--- a/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml
+++ b/Model/lib/conifer/roles/conifer/vars/ApiCommon/default.yml
@@ -50,9 +50,10 @@ modelprop:
JBROWSE_SERVICE_URL: "/{{ webapp_ctx }}/service/jbrowse"
AI_EXPRESSION_CACHE_DIR: "/var/www/Common/ai-expr-cache"
AI_EXPRESSION_QUALTRICS_ID: SV_38C4ZX1JxLi2SEe
- OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST: 33
- OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS: 2.5
- OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS: 10
+ MAX_DAILY_AI_EXPRESSION_DOLLAR_COST: 33
+ # Claude Sonnet 4.5 costs from here: https://platform.claude.com/docs/en/about-claude/pricing
+ DOLLAR_COST_PER_1M_AI_INPUT_TOKENS: 3
+ DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS: 15
user_datasets_uploadTypes_env_map:
w: "genelist"
diff --git a/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml b/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml
index 82596e15e..81c090100 100644
--- a/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml
+++ b/Model/lib/conifer/roles/conifer/vars/ApiCommon/production/default.yml
@@ -46,6 +46,7 @@ modelprop:
GOOGLE_MAPS_API_KEY: "{{ lookup('euparc', 'attr=api_key xpath=sites/google_maps default=NOKEY') }}"
COMMUNITY_SITE: "//{{ community_env_map[prefix]|default(community_env_map['default']) }}"
OPENAI_API_KEY: "{{ lookup('euparc', 'attr=api_key xpath=sites/openai default=NOKEY') }}"
+ CLAUDE_API_KEY: "{{ lookup('euparc', 'attr=api_key xpath=sites/claude default=NOKEY') }}"
# the below extends the w_ q_ prefix pattern used for workspace_env_map, which
diff --git a/Model/pom.xml b/Model/pom.xml
index 39a98f933..41c4cfbba 100644
--- a/Model/pom.xml
+++ b/Model/pom.xml
@@ -135,6 +135,11 @@
openai-java
+
+ com.anthropic
+ anthropic-java
+
+
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java
index 9b61a9f54..cf2f0d1ad 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java
@@ -13,6 +13,8 @@
import org.apidb.apicommon.model.report.ai.expression.DailyCostMonitor;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.GeneSummaryInputs;
+import org.apidb.apicommon.model.report.ai.expression.ClaudeSummarizer;
+//import org.apidb.apicommon.model.report.ai.expression.OpenAISummarizer;
import org.apidb.apicommon.model.report.ai.expression.Summarizer;
import org.gusdb.wdk.model.WdkModelException;
import org.gusdb.wdk.model.WdkServiceTemporarilyUnavailableException;
@@ -27,13 +29,44 @@
import org.json.JSONException;
import org.json.JSONObject;
+/**
+ * Reporter that generates AI-powered gene expression summaries using LLM models.
+ *
+ *
This reporter analyzes expression data across multiple experiments for a single gene
+ * and generates natural language summaries of expression patterns and biological significance.
+ * Results are cached to minimize API costs and response times.
+ *
+ * Configuration (JSON request payload)
+ *
+ * {
+ * "populateIfNotPresent": true|false, // If true, generate summary if not cached (default: false)
+ * "makeTopicEmbeddings": true|false // If true, generate embedding vectors for topics (default: false)
+ * }
+ *
+ *
+ * Cache Invalidation Warning
+ * IMPORTANT: Changing the {@code makeTopicEmbeddings} setting will invalidate
+ * the entire cache for all genes, as this value is included in the cache digest. To avoid costly
+ * cache regeneration, choose a setting and stick with it across requests. Only change this value
+ * when you intentionally want to regenerate all summaries with or without embeddings.
+ *
+ * Model Configuration
+ * The AI model and embedding model are hardcoded in the summarizer implementations
+ * ({@link ClaudeSummarizer}, {@link org.apidb.apicommon.model.report.ai.expression.OpenAISummarizer}).
+ * Changing models will also invalidate the cache.
+ */
public class SingleGeneAiExpressionReporter extends AbstractReporter {
private static final int MAX_RESULT_SIZE = 1; // one gene at a time for now
private static final String POPULATION_MODE_PROP_KEY = "populateIfNotPresent";
+ private static final String AI_MAX_CONCURRENT_REQUESTS_PROP_KEY = "AI_MAX_CONCURRENT_REQUESTS";
+ private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10;
+ private static final String MAKE_TOPIC_EMBEDDINGS_PROP_KEY = "makeTopicEmbeddings";
private boolean _populateIfNotPresent;
+ private int _maxConcurrentRequests;
+ private boolean _makeTopicEmbeddings;
private DailyCostMonitor _costMonitor;
@Override
@@ -42,6 +75,15 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk
// assign cache mode
_populateIfNotPresent = config.optBoolean(POPULATION_MODE_PROP_KEY, false);
+ // assign topic embeddings flag
+ _makeTopicEmbeddings = config.optBoolean(MAKE_TOPIC_EMBEDDINGS_PROP_KEY, false);
+
+ // read max concurrent requests from model properties or use default
+ String maxConcurrentRequestsStr = _wdkModel.getProperties().get(AI_MAX_CONCURRENT_REQUESTS_PROP_KEY);
+ _maxConcurrentRequests = maxConcurrentRequestsStr != null
+ ? Integer.parseInt(maxConcurrentRequestsStr)
+ : DEFAULT_MAX_CONCURRENT_REQUESTS;
+
// instantiate cost monitor
_costMonitor = new DailyCostMonitor(_wdkModel);
@@ -52,7 +94,7 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk
" should only be assigned to " + geneRecordClass.getFullName());
}
- // check result size; limit to small results due to OpenAI cost
+ // check result size; limit to small results due to AI API cost
if (_baseAnswer.getResultSizeFactory().getResultSize() > MAX_RESULT_SIZE) {
throw new ReporterConfigException("This reporter cannot be called with results of size greater than " + MAX_RESULT_SIZE);
}
@@ -79,9 +121,11 @@ protected void write(OutputStream out) throws IOException, WdkModelException {
// open summary cache (manages persistence of expression data)
AiExpressionCache cache = AiExpressionCache.getInstance(_wdkModel);
- // create summarizer (interacts with OpenAI)
- Summarizer summarizer = new Summarizer(_wdkModel, _costMonitor);
-
+ // create summarizer (interacts with Claude)
+ ClaudeSummarizer summarizer = new ClaudeSummarizer(_wdkModel, _costMonitor, _makeTopicEmbeddings);
+ // or alternatively use OpenAI (with the appropriate import)
+ // OpenAISummarizer summarizer = new OpenAISummarizer(_wdkModel, _costMonitor, _makeTopicEmbeddings);
+
// open record and output streams
try (RecordStream recordStream = RecordStreamFactory.getRecordStream(_baseAnswer, List.of(), tables);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out))) {
@@ -93,12 +137,13 @@ protected void write(OutputStream out) throws IOException, WdkModelException {
// create summary inputs
GeneSummaryInputs summaryInputs =
- GeneRecordProcessor.getSummaryInputsFromRecord(record, Summarizer.OPENAI_CHAT_MODEL.toString(),
+ GeneRecordProcessor.getSummaryInputsFromRecord(record, ClaudeSummarizer.CLAUDE_MODEL.toString(),
+ Summarizer.EMBEDDING_MODEL.asString(), _makeTopicEmbeddings, ClaudeSummarizer.USE_EXTENDED_THINKING,
Summarizer::getExperimentMessage, Summarizer::getFinalSummaryMessage);
// fetch summary, producing if necessary and requested
JSONObject expressionSummary = _populateIfNotPresent
- ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments)
+ ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments, _maxConcurrentRequests)
: cache.readSummary(summaryInputs);
// join entries with commas
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java
index 3bc5768b7..4054eb1e6 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java
@@ -73,7 +73,6 @@ public class AiExpressionCache {
private static Logger LOG = Logger.getLogger(AiExpressionCache.class);
// parallel processing
- private static final int MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST = 10;
private static final long VISIT_ENTRY_LOCK_MAX_WAIT_MILLIS = 50;
// cache location
@@ -317,18 +316,20 @@ private static Optional readCachedData(Path entryDir) {
* @param summaryInputs gene summary inputs
* @param experimentDescriber function to describe an experiment
* @param experimentSummarizer function to summarize experiments into an expression summary
+ * @param maxConcurrentRequests maximum number of concurrent experiment lookups
* @return expression summary (will always be a cache hit)
*/
public JSONObject populateSummary(GeneSummaryInputs summaryInputs,
FunctionWithException> experimentDescriber,
- BiFunctionWithException, JSONObject> experimentSummarizer) {
+ BiFunctionWithException, JSONObject> experimentSummarizer,
+ int maxConcurrentRequests) {
try {
return _cache.populateAndProcessContent(summaryInputs.getGeneId(),
// populator
entryDir -> {
// first populate each dataset entry as needed and collect experiment descriptors
- List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber);
+ List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber, maxConcurrentRequests);
// sort them most-interesting first so that the "Other" section will be filled
// in that order (and also to give the AI the data in a sensible order)
@@ -362,14 +363,16 @@ public JSONObject populateSummary(GeneSummaryInputs summaryInputs,
*
* @param experimentData experiment inputs
* @param experimentDescriber function to describe an experiment
+ * @param maxConcurrentRequests maximum number of concurrent experiment lookups
* @return list of cached experiment descriptions
* @throws Exception if unable to generate descriptions or store
*/
private List populateExperiments(List experimentData,
- FunctionWithException> experimentDescriber) throws Exception {
+ FunctionWithException> experimentDescriber,
+ int maxConcurrentRequests) throws Exception {
// use a thread for each experiment, up to a reasonable max
- int threadPoolSize = Math.min(MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST, experimentData.size());
+ int threadPoolSize = Math.min(maxConcurrentRequests, experimentData.size());
ExecutorService exec = Executors.newFixedThreadPool(threadPoolSize);
try {
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java
new file mode 100644
index 000000000..2e77db18e
--- /dev/null
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java
@@ -0,0 +1,134 @@
+package org.apidb.apicommon.model.report.ai.expression;
+
+import java.time.Duration;
+import java.util.concurrent.CompletableFuture;
+
+import org.gusdb.wdk.model.WdkModel;
+import org.gusdb.wdk.model.WdkModelException;
+
+import com.anthropic.client.AnthropicClientAsync;
+import com.anthropic.client.okhttp.AnthropicOkHttpClientAsync;
+import com.anthropic.models.messages.MessageCreateParams;
+import com.anthropic.models.messages.Model;
+import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema;
+
+public class ClaudeSummarizer extends Summarizer {
+
+ public static final Model CLAUDE_MODEL = Model.CLAUDE_SONNET_4_5_20250929;
+ public static final boolean USE_EXTENDED_THINKING = false;
+
+ private static final String CLAUDE_API_KEY_PROP_NAME = "CLAUDE_API_KEY";
+
+ private final AnthropicClientAsync _claudeClient;
+
+ public ClaudeSummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor, boolean makeTopicEmbeddings) throws WdkModelException {
+ super(wdkModel, costMonitor, makeTopicEmbeddings);
+
+ String apiKey = wdkModel.getProperties().get(CLAUDE_API_KEY_PROP_NAME);
+ if (apiKey == null) {
+ throw new WdkModelException("WDK property '" + CLAUDE_API_KEY_PROP_NAME + "' has not been set.");
+ }
+
+ _claudeClient = AnthropicOkHttpClientAsync.builder()
+ .apiKey(apiKey)
+ .maxRetries(32) // Handle 429 errors
+ .checkJacksonVersionCompatibility(false)
+ .build();
+ }
+
+ @Override
+ protected CompletableFuture callApiForJson(String prompt, Schema schema) {
+ // Convert JSON schema to natural language description for Claude
+ String jsonFormatInstructions = convertSchemaToPromptInstructions(schema);
+
+ String enhancedPrompt = prompt + "\n\n" + jsonFormatInstructions;
+
+ MessageCreateParams.Builder requestBuilder = MessageCreateParams.builder()
+ .model(CLAUDE_MODEL)
+ .maxTokens((long) MAX_RESPONSE_TOKENS)
+ .system(SYSTEM_MESSAGE)
+ .addUserMessage(enhancedPrompt);
+
+ if (USE_EXTENDED_THINKING) {
+ requestBuilder.enabledThinking(1024);
+ }
+
+ MessageCreateParams request = requestBuilder.build();
+
+ return retryOnOverload(
+ () -> _claudeClient.messages().create(request),
+ e -> e instanceof com.anthropic.errors.InternalServerException,
+ "Claude API call"
+ ).thenApply(response -> {
+ // Convert Claude usage to TokenUsage for cost monitoring
+ com.anthropic.models.messages.Usage claudeUsage = response.usage();
+ TokenUsage tokenUsage = TokenUsage.builder()
+ .promptTokens(claudeUsage.inputTokens())
+ .completionTokens(claudeUsage.outputTokens())
+ .build();
+
+ _costMonitor.updateCost(tokenUsage);
+
+ // Extract text from content blocks using stream API
+ String rawText = response.content().stream()
+ .flatMap(contentBlock -> contentBlock.text().stream())
+ .map(textBlock -> textBlock.text())
+ .findFirst()
+ .orElseThrow(() -> new RuntimeException("No text content found in Claude response"));
+
+ // Strip JSON markdown formatting if present
+ return stripJsonMarkdown(rawText);
+ });
+ }
+
+ @Override
+ protected void updateCostMonitor(Object apiResponse) {
+ // Claude response handling is done in callApiForJson
+ }
+
+ private String stripJsonMarkdown(String text) {
+ String trimmed = text.trim();
+
+ // Remove ```json and ``` markdown formatting
+ if (trimmed.startsWith("```json")) {
+ trimmed = trimmed.substring(7); // Remove "```json"
+ } else if (trimmed.startsWith("```")) {
+ trimmed = trimmed.substring(3); // Remove "```"
+ }
+
+ if (trimmed.endsWith("```")) {
+ trimmed = trimmed.substring(0, trimmed.length() - 3); // Remove trailing "```"
+ }
+
+ return trimmed.trim();
+ }
+
+ private String convertSchemaToPromptInstructions(Schema schema) {
+ // Convert OpenAI JSON schema to Claude-friendly format instructions
+ if (schema == experimentResponseSchema) {
+ return "Respond in valid JSON format matching this exact structure:\n" +
+ "{\n" +
+ " \"one_sentence_summary\": \"string describing gene expression\",\n" +
+ " \"biological_importance\": \"integer 0-5\",\n" +
+ " \"confidence\": \"integer 0-5\",\n" +
+ " \"experiment_keywords\": [\"array\", \"of\", \"strings\"],\n" +
+ " \"notes\": \"string with additional context\"\n" +
+ "}";
+ } else if (schema == finalResponseSchema) {
+ return "Respond in valid JSON format matching this exact structure:\n" +
+ "{\n" +
+ " \"headline\": \"string summarizing key results\",\n" +
+ " \"one_paragraph_summary\": \"string with ~100 words\",\n" +
+ " \"topics\": [\n" +
+ " {\n" +
+ " \"headline\": \"string summarizing topic\",\n" +
+ " \"one_sentence_summary\": \"string describing topic results\",\n" +
+ " \"dataset_ids\": [\"array\", \"of\", \"dataset_id\", \"strings\"]\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+ } else {
+ return "Respond in valid JSON format.";
+ }
+ }
+}
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java
index fe7459174..277853496 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java
@@ -21,8 +21,6 @@
import org.json.JSONException;
import org.json.JSONObject;
-import com.openai.models.completions.CompletionUsage;
-
public class DailyCostMonitor {
private static final Logger LOG = Logger.getLogger(DailyCostMonitor.class);
@@ -31,10 +29,18 @@ public class DailyCostMonitor {
private static final String DAILY_COST_ACCUMULATION_FILE_DIR = "dailyCost";
private static final String DAILY_COST_ACCUMULATION_FILE = "daily_cost_accumulation.txt";
- // model prop keys
- private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
- private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
- private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
+ // model prop keys (new names without OPENAI_ prefix)
+ private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
+ private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
+ private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
+
+ // hardcoded embedding token cost
+ private static final double DOLLAR_COST_PER_1M_EMBEDDING_TOKENS = 0.02;
+
+ // deprecated model prop keys (with OPENAI_ prefix)
+ private static final String DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
+ private static final String DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
+ private static final String DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
// lock characteristics
private static final long DEFAULT_TIMEOUT_MILLIS = 1000;
@@ -44,11 +50,11 @@ public class DailyCostMonitor {
private static final String JSON_DATE_PROP = "currentDate";
private static final String JSON_COST_PROP = "accumulatedCost";
- // completion usage object representing 0 cost
- private static final CompletionUsage EMPTY_COST = CompletionUsage.builder()
+ // token usage object representing 0 cost
+ private static final TokenUsage EMPTY_COST = TokenUsage.builder()
.promptTokens(0)
.completionTokens(0)
- .totalTokens(0)
+ .embeddingTokens(0)
.build();
private final Path _costMonitoringDir;
@@ -57,6 +63,7 @@ public class DailyCostMonitor {
private final double _maxDailyDollarCost;
private final double _costPerInputToken;
private final double _costPerOutputToken;
+ private final double _costPerEmbeddingToken;
public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException {
_costMonitoringDir = AiExpressionCache.getAiExpressionCacheParentDir(wdkModel).resolve(DAILY_COST_ACCUMULATION_FILE_DIR).toAbsolutePath();
@@ -68,9 +75,26 @@ public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException {
_costMonitoringFile = _costMonitoringDir.resolve(DAILY_COST_ACCUMULATION_FILE);
- _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME);
- _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000;
- _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000;
+ _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME, DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME);
+ _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000;
+ _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000;
+ _costPerEmbeddingToken = DOLLAR_COST_PER_1M_EMBEDDING_TOKENS / 1000000;
+ }
+
+ private double getNumberProp(WdkModel wdkModel, String propName, String deprecatedPropName) throws WdkModelException {
+ // First try the new property name
+ if (wdkModel.getProperties().get(propName) != null) {
+ return getNumberProp(wdkModel, propName);
+ }
+
+ // Fall back to deprecated property name with warning
+ if (wdkModel.getProperties().get(deprecatedPropName) != null) {
+ LOG.warn("WDK property '" + deprecatedPropName + "' is deprecated. Please use '" + propName + "' instead.");
+ return getNumberProp(wdkModel, deprecatedPropName);
+ }
+
+ // Neither property is set
+ throw new WdkModelException("WDK property '" + propName + "' (or deprecated '" + deprecatedPropName + "') has not been set.");
}
private double getNumberProp(WdkModel wdkModel, String propName) throws WdkModelException {
@@ -88,11 +112,11 @@ public boolean isCostExceeded() {
return updateAndGetCost(EMPTY_COST) > _maxDailyDollarCost;
}
- public void updateCost(Optional usage) {
- updateAndGetCost(usage.orElse(EMPTY_COST));
+ public void updateCost(TokenUsage usage) {
+ updateAndGetCost(usage);
}
- public double updateAndGetCost(CompletionUsage usageCost) {
+ public double updateAndGetCost(TokenUsage usageCost) {
try (DirectoryLock lock = new DirectoryLock(_costMonitoringDir, DEFAULT_TIMEOUT_MILLIS, DEFAULT_POLL_FREQUENCY_MILLIS)) {
// read current values from file
@@ -103,7 +127,8 @@ public double updateAndGetCost(CompletionUsage usageCost) {
// calculate cost of the current usage
double additionalCost =
(usageCost.promptTokens() * _costPerInputToken) +
- (usageCost.completionTokens() * _costPerOutputToken);
+ (usageCost.completionTokens() * _costPerOutputToken) +
+ (usageCost.embeddingTokens() * _costPerEmbeddingToken);
// reset cost to zero if date has rolled over to the next day
String newDate = getCurrentDateString();
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java
index 26fa39bab..0b6e00653 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/GeneRecordProcessor.java
@@ -6,6 +6,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;
+import org.apache.log4j.Logger;
import org.gusdb.fgputil.EncryptionUtil;
import org.gusdb.wdk.model.WdkModelException;
import org.gusdb.wdk.model.WdkUserException;
@@ -21,9 +22,12 @@
*/
public class GeneRecordProcessor {
+ private static final Logger LOG = Logger.getLogger(GeneRecordProcessor.class);
+
private static final Set KEYS_TO_KEEP = Set.of("y_axis", "description", "genus_species",
"project_id", "summary", "dataset_id", "assay_type", "x_axis", "module", "dataset_name", "display_name",
- "short_attribution", "paralog_number");
+ "short_attribution");
+ // TODO: restore "paralog_number" to KEYS_TO_KEEP once it is reliably present in gene records again
private static final String EXPRESSION_GRAPH_TABLE = "ExpressionGraphs";
private static final String EXPRESSION_GRAPH_DATA_TABLE = "ExpressionGraphsDataTable";
@@ -32,7 +36,7 @@ public class GeneRecordProcessor {
// Increment this to invalidate all previous cache entries:
// (for example if changing first level model outputs rather than inputs which are already digestified)
- private static final String DATA_MODEL_VERSION = "v3b";
+ private static final String DATA_MODEL_VERSION = "v4";
public interface ExperimentInputs {
@@ -62,7 +66,7 @@ private static String getGeneId(RecordInstance record) {
return record.getPrimaryKey().getValues().get("source_id");
}
- public static GeneSummaryInputs getSummaryInputsFromRecord(RecordInstance record, String aiChatModel, Function getExperimentPrompt, Function, String> getFinalSummaryPrompt) throws WdkModelException {
+ public static GeneSummaryInputs getSummaryInputsFromRecord(RecordInstance record, String aiChatModel, String embeddingModel, boolean makeTopicEmbeddings, boolean useExtendedThinking, Function getExperimentPrompt, Function, String> getFinalSummaryPrompt) throws WdkModelException {
String geneId = getGeneId(record);
@@ -90,7 +94,7 @@ public String getDigest() {
List digests = experimentsWithData.stream()
.map(exp -> new JSONObject().put("digest", exp.getDigest()))
.collect(Collectors.toList());
- return EncryptionUtil.md5(aiChatModel + ":" + DATA_MODEL_VERSION + ":" + getFinalSummaryPrompt.apply(digests));
+ return EncryptionUtil.md5(aiChatModel + ":" + embeddingModel + ":" + makeTopicEmbeddings + ":" + useExtendedThinking + ":" + DATA_MODEL_VERSION + ":" + getFinalSummaryPrompt.apply(digests));
}
};
@@ -114,6 +118,16 @@ private static List processExpressionData(RecordInstance recor
experimentInfo.put(key, experimentRow.getAttributeValue(key).getValue());
}
+ // TODO: remove this fallback once paralog_number is reliably present in gene records again;
+ // restore "paralog_number" to KEYS_TO_KEEP above and delete this block.
+ try {
+ experimentInfo.put("paralog_number", experimentRow.getAttributeValue("paralog_number").getValue());
+ } catch (WdkModelException e) {
+ LOG.warn("paralog_number attribute is missing from gene record; defaulting to 0. " +
+ "This is a temporary workaround - restore it to KEYS_TO_KEEP once the field is available again.");
+ experimentInfo.put("paralog_number", "0");
+ }
+
String datasetId = experimentRow.getAttributeValue("dataset_id").getValue();
String assayType = experimentRow.getAttributeValue("assay_type").getValue();
String experimentName = experimentRow.getAttributeValue("display_name").getValue();
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java
new file mode 100644
index 000000000..bcbd4cbab
--- /dev/null
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java
@@ -0,0 +1,77 @@
+package org.apidb.apicommon.model.report.ai.expression;
+
+import java.util.concurrent.CompletableFuture;
+
+import org.gusdb.wdk.model.WdkModel;
+import org.gusdb.wdk.model.WdkModelException;
+
+import com.openai.client.OpenAIClientAsync;
+import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+import com.openai.models.ChatModel;
+import com.openai.models.ResponseFormatJsonSchema;
+import com.openai.models.ResponseFormatJsonSchema.JsonSchema;
+
+public class OpenAISummarizer extends Summarizer {
+
+ // provide exact model number for semi-reproducibility
+ public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06;
+
+ private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY";
+
+ private final OpenAIClientAsync _openAIClient;
+
+ public OpenAISummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor, boolean makeTopicEmbeddings) throws WdkModelException {
+ super(wdkModel, costMonitor, makeTopicEmbeddings);
+
+ String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME);
+ if (apiKey == null) {
+ throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set.");
+ }
+
+ _openAIClient = OpenAIOkHttpClientAsync.builder()
+ .apiKey(apiKey)
+ .maxRetries(32) // Handle 429 errors
+ .build();
+ }
+
+ @Override
+ protected CompletableFuture callApiForJson(String prompt, com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema schema) {
+ ChatCompletionCreateParams request = ChatCompletionCreateParams.builder()
+ .model(OPENAI_CHAT_MODEL)
+ .maxCompletionTokens(MAX_RESPONSE_TOKENS)
+ .responseFormat(ResponseFormatJsonSchema.builder()
+ .jsonSchema(JsonSchema.builder()
+ .name("structured-response")
+ .schema(schema)
+ .strict(true)
+ .build())
+ .build())
+ .addSystemMessage(SYSTEM_MESSAGE)
+ .addUserMessage(prompt)
+ .build();
+
+ return retryOnOverload(
+ () -> _openAIClient.chat().completions().create(request),
+ e -> e instanceof com.openai.errors.InternalServerException,
+ "OpenAI API call"
+ ).thenApply(completion -> {
+ // update cost accumulator - convert to TokenUsage
+ completion.usage().ifPresent(usage -> {
+ TokenUsage tokenUsage = TokenUsage.builder()
+ .promptTokens(usage.promptTokens())
+ .completionTokens(usage.completionTokens())
+ .build();
+ _costMonitor.updateCost(tokenUsage);
+ });
+
+ // return JSON string
+ return completion.choices().get(0).message().content().get();
+ });
+ }
+
+ @Override
+ protected void updateCostMonitor(Object apiResponse) {
+ // OpenAI response handling is done in callApiForJson
+ }
+}
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java
index 4b119fe3c..e64d8e27c 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java
@@ -7,6 +7,7 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
+import java.util.stream.Collectors;
import org.apache.log4j.Logger;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.ExperimentInputs;
@@ -20,25 +21,29 @@
import com.openai.client.OpenAIClientAsync;
import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
import com.openai.core.JsonValue;
+import com.openai.models.embeddings.EmbeddingCreateParams;
+import com.openai.models.embeddings.EmbeddingModel;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import com.openai.models.ChatModel;
import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.ResponseFormatJsonSchema.JsonSchema;
import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema;
-public class Summarizer {
+public abstract class Summarizer {
- // provide exact model number for semi-reproducibility
- public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06;
+ protected static final int MAX_RESPONSE_TOKENS = 10000;
- private static final int MAX_RESPONSE_TOKENS = 10000;
+ public static final EmbeddingModel EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_3_SMALL;
+ private static final int EMBEDDING_DIMENSIONS = 512;
+ private static final int EMBEDDING_DECIMAL_PLACES = 4;
private static final int MAX_MALFORMED_RESPONSE_RETRIES = 3;
+ private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY";
- private static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data";
+ protected static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data";
// Prepare JSON schemas for structured responses
- private static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder()
+ protected static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder()
.putAdditionalProperty("type", JsonValue.from("object"))
.putAdditionalProperty("properties", JsonValue.from(Map.of(
"one_sentence_summary", Map.of("type", "string"),
@@ -57,7 +62,7 @@ public class Summarizer {
.putAdditionalProperty("additionalProperties", JsonValue.from(false))
.build();
- private static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder()
+ protected static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder()
.putAdditionalProperty("type", JsonValue.from("object"))
.putAdditionalProperty("properties", JsonValue.from(Map.of(
"headline", Map.of("type", "string"),
@@ -81,26 +86,139 @@ public class Summarizer {
.putAdditionalProperty("additionalProperties", JsonValue.from(false))
.build();
- private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY";
-
- private final OpenAIClientAsync _openAIClient;
- private final DailyCostMonitor _costMonitor;
+ protected final DailyCostMonitor _costMonitor;
+ private final OpenAIClientAsync _embeddingClient;
+ protected final boolean _makeTopicEmbeddings;
private static final Logger LOG = Logger.getLogger(Summarizer.class);
- public Summarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException {
+ public Summarizer(WdkModel wdkModel, DailyCostMonitor costMonitor, boolean makeTopicEmbeddings) throws WdkModelException {
+ _costMonitor = costMonitor;
+ _makeTopicEmbeddings = makeTopicEmbeddings;
- String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME);
- if (apiKey == null) {
- throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set.");
+ // Only create embedding client if we need to make topic embeddings
+ if (makeTopicEmbeddings) {
+ String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME);
+ if (apiKey == null) {
+ throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set.");
+ }
+
+ _embeddingClient = OpenAIOkHttpClientAsync.builder()
+ .apiKey(apiKey)
+ .maxRetries(32) // Handle 429 errors
+ .build();
+ } else {
+ _embeddingClient = null;
}
+ }
- _openAIClient = OpenAIOkHttpClientAsync.builder()
- .apiKey(apiKey)
- .maxRetries(32) // Handle 429 errors
+ private CompletableFuture> getEmbedding(String text) {
+ // Safety check: ensure embedding client was initialized
+ if (_embeddingClient == null) {
+ LOG.error("Attempted to generate embedding but embedding client was not initialized (makeTopicEmbeddings=false)");
+ return CompletableFuture.completedFuture(List.of());
+ }
+
+ EmbeddingCreateParams request = EmbeddingCreateParams.builder()
+ .model(EMBEDDING_MODEL)
+ .input(text)
+ .dimensions(EMBEDDING_DIMENSIONS)
.build();
- _costMonitor = costMonitor;
+ return _embeddingClient.embeddings().create(request).thenApply(response -> {
+ // Update cost monitor - convert embedding usage to TokenUsage
+ com.openai.models.embeddings.CreateEmbeddingResponse.Usage embeddingUsage = response.usage();
+ TokenUsage tokenUsage = TokenUsage.builder()
+ .embeddingTokens(embeddingUsage.totalTokens())
+ .build();
+ _costMonitor.updateCost(tokenUsage);
+
+ // Extract embedding vector from first result (convert Float to Double)
+ List rawEmbedding = response.data().get(0).embedding();
+
+ // Round to specified decimal places
+ double scale = Math.pow(10, EMBEDDING_DECIMAL_PLACES);
+ return rawEmbedding.stream()
+ .map(val -> Math.round(val.doubleValue() * scale) / scale)
+ .collect(Collectors.toList());
+ }).exceptionally(e -> {
+ LOG.error("Failed to generate embedding: " + e.getMessage(), e);
+ return List.of(); // Return empty list on error
+ });
+ }
+
+ /**
+ * Retries an operation with exponential backoff if it fails with a retriable error.
+ *
+ * @param the return type of the operation
+ * @param operation supplier that produces the CompletableFuture to execute
+ * @param shouldRetry predicate to determine if an exception should trigger a retry
+ * @param operationDescription description for logging purposes
+ * @return CompletableFuture with the result of the operation
+ */
+ protected CompletableFuture retryOnOverload(
+ java.util.function.Supplier> operation,
+ java.util.function.Predicate shouldRetry,
+ String operationDescription) {
+
+ final int maxRetries = 3;
+ final long[] backoffDelaysMs = {1000, 2000, 4000}; // 1s, 2s, 4s
+
+ return retryWithBackoff(operation, shouldRetry, operationDescription, 0, maxRetries, backoffDelaysMs);
+ }
+
+ private CompletableFuture retryWithBackoff(
+ java.util.function.Supplier> operation,
+ java.util.function.Predicate shouldRetry,
+ String operationDescription,
+ int attemptNumber,
+ int maxRetries,
+ long[] backoffDelaysMs) {
+
+ CompletableFuture result = new CompletableFuture<>();
+
+ operation.get().whenComplete((value, throwable) -> {
+ if (throwable == null) {
+ // Success case
+ result.complete(value);
+ } else {
+ // Error case - unwrap CompletionException to get the actual cause
+ Throwable actualCause = throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null
+ ? throwable.getCause()
+ : throwable;
+
+ // Check if we should retry this exception and haven't exceeded max retries
+ if (shouldRetry.test(actualCause) && attemptNumber < maxRetries) {
+ long delayMs = backoffDelaysMs[attemptNumber];
+ LOG.warn(String.format(
+ "Retrying %s after error (attempt %d/%d, waiting %dms): %s",
+ operationDescription, attemptNumber + 1, maxRetries, delayMs, actualCause.getMessage()));
+
+ // Schedule retry after delay
+ new java.util.Timer().schedule(new java.util.TimerTask() {
+ @Override
+ public void run() {
+ retryWithBackoff(operation, shouldRetry, operationDescription, attemptNumber + 1, maxRetries, backoffDelaysMs)
+ .whenComplete((retryValue, retryError) -> {
+ if (retryError != null) {
+ result.completeExceptionally(retryError);
+ } else {
+ result.complete(retryValue);
+ }
+ });
+ }
+ }, delayMs);
+ } else {
+ // No more retries or non-retriable exception
+ if (attemptNumber >= maxRetries) {
+ LOG.error(String.format("Failed %s after %d retries: %s", operationDescription, maxRetries, actualCause.getMessage()));
+ }
+ result.completeExceptionally(throwable);
+ }
+ }
+ });
+
+ return result;
}
public static String getExperimentMessage(JSONObject experiment) {
@@ -133,12 +251,9 @@ public static String getExperimentMessage(JSONObject experiment) {
public CompletableFuture describeExperiment(ExperimentInputs experimentInputs) {
- ChatCompletionCreateParams request = buildAiRequest(
- "experiment-summary",
- experimentResponseSchema,
- getExperimentMessage(experimentInputs.getExperimentData()));
+ String prompt = getExperimentMessage(experimentInputs.getExperimentData());
- return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), request, json -> {
+ return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), prompt, experimentResponseSchema, json -> {
// add some fields to the result to aid the final summarization
return json
.put("dataset_id", experimentInputs.getDatasetId())
@@ -151,6 +266,7 @@ public static String getFinalSummaryMessage(List experiments) {
return "Below are AI-generated summaries of one gene's behavior in all the transcriptomics experiments available in VEuPathDB, provided in JSON format:\n\n" +
String.format("```json\n%s\n```\n\n", new JSONArray(experiments).toString(2)) +
"Generate a one-paragraph summary (~100 words) describing the gene's expression. Structure it using , , and - tags with no attributes. If relevant, briefly speculate on the gene's potential function, but only if justified by the data. Also, generate a short, specific headline for the summary. The headline must reflect this gene's expression and **must not** include generic phrases like \"comprehensive insights into\" or the word \"gene\".\n\n" +
+ "Use sentence case for all headlines: capitalize only the first word and proper nouns, not every word.\n\n" +
"Additionally, group the per-experiment summaries (identified by `dataset_id`) with `biological_importance > 3` and `confidence > 3` into sections by topic. For each topic, provide:\n" +
"- A headline summarizing the key experimental results within the topic\n" +
"- A concise one-sentence summary of the topic's experimental results\n\n" +
@@ -159,18 +275,17 @@ public static String getFinalSummaryMessage(List experiments) {
public JSONObject summarizeExperiments(String geneId, List experiments) {
- ChatCompletionCreateParams request = buildAiRequest(
- "expression-summary",
- finalResponseSchema,
- getFinalSummaryMessage(experiments));
+ String prompt = getFinalSummaryMessage(experiments);
- return getValidatedAiResponse("summary for gene " + geneId, request, json ->
+ return getValidatedAiResponse("summary for gene " + geneId, prompt, finalResponseSchema, json ->
+ json // Return json as-is; consolidateSummary will be called separately
+ ).thenCompose(json ->
// quality control (remove bad `dataset_id`s) and add 'Others' section for any experiments not listed by AI
consolidateSummary(json, experiments)
).join();
}
- private static JSONObject consolidateSummary(JSONObject summaryResponse,
+ private CompletableFuture consolidateSummary(JSONObject summaryResponse,
List individualResults) {
// Gather all dataset IDs from individualResults and map them to summaries.
// Preserving the order of individualResults.
@@ -180,7 +295,8 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse,
}
Set seenDatasetIds = new LinkedHashSet<>();
- JSONArray deduplicatedTopics = new JSONArray();
+ List deduplicatedTopicsList = new java.util.ArrayList<>();
+ List> embeddingFutures = new java.util.ArrayList<>();
JSONArray topics = summaryResponse.getJSONArray("topics");
for (int i = 0; i < topics.length(); i++) {
@@ -210,7 +326,21 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse,
if (summaries.length() > 0) {
topic.put("summaries", summaries);
topic.remove("dataset_ids");
- deduplicatedTopics.put(topic);
+ deduplicatedTopicsList.add(topic);
+
+ // Generate embedding for non-"Other" topics (if enabled)
+ if (_makeTopicEmbeddings) {
+ String headline = topic.optString("headline", "");
+ if (!headline.equals("Other")) {
+ String embeddingText = headline + "\n\n" + topic.optString("one_sentence_summary", "");
+ CompletableFuture embeddingFuture = getEmbedding(embeddingText).thenAccept(embedding -> {
+ if (!embedding.isEmpty()) {
+ topic.put("embedding_vector", embedding);
+ }
+ });
+ embeddingFutures.add(embeddingFuture);
+ }
+ }
}
}
@@ -230,44 +360,38 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse,
otherTopic.put("one_sentence_summary",
"The AI ordered these experiments by biological importance but did not group them into topics.");
otherTopic.put("summaries", otherSummaries);
- deduplicatedTopics.put(otherTopic);
+ deduplicatedTopicsList.add(otherTopic);
+ // Note: no embedding for "Other" topic
}
- // Create final deduplicated summary
- JSONObject finalSummary = new JSONObject(summaryResponse.toString());
- finalSummary.put("topics", deduplicatedTopics);
- return finalSummary;
+ // Wait for all embeddings to complete, then create final summary
+ return CompletableFuture.allOf(embeddingFutures.toArray(new CompletableFuture[0]))
+ .thenApply(v -> {
+ // Convert deduplicated topics list back to JSONArray
+ JSONArray deduplicatedTopics = new JSONArray();
+ for (JSONObject topic : deduplicatedTopicsList) {
+ deduplicatedTopics.put(topic);
+ }
+
+ // Create final deduplicated summary
+ JSONObject finalSummary = new JSONObject(summaryResponse.toString());
+ finalSummary.put("topics", deduplicatedTopics);
+ return finalSummary;
+ });
}
- private static ChatCompletionCreateParams buildAiRequest(String name, Schema schema, String userMessage) {
- return ChatCompletionCreateParams.builder()
- .model(OPENAI_CHAT_MODEL)
- .maxCompletionTokens(MAX_RESPONSE_TOKENS)
- .responseFormat(ResponseFormatJsonSchema.builder()
- .jsonSchema(JsonSchema.builder()
- .name(name)
- .schema(schema)
- .strict(true)
- .build())
- .build())
- .addSystemMessage(SYSTEM_MESSAGE)
- .addUserMessage(userMessage)
- .build();
- }
+ protected abstract CompletableFuture callApiForJson(String prompt, Schema schema);
+
+ protected abstract void updateCostMonitor(Object apiResponse);
private CompletableFuture getValidatedAiResponse(
String operationDescription,
- ChatCompletionCreateParams request,
+ String prompt,
+ Schema schema,
Function createFinalJson
) {
- return _openAIClient.chat().completions().create(request).thenApply(completion -> {
-
- // update cost accumulator
- _costMonitor.updateCost(completion.usage());
-
- // expect response to be a JSON string
- String jsonString = completion.choices().get(0).message().content().get();
+ return callApiForJson(prompt, schema).thenApply(jsonString -> {
int attempts = 1;
Exception mostRecentError;
@@ -281,12 +405,10 @@ private CompletableFuture getValidatedAiResponse(
}
catch (JSONException e) {
mostRecentError = e;
- LOG.warn("Malformed JSON from OpenAI (attempt " + attempts + ") for " + operationDescription + ". Retrying...");
+ LOG.warn("Malformed JSON from AI (attempt " + attempts + ") for " + operationDescription + ". Retrying...");
- // Re-request from OpenAI
- completion = _openAIClient.chat().completions().create(request).join();
- _costMonitor.updateCost(completion.usage());
- jsonString = completion.choices().get(0).message().content().get();
+ // Re-request from AI
+ jsonString = callApiForJson(prompt, schema).join();
attempts++;
}
}
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/TokenUsage.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/TokenUsage.java
new file mode 100644
index 000000000..ea76ca3f8
--- /dev/null
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/TokenUsage.java
@@ -0,0 +1,59 @@
+package org.apidb.apicommon.model.report.ai.expression;
+
+/**
+ * Represents token usage for AI API calls, including chat completions and embeddings.
+ * Immutable value object with builder pattern.
+ */
+public class TokenUsage {
+
+ private final long promptTokens;
+ private final long completionTokens;
+ private final long embeddingTokens;
+
+ private TokenUsage(Builder builder) {
+ this.promptTokens = builder.promptTokens;
+ this.completionTokens = builder.completionTokens;
+ this.embeddingTokens = builder.embeddingTokens;
+ }
+
+ public long promptTokens() {
+ return promptTokens;
+ }
+
+ public long completionTokens() {
+ return completionTokens;
+ }
+
+ public long embeddingTokens() {
+ return embeddingTokens;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+ private long promptTokens = 0;
+ private long completionTokens = 0;
+ private long embeddingTokens = 0;
+
+ public Builder promptTokens(long promptTokens) {
+ this.promptTokens = promptTokens;
+ return this;
+ }
+
+ public Builder completionTokens(long completionTokens) {
+ this.completionTokens = completionTokens;
+ return this;
+ }
+
+ public Builder embeddingTokens(long embeddingTokens) {
+ this.embeddingTokens = embeddingTokens;
+ return this;
+ }
+
+ public TokenUsage build() {
+ return new TokenUsage(this);
+ }
+ }
+}