diff --git a/validator/src/lib.rs b/validator/src/lib.rs index 8a7a0c63..c99edc08 100644 --- a/validator/src/lib.rs +++ b/validator/src/lib.rs @@ -76,6 +76,7 @@ pub use validation::does_not_contain::validate_does_not_contain; pub use validation::email::validate_email; pub use validation::ip::{validate_ip, validate_ip_v4, validate_ip_v6}; pub use validation::length::validate_length; +pub use validation::length_utf16::validate_length_utf16; pub use validation::must_match::validate_must_match; #[cfg(feature = "unic")] pub use validation::non_control_character::validate_non_control_character; diff --git a/validator/src/traits.rs b/validator/src/traits.rs index 05378bc6..282e2a5b 100644 --- a/validator/src/traits.rs +++ b/validator/src/traits.rs @@ -144,6 +144,36 @@ impl HasLen for IndexSet { } } +/// Allows to limit string length in UTF-16 characters +/// +/// UTF-16 is used in JavaScript and Java. +pub trait HasLenUTF16 { + fn length_utf16(&self) -> u64; +} + +impl HasLenUTF16 for String { + fn length_utf16(&self) -> u64 { + self.encode_utf16().count() as u64 + } +} + +impl<'a> HasLenUTF16 for &'a String { + fn length_utf16(&self) -> u64 { + self.encode_utf16().count() as u64 + } +} + +impl<'a> HasLenUTF16 for &'a str { + fn length_utf16(&self) -> u64 { + self.encode_utf16().count() as u64 + } +} + +impl<'a> HasLenUTF16 for Cow<'a, str> { + fn length_utf16(&self) -> u64 { + self.encode_utf16().count() as u64 + } +} /// Trait to implement if one wants to make the `contains` validator /// work for more types pub trait Contains { diff --git a/validator/src/validation/length_utf16.rs b/validator/src/validation/length_utf16.rs new file mode 100644 index 00000000..91054b6a --- /dev/null +++ b/validator/src/validation/length_utf16.rs @@ -0,0 +1,74 @@ +use crate::traits::HasLenUTF16; + +/// Validates the length of the value given. +/// If the validator has `equal` set, it will ignore any `min` and `max` value. +/// +/// If you apply it on String, don't forget that the length can be different +/// from the number of visual characters for Unicode +#[must_use] +pub fn validate_length_utf16( + value: T, + min: Option, + max: Option, + equal: Option, +) -> bool { + let val_length = value.length_utf16(); + + if let Some(eq) = equal { + return val_length == eq; + } else { + if let Some(m) = min { + if val_length < m { + return false; + } + } + if let Some(m) = max { + if val_length > m { + return false; + } + } + } + + true +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use super::validate_length_utf16; + + #[test] + fn test_validate_length_equal_overrides_min_max() { + assert!(validate_length_utf16("hello", Some(1), Some(2), Some(5))); + } + + #[test] + fn test_validate_length_string_min_max() { + assert!(validate_length_utf16("hello", Some(1), Some(10), None)); + } + + #[test] + fn test_validate_length_string_min_only() { + assert!(!validate_length_utf16("hello", Some(10), None, None)); + } + + #[test] + fn test_validate_length_string_max_only() { + assert!(!validate_length_utf16("hello", None, Some(1), None)); + } + + #[test] + fn test_validate_length_cow() { + let test: Cow<'static, str> = "hello".into(); + assert!(validate_length_utf16(test, None, None, Some(5))); + + let test: Cow<'static, str> = String::from("hello").into(); + assert!(validate_length_utf16(test, None, None, Some(5))); + } + + #[test] + fn test_validate_length_unicode_chars() { + assert!(validate_length_utf16("𝔠", None, None, Some(2))); + } +} diff --git a/validator/src/validation/mod.rs b/validator/src/validation/mod.rs index d4307153..0d9f9949 100644 --- a/validator/src/validation/mod.rs +++ b/validator/src/validation/mod.rs @@ -5,6 +5,7 @@ pub mod does_not_contain; pub mod email; pub mod ip; pub mod length; +pub mod length_utf16; pub mod must_match; #[cfg(feature = "unic")] pub mod non_control_character; diff --git a/validator/tests/display.rs b/validator/tests/display.rs index e6bfa660..a5bf5483 100644 --- a/validator/tests/display.rs +++ b/validator/tests/display.rs @@ -15,6 +15,19 @@ mod tests { assert_eq!(err, "foo: Please provide a valid foo!"); } + #[derive(Validate, Clone)] + struct FooUTF16 { + #[validate(length_utf16(equal = 5, message = "Please provide a valid utf16 foo!"))] + foo: String, + } + + #[test] + fn test_message() { + let bad_foo = FooUTF16 { foo: "hi!".into() }; + let err = format!("{}", bad_foo.validate().unwrap_err()); + assert_eq!(err, "foo: Please provide a valid utf16 foo!"); + } + #[derive(Validate)] struct Bar { #[validate] diff --git a/validator_derive/src/asserts.rs b/validator_derive/src/asserts.rs index 6c1b8b30..f02cf62f 100644 --- a/validator_derive/src/asserts.rs +++ b/validator_derive/src/asserts.rs @@ -110,6 +110,30 @@ pub fn assert_has_len(field_name: String, type_name: &str, field_type: &syn::Typ } } +pub fn assert_has_len_utf16(field_name: String, type_name: &str, field_type: &syn::Type) { + if let syn::Type::Reference(ref tref) = field_type { + let elem = &tref.elem; + let type_name = format!("{}", quote::quote! { #elem }).replace(' ', ""); + + if type_name == "str" { + return; + } + assert_has_len_utf16(field_name, &type_name, elem); + return; + } + + if !type_name.contains("String") + && !type_name.contains("str") + // a bit ugly + && !COW_TYPE.is_match(type_name) + { + abort!(field_type.span(), + "Validator `length` can only be used on types `String`, `&str`, Cow<'_,str> types but found `{}` for field `{}`", + type_name, field_name + ); + } +} + pub fn assert_has_range(field_name: String, type_name: &str, field_type: &syn::Type) { if !NUMBER_TYPES.contains(&type_name) { abort!( diff --git a/validator_derive/src/lib.rs b/validator_derive/src/lib.rs index fdecfbd1..64e19cfd 100644 --- a/validator_derive/src/lib.rs +++ b/validator_derive/src/lib.rs @@ -9,7 +9,9 @@ use quote::ToTokens; use quote::{quote, quote_spanned}; use syn::{parse_quote, spanned::Spanned, GenericParam, Lifetime, LifetimeDef, Type}; -use asserts::{assert_has_len, assert_has_range, assert_string_type, assert_type_matches}; +use asserts::{ + assert_has_len, assert_has_len_utf16, assert_has_range, assert_string_type, assert_type_matches, +}; use lit::*; use quoting::{quote_schema_validations, quote_validator, FieldQuoter}; use validation::*; @@ -513,6 +515,18 @@ fn find_validators_for_field( &meta_items, )); } + "length_utf16" => { + assert_has_len_utf16( + rust_ident.clone(), + field_type, + &field.ty, + ); + validators.push(extract_length_utf16_validation( + rust_ident.clone(), + attr, + &meta_items, + )); + } "range" => { assert_has_range(rust_ident.clone(), field_type, &field.ty); validators.push(extract_range_validation( diff --git a/validator_derive/src/quoting.rs b/validator_derive/src/quoting.rs index 03e40a31..701c0bef 100644 --- a/validator_derive/src/quoting.rs +++ b/validator_derive/src/quoting.rs @@ -223,6 +223,66 @@ pub fn quote_length_validation( unreachable!() } +pub fn quote_length_utf16_validation( + field_quoter: &FieldQuoter, + validation: &FieldValidation, +) -> proc_macro2::TokenStream { + let field_name = &field_quoter.name; + let validator_param = field_quoter.quote_validator_param(); + + if let Validator::LengthUTF16 { min, max, equal } = &validation.validator { + let min_err_param_quoted = if let Some(v) = min { + let v = value_or_path_to_tokens(v); + quote!(err.add_param(::std::borrow::Cow::from("min"), &#v);) + } else { + quote!() + }; + let max_err_param_quoted = if let Some(v) = max { + let v = value_or_path_to_tokens(v); + quote!(err.add_param(::std::borrow::Cow::from("max"), &#v);) + } else { + quote!() + }; + let equal_err_param_quoted = if let Some(v) = equal { + let v = value_or_path_to_tokens(v); + quote!(err.add_param(::std::borrow::Cow::from("equal"), &#v);) + } else { + quote!() + }; + + let min_tokens = option_to_tokens( + &min.clone().as_ref().map(value_or_path_to_tokens).map(|x| quote!(#x as u64)), + ); + let max_tokens = option_to_tokens( + &max.clone().as_ref().map(value_or_path_to_tokens).map(|x| quote!(#x as u64)), + ); + let equal_tokens = option_to_tokens( + &equal.clone().as_ref().map(value_or_path_to_tokens).map(|x| quote!(#x as u64)), + ); + + let quoted_error = quote_error(validation); + let quoted = quote!( + if !::validator::validate_length_utf16( + #validator_param, + #min_tokens, + #max_tokens, + #equal_tokens + ) { + #quoted_error + #min_err_param_quoted + #max_err_param_quoted + #equal_err_param_quoted + err.add_param(::std::borrow::Cow::from("value"), &#validator_param); + errors.add(#field_name, err); + } + ); + + return field_quoter.wrap_if_option(quoted); + } + + unreachable!() +} + pub fn quote_range_validation( field_quoter: &FieldQuoter, validation: &FieldValidation, @@ -505,6 +565,9 @@ pub fn quote_validator( Validator::Length { .. } => { validations.push(quote_length_validation(field_quoter, validation)) } + Validator::LengthUTF16 { .. } => { + validations.push(quote_length_utf16_validation(field_quoter, validation)) + } Validator::Range { .. } => { validations.push(quote_range_validation(field_quoter, validation)) } diff --git a/validator_derive/src/validation.rs b/validator_derive/src/validation.rs index b889ccad..c3dfd4c8 100644 --- a/validator_derive/src/validation.rs +++ b/validator_derive/src/validation.rs @@ -124,6 +124,81 @@ pub fn extract_length_validation( } } +pub fn extract_length_utf16_validation( + field: String, + attr: &syn::Attribute, + meta_items: &[syn::NestedMeta], +) -> FieldValidation { + let mut min = None; + let mut max = None; + let mut equal = None; + + let (message, code) = extract_message_and_code("length_utf16", &field, meta_items); + + let error = |span: Span, msg: &str| -> ! { + abort!(span, "Invalid attribute #[validate] on field `{}`: {}", field, msg); + }; + + for meta_item in meta_items { + if let syn::NestedMeta::Meta(ref item) = *meta_item { + if let syn::Meta::NameValue(syn::MetaNameValue { ref path, ref lit, .. }) = *item { + let ident = path.get_ident().unwrap(); + match ident.to_string().as_ref() { + "message" | "code" => continue, + "min" => { + min = match lit_to_u64_or_path(lit) { + Some(s) => Some(s), + None => error(lit.span(), "invalid argument type for `min` of `length_utf16` validator: only number literals or value paths are allowed"), + }; + } + "max" => { + max = match lit_to_u64_or_path(lit) { + Some(s) => Some(s), + None => error(lit.span(), "invalid argument type for `max` of `length_utf16` validator: only number literals or value paths are allowed"), + }; + } + "equal" => { + equal = match lit_to_u64_or_path(lit) { + Some(s) => Some(s), + None => error(lit.span(), "invalid argument type for `equal` of `length_utf16` validator: only number literals or value paths are allowed"), + }; + } + v => error(path.span(), &format!( + "unknown argument `{}` for validator `length_utf16` (it only has `min`, `max`, `equal`)", + v + )) + } + } else { + error( + item.span(), + &format!( + "unexpected item {:?} while parsing `length_utf16` validator of field {}", + item, field + ), + ) + } + } + + if equal.is_some() && (min.is_some() || max.is_some()) { + error(meta_item.span(), "both `equal` and `min` or `max` have been set in `length_utf16` validator: probably a mistake"); + } + } + + if min.is_none() && max.is_none() && equal.is_none() { + error( + attr.span(), + "Validator `length_utf16` requires at least 1 argument out of `min`, `max` and `equal`", + ); + } + + let validator = Validator::LengthUTF16 { min, max, equal }; + FieldValidation { + message, + code: code.unwrap_or_else(|| validator.code().to_string()), + validator, + } +} + pub fn extract_range_validation( field: String, attr: &syn::Attribute, diff --git a/validator_derive_tests/tests/length_utf16.rs b/validator_derive_tests/tests/length_utf16.rs new file mode 100644 index 00000000..a67c4b92 --- /dev/null +++ b/validator_derive_tests/tests/length_utf16.rs @@ -0,0 +1,127 @@ +use validator::Validate; + +const MIN_CONST: u64 = 1; +const MAX_CONST: u64 = 10; + +const MAX_CONST_I32: i32 = 2; +const NEGATIVE_CONST_I32: i32 = -10; + +#[test] +fn can_validate_length_utf16_ok() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(min = 5, max = 10))] + val: String, + } + + let s = TestStruct { val: "hello".to_string() }; + + assert!(s.validate().is_ok()); +} + +#[test] +fn validate_length_utf16_with_ref_ok() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(min = "MIN_CONST", max = "MAX_CONST"))] + val: String, + } + + let s = TestStruct { val: "hello".to_string() }; + + assert!(s.validate().is_ok()); +} + +#[test] +fn validate_length_utf16_with_ref_fails() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(min = "MIN_CONST", max = "MAX_CONST"))] + val: String, + } + + let s = TestStruct { val: "".to_string() }; + + assert!(s.validate().is_err()); +} + +#[test] +fn validate_length_utf16_with_ref_i32_fails() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(max = "MAX_CONST_I32"))] + val: String, + } + + let s = TestStruct { val: "TO_LONG_YAY".to_string() }; + + assert!(s.validate().is_err()); +} + +#[test] +fn validate_length_utf16_with_ref_negative_i32_fails() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(max = "NEGATIVE_CONST_I32"))] + val: String, + } + + let s = TestStruct { val: "TO_LONG_YAY".to_string() }; + + assert!(s.validate().is_ok()); +} + +#[test] +fn value_out_of_length_utf16_fails_validation() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(min = 5, max = 10))] + val: String, + } + + let s = TestStruct { val: String::new() }; + let res = s.validate(); + assert!(res.is_err()); + let err = res.unwrap_err(); + let errs = err.field_errors(); + assert!(errs.contains_key("val")); + assert_eq!(errs["val"].len(), 1); + assert_eq!(errs["val"][0].code, "length_utf16"); + assert_eq!(errs["val"][0].params["value"], ""); + assert_eq!(errs["val"][0].params["min"], 5); + assert_eq!(errs["val"][0].params["max"], 10); +} + +#[test] +fn can_specify_code_for_length_utf16() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(min = 5, max = 10, code = "oops"))] + val: String, + } + let s = TestStruct { val: String::new() }; + let res = s.validate(); + assert!(res.is_err()); + let err = res.unwrap_err(); + let errs = err.field_errors(); + assert!(errs.contains_key("val")); + assert_eq!(errs["val"].len(), 1); + assert_eq!(errs["val"][0].code, "oops"); +} + +#[test] +fn can_specify_message_for_length_utf16() { + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length_utf16(min = 5, max = 10, message = "oops"))] + val: String, + } + let s = TestStruct { val: String::new() }; + let res = s.validate(); + assert!(res.is_err()); + let err = res.unwrap_err(); + let errs = err.field_errors(); + assert!(errs.contains_key("val")); + assert_eq!(errs["val"].len(), 1); + assert_eq!(errs["val"][0].clone().message.unwrap(), "oops"); +} diff --git a/validator_types/src/lib.rs b/validator_types/src/lib.rs index 42dbdaff..abdb4cd1 100644 --- a/validator_types/src/lib.rs +++ b/validator_types/src/lib.rs @@ -31,6 +31,12 @@ pub enum Validator { max: Option>, equal: Option>, }, + // string value validated against length in UTF-16 characters (string repr in JavaScript, Java) + LengthUTF16 { + min: Option>, + max: Option>, + equal: Option>, + }, #[cfg(feature = "card")] CreditCard, #[cfg(feature = "phone")] @@ -81,6 +87,7 @@ impl Validator { Validator::Regex(_) => "regex", Validator::Range { .. } => "range", Validator::Length { .. } => "length", + Validator::LengthUTF16 { .. } => "length_utf16", #[cfg(feature = "card")] Validator::CreditCard => "credit_card", #[cfg(feature = "phone")]