diff --git a/Taskfile.yml b/Taskfile.yml index a7fa677..c964b13 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -36,6 +36,7 @@ tasks: test: cmds: + - task: pm:test - task: cu:test dev-doc: diff --git a/packages/copper-proc-macros/Cargo.toml b/packages/copper-proc-macros/Cargo.toml index 877c296..cf0f588 100644 --- a/packages/copper-proc-macros/Cargo.toml +++ b/packages/copper-proc-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pistonite-cu-proc-macros" -version = "0.2.0" +version = "0.2.1" edition = "2024" description = "Proc-macros for Cu" repository = "https://github.com/Pistonite/cu" @@ -12,7 +12,7 @@ exclude = [ [dependencies.pm] package = "pistonite-pm" -version = "0.2.0" +version = "0.2.1" path = "../promethium" features = ["full"] diff --git a/packages/copper-proc-macros/src/cli.rs b/packages/copper-proc-macros/src/cli.rs index 9525d12..fa0c93c 100644 --- a/packages/copper-proc-macros/src/cli.rs +++ b/packages/copper-proc-macros/src/cli.rs @@ -53,9 +53,9 @@ pub fn expand(attr: TokenStream, input: TokenStream) -> pm::Result Ok(expanded) } -fn parse_attributes(attr: TokenStream) -> pm::Result { +fn parse_attributes(attr: TokenStream) -> pm::Result { let attrs = pm::parse_punctuated::(attr)?; - let mut out = CliAttributes::default(); + let mut out = ParsedAttributes::default(); for attr in attrs { if attr.path.is_ident("flags") { @@ -79,7 +79,7 @@ fn parse_attributes(attr: TokenStream) -> pm::Result { Ok(out) } #[derive(Default)] -struct CliAttributes { +struct ParsedAttributes { flags_ident: Option, preprocess_fn: Option, } diff --git a/packages/copper-proc-macros/src/error_ctx.rs b/packages/copper-proc-macros/src/error_ctx.rs new file mode 100644 index 0000000..1cb62cd --- /dev/null +++ b/packages/copper-proc-macros/src/error_ctx.rs @@ -0,0 +1,128 @@ +use pm::pre::*; + +pub fn expand(attr: TokenStream, input: TokenStream) -> pm::Result { + let attrs = parse_attributes(attr)?; + let item: syn::ItemFn = syn::parse(input)?; + + let item_attrs = &item.attrs; + let item_block = &item.block; + let sig = &item.sig; + let is_async = item.sig.asyncness.is_some(); + let retty = match &item.sig.output { + syn::ReturnType::Default => pm::quote! {()}, + syn::ReturnType::Type(_, ty) => pm::quote! {#ty}, + }; + let args = attrs.format_args; + + let block = match (attrs.is_pre, is_async) { + // for non-async, we need to use a closure + // to prevent `?` operator from returning directly. + // for async, we can use an async block + (true, false) => { + pm::quote! { + use cu::Context as _; + let __error_msg = format!(#args); + let __result: #retty = (move|| -> #retty #item_block)(); + __result.context(__error_msg) + } + } + (true, true) => { + pm::quote! { + use cu::Context as _; + let __error_msg = format!(#args); + let __result: #retty = async move #item_block.await; + __result.context(__error_msg) + } + } + (false, false) => { + pm::quote! { + use cu::Context as _; + let __result: #retty = (move|| -> #retty #item_block)(); + __result.with_context(|| format!(#args)) + } + } + (false, true) => { + pm::quote! { + use cu::Context as _; + let __result: #retty = async move #item_block.await; + __result.with_context(|| format!(#args)) + } + } + }; + + let expanded = pm::quote! { + #(#item_attrs)* #sig { #block } + }; + + Ok(expanded) +} + +fn parse_attributes(attr: TokenStream) -> pm::Result { + let Ok(attrs) = pm::parse_punctuated::(attr.clone()) else { + // if the input is not a list of meta, assuming using shorthand + // input is format args + return Ok(ParsedAttributes { + format_args: attr.into(), + ..Default::default() + }); + }; + + let mut out = ParsedAttributes::default(); + for attr in attrs { + match attr { + syn::Meta::Path(attr) => { + if attr.is_ident("pre") { + out.is_pre = true; + continue; + } else if attr.is_ident("format") { + pm::bail!( + attr, + "`format` attribute should contain the format args, i.e. #[cu::error_ctx(format(...))], or use the shorthand #[cu::error_ctx(...)]" + ); + } + pm::bail!(attr, "unknown attribute") + } + syn::Meta::List(attr) => { + if attr.path.is_ident("pre") { + pm::bail!( + attr, + "`pre` attribute should not have a value, i.e. #[cu::error_ctx(pre, ...)]" + ); + } + if attr.path.is_ident("format") { + out.format_args = attr.tokens; + continue; + } + pm::bail!(attr, "unknown attribute") + } + syn::Meta::NameValue(attr) => { + if attr.path.is_ident("pre") { + pm::bail!( + attr, + "`pre` attribute should not have a value, i.e. #[cu::error_ctx(pre, ...)]" + ); + } + if attr.path.is_ident("format") { + pm::bail!( + attr, + "`format` attribute should be parenthesized, i.e. #[cu::error_ctx(format(...))], or use the shorthand #[cu::error_ctx(...)]" + ); + } + pm::bail!(attr, "unknown attribute") + } + } + } + + Ok(out) +} + +#[derive(Default)] +struct ParsedAttributes { + /// if the error string should be formatted before invoking + /// the function. this is needed if some non-Copy values + /// are moved into the function + is_pre: bool, + + /// The format expression + format_args: TokenStream2, +} diff --git a/packages/copper-proc-macros/src/lib.rs b/packages/copper-proc-macros/src/lib.rs index 8e7f2ef..9803a4a 100644 --- a/packages/copper-proc-macros/src/lib.rs +++ b/packages/copper-proc-macros/src/lib.rs @@ -13,3 +13,13 @@ pub fn derive_parse(input: TokenStream) -> TokenStream { pm::flatten(derive_parse::expand(input)) } mod derive_parse; + +/// Attribute macro for wrapping a function with an error context +/// +/// See the [tests](https://github.com/Pistonite/cu/blob/main/packages/copper/tests/error_ctx.rs) +/// for examples +#[proc_macro_attribute] +pub fn error_ctx(attr: TokenStream, input: TokenStream) -> TokenStream { + pm::flatten(error_ctx::expand(attr, input)) +} +mod error_ctx; diff --git a/packages/copper/Cargo.toml b/packages/copper/Cargo.toml index 835c522..e71820d 100644 --- a/packages/copper/Cargo.toml +++ b/packages/copper/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pistonite-cu" -version = "0.6.4" +version = "0.6.5" edition = "2024" description = "Battery-included common utils to speed up development of rust tools" repository = "https://github.com/Pistonite/cu" @@ -11,7 +11,7 @@ exclude = [ ] [dependencies] -pistonite-cu-proc-macros = { version = "0.2.0", path = "../copper-proc-macros", optional = true } +pistonite-cu-proc-macros = { version = "0.2.1", path = "../copper-proc-macros" } anyhow = "1.0.100" log = "0.4.28" @@ -72,7 +72,7 @@ full = [ ] # Command Line Interface (enables integration with `clap` and command line entry points) -cli = ["dep:clap", "dep:pistonite-cu-proc-macros", "print"] +cli = ["dep:clap", "print"] print = ["dep:regex", "dep:env_filter", "dep:terminal_size", "dep:unicode-width"] # Utils to show prompt for user input in terminal prompt = ["print"] @@ -96,7 +96,7 @@ fs = [ "tokio?/fs" ] # Enable parsing utils -parse = ["dep:pistonite-cu-proc-macros"] +parse = [] serde = ["dep:serde"] json = ["parse", "serde", "dep:serde_json"] json-preserve-order = ["json", "serde_json/preserve_order"] diff --git a/packages/copper/src/lib.rs b/packages/copper/src/lib.rs index bc31ff3..18c48b2 100644 --- a/packages/copper/src/lib.rs +++ b/packages/copper/src/lib.rs @@ -120,6 +120,7 @@ pub use misc::*; // re-exports from libraries pub use anyhow::{Context, Error, Ok, Result, anyhow as fmterr, bail, ensure}; pub use log::{debug, error, info, trace, warn}; +pub use pistonite_cu_proc_macros::error_ctx; #[cfg(feature = "coroutine")] pub use tokio::{join, try_join}; diff --git a/packages/copper/tests/error_ctx.rs b/packages/copper/tests/error_ctx.rs new file mode 100644 index 0000000..f715eb9 --- /dev/null +++ b/packages/copper/tests/error_ctx.rs @@ -0,0 +1,112 @@ +use pistonite_cu as cu; + +#[test] +fn test_example1() { + let msg = format!("{:?}", example1(42).unwrap_err()); + assert_eq!( + msg, + r"failed with arg 42 + +Caused by: + example1" + ) +} + +#[cu::error_ctx("failed with arg {arg}")] +fn example1(arg: u32) -> cu::Result<()> { + cu::bail!("example1") +} + +#[test] +fn test_example2() { + let msg = format!("{:?}", example2("hello".to_string()).unwrap_err()); + assert_eq!( + msg, + r"failed with arg hello + +Caused by: + example2: hello" + ) +} + +// 'pre' is needed because s is moved into the function +// so the error message needs to be formatted before running the function +#[cu::error_ctx(pre, format("failed with arg {s}"))] +fn example2(s: String) -> cu::Result<()> { + cu::bail!("example2: {s}") +} + +#[tokio::test] +async fn test_example3_err() { + let msg = format!("{:?}", example3(4).await.unwrap_err()); + assert_eq!( + msg, + r"async failed with arg 4 + +Caused by: + Condition failed: `value > 4` (4 vs 4)" + ) +} + +#[tokio::test] +async fn test_example3_ok() { + assert!(example3(42).await.is_ok()) +} + +// question mark works as expected (context is added at return time) +#[cu::error_ctx("async failed with arg {}", s)] +async fn example3(s: u32) -> cu::Result<()> { + let value = returns_ok(s)?; + cu::ensure!(value > 4); + Ok(()) +} + +#[tokio::test] +async fn test_example4_err() { + let msg = format!("{:?}", example4("".to_string()).await.unwrap_err()); + assert_eq!( + msg, + r"async failed with arg + +Caused by: + Condition failed: `!value.is_empty()`" + ) +} + +#[tokio::test] +async fn test_example4_ok() { + assert!(example4("hello".to_string()).await.is_ok()) +} + +// question mark works as expected (context is added at return time) +#[cu::error_ctx(pre, format("async failed with arg {s}"))] +async fn example4(s: String) -> cu::Result<()> { + let value = returns_ok(s)?; + cu::ensure!(!value.is_empty()); + Ok(()) +} + +#[test] +fn test_example5() { + let msg = format!("{:?}", Foo(7).example5().unwrap_err()); + assert_eq!( + msg, + r"Foo failed with arg 7 + +Caused by: + example5" + ) +} + +// associated functions also work +struct Foo(u32); +impl Foo { + #[cu::error_ctx("Foo failed with arg {}", self.0)] + fn example5(&self) -> cu::Result<()> { + cu::bail!("example5") + } +} + +fn returns_ok(t: T) -> cu::Result { + Ok(t) +} diff --git a/packages/promethium/Cargo.toml b/packages/promethium/Cargo.toml index c3b01c2..a41aad1 100644 --- a/packages/promethium/Cargo.toml +++ b/packages/promethium/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pistonite-pm" -version = "0.2.0" +version = "0.2.1" edition = "2024" description = "Procedural Macro Common Utils" repository = "https://github.com/Pistonite/cu" diff --git a/packages/promethium/Taskfile.yml b/packages/promethium/Taskfile.yml index 3752a60..ed9be82 100644 --- a/packages/promethium/Taskfile.yml +++ b/packages/promethium/Taskfile.yml @@ -18,3 +18,8 @@ tasks: cmds: - cmd: cargo publish ignore_error: true + test: + cmds: + - cargo test + - cargo test --features full + - cargo test --no-default-features --features full diff --git a/packages/promethium/src/util.rs b/packages/promethium/src/util.rs index 99b9371..0b97d89 100644 --- a/packages/promethium/src/util.rs +++ b/packages/promethium/src/util.rs @@ -22,6 +22,7 @@ pub fn flatten(result: syn::Result) -> crate::TokenStream } /// Convenience wrapper for parsing punctuated syntax +#[cfg(feature = "proc-macro")] pub fn parse_punctuated( input: crate::TokenStream, ) -> syn::Result> {