Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jose-jws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let jws_json = serde_json::json!({
]
});

let Jws::General(jws) = serde_json::from_value(jws_json).unwrap() else {
let Jws::General(jws) = serde_json::from_value::<Jws>(jws_json).unwrap() else {
panic!("couldn't deserialize JWS");
};

Expand Down
17 changes: 13 additions & 4 deletions jose-jws/src/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,32 @@ use jose_b64::stream::Error;

use crate::{Flattened, General, Jws, Signature};

impl FromStr for Jws {
impl<U, P> FromStr for Jws<U, P>
where
P: serde::de::DeserializeOwned,
{
type Err = Error<serde_json::Error>;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Flattened::from_str(s)?.into())
}
}

impl FromStr for General {
impl<U, P> FromStr for General<U, P>
where
P: serde::de::DeserializeOwned,
{
type Err = Error<serde_json::Error>;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Flattened::from_str(s)?.into())
}
}

impl FromStr for Flattened {
impl<U, P> FromStr for Flattened<U, P>
where
P: serde::de::DeserializeOwned,
{
type Err = Error<serde_json::Error>;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Expand Down Expand Up @@ -54,7 +63,7 @@ impl FromStr for Flattened {
}
}

impl Display for Flattened {
impl<U, P> Display for Flattened<U, P> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let mut prot = alloc::string::String::new();
if let Some(x) = self.signature.protected.as_ref() {
Expand Down
35 changes: 18 additions & 17 deletions jose-jws/src/crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@ use rand_core::RngCore;
use crate::{Flattened, General, Jws, Protected, Signature, Unprotected};

/// Signature creation state
pub trait Signer: Update {
pub trait Signer<U = Unprotected, P = Protected<U>>: Update {
#[allow(missing_docs)]
type FinishError: From<Self::Error>;

/// Finish processing payload and create the signature.
fn finish(self, rng: impl 'static + RngCore) -> Result<Signature, Self::FinishError>;
fn finish(self, rng: impl 'static + RngCore) -> Result<Signature<U, P>, Self::FinishError>;
}

/// A signature creation key
pub trait SigningKey<'a> {
pub trait SigningKey<'a, U = Unprotected, P = Protected<U>> {
#[allow(missing_docs)]
type StartError: From<<Self::Signer as Update>::Error>;

/// The state object used during signing.
type Signer: Signer;
type Signer: Signer<U, P>;

/// Begin the signature creation process.
fn sign(
&'a self,
prot: Option<Protected>,
head: Option<Unprotected>,
) -> Result<Self::Signer, Self::StartError>;
fn sign(&'a self, prot: Option<P>, head: Option<U>) -> Result<Self::Signer, Self::StartError>;
}

/// Signature verification state
Expand Down Expand Up @@ -98,26 +94,31 @@ where
}
}

impl<'a, T: VerifyingKey<'a, &'a Signature>> VerifyingKey<'a, &'a Flattened> for T
impl<'a, T, U, P> VerifyingKey<'a, &'a Flattened<U, P>> for T
where
T: VerifyingKey<'a, &'a Signature<U, P>>,
<T::Verifier as Verifier<'a>>::FinishError: Default,
{
type StartError = T::StartError;
type Verifier = Vec<T::Verifier>;

fn verify(&'a self, flattened: &'a Flattened) -> Result<Self::Verifier, Self::StartError> {
fn verify(
&'a self,
flattened: &'a Flattened<U, P>,
) -> Result<Self::Verifier, Self::StartError> {
Ok(vec![self.verify(&flattened.signature)?])
}
}

impl<'a, T: VerifyingKey<'a, &'a Signature>> VerifyingKey<'a, &'a General> for T
impl<'a, T, U, P> VerifyingKey<'a, &'a General<U, P>> for T
where
T: VerifyingKey<'a, &'a Signature<U, P>>,
<T::Verifier as Verifier<'a>>::FinishError: Default,
{
type StartError = T::StartError;
type Verifier = Vec<T::Verifier>;

fn verify(&'a self, general: &'a General) -> Result<Self::Verifier, Self::StartError> {
fn verify(&'a self, general: &'a General<U, P>) -> Result<Self::Verifier, Self::StartError> {
general
.signatures
.iter()
Expand All @@ -126,17 +127,17 @@ where
}
}

impl<'a, T, V, E> VerifyingKey<'a, &'a Jws> for T
impl<'a, T, V, E, U, P> VerifyingKey<'a, &'a Jws<U, P>> for T
where
T: VerifyingKey<'a, &'a Flattened, Verifier = V, StartError = E>,
T: VerifyingKey<'a, &'a General, Verifier = V, StartError = E>,
T: VerifyingKey<'a, &'a Flattened<U, P>, Verifier = V, StartError = E>,
T: VerifyingKey<'a, &'a General<U, P>, Verifier = V, StartError = E>,
E: From<V::Error>,
V: Verifier<'a>,
{
type StartError = E;
type Verifier = V;

fn verify(&'a self, jws: &'a Jws) -> Result<Self::Verifier, Self::StartError> {
fn verify(&'a self, jws: &'a Jws<U, P>) -> Result<Self::Verifier, Self::StartError> {
match jws {
Jws::General(general) => self.verify(general),
Jws::Flattened(flattened) => self.verify(flattened),
Expand Down
23 changes: 19 additions & 4 deletions jose-jws/src/head.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use alloc::vec::Vec;
use alloc::{boxed::Box, string::String};
use core::ops::{Deref, DerefMut};

use jose_b64::base64ct::Base64;
use jose_b64::serde::Bytes;
Expand All @@ -22,7 +23,7 @@ fn b64_serialize(value: &bool) -> bool {

/// The JWS Protected Header
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Protected {
pub struct Protected<U = Unprotected> {
/// RFC 7517 Section 4.1.11
#[serde(skip_serializing_if = "Option::is_none", default)]
pub crit: Option<Vec<String>>,
Expand All @@ -37,20 +38,34 @@ pub struct Protected {

/// Other values that may appear in the protected header.
#[serde(flatten)]
pub oth: Unprotected,
pub oth: U,
}

impl Default for Protected {
impl<U: Default> Default for Protected<U> {
fn default() -> Self {
Self {
crit: None,
nonce: None,
b64: true,
oth: Unprotected::default(),
oth: U::default(),
}
}
}

impl<U> Deref for Protected<U> {
type Target = U;

fn deref(&self) -> &Self::Target {
&self.oth
}
}

impl<U> DerefMut for Protected<U> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.oth
}
}

/// The JWS Unprotected Header
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Unprotected {
Expand Down
36 changes: 20 additions & 16 deletions jose-jws/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,24 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[non_exhaustive]
#[allow(clippy::large_enum_variant)]
#[serde(bound(deserialize = "U: Deserialize<'de>, P: serde::de::DeserializeOwned"))]
#[serde(untagged)]
pub enum Jws {
pub enum Jws<U = Unprotected, P = Protected<U>> {
/// General Serialization. This is
General(General),
General(General<U, P>),

/// Flattened Serialization
Flattened(Flattened),
Flattened(Flattened<U, P>),
}

impl From<General> for Jws {
fn from(value: General) -> Self {
impl<U, P> From<General<U, P>> for Jws<U, P> {
fn from(value: General<U, P>) -> Self {
Jws::General(value)
}
}

impl From<Flattened> for Jws {
fn from(value: Flattened) -> Self {
impl<U, P> From<Flattened<U, P>> for Jws<U, P> {
fn from(value: Flattened<U, P>) -> Self {
Jws::Flattened(value)
}
}
Expand All @@ -77,16 +78,17 @@ impl From<Flattened> for Jws {
/// }
/// ```
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct General {
#[serde(bound(deserialize = "U: Deserialize<'de>, P: serde::de::DeserializeOwned"))]
pub struct General<U = Unprotected, P = Protected<U>> {
/// The payload of the signature.
pub payload: Option<Bytes>,

/// The signatures over the payload.
pub signatures: Vec<Signature>,
pub signatures: Vec<Signature<U, P>>,
}

impl From<Flattened> for General {
fn from(value: Flattened) -> Self {
impl<U, P> From<Flattened<U, P>> for General<U, P> {
fn from(value: Flattened<U, P>) -> Self {
Self {
payload: value.payload,
signatures: vec![value.signature],
Expand All @@ -108,23 +110,25 @@ impl From<Flattened> for General {
/// }
/// ```
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Flattened {
#[serde(bound(deserialize = "U: Deserialize<'de>, P: serde::de::DeserializeOwned"))]
pub struct Flattened<U = Unprotected, P = Protected<U>> {
/// The payload of the signature.
pub payload: Option<Bytes>,

/// The signature over the payload.
#[serde(flatten)]
pub signature: Signature,
pub signature: Signature<U, P>,
}

/// A Signature
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Signature {
#[serde(bound(deserialize = "U: Deserialize<'de>, P: serde::de::DeserializeOwned"))]
pub struct Signature<U = Unprotected, P = Protected<U>> {
/// The JWS Unprotected Header
pub header: Option<Unprotected>,
pub header: Option<U>,

/// The JWS Protected Header
pub protected: Option<Json<Protected>>,
pub protected: Option<Json<P>>,

/// The Signature Bytes
pub signature: Bytes,
Expand Down