Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/api/Elastic.Documentation.Api.Core/AskAi/AskAiUsecase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace Elastic.Documentation.Api.Core.AskAi;

public class AskAiUsecase(
IAskAiGateway<Stream> askAiGateway,
IAskAiGateway askAiGateway,
IStreamTransformer streamTransformer,
ILogger<AskAiUsecase> logger)
{
Expand All @@ -24,6 +24,7 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx)
_ = activity?.SetTag("gen_ai.agent.id", streamTransformer.AgentId); // docs-agent or docs_assistant
if (askAiRequest.ConversationId is not null)
_ = activity?.SetTag("gen_ai.conversation.id", askAiRequest.ConversationId.ToString());

var inputMessages = new[]
{
new InputMessage("user", [new MessagePart("text", askAiRequest.Message)])
Expand All @@ -33,9 +34,22 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx)
var sanitizedMessage = askAiRequest.Message?.Replace("\r", "").Replace("\n", "");
logger.LogInformation("AskAI input message: <{ask_ai.input.message}>", sanitizedMessage);
logger.LogInformation("Streaming AskAI response");
var rawStream = await askAiGateway.AskAi(askAiRequest, ctx);
// The stream transformer will handle disposing the activity when streaming completes
var transformedStream = await streamTransformer.TransformAsync(rawStream, askAiRequest.ConversationId?.ToString(), activity, ctx);

// Gateway handles conversation ID generation if needed
var response = await askAiGateway.AskAi(askAiRequest, ctx);

// Use generated ID if available, otherwise use the original request ID
var conversationId = response.GeneratedConversationId ?? askAiRequest.ConversationId;
if (conversationId is not null)
_ = activity?.SetTag("gen_ai.conversation.id", conversationId.ToString());

// The stream transformer takes ownership of the activity and disposes it when streaming completes.
// This is necessary because streaming happens asynchronously after this method returns.
var transformedStream = await streamTransformer.TransformAsync(
response.Stream,
response.GeneratedConversationId,
activity,
ctx);
return transformedStream;
}
}
Expand Down
15 changes: 13 additions & 2 deletions src/api/Elastic.Documentation.Api.Core/AskAi/IAskAiGateway.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@

namespace Elastic.Documentation.Api.Core.AskAi;

public interface IAskAiGateway<T>
/// <summary>
/// Response from an AI gateway containing the stream and conversation metadata
/// </summary>
/// <param name="Stream">The SSE response stream</param>
/// <param name="GeneratedConversationId">
/// Non-null ONLY if the gateway generated a new conversation ID for this request.
/// When set, the transformer should emit a ConversationStart event with this ID.
/// Null means either: (1) user provided an ID (continuing conversation), or (2) gateway doesn't generate IDs (e.g., Agent Builder).
/// </param>
public record AskAiGatewayResponse(Stream Stream, Guid? GeneratedConversationId);

public interface IAskAiGateway
{
Task<T> AskAi(AskAiRequest askAiRequest, Cancel ctx = default);
Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ public interface IStreamTransformer
/// Transforms a raw SSE stream into a stream of AskAiEvent objects
/// </summary>
/// <param name="rawStream">Raw SSE stream from gateway (Agent Builder, LLM Gateway, etc.)</param>
/// <param name="conversationId">Thread/conversation ID (if known)</param>
/// <param name="generatedConversationId">
/// Non-null if the gateway generated a new conversation ID (LLM Gateway only).
/// When set, transformer should emit ConversationStart event with this ID.
/// </param>
/// <param name="parentActivity">Parent activity to track the streaming operation (will be disposed when stream completes)</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Stream containing SSE-formatted AskAiEvent objects</returns>
Task<Stream> TransformAsync(Stream rawStream, string? conversationId, System.Diagnostics.Activity? parentActivity, CancellationToken cancellationToken = default);
Task<Stream> TransformAsync(Stream rawStream, Guid? generatedConversationId, System.Diagnostics.Activity? parentActivity, CancellationToken cancellationToken = default);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;

public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kibanaOptions, ILogger<AgentBuilderAskAiGateway> logger) : IAskAiGateway<Stream>
public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kibanaOptions, ILogger<AgentBuilderAskAiGateway> logger) : IAskAiGateway
{
/// <summary>
/// Model name used by Agent Builder (from AgentId)
Expand All @@ -23,8 +23,10 @@ public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kiban
/// Provider name for tracing
/// </summary>
public const string ProviderName = "agent-builder";
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
public async Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
{
// Agent Builder returns the conversation ID in the stream via conversation_id_set event
// We don't generate IDs - Agent Builder handles that in the stream
var agentBuilderPayload = new AgentBuilderPayload(
askAiRequest.Message,
"docs-agent",
Expand Down Expand Up @@ -55,7 +57,9 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
logger.LogInformation("Response Content-Length: {ContentLength}", response.Content.Headers.ContentLength?.ToString(CultureInfo.InvariantCulture));

// Agent Builder already returns SSE format, just return the stream directly
return await response.Content.ReadAsStreamAsync(ctx);
// The conversation ID will be extracted from the stream by the transformer
var stream = await response.Content.ReadAsStreamAsync(ctx);
return new AskAiGatewayResponse(stream, GeneratedConversationId: null);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
public class AskAiGatewayFactory(
IServiceProvider serviceProvider,
AskAiProviderResolver providerResolver,
ILogger<AskAiGatewayFactory> logger) : IAskAiGateway<Stream>
ILogger<AskAiGatewayFactory> logger) : IAskAiGateway
{
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
public async Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
{
var provider = providerResolver.ResolveProvider();

IAskAiGateway<Stream> gateway = provider switch
IAskAiGateway gateway = provider switch
{
"LlmGateway" => serviceProvider.GetRequiredService<LlmGatewayAskAiGateway>(),
"AgentBuilder" => serviceProvider.GetRequiredService<AgentBuilderAskAiGateway>(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;

public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway<Stream>
public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway
{
/// <summary>
/// Model name used by LLM Gateway (from PlatformContext.UseCase)
Expand All @@ -21,9 +21,13 @@ public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider t
/// Provider name for tracing
/// </summary>
public const string ProviderName = "llm-gateway";
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
public async Task<AskAiGatewayResponse> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
{
var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest);
// LLM Gateway requires a ThreadId - generate one if not provided
var generatedId = askAiRequest.ConversationId is null ? Guid.NewGuid() : (Guid?)null;
var threadId = askAiRequest.ConversationId ?? generatedId!.Value;

var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest, threadId);
var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest);
using var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl);
request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json");
Expand All @@ -46,7 +50,8 @@ public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)

// Return the response stream directly - this enables true streaming
// The stream will be consumed as data arrives from the LLM Gateway
return await response.Content.ReadAsStreamAsync(ctx);
var stream = await response.Content.ReadAsStreamAsync(ctx);
return new AskAiGatewayResponse(stream, generatedId);
}
}

Expand All @@ -57,7 +62,7 @@ public record LlmGatewayRequest(
string ThreadId
)
{
public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
public static LlmGatewayRequest CreateFromRequest(AskAiRequest request, Guid conversationId) =>
new(
UserContext: new UserContext("elastic-docs-v3@invalid"),
PlatformContext: new PlatformContext("docs_site", "docs_assistant", []),
Expand All @@ -66,7 +71,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
// new ChatInput("user", AskAiRequest.SystemPrompt),
new ChatInput("user", request.Message)
],
ThreadId: request.ConversationId?.ToString() ?? Guid.NewGuid().ToString()
ThreadId: conversationId.ToString()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,33 @@ public class LlmGatewayStreamTransformer(ILogger<LlmGatewayStreamTransformer> lo
protected override string GetAgentProvider() => LlmGatewayAskAiGateway.ProviderName;

/// <summary>
/// Override to emit ConversationStart event when conversationId is null (new conversation)
/// Override to emit ConversationStart event for new conversations.
/// LLM Gateway doesn't return a conversation ID, so we emit one to match Agent Builder behavior.
/// The generatedConversationId is the ID generated by the gateway and used as ThreadId.
/// </summary>
protected override async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken)
protected override async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken)
{
// If conversationId is null, generate a new one and emit ConversationStart event
// This matches the ThreadId format used in LlmGatewayAskAiGateway
var actualConversationId = conversationId;
if (conversationId == null)
// Emit ConversationStart event only if a new conversation ID was generated
if (generatedConversationId is not null)
{
actualConversationId = Guid.NewGuid().ToString();
var conversationId = generatedConversationId.Value.ToString();
var timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
var conversationStartEvent = new AskAiEvent.ConversationStart(
Id: Guid.NewGuid().ToString(),
Timestamp: timestamp,
ConversationId: actualConversationId
ConversationId: conversationId
);

// Set activity tags for the new conversation
_ = parentActivity?.SetTag("gen_ai.conversation.id", actualConversationId);
Logger.LogDebug("LLM Gateway conversation started: {ConversationId}", actualConversationId);
_ = parentActivity?.SetTag("gen_ai.conversation.id", conversationId);
Logger.LogDebug("LLM Gateway conversation started: {ConversationId}", conversationId);

// Write the ConversationStart event to the stream
await WriteEventAsync(conversationStartEvent, writer, cancellationToken);
}

// Continue with normal stream processing using the actual conversation ID
await base.ProcessStreamAsync(reader, writer, actualConversationId, parentActivity, cancellationToken);
// Continue with normal stream processing
await base.ProcessStreamAsync(reader, writer, generatedConversationId, parentActivity, cancellationToken);
}
protected override AskAiEvent? TransformJsonEvent(string? eventType, JsonElement json)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public abstract class StreamTransformerBase(ILogger logger) : IStreamTransformer
/// </summary>
public string AgentProvider => GetAgentProvider();

public Task<Stream> TransformAsync(Stream rawStream, string? conversationId, Activity? parentActivity, Cancel cancellationToken = default)
public Task<Stream> TransformAsync(Stream rawStream, Guid? generatedConversationId, Activity? parentActivity, Cancel cancellationToken = default)
{
// Configure pipe for low-latency streaming
var pipeOptions = new PipeOptions(
Expand All @@ -61,7 +61,7 @@ public Task<Stream> TransformAsync(Stream rawStream, string? conversationId, Act
// Note: We intentionally don't await this task as we need to return the stream immediately
// The pipe handles synchronization and backpressure between producer and consumer
// Pass parent activity - it will be disposed when streaming completes
_ = ProcessPipeAsync(reader, pipe.Writer, conversationId, parentActivity, cancellationToken);
_ = ProcessPipeAsync(reader, pipe.Writer, generatedConversationId, parentActivity, cancellationToken);

// Return the read side of the pipe as a stream
return Task.FromResult(pipe.Reader.AsStream());
Expand All @@ -71,48 +71,42 @@ public Task<Stream> TransformAsync(Stream rawStream, string? conversationId, Act
/// Process the pipe reader and write transformed events to the pipe writer.
/// This runs concurrently with the consumer reading from the output stream.
/// </summary>
private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken)
private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken)
{
using var activityScope = parentActivity;
try
{
await ProcessStreamAsync(reader, writer, generatedConversationId, parentActivity, cancellationToken);
}
catch (OperationCanceledException ex)
{
Logger.LogDebug(ex, "Stream processing was cancelled for transformer {TransformerType}", GetType().Name);
}
catch (Exception ex)
{
Logger.LogError(ex, "Error transforming stream for transformer {TransformerType}. Stream processing will be terminated.", GetType().Name);
_ = parentActivity?.SetTag("error.type", ex.GetType().Name);
try
{
await ProcessStreamAsync(reader, writer, conversationId, parentActivity, cancellationToken);
// Complete writer first, then reader - but don't try to complete reader
// if the exception came from reading (would cause "read operation pending" error)
await writer.CompleteAsync(ex);
}
catch (OperationCanceledException ex)
catch (Exception completeEx)
{
Logger.LogDebug(ex, "Stream processing was cancelled for transformer {TransformerType}", GetType().Name);
}
catch (Exception ex)
{
Logger.LogError(ex, "Error transforming stream for transformer {TransformerType}. Stream processing will be terminated.", GetType().Name);
_ = parentActivity?.SetTag("error.type", ex.GetType().Name);
try
{
// Complete writer first, then reader - but don't try to complete reader
// if the exception came from reading (would cause "read operation pending" error)
await writer.CompleteAsync(ex);
}
catch (Exception completeEx)
{
Logger.LogError(completeEx, "Error completing pipe after transformation error for transformer {TransformerType}", GetType().Name);
}
return;
Logger.LogError(completeEx, "Error completing pipe after transformation error for transformer {TransformerType}", GetType().Name);
}
return;
}
Comment on lines +85 to +100

// Normal completion - ensure cleanup happens
try
{
await writer.CompleteAsync();
}
catch (Exception ex)
{
Logger.LogError(ex, "Error completing pipe after successful transformation");
}
// Normal completion - ensure cleanup happens
try
{
await writer.CompleteAsync();
}
finally
catch (Exception ex)
{
parentActivity?.Dispose();
Logger.LogError(ex, "Error completing pipe after successful transformation");
}
}

Expand All @@ -122,7 +116,7 @@ private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, string
/// Default implementation parses SSE events and JSON, then calls TransformJsonEvent.
/// </summary>
/// <returns>Stream processing result with metrics and captured output</returns>
protected virtual async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken)
protected virtual async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken)
{
using var activity = StreamTransformerActivitySource.StartActivity("process ask_ai stream", ActivityKind.Internal);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ private IStreamTransformer GetTransformer()
public string AgentId => GetTransformer().AgentId;
public string AgentProvider => GetTransformer().AgentProvider;

public async Task<Stream> TransformAsync(Stream rawStream, string? conversationId, System.Diagnostics.Activity? parentActivity, Cancel cancellationToken = default)
public async Task<Stream> TransformAsync(Stream rawStream, Guid? generatedConversationId, System.Diagnostics.Activity? parentActivity, Cancel cancellationToken = default)
{
var transformer = GetTransformer();
return await transformer.TransformAsync(rawStream, conversationId, parentActivity, cancellationToken);
return await transformer.TransformAsync(rawStream, generatedConversationId, parentActivity, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv)
logger?.LogInformation("Both stream transformers registered as concrete types");

// Register factories as interface implementations
_ = services.AddScoped<IAskAiGateway<Stream>, AskAiGatewayFactory>();
_ = services.AddScoped<IAskAiGateway, AskAiGatewayFactory>();
_ = services.AddScoped<IStreamTransformer, StreamTransformerFactory>();
logger?.LogInformation("Gateway and transformer factories registered successfully - provider switchable via X-AI-Provider header");

Expand Down
Loading
Loading