Skip to content

Commit 0fe48f3

Browse files
committed
Implement thread creation deletion event callback.
1 parent cd4091f commit 0fe48f3

File tree

6 files changed

+123
-5
lines changed

6 files changed

+123
-5
lines changed

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ pub use crate::thread::{Thread, ThreadStatus};
112112
pub use crate::traits::{
113113
FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike,
114114
};
115+
#[cfg(feature = "luau")]
116+
pub use crate::types::ThreadEventInfo;
115117
pub use crate::types::{
116118
AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
117119
};

src/prelude.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ pub use crate::{
2020
#[doc(no_inline)]
2121
pub use crate::HookTriggers as LuaHookTriggers;
2222

23-
#[cfg(feature = "luau")]
24-
#[doc(no_inline)]
25-
pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector};
26-
2723
#[cfg(feature = "async")]
2824
#[doc(no_inline)]
2925
pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn};
26+
#[cfg(feature = "luau")]
27+
#[doc(no_inline)]
28+
pub use crate::{
29+
CoverageInfo as LuaCoverageInfo, ThreadEventInfo as LuaThreadEventInfo, Vector as LuaVector,
30+
};
3031

3132
#[cfg(feature = "serialize")]
3233
#[doc(no_inline)]

src/state.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ use crate::types::{
2323
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex,
2424
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
2525
};
26+
27+
#[cfg(any(feature = "luau", doc))]
28+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
29+
use crate::types::ThreadEventInfo;
2630
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
2731
use crate::util::{
2832
assert_stack, check_stack, protect_lua_closure, push_string, push_table, rawset_field, StackGuard,
@@ -671,6 +675,72 @@ impl Lua {
671675
}
672676
}
673677

678+
/// Sets a callback that will be called by Luau whenever a thread is created/destroyed.
679+
///
680+
/// Often used for keeping track of threads.
681+
#[cfg(any(feature = "luau", doc))]
682+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
683+
pub fn set_thread_event_callback<F>(&self, callback: F)
684+
where
685+
F: Fn(&Lua, ThreadEventInfo) -> Result<()> + MaybeSend + 'static,
686+
{
687+
use std::rc::Rc;
688+
689+
unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, state: *mut ffi::lua_State) {
690+
callback_error_ext(state, ptr::null_mut(), move |extra, _| {
691+
let raw_lua: &RawLua = (*extra).raw_lua();
692+
let _guard = StateGuard::new(raw_lua, state);
693+
694+
let userthread_cb = (*extra).userthread_callback.clone();
695+
let userthread_cb =
696+
mlua_expect!(userthread_cb, "no userthread callback set in userthread_proc");
697+
if parent.is_null() {
698+
raw_lua.push(Value::Nil).unwrap();
699+
} else {
700+
raw_lua.push_ref_thread(parent).unwrap();
701+
}
702+
if parent.is_null() {
703+
let event_info = ThreadEventInfo::Destroyed(state.cast_const().cast());
704+
let main_state = raw_lua.main_state();
705+
if main_state == state {
706+
return Ok(()); // Don't process Destroyed event on main thread.
707+
}
708+
let main_extra = ExtraData::get(main_state);
709+
let main_raw_lua: &RawLua = (*main_extra).raw_lua();
710+
let _guard = StateGuard::new(main_raw_lua, state);
711+
userthread_cb((*main_extra).lua(), event_info)
712+
} else {
713+
raw_lua.push_ref_thread(parent).unwrap();
714+
let event_info = match raw_lua.pop_value() {
715+
Value::Thread(thr) => ThreadEventInfo::Created(thr),
716+
_ => unimplemented!(),
717+
};
718+
userthread_cb((*extra).lua(), event_info)
719+
}
720+
});
721+
}
722+
723+
// Set interrupt callback
724+
let lua = self.lock();
725+
unsafe {
726+
(*lua.extra.get()).userthread_callback = Some(Rc::new(callback));
727+
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc);
728+
}
729+
}
730+
731+
/// Removes any thread event function previously set by `set_thread_event_callback`.
732+
///
733+
/// This function has no effect if a callback was not previously set.
734+
#[cfg(any(feature = "luau", doc))]
735+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
736+
pub fn remove_thread_event_callback(&self) {
737+
let lua = self.lock();
738+
unsafe {
739+
(*lua.extra.get()).userthread_callback = None;
740+
(*ffi::lua_callbacks(lua.main_state())).userthread = None;
741+
}
742+
}
743+
674744
/// Sets the warning function to be used by Lua to emit warnings.
675745
///
676746
/// Requires `feature = "lua54"`

src/state/extra.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ pub(crate) struct ExtraData {
8080
pub(super) warn_callback: Option<crate::types::WarnCallback>,
8181
#[cfg(feature = "luau")]
8282
pub(super) interrupt_callback: Option<crate::types::InterruptCallback>,
83+
#[cfg(feature = "luau")]
84+
pub(super) userthread_callback: Option<crate::types::ThreadEventCallback>,
8385

8486
#[cfg(feature = "luau")]
8587
pub(super) sandboxed: bool,
@@ -177,6 +179,8 @@ impl ExtraData {
177179
#[cfg(feature = "luau")]
178180
interrupt_callback: None,
179181
#[cfg(feature = "luau")]
182+
userthread_callback: None,
183+
#[cfg(feature = "luau")]
180184
sandboxed: false,
181185
#[cfg(feature = "luau")]
182186
compiler: None,

src/state/raw.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ impl Drop for RawLua {
6464
}
6565

6666
let mem_state = MemoryState::get(self.main_state());
67-
67+
#[cfg(feature = "luau")] // Fixes a crash during shutdown
68+
{
69+
(*ffi::lua_callbacks(self.main_state())).userthread = None;
70+
}
6871
ffi::lua_close(self.main_state());
6972

7073
// Deallocate `MemoryState`
@@ -556,6 +559,21 @@ impl RawLua {
556559
value.push_into_stack(self)
557560
}
558561

562+
pub(crate) unsafe fn push_ref_thread(&self, ref_thread: *mut ffi::lua_State) -> Result<()> {
563+
let state = self.state();
564+
check_stack(state, 1)?;
565+
let _sg = StackGuard::new(ref_thread);
566+
check_stack(ref_thread, 1)?;
567+
568+
if self.unlikely_memory_error() {
569+
ffi::lua_pushthread(ref_thread)
570+
} else {
571+
protect_lua!(ref_thread, 0, 1, |ref_thread| ffi::lua_pushthread(ref_thread))?
572+
};
573+
ffi::lua_xmove(ref_thread, self.state(), 1);
574+
Ok(())
575+
}
576+
559577
/// Pushes a `Value` (by reference) onto the Lua stack.
560578
///
561579
/// Uses 2 stack spaces, does not call `checkstack`.

src/types.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ use crate::error::Result;
77
use crate::hook::Debug;
88
use crate::state::{ExtraData, Lua, RawLua};
99

10+
#[cfg(any(feature = "luau", doc))]
11+
use crate::thread::Thread;
12+
1013
// Re-export mutex wrappers
1114
pub(crate) use sync::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak};
1215

@@ -73,6 +76,20 @@ pub enum VmState {
7376
Yield,
7477
}
7578

79+
/// Information about a thread event.
80+
///
81+
/// For creating a thread, it contains the thread that created it.
82+
///
83+
/// This is useful for tracking the origin of all threads.
84+
#[cfg(any(feature = "luau", doc))]
85+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
86+
pub enum ThreadEventInfo {
87+
/// When a thread is created, it contains the thread that created it.
88+
Created(Thread),
89+
/// When a thread is destroyed, it returns its .to_pointer representation.
90+
Destroyed(*const c_void),
91+
}
92+
7693
#[cfg(all(feature = "send", not(feature = "luau")))]
7794
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
7895

@@ -85,6 +102,12 @@ pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
85102
#[cfg(all(not(feature = "send"), feature = "luau"))]
86103
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState>>;
87104

105+
#[cfg(all(feature = "send", feature = "luau"))]
106+
pub(crate) type ThreadEventCallback = Rc<dyn Fn(&Lua, ThreadEventInfo) -> Result<()> + Send>;
107+
108+
#[cfg(all(not(feature = "send"), feature = "luau"))]
109+
pub(crate) type ThreadEventCallback = Rc<dyn Fn(&Lua, ThreadEventInfo) -> Result<()>>;
110+
88111
#[cfg(all(feature = "send", feature = "lua54"))]
89112
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;
90113

0 commit comments

Comments
 (0)