Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
90d93c6
make core types gob-serializable
yarolegovich Sep 5, 2025
a70892e
in-memory task store
yarolegovich Sep 5, 2025
043ba2e
task update logic
yarolegovich Sep 5, 2025
80d3853
concurrent task execution and cancelation management
yarolegovich Sep 5, 2025
14da9f8
taskexec integration with default request handler
yarolegovich Sep 5, 2025
0df268e
prevent possibility of failed cancelation destroying the queue which …
yarolegovich Sep 8, 2025
67290a0
comments and t.Helper() calls
yarolegovich Sep 12, 2025
efd6922
Merge branch 'main' into yarolegovich/result-aggregation-3
yarolegovich Sep 16, 2025
ef1170a
lint
yarolegovich Sep 16, 2025
a1971d9
Merge branch 'main' into yarolegovich/result-aggregation-3
yarolegovich Sep 22, 2025
601aeda
Merge branch 'yarolegovich/result-aggregation-4' into yarolegovich/re…
yarolegovich Sep 22, 2025
789dfb2
artifact update logic
yarolegovich Sep 22, 2025
9e28fb2
OnSendMessageStream() and tests
yarolegovich Sep 22, 2025
2a84c8b
test and fix
yarolegovich Sep 22, 2025
2fa71f4
PR review improvements
yarolegovich Sep 22, 2025
eeba7ff
Merge branch 'main' into yarolegovich/result-aggregation-4
yarolegovich Oct 2, 2025
c5c9812
fix blocking on nil channel and empty yield(nil, nil) in defer
yarolegovich Oct 2, 2025
c7251a7
lint
yarolegovich Oct 2, 2025
d7fdefe
Merge branch 'yarolegovich/result-aggregation-4' into yarolegovich/ar…
yarolegovich Oct 3, 2025
8dbd52b
Merge branch 'main' into yarolegovich/artifacts
yarolegovich Oct 8, 2025
12bb659
refactor messages
yarolegovich Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions a2asrv/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,32 +142,9 @@ func (h *defaultRequestHandler) OnCancelTask(ctx context.Context, params *a2a.Ta
}

func (h *defaultRequestHandler) OnSendMessage(ctx context.Context, params *a2a.MessageSendParams) (a2a.SendMessageResult, error) {
if params.Message == nil {
return nil, fmt.Errorf("message is required: %w", a2a.ErrInvalidRequest)
}

var task *a2a.Task
if len(params.Message.TaskID) == 0 {
task = taskupdate.NewSubmittedTask(params.Message)
} else {
localResult, err := h.taskStore.Get(ctx, params.Message.TaskID)
if err != nil {
return nil, err
}
task = localResult
}

// TODO(yarolegovich): move to task-locked section in executor https://github.com/a2aproject/a2a-go/issues/18
reqCtx := RequestContext{Request: params, TaskID: task.ID, ContextID: task.ContextID}
processor := &processor{updateManager: taskupdate.NewManager(h.taskStore, task)}
executor := &executor{
agent: h.agentExecutor,
reqCtx: reqCtx,
processor: processor,
}
execution, err := h.taskExecutor.Execute(ctx, task.ID, executor)
execution, err := h.handleSendMessage(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to execute: %w", err)
return nil, err
}

for event, err := range execution.Events(ctx) {
Expand All @@ -182,6 +159,17 @@ func (h *defaultRequestHandler) OnSendMessage(ctx context.Context, params *a2a.M
return execution.Result(ctx)
}

func (h *defaultRequestHandler) OnSendMessageStream(ctx context.Context, params *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] {
execution, err := h.handleSendMessage(ctx, params)
if err != nil {
return func(yield func(a2a.Event, error) bool) {
yield(nil, err)
}
}

return execution.Events(ctx)
}

func (h *defaultRequestHandler) OnResubscribeToTask(ctx context.Context, params *a2a.TaskIDParams) iter.Seq2[a2a.Event, error] {
exec, ok := h.taskExecutor.GetExecution(params.ID)
if !ok {
Expand All @@ -192,8 +180,31 @@ func (h *defaultRequestHandler) OnResubscribeToTask(ctx context.Context, params
return exec.Events(ctx)
}

func (h *defaultRequestHandler) OnSendMessageStream(ctx context.Context, message *a2a.MessageSendParams) iter.Seq2[a2a.Event, error] {
return nil
func (h *defaultRequestHandler) handleSendMessage(ctx context.Context, params *a2a.MessageSendParams) (*taskexec.Execution, error) {
if params.Message == nil {
return nil, fmt.Errorf("message is required: %w", a2a.ErrInvalidRequest)
}

var task *a2a.Task
if len(params.Message.TaskID) == 0 {
task = taskupdate.NewSubmittedTask(params.Message)
} else {
localResult, err := h.taskStore.Get(ctx, params.Message.TaskID)
if err != nil {
return nil, err
}
task = localResult
}

// TODO(yarolegovich): move to task-locked section in executor https://github.com/a2aproject/a2a-go/issues/18
reqCtx := RequestContext{Request: params, TaskID: task.ID, ContextID: task.ContextID}
processor := &processor{updateManager: taskupdate.NewManager(h.taskStore, task)}
executor := &executor{
agent: h.agentExecutor,
reqCtx: reqCtx,
processor: processor,
}
return h.taskExecutor.Execute(ctx, task.ID, executor)
}

func (h *defaultRequestHandler) OnGetTaskPushConfig(ctx context.Context, params *a2a.GetTaskPushConfigParams) (*a2a.TaskPushConfig, error) {
Expand Down
115 changes: 101 additions & 14 deletions a2asrv/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"time"

"github.com/a2aproject/a2a-go/a2a"
"github.com/a2aproject/a2a-go/a2asrv/eventqueue"
"github.com/a2aproject/a2a-go/internal/taskstore"
"github.com/google/go-cmp/cmp"
)

var fixedTime = time.Now()
Expand Down Expand Up @@ -158,7 +158,14 @@ func newTaskWithMeta(task *a2a.Task, meta map[string]any) *a2a.Task {
return &a2a.Task{ID: task.ID, ContextID: task.ContextID, Metadata: meta}
}

func newArtifactEvent(task *a2a.Task, aid a2a.ArtifactID, parts ...a2a.Part) *a2a.TaskArtifactUpdateEvent {
ev := a2a.NewArtifactEvent(task, parts...)
ev.Artifact.ID = aid
return ev
}

func TestDefaultRequestHandler_OnSendMessage(t *testing.T) {
artifactID := a2a.NewArtifactID()
taskSeed := &a2a.Task{ID: a2a.NewTaskID(), ContextID: a2a.NewContextID()}

tests := []struct {
Expand Down Expand Up @@ -254,6 +261,43 @@ func TestDefaultRequestHandler_OnSendMessage(t *testing.T) {
},
wantResult: newTaskWithStatus(taskSeed, a2a.TaskStateCompleted, "no status change history"),
},
{
name: "task artifact streaming",
agentEvents: []a2a.Event{
newTaskStatusUpdate(taskSeed, a2a.TaskStateSubmitted, "Ack"),
newArtifactEvent(taskSeed, artifactID, a2a.TextPart{Text: "Hello"}),
a2a.NewArtifactUpdateEvent(taskSeed, artifactID, a2a.TextPart{Text: ", world!"}),
newFinalTaskStatusUpdate(taskSeed, a2a.TaskStateCompleted, "Done!"),
},
wantResult: &a2a.Task{
ID: taskSeed.ID,
ContextID: taskSeed.ContextID,
Status: a2a.TaskStatus{State: a2a.TaskStateCompleted, Message: newAgentMessage("Done!"), Timestamp: &fixedTime},
History: []*a2a.Message{newAgentMessage("Ack")},
Artifacts: []*a2a.Artifact{
{ID: artifactID, Parts: a2a.ContentParts{a2a.TextPart{Text: "Hello"}, a2a.TextPart{Text: ", world!"}}},
},
},
},
{
name: "task with multiple artifacts",
agentEvents: []a2a.Event{
newTaskStatusUpdate(taskSeed, a2a.TaskStateSubmitted, "Ack"),
newArtifactEvent(taskSeed, artifactID, a2a.TextPart{Text: "Hello"}),
newArtifactEvent(taskSeed, artifactID+"2", a2a.TextPart{Text: "World"}),
newFinalTaskStatusUpdate(taskSeed, a2a.TaskStateCompleted, "Done!"),
},
wantResult: &a2a.Task{
ID: taskSeed.ID,
ContextID: taskSeed.ContextID,
Status: a2a.TaskStatus{State: a2a.TaskStateCompleted, Message: newAgentMessage("Done!"), Timestamp: &fixedTime},
History: []*a2a.Message{newAgentMessage("Ack")},
Artifacts: []*a2a.Artifact{
{ID: artifactID, Parts: a2a.ContentParts{a2a.TextPart{Text: "Hello"}}},
{ID: artifactID + "2", Parts: a2a.ContentParts{a2a.TextPart{Text: "World"}}},
},
},
},
{
name: "fails on non-existent task reference",
input: &a2a.MessageSendParams{Message: &a2a.Message{TaskID: "non-existent", ID: "test-message"}},
Expand All @@ -266,6 +310,11 @@ func TestDefaultRequestHandler_OnSendMessage(t *testing.T) {
}

for _, tt := range tests {
input := &a2a.MessageSendParams{Message: &a2a.Message{TaskID: taskSeed.ID}}
if tt.input != nil {
input = tt.input
}

t.Run(tt.name, func(t *testing.T) {
ctx := t.Context()
var qm eventqueue.Manager
Expand All @@ -276,22 +325,15 @@ func TestDefaultRequestHandler_OnSendMessage(t *testing.T) {
}
store := taskstore.NewMem()
_ = store.Save(ctx, taskSeed)
handler := newTestHandler(
WithEventQueueManager(qm),
WithTaskStore(store),
)
input := &a2a.MessageSendParams{Message: &a2a.Message{TaskID: taskSeed.ID}}
if tt.input != nil {
input = tt.input
}
handler := newTestHandler(WithEventQueueManager(qm), WithTaskStore(store))

result, gotErr := handler.OnSendMessage(ctx, input)
if tt.wantErr == nil {
if gotErr != nil {
t.Fatalf("OnSendMessage() error = %v, wantErr nil", gotErr)
}
if !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("OnSendMessage() got = %v, want %v", result, tt.wantResult)
if diff := cmp.Diff(tt.wantResult, result); diff != "" {
t.Errorf("OnSendMessage() (+got,-want):\ngot = %v\nwant %v\ndiff = %s", result, tt.wantResult, diff)
}
} else {
if gotErr == nil {
Expand All @@ -302,6 +344,54 @@ func TestDefaultRequestHandler_OnSendMessage(t *testing.T) {
}
}
})

t.Run(tt.name+" (streaming)", func(t *testing.T) {
ctx := t.Context()
var qm eventqueue.Manager
if tt.agentEvents == nil {
qm = newEventReplayQueueManager()
} else {
qm = newEventReplayQueueManager(tt.agentEvents...)
}
store := taskstore.NewMem()
_ = store.Save(ctx, taskSeed)
handler := newTestHandler(WithEventQueueManager(qm), WithTaskStore(store))

eventI := 0
var streamErr bool
for got, gotErr := range handler.OnSendMessageStream(ctx, input) {
var want a2a.Event
if eventI < len(tt.agentEvents) {
Copy link
Collaborator Author

@yarolegovich yarolegovich Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these assertions are refactored in another PR

want = tt.agentEvents[eventI]
eventI += 1
} else if streamErr {
t.Errorf("expected stream close after %v, got %v, %v", eventI, got, gotErr)
} else if tt.wantErr != nil {
streamErr = true
} else {
t.Errorf("expected error after %d-th event, got %v, %v", eventI, got, gotErr)
}

if streamErr {
if gotErr == nil {
t.Fatalf("OnSendMessageStream() error = nil, wantErr %q", tt.wantErr)
}
if gotErr.Error() != tt.wantErr.Error() {
t.Errorf("OnSendMessageStream() error = %v, wantErr %v", gotErr, tt.wantErr)
}
} else {
if gotErr != nil {
t.Fatalf("OnSendMessageStream() error = %v, wantErr nil", gotErr)
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("OnSendMessageStream() (+got,-want):\ngot = %v\nwant %v\ndiff = %s", got, want, diff)
}
}
}
if tt.wantErr == nil && eventI != len(tt.agentEvents) {
t.Errorf("OnSendMessageStream() received %d events, want %d", eventI, len(tt.agentEvents))
}
})
}
}

Expand Down Expand Up @@ -346,9 +436,6 @@ func TestDefaultRequestHandler_Unimplemented(t *testing.T) {
if _, err := handler.OnGetTask(ctx, &a2a.TaskQueryParams{}); !errors.Is(err, ErrUnimplemented) {
t.Errorf("OnGetTask: expected unimplemented error, got %v", err)
}
if seq := handler.OnSendMessageStream(ctx, &a2a.MessageSendParams{}); seq != nil {
t.Error("OnSendMessageStream: expected nil iterator, got non-nil")
}
if _, err := handler.OnGetTaskPushConfig(ctx, &a2a.GetTaskPushConfigParams{}); !errors.Is(err, ErrUnimplemented) {
t.Errorf("OnGetTaskPushConfig: expected unimplemented error, got %v", err)
}
Expand Down
44 changes: 42 additions & 2 deletions internal/taskupdate/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package taskupdate
import (
"context"
"fmt"
"slices"

"github.com/a2aproject/a2a-go/a2a"
)
Expand Down Expand Up @@ -75,8 +76,47 @@ func (mgr *Manager) Process(ctx context.Context, event a2a.Event) (*a2a.Task, er
}
}

func (mgr *Manager) updateArtifact(_ context.Context, _ *a2a.TaskArtifactUpdateEvent) (*a2a.Task, error) {
return nil, fmt.Errorf("not implemented")
func (mgr *Manager) updateArtifact(ctx context.Context, event *a2a.TaskArtifactUpdateEvent) (*a2a.Task, error) {
task := mgr.task

updateIdx := slices.IndexFunc(task.Artifacts, func(a *a2a.Artifact) bool {
return a.ID == event.Artifact.ID
})

if updateIdx < 0 {
if event.Append {
// TODO(yarolegovich): log "artifact for update not found" as Python does
return task, nil
}
task.Artifacts = append(task.Artifacts, event.Artifact)
if err := mgr.saver.Save(ctx, task); err != nil {
return nil, err
}
return task, nil
}

if !event.Append {
task.Artifacts[updateIdx] = event.Artifact
if err := mgr.saver.Save(ctx, task); err != nil {
return nil, err
}
return task, nil
}

toUpdate := task.Artifacts[updateIdx]
toUpdate.Parts = append(toUpdate.Parts, event.Artifact.Parts...)
if toUpdate.Metadata == nil && event.Artifact.Metadata != nil {
toUpdate.Metadata = event.Artifact.Metadata
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing handling of toUpdate.Metadata == nil case.

for k, v := range event.Artifact.Metadata {
toUpdate.Metadata[k] = v
}
}

if err := mgr.saver.Save(ctx, task); err != nil {
return nil, err
}
return task, nil
}

func (mgr *Manager) updateStatus(ctx context.Context, event *a2a.TaskStatusUpdateEvent) (*a2a.Task, error) {
Expand Down
Loading