diff --git a/pkg/workflows/wasm/host/engine/engine.go b/pkg/workflows/wasm/host/engine/engine.go new file mode 100644 index 0000000000..0147650135 --- /dev/null +++ b/pkg/workflows/wasm/host/engine/engine.go @@ -0,0 +1,79 @@ +package engine + +import "fmt" + +type EngineType string + +const ( + EngineWasmtime EngineType = "wasmtime" + EngineWasm3 EngineType = "wasm3" +) + +type MemoryAccessor interface { + Memory() []byte +} + +type LoadConfig struct { + InitialFuel uint64 +} + +type Engine interface { + Load(binary []byte, cfg LoadConfig) (Runtime, error) +} + +type Runtime interface { + V2ImportName() string + NewStore() Store + IncrementEpoch() + Close() +} + +type Store interface { + SetWasi(argv []string) + SetFuel(fuel uint64) error + SetLimiter(memoryBytes, tableElements, instances, tables, memories int64) + SetEpochDeadline(deadline uint64) + LinkNoDAG(exec *Execution) (Instance, error) + LinkLegacyDAG(exec *LegacyExecution) (Instance, error) + Close() +} + +type Instance interface { + CallStart() error +} + +type Execution struct { + V2ImportName string + SendResponse func(caller MemoryAccessor, ptr, ptrlen int32) int32 + CallCapability func(caller MemoryAccessor, ptr, ptrlen int32) int64 + AwaitCapabilities func(caller MemoryAccessor, awaitReq, awaitReqLen, respBuf, maxRespLen int32) int64 + GetSecrets func(caller MemoryAccessor, req, reqLen, respBuf, maxRespLen int32) int64 + AwaitSecrets func(caller MemoryAccessor, awaitReq, awaitReqLen, respBuf, maxRespLen int32) int64 + Log func(caller MemoryAccessor, ptr, ptrlen int32) + SwitchModes func(caller MemoryAccessor, mode int32) + RandomSeed func(mode int32) int64 + Now func(caller MemoryAccessor, resultTimestamp int32) int32 + PollOneoff func(caller MemoryAccessor, subPtr, evPtr, nSubs, resultNEvents int32) int32 + ClockTimeGet func(caller MemoryAccessor, id int32, precision int64, resultTs int32) int32 +} + +type LegacyExecution struct { + SendResponse func(caller MemoryAccessor, ptr, ptrlen int32) int32 + Fetch func(caller MemoryAccessor, respPtr, respLenPtr, reqPtr, reqPtrLen int32) int32 + Emit func(caller MemoryAccessor, respPtr, respLenPtr, msgPtr, msgLen int32) int32 + Log func(caller MemoryAccessor, ptr, ptrlen int32) + PollOneoff func(caller MemoryAccessor, subPtr, evPtr, nSubs, resultNEvents int32) int32 + ClockTimeGet func(caller MemoryAccessor, id int32, precision int64, resultTs int32) int32 + RandomGet func(caller MemoryAccessor, buf, bufLen int32) int32 // nil when determinism is disabled +} + +func NewEngine(engineType EngineType) (Engine, error) { + switch engineType { + case EngineWasmtime: + return &wasmtimeEngine{}, nil + case EngineWasm3: + return newWasm3Engine() + default: + return nil, fmt.Errorf("unsupported engine type: %q", engineType) + } +} diff --git a/pkg/workflows/wasm/host/engine/wasm3.go b/pkg/workflows/wasm/host/engine/wasm3.go new file mode 100644 index 0000000000..4e858ce2ec --- /dev/null +++ b/pkg/workflows/wasm/host/engine/wasm3.go @@ -0,0 +1,15 @@ +//go:build wasm3 + +package engine + +import "fmt" + +func newWasm3Engine() (Engine, error) { + return &wasm3Engine{}, nil +} + +type wasm3Engine struct{} + +func (e *wasm3Engine) Load(binary []byte, cfg LoadConfig) (Runtime, error) { + return nil, fmt.Errorf("wasm3 engine: not yet implemented") +} diff --git a/pkg/workflows/wasm/host/engine/wasm3_stub.go b/pkg/workflows/wasm/host/engine/wasm3_stub.go new file mode 100644 index 0000000000..b1e1dd535a --- /dev/null +++ b/pkg/workflows/wasm/host/engine/wasm3_stub.go @@ -0,0 +1,9 @@ +//go:build !wasm3 + +package engine + +import "fmt" + +func newWasm3Engine() (Engine, error) { + return nil, fmt.Errorf("wasm3 engine not available: build with -tags wasm3") +} diff --git a/pkg/workflows/wasm/host/engine/wasmtime.go b/pkg/workflows/wasm/host/engine/wasmtime.go new file mode 100644 index 0000000000..125821e544 --- /dev/null +++ b/pkg/workflows/wasm/host/engine/wasmtime.go @@ -0,0 +1,354 @@ +package engine + +import ( + "fmt" + "strings" + + "github.com/bytecodealliance/wasmtime-go/v28" +) + +const v2ImportPrefix = "version_v2" + +// wasmtimeEngine implements Engine using the wasmtime runtime. +type wasmtimeEngine struct{} + +func (e *wasmtimeEngine) Load(binary []byte, cfg LoadConfig) (Runtime, error) { + wcfg := wasmtime.NewConfig() + wcfg.SetEpochInterruption(true) + if cfg.InitialFuel > 0 { + wcfg.SetConsumeFuel(true) + } + wcfg.CacheConfigLoadDefault() + wcfg.SetCraneliftOptLevel(wasmtime.OptLevelSpeedAndSize) + setUnwinding(wcfg) + + eng := wasmtime.NewEngineWithConfig(wcfg) + + mod, err := wasmtime.NewModule(eng, binary) + if err != nil { + return nil, fmt.Errorf("error creating wasmtime module: %w", err) + } + + v2ImportName := "" + for _, modImport := range mod.Imports() { + name := modImport.Name() + if modImport.Module() == "env" && name != nil && strings.HasPrefix(*name, v2ImportPrefix) { + v2ImportName = *name + break + } + } + + return &wasmtimeRuntime{ + engine: eng, + module: mod, + config: wcfg, + v2ImportName: v2ImportName, + }, nil +} + +// wasmtimeRuntime holds a compiled wasmtime module and its engine. +type wasmtimeRuntime struct { + engine *wasmtime.Engine + module *wasmtime.Module + config *wasmtime.Config + v2ImportName string +} + +func (r *wasmtimeRuntime) V2ImportName() string { return r.v2ImportName } + +func (r *wasmtimeRuntime) NewStore() Store { + return &wasmtimeStore{ + store: wasmtime.NewStore(r.engine), + runtime: r, + } +} + +func (r *wasmtimeRuntime) IncrementEpoch() { r.engine.IncrementEpoch() } + +func (r *wasmtimeRuntime) Close() { + r.engine.Close() + r.module.Close() + r.config.Close() +} + +// wasmtimeStore wraps a wasmtime.Store and implements Store. +type wasmtimeStore struct { + store *wasmtime.Store + runtime *wasmtimeRuntime +} + +func (s *wasmtimeStore) SetWasi(argv []string) { + wasi := wasmtime.NewWasiConfig() + wasi.InheritStdout() + wasi.SetArgv(argv) + s.store.SetWasi(wasi) +} + +func (s *wasmtimeStore) SetFuel(fuel uint64) error { + return s.store.SetFuel(fuel) +} + +func (s *wasmtimeStore) SetLimiter(memoryBytes, tableElements, instances, tables, memories int64) { + s.store.Limiter(memoryBytes, tableElements, instances, tables, memories) +} + +func (s *wasmtimeStore) SetEpochDeadline(deadline uint64) { + s.store.SetEpochDeadline(deadline) +} + +func (s *wasmtimeStore) LinkNoDAG(exec *Execution) (Instance, error) { + linker := wasmtime.NewLinker(s.runtime.engine) + linker.AllowShadowing(true) + + if err := linker.DefineWasi(); err != nil { + return nil, err + } + + if err := linker.FuncWrap( + "wasi_snapshot_preview1", "poll_oneoff", + wrapPollOneoff(exec.PollOneoff), + ); err != nil { + return nil, err + } + + if err := linker.FuncWrap( + "wasi_snapshot_preview1", "clock_time_get", + wrapClockTimeGet(exec.ClockTimeGet), + ); err != nil { + return nil, err + } + + if err := linker.FuncWrap( + "env", exec.V2ImportName, + func(caller *wasmtime.Caller) {}, + ); err != nil { + return nil, fmt.Errorf("error wrapping v2 import func: %w", err) + } + + if err := linker.FuncWrap("env", "send_response", + wrapSendResponse(exec.SendResponse), + ); err != nil { + return nil, fmt.Errorf("error wrapping sendResponse func: %w", err) + } + + if err := linker.FuncWrap("env", "call_capability", + wrapCallCapability(exec.CallCapability), + ); err != nil { + return nil, fmt.Errorf("error wrapping callcap func: %w", err) + } + + if err := linker.FuncWrap("env", "await_capabilities", + wrapAwaitCapabilities(exec.AwaitCapabilities), + ); err != nil { + return nil, fmt.Errorf("error wrapping awaitcaps func: %w", err) + } + + if err := linker.FuncWrap("env", "get_secrets", + wrapGetSecrets(exec.GetSecrets), + ); err != nil { + return nil, fmt.Errorf("error wrapping get_secrets func: %w", err) + } + + if err := linker.FuncWrap("env", "await_secrets", + wrapAwaitSecrets(exec.AwaitSecrets), + ); err != nil { + return nil, fmt.Errorf("error wrapping await_secrets func: %w", err) + } + + if err := linker.FuncWrap("env", "log", + wrapLog(exec.Log), + ); err != nil { + return nil, fmt.Errorf("error wrapping log func: %w", err) + } + + if err := linker.FuncWrap("env", "switch_modes", + wrapSwitchModes(exec.SwitchModes), + ); err != nil { + return nil, fmt.Errorf("error wrapping switchModes func: %w", err) + } + + if err := linker.FuncWrap("env", "random_seed", exec.RandomSeed); err != nil { + return nil, fmt.Errorf("error wrapping getSeed func: %w", err) + } + + if err := linker.FuncWrap("env", "now", + wrapNow(exec.Now), + ); err != nil { + return nil, fmt.Errorf("error wrapping now func: %w", err) + } + + inst, err := linker.Instantiate(s.store, s.runtime.module) + if err != nil { + return nil, err + } + return &wasmtimeInstance{instance: inst, store: s.store}, nil +} + +func (s *wasmtimeStore) LinkLegacyDAG(exec *LegacyExecution) (Instance, error) { + linker := wasmtime.NewLinker(s.runtime.engine) + linker.AllowShadowing(true) + + if err := linker.DefineWasi(); err != nil { + return nil, err + } + + if err := linker.FuncWrap( + "wasi_snapshot_preview1", "poll_oneoff", + wrapPollOneoff(exec.PollOneoff), + ); err != nil { + return nil, err + } + + if err := linker.FuncWrap( + "wasi_snapshot_preview1", "clock_time_get", + wrapClockTimeGet(exec.ClockTimeGet), + ); err != nil { + return nil, err + } + + if exec.RandomGet != nil { + if err := linker.FuncWrap( + "wasi_snapshot_preview1", "random_get", + wrapRandomGet(exec.RandomGet), + ); err != nil { + return nil, err + } + } + + if err := linker.FuncWrap("env", "sendResponse", + wrapSendResponse(exec.SendResponse), + ); err != nil { + return nil, fmt.Errorf("error wrapping sendResponse func: %w", err) + } + + if err := linker.FuncWrap("env", "fetch", + wrapFetch(exec.Fetch), + ); err != nil { + return nil, fmt.Errorf("error wrapping fetch func: %w", err) + } + + if err := linker.FuncWrap("env", "emit", + wrapEmit(exec.Emit), + ); err != nil { + return nil, fmt.Errorf("error wrapping emit func: %w", err) + } + + if err := linker.FuncWrap("env", "log", + wrapLog(exec.Log), + ); err != nil { + return nil, fmt.Errorf("error wrapping log func: %w", err) + } + + inst, err := linker.Instantiate(s.store, s.runtime.module) + if err != nil { + return nil, err + } + return &wasmtimeInstance{instance: inst, store: s.store}, nil +} + +func (s *wasmtimeStore) Close() { s.store.Close() } + +// wasmtimeInstance wraps a wasmtime.Instance. +type wasmtimeInstance struct { + instance *wasmtime.Instance + store *wasmtime.Store +} + +func (i *wasmtimeInstance) CallStart() error { + start := i.instance.GetFunc(i.store, "_start") + if start == nil { + return fmt.Errorf("could not get start function") + } + _, err := start.Call(i.store) + return err +} + +// wasmtimeCallerAdapter wraps *wasmtime.Caller to implement MemoryAccessor. +type wasmtimeCallerAdapter struct { + c *wasmtime.Caller +} + +func (a *wasmtimeCallerAdapter) Memory() []byte { + return a.c.GetExport("memory").Memory().UnsafeData(a.c) +} + +// Wrapper functions that convert *wasmtime.Caller to MemoryAccessor and +// delegate to the engine-agnostic closures. + +func wrapSendResponse(fn func(MemoryAccessor, int32, int32) int32) func(*wasmtime.Caller, int32, int32) int32 { + return func(caller *wasmtime.Caller, ptr, ptrlen int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, ptr, ptrlen) + } +} + +func wrapCallCapability(fn func(MemoryAccessor, int32, int32) int64) func(*wasmtime.Caller, int32, int32) int64 { + return func(caller *wasmtime.Caller, ptr, ptrlen int32) int64 { + return fn(&wasmtimeCallerAdapter{caller}, ptr, ptrlen) + } +} + +func wrapAwaitCapabilities(fn func(MemoryAccessor, int32, int32, int32, int32) int64) func(*wasmtime.Caller, int32, int32, int32, int32) int64 { + return func(caller *wasmtime.Caller, a, b, c, d int32) int64 { + return fn(&wasmtimeCallerAdapter{caller}, a, b, c, d) + } +} + +func wrapGetSecrets(fn func(MemoryAccessor, int32, int32, int32, int32) int64) func(*wasmtime.Caller, int32, int32, int32, int32) int64 { + return func(caller *wasmtime.Caller, a, b, c, d int32) int64 { + return fn(&wasmtimeCallerAdapter{caller}, a, b, c, d) + } +} + +func wrapAwaitSecrets(fn func(MemoryAccessor, int32, int32, int32, int32) int64) func(*wasmtime.Caller, int32, int32, int32, int32) int64 { + return func(caller *wasmtime.Caller, a, b, c, d int32) int64 { + return fn(&wasmtimeCallerAdapter{caller}, a, b, c, d) + } +} + +func wrapLog(fn func(MemoryAccessor, int32, int32)) func(*wasmtime.Caller, int32, int32) { + return func(caller *wasmtime.Caller, ptr, ptrlen int32) { + fn(&wasmtimeCallerAdapter{caller}, ptr, ptrlen) + } +} + +func wrapSwitchModes(fn func(MemoryAccessor, int32)) func(*wasmtime.Caller, int32) { + return func(caller *wasmtime.Caller, mode int32) { + fn(&wasmtimeCallerAdapter{caller}, mode) + } +} + +func wrapNow(fn func(MemoryAccessor, int32) int32) func(*wasmtime.Caller, int32) int32 { + return func(caller *wasmtime.Caller, resultTs int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, resultTs) + } +} + +func wrapPollOneoff(fn func(MemoryAccessor, int32, int32, int32, int32) int32) func(*wasmtime.Caller, int32, int32, int32, int32) int32 { + return func(caller *wasmtime.Caller, a, b, c, d int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, a, b, c, d) + } +} + +func wrapClockTimeGet(fn func(MemoryAccessor, int32, int64, int32) int32) func(*wasmtime.Caller, int32, int64, int32) int32 { + return func(caller *wasmtime.Caller, id int32, precision int64, resultTs int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, id, precision, resultTs) + } +} + +func wrapFetch(fn func(MemoryAccessor, int32, int32, int32, int32) int32) func(*wasmtime.Caller, int32, int32, int32, int32) int32 { + return func(caller *wasmtime.Caller, a, b, c, d int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, a, b, c, d) + } +} + +func wrapEmit(fn func(MemoryAccessor, int32, int32, int32, int32) int32) func(*wasmtime.Caller, int32, int32, int32, int32) int32 { + return func(caller *wasmtime.Caller, a, b, c, d int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, a, b, c, d) + } +} + +func wrapRandomGet(fn func(MemoryAccessor, int32, int32) int32) func(*wasmtime.Caller, int32, int32) int32 { + return func(caller *wasmtime.Caller, buf, bufLen int32) int32 { + return fn(&wasmtimeCallerAdapter{caller}, buf, bufLen) + } +} diff --git a/pkg/workflows/wasm/host/unwind_unix.go b/pkg/workflows/wasm/host/engine/wasmtime_unwind_unix.go similarity index 83% rename from pkg/workflows/wasm/host/unwind_unix.go rename to pkg/workflows/wasm/host/engine/wasmtime_unwind_unix.go index b05aff84b8..0ac7b6baf0 100644 --- a/pkg/workflows/wasm/host/unwind_unix.go +++ b/pkg/workflows/wasm/host/engine/wasmtime_unwind_unix.go @@ -1,11 +1,11 @@ //go:build unix -package host +package engine import "github.com/bytecodealliance/wasmtime-go/v28" // Load testing shows that leaving native unwind info enabled causes a very large slowdown when loading multiple modules. -func SetUnwinding(cfg *wasmtime.Config) { +func setUnwinding(cfg *wasmtime.Config) { if cfg == nil { panic("wasmtime.Config cannot be nil") } diff --git a/pkg/workflows/wasm/host/unwind_windows.go b/pkg/workflows/wasm/host/engine/wasmtime_unwind_windows.go similarity index 68% rename from pkg/workflows/wasm/host/unwind_windows.go rename to pkg/workflows/wasm/host/engine/wasmtime_unwind_windows.go index b70d3a4003..585622c697 100644 --- a/pkg/workflows/wasm/host/unwind_windows.go +++ b/pkg/workflows/wasm/host/engine/wasmtime_unwind_windows.go @@ -1,9 +1,9 @@ //go:build windows -package host +package engine import "github.com/bytecodealliance/wasmtime-go/v28" -func SetUnwinding(cfg *wasmtime.Config) { +func setUnwinding(_ *wasmtime.Config) { // Unwinding cannot be disabled on Windows. } diff --git a/pkg/workflows/wasm/host/engine_selection_test.go b/pkg/workflows/wasm/host/engine_selection_test.go new file mode 100644 index 0000000000..241bbe7606 --- /dev/null +++ b/pkg/workflows/wasm/host/engine_selection_test.go @@ -0,0 +1,48 @@ +package host + +import ( + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/engine" + "github.com/stretchr/testify/require" +) + +func TestEngineSelection_Default(t *testing.T) { + binary := createTestBinary(nodagRandomBinaryCmd, nodagRandomBinaryLocation, true, t) + + mc := &ModuleConfig{ + Logger: logger.Test(t), + IsUncompressed: true, + } + m, err := NewModule(t.Context(), mc, binary) + require.NoError(t, err) + defer m.Close() + require.Equal(t, string(engine.EngineWasmtime), mc.Engine) +} + +func TestEngineSelection_ExplicitWasmtime(t *testing.T) { + binary := createTestBinary(nodagRandomBinaryCmd, nodagRandomBinaryLocation, true, t) + + mc := &ModuleConfig{ + Logger: logger.Test(t), + IsUncompressed: true, + Engine: string(engine.EngineWasmtime), + } + m, err := NewModule(t.Context(), mc, binary) + require.NoError(t, err) + defer m.Close() +} + +func TestEngineSelection_Invalid(t *testing.T) { + binary := createTestBinary(nodagRandomBinaryCmd, nodagRandomBinaryLocation, true, t) + + mc := &ModuleConfig{ + Logger: logger.Test(t), + IsUncompressed: true, + Engine: "nonexistent", + } + _, err := NewModule(t.Context(), mc, binary) + require.Error(t, err) + require.Contains(t, err.Error(), "nonexistent") +} diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index 8848e07508..a48f1f0b73 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -7,8 +7,7 @@ import ( "sync" "time" - "github.com/bytecodealliance/wasmtime-go/v28" - + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/engine" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) @@ -140,7 +139,7 @@ func (e *execution[T]) awaitSecrets(ctx context.Context, acr *sdkpb.AwaitSecrets }, nil } -func (e *execution[T]) log(caller *wasmtime.Caller, ptr int32, ptrlen int32) { +func (e *execution[T]) log(caller engine.MemoryAccessor, ptr int32, ptrlen int32) { switch e.mode { case sdkpb.Mode_MODE_DON: e.donLogCount++ @@ -193,14 +192,14 @@ func (e *execution[T]) getSeed(mode int32) int64 { return -1 } -func (e *execution[T]) switchModes(_ *wasmtime.Caller, mode int32) { +func (e *execution[T]) switchModes(_ engine.MemoryAccessor, mode int32) { e.hasRun = true e.mode = sdkpb.Mode(mode) } // clockTimeGet is the default time.Now() which is also called by Go many times. // This implementation uses Node Mode to not have to wait for OCR rounds. -func (e *execution[T]) clockTimeGet(caller *wasmtime.Caller, id int32, precision int64, resultTimestamp int32) int32 { +func (e *execution[T]) clockTimeGet(caller engine.MemoryAccessor, id int32, precision int64, resultTimestamp int32) int32 { donTime, err := e.timeFetcher.GetTime(sdkpb.Mode_MODE_NODE) if err != nil { return ErrnoInval @@ -230,7 +229,7 @@ func (e *execution[T]) clockTimeGet(caller *wasmtime.Caller, id int32, precision } // now is used by rawsdk for Workflows and should be called instead of Go's time.Now(). -func (e *execution[T]) now(caller *wasmtime.Caller, resultTimestamp int32) int32 { +func (e *execution[T]) now(caller engine.MemoryAccessor, resultTimestamp int32) int32 { donTime, err := e.timeFetcher.GetTime(e.mode) if err != nil { return ErrnoInval @@ -250,7 +249,7 @@ func (e *execution[T]) now(caller *wasmtime.Caller, resultTimestamp int32) int32 // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md // This implementation only responds to clock events, not to file descriptor notifications. // It sleeps based on the largest timeout -func (e *execution[T]) pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, nsubscriptions int32, resultNevents int32) int32 { +func (e *execution[T]) pollOneoff(caller engine.MemoryAccessor, subscriptionptr int32, eventsptr int32, nsubscriptions int32, resultNevents int32) int32 { if nsubscriptions == 0 { return ErrnoInval } diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 01f66eacea..c8147e3fa0 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -18,7 +18,6 @@ import ( "time" "github.com/andybalholm/brotli" - "github.com/bytecodealliance/wasmtime-go/v28" "google.golang.org/protobuf/proto" "github.com/smartcontractkit/chainlink-common/pkg/config" @@ -28,12 +27,12 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" dagsdk "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/engine" wasmdagpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" "github.com/smartcontractkit/chainlink-protos/cre/go/values" ) -const v2ImportPrefix = "version_v2" var ( defaultTickInterval = 100 * time.Millisecond @@ -86,6 +85,9 @@ type ModuleConfig struct { // If Determinism is set, the module will override the random_get function in the WASI API with // the provided seed to ensure deterministic behavior. Determinism *DeterminismConfig + + // Engine selects which WASM runtime to use. Defaults to "wasmtime". + Engine string } type ModuleBase interface { @@ -124,22 +126,14 @@ type ExecutionHelper interface { } type module struct { - engine *wasmtime.Engine - module *wasmtime.Module - wconfig *wasmtime.Config - - cfg *ModuleConfig - - wg sync.WaitGroup - stopCh chan struct{} - - v2ImportName string + runtime engine.Runtime + cfg *ModuleConfig + wg sync.WaitGroup + stopCh chan struct{} } var _ ModuleV1 = (*module)(nil) -type linkFn[T any] func(m *module, store *wasmtime.Store, exec *execution[T]) (*wasmtime.Instance, error) - // WithDeterminism sets the Determinism field to a deterministic seed from a known time. // // "The Times 03/Jan/2009 Chancellor on brink of second bailout for banks" @@ -182,6 +176,10 @@ func NewModule(ctx context.Context, modCfg *ModuleConfig, binary []byte, opts .. modCfg.SdkLabeler = func(string) {} } + if modCfg.Engine == "" { + modCfg.Engine = string(engine.EngineWasmtime) + } + if modCfg.TickInterval == 0 { modCfg.TickInterval = defaultTickInterval } @@ -290,177 +288,27 @@ func NewModule(ctx context.Context, modCfg *ModuleConfig, binary []byte, opts .. } func newModule(modCfg *ModuleConfig, binary []byte) (*module, error) { - cfg := wasmtime.NewConfig() - cfg.SetEpochInterruption(true) - if modCfg.InitialFuel > 0 { - cfg.SetConsumeFuel(true) - } - cfg.CacheConfigLoadDefault() - cfg.SetCraneliftOptLevel(wasmtime.OptLevelSpeedAndSize) - SetUnwinding(cfg) // Handled differenty based on host OS. - - engine := wasmtime.NewEngineWithConfig(cfg) - - mod, err := wasmtime.NewModule(engine, binary) + eng, err := engine.NewEngine(engine.EngineType(modCfg.Engine)) if err != nil { - return nil, fmt.Errorf("error creating wasmtime module: %w", err) + return nil, fmt.Errorf("error creating engine: %w", err) } - v2ImportName := "" - for _, modImport := range mod.Imports() { - name := modImport.Name() - if modImport.Module() == "env" && name != nil && strings.HasPrefix(*name, v2ImportPrefix) { - v2ImportName = *name - break - } + rt, err := eng.Load(binary, engine.LoadConfig{ + InitialFuel: modCfg.InitialFuel, + }) + if err != nil { + return nil, fmt.Errorf("error loading module: %w", err) } - modCfg.SdkLabeler(v2ImportName) + modCfg.SdkLabeler(rt.V2ImportName()) return &module{ - engine: engine, - module: mod, - wconfig: cfg, - cfg: modCfg, - stopCh: make(chan struct{}), - v2ImportName: v2ImportName, + runtime: rt, + cfg: modCfg, + stopCh: make(chan struct{}), }, nil } -func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*sdkpb.ExecutionResult]) (*wasmtime.Instance, error) { - linker, err := newWasiLinker(exec, m.engine) - if err != nil { - return nil, err - } - - if err = linker.FuncWrap( - "env", - m.v2ImportName, - func(caller *wasmtime.Caller) {}); err != nil { - return nil, fmt.Errorf("error wrapping log func: %w", err) - } - - logger := m.cfg.Logger - if err = linker.FuncWrap( - "env", - "send_response", - createSendResponseFn(logger, exec, func() *sdkpb.ExecutionResult { - return &sdkpb.ExecutionResult{} - }), - ); err != nil { - return nil, fmt.Errorf("error wrapping sendResponse func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "call_capability", - createCallCapFn(logger, exec), - ); err != nil { - return nil, fmt.Errorf("error wrapping callcap func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "await_capabilities", - createAwaitCapsFn(logger, exec), - ); err != nil { - return nil, fmt.Errorf("error wrapping awaitcaps func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "get_secrets", - createGetSecretsFn(logger, exec), - ); err != nil { - return nil, fmt.Errorf("error wrapping get_secrets func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "await_secrets", - createAwaitSecretsFn(logger, exec), - ); err != nil { - return nil, fmt.Errorf("error wrapping await_secrets func: %w", err) - } - - if err := linker.FuncWrap( - "env", - "log", - exec.log, - ); err != nil { - return nil, fmt.Errorf("error wrapping log func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "switch_modes", - exec.switchModes); err != nil { - return nil, fmt.Errorf("error wrapping switchModes func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "random_seed", - exec.getSeed); err != nil { - return nil, fmt.Errorf("error wrapping getSeed func: %w", err) - } - - if err = linker.FuncWrap( - "env", - "now", - exec.now); err != nil { - return nil, fmt.Errorf("error wrapping get_time func: %w", err) - } - - return linker.Instantiate(store, m.module) -} - -func linkLegacyDAG(m *module, store *wasmtime.Store, exec *execution[*wasmdagpb.Response]) (*wasmtime.Instance, error) { - linker, err := newDagWasiLinker(m.cfg, m.engine) - if err != nil { - return nil, err - } - - logger := m.cfg.Logger - - if err = linker.FuncWrap( - "env", - "sendResponse", - createSendResponseFn(logger, exec, func() *wasmdagpb.Response { - return &wasmdagpb.Response{} - }), - ); err != nil { - return nil, fmt.Errorf("error wrapping sendResponse func: %w", err) - } - - err = linker.FuncWrap( - "env", - "fetch", - createFetchFn(logger, wasmRead, wasmWrite, wasmWriteUInt32, m.cfg, exec), - ) - if err != nil { - return nil, fmt.Errorf("error wrapping fetch func: %w", err) - } - - err = linker.FuncWrap( - "env", - "emit", - createEmitFn(logger, exec, m.cfg.Labeler, wasmRead, wasmWrite, wasmWriteUInt32), - ) - if err != nil { - return nil, fmt.Errorf("error wrapping emit func: %w", err) - } - - if err := linker.FuncWrap( - "env", - "log", - createLogFn(logger), - ); err != nil { - return nil, fmt.Errorf("error wrapping log func: %w", err) - } - - return linker.Instantiate(store, m.module) -} func (m *module) Start() { m.wg.Go(func() { @@ -470,7 +318,7 @@ func (m *module) Start() { case <-m.stopCh: return case <-ticker.C: - m.engine.IncrementEpoch() + m.runtime.IncrementEpoch() } } }) @@ -479,14 +327,11 @@ func (m *module) Start() { func (m *module) Close() { close(m.stopCh) m.wg.Wait() - - m.engine.Close() - m.module.Close() - m.wconfig.Close() + m.runtime.Close() } func (m *module) IsLegacyDAG() bool { - return m.v2ImportName == "" + return m.runtime.V2ImportName() == "" } func (m *module) Execute(ctx context.Context, req *sdkpb.ExecuteRequest, executor ExecutionHelper) (*sdkpb.ExecutionResult, error) { @@ -535,6 +380,45 @@ func (m *module) Run(ctx context.Context, request *wasmdagpb.Request) (*wasmdagp return runWasm(ctx, m, request, setMaxResponseSize, linkLegacyDAG, nil) } +type linkFn[O any] func(m *module, store engine.Store, exec *execution[O]) (engine.Instance, error) + +func linkNoDAG(m *module, store engine.Store, exec *execution[*sdkpb.ExecutionResult]) (engine.Instance, error) { + exec.timeFetcher = newTimeFetcher(exec.ctx, exec.executor) + exec.timeFetcher.Start() + + lgr := m.cfg.Logger + return store.LinkNoDAG(&engine.Execution{ + V2ImportName: m.runtime.V2ImportName(), + SendResponse: createSendResponseFn(lgr, exec, func() *sdkpb.ExecutionResult { return &sdkpb.ExecutionResult{} }), + CallCapability: createCallCapFn(lgr, exec), + AwaitCapabilities: createAwaitCapsFn(lgr, exec), + GetSecrets: createGetSecretsFn(lgr, exec), + AwaitSecrets: createAwaitSecretsFn(lgr, exec), + Log: exec.log, + SwitchModes: exec.switchModes, + RandomSeed: exec.getSeed, + Now: exec.now, + PollOneoff: exec.pollOneoff, + ClockTimeGet: exec.clockTimeGet, + }) +} + +func linkLegacyDAG(m *module, store engine.Store, exec *execution[*wasmdagpb.Response]) (engine.Instance, error) { + lgr := m.cfg.Logger + le := &engine.LegacyExecution{ + SendResponse: createSendResponseFn(lgr, exec, func() *wasmdagpb.Response { return &wasmdagpb.Response{} }), + Fetch: createFetchFn(lgr, wasmRead, wasmWrite, wasmWriteUInt32, m.cfg, exec), + Emit: createEmitFn(lgr, exec, m.cfg.Labeler, wasmRead, wasmWrite, wasmWriteUInt32), + Log: createLogFn(lgr), + PollOneoff: legacyPollOneoff, + ClockTimeGet: legacyClockTimeGet, + } + if m.cfg.Determinism != nil { + le.RandomGet = createRandomGet(m.cfg) + } + return store.LinkLegacyDAG(le) +} + func runWasm[I, O proto.Message]( ctx context.Context, m *module, @@ -548,8 +432,7 @@ func runWasm[I, O proto.Message]( ctxWithTimeout, cancel := context.WithTimeout(ctx, *m.cfg.Timeout) defer cancel() - store := wasmtime.NewStore(m.engine) - + store := m.runtime.NewStore() defer store.Close() maxResponseSizeBytes, err := m.cfg.MaxResponseSizeLimiter.Limit(ctx) @@ -564,13 +447,7 @@ func runWasm[I, O proto.Message]( reqstr := base64.StdEncoding.EncodeToString(reqpb) - wasi := wasmtime.NewWasiConfig() - wasi.InheritStdout() - defer wasi.Close() - - wasi.SetArgv([]string{"wasi", reqstr}) - - store.SetWasi(wasi) + store.SetWasi([]string{"wasi", reqstr}) if m.cfg.InitialFuel > 0 { err = store.SetFuel(m.cfg.InitialFuel) @@ -579,12 +456,11 @@ func runWasm[I, O proto.Message]( } } - // Limit memory to max memory megabytes per instance. maxMemoryBytes, err := m.cfg.MemoryLimiter.Limit(ctx) if err != nil { return o, fmt.Errorf("failed to get memory limit: %w", err) } - store.Limiter( + store.SetLimiter( int64(maxMemoryBytes/config.MByte)*int64(math.Pow(10, 6)), -1, // tableElements, -1 == default 1, // instances @@ -618,16 +494,10 @@ func runWasm[I, O proto.Message]( return o, fmt.Errorf("error linking wasm: %w", err) } - start := instance.GetFunc(store, "_start") - if start == nil { - return o, errors.New("could not get start function") - } - startTime := time.Now() - _, err = start.Call(store) + err = instance.CallStart() executionDuration := time.Since(startTime) - // The error codes below are only returned by the v1 legacy DAG workflow. switch { case containsCode(err, wasm.CodeSuccess): if any(exec.response) == nil { @@ -639,7 +509,6 @@ func runWasm[I, O proto.Message]( case containsCode(err, wasm.CodeInvalidRequest): return o, fmt.Errorf("invariant violation: invalid request to runner") case containsCode(err, wasm.CodeRunnerErr): - // legacy DAG captured all errors, since the function didn't return an error resp, ok := any(exec).(*execution[*wasmdagpb.Response]) if ok && resp.response != nil { return o, fmt.Errorf("error executing runner: %s: %w", resp.response.ErrMsg, err) @@ -649,11 +518,7 @@ func runWasm[I, O proto.Message]( return o, fmt.Errorf("invariant violation: host errored during sendResponse") } - // If an error has occurred and the deadline has been reached or exceeded, return a deadline exceeded error. - // Note - there is no other reliable signal on the error that can be used to infer it is due to epoch deadline - // being reached, so if an error is returned after the deadline it is assumed it is due to that and return - // context.DeadlineExceeded. - if err != nil && executionDuration >= *m.cfg.Timeout-m.cfg.TickInterval { // As start could be called just before epoch update 1 tick interval is deducted to account for this + if err != nil && executionDuration >= *m.cfg.Timeout-m.cfg.TickInterval { m.cfg.Logger.Errorw("start function returned error after deadline reached, returning deadline exceeded error", "errFromStartFunction", err) return o, context.DeadlineExceeded } @@ -673,8 +538,8 @@ func containsCode(err error, code int) bool { func createSendResponseFn[T proto.Message]( logger logger.Logger, exec *execution[T], - newT func() T) func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int32 { - return func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int32 { + newT func() T) func(caller engine.MemoryAccessor, ptr int32, ptrlen int32) int32 { + return func(caller engine.MemoryAccessor, ptr int32, ptrlen int32) int32 { b, innerErr := wasmRead(caller, ptr, ptrlen) if innerErr != nil { logger.Errorf("error calling sendResponse: %s", innerErr) @@ -770,8 +635,8 @@ func createFetchFn( sizeWriter unsafeFixedLengthWriterFunc, modCfg *ModuleConfig, exec *execution[*wasmdagpb.Response], -) func(caller *wasmtime.Caller, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 { - return func(caller *wasmtime.Caller, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 { +) func(caller engine.MemoryAccessor, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 { + return func(caller engine.MemoryAccessor, respptr int32, resplenptr int32, reqptr int32, reqptrlen int32) int32 { const errFetchSfx = "error calling fetch" // writeErr marshals and writes an error response to wasm @@ -860,12 +725,12 @@ func createEmitFn( reader unsafeReaderFunc, writer unsafeWriterFunc, sizeWriter unsafeFixedLengthWriterFunc, -) func(caller *wasmtime.Caller, respptr, resplenptr, msgptr, msglen int32) int32 { +) func(caller engine.MemoryAccessor, respptr, resplenptr, msgptr, msglen int32) int32 { logErr := func(err error) { l.Errorf("error emitting message: %s", err) } - return func(caller *wasmtime.Caller, respptr, resplenptr, msgptr, msglen int32) int32 { + return func(caller engine.MemoryAccessor, respptr, resplenptr, msgptr, msglen int32) int32 { // writeErr marshals and writes an error response to wasm writeErr := func(err error) int32 { logErr(err) @@ -914,8 +779,8 @@ func createEmitFn( } // createLogFn injects dependencies and builds the log function exposed by the WASM. -func createLogFn(logger logger.Logger) func(caller *wasmtime.Caller, ptr int32, ptrlen int32) { - return func(caller *wasmtime.Caller, ptr int32, ptrlen int32) { +func createLogFn(logger logger.Logger) func(caller engine.MemoryAccessor, ptr int32, ptrlen int32) { + return func(caller engine.MemoryAccessor, ptr int32, ptrlen int32) { b, innerErr := wasmRead(caller, ptr, ptrlen) if innerErr != nil { logger.Errorf("error calling log: %s", innerErr) @@ -1015,24 +880,19 @@ func toValidatedLabels(msg *wasmdagpb.EmitMessageRequest) (map[string]string, er // unsafeWriterFunc defines behavior for writing directly to wasm memory. A source slice of bytes // is written to the location defined by the ptr. -type unsafeWriterFunc func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 +type unsafeWriterFunc func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 // unsafeFixedLengthWriterFunc defines behavior for writing a uint32 value to wasm memory at the location defined // by the ptr. -type unsafeFixedLengthWriterFunc func(c *wasmtime.Caller, ptr int32, val uint32) int64 +type unsafeFixedLengthWriterFunc func(c engine.MemoryAccessor, ptr int32, val uint32) int64 // unsafeReaderFunc abstractly defines the behavior of reading from WASM memory. Returns a copy of // the memory at the given pointer and size. -type unsafeReaderFunc func(c *wasmtime.Caller, ptr, len int32) ([]byte, error) - -// wasmMemoryAccessor is the default implementation for unsafely accessing the memory of the WASM module. -func wasmMemoryAccessor(caller *wasmtime.Caller) []byte { - return caller.GetExport("memory").Memory().UnsafeData(caller) -} +type unsafeReaderFunc func(c engine.MemoryAccessor, ptr, len int32) ([]byte, error) // wasmRead returns a copy of the wasm module memory at the given pointer and size. -func wasmRead(caller *wasmtime.Caller, ptr int32, size int32) ([]byte, error) { - return read(wasmMemoryAccessor(caller), ptr, size) +func wasmRead(caller engine.MemoryAccessor, ptr int32, size int32) ([]byte, error) { + return read(caller.Memory(), ptr, size) } // Read acts on a byte slice that should represent an unsafely accessed slice of memory. It returns @@ -1052,13 +912,13 @@ func read(memory []byte, ptr int32, size int32) ([]byte, error) { } // wasmWrite copies the given src byte slice into the wasm module memory at the given pointer and size. -func wasmWrite(caller *wasmtime.Caller, src []byte, ptr int32, maxSize int32) int64 { - return write(wasmMemoryAccessor(caller), src, ptr, maxSize) +func wasmWrite(caller engine.MemoryAccessor, src []byte, ptr int32, maxSize int32) int64 { + return write(caller.Memory(), src, ptr, maxSize) } // wasmWriteUInt32 binary encodes and writes a uint32 to the wasm module memory at the given pointer. -func wasmWriteUInt32(caller *wasmtime.Caller, ptr int32, val uint32) int64 { - return writeUInt32(wasmMemoryAccessor(caller), ptr, val) +func wasmWriteUInt32(caller engine.MemoryAccessor, ptr int32, val uint32) int64 { + return writeUInt32(caller.Memory(), ptr, val) } // writeUInt32 binary encodes and writes a uint32 to the memory at the given pointer. @@ -1069,8 +929,8 @@ func writeUInt32(memory []byte, ptr int32, val uint32) int64 { return write(memory, buffer, ptr, uint32Size) } -func truncateWasmWrite(caller *wasmtime.Caller, src []byte, ptr int32, size int32) int64 { - memory := wasmMemoryAccessor(caller) +func truncateWasmWrite(caller engine.MemoryAccessor, src []byte, ptr int32, size int32) int64 { + memory := caller.Memory() if int32(len(memory)) < ptr+size { size = int32(len(memory)) - ptr src = src[:size] @@ -1100,8 +960,8 @@ func write(memory, src []byte, ptr, maxSize int32) int64 { func createCallCapFn( logger logger.Logger, - exec *execution[*sdkpb.ExecutionResult]) func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int64 { - return func(caller *wasmtime.Caller, ptr int32, ptrlen int32) int64 { + exec *execution[*sdkpb.ExecutionResult]) func(caller engine.MemoryAccessor, ptr int32, ptrlen int32) int64 { + return func(caller engine.MemoryAccessor, ptr int32, ptrlen int32) int64 { b, innerErr := wasmRead(caller, ptr, ptrlen) if innerErr != nil { errStr := fmt.Sprintf("error calling wasmRead: %s", innerErr) @@ -1131,8 +991,8 @@ func createCallCapFn( func createAwaitCapsFn( logger logger.Logger, exec *execution[*sdkpb.ExecutionResult], -) func(caller *wasmtime.Caller, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { - return func(caller *wasmtime.Caller, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { +) func(caller engine.MemoryAccessor, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { + return func(caller engine.MemoryAccessor, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { b, err := wasmRead(caller, awaitRequest, awaitRequestLen) if err != nil { errStr := fmt.Sprintf("error reading from wasm %s", err) @@ -1175,8 +1035,8 @@ func createAwaitCapsFn( func createGetSecretsFn( logger logger.Logger, - exec *execution[*sdkpb.ExecutionResult]) func(caller *wasmtime.Caller, req, requestLen, responseBuffer, maxResponseLen int32) int64 { - return func(caller *wasmtime.Caller, req, requestLen, responseBuffer, maxResponseLen int32) int64 { + exec *execution[*sdkpb.ExecutionResult]) func(caller engine.MemoryAccessor, req, requestLen, responseBuffer, maxResponseLen int32) int64 { + return func(caller engine.MemoryAccessor, req, requestLen, responseBuffer, maxResponseLen int32) int64 { b, innerErr := wasmRead(caller, req, requestLen) if innerErr != nil { errStr := fmt.Sprintf("error calling wasmRead: %s", innerErr) @@ -1205,8 +1065,8 @@ func createGetSecretsFn( func createAwaitSecretsFn( logger logger.Logger, exec *execution[*sdkpb.ExecutionResult], -) func(caller *wasmtime.Caller, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { - return func(caller *wasmtime.Caller, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { +) func(caller engine.MemoryAccessor, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { + return func(caller engine.MemoryAccessor, awaitRequest, awaitRequestLen, responseBuffer, maxResponseLen int32) int64 { b, err := wasmRead(caller, awaitRequest, awaitRequestLen) if err != nil { errStr := fmt.Sprintf("error reading from wasm %s", err) diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index af5d4fc12c..f05503f79a 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -7,7 +7,7 @@ import ( "sync" "testing" - "github.com/bytecodealliance/wasmtime-go/v28" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/engine" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -21,6 +21,14 @@ import ( "github.com/smartcontractkit/chainlink-protos/cre/go/values/pb" ) +type fakeMemoryAccessor struct{ data []byte } + +func (f *fakeMemoryAccessor) Memory() []byte { return f.data } + +func newFakeMemoryAccessor() engine.MemoryAccessor { + return &fakeMemoryAccessor{data: make([]byte, 65536)} +} + type mockMessageEmitter struct { e func(context.Context, string, map[string]string) error labels map[string]string @@ -66,7 +74,7 @@ func Test_createEmitFn(t *testing.T) { assert.Equal(t, ctxValue, v) return nil }), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ RequestId: reqId, Message: "hello, world", @@ -83,14 +91,14 @@ func Test_createEmitFn(t *testing.T) { assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), ) - gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := emitFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -102,19 +110,19 @@ func Test_createEmitFn(t *testing.T) { newMockMessageEmitter(func(_ context.Context, _ string, _ map[string]string) error { return nil }), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.EmitMessageRequest{}) assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), ) - gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := emitFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -131,19 +139,19 @@ func Test_createEmitFn(t *testing.T) { logger.Test(t), exec, nil, - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { return nil, assert.AnError }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { assert.Equal(t, respBytes, src, "marshalled response not equal to bytes to write") return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { assert.Equal(t, uint32(len(respBytes)), val, "did not write length of response") return 0 }), ) - gotCode := emitFn(new(wasmtime.Caller), 0, int32(len(respBytes)), 0, 0) + gotCode := emitFn(newFakeMemoryAccessor(), 0, int32(len(respBytes)), 0, 0) assert.Equal(t, ErrnoSuccess, gotCode, "code mismatch") }) @@ -163,23 +171,23 @@ func Test_createEmitFn(t *testing.T) { newMockMessageEmitter(func(_ context.Context, _ string, _ map[string]string) error { return assert.AnError }), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.EmitMessageRequest{ RequestId: reqId, }) assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { assert.Equal(t, respBytes, src, "marshalled response not equal to bytes to write") return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { assert.Equal(t, uint32(len(respBytes)), val, "did not write length of response") return 0 }), ) - gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := emitFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -201,19 +209,19 @@ func Test_createEmitFn(t *testing.T) { logger.Test(t), exec, nil, - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { return badData, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { assert.Equal(t, respBytes, src, "marshalled response not equal to bytes to write") return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { assert.Equal(t, uint32(len(respBytes)), val, "did not write length of response") return 0 }), ) - gotCode := emitFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := emitFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) } @@ -225,17 +233,17 @@ func TestCreateFetchFn(t *testing.T) { fetchFn := createFetchFn( logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.FetchRequest{ Id: testID, }) assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), &ModuleConfig{ @@ -248,7 +256,7 @@ func TestCreateFetchFn(t *testing.T) { exec, ) - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := fetchFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -257,10 +265,10 @@ func TestCreateFetchFn(t *testing.T) { fetchFn := createFetchFn( logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { return nil, assert.AnError }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { // the error is handled and written to the buffer resp := &wasmpb.FetchResponse{} err := proto.Unmarshal(src, resp) @@ -268,7 +276,7 @@ func TestCreateFetchFn(t *testing.T) { assert.Equal(t, assert.AnError.Error(), resp.ErrorMessage) return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), &ModuleConfig{ @@ -280,7 +288,7 @@ func TestCreateFetchFn(t *testing.T) { exec, ) - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := fetchFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -289,10 +297,10 @@ func TestCreateFetchFn(t *testing.T) { fetchFn := createFetchFn( logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { return []byte("bad-request-payload"), nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { // the error is handled and written to the buffer resp := &wasmpb.FetchResponse{} err := proto.Unmarshal(src, resp) @@ -301,7 +309,7 @@ func TestCreateFetchFn(t *testing.T) { assert.Contains(t, resp.ErrorMessage, expectedErr) return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), &ModuleConfig{ @@ -313,7 +321,7 @@ func TestCreateFetchFn(t *testing.T) { exec, ) - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := fetchFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -322,14 +330,14 @@ func TestCreateFetchFn(t *testing.T) { fetchFn := createFetchFn( logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.FetchRequest{ Id: testID, }) assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { // the error is handled and written to the buffer resp := &wasmpb.FetchResponse{} err := proto.Unmarshal(src, resp) @@ -338,7 +346,7 @@ func TestCreateFetchFn(t *testing.T) { assert.Equal(t, expectedErr, resp.ErrorMessage) return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), &ModuleConfig{ @@ -351,7 +359,7 @@ func TestCreateFetchFn(t *testing.T) { exec, ) - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := fetchFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoSuccess, gotCode) }) @@ -360,17 +368,17 @@ func TestCreateFetchFn(t *testing.T) { fetchFn := createFetchFn( logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.FetchRequest{ Id: testID, }) assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { return -1 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return 0 }), &ModuleConfig{ @@ -382,7 +390,7 @@ func TestCreateFetchFn(t *testing.T) { exec, ) - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := fetchFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoFault, gotCode) }) @@ -391,17 +399,17 @@ func TestCreateFetchFn(t *testing.T) { fetchFn := createFetchFn( logger.Test(t), - unsafeReaderFunc(func(_ *wasmtime.Caller, _, _ int32) ([]byte, error) { + unsafeReaderFunc(func(_ engine.MemoryAccessor, _, _ int32) ([]byte, error) { b, err := proto.Marshal(&wasmpb.FetchRequest{ Id: testID, }) assert.NoError(t, err) return b, nil }), - unsafeWriterFunc(func(c *wasmtime.Caller, src []byte, ptr, len int32) int64 { + unsafeWriterFunc(func(c engine.MemoryAccessor, src []byte, ptr, len int32) int64 { return 0 }), - unsafeFixedLengthWriterFunc(func(c *wasmtime.Caller, ptr int32, val uint32) int64 { + unsafeFixedLengthWriterFunc(func(c engine.MemoryAccessor, ptr int32, val uint32) int64 { return -1 }), &ModuleConfig{ @@ -413,7 +421,7 @@ func TestCreateFetchFn(t *testing.T) { exec, ) - gotCode := fetchFn(new(wasmtime.Caller), 0, 0, 0, 0) + gotCode := fetchFn(newFakeMemoryAccessor(), 0, 0, 0, 0) assert.Equal(t, ErrnoFault, gotCode) }) } @@ -613,7 +621,7 @@ func Test_SdkLabeler(t *testing.T) { require.NoError(t, err) require.False(t, m.IsLegacyDAG(), "expected NoDAG module") require.NotEmpty(t, capturedName, "SdkLabeler should have been called with v2 import name") - require.True(t, strings.HasPrefix(capturedName, v2ImportPrefix), "captured name should have v2 prefix") + require.True(t, strings.HasPrefix(capturedName, "version_v2"), "captured name should have v2 prefix") }) } diff --git a/pkg/workflows/wasm/host/wasip1.go b/pkg/workflows/wasm/host/wasip1.go index b1d58b286d..205a70fd22 100644 --- a/pkg/workflows/wasm/host/wasip1.go +++ b/pkg/workflows/wasm/host/wasip1.go @@ -7,8 +7,9 @@ import ( "math/rand" "time" - "github.com/bytecodealliance/wasmtime-go/v28" "github.com/jonboulle/clockwork" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/engine" ) var ( @@ -17,89 +18,22 @@ var ( tick = 100 * time.Millisecond ) -func newWasiLinker[T any](exec *execution[T], engine *wasmtime.Engine) (*wasmtime.Linker, error) { - linker := wasmtime.NewLinker(engine) - linker.AllowShadowing(true) - - err := linker.DefineWasi() - if err != nil { - return nil, err - } - - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "poll_oneoff", - exec.pollOneoff, - ) - if err != nil { - return nil, err - } - - exec.timeFetcher = newTimeFetcher(exec.ctx, exec.executor) - exec.timeFetcher.Start() - - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "clock_time_get", - exec.clockTimeGet, - ) - if err != nil { - return nil, err - } - - return linker, nil -} - -func newDagWasiLinker(modCfg *ModuleConfig, engine *wasmtime.Engine) (*wasmtime.Linker, error) { - linker := wasmtime.NewLinker(engine) - linker.AllowShadowing(true) - - err := linker.DefineWasi() - if err != nil { - return nil, err - } - - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "poll_oneoff", - pollOneoff, - ) - if err != nil { - return nil, err - } - - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "clock_time_get", - clockTimeGet, - ) - if err != nil { - return nil, err - } - - if modCfg.Determinism != nil { - err = linker.FuncWrap( - "wasi_snapshot_preview1", - "random_get", - createRandomGet(modCfg), - ) - if err != nil { - return nil, err - } - } - - return linker, nil -} - const ( clockIDRealtime = iota clockIDMonotonic ) -// Loosely based off the implementation here: -// https://github.com/tetratelabs/wazero/blob/main/imports/wasi_snapshot_preview1/clock.go#L42 -// Each call to clockTimeGet increments our fake clock by `tick`. -func clockTimeGet(caller *wasmtime.Caller, id int32, precision int64, resultTimestamp int32) int32 { +const ( + subscriptionLen = 48 + eventsLen = 32 + + eventTypeClock = iota + eventTypeFDRead + eventTypeFDWrite +) + +// legacyClockTimeGet is the fake-clock implementation used by legacy DAG workflows. +func legacyClockTimeGet(caller engine.MemoryAccessor, id int32, precision int64, resultTimestamp int32) int32 { var val int64 switch id { case clockIDMonotonic: @@ -119,22 +53,9 @@ func clockTimeGet(caller *wasmtime.Caller, id int32, precision int64, resultTime return ErrnoSuccess } -const ( - subscriptionLen = 48 - eventsLen = 32 - - eventTypeClock = iota - eventTypeFDRead - eventTypeFDWrite -) - -// Loosely based off the implementation here: -// https://github.com/tetratelabs/wazero/blob/main/imports/wasi_snapshot_preview1/poll.go#L52 -// For an overview of the spec, including the datatypes being referred to, see: -// https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md -// This implementation only responds to clock events, not to file descriptor notifications. -// It doesn't actually sleep though, and will instead advance our fake clock by the sleep duration. -func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, nsubscriptions int32, resultNevents int32) int32 { +// legacyPollOneoff is the fake-clock implementation used by legacy DAG workflows. +// It advances the fake clock by the sleep duration rather than actually sleeping. +func legacyPollOneoff(caller engine.MemoryAccessor, subscriptionptr int32, eventsptr int32, nsubscriptions int32, resultNevents int32) int32 { if nsubscriptions == 0 { return ErrnoInval } @@ -144,42 +65,28 @@ func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, return ErrnoFault } - // Each subscription should have an event events := make([]byte, nsubscriptions*eventsLen) timeout := time.Duration(0) for i := range nsubscriptions { - // First, let's read the subscription inOffset := i * subscriptionLen - userData := subs[inOffset : inOffset+8] eventType := subs[inOffset+8] argBuf := subs[inOffset+8+8:] - slot, err := getSlot(events, i) - if err != nil { + slot, serr := getSlot(events, i) + if serr != nil { return ErrnoFault } switch eventType { case eventTypeClock: - // We want to stub out clock events, - // so let's just return success, and - // we'll advance the clock by the timeout duration - // below. - - // Structure of event, per: - // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-subscription_clock-struct - // - 0-8: clock id - // - 8-16: timeout - // - 16-24: precision - // - 24-32: flag newTimeout := binary.LittleEndian.Uint64(argBuf[8:16]) flag := binary.LittleEndian.Uint16(argBuf[24:32]) var errno Errno switch flag { - case 0: // relative time + case 0: errno = ErrnoSuccess if timeout < time.Duration(newTimeout) { timeout = time.Duration(newTimeout) @@ -189,20 +96,14 @@ func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, } writeEvent(slot, userData, errno, eventTypeClock) case eventTypeFDRead: - // Our sandbox doesn't allow access to the filesystem, - // so let's just error these events writeEvent(slot, userData, ErrnoBadf, eventTypeFDRead) case eventTypeFDWrite: - // Our sandbox doesn't allow access to the filesystem, - // so let's just error these events writeEvent(slot, userData, ErrnoBadf, eventTypeFDWrite) default: writeEvent(slot, userData, ErrnoInval, int(eventType)) } } - // Advance the clock by timeout. - // This will make it seem like we've slept by timeout. if timeout > 0 { clock.Advance(timeout) } @@ -211,15 +112,10 @@ func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, rne := make([]byte, uint32Size) binary.LittleEndian.PutUint32(rne, uint32(nsubscriptions)) - // Write the number of events to `resultNevents` - size := wasmWrite(caller, rne, resultNevents, uint32Size) - if size == -1 { + if wasmWrite(caller, rne, resultNevents, uint32Size) == -1 { return ErrnoFault } - - // Write the events to `events` - size = wasmWrite(caller, events, eventsptr, nsubscriptions*eventsLen) - if size == -1 { + if wasmWrite(caller, events, eventsptr, nsubscriptions*eventsLen) == -1 { return ErrnoFault } @@ -227,36 +123,28 @@ func pollOneoff(caller *wasmtime.Caller, subscriptionptr int32, eventsptr int32, } func writeEvent(slot []byte, userData []byte, errno Errno, eventType int) { - // the event structure is described here: - // https://github.com/WebAssembly/WASI/blob/snapshot-01/phases/snapshot/docs.md#-event-struct copy(slot, userData) slot[8] = byte(errno) slot[9] = 0 binary.LittleEndian.PutUint32(slot[10:], uint32(eventType)) } -// createRandomGet accepts a seed from the module config and overrides the random_get function in -// the WASI API. The override fixes the random source with a hardcoded seed via insecure randomness. -// Function errors if the config is not set or does not contain a seed. -func createRandomGet(cfg *ModuleConfig) func(caller *wasmtime.Caller, buf, bufLen int32) int32 { - return func(caller *wasmtime.Caller, buf, bufLen int32) int32 { +func createRandomGet(cfg *ModuleConfig) func(caller engine.MemoryAccessor, buf, bufLen int32) int32 { + return func(caller engine.MemoryAccessor, buf, bufLen int32) int32 { if cfg == nil || cfg.Determinism == nil { return ErrnoInval } var ( - // Fix the random source with a hardcoded seed seed = cfg.Determinism.Seed randSource = rand.New(rand.NewSource(seed)) //nolint:gosec randOutput = make([]byte, bufLen) ) - // Generate random bytes from the source if _, err := io.ReadAtLeast(randSource, randOutput, int(bufLen)); err != nil { return ErrnoFault } - // Copy the random bytes into the wasm module memory if n := wasmWrite(caller, randOutput, buf, bufLen); n != int64(len(randOutput)) { return ErrnoFault } diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 8a11564195..7600a3a75e 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -11,12 +11,22 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/engine" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +type testRuntimeOverride struct { + engine.Runtime + v2ImportNameOverride string +} + +func (t *testRuntimeOverride) V2ImportName() string { + return t.v2ImportNameOverride +} + const ( nodagRandomBinaryCmd = "standard_tests/multiple_triggers" nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" @@ -35,7 +45,7 @@ func Test_Sleep_Timeout(t *testing.T) { m, err := NewModule(t.Context(), mc, binary) require.NoError(t, err) - m.v2ImportName = "test" + m.runtime = &testRuntimeOverride{Runtime: m.runtime, v2ImportNameOverride: "test"} m.Start() defer m.Close()