diff --git a/src/api/Elastic.Documentation.Api.Core/AskAi/AskAiUsecase.cs b/src/api/Elastic.Documentation.Api.Core/AskAi/AskAiUsecase.cs index 12043ed7f..9c5c9f5b7 100644 --- a/src/api/Elastic.Documentation.Api.Core/AskAi/AskAiUsecase.cs +++ b/src/api/Elastic.Documentation.Api.Core/AskAi/AskAiUsecase.cs @@ -9,7 +9,7 @@ namespace Elastic.Documentation.Api.Core.AskAi; public class AskAiUsecase( - IAskAiGateway askAiGateway, + IAskAiGateway askAiGateway, IStreamTransformer streamTransformer, ILogger logger) { @@ -24,6 +24,7 @@ public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx) _ = activity?.SetTag("gen_ai.agent.id", streamTransformer.AgentId); // docs-agent or docs_assistant if (askAiRequest.ConversationId is not null) _ = activity?.SetTag("gen_ai.conversation.id", askAiRequest.ConversationId.ToString()); + var inputMessages = new[] { new InputMessage("user", [new MessagePart("text", askAiRequest.Message)]) @@ -33,9 +34,22 @@ public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx) var sanitizedMessage = askAiRequest.Message?.Replace("\r", "").Replace("\n", ""); logger.LogInformation("AskAI input message: <{ask_ai.input.message}>", sanitizedMessage); logger.LogInformation("Streaming AskAI response"); - var rawStream = await askAiGateway.AskAi(askAiRequest, ctx); - // The stream transformer will handle disposing the activity when streaming completes - var transformedStream = await streamTransformer.TransformAsync(rawStream, askAiRequest.ConversationId?.ToString(), activity, ctx); + + // Gateway handles conversation ID generation if needed + var response = await askAiGateway.AskAi(askAiRequest, ctx); + + // Use generated ID if available, otherwise use the original request ID + var conversationId = response.GeneratedConversationId ?? askAiRequest.ConversationId; + if (conversationId is not null) + _ = activity?.SetTag("gen_ai.conversation.id", conversationId.ToString()); + + // The stream transformer takes ownership of the activity and disposes it when streaming completes. + // This is necessary because streaming happens asynchronously after this method returns. + var transformedStream = await streamTransformer.TransformAsync( + response.Stream, + response.GeneratedConversationId, + activity, + ctx); return transformedStream; } } diff --git a/src/api/Elastic.Documentation.Api.Core/AskAi/IAskAiGateway.cs b/src/api/Elastic.Documentation.Api.Core/AskAi/IAskAiGateway.cs index 236bf94b1..07e516196 100644 --- a/src/api/Elastic.Documentation.Api.Core/AskAi/IAskAiGateway.cs +++ b/src/api/Elastic.Documentation.Api.Core/AskAi/IAskAiGateway.cs @@ -4,7 +4,18 @@ namespace Elastic.Documentation.Api.Core.AskAi; -public interface IAskAiGateway +/// +/// Response from an AI gateway containing the stream and conversation metadata +/// +/// The SSE response stream +/// +/// Non-null ONLY if the gateway generated a new conversation ID for this request. +/// When set, the transformer should emit a ConversationStart event with this ID. +/// Null means either: (1) user provided an ID (continuing conversation), or (2) gateway doesn't generate IDs (e.g., Agent Builder). +/// +public record AskAiGatewayResponse(Stream Stream, Guid? GeneratedConversationId); + +public interface IAskAiGateway { - Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default); + Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default); } diff --git a/src/api/Elastic.Documentation.Api.Core/AskAi/IStreamTransformer.cs b/src/api/Elastic.Documentation.Api.Core/AskAi/IStreamTransformer.cs index 56435df6d..6bda05297 100644 --- a/src/api/Elastic.Documentation.Api.Core/AskAi/IStreamTransformer.cs +++ b/src/api/Elastic.Documentation.Api.Core/AskAi/IStreamTransformer.cs @@ -23,9 +23,12 @@ public interface IStreamTransformer /// Transforms a raw SSE stream into a stream of AskAiEvent objects /// /// Raw SSE stream from gateway (Agent Builder, LLM Gateway, etc.) - /// Thread/conversation ID (if known) + /// + /// Non-null if the gateway generated a new conversation ID (LLM Gateway only). + /// When set, transformer should emit ConversationStart event with this ID. + /// /// Parent activity to track the streaming operation (will be disposed when stream completes) /// Cancellation token /// Stream containing SSE-formatted AskAiEvent objects - Task TransformAsync(Stream rawStream, string? conversationId, System.Diagnostics.Activity? parentActivity, CancellationToken cancellationToken = default); + Task TransformAsync(Stream rawStream, Guid? generatedConversationId, System.Diagnostics.Activity? parentActivity, CancellationToken cancellationToken = default); } diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AgentBuilderAskAiGateway.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AgentBuilderAskAiGateway.cs index 69387f460..6c182efde 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AgentBuilderAskAiGateway.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AgentBuilderAskAiGateway.cs @@ -12,7 +12,7 @@ namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi; -public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kibanaOptions, ILogger logger) : IAskAiGateway +public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kibanaOptions, ILogger logger) : IAskAiGateway { /// /// Model name used by Agent Builder (from AgentId) @@ -23,8 +23,10 @@ public class AgentBuilderAskAiGateway(HttpClient httpClient, KibanaOptions kiban /// Provider name for tracing /// public const string ProviderName = "agent-builder"; - public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) + public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) { + // Agent Builder returns the conversation ID in the stream via conversation_id_set event + // We don't generate IDs - Agent Builder handles that in the stream var agentBuilderPayload = new AgentBuilderPayload( askAiRequest.Message, "docs-agent", @@ -55,7 +57,9 @@ public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) logger.LogInformation("Response Content-Length: {ContentLength}", response.Content.Headers.ContentLength?.ToString(CultureInfo.InvariantCulture)); // Agent Builder already returns SSE format, just return the stream directly - return await response.Content.ReadAsStreamAsync(ctx); + // The conversation ID will be extracted from the stream by the transformer + var stream = await response.Content.ReadAsStreamAsync(ctx); + return new AskAiGatewayResponse(stream, GeneratedConversationId: null); } } diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AskAiGatewayFactory.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AskAiGatewayFactory.cs index f5e094324..ec55afc03 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AskAiGatewayFactory.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/AskAiGatewayFactory.cs @@ -14,13 +14,13 @@ namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi; public class AskAiGatewayFactory( IServiceProvider serviceProvider, AskAiProviderResolver providerResolver, - ILogger logger) : IAskAiGateway + ILogger logger) : IAskAiGateway { - public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) + public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) { var provider = providerResolver.ResolveProvider(); - IAskAiGateway gateway = provider switch + IAskAiGateway gateway = provider switch { "LlmGateway" => serviceProvider.GetRequiredService(), "AgentBuilder" => serviceProvider.GetRequiredService(), diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs index 322deacd0..153cf79d6 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs @@ -10,7 +10,7 @@ namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi; -public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway +public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway { /// /// Model name used by LLM Gateway (from PlatformContext.UseCase) @@ -21,9 +21,13 @@ public class LlmGatewayAskAiGateway(HttpClient httpClient, IGcpIdTokenProvider t /// Provider name for tracing /// public const string ProviderName = "llm-gateway"; - public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) + public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) { - var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest); + // LLM Gateway requires a ThreadId - generate one if not provided + var generatedId = askAiRequest.ConversationId is null ? Guid.NewGuid() : (Guid?)null; + var threadId = askAiRequest.ConversationId ?? generatedId!.Value; + + var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest, threadId); var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest); using var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl); request.Content = new StringContent(requestBody, Encoding.UTF8, "application/json"); @@ -46,7 +50,8 @@ public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) // Return the response stream directly - this enables true streaming // The stream will be consumed as data arrives from the LLM Gateway - return await response.Content.ReadAsStreamAsync(ctx); + var stream = await response.Content.ReadAsStreamAsync(ctx); + return new AskAiGatewayResponse(stream, generatedId); } } @@ -57,7 +62,7 @@ public record LlmGatewayRequest( string ThreadId ) { - public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) => + public static LlmGatewayRequest CreateFromRequest(AskAiRequest request, Guid conversationId) => new( UserContext: new UserContext("elastic-docs-v3@invalid"), PlatformContext: new PlatformContext("docs_site", "docs_assistant", []), @@ -66,7 +71,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) => // new ChatInput("user", AskAiRequest.SystemPrompt), new ChatInput("user", request.Message) ], - ThreadId: request.ConversationId?.ToString() ?? Guid.NewGuid().ToString() + ThreadId: conversationId.ToString() ); } diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayStreamTransformer.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayStreamTransformer.cs index d12ae3541..dbb4cfe7f 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayStreamTransformer.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayStreamTransformer.cs @@ -19,33 +19,33 @@ public class LlmGatewayStreamTransformer(ILogger lo protected override string GetAgentProvider() => LlmGatewayAskAiGateway.ProviderName; /// - /// Override to emit ConversationStart event when conversationId is null (new conversation) + /// Override to emit ConversationStart event for new conversations. + /// LLM Gateway doesn't return a conversation ID, so we emit one to match Agent Builder behavior. + /// The generatedConversationId is the ID generated by the gateway and used as ThreadId. /// - protected override async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken) + protected override async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken) { - // If conversationId is null, generate a new one and emit ConversationStart event - // This matches the ThreadId format used in LlmGatewayAskAiGateway - var actualConversationId = conversationId; - if (conversationId == null) + // Emit ConversationStart event only if a new conversation ID was generated + if (generatedConversationId is not null) { - actualConversationId = Guid.NewGuid().ToString(); + var conversationId = generatedConversationId.Value.ToString(); var timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); var conversationStartEvent = new AskAiEvent.ConversationStart( Id: Guid.NewGuid().ToString(), Timestamp: timestamp, - ConversationId: actualConversationId + ConversationId: conversationId ); // Set activity tags for the new conversation - _ = parentActivity?.SetTag("gen_ai.conversation.id", actualConversationId); - Logger.LogDebug("LLM Gateway conversation started: {ConversationId}", actualConversationId); + _ = parentActivity?.SetTag("gen_ai.conversation.id", conversationId); + Logger.LogDebug("LLM Gateway conversation started: {ConversationId}", conversationId); // Write the ConversationStart event to the stream await WriteEventAsync(conversationStartEvent, writer, cancellationToken); } - // Continue with normal stream processing using the actual conversation ID - await base.ProcessStreamAsync(reader, writer, actualConversationId, parentActivity, cancellationToken); + // Continue with normal stream processing + await base.ProcessStreamAsync(reader, writer, generatedConversationId, parentActivity, cancellationToken); } protected override AskAiEvent? TransformJsonEvent(string? eventType, JsonElement json) { diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerBase.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerBase.cs index 3079d7a77..d71d0f44f 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerBase.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerBase.cs @@ -42,7 +42,7 @@ public abstract class StreamTransformerBase(ILogger logger) : IStreamTransformer /// public string AgentProvider => GetAgentProvider(); - public Task TransformAsync(Stream rawStream, string? conversationId, Activity? parentActivity, Cancel cancellationToken = default) + public Task TransformAsync(Stream rawStream, Guid? generatedConversationId, Activity? parentActivity, Cancel cancellationToken = default) { // Configure pipe for low-latency streaming var pipeOptions = new PipeOptions( @@ -61,7 +61,7 @@ public Task TransformAsync(Stream rawStream, string? conversationId, Act // Note: We intentionally don't await this task as we need to return the stream immediately // The pipe handles synchronization and backpressure between producer and consumer // Pass parent activity - it will be disposed when streaming completes - _ = ProcessPipeAsync(reader, pipe.Writer, conversationId, parentActivity, cancellationToken); + _ = ProcessPipeAsync(reader, pipe.Writer, generatedConversationId, parentActivity, cancellationToken); // Return the read side of the pipe as a stream return Task.FromResult(pipe.Reader.AsStream()); @@ -71,48 +71,42 @@ public Task TransformAsync(Stream rawStream, string? conversationId, Act /// Process the pipe reader and write transformed events to the pipe writer. /// This runs concurrently with the consumer reading from the output stream. /// - private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken) + private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken) { + using var activityScope = parentActivity; try { + await ProcessStreamAsync(reader, writer, generatedConversationId, parentActivity, cancellationToken); + } + catch (OperationCanceledException ex) + { + Logger.LogDebug(ex, "Stream processing was cancelled for transformer {TransformerType}", GetType().Name); + } + catch (Exception ex) + { + Logger.LogError(ex, "Error transforming stream for transformer {TransformerType}. Stream processing will be terminated.", GetType().Name); + _ = parentActivity?.SetTag("error.type", ex.GetType().Name); try { - await ProcessStreamAsync(reader, writer, conversationId, parentActivity, cancellationToken); + // Complete writer first, then reader - but don't try to complete reader + // if the exception came from reading (would cause "read operation pending" error) + await writer.CompleteAsync(ex); } - catch (OperationCanceledException ex) + catch (Exception completeEx) { - Logger.LogDebug(ex, "Stream processing was cancelled for transformer {TransformerType}", GetType().Name); - } - catch (Exception ex) - { - Logger.LogError(ex, "Error transforming stream for transformer {TransformerType}. Stream processing will be terminated.", GetType().Name); - _ = parentActivity?.SetTag("error.type", ex.GetType().Name); - try - { - // Complete writer first, then reader - but don't try to complete reader - // if the exception came from reading (would cause "read operation pending" error) - await writer.CompleteAsync(ex); - } - catch (Exception completeEx) - { - Logger.LogError(completeEx, "Error completing pipe after transformation error for transformer {TransformerType}", GetType().Name); - } - return; + Logger.LogError(completeEx, "Error completing pipe after transformation error for transformer {TransformerType}", GetType().Name); } + return; + } - // Normal completion - ensure cleanup happens - try - { - await writer.CompleteAsync(); - } - catch (Exception ex) - { - Logger.LogError(ex, "Error completing pipe after successful transformation"); - } + // Normal completion - ensure cleanup happens + try + { + await writer.CompleteAsync(); } - finally + catch (Exception ex) { - parentActivity?.Dispose(); + Logger.LogError(ex, "Error completing pipe after successful transformation"); } } @@ -122,7 +116,7 @@ private async Task ProcessPipeAsync(PipeReader reader, PipeWriter writer, string /// Default implementation parses SSE events and JSON, then calls TransformJsonEvent. /// /// Stream processing result with metrics and captured output - protected virtual async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, string? conversationId, Activity? parentActivity, CancellationToken cancellationToken) + protected virtual async Task ProcessStreamAsync(PipeReader reader, PipeWriter writer, Guid? generatedConversationId, Activity? parentActivity, CancellationToken cancellationToken) { using var activity = StreamTransformerActivitySource.StartActivity("process ask_ai stream", ActivityKind.Internal); diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerFactory.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerFactory.cs index 4d9937871..d59d9c695 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerFactory.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/StreamTransformerFactory.cs @@ -39,9 +39,9 @@ private IStreamTransformer GetTransformer() public string AgentId => GetTransformer().AgentId; public string AgentProvider => GetTransformer().AgentProvider; - public async Task TransformAsync(Stream rawStream, string? conversationId, System.Diagnostics.Activity? parentActivity, Cancel cancellationToken = default) + public async Task TransformAsync(Stream rawStream, Guid? generatedConversationId, System.Diagnostics.Activity? parentActivity, Cancel cancellationToken = default) { var transformer = GetTransformer(); - return await transformer.TransformAsync(rawStream, conversationId, parentActivity, cancellationToken); + return await transformer.TransformAsync(rawStream, generatedConversationId, parentActivity, cancellationToken); } } diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs b/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs index e5b491dbb..e597e9e45 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs @@ -173,7 +173,7 @@ private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv) logger?.LogInformation("Both stream transformers registered as concrete types"); // Register factories as interface implementations - _ = services.AddScoped, AskAiGatewayFactory>(); + _ = services.AddScoped(); _ = services.AddScoped(); logger?.LogInformation("Gateway and transformer factories registered successfully - provider switchable via X-AI-Provider header"); diff --git a/tests-integration/Elastic.Documentation.Api.IntegrationTests/AskAiGatewayStreamingTests.cs b/tests-integration/Elastic.Documentation.Api.IntegrationTests/AskAiGatewayStreamingTests.cs index d60df351d..7e42d47c1 100644 --- a/tests-integration/Elastic.Documentation.Api.IntegrationTests/AskAiGatewayStreamingTests.cs +++ b/tests-integration/Elastic.Documentation.Api.IntegrationTests/AskAiGatewayStreamingTests.cs @@ -58,15 +58,15 @@ public async Task AgentBuilderGatewayDoesNotDisposeHttpResponsePrematurely() var request = new AskAiRequest("Test message", null); - // Act - get the stream from the gateway - var stream = await gateway.AskAi(request, TestContext.Current.CancellationToken); + // Act - get the response from the gateway + var response = await gateway.AskAi(request, TestContext.Current.CancellationToken); // Assert - the stream should be readable (not disposed) - stream.Should().NotBeNull(); - stream.CanRead.Should().BeTrue("stream should not be disposed by the gateway"); + response.Should().NotBeNull(); + response.Stream.CanRead.Should().BeTrue("stream should not be disposed by the gateway"); // Read the entire stream to verify it works - using var reader = new StreamReader(stream); + using var reader = new StreamReader(response.Stream); var content = await reader.ReadToEndAsync(TestContext.Current.CancellationToken); content.Should().NotBeEmpty(); @@ -113,14 +113,14 @@ public async Task AgentBuilderGatewayAllowsMultipleReadsFromStream() var request = new AskAiRequest("Test", null); - // Act - get the stream and read it in chunks - var stream = await gateway.AskAi(request, TestContext.Current.CancellationToken); + // Act - get the response and read it in chunks + var response = await gateway.AskAi(request, TestContext.Current.CancellationToken); var chunks = new List(); var buffer = new byte[16]; // Small buffer to force multiple reads int bytesRead; - while ((bytesRead = await stream.ReadAsync(buffer.AsMemory(0, buffer.Length), TestContext.Current.CancellationToken)) > 0) + while ((bytesRead = await response.Stream.ReadAsync(buffer.AsMemory(0, buffer.Length), TestContext.Current.CancellationToken)) > 0) { var chunk = Encoding.UTF8.GetString(buffer, 0, bytesRead); chunks.Add(chunk); @@ -168,15 +168,15 @@ public async Task LlmGatewayDoesNotDisposeHttpResponsePrematurely() var request = new AskAiRequest("Test message", null); - // Act - get the stream from the gateway - var stream = await gateway.AskAi(request, TestContext.Current.CancellationToken); + // Act - get the response from the gateway + var response = await gateway.AskAi(request, TestContext.Current.CancellationToken); // Assert - the stream should be readable (not disposed) - stream.Should().NotBeNull(); - stream.CanRead.Should().BeTrue("stream should not be disposed by the gateway"); + response.Should().NotBeNull(); + response.Stream.CanRead.Should().BeTrue("stream should not be disposed by the gateway"); // Read the entire stream to verify it works - using var reader = new StreamReader(stream); + using var reader = new StreamReader(response.Stream); var content = await reader.ReadToEndAsync(TestContext.Current.CancellationToken); content.Should().NotBeEmpty(); @@ -228,14 +228,14 @@ public async Task LlmGatewayGatewayAllowsMultipleReadsFromStream() var request = new AskAiRequest("Test", null); - // Act - get the stream and read it in chunks - var stream = await gateway.AskAi(request, TestContext.Current.CancellationToken); + // Act - get the response and read it in chunks + var response = await gateway.AskAi(request, TestContext.Current.CancellationToken); var chunks = new List(); var buffer = new byte[16]; // Small buffer to force multiple reads int bytesRead; - while ((bytesRead = await stream.ReadAsync(buffer.AsMemory(0, buffer.Length), TestContext.Current.CancellationToken)) > 0) + while ((bytesRead = await response.Stream.ReadAsync(buffer.AsMemory(0, buffer.Length), TestContext.Current.CancellationToken)) > 0) { var chunk = Encoding.UTF8.GetString(buffer, 0, bytesRead); chunks.Add(chunk); @@ -270,16 +270,16 @@ public async Task AgentBuilderGatewayUsesResponseHeadersReadForStreaming() var request = new AskAiRequest("Test", null); // Act - var stream = await gateway.AskAi(request, TestContext.Current.CancellationToken); + var response = await gateway.AskAi(request, TestContext.Current.CancellationToken); // Assert - stream.Should().NotBeNull(); - stream.CanRead.Should().BeTrue(); + response.Should().NotBeNull(); + response.Stream.CanRead.Should().BeTrue(); // The fact that we can immediately read from the stream indicates // that ResponseHeadersRead was used (otherwise it would buffer) var buffer = new byte[10]; - var bytesRead = await stream.ReadAsync(buffer.AsMemory(0, buffer.Length), TestContext.Current.CancellationToken); + var bytesRead = await response.Stream.ReadAsync(buffer.AsMemory(0, buffer.Length), TestContext.Current.CancellationToken); bytesRead.Should().BeGreaterThan(0, "stream should be readable immediately"); } } diff --git a/tests-integration/Elastic.Documentation.Api.IntegrationTests/EuidEnrichmentIntegrationTests.cs b/tests-integration/Elastic.Documentation.Api.IntegrationTests/EuidEnrichmentIntegrationTests.cs index 3dfd28bd3..069a02f06 100644 --- a/tests-integration/Elastic.Documentation.Api.IntegrationTests/EuidEnrichmentIntegrationTests.cs +++ b/tests-integration/Elastic.Documentation.Api.IntegrationTests/EuidEnrichmentIntegrationTests.cs @@ -51,13 +51,13 @@ public async Task AskAiEndpointPropagatatesEuidToAllSpansAndLogs() using var factory = ApiWebApplicationFactory.WithMockedServices(services => { // Mock IAskAiGateway to avoid external AI service calls - var mockAskAiGateway = A.Fake>(); + var mockAskAiGateway = A.Fake(); A.CallTo(() => mockAskAiGateway.AskAi(A._, A._)) .ReturnsLazily(() => { var stream = new MemoryStream(Encoding.UTF8.GetBytes("data: test\n\n")); mockStreams.Add(stream); - return Task.FromResult(stream); + return Task.FromResult(new AskAiGatewayResponse(stream, GeneratedConversationId: Guid.NewGuid())); }); services.AddSingleton(mockAskAiGateway); @@ -65,8 +65,8 @@ public async Task AskAiEndpointPropagatatesEuidToAllSpansAndLogs() var mockTransformer = A.Fake(); A.CallTo(() => mockTransformer.AgentProvider).Returns("test-provider"); A.CallTo(() => mockTransformer.AgentId).Returns("test-agent"); - A.CallTo(() => mockTransformer.TransformAsync(A._, A._, A._, A._)) - .ReturnsLazily((Stream s, string? _, Activity? activity, Cancel _) => + A.CallTo(() => mockTransformer.TransformAsync(A._, A._, A._, A._)) + .ReturnsLazily((Stream s, Guid? _, Activity? activity, Cancel _) => { // Dispose the activity if provided (simulating what the real transformer does) activity?.Dispose(); diff --git a/tests/Elastic.Documentation.Api.Infrastructure.Tests/Adapters/AskAi/StreamTransformerTests.cs b/tests/Elastic.Documentation.Api.Infrastructure.Tests/Adapters/AskAi/StreamTransformerTests.cs index 384ffaaa2..ad531c32c 100644 --- a/tests/Elastic.Documentation.Api.Infrastructure.Tests/Adapters/AskAi/StreamTransformerTests.cs +++ b/tests/Elastic.Documentation.Api.Infrastructure.Tests/Adapters/AskAi/StreamTransformerTests.cs @@ -76,8 +76,8 @@ public async Task TransformAsyncWithRealAgentBuilderPayloadParsesAllEventTypes() var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); - // Act - var outputStream = await _transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + // Act - No generated ID (Agent Builder handles IDs in stream) + var outputStream = await _transformer.TransformAsync(inputStream, generatedConversationId: null, null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); // Assert @@ -85,7 +85,7 @@ public async Task TransformAsyncWithRealAgentBuilderPayloadParsesAllEventTypes() // In production, real SSE streams stay open, so this isn't an issue events.Should().HaveCountGreaterOrEqualTo(7); - // Verify we got the key events + // Verify we got the key events (ConversationStart comes from stream for Agent Builder) events.Should().ContainSingle(e => e is AskAiEvent.ConversationStart); events.Should().ContainSingle(e => e is AskAiEvent.Reasoning); events.Should().ContainSingle(e => e is AskAiEvent.SearchToolCall); @@ -139,7 +139,7 @@ public async Task TransformAsyncWithKeepAliveCommentsSkipsThem() var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); // Act - var outputStream = await _transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + var outputStream = await _transformer.TransformAsync(inputStream, generatedConversationId: null, null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); // Assert - Should have at least 1 event (round_complete might not be written in time) @@ -162,7 +162,7 @@ public async Task TransformAsyncWithMultilineDataFieldsAccumulatesCorrectly() var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); // Act - var outputStream = await _transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + var outputStream = await _transformer.TransformAsync(inputStream, generatedConversationId: null, null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); @@ -210,21 +210,18 @@ public async Task TransformAsyncWithRealLlmGatewayPayloadParsesAllEventTypes() var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); - // Act - var outputStream = await _transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + // Act - Simulate new conversation to get ConversationStart event + var testConversationId = Guid.NewGuid().ToString(); + var outputStream = await _transformer.TransformAsync(inputStream, Guid.Parse(testConversationId), null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); // Assert events.Should().HaveCount(7); - // Event 1: agent_start -> ConversationStart (with generated UUID) + // Event 1: ConversationStart (emitted by transformer for new conversation) events[0].Should().BeOfType(); var convStart = events[0] as AskAiEvent.ConversationStart; - convStart!.ConversationId.Should().NotBeNullOrEmpty(); - - // convStart!.ConversationId.Should().Be("1"); - - _ = Guid.TryParse(convStart.ConversationId, out _).Should().BeTrue(); + convStart!.ConversationId.Should().Be(testConversationId); // Event 2: ai_message_chunk (first) @@ -279,11 +276,12 @@ public async Task TransformAsyncWithEmptyDataLinesSkipsThem() var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); - // Act - var outputStream = await _transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + // Act - Simulate new conversation + var testConversationId = Guid.NewGuid().ToString(); + var outputStream = await _transformer.TransformAsync(inputStream, Guid.Parse(testConversationId), null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); - // Assert - Should only have 2 events + // Assert - Should only have 2 events (ConversationStart + ConversationEnd) events.Should().HaveCount(2); events[0].Should().BeOfType(); events[1].Should().BeOfType(); @@ -304,11 +302,12 @@ public async Task TransformAsyncSkipsModelLifecycleEvents() var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); - // Act - var outputStream = await _transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + // Act - Simulate new conversation + var testConversationId = Guid.NewGuid().ToString(); + var outputStream = await _transformer.TransformAsync(inputStream, Guid.Parse(testConversationId), null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); - // Assert - Should only have the message chunk, model events skipped + // Assert - Should only have the ConversationStart and message chunk, model events skipped events.Should().HaveCount(2); events[0].Should().BeOfType(); events[1].Should().BeOfType(); @@ -354,33 +353,26 @@ public static TheoryData StreamTransformerTe [Theory] [MemberData(nameof(StreamTransformerTestCases))] - public async Task TransformAsyncWhenConversationIdIsNullEmitsConversationStartEvent( + public async Task TransformAsyncWhenIsNewConversationEmitsConversationStartEvent( string transformerName, IStreamTransformer transformer, string sseData) { // Arrange var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); + var testConversationId = Guid.NewGuid().ToString(); - // Act - Pass null conversationId to simulate new conversation - var outputStream = await transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + // Act - Pass isNewConversation: true to simulate new conversation + var outputStream = await transformer.TransformAsync(inputStream, Guid.Parse(testConversationId), null, CancellationToken.None); var events = await StreamTransformerTestHelpers.ParseAskAiEventsAsync(outputStream); // Assert - Should have ConversationStart event events.Should().ContainSingle(e => e is AskAiEvent.ConversationStart, - $"{transformerName} should emit ConversationStart when conversationId is null"); + $"{transformerName} should emit ConversationStart when isNewConversation is true"); var conversationStart = events.OfType().First(); conversationStart.ConversationId.Should().NotBeNullOrEmpty( $"{transformerName} should have a non-empty conversation ID in ConversationStart event"); - - // For LlmGateway, when conversationId is null, we generate a pure GUID - // For AgentBuilder, the conversation ID comes from the SSE event and may have a different format - if (transformerName == "LlmGatewayStreamTransformer") - { - Guid.TryParse(conversationStart.ConversationId, out _).Should().BeTrue( - $"{transformerName} should generate a valid GUID as conversation ID when conversationId is null"); - } } [Theory] @@ -392,9 +384,10 @@ public async Task TransformAsyncConversationStartEventHasValidTimestamp( { // Arrange var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); + var testConversationId = Guid.NewGuid().ToString(); // Act - var outputStream = await transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + var outputStream = await transformer.TransformAsync(inputStream, Guid.Parse(testConversationId), null, CancellationToken.None); var events = new List(); var reader = PipeReader.Create(outputStream); await foreach (var sseEvent in SseParser.ParseAsync(reader, CancellationToken.None)) @@ -422,9 +415,10 @@ public async Task TransformAsyncConversationStartEventHasValidId( { // Arrange var inputStream = new MemoryStream(Encoding.UTF8.GetBytes(sseData)); + var testConversationId = Guid.NewGuid().ToString(); // Act - var outputStream = await transformer.TransformAsync(inputStream, null, null, CancellationToken.None); + var outputStream = await transformer.TransformAsync(inputStream, Guid.Parse(testConversationId), null, CancellationToken.None); var events = new List(); var reader = PipeReader.Create(outputStream); await foreach (var sseEvent in SseParser.ParseAsync(reader, CancellationToken.None))