From 60f3996de022dc5c3f6478c770648d2e26b349f0 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 10:02:18 -0400 Subject: [PATCH 01/17] initial changes --- go.mod | 1 + go.sum | 2 ++ iter/iter.go | 39 ++++++++++++++++++++++++++++++++------- iter/map.go | 29 +++++++++++++++++++++++------ 4 files changed, 58 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index e06798f..df2c3c7 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.19 require ( github.com/stretchr/testify v1.8.1 go.uber.org/multierr v1.9.0 + golang.org/x/sync v0.2.0 ) require ( diff --git a/go.sum b/go.sum index 2eaf607..8ee06ca 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/iter/iter.go b/iter/iter.go index 124b4f9..0fa1a5c 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -1,10 +1,11 @@ package iter import ( + "context" "runtime" "sync/atomic" - "github.com/sourcegraph/conc" + "golang.org/x/sync/errgroup" ) // defaultMaxGoroutines returns the default maximum number of @@ -35,6 +36,10 @@ type Iterator[T any] struct { // a configurable goroutine limit, use a custom Iterator. func ForEach[T any](input []T, f func(*T)) { Iterator[T]{}.ForEach(input, f) } +func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error { + return Iterator[T]{}.ForEachCtx(ctx, input, f) +} + // ForEach executes f in parallel over each element in input, // using up to the Iterator's configured maximum number of // goroutines. @@ -50,13 +55,30 @@ func (iter Iterator[T]) ForEach(input []T, f func(*T)) { }) } +func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { + return iter.ForEachIdxCtx(ctx, input, func(_ context.Context, _ int, input *T) error { + return f(ctx, input) + }) +} + // ForEachIdx is the same as ForEach except it also provides the // index of the element to the callback. func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(input, f) } +func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error { + return Iterator[T]{}.ForEachIdxCtx(ctx, input, f) +} + // ForEachIdx is the same as ForEach except it also provides the // index of the element to the callback. func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { + _ = iter.ForEachIdxCtx(context.Background(), input, func(_ context.Context, idx int, input *T) error { + f(idx, input) + return nil + }) +} + +func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate iter.MaxGoroutines = defaultMaxGoroutines() @@ -68,18 +90,21 @@ func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { iter.MaxGoroutines = numInput } + eg, ectx := errgroup.WithContext(ctx) var idx atomic.Int64 // Create the task outside the loop to avoid extra closure allocations. - task := func() { + task := func() error { i := int(idx.Add(1) - 1) - for ; i < numInput; i = int(idx.Add(1) - 1) { - f(i, &input[i]) + for ; i < numInput && ectx.Err() == nil; i = int(idx.Add(1) - 1) { + if err := f(ectx, i, &input[i]); err != nil { + return err + } } + return nil } - var wg conc.WaitGroup for i := 0; i < iter.MaxGoroutines; i++ { - wg.Go(task) + eg.Go(task) } - wg.Wait() + return eg.Wait() } diff --git a/iter/map.go b/iter/map.go index efbe6bf..c874a7a 100644 --- a/iter/map.go +++ b/iter/map.go @@ -1,6 +1,7 @@ package iter import ( + "context" "sync" "github.com/sourcegraph/conc/internal/multierror" @@ -25,9 +26,8 @@ func Map[T, R any](input []T, f func(*T) R) []R { // // Map uses up to the configured Mapper's maximum number of goroutines. func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R { - res := make([]R, len(input)) - Iterator[T](m).ForEachIdx(input, func(i int, t *T) { - res[i] = f(t) + res, _ := m.MapErr(input, func(t *T) (R, error) { + return f(t), nil }) return res } @@ -41,25 +41,42 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) { return Mapper[T, R]{}.MapErr(input, f) } +func MapErrCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + return Mapper[T, R]{}.MapErrCtx(ctx, input, f) +} + // MapErr applies f to each element of the input, returning the mapped result // and a combined error of all returned errors. // // Map uses up to the configured Mapper's maximum number of goroutines. func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { var ( - res = make([]R, len(input)) errMux sync.Mutex errs error ) - Iterator[T](m).ForEachIdx(input, func(i int, t *T) { + // MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapErrCtx + res, _ := m.MapErrCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) { var err error - res[i], err = f(t) + var ires R + ires, err = f(t) if err != nil { errMux.Lock() // TODO: use stdlib errors once multierrors land in go 1.20 errs = multierror.Join(errs, err) errMux.Unlock() } + return ires, nil }) return res, errs } + +func (m Mapper[T, R]) MapErrCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + var ( + res = make([]R, len(input)) + ) + return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(ctx context.Context, i int, t *T) error { + var err error + res[i], err = f(ctx, t) + return err + }) +} From a22d62e24245de7a93167336daefe764fb6ec908 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 12:22:22 -0400 Subject: [PATCH 02/17] using contextpool over errgroup --- go.mod | 1 - go.sum | 2 -- iter/iter.go | 22 ++++++++++------------ 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index df2c3c7..e06798f 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.19 require ( github.com/stretchr/testify v1.8.1 go.uber.org/multierr v1.9.0 - golang.org/x/sync v0.2.0 ) require ( diff --git a/go.sum b/go.sum index 8ee06ca..2eaf607 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= -golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= -golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/iter/iter.go b/iter/iter.go index 0fa1a5c..8b39038 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -5,7 +5,7 @@ import ( "runtime" "sync/atomic" - "golang.org/x/sync/errgroup" + "github.com/sourcegraph/conc/pool" ) // defaultMaxGoroutines returns the default maximum number of @@ -85,26 +85,24 @@ func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(con } numInput := len(input) - if iter.MaxGoroutines > numInput { - // No more concurrent tasks than the number of input items. - iter.MaxGoroutines = numInput - } - - eg, ectx := errgroup.WithContext(ctx) var idx atomic.Int64 // Create the task outside the loop to avoid extra closure allocations. - task := func() error { + task := func(ctx context.Context) error { i := int(idx.Add(1) - 1) - for ; i < numInput && ectx.Err() == nil; i = int(idx.Add(1) - 1) { - if err := f(ectx, i, &input[i]); err != nil { + for ; i < numInput && ctx.Err() == nil; i = int(idx.Add(1) - 1) { + if err := f(ctx, i, &input[i]); err != nil { return err } } return nil } + runner := pool.New().WithContext(ctx). + WithCancelOnError(). + WithFirstError(). + WithMaxGoroutines(iter.MaxGoroutines) for i := 0; i < iter.MaxGoroutines; i++ { - eg.Go(task) + runner.Go(task) } - return eg.Wait() + return runner.Wait() } From 57ae9ca743ab29f61b9c212cb0e5263ac62133f5 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 12:27:54 -0400 Subject: [PATCH 03/17] checking input size --- iter/iter.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/iter/iter.go b/iter/iter.go index 8b39038..9132ce3 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -85,6 +85,11 @@ func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(con } numInput := len(input) + if iter.MaxGoroutines > numInput && numInput > 0 { + // No more concurrent tasks than the number of input items. + iter.MaxGoroutines = numInput + } + var idx atomic.Int64 // Create the task outside the loop to avoid extra closure allocations. task := func(ctx context.Context) error { From 722e19d9b23962ac7e994bcb7c396d27bb312159 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 12:30:24 -0400 Subject: [PATCH 04/17] removing declarations --- iter/map.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/iter/map.go b/iter/map.go index c874a7a..4106e57 100644 --- a/iter/map.go +++ b/iter/map.go @@ -56,9 +56,7 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { ) // MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapErrCtx res, _ := m.MapErrCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) { - var err error - var ires R - ires, err = f(t) + ires, err := f(t) if err != nil { errMux.Lock() // TODO: use stdlib errors once multierrors land in go 1.20 From e9a88bf6add3619847f2d540612ad52be0330954 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 12:32:03 -0400 Subject: [PATCH 05/17] Moving function --- iter/map.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/iter/map.go b/iter/map.go index 4106e57..8c3b42c 100644 --- a/iter/map.go +++ b/iter/map.go @@ -41,10 +41,6 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) { return Mapper[T, R]{}.MapErr(input, f) } -func MapErrCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { - return Mapper[T, R]{}.MapErrCtx(ctx, input, f) -} - // MapErr applies f to each element of the input, returning the mapped result // and a combined error of all returned errors. // @@ -68,6 +64,10 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { return res, errs } +func MapErrCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + return Mapper[T, R]{}.MapErrCtx(ctx, input, f) +} + func (m Mapper[T, R]) MapErrCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) From c6581d390afe2c7a2e89dbb5846303a62f0d4e8c Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 12:34:24 -0400 Subject: [PATCH 06/17] moving new functions down --- iter/iter.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index 9132ce3..af3c099 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -36,10 +36,6 @@ type Iterator[T any] struct { // a configurable goroutine limit, use a custom Iterator. func ForEach[T any](input []T, f func(*T)) { Iterator[T]{}.ForEach(input, f) } -func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error { - return Iterator[T]{}.ForEachCtx(ctx, input, f) -} - // ForEach executes f in parallel over each element in input, // using up to the Iterator's configured maximum number of // goroutines. @@ -55,20 +51,10 @@ func (iter Iterator[T]) ForEach(input []T, f func(*T)) { }) } -func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { - return iter.ForEachIdxCtx(ctx, input, func(_ context.Context, _ int, input *T) error { - return f(ctx, input) - }) -} - // ForEachIdx is the same as ForEach except it also provides the // index of the element to the callback. func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(input, f) } -func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error { - return Iterator[T]{}.ForEachIdxCtx(ctx, input, f) -} - // ForEachIdx is the same as ForEach except it also provides the // index of the element to the callback. func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { @@ -78,6 +64,20 @@ func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { }) } +func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error { + return Iterator[T]{}.ForEachIdxCtx(ctx, input, f) +} + +func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error { + return Iterator[T]{}.ForEachCtx(ctx, input, f) +} + +func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { + return iter.ForEachIdxCtx(ctx, input, func(_ context.Context, _ int, input *T) error { + return f(ctx, input) + }) +} + func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate From 4a15de50a0e38b73284c6e2d28ad28035985f730 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 14:09:14 -0400 Subject: [PATCH 07/17] renaming outer context --- iter/iter.go | 16 ++++++++-------- iter/map.go | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index af3c099..73ac2d7 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -64,21 +64,21 @@ func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { }) } -func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error { - return Iterator[T]{}.ForEachIdxCtx(ctx, input, f) +func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error { + return Iterator[T]{}.ForEachIdxCtx(octx, input, f) } -func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error { - return Iterator[T]{}.ForEachCtx(ctx, input, f) +func ForEachCtx[T any](octx context.Context, input []T, f func(context.Context, *T) error) error { + return Iterator[T]{}.ForEachCtx(octx, input, f) } -func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { - return iter.ForEachIdxCtx(ctx, input, func(_ context.Context, _ int, input *T) error { +func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(context.Context, *T) error) error { + return iter.ForEachIdxCtx(octx, input, func(ctx context.Context, _ int, input *T) error { return f(ctx, input) }) } -func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error { +func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate iter.MaxGoroutines = defaultMaxGoroutines() @@ -102,7 +102,7 @@ func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(con return nil } - runner := pool.New().WithContext(ctx). + runner := pool.New().WithContext(octx). WithCancelOnError(). WithFirstError(). WithMaxGoroutines(iter.MaxGoroutines) diff --git a/iter/map.go b/iter/map.go index 8c3b42c..0e56dab 100644 --- a/iter/map.go +++ b/iter/map.go @@ -64,15 +64,15 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { return res, errs } -func MapErrCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { - return Mapper[T, R]{}.MapErrCtx(ctx, input, f) +func MapErrCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + return Mapper[T, R]{}.MapErrCtx(octx, input, f) } -func (m Mapper[T, R]) MapErrCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { +func (m Mapper[T, R]) MapErrCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) ) - return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(ctx context.Context, i int, t *T) error { + return res, Iterator[T](m).ForEachIdxCtx(octx, input, func(ctx context.Context, i int, t *T) error { var err error res[i], err = f(ctx, t) return err From 3ae45b68d9d6eb38d5f0eb1a84182dafb5a1a09a Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 14:10:51 -0400 Subject: [PATCH 08/17] splitting init --- iter/iter.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/iter/iter.go b/iter/iter.go index 73ac2d7..a674939 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -102,7 +102,8 @@ func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(co return nil } - runner := pool.New().WithContext(octx). + runner := pool.New(). + WithContext(octx). WithCancelOnError(). WithFirstError(). WithMaxGoroutines(iter.MaxGoroutines) From 11cbc39ab9c46ae486cb8e4393e5abaab56cd5ef Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 14:49:07 -0400 Subject: [PATCH 09/17] added comments --- iter/iter.go | 22 +++++++++++++++++----- iter/map.go | 6 ++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index a674939..ba4c140 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -64,20 +64,32 @@ func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { }) } -func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error { - return Iterator[T]{}.ForEachIdxCtx(octx, input, f) -} - +// ForEachCtx is the same as ForEach except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned func ForEachCtx[T any](octx context.Context, input []T, f func(context.Context, *T) error) error { return Iterator[T]{}.ForEachCtx(octx, input, f) } +// ForEachCtx is the same as ForEach except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(context.Context, *T) error) error { return iter.ForEachIdxCtx(octx, input, func(ctx context.Context, _ int, input *T) error { return f(ctx, input) }) } +// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned +func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error { + return Iterator[T]{}.ForEachIdxCtx(octx, input, f) +} + +// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate @@ -99,7 +111,7 @@ func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(co return err } } - return nil + return ctx.Err() // nil if the context was never cancelled } runner := pool.New(). diff --git a/iter/map.go b/iter/map.go index 0e56dab..8e0a2a6 100644 --- a/iter/map.go +++ b/iter/map.go @@ -64,10 +64,16 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { return res, errs } +// MapErrCtx is the same as MapErr except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned func MapErrCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { return Mapper[T, R]{}.MapErrCtx(octx, input, f) } +// MapErrCtx is the same as MapErr except it also accepts a context +// that it uses to manages the execution of tasks. +// The context is cancelled on task failure and the first error is returned func (m Mapper[T, R]) MapErrCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) From 1d04c682487fb66a29a9e584ef0ecc3a3526708a Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 16:28:10 -0400 Subject: [PATCH 10/17] testing Map --- iter/map_test.go | 67 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/iter/map_test.go b/iter/map_test.go index 28be912..70de5cd 100644 --- a/iter/map_test.go +++ b/iter/map_test.go @@ -1,6 +1,7 @@ package iter import ( + "context" "errors" "fmt" "testing" @@ -82,11 +83,51 @@ func TestMap(t *testing.T) { func TestMapErr(t *testing.T) { t.Parallel() + err1 := errors.New("error1") + err2 := errors.New("error2") + + t.Run("error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + res, err := MapErr(ints, func(val *int) (int, error) { + if *val == 3 { + return 0, err1 + } + return *val + 1, nil + }) + require.ErrorIs(t, err, err1) + require.Equal(t, []int{2, 3, 0, 5, 6}, res) + require.Equal(t, []int{1, 2, 3, 4, 5}, ints) + }) + + t.Run("first errors are propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + res, err := MapErr(ints, func(val *int) (int, error) { + if *val == 3 { + return 0, err1 + } + if *val == 4 { + return 0, err2 + } + return *val + 1, nil + }) + require.ErrorIs(t, err, err1) + require.ErrorIs(t, err, err2) + require.Equal(t, []int{2, 3, 0, 0, 6}, res) + require.Equal(t, []int{1, 2, 3, 4, 5}, ints) + }) +} + +func TestMapErrCtx(t *testing.T) { + t.Parallel() + + bgctx := context.Background() t.Run("empty", func(t *testing.T) { t.Parallel() f := func() { ints := []int{} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { panic("this should never be called") }) require.NoError(t, err) @@ -99,7 +140,7 @@ func TestMapErr(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - _, _ = MapErr(ints, func(val *int) (int, error) { + _, _ = MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { panic("super bad thing happened") }) } @@ -109,7 +150,7 @@ func TestMapErr(t *testing.T) { t.Run("mutating inputs is fine, though not recommended", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { *val += 1 return 0, nil }) @@ -121,7 +162,7 @@ func TestMapErr(t *testing.T) { t.Run("basic increment", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { return *val + 1, nil }) require.NoError(t, err) @@ -130,26 +171,26 @@ func TestMapErr(t *testing.T) { }) err1 := errors.New("error1") - err2 := errors.New("error1") + err2 := errors.New("error2") t.Run("error is propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { if *val == 3 { return 0, err1 } return *val + 1, nil }) require.ErrorIs(t, err, err1) - require.Equal(t, []int{2, 3, 0, 5, 6}, res) + require.Equal(t, []int{2, 3, 0, 0, 0}, res) require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) - t.Run("multiple errors are propagated", func(t *testing.T) { + t.Run("first error is propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErr(ints, func(val *int) (int, error) { + res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { if *val == 3 { return 0, err1 } @@ -159,21 +200,21 @@ func TestMapErr(t *testing.T) { return *val + 1, nil }) require.ErrorIs(t, err, err1) - require.ErrorIs(t, err, err2) - require.Equal(t, []int{2, 3, 0, 0, 6}, res) + require.Equal(t, []int{2, 3, 0, 0, 0}, res) require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - res := Map(ints, func(val *int) int { - return 1 + res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + return 1, nil }) expected := make([]int, 10000) for i := 0; i < 10000; i++ { expected[i] = 1 } require.Equal(t, expected, res) + require.NoError(t, err) }) } From d71cc0582269668d4ec8485ffdfbb1e5733916aa Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Mon, 15 May 2023 16:48:53 -0400 Subject: [PATCH 11/17] tests --- iter/iter_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++----- iter/map_test.go | 2 ++ 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/iter/iter_test.go b/iter/iter_test.go index d65af2b..86b1126 100644 --- a/iter/iter_test.go +++ b/iter/iter_test.go @@ -1,6 +1,8 @@ package iter import ( + "context" + "errors" "fmt" "strconv" "sync/atomic" @@ -70,16 +72,18 @@ func TestIterator(t *testing.T) { }) } -func TestForEachIdx(t *testing.T) { +func TestForEachIdxCtx(t *testing.T) { t.Parallel() + bgctx := context.Background() t.Run("empty", func(t *testing.T) { t.Parallel() f := func() { ints := []int{} - ForEachIdx(ints, func(i int, val *int) { + err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { panic("this should never be called") }) + require.NoError(t, err) } require.NotPanics(t, f) }) @@ -88,9 +92,10 @@ func TestForEachIdx(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - ForEachIdx(ints, func(i int, val *int) { - panic("super bad thing happened") - }) + ForEachIdxCtx(bgctx, ints, + func(ctx context.Context, i int, val *int) error { + panic("super bad thing happened") + }) } require.Panics(t, f) }) @@ -98,23 +103,46 @@ func TestForEachIdx(t *testing.T) { t.Run("mutating inputs is fine", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - ForEachIdx(ints, func(i int, val *int) { + err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { *val += 1 + return nil }) require.Equal(t, []int{2, 3, 4, 5, 6}, ints) + require.NoError(t, err) }) t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - ForEachIdx(ints, func(i int, val *int) { + err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { *val = i + return nil }) expected := make([]int, 10000) for i := 0; i < 10000; i++ { expected[i] = i } require.Equal(t, expected, ints) + require.NoError(t, err) + }) + + err1 := errors.New("error1") + err2 := errors.New("error2") + + t.Run("first error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + err := ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { + if *val == 3 { + return err1 + } + if *val == 4 { + return err2 + } + return nil + }) + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) }) } @@ -166,6 +194,41 @@ func TestForEach(t *testing.T) { }) } +func TestForEachCtx(t *testing.T) { + t.Parallel() + + bgctx := context.Background() + t.Run("mutating inputs is fine", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + err := ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error { + *val += 1 + return nil + }) + require.Equal(t, []int{2, 3, 4, 5, 6}, ints) + require.NoError(t, err) + }) + + err1 := errors.New("error1") + err2 := errors.New("error2") + + t.Run("first error is propagated", func(t *testing.T) { + t.Parallel() + ints := []int{1, 2, 3, 4, 5} + err := ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error { + if *val == 3 { + return err1 + } + if *val == 4 { + return err2 + } + return nil + }) + require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) + }) +} + func BenchmarkForEach(b *testing.B) { for _, count := range []int{0, 1, 8, 100, 1000, 10000, 100000} { b.Run(strconv.Itoa(count), func(b *testing.B) { diff --git a/iter/map_test.go b/iter/map_test.go index 70de5cd..3cd3a0e 100644 --- a/iter/map_test.go +++ b/iter/map_test.go @@ -183,6 +183,7 @@ func TestMapErrCtx(t *testing.T) { return *val + 1, nil }) require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) require.Equal(t, []int{2, 3, 0, 0, 0}, res) require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) @@ -200,6 +201,7 @@ func TestMapErrCtx(t *testing.T) { return *val + 1, nil }) require.ErrorIs(t, err, err1) + require.NotErrorIs(t, err, err2) require.Equal(t, []int{2, 3, 0, 0, 0}, res) require.Equal(t, []int{1, 2, 3, 4, 5}, ints) }) From 2f570b7c92d58adb577bf6d3712836c2ae74893d Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Wed, 31 May 2023 12:06:56 -0400 Subject: [PATCH 12/17] Renaming MapErrCtx to MapCtx --- iter/map.go | 12 ++++++------ iter/map_test.go | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/iter/map.go b/iter/map.go index 8e0a2a6..6a675e9 100644 --- a/iter/map.go +++ b/iter/map.go @@ -51,7 +51,7 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { errs error ) // MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapErrCtx - res, _ := m.MapErrCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) { + res, _ := m.MapCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) { ires, err := f(t) if err != nil { errMux.Lock() @@ -64,17 +64,17 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { return res, errs } -// MapErrCtx is the same as MapErr except it also accepts a context +// MapCtx is the same as Map except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned -func MapErrCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { - return Mapper[T, R]{}.MapErrCtx(octx, input, f) +func MapCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + return Mapper[T, R]{}.MapCtx(octx, input, f) } -// MapErrCtx is the same as MapErr except it also accepts a context +// MapCtx is the same as Map except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned -func (m Mapper[T, R]) MapErrCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { +func (m Mapper[T, R]) MapCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) ) diff --git a/iter/map_test.go b/iter/map_test.go index 3cd3a0e..93a6657 100644 --- a/iter/map_test.go +++ b/iter/map_test.go @@ -119,7 +119,7 @@ func TestMapErr(t *testing.T) { }) } -func TestMapErrCtx(t *testing.T) { +func TestMapCtx(t *testing.T) { t.Parallel() bgctx := context.Background() @@ -127,7 +127,7 @@ func TestMapErrCtx(t *testing.T) { t.Parallel() f := func() { ints := []int{} - res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + res, err := MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { panic("this should never be called") }) require.NoError(t, err) @@ -140,7 +140,7 @@ func TestMapErrCtx(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - _, _ = MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + _, _ = MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { panic("super bad thing happened") }) } @@ -150,7 +150,7 @@ func TestMapErrCtx(t *testing.T) { t.Run("mutating inputs is fine, though not recommended", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + res, err := MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { *val += 1 return 0, nil }) @@ -162,7 +162,7 @@ func TestMapErrCtx(t *testing.T) { t.Run("basic increment", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + res, err := MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { return *val + 1, nil }) require.NoError(t, err) @@ -176,7 +176,7 @@ func TestMapErrCtx(t *testing.T) { t.Run("error is propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + res, err := MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { if *val == 3 { return 0, err1 } @@ -191,7 +191,7 @@ func TestMapErrCtx(t *testing.T) { t.Run("first error is propagated", func(t *testing.T) { t.Parallel() ints := []int{1, 2, 3, 4, 5} - res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + res, err := MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { if *val == 3 { return 0, err1 } @@ -209,7 +209,7 @@ func TestMapErrCtx(t *testing.T) { t.Run("huge inputs", func(t *testing.T) { t.Parallel() ints := make([]int, 10000) - res, err := MapErrCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { + res, err := MapCtx(bgctx, ints, func(ctx context.Context, val *int) (int, error) { return 1, nil }) expected := make([]int, 10000) From bba912709ca33d40b7d8ad5bcfc13897a75ada7c Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Wed, 31 May 2023 12:10:16 -0400 Subject: [PATCH 13/17] reimplemented Map with MapCtx --- iter/map.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iter/map.go b/iter/map.go index 6a675e9..31beee5 100644 --- a/iter/map.go +++ b/iter/map.go @@ -26,7 +26,7 @@ func Map[T, R any](input []T, f func(*T) R) []R { // // Map uses up to the configured Mapper's maximum number of goroutines. func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R { - res, _ := m.MapErr(input, func(t *T) (R, error) { + res, _ := m.MapCtx(context.Background(), input, func(_ context.Context, t *T) (R, error) { return f(t), nil }) return res From 825d1dfa1caf699891cf7b7fc83bf92a8e1ff7bd Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Wed, 31 May 2023 12:12:53 -0400 Subject: [PATCH 14/17] Fixing comment --- iter/map.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iter/map.go b/iter/map.go index 31beee5..39c1c18 100644 --- a/iter/map.go +++ b/iter/map.go @@ -50,7 +50,7 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { errMux sync.Mutex errs error ) - // MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapErrCtx + // MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapCtx which is only the first error res, _ := m.MapCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) { ires, err := f(t) if err != nil { From 296e00b97c2916d40b4b29cb851d86dc7b240c1f Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Thu, 1 Jun 2023 10:51:58 -0400 Subject: [PATCH 15/17] Fixing lint issues --- iter/iter.go | 8 ++++---- iter/iter_test.go | 2 +- iter/map.go | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index ba4c140..e6dd4b0 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -66,14 +66,14 @@ func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { // ForEachCtx is the same as ForEach except it also accepts a context // that it uses to manages the execution of tasks. -// The context is cancelled on task failure and the first error is returned +// The context is cancelled on task failure and the first error is returned. func ForEachCtx[T any](octx context.Context, input []T, f func(context.Context, *T) error) error { return Iterator[T]{}.ForEachCtx(octx, input, f) } // ForEachCtx is the same as ForEach except it also accepts a context // that it uses to manages the execution of tasks. -// The context is cancelled on task failure and the first error is returned +// The context is cancelled on task failure and the first error is returned. func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(context.Context, *T) error) error { return iter.ForEachIdxCtx(octx, input, func(ctx context.Context, _ int, input *T) error { return f(ctx, input) @@ -82,14 +82,14 @@ func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(conte // ForEachIdxCtx is the same as ForEachIdx except it also accepts a context // that it uses to manages the execution of tasks. -// The context is cancelled on task failure and the first error is returned +// The context is cancelled on task failure and the first error is returned. func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error { return Iterator[T]{}.ForEachIdxCtx(octx, input, f) } // ForEachIdxCtx is the same as ForEachIdx except it also accepts a context // that it uses to manages the execution of tasks. -// The context is cancelled on task failure and the first error is returned +// The context is cancelled on task failure and the first error is returned. func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate diff --git a/iter/iter_test.go b/iter/iter_test.go index 86b1126..962e00b 100644 --- a/iter/iter_test.go +++ b/iter/iter_test.go @@ -92,7 +92,7 @@ func TestForEachIdxCtx(t *testing.T) { t.Parallel() f := func() { ints := []int{1} - ForEachIdxCtx(bgctx, ints, + _ = ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error { panic("super bad thing happened") }) diff --git a/iter/map.go b/iter/map.go index 39c1c18..b99e7ac 100644 --- a/iter/map.go +++ b/iter/map.go @@ -66,14 +66,14 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { // MapCtx is the same as Map except it also accepts a context // that it uses to manages the execution of tasks. -// The context is cancelled on task failure and the first error is returned +// The context is cancelled on task failure and the first error is returned. func MapCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { return Mapper[T, R]{}.MapCtx(octx, input, f) } // MapCtx is the same as Map except it also accepts a context // that it uses to manages the execution of tasks. -// The context is cancelled on task failure and the first error is returned +// The context is cancelled on task failure and the first error is returned. func (m Mapper[T, R]) MapCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) From 9018777a50c77300893524ba23098e98d11ba471 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Thu, 1 Jun 2023 15:35:29 -0400 Subject: [PATCH 16/17] removed octx --- iter/iter.go | 18 +++++++++--------- iter/map.go | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index e6dd4b0..5e73011 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -67,30 +67,30 @@ func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) { // ForEachCtx is the same as ForEach except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. -func ForEachCtx[T any](octx context.Context, input []T, f func(context.Context, *T) error) error { - return Iterator[T]{}.ForEachCtx(octx, input, f) +func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error { + return Iterator[T]{}.ForEachCtx(ctx, input, f) } // ForEachCtx is the same as ForEach except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. -func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(context.Context, *T) error) error { - return iter.ForEachIdxCtx(octx, input, func(ctx context.Context, _ int, input *T) error { - return f(ctx, input) +func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { + return iter.ForEachIdxCtx(ctx, input, func(ictx context.Context, _ int, input *T) error { + return f(ictx, input) }) } // ForEachIdxCtx is the same as ForEachIdx except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. -func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error { - return Iterator[T]{}.ForEachIdxCtx(octx, input, f) +func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error { + return Iterator[T]{}.ForEachIdxCtx(ctx, input, f) } // ForEachIdxCtx is the same as ForEachIdx except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. -func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(context.Context, int, *T) error) error { +func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error { if iter.MaxGoroutines == 0 { // iter is a value receiver and is hence safe to mutate iter.MaxGoroutines = defaultMaxGoroutines() @@ -115,7 +115,7 @@ func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(co } runner := pool.New(). - WithContext(octx). + WithContext(ctx). WithCancelOnError(). WithFirstError(). WithMaxGoroutines(iter.MaxGoroutines) diff --git a/iter/map.go b/iter/map.go index b99e7ac..6302d1e 100644 --- a/iter/map.go +++ b/iter/map.go @@ -67,20 +67,20 @@ func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) { // MapCtx is the same as Map except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. -func MapCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { - return Mapper[T, R]{}.MapCtx(octx, input, f) +func MapCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { + return Mapper[T, R]{}.MapCtx(ctx, input, f) } // MapCtx is the same as Map except it also accepts a context // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. -func (m Mapper[T, R]) MapCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { +func (m Mapper[T, R]) MapCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) { var ( res = make([]R, len(input)) ) - return res, Iterator[T](m).ForEachIdxCtx(octx, input, func(ctx context.Context, i int, t *T) error { + return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(ictx context.Context, i int, t *T) error { var err error - res[i], err = f(ctx, t) + res[i], err = f(ictx, t) return err }) } From 08c699dcaa7cf1525cae3c29cffe420050b3fa93 Mon Sep 17 00:00:00 2001 From: Ron Koehler Date: Thu, 1 Jun 2023 15:39:31 -0400 Subject: [PATCH 17/17] fixing naming --- iter/iter.go | 12 ++++++------ iter/map.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/iter/iter.go b/iter/iter.go index 5e73011..4f22e43 100644 --- a/iter/iter.go +++ b/iter/iter.go @@ -75,8 +75,8 @@ func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, * // that it uses to manages the execution of tasks. // The context is cancelled on task failure and the first error is returned. func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error { - return iter.ForEachIdxCtx(ctx, input, func(ictx context.Context, _ int, input *T) error { - return f(ictx, input) + return iter.ForEachIdxCtx(ctx, input, func(innerctx context.Context, _ int, input *T) error { + return f(innerctx, input) }) } @@ -104,14 +104,14 @@ func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(con var idx atomic.Int64 // Create the task outside the loop to avoid extra closure allocations. - task := func(ctx context.Context) error { + task := func(innerctx context.Context) error { i := int(idx.Add(1) - 1) - for ; i < numInput && ctx.Err() == nil; i = int(idx.Add(1) - 1) { - if err := f(ctx, i, &input[i]); err != nil { + for ; i < numInput && innerctx.Err() == nil; i = int(idx.Add(1) - 1) { + if err := f(innerctx, i, &input[i]); err != nil { return err } } - return ctx.Err() // nil if the context was never cancelled + return innerctx.Err() // nil if the context was never cancelled } runner := pool.New(). diff --git a/iter/map.go b/iter/map.go index 6302d1e..542ab50 100644 --- a/iter/map.go +++ b/iter/map.go @@ -78,9 +78,9 @@ func (m Mapper[T, R]) MapCtx(ctx context.Context, input []T, f func(context.Cont var ( res = make([]R, len(input)) ) - return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(ictx context.Context, i int, t *T) error { + return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(innerctx context.Context, i int, t *T) error { var err error - res[i], err = f(ictx, t) + res[i], err = f(innerctx, t) return err }) }