@@ -3,6 +3,8 @@ use crate::CodexAuth;
3
3
use crate :: codex:: Codex ;
4
4
use crate :: codex:: CodexSpawnOk ;
5
5
use crate :: codex:: INITIAL_SUBMIT_ID ;
6
+ use crate :: codex:: compact:: content_items_to_text;
7
+ use crate :: codex:: compact:: is_session_prefix_message;
6
8
use crate :: codex_conversation:: CodexConversation ;
7
9
use crate :: config:: Config ;
8
10
use crate :: error:: CodexErr ;
@@ -134,19 +136,19 @@ impl ConversationManager {
134
136
self . conversations . write ( ) . await . remove ( conversation_id)
135
137
}
136
138
137
- /// Fork an existing conversation by dropping the last `drop_last_messages`
138
- /// user/assistant messages from its transcript and starting a new
139
+ /// Fork an existing conversation by taking messages up to the given position
140
+ /// (not including the message at the given position) and starting a new
139
141
/// conversation with identical configuration (unless overridden by the
140
142
/// caller's `config`). The new conversation will have a fresh id.
141
143
pub async fn fork_conversation (
142
144
& self ,
143
- num_messages_to_drop : usize ,
145
+ nth_user_message : usize ,
144
146
config : Config ,
145
147
path : PathBuf ,
146
148
) -> CodexResult < NewConversation > {
147
149
// Compute the prefix up to the cut point.
148
150
let history = RolloutRecorder :: get_rollout_history ( & path) . await ?;
149
- let history = truncate_after_dropping_last_messages ( history, num_messages_to_drop ) ;
151
+ let history = truncate_before_nth_user_message ( history, nth_user_message ) ;
150
152
151
153
// Spawn a new conversation with the computed initial history.
152
154
let auth_manager = self . auth_manager . clone ( ) ;
@@ -159,33 +161,30 @@ impl ConversationManager {
159
161
}
160
162
}
161
163
162
- /// Return a prefix of `items` obtained by dropping the last `n` user messages
163
- /// and all items that follow them.
164
- fn truncate_after_dropping_last_messages ( history : InitialHistory , n : usize ) -> InitialHistory {
165
- if n == 0 {
166
- return InitialHistory :: Forked ( history. get_rollout_items ( ) ) ;
167
- }
168
-
169
- // Work directly on rollout items, and cut the vector at the nth-from-last user message input.
164
+ /// Return a prefix of `items` obtained by cutting strictly before the nth user message
165
+ /// (0-based) and all items that follow it.
166
+ fn truncate_before_nth_user_message ( history : InitialHistory , n : usize ) -> InitialHistory {
167
+ // Work directly on rollout items, and cut the vector at the nth user message input.
170
168
let items: Vec < RolloutItem > = history. get_rollout_items ( ) ;
171
169
172
170
// Find indices of user message inputs in rollout order.
173
171
let mut user_positions: Vec < usize > = Vec :: new ( ) ;
174
172
for ( idx, item) in items. iter ( ) . enumerate ( ) {
175
- if let RolloutItem :: ResponseItem ( ResponseItem :: Message { role, .. } ) = item
173
+ if let RolloutItem :: ResponseItem ( ResponseItem :: Message { role, content , .. } ) = item
176
174
&& role == "user"
175
+ && content_items_to_text ( content) . is_some_and ( |text| !is_session_prefix_message ( & text) )
177
176
{
178
177
user_positions. push ( idx) ;
179
178
}
180
179
}
181
180
182
- // If fewer than n user messages exist, treat as empty.
183
- if user_positions. len ( ) < n {
181
+ // If fewer than or equal to n user messages exist, treat as empty (out of range) .
182
+ if user_positions. len ( ) <= n {
184
183
return InitialHistory :: New ;
185
184
}
186
185
187
- // Cut strictly before the nth-from-last user message (do not keep the nth itself).
188
- let cut_idx = user_positions[ user_positions . len ( ) - n] ;
186
+ // Cut strictly before the nth user message (do not keep the nth itself).
187
+ let cut_idx = user_positions[ n] ;
189
188
let rolled: Vec < RolloutItem > = items. into_iter ( ) . take ( cut_idx) . collect ( ) ;
190
189
191
190
if rolled. is_empty ( ) {
@@ -198,9 +197,11 @@ fn truncate_after_dropping_last_messages(history: InitialHistory, n: usize) -> I
198
197
#[ cfg( test) ]
199
198
mod tests {
200
199
use super :: * ;
200
+ use crate :: codex:: make_session_and_context;
201
201
use codex_protocol:: models:: ContentItem ;
202
202
use codex_protocol:: models:: ReasoningItemReasoningSummary ;
203
203
use codex_protocol:: models:: ResponseItem ;
204
+ use pretty_assertions:: assert_eq;
204
205
205
206
fn user_msg ( text : & str ) -> ResponseItem {
206
207
ResponseItem :: Message {
@@ -252,7 +253,7 @@ mod tests {
252
253
. cloned ( )
253
254
. map ( RolloutItem :: ResponseItem )
254
255
. collect ( ) ;
255
- let truncated = truncate_after_dropping_last_messages ( InitialHistory :: Forked ( initial) , 1 ) ;
256
+ let truncated = truncate_before_nth_user_message ( InitialHistory :: Forked ( initial) , 1 ) ;
256
257
let got_items = truncated. get_rollout_items ( ) ;
257
258
let expected_items = vec ! [
258
259
RolloutItem :: ResponseItem ( items[ 0 ] . clone( ) ) ,
@@ -269,7 +270,37 @@ mod tests {
269
270
. cloned ( )
270
271
. map ( RolloutItem :: ResponseItem )
271
272
. collect ( ) ;
272
- let truncated2 = truncate_after_dropping_last_messages ( InitialHistory :: Forked ( initial2) , 2 ) ;
273
+ let truncated2 = truncate_before_nth_user_message ( InitialHistory :: Forked ( initial2) , 2 ) ;
273
274
assert ! ( matches!( truncated2, InitialHistory :: New ) ) ;
274
275
}
276
+
277
+ #[ test]
278
+ fn ignores_session_prefix_messages_when_truncating ( ) {
279
+ let ( session, turn_context) = make_session_and_context ( ) ;
280
+ let mut items = session. build_initial_context ( & turn_context) ;
281
+ items. push ( user_msg ( "feature request" ) ) ;
282
+ items. push ( assistant_msg ( "ack" ) ) ;
283
+ items. push ( user_msg ( "second question" ) ) ;
284
+ items. push ( assistant_msg ( "answer" ) ) ;
285
+
286
+ let rollout_items: Vec < RolloutItem > = items
287
+ . iter ( )
288
+ . cloned ( )
289
+ . map ( RolloutItem :: ResponseItem )
290
+ . collect ( ) ;
291
+
292
+ let truncated = truncate_before_nth_user_message ( InitialHistory :: Forked ( rollout_items) , 1 ) ;
293
+ let got_items = truncated. get_rollout_items ( ) ;
294
+
295
+ let expected: Vec < RolloutItem > = vec ! [
296
+ RolloutItem :: ResponseItem ( items[ 0 ] . clone( ) ) ,
297
+ RolloutItem :: ResponseItem ( items[ 1 ] . clone( ) ) ,
298
+ RolloutItem :: ResponseItem ( items[ 2 ] . clone( ) ) ,
299
+ ] ;
300
+
301
+ assert_eq ! (
302
+ serde_json:: to_value( & got_items) . unwrap( ) ,
303
+ serde_json:: to_value( & expected) . unwrap( )
304
+ ) ;
305
+ }
275
306
}
0 commit comments