Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaIN.Core.UnitTests/AgentContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public async Task ProcessAsync_WithStringMessage_ShouldReturnChatResult()
.ReturnsAsync(chat);

_mockAgentService
.Setup(s => s.Process(It.IsAny<Chat>(), _agentContext.GetAgentId(), It.IsAny<Knowledge>(), It.IsAny<bool>()))
.Setup(s => s.Process(It.IsAny<Chat>(), _agentContext.GetAgentId(), It.IsAny<Knowledge>(), It.IsAny<bool>(), null))
.ReturnsAsync(new Chat {
Model = "test-model",
Name = "test",
Expand Down
2 changes: 1 addition & 1 deletion src/MaIN.Core.UnitTests/FlowContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public async Task ProcessAsync_WithStringMessage_ShouldReturnChatResult()
.ReturnsAsync(chat);

_mockAgentService
.Setup(s => s.Process(It.IsAny<Chat>(), firstAgent.Id, It.IsAny<Knowledge>(), It.IsAny<bool>()))
.Setup(s => s.Process(It.IsAny<Chat>(), firstAgent.Id, It.IsAny<Knowledge>(), It.IsAny<bool>(), null))
.ReturnsAsync(new Chat {
Model = "test-model",
Name = "test",
Expand Down
2 changes: 1 addition & 1 deletion src/MaIN.Core/.nuspec
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<package>
<metadata>
<id>MaIN.NET</id>
<version>0.7.2</version>
<version>0.7.3</version>
<authors>Wisedev</authors>
<owners>Wisedev</owners>
<icon>favicon.png</icon>
Expand Down
12 changes: 6 additions & 6 deletions src/MaIN.Core/Hub/Contexts/AgentContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public async Task<ChatResult> ProcessAsync(Chat chat, bool translate = false)
};
}

public async Task<ChatResult> ProcessAsync(string message, bool translate = false)
public async Task<ChatResult> ProcessAsync(string message, bool translate = false, Func<LLMTokenValue, Task>? callback = null)
{
if (_knowledge == null)
{
Expand All @@ -240,7 +240,7 @@ public async Task<ChatResult> ProcessAsync(string message, bool translate = fals
Type = MessageType.LocalLLM,
Time = DateTime.Now
});
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate);
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, callback);
var messageResult = result.Messages.LastOrDefault()!;
return new ChatResult()
{
Expand All @@ -251,15 +251,15 @@ public async Task<ChatResult> ProcessAsync(string message, bool translate = fals
};
}

public async Task<ChatResult> ProcessAsync(Message message, bool translate = false)
public async Task<ChatResult> ProcessAsync(Message message, bool translate = false, Func<LLMTokenValue, Task>? callback = null)
{
if (_knowledge == null)
{
LoadExistingKnowledgeIfExists();
}
var chat = await _agentService.GetChatByAgent(_agent.Id);
chat.Messages.Add(message);
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate);
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, callback);
var messageResult = result.Messages.LastOrDefault()!;
return new ChatResult()
{
Expand All @@ -270,7 +270,7 @@ public async Task<ChatResult> ProcessAsync(Message message, bool translate = fal
};
}

public async Task<ChatResult> ProcessAsync(IEnumerable<Message> messages, bool translate = false)
public async Task<ChatResult> ProcessAsync(IEnumerable<Message> messages, bool translate = false, Func<LLMTokenValue, Task>? callback = null)
{
if (_knowledge == null)
{
Expand All @@ -279,7 +279,7 @@ public async Task<ChatResult> ProcessAsync(IEnumerable<Message> messages, bool t
var chat = await _agentService.GetChatByAgent(_agent.Id);
chat.Messages.Clear();
chat.Messages.AddRange(messages);
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate);
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, callback);
var messageResult = result.Messages.LastOrDefault()!;
return new ChatResult()
{
Expand Down
5 changes: 4 additions & 1 deletion src/MaIN.Services/Services/Abstract/IAgentService.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
using MaIN.Domain.Entities;
using MaIN.Domain.Entities.Agents;
using MaIN.Domain.Entities.Agents.Knowledge;
using MaIN.Domain.Models;
using MaIN.Services.Services.LLMService;

namespace MaIN.Services.Services.Abstract;

public interface IAgentService
{
Task<Chat> Process(Chat chat, string agentId, Knowledge? knowledge, bool translatePrompt = false);
Task<Chat> Process(Chat chat, string agentId, Knowledge? knowledge, bool translatePrompt = false,
Func<LLMTokenValue, Task>? callback = null);
Task<Agent> CreateAgent(Agent agent, bool flow = false, bool interactiveResponse = false,
InferenceParams? inferenceParams = null, MemoryParams? memoryParams = null, bool disableCache = false);
Task<Chat> GetChatByAgent(string agentId);
Expand Down
2 changes: 2 additions & 0 deletions src/MaIN.Services/Services/Abstract/IStepProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using MaIN.Domain.Entities;
using MaIN.Domain.Entities.Agents.Knowledge;
using MaIN.Domain.Models;
using MaIN.Infrastructure.Models;
using Microsoft.Extensions.Logging;

Expand All @@ -11,6 +12,7 @@ Task<Chat> ProcessSteps(AgentContextDocument context,
AgentDocument agent,
Knowledge? knowledge,
Chat chat,
Func<LLMTokenValue, Task>? callback,
Func<string, string, string?, string, string, Task> notifyProgress,
Func<Chat, Task> updateChat,
ILogger logger);
Expand Down
10 changes: 9 additions & 1 deletion src/MaIN.Services/Services/AgentService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
using MaIN.Domain.Entities;
using MaIN.Domain.Entities.Agents;
using MaIN.Domain.Entities.Agents.Knowledge;
using MaIN.Domain.Models;
using MaIN.Infrastructure.Repositories.Abstract;
using MaIN.Services.Constants;
using MaIN.Services.Mappers;
using MaIN.Services.Services.Abstract;
using MaIN.Services.Services.ImageGenServices;
using MaIN.Services.Services.LLMService;
using MaIN.Services.Services.LLMService.Factory;
using MaIN.Services.Services.Models.Commands;
using MaIN.Services.Services.Steps.Commands;
Expand All @@ -28,7 +30,12 @@ public class AgentService(
MaINSettings maInSettings)
: IAgentService
{
public async Task<Chat> Process(Chat chat, string agentId, Knowledge? knowledge, bool translatePrompt = false)
public async Task<Chat> Process(
Chat chat,
string agentId,
Knowledge? knowledge,
bool translatePrompt = false,
Func<LLMTokenValue, Task>? callback = null)
{
var agent = await agentRepository.GetAgentById(agentId);
if (agent == null)
Expand All @@ -46,6 +53,7 @@ await notificationService.DispatchNotification(
agent,
knowledge,
chat,
callback,
async (status, id, progress, behaviour, details) =>
{
await notificationService.DispatchNotification(
Expand Down
2 changes: 2 additions & 0 deletions src/MaIN.Services/Services/Models/Commands/AnswerCommand.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using MaIN.Domain.Entities;
using MaIN.Domain.Entities.Agents.Knowledge;
using MaIN.Domain.Models;
using MaIN.Services.Constants;
using MaIN.Services.Services.Models.Commands.Base;
using MaIN.Services.Services.Steps.Commands;
Expand All @@ -14,4 +15,5 @@ public class AnswerCommand : BaseCommand, ICommand<Message?>
public KnowledgeUsage KnowledgeUsage { get; init; }
public Knowledge? Knowledge { get; init; }
public string CommandName => "ANSWER";
public Func<LLMTokenValue, Task>? Callback { get; set; }
}
2 changes: 2 additions & 0 deletions src/MaIN.Services/Services/Models/StepContext.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using MaIN.Domain.Entities;
using MaIN.Domain.Entities.Agents.Knowledge;
using MaIN.Domain.Models;
using MaIN.Infrastructure.Models;

namespace MaIN.Services.Services.Models;
Expand All @@ -16,4 +17,5 @@ public class StepContext
public required Func<Chat, Task> UpdateChat { get; init; }
public required string StepName { get; init; }
public Knowledge? Knowledge { get; set; }
public Func<LLMTokenValue, Task>? Callback { get; set; }
}
6 changes: 6 additions & 0 deletions src/MaIN.Services/Services/StepProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using MaIN.Domain.Entities;
using MaIN.Domain.Entities.Agents.Knowledge;
using MaIN.Domain.Models;
using MaIN.Infrastructure.Models;
using MaIN.Services.Services.Abstract;
using MaIN.Services.Services.Models;
Expand Down Expand Up @@ -29,14 +30,18 @@ public async Task<Chat> ProcessSteps(
AgentDocument agent,
Knowledge? knowledge,
Chat chat,
Func<LLMTokenValue, Task>? callback,
Func<string, string, string?, string, string, Task> notifyProgress,
Func<Chat, Task> updateChat,
ILogger logger)
{
Message redirectMessage = chat.Messages.Last();
var stepCount = 0;
var tagsToReplaceWithFilter = new List<string>();
foreach (var step in context.Steps!)
{
stepCount++;
var lastStep = stepCount.Equals(context.Steps.Count);
logger.LogInformation("Processing step: {Step} on agent {agent}", step, agent.Name);

var (stepName, arguments) = ParseStep(step);
Expand All @@ -53,6 +58,7 @@ public async Task<Chat> ProcessSteps(
McpConfig = context.McpConfig,
NotifyProgress = notifyProgress,
UpdateChat = updateChat,
Callback = lastStep ? callback : null,
StepName = stepName
};

Expand Down
3 changes: 2 additions & 1 deletion src/MaIN.Services/Services/Steps/AnswerStepHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public async Task<StepResult> Handle(StepContext context)
: useMemory ? KnowledgeUsage.UseMemory
: KnowledgeUsage.None,
Knowledge = context.Knowledge,
AgentId = context.Agent.Id
AgentId = context.Agent.Id,
Callback = context.Callback
};

var answerResponse = await commandDispatcher.DispatchAsync(answerCommand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class AnswerCommandHandler(
result = command.Chat.Visual
? await imageGenService!.Send(command.Chat)
: await llmService.Send(command.Chat,
new ChatRequestOptions { InteractiveUpdates = command.Chat.Interactive });
new ChatRequestOptions { InteractiveUpdates = command.Chat.Interactive, TokenCallback = command.Callback });

return result!.Message;
}
Expand Down
Loading