diff --git a/Cargo.toml b/Cargo.toml index e861269f..45e9b38d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,8 +83,8 @@ http-tls = ["http", "reqwest/default-tls"] http-native-tls = ["http", "reqwest/native-tls"] http-rustls-tls = ["http", "reqwest/rustls-tls"] signing = ["secp256k1", "once_cell"] -ws-tokio = ["soketto", "url", "tokio", "tokio-util", "headers"] -ws-async-std = ["soketto", "url", "async-std", "headers"] +ws-tokio = ["soketto", "url", "tokio", "tokio-util", "headers", "tokio-stream"] +ws-async-std = ["soketto", "url", "async-std", "headers", "tokio-stream"] ws-tls-tokio = ["async-native-tls", "async-native-tls/runtime-tokio", "ws-tokio"] ws-rustls-tokio = ["tokio-rustls", "webpki-roots", "rustls-pki-types", "ws-tokio"] ws-tls-async-std = ["async-native-tls", "async-native-tls/runtime-async-std", "ws-async-std"] diff --git a/src/helpers.rs b/src/helpers.rs index b88a8710..45da34ad 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -6,7 +6,7 @@ use futures::{ Future, }; use pin_project::pin_project; -use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Deserialize}; use std::{marker::PhantomData, pin::Pin}; /// Takes any type which is deserializable from rpc::Value and such a value and @@ -87,6 +87,18 @@ where } } +/// Extract the response id from slice. Used to obtain the response id if the deserialization of the whole response fails, +/// workraround for https://github.com/tomusdrw/rust-web3/issues/566 +pub fn response_id_from_slice(response: &[u8]) -> Option { + #[derive(Deserialize)] + struct JustId { + id: rpc::Id, + } + + let value: JustId = serde_json::from_slice(response).ok()?; + Some(value.id) +} + /// Parse bytes slice into JSON-RPC notification. pub fn to_notification_from_slice(notification: &[u8]) -> error::Result { serde_json::from_slice(notification).map_err(|e| error::Error::InvalidResponse(format!("{:?}", e))) diff --git a/src/transports/ws.rs b/src/transports/ws.rs index 1e6cd19e..abd3636e 100644 --- a/src/transports/ws.rs +++ b/src/transports/ws.rs @@ -8,6 +8,7 @@ use crate::{ }; use futures::{ channel::{mpsc, oneshot}, + future, task::{Context, Poll}, AsyncRead, AsyncWrite, Future, FutureExt, Stream, StreamExt, }; @@ -17,10 +18,12 @@ use soketto::{ }; use std::{ collections::BTreeMap, + convert::TryInto, fmt, marker::Unpin, pin::Pin, sync::{atomic, Arc}, + time::{Duration, Instant}, }; use url::Url; @@ -41,6 +44,8 @@ type BatchResult = error::Result>; type Pending = oneshot::Sender; type Subscription = mpsc::UnboundedSender; +const PING_PONG_INTERVAL: Duration = Duration::from_secs(20); + /// Stream, either plain TCP or TLS. enum MaybeTlsStream { /// Unencrypted socket stream. @@ -95,6 +100,27 @@ struct WsServerTask { subscriptions: BTreeMap, sender: connection::Sender>, receiver: connection::Receiver>, + ping_pong_interval: Option, +} + +#[cfg(target_arch = "wasm32")] +fn interval_stream(interval: Option) -> impl Stream { + if interval.is_some() { + log::warn!("Ignoring the ping pong interval, feature unsupported on wasm32"); + } + future::pending().into_stream() +} + +#[cfg(not(target_arch = "wasm32"))] +fn interval_stream(interval: Option) -> impl Stream { + use tokio::time; + use tokio_stream::wrappers::IntervalStream; + if let Some(interval) = interval { + let interval = time::interval(interval); + IntervalStream::new(interval).map(|instant| instant.into_std()).boxed() + } else { + future::pending().into_stream().boxed() + } } impl WsServerTask { @@ -196,6 +222,7 @@ impl WsServerTask { subscriptions: Default::default(), sender, receiver, + ping_pong_interval: PING_PONG_INTERVAL.into(), }) } @@ -205,8 +232,11 @@ impl WsServerTask { mut sender, mut pending, mut subscriptions, + ping_pong_interval, } = self; + let mut ping_pong_interval = interval_stream(ping_pong_interval); + let receiver = as_data_stream(receiver).fuse(); let requests = requests.fuse(); pin_mut!(receiver); @@ -248,6 +278,13 @@ impl WsServerTask { }, None => break, }, + _ = ping_pong_interval.next().fuse() => { + log::trace!("Pinging the WS connection"); + let data = [].as_slice().try_into().unwrap(); + if let Err(e) = sender.send_ping(data).await { + log::error!("Sending ping failed: {}", e); + } + } complete => break, } } @@ -292,56 +329,93 @@ fn as_data_stream( }) } +enum ParsedMessage { + /// Represents a JSON-RPC notification + Notification(rpc::Notification), + /// Represents a valid JSON-RPC response + Response(rpc::Response), + /// Represents an invalid JSON-RPC response + InvalidResponse(rpc::Id), +} + +fn parse_message(data: &[u8]) -> Option { + if let Ok(notification) = helpers::to_notification_from_slice(data) { + Some(ParsedMessage::Notification(notification)) + } else if let Ok(response) = helpers::to_response_from_slice(data) { + Some(ParsedMessage::Response(response)) + } else if let Some(id) = helpers::response_id_from_slice(data) { + Some(ParsedMessage::InvalidResponse(id)) + } else { + None + } +} + +fn respond(id: rpc::Id, outputs: Result, Error>, pending: &mut BTreeMap) { + if let rpc::Id::Num(num) = id { + if let Some(request) = pending.remove(&(num as usize)) { + log::trace!("Responding to (id: {:?}) with {:?}", num, outputs); + let response = outputs.and_then(helpers::to_results_from_outputs); + if let Err(err) = request.send(response) { + log::warn!("Sending a response to deallocated channel: {:?}", err); + } + } else { + log::warn!("Got response for unknown request (id: {:?})", num); + } + } else { + log::warn!("Got unsupported response (id: {:?})", id); + } +} + fn handle_message( data: &[u8], subscriptions: &BTreeMap, pending: &mut BTreeMap, ) { - log::trace!("Message received: {:?}", data); - if let Ok(notification) = helpers::to_notification_from_slice(data) { - if let rpc::Params::Map(params) = notification.params { - let id = params.get("subscription"); - let result = params.get("result"); - - if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) { - let id: SubscriptionId = id.clone().into(); - if let Some(stream) = subscriptions.get(&id) { - if let Err(e) = stream.unbounded_send(result.clone()) { - log::error!("Error sending notification: {:?} (id: {:?}", e, id); + log::trace!("Message received: {:?}", String::from_utf8_lossy(data)); + match parse_message(data) { + Some(ParsedMessage::Notification(notification)) => { + if let rpc::Params::Map(params) = notification.params { + let id = params.get("subscription"); + let result = params.get("result"); + log::debug!("params={:#?}", params); + + if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) { + let id: SubscriptionId = id.clone().into(); + log::debug!("subscriptions={:#?}", subscriptions); + + if let Some(stream) = subscriptions.get(&id) { + if let Err(e) = stream.unbounded_send(result.clone()) { + log::error!("Error sending notification: {:?} (id: {:?}", e, id); + } + } else { + log::warn!("Got notification for unknown subscription (id: {:?})", id); } } else { - log::warn!("Got notification for unknown subscription (id: {:?})", id); + log::error!("Got unsupported notification (id: {:?})", id); } - } else { - log::error!("Got unsupported notification (id: {:?})", id); } } - } else { - let response = helpers::to_response_from_slice(data); - let outputs = match response { - Ok(rpc::Response::Single(output)) => vec![output], - Ok(rpc::Response::Batch(outputs)) => outputs, - _ => vec![], - }; + Some(ParsedMessage::Response(response)) => { + let outputs = match response { + rpc::Response::Single(output) => vec![output], + rpc::Response::Batch(outputs) => outputs, + }; - let id = match outputs.get(0) { - Some(&rpc::Output::Success(ref success)) => success.id.clone(), - Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(), - None => rpc::Id::Num(0), - }; + let id = match outputs.get(0).unwrap() { + &rpc::Output::Success(ref success) => success.id.clone(), + &rpc::Output::Failure(ref failure) => failure.id.clone(), + }; - if let rpc::Id::Num(num) = id { - if let Some(request) = pending.remove(&(num as usize)) { - log::trace!("Responding to (id: {:?}) with {:?}", num, outputs); - if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) { - log::warn!("Sending a response to deallocated channel: {:?}", err); - } - } else { - log::warn!("Got response for unknown request (id: {:?})", num); - } - } else { - log::warn!("Got unsupported response (id: {:?})", id); + respond(id, Ok(outputs), pending); + } + Some(ParsedMessage::InvalidResponse(id)) => { + let error = Error::Decoder(String::from_utf8_lossy(data).to_string()); + respond(id, Err(error), pending); } + None => log::warn!( + "Got invalid response, which could not be parsed: {}", + String::from_utf8_lossy(data) + ), } }