Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Rust

on:
push:
branches: ["main"]
pull_request:

jobs:
format:
name: Format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: moonrepo/setup-rust@v1
with:
components: rustfmt
- name: Check formatting
run: >-
cargo fmt --all --check
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: moonrepo/setup-rust@v1
with:
components: clippy
- name: Run linter
run: >-
cargo clippy --all-features
test:
name: Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: moonrepo/setup-rust@v1
- name: Run tests
run: >-
cargo test --no-default-features
- name: Run tests (all features)
run: >-
cargo test --all-features
24 changes: 11 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{
Attribute, Ident, ItemFn, ItemImpl, ItemTrait, Meta, ReturnType, Signature, Token, TraitItem,
TraitItemFn, Type, TypeImplTrait, TypeParamBound, WherePredicate, ImplItem, ImplItemFn,
Attribute, FnArg, GenericArgument, Ident, ImplItem, ImplItemFn, ItemFn, ItemImpl, ItemTrait,
Meta, PathArguments, ReturnType, Signature, Token, TraitItem, TraitItemFn, Type, TypeImplTrait,
TypeParamBound, WherePredicate,
parse::{Parse, ParseStream},
parse_quote, FnArg, PathArguments, GenericArgument,
parse_quote,
};

/// Whether to bound an `async fn` or its receiver by [`Send`] or [`Sync`].
Expand Down Expand Up @@ -207,17 +208,17 @@ impl DesugarAsync for ImplItemFn {
if self.sig.asyncness.is_some() {
// Store the original body
let body = &self.block;

// Transform the signature
self.sig.desugar_async(config);

// Wrap the body in an async block
self.block = parse_quote! {
{
async move #body
}
};

// Add #[must_use] attribute to async methods
self.attrs.push(parse_quote! { #[must_use] });
// Add lint suppression
Expand Down Expand Up @@ -267,7 +268,7 @@ impl DesugarAsync for TraitItemFn {
}
};
self.attrs.push(lint_attr);

// Transform default method body if present
if let Some(block) = &mut self.default {
let transformed = quote! {
Expand Down Expand Up @@ -296,15 +297,14 @@ impl DesugarAsync for Signature {
// Build the Future bounds
let mut bounds: Vec<TypeParamBound> =
vec![parse_quote! { std::future::Future<Output = #output_type> }];

// Check receiver type to determine bounds
let receiver_bounds = analyze_receiver(&self.inputs);

if config.send || receiver_bounds.needs_send {
bounds.push(parse_quote! { Send });
}


// Create the new return type
let impl_trait = TypeImplTrait {
impl_token: syn::token::Impl::default(),
Expand Down Expand Up @@ -368,7 +368,7 @@ fn analyze_receiver(inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>)
}
}
}

ReceiverBounds {
needs_send: false,
needs_sync: false,
Expand All @@ -389,5 +389,3 @@ fn add_self_sync_bound(sig: &mut Signature) {
.predicates
.push(sync_bound);
}


45 changes: 27 additions & 18 deletions tests/default_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@ use bitte::bitte;
trait AsyncTraitWithDefaults {
// Required method
async fn required_method(&self) -> String;

// Default async method
async fn default_method(&self) -> String {
"default implementation".to_string()
}

// Default async method that calls another method
async fn composed_default(&self) -> String {
let required = self.required_method().await;
format!("composed: {}", required)
}

// Default async method with parameters (takes ownership to avoid lifetime issues)
async fn default_with_params(&self, prefix: String) -> String {
format!("{}: default", prefix)
}

// Non-async default method (should remain unchanged)
fn sync_default(&self) -> &'static str {
"sync default"
Expand All @@ -34,7 +34,7 @@ impl AsyncTraitWithDefaults for CustomImpl {
async fn required_method(&self) -> String {
"custom implementation".to_string()
}

// Override one default method
async fn default_method(&self) -> String {
"overridden default".to_string()
Expand All @@ -58,29 +58,38 @@ mod tests {
#[tokio::test]
async fn test_custom_impl() {
let custom = CustomImpl;

assert_eq!(custom.required_method().await, "custom implementation");
assert_eq!(custom.default_method().await, "overridden default");
assert_eq!(custom.composed_default().await, "composed: custom implementation");
assert_eq!(custom.default_with_params("test".to_string()).await, "test: default");
assert_eq!(
custom.composed_default().await,
"composed: custom implementation"
);
assert_eq!(
custom.default_with_params("test".to_string()).await,
"test: default"
);
assert_eq!(custom.sync_default(), "sync default");
}

#[tokio::test]
async fn test_minimal_impl() {
let minimal = MinimalImpl;

assert_eq!(minimal.required_method().await, "minimal");
assert_eq!(minimal.default_method().await, "default implementation");
assert_eq!(minimal.composed_default().await, "composed: minimal");
assert_eq!(minimal.default_with_params("hello".to_string()).await, "hello: default");
assert_eq!(
minimal.default_with_params("hello".to_string()).await,
"hello: default"
);
assert_eq!(minimal.sync_default(), "sync default");
}

#[tokio::test]
async fn test_future_is_send() {
fn assert_send<T: Send>(_: T) {}

let custom = CustomImpl;
assert_send(custom.required_method());
assert_send(custom.default_method());
Expand All @@ -92,8 +101,8 @@ mod tests {
#[bitte]
trait GenericAsyncWithDefaults<T: Send + Sync + 'static> {
async fn process(&self, value: T) -> T;
async fn process_twice(&self, value: T) -> T

async fn process_twice(&self, value: T) -> T
where
T: Clone,
{
Expand All @@ -118,11 +127,11 @@ mod generic_tests {
#[tokio::test]
async fn test_generic_defaults() {
let generic_impl = GenericImpl;

let result = generic_impl.process(42).await;
assert_eq!(result, 42);

let result = generic_impl.process_twice(10).await;
assert_eq!(result, 10);
}
}
}