Skip to content

Commit c6837a0

Browse files
committed
Allow to overwrite init on a per test basis
with `#[test(init=other_init)]`
1 parent ba92102 commit c6837a0

File tree

7 files changed

+153
-57
lines changed

7 files changed

+153
-57
lines changed

macros/src/attributes/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@ pub(crate) fn expand(args: TokenStream, input: TokenStream) -> TokenStream {
1717
let validated_module = validate::ValidatedModule::from_module_and_args(module, macro_args);
1818

1919
let untouched_tokens = &validated_module.untouched_tokens;
20-
let init_fn = validated_module.init_func.as_ref().map(|i| &i.func);
2120
let tests = validated_module
2221
.tests
2322
.iter()
2423
.map(|test| codegen::test(test, &validated_module));
24+
let init_fns = validated_module.init_funcs.values().map(|i| &i.func);
2525

2626
let mod_name = format_ident!("{}", validated_module.module_name);
2727
quote!(
2828
#[cfg(test)]
2929
mod #mod_name {
3030
#(#untouched_tokens)*
3131

32-
#init_fn
32+
#(#init_fns)*
3333

3434
#(#tests)*
3535
}

macros/src/attributes/tests/codegen/test.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@ pub(crate) fn test(test: &TestFunc, module: &ValidatedModule) -> TokenStream {
1313
let mut embassy_task = None;
1414

1515
// Generate the code block that will call init, run the test and check the outcome.
16-
let mut test_invocation = call_test_fn(test, module.init_func.as_ref());
16+
let init = module.init_function_for_test(test);
17+
let mut test_invocation = call_test_fn(test, init);
1718

18-
let init_is_async = module
19-
.init_func
20-
.as_ref()
21-
.map(|i| i.asyncness)
22-
.unwrap_or_default();
19+
let init_is_async = init.map(|i| i.asyncness).unwrap_or_default();
2320

2421
// If the test is async or the init function is async, we need to wrap the test invocation in an executor.
2522
// Result is still a block

macros/src/attributes/tests/parse/function_attributes.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
use darling::FromMeta;
12
use proc_macro_error2::abort;
23
use syn::spanned::Spanned;
34
use syn::{Attribute, ItemFn};
45

56
/// Represents the attributes that can be applied to a function in the test module
67
pub(crate) enum FuncAttribute {
78
Init,
8-
Test,
9+
Test(TestAttribute),
910
ShouldPanic,
1011
Ignore,
1112
Timeout(TimeoutAttribute),
@@ -19,7 +20,7 @@ impl FuncAttribute {
1920
let ident = attr.path().get_ident()?.to_string();
2021
Some(match ident.as_str() {
2122
"init" => FuncAttribute::Init,
22-
"test" => FuncAttribute::Test,
23+
"test" => FuncAttribute::Test(TestAttribute::from_attr(attr)),
2324
"should_panic" => FuncAttribute::ShouldPanic,
2425
"ignore" => FuncAttribute::Ignore,
2526
"timeout" => FuncAttribute::Timeout(TimeoutAttribute::from_attr(attr)),
@@ -56,6 +57,24 @@ impl TimeoutAttribute {
5657
}
5758
}
5859

60+
#[derive(Debug, FromMeta, Default)]
61+
pub(crate) struct TestAttribute {
62+
#[darling(default)]
63+
pub init: Option<syn::Ident>,
64+
}
65+
66+
impl TestAttribute {
67+
pub fn from_attr(attr: &Attribute) -> Self {
68+
match &attr.meta {
69+
syn::Meta::Path(_) => TestAttribute::default(),
70+
meta => match TestAttribute::from_meta(meta) {
71+
Ok(test_attr) => test_attr,
72+
Err(e) => abort!(attr, "failed to parse `test` attribute. Must be of the form #[test(init = init_function)]: {}", e),
73+
},
74+
}
75+
}
76+
}
77+
5978
pub(crate) struct FunctionWithAttributes {
6079
/// Original function item without the attributes that we recognize
6180
pub func: ItemFn,

macros/src/attributes/tests/validate/function.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use proc_macro_error2::abort;
33
use syn::{Attribute, ItemFn, ReturnType, Type};
44

55
pub(crate) struct InitFunc {
6+
pub name: String,
67
pub func: ItemFn,
78
pub state: Option<Type>,
89
pub asyncness: bool,
@@ -14,15 +15,15 @@ impl From<FunctionWithAttributes> for InitFunc {
1415
for (attr, span) in attributes {
1516
match attr {
1617
FuncAttribute::Init => {}
17-
FuncAttribute::Test => unreachable!(),
18+
FuncAttribute::Test(_) => unreachable!(),
1819
_ => abort!(span, "The `#[init]` function can not have this attribute"),
1920
}
2021
}
2122
if check_fn_sig(&func.sig).is_err() || !func.sig.inputs.is_empty() {
2223
abort!(
23-
func.sig,
24-
"`#[init]` function must have signature `async fn() [-> Type]` (async/return type are optional)",
25-
);
24+
func.sig,
25+
"`#[init]` function must have signature `async fn() [-> Type]` (async/return type are optional)",
26+
);
2627
}
2728

2829
if cfg!(not(feature = "embassy")) && func.sig.asyncness.is_some() {
@@ -37,6 +38,7 @@ impl From<FunctionWithAttributes> for InitFunc {
3738
ReturnType::Type(.., ty) => Some(*ty.clone()),
3839
};
3940
InitFunc {
41+
name: func.sig.ident.to_string(),
4042
asyncness: func.sig.asyncness.is_some(),
4143
func,
4244
state,
@@ -52,6 +54,7 @@ pub(crate) struct TestFunc {
5254
pub ignore: bool,
5355
pub asyncness: bool,
5456
pub timeout: Option<u32>,
57+
pub custom_init: Option<syn::Ident>,
5558
}
5659

5760
impl From<FunctionWithAttributes> for TestFunc {
@@ -60,10 +63,11 @@ impl From<FunctionWithAttributes> for TestFunc {
6063
let mut should_panic = false;
6164
let mut ignore = false;
6265
let mut timeout = None;
66+
let mut custom_init = None;
6367
for (attr, _span) in attributes {
6468
match attr {
6569
FuncAttribute::Init => unreachable!(),
66-
FuncAttribute::Test => {}
70+
FuncAttribute::Test(attr) => custom_init = attr.init,
6771
FuncAttribute::ShouldPanic => should_panic = true,
6872
FuncAttribute::Ignore => ignore = true,
6973
FuncAttribute::Timeout(t) => timeout = Some(t.value),
@@ -99,17 +103,31 @@ impl From<FunctionWithAttributes> for TestFunc {
99103
should_panic,
100104
ignore,
101105
timeout,
106+
custom_init,
107+
}
108+
}
109+
}
110+
111+
pub(crate) struct OtherFunc(pub FunctionWithAttributes);
112+
impl From<FunctionWithAttributes> for OtherFunc {
113+
fn from(func: FunctionWithAttributes) -> Self {
114+
if let Some((_attr, span)) = func.attributes.first() {
115+
abort!(
116+
span,
117+
"Only `#[test]` or `#[init]` functions can have such an attribute"
118+
);
102119
}
120+
OtherFunc(func)
103121
}
104122
}
105123

106-
pub(crate) enum Func {
124+
pub(crate) enum AnnotatedFunction {
107125
Init(InitFunc),
108126
Test(TestFunc),
109-
Other(FunctionWithAttributes),
127+
Other(OtherFunc),
110128
}
111129

112-
impl From<FunctionWithAttributes> for Func {
130+
impl From<FunctionWithAttributes> for AnnotatedFunction {
113131
fn from(func: FunctionWithAttributes) -> Self {
114132
enum FuncKind {
115133
Init,
@@ -119,8 +137,8 @@ impl From<FunctionWithAttributes> for Func {
119137
for (attr, span) in &func.attributes {
120138
match attr {
121139
FuncAttribute::Init if func_kind.is_none() => func_kind = Some(FuncKind::Init),
122-
FuncAttribute::Test if func_kind.is_none() => func_kind = Some(FuncKind::Test),
123-
FuncAttribute::Init | FuncAttribute::Test => {
140+
FuncAttribute::Test(_) if func_kind.is_none() => func_kind = Some(FuncKind::Test),
141+
FuncAttribute::Init | FuncAttribute::Test(_) => {
124142
abort!(
125143
span,
126144
"A function can only be marked with one of `#[init]` or `#[test]`"
@@ -131,9 +149,9 @@ impl From<FunctionWithAttributes> for Func {
131149
}
132150

133151
match func_kind {
134-
Some(FuncKind::Init) => Func::Init(InitFunc::from(func)),
135-
Some(FuncKind::Test) => Func::Test(TestFunc::from(func)),
136-
None => Func::Other(func),
152+
Some(FuncKind::Init) => AnnotatedFunction::Init(InitFunc::from(func)),
153+
Some(FuncKind::Test) => AnnotatedFunction::Test(TestFunc::from(func)),
154+
None => AnnotatedFunction::Other(OtherFunc::from(func)),
137155
}
138156
}
139157
}

0 commit comments

Comments
 (0)