Skip to content

Commit ed796ff

Browse files
committed
Support async require loaders for Luau
1 parent 65e5be8 commit ed796ff

File tree

5 files changed

+150
-29
lines changed

5 files changed

+150
-29
lines changed

mlua-sys/src/luau/luarequire.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::os::raw::{c_char, c_int, c_void};
44

55
use super::lua::lua_State;
66

7+
pub const LUA_REGISTERED_MODULES_TABLE: *const c_char = cstr!("_REGISTEREDMODULES");
8+
79
#[repr(C)]
810
pub enum luarequire_NavigateResult {
911
Success,

src/luau/mod.rs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::ffi::CStr;
2-
use std::mem;
32
use std::os::raw::c_int;
43

54
use crate::error::Result;
@@ -16,26 +15,7 @@ impl Lua {
1615
#[cfg(any(feature = "luau", doc))]
1716
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
1817
pub fn create_require_function<R: Require + 'static>(&self, require: R) -> Result<Function> {
19-
unsafe extern "C-unwind" fn mlua_require(state: *mut ffi::lua_State) -> c_int {
20-
let mut ar: ffi::lua_Debug = mem::zeroed();
21-
if ffi::lua_getinfo(state, 1, cstr!("s"), &mut ar) == 0 {
22-
ffi::luaL_error(state, cstr!("require is not supported in this context"));
23-
}
24-
let top = ffi::lua_gettop(state);
25-
ffi::lua_pushvalue(state, ffi::lua_upvalueindex(2)); // the "proxy" require function
26-
ffi::lua_pushvalue(state, 1); // require path
27-
ffi::lua_pushstring(state, ar.source); // current file
28-
ffi::lua_call(state, 2, ffi::LUA_MULTRET);
29-
ffi::lua_gettop(state) - top
30-
}
31-
32-
unsafe {
33-
self.exec_raw((), move |state| {
34-
let requirer_ptr = ffi::lua_newuserdata_t::<Box<dyn Require>>(state, Box::new(require));
35-
ffi::luarequire_pushproxyrequire(state, require::init_config, requirer_ptr as *mut _);
36-
ffi::lua_pushcclosured(state, mlua_require, cstr!("require"), 2);
37-
})
38-
}
18+
require::create_require_function(self, require)
3919
}
4020

4121
pub(crate) unsafe fn configure_luau(&self) -> Result<()> {

src/luau/require.rs

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ use std::io::Result as IoResult;
55
use std::os::raw::{c_char, c_int, c_void};
66
use std::path::{Component, Path, PathBuf};
77
use std::result::Result as StdResult;
8-
use std::{env, fmt, fs, ptr};
8+
use std::{env, fmt, fs, mem, ptr};
99

1010
use crate::error::Result;
1111
use crate::function::Function;
1212
use crate::state::{callback_error_ext, Lua};
13+
use crate::table::Table;
1314
use crate::types::MaybeSend;
14-
use crate::value::Value;
1515

1616
/// An error that can occur during navigation in the Luau `require` system.
1717
pub enum NavigateError {
@@ -87,6 +87,8 @@ pub trait Require: MaybeSend {
8787
fn config(&self) -> IoResult<Vec<u8>>;
8888

8989
/// Returns a loader that when called, loads the module and returns the result.
90+
///
91+
/// Loader can be sync or async.
9092
fn loader(&self, lua: &Lua, path: &str, chunk_name: &str, content: &[u8]) -> Result<Function> {
9193
let _ = path;
9294
lua.load(content).set_name(chunk_name).into_function()
@@ -425,10 +427,7 @@ pub(super) unsafe extern "C" fn init_config(config: *mut ffi::luarequire_Configu
425427
let contents = CStr::from_ptr(contents).to_bytes();
426428
callback_error_ext(state, ptr::null_mut(), false, move |extra, _| {
427429
let rawlua = (*extra).raw_lua();
428-
match (this.loader(rawlua.lua(), &path, &chunk_name, contents)?).call(())? {
429-
Value::Nil => rawlua.push(true)?,
430-
value => rawlua.push_value(&value)?,
431-
};
430+
rawlua.push(this.loader(rawlua.lua(), &path, &chunk_name, contents)?)?;
432431
Ok(1)
433432
})
434433
}
@@ -495,6 +494,105 @@ unsafe fn write_to_buffer(
495494
}
496495
}
497496

497+
#[cfg(feature = "luau")]
498+
pub fn create_require_function<R: Require + 'static>(lua: &Lua, require: R) -> Result<Function> {
499+
unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int {
500+
let mut ar: ffi::lua_Debug = mem::zeroed();
501+
for level in 2.. {
502+
if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 {
503+
ffi::luaL_error(state, cstr!("require is not supported in this context"));
504+
}
505+
if CStr::from_ptr(ar.what) != c"C" {
506+
break;
507+
}
508+
}
509+
ffi::lua_pushstring(state, ar.source);
510+
1
511+
}
512+
513+
unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int {
514+
let requirer = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1)) as *const Box<dyn Require>;
515+
let cache_key = (*requirer).cache_key();
516+
ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len());
517+
1
518+
}
519+
520+
let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe {
521+
lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| {
522+
let requirer_ptr = ffi::lua_newuserdata_t::<Box<dyn Require>>(state, Box::new(require));
523+
ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1);
524+
ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file"));
525+
ffi::luarequire_pushproxyrequire(state, init_config, requirer_ptr as *mut _);
526+
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE);
527+
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__MLUA_LOADER_CACHE"));
528+
})
529+
}?;
530+
531+
unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int {
532+
ffi::luaL_where(state, 1);
533+
ffi::lua_pushvalue(state, 1);
534+
ffi::lua_concat(state, 2);
535+
ffi::lua_error(state);
536+
}
537+
538+
unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int {
539+
ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1)));
540+
1
541+
}
542+
543+
let (error, r#type) = unsafe {
544+
lua.exec_raw::<(Function, Function)>((), move |state| {
545+
ffi::lua_pushcfunctiond(state, error, cstr!("error"));
546+
ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
547+
})
548+
}?;
549+
550+
// Prepare environment for the "require" function
551+
let env = lua.create_table_with_capacity(0, 7)?;
552+
env.raw_set("get_cache_key", get_cache_key)?;
553+
env.raw_set("find_current_file", find_current_file)?;
554+
env.raw_set("proxyrequire", proxyrequire)?;
555+
env.raw_set("REGISTERED_MODULES", registered_modules)?;
556+
env.raw_set("LOADER_CACHE", loader_cache)?;
557+
env.raw_set("error", error)?;
558+
env.raw_set("type", r#type)?;
559+
560+
lua.load(
561+
r#"
562+
local path = ...
563+
if type(path) ~= "string" then
564+
error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")")
565+
end
566+
567+
-- Check if the module (path) is explicitly registered
568+
local maybe_result = REGISTERED_MODULES[path]
569+
if maybe_result ~= nil then
570+
return maybe_result
571+
end
572+
573+
local loader = proxyrequire(path, find_current_file())
574+
local cache_key = get_cache_key()
575+
-- Check if the loader result is already cached
576+
local result = LOADER_CACHE[cache_key]
577+
if result ~= nil then
578+
return result
579+
end
580+
581+
-- Call the loader function and cache the result
582+
result = loader()
583+
if result == nil then
584+
result = true
585+
end
586+
LOADER_CACHE[cache_key] = result
587+
return result
588+
"#,
589+
)
590+
.try_cache()
591+
.set_name("=__mlua_require")
592+
.set_environment(env)
593+
.into_function()
594+
}
595+
498596
#[cfg(test)]
499597
mod tests {
500598
use std::path::Path;

src/state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ impl Lua {
356356
#[cfg(not(feature = "luau"))]
357357
const LOADED_MODULES_KEY: *const c_char = ffi::LUA_LOADED_TABLE;
358358
#[cfg(feature = "luau")]
359-
const LOADED_MODULES_KEY: *const c_char = cstr!("_REGISTEREDMODULES");
359+
const LOADED_MODULES_KEY: *const c_char = ffi::LUA_REGISTERED_MODULES_TABLE;
360360

361361
if cfg!(feature = "luau") && !modname.starts_with('@') {
362362
return Err(Error::runtime("module name must begin with '@'"));

tests/luau/require.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use mlua::{IntoLua, Lua, Result, Value};
22

3-
fn run_require(lua: &Lua, path: &str) -> Result<Value> {
3+
fn run_require(lua: &Lua, path: impl IntoLua) -> Result<Value> {
44
lua.load(r#"return require(...)"#).call(path)
55
}
66

@@ -26,6 +26,12 @@ fn test_require_errors() {
2626
assert!(
2727
(res.unwrap_err().to_string()).contains("require path must start with a valid prefix: ./, ../, or @")
2828
);
29+
30+
// Pass non-string to require
31+
let res = run_require(&lua, true);
32+
assert!(res.is_err());
33+
assert!((res.unwrap_err().to_string())
34+
.contains("bad argument #1 to 'require' (string expected, got boolean)"));
2935
}
3036

3137
#[test]
@@ -100,3 +106,38 @@ fn test_require_with_config() {
100106
assert!(res.is_err());
101107
assert!((res.unwrap_err().to_string()).contains("@ is not a valid alias"));
102108
}
109+
110+
#[cfg(feature = "async")]
111+
#[tokio::test]
112+
async fn test_async_require() -> Result<()> {
113+
let lua = Lua::new();
114+
115+
let temp_dir = tempfile::tempdir().unwrap();
116+
let temp_path = temp_dir.path().join("async_chunk.luau");
117+
std::fs::write(
118+
&temp_path,
119+
r#"
120+
sleep_ms(10)
121+
return "result_after_async_sleep"
122+
"#,
123+
)
124+
.unwrap();
125+
126+
lua.globals().set(
127+
"sleep_ms",
128+
lua.create_async_function(|_, ms: u64| async move {
129+
tokio::time::sleep(std::time::Duration::from_millis(ms)).await;
130+
Ok(())
131+
})?,
132+
)?;
133+
134+
lua.load(
135+
r#"
136+
local result = require("./async_chunk")
137+
assert(result == "result_after_async_sleep")
138+
"#,
139+
)
140+
.set_name(format!("@{}", temp_dir.path().join("require.rs").display()))
141+
.exec_async()
142+
.await
143+
}

0 commit comments

Comments
 (0)