Skip to content

Commit 497485c

Browse files
authored
Merge pull request #241 from influxdata/crepererum/udf-move
refactor: move UDF code to dedicated module
2 parents a0df97a + 6743037 commit 497485c

File tree

2 files changed

+329
-324
lines changed

2 files changed

+329
-324
lines changed

host/src/lib.rs

Lines changed: 2 additions & 324 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,6 @@
22
//!
33
//!
44
//! [DataFusion]: https://datafusion.apache.org/
5-
use std::{any::Any, collections::HashSet, hash::Hash, sync::Arc};
6-
7-
use arrow::datatypes::DataType;
8-
use datafusion_common::{DataFusionError, Result as DataFusionResult};
9-
use datafusion_execution::memory_pool::MemoryPool;
10-
use datafusion_expr::{
11-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
12-
async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl},
13-
};
14-
use tokio::runtime::Handle;
15-
use uuid::Uuid;
16-
use wasmtime::component::ResourceAny;
17-
use wasmtime_wasi::async_trait;
18-
19-
use crate::{
20-
bindings::exports::datafusion_udf_wasm::udf::types as wit_types,
21-
component::WasmComponentInstance,
22-
conversion::limits::{CheckedInto, ComplexityToken},
23-
error::{DataFusionResultExt, WasmToDataFusionResultExt, WitDataFusionResultExt},
24-
tokio_helpers::async_in_sync_context,
25-
};
265
276
pub use crate::{
287
component::WasmComponentPrecompiled,
@@ -33,6 +12,7 @@ pub use crate::{
3312
},
3413
limiter::StaticResourceLimits,
3514
permissions::WasmPermissions,
15+
udf::WasmScalarUdf,
3616
vfs::limits::VfsLimits,
3717
};
3818

@@ -55,307 +35,5 @@ mod linker;
5535
mod permissions;
5636
mod state;
5737
mod tokio_helpers;
38+
mod udf;
5839
mod vfs;
59-
60-
/// A [`ScalarUDFImpl`] that wraps a WebAssembly payload.
61-
///
62-
/// # Async, Blocking, Cancellation
63-
/// Async methods will yield back to the runtime in periodical intervals. The caller should implement some form of
64-
/// timeout, e.g. using [`tokio::time::timeout`]. It is safe to cancel async methods.
65-
///
66-
/// For the async interruption to work it is important that the I/O [runtime] passed to [`WasmScalarUdf::new`] is
67-
/// different from the runtime used to call UDF methods, since the I/O runtime is also used to schedule an
68-
/// [epoch timer](WasmPermissions::with_epoch_tick_time).
69-
///
70-
/// Methods that return references -- e.g. [`ScalarUDFImpl::name`] and [`ScalarUDFImpl::signature`] -- are cached
71-
/// during UDF creation.
72-
///
73-
/// Some methods do NOT offer an async interface yet, e.g. [`ScalarUDFImpl::return_type`]. For these we try to cache
74-
/// them during creation, but if that is not possible we need to block in place when the method is called. This only
75-
/// works when a multi-threaded tokio runtime is used. There is a
76-
/// [timeout](WasmPermissions::with_inplace_blocking_max_ticks). See
77-
/// <https://github.com/influxdata/datafusion-udf-wasm/issues/169> for a potential future improvement on that front.
78-
///
79-
///
80-
/// [runtime]: tokio::runtime::Runtime
81-
#[derive(Debug)]
82-
pub struct WasmScalarUdf {
83-
/// WASM component instance.
84-
instance: Arc<WasmComponentInstance>,
85-
86-
/// Resource handle for the Scalar UDF within the VM.
87-
///
88-
/// This is somewhat an "object reference".
89-
resource: ResourceAny,
90-
91-
/// Name of the UDF.
92-
///
93-
/// This was pre-fetched during UDF generation because
94-
/// [`ScalarUDFImpl::name`] is sync and requires us to return a reference.
95-
name: String,
96-
97-
/// We treat every UDF as unique, but we need a proxy value to express that.
98-
id: Uuid,
99-
100-
/// Signature of the UDF.
101-
///
102-
/// This was pre-fetched during UDF generation because
103-
/// [`ScalarUDFImpl::signature`] is sync and requires us to return a
104-
/// reference.
105-
signature: Signature,
106-
107-
/// Return type of the UDF.
108-
///
109-
/// This was pre-fetched during UDF generation because
110-
/// [`ScalarUDFImpl::return_type`] is sync and requires us to return a
111-
/// reference. We can only compute the return type if the underlying
112-
/// [TypeSignature] is [Exact](TypeSignature::Exact).
113-
return_type: Option<DataType>,
114-
}
115-
116-
impl WasmScalarUdf {
117-
/// Create multiple UDFs from a single WASM VM.
118-
///
119-
/// UDFs bound to the same VM share state, however calling this method
120-
/// multiple times will yield independent WASM VMs.
121-
pub async fn new(
122-
component: &WasmComponentPrecompiled,
123-
permissions: &WasmPermissions,
124-
io_rt: Handle,
125-
memory_pool: &Arc<dyn MemoryPool>,
126-
source: String,
127-
) -> DataFusionResult<Vec<Self>> {
128-
let instance =
129-
Arc::new(WasmComponentInstance::new(component, permissions, io_rt, memory_pool).await?);
130-
131-
let udf_resources = {
132-
let mut state = instance.lock_state().await;
133-
instance
134-
.bindings()
135-
.datafusion_udf_wasm_udf_types()
136-
.call_scalar_udfs(&mut state, &source)
137-
.await
138-
.context(
139-
"calling scalar_udfs() method failed",
140-
Some(&state.stderr.contents()),
141-
)?
142-
.convert_err(permissions.trusted_data_limits.clone())
143-
.context("scalar_udfs")?
144-
};
145-
if udf_resources.len() > permissions.max_udfs {
146-
return Err(DataFusionError::ResourcesExhausted(format!(
147-
"guest returned too many UDFs: got={}, limit={}",
148-
udf_resources.len(),
149-
permissions.max_udfs,
150-
)));
151-
}
152-
153-
let mut udfs = Vec::with_capacity(udf_resources.len());
154-
let mut names_seen = HashSet::with_capacity(udf_resources.len());
155-
for resource in udf_resources {
156-
let mut state = instance.lock_state().await;
157-
let name = instance
158-
.bindings()
159-
.datafusion_udf_wasm_udf_types()
160-
.scalar_udf()
161-
.call_name(&mut state, resource)
162-
.await
163-
.context("call ScalarUdf::name", Some(&state.stderr.contents()))?;
164-
ComplexityToken::new(permissions.trusted_data_limits.clone())?
165-
.check_identifier(&name)
166-
.context("UDF name")?;
167-
if !names_seen.insert(name.clone()) {
168-
return Err(DataFusionError::External(
169-
format!("non-unique UDF name: '{name}'").into(),
170-
));
171-
}
172-
173-
let signature: Signature = instance
174-
.bindings()
175-
.datafusion_udf_wasm_udf_types()
176-
.scalar_udf()
177-
.call_signature(&mut state, resource)
178-
.await
179-
.context("call ScalarUdf::signature", Some(&state.stderr.contents()))?
180-
.checked_into_root(&permissions.trusted_data_limits)?;
181-
182-
let return_type = match &signature.type_signature {
183-
TypeSignature::Exact(t) => {
184-
let r = instance
185-
.bindings()
186-
.datafusion_udf_wasm_udf_types()
187-
.scalar_udf()
188-
.call_return_type(
189-
&mut state,
190-
resource,
191-
&t.iter()
192-
.map(|dt| wit_types::DataType::from(dt.clone()))
193-
.collect::<Vec<_>>(),
194-
)
195-
.await
196-
.context(
197-
"call ScalarUdf::return_type",
198-
Some(&state.stderr.contents()),
199-
)?
200-
.convert_err(permissions.trusted_data_limits.clone())?;
201-
Some(r.checked_into_root(&permissions.trusted_data_limits)?)
202-
}
203-
_ => None,
204-
};
205-
206-
udfs.push(Self {
207-
instance: Arc::clone(&instance),
208-
resource,
209-
name,
210-
id: Uuid::new_v4(),
211-
signature,
212-
return_type,
213-
});
214-
}
215-
216-
Ok(udfs)
217-
}
218-
219-
/// Convert this [WasmScalarUdf] into an [AsyncScalarUDF].
220-
pub fn as_async_udf(self) -> AsyncScalarUDF {
221-
AsyncScalarUDF::new(Arc::new(self))
222-
}
223-
224-
/// Check that the provided argument types match the UDF signature.
225-
fn check_arg_types(&self, arg_types: &[DataType]) -> DataFusionResult<()> {
226-
if let TypeSignature::Exact(expected_types) = &self.signature.type_signature {
227-
if arg_types.len() != expected_types.len() {
228-
return Err(DataFusionError::Plan(format!(
229-
"`{}` expects {} parameters but got {}",
230-
self.name,
231-
expected_types.len(),
232-
arg_types.len()
233-
)));
234-
}
235-
236-
for (i, (provided, expected)) in arg_types.iter().zip(expected_types.iter()).enumerate()
237-
{
238-
if provided != expected {
239-
return Err(DataFusionError::Plan(format!(
240-
"argument {} of `{}` should be {:?}, got {:?}",
241-
i + 1,
242-
self.name,
243-
expected,
244-
provided
245-
)));
246-
}
247-
}
248-
}
249-
250-
Ok(())
251-
}
252-
}
253-
254-
impl PartialEq<Self> for WasmScalarUdf {
255-
fn eq(&self, other: &Self) -> bool {
256-
self.id == other.id
257-
}
258-
}
259-
260-
impl Eq for WasmScalarUdf {}
261-
262-
impl Hash for WasmScalarUdf {
263-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
264-
self.id.hash(state);
265-
}
266-
}
267-
268-
impl ScalarUDFImpl for WasmScalarUdf {
269-
fn as_any(&self) -> &dyn Any {
270-
self
271-
}
272-
273-
fn name(&self) -> &str {
274-
&self.name
275-
}
276-
277-
fn signature(&self) -> &Signature {
278-
&self.signature
279-
}
280-
281-
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
282-
self.check_arg_types(arg_types)?;
283-
284-
if let Some(return_type) = &self.return_type {
285-
return Ok(return_type.clone());
286-
}
287-
288-
async_in_sync_context(
289-
async {
290-
let arg_types = arg_types
291-
.iter()
292-
.map(|t| wit_types::DataType::from(t.clone()))
293-
.collect::<Vec<_>>();
294-
let mut state = self.instance.lock_state().await;
295-
let return_type = self
296-
.instance
297-
.bindings()
298-
.datafusion_udf_wasm_udf_types()
299-
.scalar_udf()
300-
.call_return_type(&mut state, self.resource, &arg_types)
301-
.await
302-
.context(
303-
"call ScalarUdf::return_type",
304-
Some(&state.stderr.contents()),
305-
)?
306-
.convert_err(self.instance.trusted_data_limits().clone())?;
307-
return_type.checked_into_root(self.instance.trusted_data_limits())
308-
},
309-
self.instance.inplace_blocking_timeout(),
310-
)
311-
}
312-
313-
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
314-
Err(DataFusionError::NotImplemented(
315-
"synchronous invocation of WasmScalarUdf is not supported, use invoke_async_with_args instead".to_string(),
316-
))
317-
}
318-
}
319-
320-
#[async_trait]
321-
impl AsyncScalarUDFImpl for WasmScalarUdf {
322-
fn ideal_batch_size(&self) -> Option<usize> {
323-
None
324-
}
325-
326-
async fn invoke_async_with_args(
327-
&self,
328-
args: ScalarFunctionArgs,
329-
) -> DataFusionResult<ColumnarValue> {
330-
let args = args.try_into()?;
331-
let mut state = self.instance.lock_state().await;
332-
let return_type = self
333-
.instance
334-
.bindings()
335-
.datafusion_udf_wasm_udf_types()
336-
.scalar_udf()
337-
.call_invoke_with_args(&mut state, self.resource, &args)
338-
.await
339-
.context(
340-
"call ScalarUdf::invoke_with_args",
341-
Some(&state.stderr.contents()),
342-
)?
343-
.convert_err(self.instance.trusted_data_limits().clone())?;
344-
345-
match return_type.checked_into_root(self.instance.trusted_data_limits()) {
346-
Ok(ColumnarValue::Scalar(scalar)) => Ok(ColumnarValue::Scalar(scalar)),
347-
Ok(ColumnarValue::Array(array)) if array.len() as u64 != args.number_rows => {
348-
Err(DataFusionError::External(
349-
format!(
350-
"UDF returned array of length {} but should produce {} rows",
351-
array.len(),
352-
args.number_rows
353-
)
354-
.into(),
355-
))
356-
}
357-
Ok(ColumnarValue::Array(array)) => Ok(ColumnarValue::Array(array)),
358-
Err(e) => Err(e),
359-
}
360-
}
361-
}

0 commit comments

Comments
 (0)