diff --git a/ARROW.md b/ARROW.md new file mode 100644 index 0000000..4c96b00 --- /dev/null +++ b/ARROW.md @@ -0,0 +1,326 @@ +# Arrow Columnar Data Processing + +This document explains the Apache Arrow implementation introduced for efficient batch aggregation of oracle observations in the LLO (Low-Latency Oracle) system. + +## Overview + +### Why Arrow? + +The Arrow implementation addresses key performance challenges: + +1. **Memory Efficiency** - Replaces the 1GB static memory ballast with controlled, bounded allocation +2. **Batch Processing** - Enables efficient columnar operations on thousands of observations simultaneously +3. **Reduced Allocations** - Builder pooling minimizes GC pressure during repeated aggregation cycles + +### Dependencies + +```go +github.com/apache/arrow-go/v18 v18.3.1 +``` + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Arrow Data Pipeline │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ │ +│ Node Observations │ +│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ +│ │ Observer │ │ Observer │ │ Observer │ ... │ +│ │ 0 │ │ 1 │ │ N │ │ +│ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ +│ │ │ │ │ +│ └────────────┼────────────┘ │ +│ ▼ │ +│ ┌────────────────────────┐ │ +│ │ ArrowObservationMerger │ │ +│ │ MergeObservations() │ │ +│ └───────────┬────────────┘ │ +│ ▼ │ +│ ┌────────────────────────┐ │ +│ │ Arrow Record │ (ObservationSchema) │ +│ │ [observer_id, stream_id, value_type, values...] │ +│ └───────────┬────────────┘ │ +│ ▼ │ +│ ┌────────────────────────┐ │ +│ │ ArrowAggregator │ │ +│ │ AggregateObservations()│ │ +│ └───────────┬────────────┘ │ +│ ▼ │ +│ ┌────────────────────────┐ │ +│ │ StreamAggregates │ (per-stream aggregated values) │ +│ └────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +## Schemas + +Four Arrow schemas are defined in `llo/arrow_schemas.go`: + +### 1. ObservationSchema + +Stores merged observations from all nodes for batch aggregation. + +| Column | Type | Description | +|--------|------|-------------| +| `observer_id` | uint8 | Node that produced the observation (0-255) | +| `stream_id` | uint32 | Stream identifier | +| `value_type` | uint8 | Type discriminator (see Value Types) | +| `decimal_value` | binary | Encoded decimal value | +| `quote_bid` | binary | Quote bid component | +| `quote_benchmark` | binary | Quote benchmark component | +| `quote_ask` | binary | Quote ask component | +| `observed_at_ns` | uint64 | Provider timestamp (nanoseconds) | +| `timestamp_ns` | uint64 | Node observation timestamp | + +### 2. StreamAggregatesSchema + +Output from aggregation, input to report generation. + +| Column | Type | Description | +|--------|------|-------------| +| `stream_id` | uint32 | Stream identifier | +| `aggregator` | uint32 | Aggregator type used | +| `value_type` | uint8 | Type discriminator | +| `decimal_value` | binary | Aggregated decimal | +| `quote_*` | binary | Aggregated quote components | +| `observed_at_ns` | uint64 | Observation timestamp | + +### 3. CacheSchema + +Observation cache with TTL-based expiration. + +| Column | Type | Description | +|--------|------|-------------| +| `stream_id` | uint32 | Stream identifier | +| `value_type` | uint8 | Type discriminator | +| `decimal_value` | binary | Cached decimal | +| `quote_*` | binary | Cached quote components | +| `observed_at_ns` | uint64 | Observation timestamp | +| `expires_at_ns` | int64 | TTL expiration timestamp | + +### 4. TransmissionSchema + +Batched report transmissions with Arrow IPC compression. + +| Column | Type | Description | +|--------|------|-------------| +| `server_url` | string | Destination server | +| `config_digest` | fixed[32] | Configuration hash | +| `seq_nr` | uint64 | Sequence number | +| `report_data` | large_binary | Encoded report | +| `lifecycle_stage` | string | Report lifecycle stage | +| `report_format` | uint32 | Format identifier | +| `signatures` | list | Report signatures | +| `signers` | list | Signer indices | +| `transmission_hash` | fixed[32] | Transmission hash | +| `created_at_ns` | timestamp[ns] | Creation time | + +## Value Types + +Three value types are supported, identified by `value_type` column: + +```go +const ( + StreamValueTypeDecimal uint8 = 0 // Single decimal value + StreamValueTypeQuote uint8 = 1 // Quote with bid/benchmark/ask + StreamValueTypeTimestampd uint8 = 2 // Decimal with observation timestamp +) +``` + +### Decimal Encoding + +Values use `shopspring/decimal` binary encoding for precise representation: + +```go +// Encode +bytes, _ := decimal.MarshalBinary() + +// Decode +var d decimal.Decimal +d.UnmarshalBinary(bytes) +``` + +## Core Components + +### arrow_schemas.go + +Defines all four Arrow schemas and column index constants for type-safe access: + +```go +// Column indices for ObservationSchema +const ( + ObsColObserverID = iota + ObsColStreamID + ObsColValueType + // ... +) +``` + +### arrow_pool.go + +Memory management with two pool types: + +**LLOMemoryPool** - Wraps Arrow's allocator with metrics and optional bounds: + +```go +pool := NewLLOMemoryPool(maxBytes) // 0 for unlimited +allocated, allocs, releases := pool.Metrics() +``` + +**ArrowBuilderPool** - Unified pool for all schema builders: + +```go +builderPool := NewArrowBuilderPool(maxMemoryBytes) + +// Get a builder for observations +builder := builderPool.GetObservationBuilder() +// ... use builder ... +builderPool.PutObservationBuilder(builder) +``` + +### arrow_converters.go + +Type conversion between Go types and Arrow columns: + +```go +// Write StreamValue to Arrow builders +StreamValueToArrow(sv, valueTypeBuilder, decimalBuilder, bidBuilder, ...) + +// Read StreamValue from Arrow arrays +sv, _ := ArrowToStreamValue(idx, valueTypeArr, decimalArr, bidArr, ...) + +// Batch conversion for cache operations +record, _ := StreamValuesToArrowRecord(values, pool) +values, _ := ArrowRecordToStreamValues(record) +``` + +### arrow_observation_merger.go + +Merges observations from multiple nodes into a single Arrow record: + +```go +merger := NewArrowObservationMerger(pool, codec) + +// Merge attributed observations from consensus +record, _ := merger.MergeObservations(attributedObservations) +defer record.Release() + +// Utility functions +counts := CountByStreamID(record) // {streamID: count} +counts := CountByObserver(record) // {observerID: count} +``` + +### arrow_aggregators.go + +Performs vectorized aggregation on Arrow records: + +```go +aggregator := NewArrowAggregator(pool) + +// Aggregate with channel definitions providing aggregator type per stream +results, _ := aggregator.AggregateObservations(record, channelDefs, f) +// f = fault tolerance threshold (observations must exceed f) +``` + +**Supported Aggregators:** + +| Aggregator | Description | +|------------|-------------| +| `Median` | Sorts values, returns middle element | +| `Mode` | Most common value (requires f+1 agreement) | +| `Quote` | Median of each quote component separately | + +## Data Flow Example + +```go +// 1. Create pools +builderPool := NewArrowBuilderPool(0) +codec := &StandardObservationCodec{} + +// 2. Merge observations from all nodes +merger := NewArrowObservationMerger(builderPool, codec) +record, _ := merger.MergeObservations(attributedObservations) +defer record.Release() + +// 3. Aggregate using channel definitions +aggregator := NewArrowAggregator(builderPool) +streamAggregates, _ := aggregator.AggregateObservations(record, channelDefs, f) + +// 4. Use aggregated values for report generation +for streamID, aggregatorValues := range streamAggregates { + for aggregatorType, value := range aggregatorValues { + // Build reports... + } +} +``` + +## Memory Management Best Practices + +1. **Always release records** when done: + ```go + record, _ := merger.MergeObservations(...) + defer record.Release() + ``` + +2. **Return builders to pool** after use: + ```go + builder := pool.GetObservationBuilder() + // ... use builder ... + pool.PutObservationBuilder(builder) + ``` + +3. **Set memory limits** in production: + ```go + pool := NewArrowBuilderPool(500 * 1024 * 1024) // 500MB limit + ``` + +4. **Monitor allocation metrics**: + ```go + allocated, allocs, releases := pool.MemoryStats() + ``` + +## Testing + +### Unit Tests + +`llo/arrow_aggregators_test.go` - Validates aggregation logic for all value types and aggregator combinations. + +### Benchmarks + +`llo/arrow_bench_test.go` - Performance benchmarks for: +- Median aggregation +- Quote aggregation +- Type conversion operations +- Builder pool efficiency + +Run benchmarks: +```bash +cd llo +go test -bench=. -benchmem ./... +``` + +### Comparison Tests + +`llo/arrow_comparison_test.go` - Compares Arrow implementation against the original implementation at various scales: +- 10, 100, 1000, 5000, 10000 observations + +Run comparison: +```bash +go test -run=Comparison -v ./llo/ +``` + +## File Reference + +| File | Purpose | +|------|---------| +| `llo/arrow_schemas.go` | Schema definitions and column constants | +| `llo/arrow_pool.go` | Memory pool and builder pool management | +| `llo/arrow_converters.go` | Go type <-> Arrow conversion utilities | +| `llo/arrow_observation_merger.go` | Multi-node observation merging | +| `llo/arrow_aggregators.go` | Vectorized aggregation algorithms | +| `llo/arrow_aggregators_test.go` | Unit tests | +| `llo/arrow_bench_test.go` | Performance benchmarks | +| `llo/arrow_comparison_test.go` | Before/after comparison tests | diff --git a/go.mod b/go.mod index c29c03f..d53f495 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/smartcontractkit/chainlink-data-streams go 1.25.3 require ( + github.com/apache/arrow-go/v18 v18.3.1 github.com/ethereum/go-ethereum v1.15.3 github.com/expr-lang/expr v1.17.5 github.com/goccy/go-json v0.10.5 @@ -52,6 +53,7 @@ require ( github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect @@ -73,12 +75,12 @@ require ( github.com/jmoiron/sqlx v1.4.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lib/pq v1.10.9 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.14 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -93,7 +95,6 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.65.0 // indirect github.com/prometheus/procfs v0.16.1 // indirect - github.com/rivo/uniseg v0.4.4 // indirect github.com/rs/cors v1.9.0 // indirect github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 // indirect github.com/scylladb/go-reflectx v1.0.1 // indirect @@ -111,6 +112,7 @@ require ( github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect @@ -134,11 +136,15 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/crypto v0.45.0 // indirect + golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect + golang.org/x/telemetry v0.0.0-20251008203120-078029d740a8 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.38.0 // indirect + golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0101e03..137dcd6 100644 --- a/go.sum +++ b/go.sum @@ -52,9 +52,13 @@ github.com/VictoriaMetrics/fastcache v1.12.2 h1:N0y9ASrJ0F6h0QaC3o6uJb3NIZ9VKLjC github.com/VictoriaMetrics/fastcache v1.12.2/go.mod h1:AmC+Nzz1+3G2eCPapF6UcsnkThDcMsQicp4xDukwJYI= github.com/XSAM/otelsql v0.37.0 h1:ya5RNw028JW0eJW8Ma4AmoKxAYsJSGuNVbC7F1J457A= github.com/XSAM/otelsql v0.37.0/go.mod h1:LHbCu49iU8p255nCn1oi04oX2UjSoRcUMiKEHo2a5qM= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow-go/v18 v18.3.1 h1:oYZT8FqONiK74JhlH3WKVv+2NKYoyZ7C2ioD4Dj3ixk= github.com/apache/arrow-go/v18 v18.3.1/go.mod h1:12QBya5JZT6PnBihi5NJTzbACrDGXYkrgjujz3MRQXU= +github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= +github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -232,8 +236,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= -github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= @@ -397,6 +401,8 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= +github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= @@ -443,12 +449,16 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= -github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= @@ -456,8 +466,9 @@ github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS4 github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/pointerstructure v1.2.0 h1:O+i9nHnXS3l/9Wu7r4NrEdwA2VFTicjUEN1uBnDo34A= github.com/mitchellh/pointerstructure v1.2.0/go.mod h1:BRAsLI5zgXmw97Lf6s25bs8ohIXc3tViBH44KcwB2g4= github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= @@ -526,7 +537,6 @@ github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2 github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -633,6 +643,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= diff --git a/llo/arrow_aggregators.go b/llo/arrow_aggregators.go new file mode 100644 index 0000000..31af838 --- /dev/null +++ b/llo/arrow_aggregators.go @@ -0,0 +1,368 @@ +package llo + +import ( + "fmt" + "sort" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/shopspring/decimal" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +// ArrowAggregator performs vectorized aggregation on Arrow records. +type ArrowAggregator struct { + pool *ArrowBuilderPool + logger logger.Logger +} + +// NewArrowAggregator creates a new Arrow-based aggregator. +// Logger can be nil if logging is not needed. +func NewArrowAggregator(pool *ArrowBuilderPool, lggr logger.Logger) *ArrowAggregator { + return &ArrowAggregator{pool: pool, logger: lggr} +} + +// streamObservations groups observations by stream ID for aggregation. +type streamObservations struct { + valueType uint8 + decimals []decimal.Decimal + quotes []*Quote + timestamps []uint64 + innerValues []decimal.Decimal // For TimestampedStreamValue +} + +// AggregateObservations performs aggregation on merged observations. +// It groups observations by stream ID and applies the appropriate aggregator. +func (a *ArrowAggregator) AggregateObservations( + record arrow.Record, + channelDefs llotypes.ChannelDefinitions, + f int, +) (StreamAggregates, error) { + if record == nil || record.NumRows() == 0 { + return nil, nil + } + + // Extract columns + streamIDArr := record.Column(ObsColStreamID).(*array.Uint32) + valueTypeArr := record.Column(ObsColValueType).(*array.Uint8) + decimalArr := record.Column(ObsColDecimalValue).(*array.Binary) + bidArr := record.Column(ObsColQuoteBid).(*array.Binary) + benchmarkArr := record.Column(ObsColQuoteBenchmark).(*array.Binary) + askArr := record.Column(ObsColQuoteAsk).(*array.Binary) + observedAtArr := record.Column(ObsColObservedAtNs).(*array.Uint64) + + // Group observations by stream ID + grouped := make(map[llotypes.StreamID]*streamObservations) + + for i := 0; i < int(record.NumRows()); i++ { + streamID := streamIDArr.Value(i) + if valueTypeArr.IsNull(i) { + continue + } + + obs, exists := grouped[streamID] + if !exists { + obs = &streamObservations{ + valueType: valueTypeArr.Value(i), + } + grouped[streamID] = obs + } + + valueType := valueTypeArr.Value(i) + + switch valueType { + case StreamValueTypeDecimal: + if !decimalArr.IsNull(i) { + d, err := BytesToDecimal(decimalArr.Value(i)) + if err == nil { + obs.decimals = append(obs.decimals, d) + } + } + + case StreamValueTypeQuote: + if !bidArr.IsNull(i) && !benchmarkArr.IsNull(i) && !askArr.IsNull(i) { + bid, err1 := BytesToDecimal(bidArr.Value(i)) + benchmark, err2 := BytesToDecimal(benchmarkArr.Value(i)) + ask, err3 := BytesToDecimal(askArr.Value(i)) + // Skip quotes with parsing errors to avoid injecting zero/corrupt values + if err1 != nil || err2 != nil || err3 != nil { + continue + } + q := &Quote{Bid: bid, Benchmark: benchmark, Ask: ask} + if q.IsValid() { + obs.quotes = append(obs.quotes, q) + } + } + + case StreamValueTypeTimestampd: + if !observedAtArr.IsNull(i) { + obs.timestamps = append(obs.timestamps, observedAtArr.Value(i)) + } + if !decimalArr.IsNull(i) { + d, err := BytesToDecimal(decimalArr.Value(i)) + if err == nil { + obs.innerValues = append(obs.innerValues, d) + } + } + } + } + + // Determine required aggregators for each stream from channel definitions + streamAggregators := make(map[llotypes.StreamID]llotypes.Aggregator) + for _, cd := range channelDefs { + for _, stream := range cd.Streams { + streamAggregators[stream.StreamID] = stream.Aggregator + } + } + + // Apply aggregators + result := make(StreamAggregates) + + for streamID, obs := range grouped { + aggregator, exists := streamAggregators[streamID] + if !exists { + aggregator = llotypes.AggregatorMedian // Default + } + + var sv StreamValue + var err error + + switch aggregator { + case llotypes.AggregatorMedian: + sv, err = a.medianAggregate(obs, f) + case llotypes.AggregatorMode: + sv, err = a.modeAggregate(obs, f) + case llotypes.AggregatorQuote: + sv, err = a.quoteAggregate(obs, f) + default: + sv, err = a.medianAggregate(obs, f) + } + + if err != nil { + if a.logger != nil { + a.logger.Debugw("Aggregation failed for stream", "streamID", streamID, "aggregator", aggregator, "err", err) + } + continue // Skip streams that fail aggregation + } + + if sv != nil { + if result[streamID] == nil { + result[streamID] = make(map[llotypes.Aggregator]StreamValue) + } + result[streamID][aggregator] = sv + } + } + + return result, nil +} + +// medianAggregate computes the median for a stream's observations. +func (a *ArrowAggregator) medianAggregate(obs *streamObservations, f int) (StreamValue, error) { + switch obs.valueType { + case StreamValueTypeDecimal: + if len(obs.decimals) <= f { + return nil, fmt.Errorf("not enough observations: %d <= %d", len(obs.decimals), f) + } + // Sort decimals + sorted := make([]decimal.Decimal, len(obs.decimals)) + copy(sorted, obs.decimals) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Cmp(sorted[j]) < 0 + }) + return ToDecimal(sorted[len(sorted)/2]), nil + + case StreamValueTypeQuote: + // For quotes, use benchmark for median calculation + if len(obs.quotes) <= f { + return nil, fmt.Errorf("not enough observations: %d <= %d", len(obs.quotes), f) + } + benchmarks := make([]decimal.Decimal, len(obs.quotes)) + for i, q := range obs.quotes { + benchmarks[i] = q.Benchmark + } + sort.Slice(benchmarks, func(i, j int) bool { + return benchmarks[i].Cmp(benchmarks[j]) < 0 + }) + return ToDecimal(benchmarks[len(benchmarks)/2]), nil + + case StreamValueTypeTimestampd: + if len(obs.innerValues) <= f { + return nil, fmt.Errorf("not enough observations: %d <= %d", len(obs.innerValues), f) + } + // Sort inner values + sorted := make([]decimal.Decimal, len(obs.innerValues)) + copy(sorted, obs.innerValues) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Cmp(sorted[j]) < 0 + }) + + // Sort timestamps + sortedTs := make([]uint64, len(obs.timestamps)) + copy(sortedTs, obs.timestamps) + sort.Slice(sortedTs, func(i, j int) bool { + return sortedTs[i] < sortedTs[j] + }) + + return &TimestampedStreamValue{ + ObservedAtNanoseconds: sortedTs[len(sortedTs)/2], + StreamValue: ToDecimal(sorted[len(sorted)/2]), + }, nil + + default: + return nil, fmt.Errorf("unsupported value type for median: %d", obs.valueType) + } +} + +// modeAggregate computes the mode (most common value) for a stream's observations. +func (a *ArrowAggregator) modeAggregate(obs *streamObservations, f int) (StreamValue, error) { + switch obs.valueType { + case StreamValueTypeDecimal: + if len(obs.decimals) == 0 { + return nil, fmt.Errorf("no observations") + } + + // Count occurrences using string representation + counts := make(map[string]int) + valueMap := make(map[string]decimal.Decimal) + for _, d := range obs.decimals { + key := d.String() + counts[key]++ + valueMap[key] = d + } + + // Find mode + var modeKey string + var modeCount int + for key, count := range counts { + if count > modeCount || (count == modeCount && key < modeKey) { + modeKey = key + modeCount = count + } + } + + if modeCount < f+1 { + return nil, fmt.Errorf("not enough observations in agreement: %d < %d", modeCount, f+1) + } + + return ToDecimal(valueMap[modeKey]), nil + + case StreamValueTypeQuote: + if len(obs.quotes) == 0 { + return nil, fmt.Errorf("no observations") + } + + // Count occurrences using string representation + counts := make(map[string]int) + valueMap := make(map[string]*Quote) + for _, q := range obs.quotes { + key := fmt.Sprintf("%s|%s|%s", q.Bid.String(), q.Benchmark.String(), q.Ask.String()) + counts[key]++ + valueMap[key] = q + } + + // Find mode + var modeKey string + var modeCount int + for key, count := range counts { + if count > modeCount || (count == modeCount && key < modeKey) { + modeKey = key + modeCount = count + } + } + + if modeCount < f+1 { + return nil, fmt.Errorf("not enough observations in agreement: %d < %d", modeCount, f+1) + } + + return valueMap[modeKey], nil + + default: + return nil, fmt.Errorf("unsupported value type for mode: %d", obs.valueType) + } +} + +// quoteAggregate computes the median for each component of quote observations. +func (a *ArrowAggregator) quoteAggregate(obs *streamObservations, f int) (StreamValue, error) { + if obs.valueType != StreamValueTypeQuote { + return nil, fmt.Errorf("quote aggregator requires quote observations") + } + + if len(obs.quotes) <= f { + return nil, fmt.Errorf("not enough observations: %d <= %d", len(obs.quotes), f) + } + + // Sort and get median for each component + bids := make([]decimal.Decimal, len(obs.quotes)) + benchmarks := make([]decimal.Decimal, len(obs.quotes)) + asks := make([]decimal.Decimal, len(obs.quotes)) + + for i, q := range obs.quotes { + bids[i] = q.Bid + benchmarks[i] = q.Benchmark + asks[i] = q.Ask + } + + sort.Slice(bids, func(i, j int) bool { return bids[i].Cmp(bids[j]) < 0 }) + sort.Slice(benchmarks, func(i, j int) bool { return benchmarks[i].Cmp(benchmarks[j]) < 0 }) + sort.Slice(asks, func(i, j int) bool { return asks[i].Cmp(asks[j]) < 0 }) + + mid := len(obs.quotes) / 2 + return &Quote{ + Bid: bids[mid], + Benchmark: benchmarks[mid], + Ask: asks[mid], + }, nil +} + +// MedianDecimalBatch computes median for a slice of decimals. +// This is optimized for batch processing. +func MedianDecimalBatch(values []decimal.Decimal, f int) (decimal.Decimal, error) { + if len(values) <= f { + return decimal.Decimal{}, fmt.Errorf("not enough values: %d <= %d", len(values), f) + } + + sorted := make([]decimal.Decimal, len(values)) + copy(sorted, values) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Cmp(sorted[j]) < 0 + }) + + return sorted[len(sorted)/2], nil +} + +// AggregateResult holds the result of Arrow-based aggregation for a single stream. +type AggregateResult struct { + StreamID llotypes.StreamID + Aggregator llotypes.Aggregator + Value StreamValue +} + +// BatchAggregateToRecord converts aggregation results to an Arrow record. +func (a *ArrowAggregator) BatchAggregateToRecord(results []AggregateResult) (arrow.Record, error) { + builder := a.pool.GetAggregatesBuilder() + defer a.pool.PutAggregatesBuilder(builder) + + streamIDBuilder := builder.Field(AggColStreamID).(*array.Uint32Builder) + aggregatorBuilder := builder.Field(AggColAggregator).(*array.Uint32Builder) + valueTypeBuilder := builder.Field(AggColValueType).(*array.Uint8Builder) + decimalBuilder := builder.Field(AggColDecimalValue).(*array.BinaryBuilder) + bidBuilder := builder.Field(AggColQuoteBid).(*array.BinaryBuilder) + benchmarkBuilder := builder.Field(AggColQuoteBenchmark).(*array.BinaryBuilder) + askBuilder := builder.Field(AggColQuoteAsk).(*array.BinaryBuilder) + observedAtBuilder := builder.Field(AggColObservedAtNs).(*array.Uint64Builder) + + for _, result := range results { + streamIDBuilder.Append(result.StreamID) + aggregatorBuilder.Append(uint32(result.Aggregator)) + + _, err := StreamValueToArrow(result.Value, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + if err != nil { + return nil, err + } + } + + return builder.NewRecord(), nil +} diff --git a/llo/arrow_aggregators_test.go b/llo/arrow_aggregators_test.go new file mode 100644 index 0000000..0cfea9f --- /dev/null +++ b/llo/arrow_aggregators_test.go @@ -0,0 +1,376 @@ +package llo + +import ( + "testing" + + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +func TestDecimalToBytes(t *testing.T) { + tests := []struct { + name string + value string + }{ + {"zero", "0"}, + {"positive", "123.456"}, + {"negative", "-789.012"}, + {"large", "999999999999999999.999999999999999999"}, + {"small", "0.000000000000000001"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d, err := decimal.NewFromString(tt.value) + require.NoError(t, err) + + // Convert to bytes + bytes, err := DecimalToBytes(d) + require.NoError(t, err) + assert.NotEmpty(t, bytes) + + // Convert back + result, err := BytesToDecimal(bytes) + require.NoError(t, err) + assert.True(t, d.Equal(result), "expected %s, got %s", d, result) + }) + } +} + +func TestStreamValueToArrow(t *testing.T) { + pool := NewArrowBuilderPool(0) + builder := pool.GetCacheBuilder() + defer pool.PutCacheBuilder(builder) + + valueTypeBuilder := builder.Field(CacheColValueType).(*array.Uint8Builder) + decimalBuilder := builder.Field(CacheColDecimalValue).(*array.BinaryBuilder) + bidBuilder := builder.Field(CacheColQuoteBid).(*array.BinaryBuilder) + benchmarkBuilder := builder.Field(CacheColQuoteBenchmark).(*array.BinaryBuilder) + askBuilder := builder.Field(CacheColQuoteAsk).(*array.BinaryBuilder) + observedAtBuilder := builder.Field(CacheColObservedAtNs).(*array.Uint64Builder) + + t.Run("decimal", func(t *testing.T) { + d := decimal.NewFromFloat(123.456) + sv := ToDecimal(d) + + valueType, err := StreamValueToArrow(sv, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + require.NoError(t, err) + assert.Equal(t, StreamValueTypeDecimal, valueType) + }) + + t.Run("quote", func(t *testing.T) { + q := &Quote{ + Bid: decimal.NewFromFloat(99.5), + Benchmark: decimal.NewFromFloat(100.0), + Ask: decimal.NewFromFloat(100.5), + } + + valueType, err := StreamValueToArrow(q, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + require.NoError(t, err) + assert.Equal(t, StreamValueTypeQuote, valueType) + }) + + t.Run("timestamped", func(t *testing.T) { + tsv := &TimestampedStreamValue{ + ObservedAtNanoseconds: 1234567890, + StreamValue: ToDecimal(decimal.NewFromFloat(42.0)), + } + + valueType, err := StreamValueToArrow(tsv, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + require.NoError(t, err) + assert.Equal(t, StreamValueTypeTimestampd, valueType) + }) + + t.Run("nil", func(t *testing.T) { + valueType, err := StreamValueToArrow(nil, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + require.NoError(t, err) + assert.Equal(t, uint8(0), valueType) + }) +} + +func TestMedianDecimalBatch(t *testing.T) { + t.Run("odd count", func(t *testing.T) { + values := []decimal.Decimal{ + decimal.NewFromFloat(10), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + decimal.NewFromFloat(40), + decimal.NewFromFloat(50), + } + + result, err := MedianDecimalBatch(values, 1) + require.NoError(t, err) + assert.True(t, decimal.NewFromFloat(30).Equal(result)) + }) + + t.Run("even count", func(t *testing.T) { + values := []decimal.Decimal{ + decimal.NewFromFloat(10), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + decimal.NewFromFloat(40), + } + + // With even count, we take the higher middle value (rank-k median) + result, err := MedianDecimalBatch(values, 1) + require.NoError(t, err) + assert.True(t, decimal.NewFromFloat(30).Equal(result)) + }) + + t.Run("unsorted input", func(t *testing.T) { + values := []decimal.Decimal{ + decimal.NewFromFloat(50), + decimal.NewFromFloat(10), + decimal.NewFromFloat(40), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + } + + result, err := MedianDecimalBatch(values, 1) + require.NoError(t, err) + assert.True(t, decimal.NewFromFloat(30).Equal(result)) + }) + + t.Run("not enough values", func(t *testing.T) { + values := []decimal.Decimal{ + decimal.NewFromFloat(10), + } + + _, err := MedianDecimalBatch(values, 1) + require.Error(t, err) + }) +} + +func TestArrowAggregator_MedianAggregate(t *testing.T) { + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + + t.Run("decimal values", func(t *testing.T) { + obs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: []decimal.Decimal{ + decimal.NewFromFloat(10), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + decimal.NewFromFloat(40), + decimal.NewFromFloat(50), + }, + } + + result, err := agg.medianAggregate(obs, 1) + require.NoError(t, err) + require.NotNil(t, result) + + dec, ok := result.(*Decimal) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(30).Equal(dec.Decimal())) + }) + + t.Run("quote values uses benchmark", func(t *testing.T) { + obs := &streamObservations{ + valueType: StreamValueTypeQuote, + quotes: []*Quote{ + {Bid: decimal.NewFromFloat(9), Benchmark: decimal.NewFromFloat(10), Ask: decimal.NewFromFloat(11)}, + {Bid: decimal.NewFromFloat(19), Benchmark: decimal.NewFromFloat(20), Ask: decimal.NewFromFloat(21)}, + {Bid: decimal.NewFromFloat(29), Benchmark: decimal.NewFromFloat(30), Ask: decimal.NewFromFloat(31)}, + }, + } + + result, err := agg.medianAggregate(obs, 1) + require.NoError(t, err) + require.NotNil(t, result) + + dec, ok := result.(*Decimal) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(20).Equal(dec.Decimal())) + }) + + t.Run("timestamped values", func(t *testing.T) { + obs := &streamObservations{ + valueType: StreamValueTypeTimestampd, + timestamps: []uint64{ + 1000, + 2000, + 3000, + }, + innerValues: []decimal.Decimal{ + decimal.NewFromFloat(10), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + }, + } + + result, err := agg.medianAggregate(obs, 1) + require.NoError(t, err) + require.NotNil(t, result) + + tsv, ok := result.(*TimestampedStreamValue) + require.True(t, ok) + assert.Equal(t, uint64(2000), tsv.ObservedAtNanoseconds) + + dec, ok := tsv.StreamValue.(*Decimal) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(20).Equal(dec.Decimal())) + }) +} + +func TestArrowAggregator_QuoteAggregate(t *testing.T) { + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + + t.Run("quote aggregation", func(t *testing.T) { + obs := &streamObservations{ + valueType: StreamValueTypeQuote, + quotes: []*Quote{ + {Bid: decimal.NewFromFloat(95), Benchmark: decimal.NewFromFloat(100), Ask: decimal.NewFromFloat(105)}, + {Bid: decimal.NewFromFloat(96), Benchmark: decimal.NewFromFloat(101), Ask: decimal.NewFromFloat(106)}, + {Bid: decimal.NewFromFloat(97), Benchmark: decimal.NewFromFloat(102), Ask: decimal.NewFromFloat(107)}, + {Bid: decimal.NewFromFloat(98), Benchmark: decimal.NewFromFloat(103), Ask: decimal.NewFromFloat(108)}, + {Bid: decimal.NewFromFloat(99), Benchmark: decimal.NewFromFloat(104), Ask: decimal.NewFromFloat(109)}, + }, + } + + result, err := agg.quoteAggregate(obs, 1) + require.NoError(t, err) + require.NotNil(t, result) + + quote, ok := result.(*Quote) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(97).Equal(quote.Bid)) + assert.True(t, decimal.NewFromFloat(102).Equal(quote.Benchmark)) + assert.True(t, decimal.NewFromFloat(107).Equal(quote.Ask)) + }) +} + +func TestArrowAggregator_ModeAggregate(t *testing.T) { + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + + t.Run("decimal mode", func(t *testing.T) { + obs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: []decimal.Decimal{ + decimal.NewFromFloat(10), + decimal.NewFromFloat(20), + decimal.NewFromFloat(20), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + }, + } + + result, err := agg.modeAggregate(obs, 2) + require.NoError(t, err) + require.NotNil(t, result) + + dec, ok := result.(*Decimal) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(20).Equal(dec.Decimal())) + }) + + t.Run("not enough agreement", func(t *testing.T) { + obs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: []decimal.Decimal{ + decimal.NewFromFloat(10), + decimal.NewFromFloat(20), + decimal.NewFromFloat(30), + decimal.NewFromFloat(40), + decimal.NewFromFloat(50), + }, + } + + _, err := agg.modeAggregate(obs, 2) + require.Error(t, err) + }) +} + +func TestStreamValuesToArrowRecord(t *testing.T) { + pool := NewArrowBuilderPool(0) + + values := map[llotypes.StreamID]StreamValue{ + 1: ToDecimal(decimal.NewFromFloat(100.5)), + 2: &Quote{ + Bid: decimal.NewFromFloat(99), + Benchmark: decimal.NewFromFloat(100), + Ask: decimal.NewFromFloat(101), + }, + 3: &TimestampedStreamValue{ + ObservedAtNanoseconds: 1234567890, + StreamValue: ToDecimal(decimal.NewFromFloat(42.0)), + }, + } + + record, err := StreamValuesToArrowRecord(values, pool) + require.NoError(t, err) + require.NotNil(t, record) + defer record.Release() + + assert.Equal(t, int64(3), record.NumRows()) +} + +func TestArrowRecordToStreamValues(t *testing.T) { + pool := NewArrowBuilderPool(0) + + original := map[llotypes.StreamID]StreamValue{ + 1: ToDecimal(decimal.NewFromFloat(100.5)), + 2: &Quote{ + Bid: decimal.NewFromFloat(99), + Benchmark: decimal.NewFromFloat(100), + Ask: decimal.NewFromFloat(101), + }, + } + + // Convert to Arrow record + record, err := StreamValuesToArrowRecord(original, pool) + require.NoError(t, err) + require.NotNil(t, record) + defer record.Release() + + // Convert back + result, err := ArrowRecordToStreamValues(record) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify values match + assert.Len(t, result, 2) + + dec1, ok := result[1].(*Decimal) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(100.5).Equal(dec1.Decimal())) + + quote2, ok := result[2].(*Quote) + require.True(t, ok) + assert.True(t, decimal.NewFromFloat(99).Equal(quote2.Bid)) + assert.True(t, decimal.NewFromFloat(100).Equal(quote2.Benchmark)) + assert.True(t, decimal.NewFromFloat(101).Equal(quote2.Ask)) +} + +func TestArrowBuilderPool(t *testing.T) { + pool := NewArrowBuilderPool(1024 * 1024) // 1MB limit + + t.Run("get and put observation builder", func(t *testing.T) { + builder := pool.GetObservationBuilder() + require.NotNil(t, builder) + + // Add some data + builder.Field(ObsColObserverID).(*array.Uint8Builder).Append(1) + builder.Field(ObsColStreamID).(*array.Uint32Builder).Append(100) + + pool.PutObservationBuilder(builder) + }) + + t.Run("memory stats", func(t *testing.T) { + allocated, allocs, releases := pool.MemoryStats() + assert.GreaterOrEqual(t, allocated, int64(0)) + assert.GreaterOrEqual(t, allocs, int64(0)) + assert.GreaterOrEqual(t, releases, int64(0)) + }) +} diff --git a/llo/arrow_bench_test.go b/llo/arrow_bench_test.go new file mode 100644 index 0000000..26dbedc --- /dev/null +++ b/llo/arrow_bench_test.go @@ -0,0 +1,234 @@ +package llo + +import ( + "fmt" + "testing" + + "github.com/shopspring/decimal" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +// BenchmarkMedianDecimalBatch benchmarks the median aggregation. +func BenchmarkMedianDecimalBatch(b *testing.B) { + sizes := []int{10, 100, 1000, 10000} + + for _, size := range sizes { + values := make([]decimal.Decimal, size) + for i := 0; i < size; i++ { + values[i] = decimal.NewFromFloat(float64(i)) + } + + b.Run(fmt.Sprintf("size_%d", len(values)), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = MedianDecimalBatch(values, 1) + } + }) + } +} + +// BenchmarkStreamValuesToArrowRecord benchmarks conversion to Arrow. +func BenchmarkStreamValuesToArrowRecord(b *testing.B) { + pool := NewArrowBuilderPool(0) + + sizes := []int{100, 1000, 10000} + + for _, size := range sizes { + values := make(map[llotypes.StreamID]StreamValue, size) + for i := 0; i < size; i++ { + values[llotypes.StreamID(i)] = ToDecimal(decimal.NewFromFloat(float64(i))) + } + + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + record, _ := StreamValuesToArrowRecord(values, pool) + if record != nil { + record.Release() + } + } + }) + } +} + +// BenchmarkArrowRecordToStreamValues benchmarks conversion from Arrow. +func BenchmarkArrowRecordToStreamValues(b *testing.B) { + pool := NewArrowBuilderPool(0) + + sizes := []int{100, 1000, 10000} + + for _, size := range sizes { + values := make(map[llotypes.StreamID]StreamValue, size) + for i := 0; i < size; i++ { + values[llotypes.StreamID(i)] = ToDecimal(decimal.NewFromFloat(float64(i))) + } + + record, _ := StreamValuesToArrowRecord(values, pool) + + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = ArrowRecordToStreamValues(record) + } + }) + + record.Release() + } +} + +// BenchmarkArrowAggregator_MedianAggregate benchmarks median aggregation. +func BenchmarkArrowAggregator_MedianAggregate(b *testing.B) { + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + obs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: make([]decimal.Decimal, size), + } + for i := 0; i < size; i++ { + obs.decimals[i] = decimal.NewFromFloat(float64(i)) + } + + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = agg.medianAggregate(obs, 1) + } + }) + } +} + +// BenchmarkArrowAggregator_QuoteAggregate benchmarks quote aggregation. +func BenchmarkArrowAggregator_QuoteAggregate(b *testing.B) { + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + obs := &streamObservations{ + valueType: StreamValueTypeQuote, + quotes: make([]*Quote, size), + } + for i := 0; i < size; i++ { + obs.quotes[i] = &Quote{ + Bid: decimal.NewFromFloat(float64(i) - 0.5), + Benchmark: decimal.NewFromFloat(float64(i)), + Ask: decimal.NewFromFloat(float64(i) + 0.5), + } + } + + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = agg.quoteAggregate(obs, 1) + } + }) + } +} + +// BenchmarkDecimalConversion benchmarks decimal to/from bytes conversion. +func BenchmarkDecimalConversion(b *testing.B) { + d := decimal.NewFromFloat(123456789.123456789) + + b.Run("to_bytes", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = DecimalToBytes(d) + } + }) + + bytes, _ := DecimalToBytes(d) + b.Run("from_bytes", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = BytesToDecimal(bytes) + } + }) +} + +// BenchmarkBuilderPool benchmarks the builder pool. +func BenchmarkBuilderPool(b *testing.B) { + pool := NewArrowBuilderPool(0) + + b.Run("get_put", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + builder := pool.GetObservationBuilder() + pool.PutObservationBuilder(builder) + } + }) + + b.Run("get_build_put", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + builder := pool.GetObservationBuilder() + // Simulate some work + _ = builder.NewRecord() + pool.PutObservationBuilder(builder) + } + }) +} + +// BenchmarkMemoryPoolAllocation benchmarks memory pool allocations. +func BenchmarkMemoryPoolAllocation(b *testing.B) { + pool := NewLLOMemoryPool(0) + + sizes := []int{64, 256, 1024, 4096, 16384} + + for _, size := range sizes { + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := pool.Allocate(size) + pool.Free(buf) + } + }) + } +} + +// BenchmarkExistingMedianAggregator benchmarks the existing implementation for comparison. +func BenchmarkExistingMedianAggregator(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + values := make([]StreamValue, size) + for i := 0; i < size; i++ { + values[i] = ToDecimal(decimal.NewFromFloat(float64(i))) + } + + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = MedianAggregator(values, 1) + } + }) + } +} + +// BenchmarkExistingQuoteAggregator benchmarks the existing implementation for comparison. +func BenchmarkExistingQuoteAggregator(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + values := make([]StreamValue, size) + for i := 0; i < size; i++ { + values[i] = &Quote{ + Bid: decimal.NewFromFloat(float64(i) - 0.5), + Benchmark: decimal.NewFromFloat(float64(i)), + Ask: decimal.NewFromFloat(float64(i) + 0.5), + } + } + + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = QuoteAggregator(values, 1) + } + }) + } +} diff --git a/llo/arrow_comparison_test.go b/llo/arrow_comparison_test.go new file mode 100644 index 0000000..b74959b --- /dev/null +++ b/llo/arrow_comparison_test.go @@ -0,0 +1,449 @@ +package llo + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +// ============================================================================ +// BEFORE/AFTER COMPARISON BENCHMARKS +// ============================================================================ +// +// These benchmarks compare the original implementation with the new Arrow-based +// implementation. Run with: +// +// go test ./llo/... -bench=Comparison -benchmem -count=5 +// +// The "Before_Original" benchmarks use the existing implementation. +// The "After_Arrow" benchmarks use the new Arrow-based implementation. +// ============================================================================ + +// BenchmarkComparison_MedianAggregation compares median aggregation performance. +func BenchmarkComparison_MedianAggregation(b *testing.B) { + sizes := []int{10, 100, 1000, 5000, 10000} + + for _, size := range sizes { + name := fmt.Sprintf("%d_observations", size) + + // Setup: create test data for original implementation + originalValues := make([]StreamValue, size) + for i := 0; i < size; i++ { + originalValues[i] = ToDecimal(decimal.NewFromFloat(float64(rand.Intn(10000)))) + } + + // Setup: create test data for Arrow implementation + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + arrowObs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: make([]decimal.Decimal, size), + } + for i := 0; i < size; i++ { + arrowObs.decimals[i] = decimal.NewFromFloat(float64(rand.Intn(10000))) + } + + b.Run("Before_Original/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = MedianAggregator(originalValues, 1) + } + }) + + b.Run("After_Arrow/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = agg.medianAggregate(arrowObs, 1) + } + }) + } +} + +// BenchmarkComparison_QuoteAggregation compares quote aggregation performance. +func BenchmarkComparison_QuoteAggregation(b *testing.B) { + sizes := []int{10, 100, 1000, 5000} + + for _, size := range sizes { + name := fmt.Sprintf("%d_observations", size) + + // Setup: create test data for original implementation + originalValues := make([]StreamValue, size) + for i := 0; i < size; i++ { + base := float64(rand.Intn(10000)) + originalValues[i] = &Quote{ + Bid: decimal.NewFromFloat(base - 0.5), + Benchmark: decimal.NewFromFloat(base), + Ask: decimal.NewFromFloat(base + 0.5), + } + } + + // Setup: create test data for Arrow implementation + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + arrowObs := &streamObservations{ + valueType: StreamValueTypeQuote, + quotes: make([]*Quote, size), + } + for i := 0; i < size; i++ { + base := float64(rand.Intn(10000)) + arrowObs.quotes[i] = &Quote{ + Bid: decimal.NewFromFloat(base - 0.5), + Benchmark: decimal.NewFromFloat(base), + Ask: decimal.NewFromFloat(base + 0.5), + } + } + + b.Run("Before_Original/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = QuoteAggregator(originalValues, 1) + } + }) + + b.Run("After_Arrow/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = agg.quoteAggregate(arrowObs, 1) + } + }) + } +} + +// BenchmarkComparison_ModeAggregation compares mode aggregation performance. +func BenchmarkComparison_ModeAggregation(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + name := fmt.Sprintf("%d_observations", size) + + // Create values with some repetition for mode to work + numUnique := size / 5 // 20% unique values + if numUnique < 3 { + numUnique = 3 + } + + // Setup: create test data for original implementation + originalValues := make([]StreamValue, size) + for i := 0; i < size; i++ { + originalValues[i] = ToDecimal(decimal.NewFromInt(int64(i % numUnique))) + } + + // Setup: create test data for Arrow implementation + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + arrowObs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: make([]decimal.Decimal, size), + } + for i := 0; i < size; i++ { + arrowObs.decimals[i] = decimal.NewFromInt(int64(i % numUnique)) + } + + b.Run("Before_Original/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = ModeAggregator(originalValues, 1) + } + }) + + b.Run("After_Arrow/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = agg.modeAggregate(arrowObs, 1) + } + }) + } +} + +// BenchmarkComparison_StreamValuesConversion compares map vs Arrow record operations. +func BenchmarkComparison_StreamValuesConversion(b *testing.B) { + sizes := []int{100, 1000, 5000, 10000} + + for _, size := range sizes { + name := fmt.Sprintf("%d_streams", size) + + // Setup: create test data + values := make(map[llotypes.StreamID]StreamValue, size) + for i := 0; i < size; i++ { + values[llotypes.StreamID(i)] = ToDecimal(decimal.NewFromFloat(float64(i))) + } + + pool := NewArrowBuilderPool(0) + + b.Run("Before_Original_MapIteration/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + // Simulate iterating over map (common operation) + count := 0 + for _, v := range values { + if v != nil { + count++ + } + } + _ = count + } + }) + + b.Run("After_Arrow_ToRecord/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + record, _ := StreamValuesToArrowRecord(values, pool) + if record != nil { + record.Release() + } + } + }) + + // Create Arrow record for FromRecord benchmark + record, _ := StreamValuesToArrowRecord(values, pool) + defer record.Release() + + b.Run("After_Arrow_FromRecord/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = ArrowRecordToStreamValues(record) + } + }) + } +} + +// BenchmarkComparison_MemoryAllocation measures memory allocation patterns. +func BenchmarkComparison_MemoryAllocation(b *testing.B) { + sizes := []int{100, 1000, 10000} + + for _, size := range sizes { + name := fmt.Sprintf("%d_values", size) + + b.Run("Before_Original_MapAllocation/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + values := make(map[llotypes.StreamID]StreamValue, size) + for j := 0; j < size; j++ { + values[llotypes.StreamID(j)] = ToDecimal(decimal.NewFromFloat(float64(j))) + } + _ = values + } + }) + + pool := NewArrowBuilderPool(0) + + b.Run("After_Arrow_PooledAllocation/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + builder := pool.GetObservationBuilder() + // Simulate adding data - must fill all fields for valid record + for j := 0; j < size; j++ { + builder.Field(ObsColObserverID).(*array.Uint8Builder).Append(1) + builder.Field(ObsColStreamID).(*array.Uint32Builder).Append(uint32(j)) + builder.Field(ObsColValueType).(*array.Uint8Builder).Append(StreamValueTypeDecimal) + builder.Field(ObsColDecimalValue).(*array.BinaryBuilder).AppendNull() + builder.Field(ObsColQuoteBid).(*array.BinaryBuilder).AppendNull() + builder.Field(ObsColQuoteBenchmark).(*array.BinaryBuilder).AppendNull() + builder.Field(ObsColQuoteAsk).(*array.BinaryBuilder).AppendNull() + builder.Field(ObsColObservedAtNs).(*array.Uint64Builder).Append(0) + builder.Field(ObsColTimestampNs).(*array.Uint64Builder).Append(0) + } + record := builder.NewRecord() + record.Release() + pool.PutObservationBuilder(builder) + } + }) + } +} + +// BenchmarkComparison_EndToEnd simulates full aggregation pipeline. +func BenchmarkComparison_EndToEnd(b *testing.B) { + sizes := []int{100, 500, 1000} + numObservers := 5 // Simulate 5 nodes + + for _, size := range sizes { + name := fmt.Sprintf("%d_streams_%d_observers", size, numObservers) + + // Setup: simulate observations from multiple nodes + allObservations := make([][]StreamValue, numObservers) + for obs := 0; obs < numObservers; obs++ { + allObservations[obs] = make([]StreamValue, size) + for i := 0; i < size; i++ { + // Add some variance between observers + allObservations[obs][i] = ToDecimal(decimal.NewFromFloat(float64(i) + float64(obs)*0.1)) + } + } + + b.Run("Before_Original/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + // Merge observations (simple concatenation for streams) + merged := make([]StreamValue, 0, size*numObservers) + for _, obs := range allObservations { + merged = append(merged, obs...) + } + // Aggregate per stream (simplified: just take all) + _, _ = MedianAggregator(merged[:size], 1) + } + }) + + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + + // Pre-build Arrow observations + arrowObs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: make([]decimal.Decimal, size*numObservers), + } + idx := 0 + for obs := 0; obs < numObservers; obs++ { + for i := 0; i < size; i++ { + arrowObs.decimals[idx] = decimal.NewFromFloat(float64(i) + float64(obs)*0.1) + idx++ + } + } + + b.Run("After_Arrow/"+name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = agg.medianAggregate(arrowObs, 1) + } + }) + } +} + +// ============================================================================ +// RESULT EQUIVALENCE TESTS +// ============================================================================ +// These tests verify that the Arrow implementation produces the same results +// as the original implementation. +// ============================================================================ + +// TestResultEquivalence_MedianAggregator verifies Arrow median matches original. +func TestResultEquivalence_MedianAggregator(t *testing.T) { + sizes := []int{5, 10, 21, 100} + + for _, size := range sizes { + t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { + // Create identical test data + values := make([]decimal.Decimal, size) + for i := 0; i < size; i++ { + values[i] = decimal.NewFromFloat(float64(i * 10)) + } + + // Original implementation + originalInput := make([]StreamValue, size) + for i := 0; i < size; i++ { + originalInput[i] = ToDecimal(values[i]) + } + originalResult, err := MedianAggregator(originalInput, 1) + require.NoError(t, err) + + // Arrow implementation + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + arrowObs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: values, + } + arrowResult, err := agg.medianAggregate(arrowObs, 1) + require.NoError(t, err) + + // Compare results + originalDec := originalResult.(*Decimal).Decimal() + arrowDec := arrowResult.(*Decimal).Decimal() + assert.True(t, originalDec.Equal(arrowDec), + "Results differ: original=%s, arrow=%s", originalDec, arrowDec) + }) + } +} + +// TestResultEquivalence_QuoteAggregator verifies Arrow quote matches original. +func TestResultEquivalence_QuoteAggregator(t *testing.T) { + sizes := []int{5, 10, 21} + + for _, size := range sizes { + t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { + // Create identical test data + quotes := make([]*Quote, size) + for i := 0; i < size; i++ { + base := float64(i * 10) + quotes[i] = &Quote{ + Bid: decimal.NewFromFloat(base - 1), + Benchmark: decimal.NewFromFloat(base), + Ask: decimal.NewFromFloat(base + 1), + } + } + + // Original implementation + originalInput := make([]StreamValue, size) + for i := 0; i < size; i++ { + originalInput[i] = quotes[i] + } + originalResult, err := QuoteAggregator(originalInput, 1) + require.NoError(t, err) + + // Arrow implementation + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + arrowObs := &streamObservations{ + valueType: StreamValueTypeQuote, + quotes: quotes, + } + arrowResult, err := agg.quoteAggregate(arrowObs, 1) + require.NoError(t, err) + + // Compare results + originalQuote := originalResult.(*Quote) + arrowQuote := arrowResult.(*Quote) + + assert.True(t, originalQuote.Bid.Equal(arrowQuote.Bid), + "Bid differs: original=%s, arrow=%s", originalQuote.Bid, arrowQuote.Bid) + assert.True(t, originalQuote.Benchmark.Equal(arrowQuote.Benchmark), + "Benchmark differs: original=%s, arrow=%s", originalQuote.Benchmark, arrowQuote.Benchmark) + assert.True(t, originalQuote.Ask.Equal(arrowQuote.Ask), + "Ask differs: original=%s, arrow=%s", originalQuote.Ask, arrowQuote.Ask) + }) + } +} + +// TestResultEquivalence_ModeAggregator verifies Arrow mode matches original. +func TestResultEquivalence_ModeAggregator(t *testing.T) { + t.Run("clear_mode", func(t *testing.T) { + // Create data with a clear mode + values := []decimal.Decimal{ + decimal.NewFromInt(10), + decimal.NewFromInt(20), + decimal.NewFromInt(20), + decimal.NewFromInt(20), + decimal.NewFromInt(30), + } + + // Original implementation + originalInput := make([]StreamValue, len(values)) + for i, v := range values { + originalInput[i] = ToDecimal(v) + } + originalResult, err := ModeAggregator(originalInput, 2) + require.NoError(t, err) + + // Arrow implementation + pool := NewArrowBuilderPool(0) + agg := NewArrowAggregator(pool, nil) + arrowObs := &streamObservations{ + valueType: StreamValueTypeDecimal, + decimals: values, + } + arrowResult, err := agg.modeAggregate(arrowObs, 2) + require.NoError(t, err) + + // Compare results + originalDec := originalResult.(*Decimal).Decimal() + arrowDec := arrowResult.(*Decimal).Decimal() + assert.True(t, originalDec.Equal(arrowDec), + "Results differ: original=%s, arrow=%s", originalDec, arrowDec) + assert.True(t, decimal.NewFromInt(20).Equal(arrowDec), "Expected mode to be 20") + }) +} diff --git a/llo/arrow_converters.go b/llo/arrow_converters.go new file mode 100644 index 0000000..87c64c8 --- /dev/null +++ b/llo/arrow_converters.go @@ -0,0 +1,278 @@ +package llo + +import ( + "fmt" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/shopspring/decimal" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +// maxDecimalBytes is the maximum size of a binary-encoded decimal. +// This limit prevents potential DoS from oversized input data. +const maxDecimalBytes = 256 + +// DecimalToBytes converts a shopspring decimal to its binary representation. +// This uses the decimal's native MarshalBinary for consistent serialization. +func DecimalToBytes(d decimal.Decimal) ([]byte, error) { + return d.MarshalBinary() +} + +// BytesToDecimal converts bytes back to a shopspring decimal. +// Returns an error if the input exceeds maxDecimalBytes. +func BytesToDecimal(b []byte) (decimal.Decimal, error) { + if len(b) > maxDecimalBytes { + return decimal.Decimal{}, fmt.Errorf("decimal bytes exceed max size: %d > %d", len(b), maxDecimalBytes) + } + var d decimal.Decimal + err := d.UnmarshalBinary(b) + return d, err +} + +// StreamValueToArrow appends a StreamValue to the appropriate Arrow builder fields. +// Returns the value type that was appended. +func StreamValueToArrow( + sv StreamValue, + valueTypeBuilder *array.Uint8Builder, + decimalBuilder *array.BinaryBuilder, + bidBuilder *array.BinaryBuilder, + benchmarkBuilder *array.BinaryBuilder, + askBuilder *array.BinaryBuilder, + observedAtBuilder *array.Uint64Builder, +) (uint8, error) { + if sv == nil { + // Null value - append nulls to all fields + valueTypeBuilder.AppendNull() + decimalBuilder.AppendNull() + bidBuilder.AppendNull() + benchmarkBuilder.AppendNull() + askBuilder.AppendNull() + observedAtBuilder.AppendNull() + return 0, nil + } + + switch v := sv.(type) { + case *Decimal: + valueTypeBuilder.Append(StreamValueTypeDecimal) + b, err := DecimalToBytes(v.Decimal()) + if err != nil { + return 0, err + } + decimalBuilder.Append(b) + bidBuilder.AppendNull() + benchmarkBuilder.AppendNull() + askBuilder.AppendNull() + observedAtBuilder.AppendNull() + return StreamValueTypeDecimal, nil + + case *Quote: + valueTypeBuilder.Append(StreamValueTypeQuote) + decimalBuilder.AppendNull() + + bidBytes, err := DecimalToBytes(v.Bid) + if err != nil { + return 0, err + } + bidBuilder.Append(bidBytes) + + benchmarkBytes, err := DecimalToBytes(v.Benchmark) + if err != nil { + return 0, err + } + benchmarkBuilder.Append(benchmarkBytes) + + askBytes, err := DecimalToBytes(v.Ask) + if err != nil { + return 0, err + } + askBuilder.Append(askBytes) + + observedAtBuilder.AppendNull() + return StreamValueTypeQuote, nil + + case *TimestampedStreamValue: + valueTypeBuilder.Append(StreamValueTypeTimestampd) + observedAtBuilder.Append(v.ObservedAtNanoseconds) + + // Handle the inner stream value (usually Decimal) + if inner, ok := v.StreamValue.(*Decimal); ok { + b, err := DecimalToBytes(inner.Decimal()) + if err != nil { + return 0, err + } + decimalBuilder.Append(b) + } else { + decimalBuilder.AppendNull() + } + bidBuilder.AppendNull() + benchmarkBuilder.AppendNull() + askBuilder.AppendNull() + return StreamValueTypeTimestampd, nil + + default: + // Unknown type - append nulls + valueTypeBuilder.AppendNull() + decimalBuilder.AppendNull() + bidBuilder.AppendNull() + benchmarkBuilder.AppendNull() + askBuilder.AppendNull() + observedAtBuilder.AppendNull() + return 0, nil + } +} + +// ArrowToStreamValue extracts a StreamValue from Arrow arrays at the given index. +func ArrowToStreamValue( + idx int, + valueTypeArr *array.Uint8, + decimalArr *array.Binary, + bidArr *array.Binary, + benchmarkArr *array.Binary, + askArr *array.Binary, + observedAtArr *array.Uint64, +) (StreamValue, error) { + if valueTypeArr.IsNull(idx) { + return nil, nil + } + + valueType := valueTypeArr.Value(idx) + + switch valueType { + case StreamValueTypeDecimal: + if decimalArr.IsNull(idx) { + return nil, nil + } + d, err := BytesToDecimal(decimalArr.Value(idx)) + if err != nil { + return nil, err + } + return ToDecimal(d), nil + + case StreamValueTypeQuote: + if bidArr.IsNull(idx) || benchmarkArr.IsNull(idx) || askArr.IsNull(idx) { + return nil, nil + } + bid, err := BytesToDecimal(bidArr.Value(idx)) + if err != nil { + return nil, err + } + benchmark, err := BytesToDecimal(benchmarkArr.Value(idx)) + if err != nil { + return nil, err + } + ask, err := BytesToDecimal(askArr.Value(idx)) + if err != nil { + return nil, err + } + return &Quote{Bid: bid, Benchmark: benchmark, Ask: ask}, nil + + case StreamValueTypeTimestampd: + observedAt := uint64(0) + if !observedAtArr.IsNull(idx) { + observedAt = observedAtArr.Value(idx) + } + + var innerValue StreamValue + if !decimalArr.IsNull(idx) { + d, err := BytesToDecimal(decimalArr.Value(idx)) + if err != nil { + return nil, err + } + innerValue = ToDecimal(d) + } + + return &TimestampedStreamValue{ + ObservedAtNanoseconds: observedAt, + StreamValue: innerValue, + }, nil + + default: + return nil, nil + } +} + +// StreamValuesToArrowRecord converts a map of StreamValues to an Arrow Record. +// This is useful for batch operations on cached stream values. +func StreamValuesToArrowRecord( + values map[llotypes.StreamID]StreamValue, + pool *ArrowBuilderPool, +) (arrow.Record, error) { + builder := pool.GetCacheBuilder() + + streamIDBuilder := builder.Field(CacheColStreamID).(*array.Uint32Builder) + valueTypeBuilder := builder.Field(CacheColValueType).(*array.Uint8Builder) + decimalBuilder := builder.Field(CacheColDecimalValue).(*array.BinaryBuilder) + bidBuilder := builder.Field(CacheColQuoteBid).(*array.BinaryBuilder) + benchmarkBuilder := builder.Field(CacheColQuoteBenchmark).(*array.BinaryBuilder) + askBuilder := builder.Field(CacheColQuoteAsk).(*array.BinaryBuilder) + observedAtBuilder := builder.Field(CacheColObservedAtNs).(*array.Uint64Builder) + expiresAtBuilder := builder.Field(CacheColExpiresAtNs).(*array.Int64Builder) + + // Pre-allocate capacity + streamIDBuilder.Reserve(len(values)) + + for streamID, sv := range values { + streamIDBuilder.Append(streamID) + _, err := StreamValueToArrow(sv, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + if err != nil { + pool.PutCacheBuilder(builder) + return nil, err + } + expiresAtBuilder.Append(0) // Caller should set expiration + } + + record := builder.NewRecord() + pool.PutCacheBuilder(builder) + return record, nil +} + +// ArrowRecordToStreamValues converts an Arrow Record back to a map of StreamValues. +func ArrowRecordToStreamValues(record arrow.Record) (map[llotypes.StreamID]StreamValue, error) { + if record == nil || record.NumRows() == 0 { + return nil, nil + } + + streamIDArr := record.Column(CacheColStreamID).(*array.Uint32) + valueTypeArr := record.Column(CacheColValueType).(*array.Uint8) + decimalArr := record.Column(CacheColDecimalValue).(*array.Binary) + bidArr := record.Column(CacheColQuoteBid).(*array.Binary) + benchmarkArr := record.Column(CacheColQuoteBenchmark).(*array.Binary) + askArr := record.Column(CacheColQuoteAsk).(*array.Binary) + observedAtArr := record.Column(CacheColObservedAtNs).(*array.Uint64) + + result := make(map[llotypes.StreamID]StreamValue, record.NumRows()) + + for i := 0; i < int(record.NumRows()); i++ { + streamID := streamIDArr.Value(i) + sv, err := ArrowToStreamValue(i, valueTypeArr, decimalArr, bidArr, + benchmarkArr, askArr, observedAtArr) + if err != nil { + return nil, err + } + if sv != nil { + result[streamID] = sv + } + } + + return result, nil +} + +// ExtractDecimalColumn extracts all decimal values from an Arrow record column. +// Useful for aggregation operations that need to work with decimal arrays. +func ExtractDecimalColumn(arr *array.Binary) ([]decimal.Decimal, error) { + result := make([]decimal.Decimal, 0, arr.Len()) + for i := 0; i < arr.Len(); i++ { + if arr.IsNull(i) { + continue + } + d, err := BytesToDecimal(arr.Value(i)) + if err != nil { + return nil, err + } + result = append(result, d) + } + return result, nil +} diff --git a/llo/arrow_observation_merger.go b/llo/arrow_observation_merger.go new file mode 100644 index 0000000..9677f95 --- /dev/null +++ b/llo/arrow_observation_merger.go @@ -0,0 +1,298 @@ +package llo + +import ( + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" +) + +// ArrowObservationMerger converts multiple node observations into a single Arrow record. +// This enables efficient batch aggregation across all streams from all observers. +type ArrowObservationMerger struct { + pool *ArrowBuilderPool + codec ObservationCodec + logger logger.Logger +} + +// NewArrowObservationMerger creates a new observation merger. +// Logger can be nil if logging is not needed. +func NewArrowObservationMerger(pool *ArrowBuilderPool, codec ObservationCodec, lggr logger.Logger) *ArrowObservationMerger { + return &ArrowObservationMerger{ + pool: pool, + codec: codec, + logger: lggr, + } +} + +// MergeObservations converts attributed observations from multiple nodes into a single Arrow record. +// The record contains all stream values from all observers, ready for aggregation. +// +// Returns an Arrow record that must be released by the caller when done. +func (m *ArrowObservationMerger) MergeObservations( + aos []types.AttributedObservation, +) (arrow.Record, error) { + builder := m.pool.GetObservationBuilder() + + // Get typed builders for each column + observerIDBuilder := builder.Field(ObsColObserverID).(*array.Uint8Builder) + streamIDBuilder := builder.Field(ObsColStreamID).(*array.Uint32Builder) + valueTypeBuilder := builder.Field(ObsColValueType).(*array.Uint8Builder) + decimalBuilder := builder.Field(ObsColDecimalValue).(*array.BinaryBuilder) + bidBuilder := builder.Field(ObsColQuoteBid).(*array.BinaryBuilder) + benchmarkBuilder := builder.Field(ObsColQuoteBenchmark).(*array.BinaryBuilder) + askBuilder := builder.Field(ObsColQuoteAsk).(*array.BinaryBuilder) + observedAtBuilder := builder.Field(ObsColObservedAtNs).(*array.Uint64Builder) + timestampBuilder := builder.Field(ObsColTimestampNs).(*array.Uint64Builder) + + // Estimate capacity: assume each observation has ~100 stream values on average + estimatedRows := len(aos) * 100 + observerIDBuilder.Reserve(estimatedRows) + streamIDBuilder.Reserve(estimatedRows) + + for _, ao := range aos { + // Validate observer ID bounds before casting to uint8 + if ao.Observer > 255 { + if m.logger != nil { + m.logger.Warnw("Observer ID exceeds uint8 bounds", "observer", ao.Observer) + } + continue + } + + // Decode the observation + obs, err := m.codec.Decode(ao.Observation) + if err != nil { + if m.logger != nil { + m.logger.Debugw("Failed to decode observation", "observer", ao.Observer, "err", err) + } + continue + } + + observerID := uint8(ao.Observer) + timestamp := obs.UnixTimestampNanoseconds + + // Add each stream value + for streamID, sv := range obs.StreamValues { + if sv == nil { + continue + } + + observerIDBuilder.Append(observerID) + streamIDBuilder.Append(streamID) + timestampBuilder.Append(timestamp) + + // Append the stream value to appropriate columns + _, err := StreamValueToArrow(sv, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + if err != nil { + if m.logger != nil { + m.logger.Debugw("Failed to convert stream value", "streamID", streamID, "observer", observerID, "err", err) + } + continue + } + } + } + + record := builder.NewRecord() + m.pool.PutObservationBuilder(builder) + return record, nil +} + +// MergeStreamValues merges stream values from multiple sources into a single Arrow record. +// This is useful for merging cached values with new observations. +func (m *ArrowObservationMerger) MergeStreamValues( + valuesByObserver map[uint8]StreamValues, + timestamps map[uint8]uint64, +) (arrow.Record, error) { + builder := m.pool.GetObservationBuilder() + + observerIDBuilder := builder.Field(ObsColObserverID).(*array.Uint8Builder) + streamIDBuilder := builder.Field(ObsColStreamID).(*array.Uint32Builder) + valueTypeBuilder := builder.Field(ObsColValueType).(*array.Uint8Builder) + decimalBuilder := builder.Field(ObsColDecimalValue).(*array.BinaryBuilder) + bidBuilder := builder.Field(ObsColQuoteBid).(*array.BinaryBuilder) + benchmarkBuilder := builder.Field(ObsColQuoteBenchmark).(*array.BinaryBuilder) + askBuilder := builder.Field(ObsColQuoteAsk).(*array.BinaryBuilder) + observedAtBuilder := builder.Field(ObsColObservedAtNs).(*array.Uint64Builder) + timestampBuilder := builder.Field(ObsColTimestampNs).(*array.Uint64Builder) + + for observerID, values := range valuesByObserver { + timestamp := timestamps[observerID] + + for streamID, sv := range values { + if sv == nil { + continue + } + + observerIDBuilder.Append(observerID) + streamIDBuilder.Append(streamID) + timestampBuilder.Append(timestamp) + + _, err := StreamValueToArrow(sv, valueTypeBuilder, decimalBuilder, + bidBuilder, benchmarkBuilder, askBuilder, observedAtBuilder) + if err != nil { + continue + } + } + } + + record := builder.NewRecord() + m.pool.PutObservationBuilder(builder) + return record, nil +} + +// FilterByStreamIDs creates a new record containing only the specified stream IDs. +// This is useful for extracting relevant streams for a specific channel. +// If pool is nil, a new temporary pool will be created (less efficient). +func FilterByStreamIDs(record arrow.Record, streamIDs []uint32, pool *LLOMemoryPool) arrow.Record { + if record == nil || record.NumRows() == 0 || len(streamIDs) == 0 { + return nil + } + + // Build a set for fast lookup + streamIDSet := make(map[uint32]struct{}, len(streamIDs)) + for _, id := range streamIDs { + streamIDSet[id] = struct{}{} + } + + streamIDArr := record.Column(ObsColStreamID).(*array.Uint32) + + // Find matching indices + indices := make([]int, 0, record.NumRows()) + for i := 0; i < int(record.NumRows()); i++ { + if _, ok := streamIDSet[streamIDArr.Value(i)]; ok { + indices = append(indices, i) + } + } + + if len(indices) == 0 { + return nil + } + + // Build filtered record + // Note: In production, you'd use Arrow's Take kernel for efficiency + return takeRows(record, indices, pool) +} + +// takeRows creates a new record with only the specified row indices. +// This is a simplified implementation; production code should use Arrow compute. +// If pool is nil, a new temporary pool will be created (less efficient). +func takeRows(record arrow.Record, indices []int, pool *LLOMemoryPool) arrow.Record { + if len(indices) == 0 { + return nil + } + + // Use provided pool or create a temporary one + if pool == nil { + pool = NewLLOMemoryPool(0) + } + + // Create new builders for each column + schema := record.Schema() + builders := make([]array.Builder, schema.NumFields()) + + for i := 0; i < schema.NumFields(); i++ { + builders[i] = array.NewBuilder(pool, schema.Field(i).Type) + } + + // Copy selected rows + for _, idx := range indices { + for colIdx := 0; colIdx < int(record.NumCols()); colIdx++ { + col := record.Column(colIdx) + appendValue(builders[colIdx], col, idx) + } + } + + // Build arrays + arrays := make([]arrow.Array, len(builders)) + for i, b := range builders { + arrays[i] = b.NewArray() + b.Release() + } + + return array.NewRecord(schema, arrays, int64(len(indices))) +} + +// appendValue appends a single value from an array to a builder. +// Uses safe type assertions to prevent panics on type mismatches. +func appendValue(builder array.Builder, arr arrow.Array, idx int) { + if arr.IsNull(idx) { + builder.AppendNull() + return + } + + switch b := builder.(type) { + case *array.Uint8Builder: + if a, ok := arr.(*array.Uint8); ok { + b.Append(a.Value(idx)) + } else { + builder.AppendNull() + } + case *array.Uint32Builder: + if a, ok := arr.(*array.Uint32); ok { + b.Append(a.Value(idx)) + } else { + builder.AppendNull() + } + case *array.Uint64Builder: + if a, ok := arr.(*array.Uint64); ok { + b.Append(a.Value(idx)) + } else { + builder.AppendNull() + } + case *array.Int64Builder: + if a, ok := arr.(*array.Int64); ok { + b.Append(a.Value(idx)) + } else { + builder.AppendNull() + } + case *array.BinaryBuilder: + if a, ok := arr.(*array.Binary); ok { + b.Append(a.Value(idx)) + } else { + builder.AppendNull() + } + case *array.StringBuilder: + if a, ok := arr.(*array.String); ok { + b.Append(a.Value(idx)) + } else { + builder.AppendNull() + } + default: + builder.AppendNull() + } +} + +// CountByStreamID returns the count of observations per stream ID. +// Useful for validation and debugging. +func CountByStreamID(record arrow.Record) map[uint32]int { + if record == nil || record.NumRows() == 0 { + return nil + } + + streamIDArr := record.Column(ObsColStreamID).(*array.Uint32) + counts := make(map[uint32]int) + + for i := 0; i < int(record.NumRows()); i++ { + counts[streamIDArr.Value(i)]++ + } + + return counts +} + +// CountByObserver returns the count of observations per observer. +// Useful for validation and debugging. +func CountByObserver(record arrow.Record) map[uint8]int { + if record == nil || record.NumRows() == 0 { + return nil + } + + observerIDArr := record.Column(ObsColObserverID).(*array.Uint8) + counts := make(map[uint8]int) + + for i := 0; i < int(record.NumRows()); i++ { + counts[observerIDArr.Value(i)]++ + } + + return counts +} diff --git a/llo/arrow_pool.go b/llo/arrow_pool.go new file mode 100644 index 0000000..e321ca6 --- /dev/null +++ b/llo/arrow_pool.go @@ -0,0 +1,186 @@ +package llo + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" +) + +// LLOMemoryPool wraps Arrow's allocator with metrics and bounds checking. +// This can replace the 1GB memory ballast by providing controlled allocation. +type LLOMemoryPool struct { + allocator memory.Allocator + allocated atomic.Int64 + maxBytes int64 + allocCount atomic.Int64 + releaseCount atomic.Int64 +} + +// NewLLOMemoryPool creates a new memory pool with optional max memory limit. +// If maxBytes is 0, no limit is enforced. +func NewLLOMemoryPool(maxBytes int64) *LLOMemoryPool { + return &LLOMemoryPool{ + allocator: memory.NewGoAllocator(), + maxBytes: maxBytes, + } +} + +// ErrMemoryLimitExceeded is returned when an allocation would exceed the memory limit. +var ErrMemoryLimitExceeded = fmt.Errorf("memory limit exceeded") + +// Allocate allocates a new byte slice of the given size. +// Returns nil if the allocation would exceed the configured memory limit. +func (p *LLOMemoryPool) Allocate(size int) []byte { + if p.maxBytes > 0 && p.allocated.Load()+int64(size) > p.maxBytes { + // Enforce memory limit - return nil to signal allocation failure. + // Caller must handle this case appropriately. + return nil + } + buf := p.allocator.Allocate(size) + p.allocated.Add(int64(size)) + p.allocCount.Add(1) + return buf +} + +// Reallocate reallocates a byte slice to a new size. +func (p *LLOMemoryPool) Reallocate(size int, b []byte) []byte { + oldSize := len(b) + buf := p.allocator.Reallocate(size, b) + p.allocated.Add(int64(size - oldSize)) + return buf +} + +// Free releases a byte slice back to the pool. +func (p *LLOMemoryPool) Free(b []byte) { + p.allocated.Add(-int64(len(b))) + p.releaseCount.Add(1) + p.allocator.Free(b) +} + +// Metrics returns current allocation statistics. +func (p *LLOMemoryPool) Metrics() (allocated, allocs, releases int64) { + return p.allocated.Load(), p.allocCount.Load(), p.releaseCount.Load() +} + +// AllocatedBytes returns the current allocated bytes count. +func (p *LLOMemoryPool) AllocatedBytes() int64 { + return p.allocated.Load() +} + +// RecordBuilderPool manages a pool of Arrow RecordBuilders for efficient reuse. +// This reduces allocations when building Arrow records repeatedly. +type RecordBuilderPool struct { + pool sync.Pool + memPool *LLOMemoryPool + schema *arrow.Schema +} + +// NewRecordBuilderPool creates a new pool for the given schema. +func NewRecordBuilderPool(schema *arrow.Schema, memPool *LLOMemoryPool) *RecordBuilderPool { + rbp := &RecordBuilderPool{ + memPool: memPool, + schema: schema, + } + rbp.pool = sync.Pool{ + New: func() any { + return array.NewRecordBuilder(memPool, schema) + }, + } + return rbp +} + +// Get retrieves a RecordBuilder from the pool. +func (p *RecordBuilderPool) Get() *array.RecordBuilder { + return p.pool.Get().(*array.RecordBuilder) +} + +// Put returns a RecordBuilder to the pool. +// +// IMPORTANT: Callers MUST ensure the builder is in a clean state before calling Put. +// This means either: +// 1. NewRecord() was called (which resets the builder), OR +// 2. No data was appended to the builder since it was retrieved +// +// Returning a builder with partial/unbalanced data will cause corruption for +// the next user. In error paths where NewRecord() wasn't called, callers should +// either call NewRecord() and release the result, or simply not return the +// builder to the pool (let it be garbage collected). +func (p *RecordBuilderPool) Put(b *array.RecordBuilder) { + p.pool.Put(b) +} + +// ArrowBuilderPool contains pools for all LLO-related Arrow schemas. +type ArrowBuilderPool struct { + memPool *LLOMemoryPool + observationPool *RecordBuilderPool + aggregatesPool *RecordBuilderPool + cachePool *RecordBuilderPool + transmissionPool *RecordBuilderPool +} + +// NewArrowBuilderPool creates a new pool for all LLO Arrow schemas. +// maxMemoryBytes sets the memory limit (0 for unlimited). +func NewArrowBuilderPool(maxMemoryBytes int64) *ArrowBuilderPool { + memPool := NewLLOMemoryPool(maxMemoryBytes) + return &ArrowBuilderPool{ + memPool: memPool, + observationPool: NewRecordBuilderPool(ObservationSchema, memPool), + aggregatesPool: NewRecordBuilderPool(StreamAggregatesSchema, memPool), + cachePool: NewRecordBuilderPool(CacheSchema, memPool), + transmissionPool: NewRecordBuilderPool(TransmissionSchema, memPool), + } +} + +// GetObservationBuilder returns a builder for observation records. +func (p *ArrowBuilderPool) GetObservationBuilder() *array.RecordBuilder { + return p.observationPool.Get() +} + +// PutObservationBuilder returns an observation builder to the pool. +func (p *ArrowBuilderPool) PutObservationBuilder(b *array.RecordBuilder) { + p.observationPool.Put(b) +} + +// GetAggregatesBuilder returns a builder for aggregate records. +func (p *ArrowBuilderPool) GetAggregatesBuilder() *array.RecordBuilder { + return p.aggregatesPool.Get() +} + +// PutAggregatesBuilder returns an aggregates builder to the pool. +func (p *ArrowBuilderPool) PutAggregatesBuilder(b *array.RecordBuilder) { + p.aggregatesPool.Put(b) +} + +// GetCacheBuilder returns a builder for cache records. +func (p *ArrowBuilderPool) GetCacheBuilder() *array.RecordBuilder { + return p.cachePool.Get() +} + +// PutCacheBuilder returns a cache builder to the pool. +func (p *ArrowBuilderPool) PutCacheBuilder(b *array.RecordBuilder) { + p.cachePool.Put(b) +} + +// GetTransmissionBuilder returns a builder for transmission records. +func (p *ArrowBuilderPool) GetTransmissionBuilder() *array.RecordBuilder { + return p.transmissionPool.Get() +} + +// PutTransmissionBuilder returns a transmission builder to the pool. +func (p *ArrowBuilderPool) PutTransmissionBuilder(b *array.RecordBuilder) { + p.transmissionPool.Put(b) +} + +// MemoryPool returns the underlying memory pool for direct access. +func (p *ArrowBuilderPool) MemoryPool() *LLOMemoryPool { + return p.memPool +} + +// MemoryStats returns current memory allocation statistics. +func (p *ArrowBuilderPool) MemoryStats() (allocated, allocs, releases int64) { + return p.memPool.Metrics() +} diff --git a/llo/arrow_schemas.go b/llo/arrow_schemas.go new file mode 100644 index 0000000..92a1d0e --- /dev/null +++ b/llo/arrow_schemas.go @@ -0,0 +1,128 @@ +package llo + +import ( + "github.com/apache/arrow-go/v18/arrow" +) + +// StreamValueType constants for Arrow arrays +const ( + StreamValueTypeDecimal uint8 = 0 + StreamValueTypeQuote uint8 = 1 + StreamValueTypeTimestampd uint8 = 2 +) + +// ObservationSchema defines the Arrow schema for merged observations from multiple nodes. +// This columnar format enables efficient batch aggregation across all streams. +// +// Fields: +// - observer_id: The node that produced this observation (0-255) +// - stream_id: The stream identifier (uint32) +// - value_type: Type discriminator (0=Decimal, 1=Quote, 2=TimestampedStreamValue) +// - decimal_value: Binary-encoded decimal value (shopspring/decimal format) +// - quote_bid/benchmark/ask: Binary-encoded quote components +// - observed_at_ns: Timestamp from provider (for TimestampedStreamValue) +// - timestamp_ns: Node's observation timestamp +var ObservationSchema = arrow.NewSchema( + []arrow.Field{ + {Name: "observer_id", Type: arrow.PrimitiveTypes.Uint8, Nullable: false}, + {Name: "stream_id", Type: arrow.PrimitiveTypes.Uint32, Nullable: false}, + {Name: "value_type", Type: arrow.PrimitiveTypes.Uint8, Nullable: false}, + {Name: "decimal_value", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_bid", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_benchmark", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_ask", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "observed_at_ns", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "timestamp_ns", Type: arrow.PrimitiveTypes.Uint64, Nullable: false}, + }, + nil, // no metadata +) + +// StreamAggregatesSchema defines the Arrow schema for aggregated stream values. +// Used as output from aggregation and input to report generation. +var StreamAggregatesSchema = arrow.NewSchema( + []arrow.Field{ + {Name: "stream_id", Type: arrow.PrimitiveTypes.Uint32, Nullable: false}, + {Name: "aggregator", Type: arrow.PrimitiveTypes.Uint32, Nullable: false}, + {Name: "value_type", Type: arrow.PrimitiveTypes.Uint8, Nullable: false}, + {Name: "decimal_value", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_bid", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_benchmark", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_ask", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "observed_at_ns", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + }, + nil, +) + +// CacheSchema defines the Arrow schema for the observation cache. +// This is optimized for fast lookups by stream_id with TTL-based expiration. +var CacheSchema = arrow.NewSchema( + []arrow.Field{ + {Name: "stream_id", Type: arrow.PrimitiveTypes.Uint32, Nullable: false}, + {Name: "value_type", Type: arrow.PrimitiveTypes.Uint8, Nullable: false}, + {Name: "decimal_value", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_bid", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_benchmark", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "quote_ask", Type: arrow.BinaryTypes.Binary, Nullable: true}, + {Name: "observed_at_ns", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "expires_at_ns", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + }, + nil, +) + +// TransmissionSchema defines the Arrow schema for batched report transmissions. +// Used for efficient batch encoding with Arrow IPC and compression. +// +// TODO: This schema is defined for future use in batched transmission encoding. +// Implementation pending integration with the transmission subsystem. +var TransmissionSchema = arrow.NewSchema( + []arrow.Field{ + {Name: "server_url", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "config_digest", Type: &arrow.FixedSizeBinaryType{ByteWidth: 32}, Nullable: false}, + {Name: "seq_nr", Type: arrow.PrimitiveTypes.Uint64, Nullable: false}, + {Name: "report_data", Type: arrow.BinaryTypes.LargeBinary, Nullable: false}, + {Name: "lifecycle_stage", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "report_format", Type: arrow.PrimitiveTypes.Uint32, Nullable: false}, + {Name: "signatures", Type: arrow.ListOf(arrow.BinaryTypes.Binary), Nullable: false}, + {Name: "signers", Type: arrow.ListOf(arrow.PrimitiveTypes.Uint8), Nullable: false}, + {Name: "transmission_hash", Type: &arrow.FixedSizeBinaryType{ByteWidth: 32}, Nullable: false}, + {Name: "created_at_ns", Type: arrow.FixedWidthTypes.Timestamp_ns, Nullable: false}, + }, + nil, +) + +// Column indices for ObservationSchema - for efficient column access +const ( + ObsColObserverID = iota + ObsColStreamID + ObsColValueType + ObsColDecimalValue + ObsColQuoteBid + ObsColQuoteBenchmark + ObsColQuoteAsk + ObsColObservedAtNs + ObsColTimestampNs +) + +// Column indices for StreamAggregatesSchema +const ( + AggColStreamID = iota + AggColAggregator + AggColValueType + AggColDecimalValue + AggColQuoteBid + AggColQuoteBenchmark + AggColQuoteAsk + AggColObservedAtNs +) + +// Column indices for CacheSchema +const ( + CacheColStreamID = iota + CacheColValueType + CacheColDecimalValue + CacheColQuoteBid + CacheColQuoteBenchmark + CacheColQuoteAsk + CacheColObservedAtNs + CacheColExpiresAtNs +)