Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ http-tls = ["http", "reqwest/default-tls"]
http-native-tls = ["http", "reqwest/native-tls"]
http-rustls-tls = ["http", "reqwest/rustls-tls"]
signing = ["secp256k1", "once_cell"]
ws-tokio = ["soketto", "url", "tokio", "tokio-util", "headers"]
ws-async-std = ["soketto", "url", "async-std", "headers"]
ws-tokio = ["soketto", "url", "tokio", "tokio-util", "headers", "tokio-stream"]
ws-async-std = ["soketto", "url", "async-std", "headers", "tokio-stream"]
ws-tls-tokio = ["async-native-tls", "async-native-tls/runtime-tokio", "ws-tokio"]
ws-rustls-tokio = ["tokio-rustls", "webpki-roots", "rustls-pki-types", "ws-tokio"]
ws-tls-async-std = ["async-native-tls", "async-native-tls/runtime-async-std", "ws-async-std"]
Expand Down
14 changes: 13 additions & 1 deletion src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::{
Future,
};
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use serde::{de::DeserializeOwned, Deserialize};
use std::{marker::PhantomData, pin::Pin};

/// Takes any type which is deserializable from rpc::Value and such a value and
Expand Down Expand Up @@ -87,6 +87,18 @@ where
}
}

/// Extract the response id from slice. Used to obtain the response id if the deserialization of the whole response fails,
/// workraround for https://github.com/tomusdrw/rust-web3/issues/566
pub fn response_id_from_slice(response: &[u8]) -> Option<rpc::Id> {
#[derive(Deserialize)]
struct JustId {
id: rpc::Id,
}

let value: JustId = serde_json::from_slice(response).ok()?;
Some(value.id)
}

/// Parse bytes slice into JSON-RPC notification.
pub fn to_notification_from_slice(notification: &[u8]) -> error::Result<rpc::Notification> {
serde_json::from_slice(notification).map_err(|e| error::Error::InvalidResponse(format!("{:?}", e)))
Expand Down
148 changes: 111 additions & 37 deletions src/transports/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
};
use futures::{
channel::{mpsc, oneshot},
future,
task::{Context, Poll},
AsyncRead, AsyncWrite, Future, FutureExt, Stream, StreamExt,
};
Expand All @@ -17,10 +18,12 @@
};
use std::{
collections::BTreeMap,
convert::TryInto,
fmt,
marker::Unpin,
pin::Pin,
sync::{atomic, Arc},
time::{Duration, Instant},
};
use url::Url;

Expand All @@ -41,6 +44,8 @@
type Pending = oneshot::Sender<BatchResult>;
type Subscription = mpsc::UnboundedSender<rpc::Value>;

const PING_PONG_INTERVAL: Duration = Duration::from_secs(20);

/// Stream, either plain TCP or TLS.
enum MaybeTlsStream<P, T> {
/// Unencrypted socket stream.
Expand Down Expand Up @@ -95,6 +100,27 @@
subscriptions: BTreeMap<SubscriptionId, Subscription>,
sender: connection::Sender<MaybeTlsStream<TcpStream, TlsStream>>,
receiver: connection::Receiver<MaybeTlsStream<TcpStream, TlsStream>>,
ping_pong_interval: Option<Duration>,
}

#[cfg(target_arch = "wasm32")]
fn interval_stream(interval: Option<Duration>) -> impl Stream<Item = Instant> {
if interval.is_some() {
log::warn!("Ignoring the ping pong interval, feature unsupported on wasm32");
}
future::pending().into_stream()
}

#[cfg(not(target_arch = "wasm32"))]
fn interval_stream(interval: Option<Duration>) -> impl Stream<Item = Instant> {
use tokio::time;

Check failure on line 116 in src/transports/ws.rs

View workflow job for this annotation

GitHub Actions / Check Transports

unresolved import `tokio`
use tokio_stream::wrappers::IntervalStream;
if let Some(interval) = interval {
let interval = time::interval(interval);
IntervalStream::new(interval).map(|instant| instant.into_std()).boxed()
} else {
future::pending().into_stream().boxed()
}
}

impl WsServerTask {
Expand Down Expand Up @@ -196,6 +222,7 @@
subscriptions: Default::default(),
sender,
receiver,
ping_pong_interval: PING_PONG_INTERVAL.into(),
})
}

Expand All @@ -205,8 +232,11 @@
mut sender,
mut pending,
mut subscriptions,
ping_pong_interval,
} = self;

let mut ping_pong_interval = interval_stream(ping_pong_interval);

let receiver = as_data_stream(receiver).fuse();
let requests = requests.fuse();
pin_mut!(receiver);
Expand Down Expand Up @@ -248,6 +278,13 @@
},
None => break,
},
_ = ping_pong_interval.next().fuse() => {
log::trace!("Pinging the WS connection");
let data = [].as_slice().try_into().unwrap();
if let Err(e) = sender.send_ping(data).await {
log::error!("Sending ping failed: {}", e);
}
}
complete => break,
}
}
Expand Down Expand Up @@ -292,56 +329,93 @@
})
}

enum ParsedMessage {
/// Represents a JSON-RPC notification
Notification(rpc::Notification),
/// Represents a valid JSON-RPC response
Response(rpc::Response),
/// Represents an invalid JSON-RPC response
InvalidResponse(rpc::Id),
}

fn parse_message(data: &[u8]) -> Option<ParsedMessage> {
if let Ok(notification) = helpers::to_notification_from_slice(data) {
Some(ParsedMessage::Notification(notification))
} else if let Ok(response) = helpers::to_response_from_slice(data) {
Some(ParsedMessage::Response(response))
} else if let Some(id) = helpers::response_id_from_slice(data) {
Some(ParsedMessage::InvalidResponse(id))
} else {
None
}
}

fn respond(id: rpc::Id, outputs: Result<Vec<rpc::Output>, Error>, pending: &mut BTreeMap<RequestId, Pending>) {
if let rpc::Id::Num(num) = id {
if let Some(request) = pending.remove(&(num as usize)) {
log::trace!("Responding to (id: {:?}) with {:?}", num, outputs);
let response = outputs.and_then(helpers::to_results_from_outputs);
if let Err(err) = request.send(response) {
log::warn!("Sending a response to deallocated channel: {:?}", err);
}
} else {
log::warn!("Got response for unknown request (id: {:?})", num);
}
} else {
log::warn!("Got unsupported response (id: {:?})", id);
}
}

fn handle_message(
data: &[u8],
subscriptions: &BTreeMap<SubscriptionId, Subscription>,
pending: &mut BTreeMap<RequestId, Pending>,
) {
log::trace!("Message received: {:?}", data);
if let Ok(notification) = helpers::to_notification_from_slice(data) {
if let rpc::Params::Map(params) = notification.params {
let id = params.get("subscription");
let result = params.get("result");

if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
let id: SubscriptionId = id.clone().into();
if let Some(stream) = subscriptions.get(&id) {
if let Err(e) = stream.unbounded_send(result.clone()) {
log::error!("Error sending notification: {:?} (id: {:?}", e, id);
log::trace!("Message received: {:?}", String::from_utf8_lossy(data));
match parse_message(data) {
Some(ParsedMessage::Notification(notification)) => {
if let rpc::Params::Map(params) = notification.params {
let id = params.get("subscription");
let result = params.get("result");
log::debug!("params={:#?}", params);

if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
let id: SubscriptionId = id.clone().into();
log::debug!("subscriptions={:#?}", subscriptions);

if let Some(stream) = subscriptions.get(&id) {
if let Err(e) = stream.unbounded_send(result.clone()) {
log::error!("Error sending notification: {:?} (id: {:?}", e, id);
}
} else {
log::warn!("Got notification for unknown subscription (id: {:?})", id);
}
} else {
log::warn!("Got notification for unknown subscription (id: {:?})", id);
log::error!("Got unsupported notification (id: {:?})", id);
}
} else {
log::error!("Got unsupported notification (id: {:?})", id);
}
}
} else {
let response = helpers::to_response_from_slice(data);
let outputs = match response {
Ok(rpc::Response::Single(output)) => vec![output],
Ok(rpc::Response::Batch(outputs)) => outputs,
_ => vec![],
};
Some(ParsedMessage::Response(response)) => {
let outputs = match response {
rpc::Response::Single(output) => vec![output],
rpc::Response::Batch(outputs) => outputs,
};

let id = match outputs.get(0) {
Some(&rpc::Output::Success(ref success)) => success.id.clone(),
Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(),
None => rpc::Id::Num(0),
};
let id = match outputs.get(0).unwrap() {
&rpc::Output::Success(ref success) => success.id.clone(),
&rpc::Output::Failure(ref failure) => failure.id.clone(),
};

if let rpc::Id::Num(num) = id {
if let Some(request) = pending.remove(&(num as usize)) {
log::trace!("Responding to (id: {:?}) with {:?}", num, outputs);
if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) {
log::warn!("Sending a response to deallocated channel: {:?}", err);
}
} else {
log::warn!("Got response for unknown request (id: {:?})", num);
}
} else {
log::warn!("Got unsupported response (id: {:?})", id);
respond(id, Ok(outputs), pending);
}
Some(ParsedMessage::InvalidResponse(id)) => {
let error = Error::Decoder(String::from_utf8_lossy(data).to_string());
respond(id, Err(error), pending);
}
None => log::warn!(
"Got invalid response, which could not be parsed: {}",
String::from_utf8_lossy(data)
),
}
}

Expand Down
Loading