Skip to content

Commit cd24d70

Browse files
committed
feat: token compression at 80 % of max_context_tokens
Signed-off-by: Pierric Buchez <[email protected]>
1 parent 5d0f9e5 commit cd24d70

File tree

11 files changed

+691
-13
lines changed

11 files changed

+691
-13
lines changed

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,50 @@ the `shai` binary will be installed in `$HOME/.local/bin`
2929

3030
## Configure a provider and Run!
3131

32+
### Configuration files
33+
34+
Shai can be configured via **configuration files** written in JSON. By default, the configuration file is `auth.config` located in `~/.config/shai/`. The file defines the list of LLM providers, the selected provider, model, and tool call method.
35+
36+
#### Example `.shai.config`
37+
```json
38+
{
39+
"providers": [
40+
{
41+
"provider": "ovhcloud",
42+
"env_vars": {
43+
"OVH_BASE_URL": "https://gpt-oss-120b.endpoints.kepler.ai.cloud.ovh.net/api/openai_compat/v1"
44+
},
45+
"model": "gpt-oss-120b",
46+
"tool_method": "FunctionCall",
47+
"max_context_tokens": 8192
48+
}
49+
],
50+
"selected_provider": 0
51+
}
52+
```
53+
54+
- **providers**: an array of provider definitions. Each provider can specify environment variables (`env_vars`), the model name, the tool call method (`FunctionCall` or `Chat`), and optionally `max_context_tokens` to limit the context size.
55+
- **selected_provider**: the index of the provider to use (starting at `0`).
56+
- **max_context_tokens** (optional, per provider): maximum number of tokens that can be sent in the context to the LLM. If omitted, the default for the model is used.
57+
58+
You can create multiple configuration files for different agents (see the *Custom Agent* section). To use a specific configuration, place the file in `~/.config/shai/agents/` and run the agent by its filename (without the `.config` extension):
59+
```
60+
shai my_custom_agent
61+
```
62+
63+
Shai will automatically load the configuration, set the required environment variables, and use the selected provider for all subsequent interactions.
64+
65+
### Using the configuration
66+
67+
- **Automatic loading**: If a `.shai.config` file is present in the current directory, Shai will load it automatically.
68+
- **Explicit loading**: Use the `--config <path>` flag to specify a custom configuration file:
69+
```
70+
shai --config ~/.config/shai/agents/example.config
71+
```
72+
73+
The configuration system allows you to switch providers, models, or tool call methods without recompiling the binary.
74+
75+
3276
By default `shai` uses OVHcloud as an anonymous user meaning you will be rate limited! If you want to sign in with your account or select another provider, run:
3377

3478
```

shai-core/src/agent/actions/brain.rs

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl AgentCore {
4747

4848
/// Process a brain task result
4949
pub async fn process_next_step(&mut self, result: Result<ThinkerDecision, AgentError>) -> Result<(), AgentError> {
50-
let ThinkerDecision{message, flow, token_usage} = self.handle_brain_error(result).await?;
50+
let ThinkerDecision{message, flow, token_usage, compression_info} = self.handle_brain_error(result).await?;
5151
let ChatMessage::Assistant { content, reasoning_content, tool_calls, .. } = message.clone() else {
5252
return self.handle_brain_error::<ThinkerDecision>(
5353
Err(AgentError::InvalidResponse(format!("ChatMessage::Assistant expected, but got {:?} instead", message)))).await.map(|_| ()
@@ -72,6 +72,18 @@ impl AgentCore {
7272
output_tokens
7373
}).await;
7474
}
75+
76+
// Emit context compression event if available
77+
if let Some(compression_info) = compression_info {
78+
let _ = self.emit_event(AgentEvent::ContextCompressed {
79+
original_message_count: compression_info.original_message_count,
80+
compressed_message_count: compression_info.compressed_message_count,
81+
tokens_before: compression_info.tokens_before,
82+
current_tokens: compression_info.current_tokens,
83+
max_tokens: compression_info.max_tokens,
84+
ai_summary: compression_info.ai_summary,
85+
}).await;
86+
}
7587

7688
// run tool call if any
7789
let tool_calls_from_brain = tool_calls.unwrap_or(vec![]);
@@ -86,19 +98,74 @@ impl AgentCore {
8698
self.set_state(InternalAgentState::Running).await;
8799
}
88100
ThinkerFlowControl::AgentPause => {
101+
// Check if we need to compress context when task is complete
102+
self.check_and_compress_context().await?;
89103
self.set_state(InternalAgentState::Paused).await;
90104
}
91105
}
92106
Ok(())
93107
}
94108

109+
/// Check if context compression is needed and apply it when task is complete
110+
async fn check_and_compress_context(&mut self) -> Result<(), AgentError> {
111+
// Extract compression logic from the brain if it's a CoderBrain
112+
let brain = self.brain.clone();
113+
let brain_read = brain.read().await;
114+
115+
// This is a bit hacky but we need to check if the brain has a compressor
116+
// We'll use Any trait to downcast to CoderBrain
117+
use std::any::Any;
118+
119+
if let Some(coder_brain) = (&**brain_read as &dyn Any).downcast_ref::<crate::runners::coder::coder::CoderBrain>() {
120+
if let Some(compressor) = &coder_brain.context_compressor {
121+
let compressor_clone = compressor.clone();
122+
drop(brain_read); // Release the read lock
123+
124+
let trace = self.trace.read().await.clone();
125+
let mut compressor_clone = compressor_clone;
126+
127+
if compressor_clone.should_compress_conversation(&trace) {
128+
let (compressed_trace, compression_info) = compressor_clone.compress_messages(trace).await;
129+
130+
// Update the trace with compressed version
131+
{
132+
let mut trace_write = self.trace.write().await;
133+
*trace_write = compressed_trace;
134+
}
135+
136+
// Update the compressor in the brain
137+
{
138+
let mut brain_write = brain.write().await;
139+
if let Some(coder_brain_mut) = (&mut **brain_write as &mut dyn Any).downcast_mut::<crate::runners::coder::coder::CoderBrain>() {
140+
coder_brain_mut.context_compressor = Some(compressor_clone);
141+
}
142+
}
143+
144+
// Emit compression event if compression occurred
145+
if let Some(compression_info) = compression_info {
146+
let _ = self.emit_event(AgentEvent::ContextCompressed {
147+
original_message_count: compression_info.original_message_count,
148+
compressed_message_count: compression_info.compressed_message_count,
149+
tokens_before: compression_info.tokens_before,
150+
current_tokens: compression_info.current_tokens,
151+
max_tokens: compression_info.max_tokens,
152+
ai_summary: compression_info.ai_summary,
153+
}).await;
154+
}
155+
}
156+
}
157+
}
158+
159+
Ok(())
160+
}
161+
95162
// Helper method that emits error events before returning the error
96163
async fn handle_brain_error<T>(&mut self, result: Result<T, AgentError>) -> Result<T, AgentError> {
97164
match result {
98165
Ok(value) => Ok(value),
99166
Err(error) => {
100167
self.set_state(InternalAgentState::Paused).await;
101-
let _ = self.emit_event(AgentEvent::BrainResult {
168+
let _ = self.emit_event(AgentEvent::BrainResult {
102169
timestamp: Utc::now(),
103170
thought: Err(error.clone())
104171
}).await;

shai-core/src/agent/brain.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use shai_llm::{ChatMessage, ToolCallMethod};
44
use tokio::sync::RwLock;
55

66
use crate::tools::types::AnyToolBox;
7+
use crate::runners::compacter::CompressionInfo;
78
use super::error::AgentError;
89

910

@@ -29,6 +30,7 @@ pub struct ThinkerDecision {
2930
pub message: ChatMessage,
3031
pub flow: ThinkerFlowControl,
3132
pub token_usage: Option<(u32, u32)>, // (input_tokens, output_tokens)
33+
pub compression_info: Option<CompressionInfo>,
3234
}
3335

3436
impl ThinkerDecision {
@@ -37,6 +39,7 @@ impl ThinkerDecision {
3739
message,
3840
flow: ThinkerFlowControl::AgentPause,
3941
token_usage: None,
42+
compression_info: None,
4043
}
4144
}
4245

@@ -45,6 +48,7 @@ impl ThinkerDecision {
4548
message,
4649
flow: ThinkerFlowControl::AgentContinue,
4750
token_usage: None,
51+
compression_info: None,
4852
}
4953
}
5054

@@ -53,6 +57,7 @@ impl ThinkerDecision {
5357
message,
5458
flow: ThinkerFlowControl::AgentPause,
5559
token_usage: None,
60+
compression_info: None,
5661
}
5762
}
5863

@@ -61,6 +66,7 @@ impl ThinkerDecision {
6166
message,
6267
flow: ThinkerFlowControl::AgentContinue,
6368
token_usage: Some((input_tokens, output_tokens)),
69+
compression_info: None,
6470
}
6571
}
6672

@@ -69,6 +75,25 @@ impl ThinkerDecision {
6975
message,
7076
flow: ThinkerFlowControl::AgentPause,
7177
token_usage: Some((input_tokens, output_tokens)),
78+
compression_info: None,
79+
}
80+
}
81+
82+
pub fn agent_continue_with_compression(message: ChatMessage, input_tokens: u32, output_tokens: u32, compression_info: CompressionInfo) -> Self {
83+
ThinkerDecision{
84+
message,
85+
flow: ThinkerFlowControl::AgentContinue,
86+
token_usage: Some((input_tokens, output_tokens)),
87+
compression_info: Some(compression_info),
88+
}
89+
}
90+
91+
pub fn agent_pause_with_compression(message: ChatMessage, input_tokens: u32, output_tokens: u32, compression_info: CompressionInfo) -> Self {
92+
ThinkerDecision{
93+
message,
94+
flow: ThinkerFlowControl::AgentPause,
95+
token_usage: Some((input_tokens, output_tokens)),
96+
compression_info: Some(compression_info),
7297
}
7398
}
7499

@@ -79,7 +104,7 @@ impl ThinkerDecision {
79104

80105
/// Core thinking interface - pure decision making
81106
#[async_trait]
82-
pub trait Brain: Send + Sync {
107+
pub trait Brain: Send + Sync + std::any::Any {
83108
/// This method is called at every step of the agent to decide next step
84109
/// note that if the message contains toolcall, it will always continue
85110
async fn next_step(&mut self, context: ThinkerContext) -> Result<ThinkerDecision, AgentError>;

shai-core/src/agent/events.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ pub enum AgentEvent {
101101
input_tokens: u32,
102102
output_tokens: u32
103103
},
104+
/// Context compression notification
105+
ContextCompressed {
106+
original_message_count: usize,
107+
compressed_message_count: usize,
108+
tokens_before: Option<u32>,
109+
current_tokens: Option<u32>,
110+
max_tokens: u32,
111+
ai_summary: Option<String>,
112+
},
104113
}
105114

106115
/// Types of user input that an agent can request
@@ -274,6 +283,16 @@ impl std::fmt::Debug for AgentEvent {
274283
.field("output_tokens", output_tokens)
275284
.finish()
276285
}
286+
AgentEvent::ContextCompressed { original_message_count, compressed_message_count, tokens_before, current_tokens, max_tokens, ai_summary } => {
287+
f.debug_struct("ContextCompressed")
288+
.field("original_message_count", original_message_count)
289+
.field("compressed_message_count", compressed_message_count)
290+
.field("tokens_before", tokens_before)
291+
.field("current_tokens", current_tokens)
292+
.field("max_tokens", max_tokens)
293+
.field("ai_summary", ai_summary)
294+
.finish()
295+
}
277296
}
278297
}
279298
}

shai-core/src/agent/output/log.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ impl FileEventLogger {
5656
AgentEvent::TokenUsage { input_tokens, output_tokens } => {
5757
format!("Token Usage: input={} output={} total={}", input_tokens, output_tokens, input_tokens + output_tokens)
5858
}
59+
AgentEvent::ContextCompressed { original_message_count, compressed_message_count, tokens_before, current_tokens, max_tokens, ai_summary } => {
60+
let summary_text = if let Some(summary) = ai_summary {
61+
format!(" | Summary: {}", summary)
62+
} else {
63+
"".to_string()
64+
};
65+
66+
let token_info = match (tokens_before, current_tokens) {
67+
(Some(before), Some(after)) => format!(", tokens: {} → {}", before, after),
68+
(Some(before), None) => format!(", tokens before: {}", before),
69+
_ => "".to_string(),
70+
};
71+
72+
format!("Context Compressed with AI Summary: {} → {} messages{}{}",
73+
original_message_count, compressed_message_count, token_info, summary_text)
74+
}
5975
};
6076

6177
let log_line = format!("[{}] {}\n", timestamp.format("%Y-%m-%d %H:%M:%S%.3f"), event_str);

shai-core/src/agent/output/pretty.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,47 @@ impl PrettyFormatter {
114114
// Don't display token usage in the main output - it's handled by /tokens command
115115
None
116116
},
117+
AgentEvent::ContextCompressed { original_message_count, compressed_message_count, tokens_before, current_tokens, max_tokens, ai_summary } => {
118+
let net_change = if original_message_count > compressed_message_count {
119+
original_message_count - compressed_message_count
120+
} else {
121+
0
122+
};
123+
124+
let markdown = match (tokens_before, current_tokens) {
125+
(Some(before), Some(after)) => {
126+
if net_change > 0 {
127+
format!(
128+
"● **Context Compressed with AI Summary** - Summarized {} message(s) to stay within token limits ({} → {} tokens)",
129+
net_change, before, after
130+
)
131+
} else {
132+
format!(
133+
"● **Context Compression Applied** - Added AI summary to optimize token usage ({} → {} tokens)",
134+
before, after
135+
)
136+
}
137+
}
138+
_ => {
139+
if net_change > 0 {
140+
format!(
141+
"● **Context Compressed with AI Summary** - Summarized {} message(s) to stay within token limits",
142+
net_change
143+
)
144+
} else {
145+
format!(
146+
"● **Context Compression Applied** - Added AI summary to optimize token usage"
147+
)
148+
}
149+
}
150+
};
151+
152+
let mut compression_skin = self.skin.clone();
153+
compression_skin.paragraph.set_fg(rgb(100, 200, 255)); // Blue for AI compression
154+
compression_skin.bold.set_fg(rgb(120, 220, 255)); // Light blue for bold
155+
156+
Some(compression_skin.term_text(&markdown).to_string())
157+
},
117158
}.map(|s| format!("\n{}", s))
118159
}
119160

shai-core/src/config/config.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ pub struct ProviderConfig {
1212
pub provider: String,
1313
pub env_vars: std::collections::HashMap<String, String>,
1414
pub model: String,
15-
pub tool_method: ToolCallMethod
15+
pub tool_method: ToolCallMethod,
16+
#[serde(default)]
17+
pub max_context_tokens: Option<u32>,
1618
}
1719

1820
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -36,9 +38,23 @@ impl ShaiConfig {
3638
provider,
3739
env_vars,
3840
model,
39-
tool_method: ToolCallMethod::FunctionCall
41+
tool_method: ToolCallMethod::FunctionCall,
42+
max_context_tokens: None,
4043
};
41-
44+
45+
self.providers.push(provider_config);
46+
self.providers.len() - 1
47+
}
48+
49+
pub fn add_provider_with_context(&mut self, provider: String, env_vars: std::collections::HashMap<String, String>, model: String, max_context_tokens: Option<u32>) -> usize {
50+
let provider_config = ProviderConfig {
51+
provider,
52+
env_vars,
53+
model,
54+
tool_method: ToolCallMethod::FunctionCall,
55+
max_context_tokens,
56+
};
57+
4258
self.providers.push(provider_config);
4359
self.providers.len() - 1
4460
}
@@ -228,7 +244,8 @@ impl Default for ShaiConfig {
228244
(String::from("OVH_BASE_URL"), String::from("https://qwen-3-32b.endpoints.kepler.ai.cloud.ovh.net/api/openai_compat/v1"))
229245
]),
230246
model: "Qwen3-32B".to_string(),
231-
tool_method: ToolCallMethod::FunctionCall
247+
tool_method: ToolCallMethod::FunctionCall,
248+
max_context_tokens: Some(32768), // Qwen3-32B has 32k context
232249
}],
233250
selected_provider: 0,
234251
mcp_configs: HashMap::new(),

0 commit comments

Comments
 (0)