Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,47 @@ async fn handle_streaming_response(
session_id: String,
tools: Option<Vec<JsonValue>>,
) -> Result<LLMInferenceResponse, LLMProviderError> {
let res = client
// Use tokio::select! to allow cancellation during the initial request phase
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
let response_fut = client
.post(url)
.header("anthropic-version", "2023-06-01")
.header("x-api-key", api_key)
.header("content-type", "application/json")
.json(&payload)
.send()
.await?;
.send();
let mut response_fut = Box::pin(response_fut);

// Wait for response or cancellation
let res = loop {
tokio::select! {
_ = interval.tick() => {
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
eprintln!("LLM job stopped by user request before response arrived");
llm_stopper.reset(&inbox_name.to_string());

// Send WS message indicating the job is done
let _ = send_ws_update(
&ws_manager_trait,
Some(inbox_name.clone()),
&session_id,
"".to_string(),
false,
true,
Some("Stopped by user request".to_string()),
)
.await;

return Ok(LLMInferenceResponse::new("".to_string(), None, json!({}), Vec::new(), Vec::new(), None));
}
}
},
response = &mut response_fut => {
break response?;
}
}
};

// Check if it's an error response
if !res.status().is_success() {
Expand Down
67 changes: 63 additions & 4 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl LLMService for Gemini {
inbox_name: Option<InboxName>,
ws_manager_trait: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
config: Option<JobConfig>,
_llm_stopper: Arc<LLMStopper>,
llm_stopper: Arc<LLMStopper>,
db: Arc<SqliteManager>,
tracing_message_id: Option<String>,
) -> Result<LLMInferenceResponse, LLMProviderError> {
Expand Down Expand Up @@ -236,12 +236,38 @@ impl LLMService for Gemini {
}
}

let res = client
// Use tokio::select! to allow cancellation during the initial request phase
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
let response_fut = client
.post(&url)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
.send();
let mut response_fut = Box::pin(response_fut);

// Wait for response or cancellation
let res = loop {
tokio::select! {
_ = interval.tick() => {
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
"LLM job stopped by user request before response arrived",
);
llm_stopper.reset(&inbox_name.to_string());

return Ok(LLMInferenceResponse::new("".to_string(), None, json!({}), Vec::new(), Vec::new(), None));
}
}
},
response = &mut response_fut => {
break response?;
}
}
};

shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
Expand All @@ -260,6 +286,39 @@ impl LLMService for Gemini {
let mut thinking_started = false;

while let Some(item) = stream.next().await {
// Check if we need to stop the LLM job
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
"LLM job stopped by user request during streaming",
);
llm_stopper.reset(&inbox_name.to_string());

// Send WS message indicating the job is done
let _ = send_ws_update(
&ws_manager_trait,
Some(inbox_name.clone()),
&session_id,
regular_content.clone(),
false,
true,
Some("Stopped by user request".to_string()),
)
.await;

return Ok(LLMInferenceResponse::new(
regular_content,
if thinking_content.is_empty() { None } else { Some(thinking_content) },
json!({}),
function_calls,
generated_files,
None,
));
}
}

match item {
Ok(chunk) => {
process_chunk(
Expand Down
31 changes: 28 additions & 3 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/grok.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,38 @@ async fn handle_streaming_response(
session_id: String,
tools: Option<Vec<JsonValue>>, // Add tools parameter
) -> Result<LLMInferenceResponse, LLMProviderError> {
let res = client
// Use tokio::select! to allow cancellation during the initial request phase
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
let response_fut = client
.post(url)
.bearer_auth(api_key)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
.send();
let mut response_fut = Box::pin(response_fut);

// Wait for response or cancellation
let res = loop {
tokio::select! {
_ = interval.tick() => {
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
"LLM job stopped by user request before response arrived",
);
llm_stopper.reset(&inbox_name.to_string());

return Ok(LLMInferenceResponse::new("".to_string(), None, json!({}), Vec::new(), Vec::new(), None));
}
}
},
response = &mut response_fut => {
break response?;
}
}
};

// Check if it's an error response
let status = res.status();
Expand Down
31 changes: 28 additions & 3 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,38 @@ async fn handle_streaming_response(
session_id: String,
tools: Option<Vec<JsonValue>>, // Add tools parameter
) -> Result<LLMInferenceResponse, LLMProviderError> {
let res = client
// Use tokio::select! to allow cancellation during the initial request phase
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
let response_fut = client
.post(url)
.bearer_auth(api_key)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
.send();
let mut response_fut = Box::pin(response_fut);

// Wait for response or cancellation
let res = loop {
tokio::select! {
_ = interval.tick() => {
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
"LLM job stopped by user request before response arrived",
);
llm_stopper.reset(&inbox_name.to_string());

return Ok(LLMInferenceResponse::new("".to_string(), None, json!({}), Vec::new(), Vec::new(), None));
}
}
},
response = &mut response_fut => {
break response?;
}
}
};

// Check if it's an error response
if !res.status().is_success() {
Expand Down
43 changes: 40 additions & 3 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,9 @@ pub async fn handle_streaming_response(
tools: Option<Vec<JsonValue>>,
headers: Option<JsonValue>,
) -> Result<LLMInferenceResponse, LLMProviderError> {
let res = client
// Use tokio::select! to allow cancellation during the initial request phase
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
let response_fut = client
.post(url)
.bearer_auth(api_key)
.header("Content-Type", "application/json")
Expand Down Expand Up @@ -776,8 +778,43 @@ pub async fn handle_streaming_response(
.unwrap_or(""),
)
.json(&payload)
.send()
.await?;
.send();
let mut response_fut = Box::pin(response_fut);

// Wait for response or cancellation
let res = loop {
tokio::select! {
_ = interval.tick() => {
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
"LLM job stopped by user request before response arrived",
);
llm_stopper.reset(&inbox_name.to_string());

// Send WS message indicating the job is done
let _ = send_ws_update(
&ws_manager_trait,
Some(inbox_name.clone()),
&session_id,
"".to_string(),
false,
true,
Some("Stopped by user request".to_string()),
)
.await;

return Ok(LLMInferenceResponse::new("".to_string(), None, json!({}), Vec::new(), Vec::new(), None));
}
}
},
response = &mut response_fut => {
break response?;
}
}
};

// Check if it's an error response
if !res.status().is_success() {
Expand Down
31 changes: 28 additions & 3 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,39 @@ async fn handle_streaming_response(
session_id: String,
tools: Option<Vec<JsonValue>>, // Add tools parameter
) -> Result<LLMInferenceResponse, LLMProviderError> {
let res = client
// Use tokio::select! to allow cancellation during the initial request phase
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(500));
let response_fut = client
.post(url)
.bearer_auth(api_key)
.header("Content-Type", "application/json")
.header("X-Title", "Shinkai")
.json(&payload)
.send()
.await?;
.send();
let mut response_fut = Box::pin(response_fut);

// Wait for response or cancellation
let res = loop {
tokio::select! {
_ = interval.tick() => {
if let Some(ref inbox_name) = inbox_name {
if llm_stopper.should_stop(&inbox_name.to_string()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
"LLM job stopped by user request before response arrived",
);
llm_stopper.reset(&inbox_name.to_string());

return Ok(LLMInferenceResponse::new("".to_string(), None, json!({}), Vec::new(), Vec::new(), None));
}
}
},
response = &mut response_fut => {
break response?;
}
}
};

// Check if it's an error response
if !res.status().is_success() {
Expand Down