Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions files/shinkai_welcome.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ At its core, an AI agent in Shinkai starts with a base AI model (like those from
2. **Give Instructions:** Write a "System Prompt" detailing how you want your agent to act, what knowledge it should focus on, or the persona it should adopt.
3. **Equip with Tools:** Grant your agent specific skills by enabling "Tools" (you can build your own with our specialized AI, download from the AI Store or manually create them).

┌──────────────────┐ ┌────────────────────────┐ ┌──────────────────┐
│ 1. Pick a Model │→ │ 2. Prompt + Add Tools │→ │ 3. Launch Agent │
└──────────────────┘ └────────────────────────┘ └──────────────────┘
* ┌──────────────────┐ ┌────────────────────────┐ ┌──────────────────┐
* │ 1. Pick a Model │→ │ 2. Prompt + Add Tools │→ │ 3. Launch Agent │
* └──────────────────┘ └────────────────────────┘ └──────────────────┘

## Tools: Supercharging Your Agents

Expand Down
45 changes: 14 additions & 31 deletions shinkai-bin/shinkai-node/src/llm_provider/job_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,37 +777,20 @@ impl JobManager {
// Remove from the jobs map
self.jobs.lock().await.remove(&job_id);

// Remove from the database
if let Some(db_arc) = self.db.upgrade() {
// Remove job from database
if let Err(e) = db_arc.remove_job(&job_id) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Failed to delete job {} from database: {}", job_id, e),
);
return Err(LLMProviderError::ShinkaiDB(e));
}

// Remove from both job queues
let _ = self.job_queue_manager_normal.lock().await.dequeue(&job_id).await;
let _ = self.job_queue_manager_immediate.lock().await.dequeue(&job_id).await;

shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
&format!(
"Successfully killed job with conversation inbox: {}",
conversation_inbox_name
),
);

Ok(job_id)
} else {
Err(LLMProviderError::DatabaseError(
"Failed to upgrade database reference".to_string(),
))
}
// Remove from both job queues
let _ = self.job_queue_manager_normal.lock().await.dequeue(&job_id).await;
let _ = self.job_queue_manager_immediate.lock().await.dequeue(&job_id).await;

shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
&format!(
"Successfully killed job with conversation inbox: {}",
conversation_inbox_name
),
);

Ok(job_id)
}
}

Expand Down
17 changes: 1 addition & 16 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use shinkai_message_primitives::schemas::job_config::JobConfig;
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::{LLMProviderInterface, OpenAI};
use shinkai_message_primitives::schemas::prompts::Prompt;
use shinkai_message_primitives::schemas::ws_types::{
ToolMetadata, ToolStatus, ToolStatusType, WSMessageType, WSMetadata, WSUpdateHandler, WidgetMetadata,
ToolMetadata, ToolStatus, ToolStatusType, WSMessageType, WSMetadata, WSUpdateHandler, WidgetMetadata
};
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::WSTopic;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption};
Expand Down Expand Up @@ -151,21 +151,6 @@ impl LLMService for OpenAI {
}
}

if let Some(ref msg_id) = tracing_message_id {
let network_info = json!({
"url": url,
"payload": payload_log
});
if let Err(e) = db.add_tracing(
msg_id,
inbox_name.as_ref().map(|i| i.get_value()).as_deref(),
"llm_network_request",
&network_info,
) {
eprintln!("failed to add network request trace: {:?}", e);
}
}

if is_stream {
handle_streaming_response(
client,
Expand Down
43 changes: 35 additions & 8 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ impl Node {
let listen_address_clone = self.listen_address;
let libp2p_manager_clone = self.libp2p_manager.clone();
tokio::spawn(async move {
let _ = Self::ping_all(
listen_address_clone,
libp2p_manager_clone,
)
.await;
let _ = Self::ping_all(listen_address_clone, libp2p_manager_clone).await;
});
}
NodeCommand::GetPublicKeys(sender) => {
Expand Down Expand Up @@ -797,6 +793,28 @@ impl Node {
let _ = Node::v2_remove_job(db_clone, bearer, job_id, res).await;
});
}
NodeCommand::V2ApiKillJob {
bearer,
conversation_inbox_name,
res,
} => {
let db_clone = self.db.clone();
let job_manager_clone = self.job_manager.clone().unwrap();
let ws_manager_clone = self.ws_manager.clone();
let llm_stopper_clone = self.llm_stopper.clone();
tokio::spawn(async move {
let _ = Node::v2_api_kill_job(
db_clone,
job_manager_clone,
ws_manager_clone,
llm_stopper_clone,
bearer,
conversation_inbox_name,
res,
)
.await;
});
}
NodeCommand::V2ApiVecFSRetrievePathSimplifiedJson { bearer, payload, res } => {
let db_clone = Arc::clone(&self.db);

Expand Down Expand Up @@ -1245,11 +1263,16 @@ impl Node {
let _ = Node::v2_api_get_shinkai_tool_metadata(db_clone, bearer, tool_router_key, res).await;
});
}
NodeCommand::V2ApiGetToolWithOffering { bearer, tool_key_name, res } => {
NodeCommand::V2ApiGetToolWithOffering {
bearer,
tool_key_name,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_name_clone = self.node_name.clone();
tokio::spawn(async move {
let _ = Node::v2_api_get_tool_with_offering(db_clone, node_name_clone, bearer, tool_key_name, res).await;
let _ = Node::v2_api_get_tool_with_offering(db_clone, node_name_clone, bearer, tool_key_name, res)
.await;
});
}
NodeCommand::V2ApiGetToolsWithOfferings { bearer, res } => {
Expand Down Expand Up @@ -1669,7 +1692,11 @@ impl Node {
let _ = Node::v2_api_get_job_scope(db_clone, bearer, job_id, res).await;
});
}
NodeCommand::V2ApiGetMessageTraces { bearer, message_id, res } => {
NodeCommand::V2ApiGetMessageTraces {
bearer,
message_id,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_get_message_traces(db_clone, bearer, message_id, res).await;
Expand Down
100 changes: 94 additions & 6 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use tokio::sync::Mutex;
use x25519_dalek::PublicKey as EncryptionPublicKey;

use crate::{
llm_provider::job_manager::JobManager, managers::IdentityManager, network::{node_error::NodeError, Node}
llm_provider::{job_manager::JobManager, llm_stopper::LLMStopper}, managers::IdentityManager, network::{node_error::NodeError, ws_manager::WebSocketManager, Node}
};

use x25519_dalek::StaticSecret as EncryptionStaticKey;
Expand Down Expand Up @@ -1605,6 +1605,98 @@ impl Node {
Ok(())
}

pub async fn v2_api_kill_job(
db: Arc<SqliteManager>,
job_manager: Arc<Mutex<JobManager>>,
ws_manager: Option<Arc<Mutex<WebSocketManager>>>,
llm_stopper: Arc<LLMStopper>,
bearer: String,
conversation_inbox_name: String,
res: Sender<Result<SendResponseBody, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Kill the job and capture necessary info
let (job_id, identity_sk, node_name) = {
let mut jm = job_manager.lock().await;
match jm.kill_job_by_conversation_inbox_name(&conversation_inbox_name).await {
Ok(job_id) => {
let id_sk = clone_signature_secret_key(&jm.identity_secret_key);
let node_name = jm.node_profile_name.clone();
(job_id, id_sk, node_name)
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: err.to_string(),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
}
};

// Obtain partial assistant message from the WebSocket manager
let partial_text = if let Some(manager) = ws_manager.as_ref() {
manager
.lock()
.await
.get_fragment(&conversation_inbox_name)
.await
.unwrap_or_default()
} else {
String::new()
};

// Signal the LLM to stop processing
llm_stopper.stop(&conversation_inbox_name);

// Insert an assistant message with the partial text
let ai_message = ShinkaiMessageBuilder::job_message_from_llm_provider(
job_id.clone(),
partial_text,
Vec::new(),
None,
identity_sk,
node_name.node_name.clone(),
node_name.node_name.clone(),
)
.map_err(|_| NodeError {
message: "Failed to build message".to_string(),
})?;

if let Err(err) = db.add_message_to_job_inbox(&job_id, &ai_message, None, None).await {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to add message: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}

if let Some(manager) = ws_manager {
manager.lock().await.clear_fragment(&conversation_inbox_name).await;
}

// Clear any stop signal set for this job
llm_stopper.reset(&conversation_inbox_name);

let _ = res
.send(Ok(SendResponseBody {
status: "success".to_string(),
message: "Job killed successfully".to_string(),
data: None,
}))
.await;

Ok(())
}

pub async fn v2_export_messages_from_inbox(
db: Arc<SqliteManager>,
bearer: String,
Expand Down Expand Up @@ -1761,11 +1853,7 @@ impl Node {

for messages in v2_chat_messages {
for message in messages {
let role = if message
.sender_subidentity
.to_lowercase()
.contains("/agent/")
{
let role = if message.sender_subidentity.to_lowercase().contains("/agent/") {
"assistant"
} else {
"user"
Expand Down
25 changes: 25 additions & 0 deletions shinkai-bin/shinkai-node/src/network/ws_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct WebSocketManager {
identity_manager_trait: Arc<Mutex<dyn IdentityManagerTrait + Send>>,
encryption_secret_key: EncryptionStaticKey,
message_queue: MessageQueue,
message_fragments: Arc<Mutex<HashMap<String, String>>>,
}

impl Clone for WebSocketManager {
Expand All @@ -55,6 +56,7 @@ impl Clone for WebSocketManager {
identity_manager_trait: Arc::clone(&self.identity_manager_trait),
encryption_secret_key: self.encryption_secret_key.clone(),
message_queue: Arc::clone(&self.message_queue),
message_fragments: Arc::clone(&self.message_fragments),
}
}
}
Expand Down Expand Up @@ -87,6 +89,7 @@ impl WebSocketManager {
identity_manager_trait,
encryption_secret_key,
message_queue: Arc::new(Mutex::new(VecDeque::new())),
message_fragments: Arc::new(Mutex::new(HashMap::new())),
}));

let manager_clone = Arc::clone(&manager);
Expand Down Expand Up @@ -493,6 +496,16 @@ impl WebSocketManager {
}
}
}

pub async fn get_fragment(&self, inbox: &str) -> Option<String> {
let fragments = self.message_fragments.lock().await;
fragments.get(inbox).cloned()
}

pub async fn clear_fragment(&self, inbox: &str) {
let mut fragments = self.message_fragments.lock().await;
fragments.remove(inbox);
}
}

#[async_trait]
Expand All @@ -505,6 +518,18 @@ impl WSUpdateHandler for WebSocketManager {
metadata: WSMessageType,
is_stream: bool,
) {
if is_stream && matches!(topic, WSTopic::Inbox) {
let mut fragments = self.message_fragments.lock().await;
let entry = fragments.entry(subtopic.clone()).or_default();
entry.push_str(&update);

if let WSMessageType::Metadata(meta) = &metadata {
if meta.is_done {
fragments.remove(&subtopic);
}
}
}

let mut queue = self.message_queue.lock().await;
queue.push_back((topic, subtopic, update, metadata, is_stream));
}
Expand Down
Loading