diff --git a/include/escalus_xmlns.hrl b/include/escalus_xmlns.hrl index a920062..0b766a9 100644 --- a/include/escalus_xmlns.hrl +++ b/include/escalus_xmlns.hrl @@ -12,6 +12,9 @@ % Defined by XMPP Core (RFC 3920). -define(NS_XMPP, <<"http://etherx.jabber.org/streams">>). +% Defined by XMPP Subprotocol for WebSocket (RFC 7395). +-define(NS_FRAMING, <<"urn:ietf:params:xml:ns:xmpp-framing">>). + -define(NS_STREAM_ERRORS, <<"urn:ietf:params:xml:ns:xmpp-streams">>). -define(NS_TLS, <<"urn:ietf:params:xml:ns:xmpp-tls">>). -define(NS_SASL, <<"urn:ietf:params:xml:ns:xmpp-sasl">>). diff --git a/src/escalus_bosh.erl b/src/escalus_bosh.erl index a49a35e..ec3b851 100644 --- a/src/escalus_bosh.erl +++ b/src/escalus_bosh.erl @@ -141,8 +141,8 @@ set_filter_predicate(Pid, Pred) -> -spec stream_start_req(escalus_users:user_spec()) -> exml_stream:element(). stream_start_req(Props) -> {server, Server} = lists:keyfind(server, 1, Props), - NS = proplists:get_value(stream_ns, Props, <<"jabber:client">>), - escalus_stanza:stream_start(Server, NS). + Attrs = proplists:get_value(stream_attrs, Props, #{}), + escalus_stanza:stream_start(Server, Attrs). -spec stream_end_req(_) -> exml_stream:element(). stream_end_req(_) -> diff --git a/src/escalus_stanza.erl b/src/escalus_stanza.erl index a2bdf03..73b751a 100644 --- a/src/escalus_stanza.erl +++ b/src/escalus_stanza.erl @@ -105,10 +105,13 @@ sm_ack/1, resume/2]). --export([stream_start/2, +-export([stream_start/1, + stream_start/2, stream_end/0, ws_open/1, + ws_open/2, ws_close/0, + ws_close/1, starttls/0, compress/1]). @@ -144,34 +147,51 @@ -define(i2l(I), erlang:integer_to_list(I)). -define(io2b(IOList), erlang:iolist_to_binary(IOList)). +-type attrs() :: #{binary() => binary() | undefined}. + %%-------------------------------------------------------------------- %% Stream - related functions %%-------------------------------------------------------------------- --spec stream_start(binary(), binary()) -> exml_stream:start(). -stream_start(Server, XMLNS) -> +-spec stream_start(binary() | undefined) -> exml_stream:start(). +stream_start(To) -> + stream_start(To, #{}). + +-spec stream_start(binary() | undefined, attrs()) -> exml_stream:start(). +stream_start(To, ExtraAttrs) -> + BasicAttrs = #{<<"to">> => To, + <<"version">> => <<"1.0">>, + <<"xml:lang">> => <<"en">>, + <<"xmlns">> => ?NS_JABBER_CLIENT, + <<"xmlns:stream">> => ?NS_XMPP}, #xmlstreamstart{name = <<"stream:stream">>, - attrs = #{<<"to">> => Server, - <<"version">> => <<"1.0">>, - <<"xml:lang">> => <<"en">>, - <<"xmlns">> => XMLNS, - <<"xmlns:stream">> => <<"http://etherx.jabber.org/streams">>}}. + attrs = skip_undefined(maps:merge(BasicAttrs, ExtraAttrs))}. -spec stream_end() -> exml_stream:stop(). stream_end() -> #xmlstreamend{name = <<"stream:stream">>}. --spec ws_open(binary()) -> exml:element(). -ws_open(Server) -> - #xmlel{name= <<"open">>, - attrs = #{<<"xmlns">> => <<"urn:ietf:params:xml:ns:xmpp-framing">>, - <<"to">> => Server, - <<"version">> => <<"1.0">>}}. +-spec ws_open(binary() | undefined) -> exml:element(). +ws_open(To) -> + ws_open(To, #{}). + +-spec ws_open(binary() | undefined, attrs()) -> exml:element(). +ws_open(To, ExtraAttrs) -> + BasicAttrs = #{<<"to">> => To, + <<"version">> => <<"1.0">>, + <<"xmlns">> => ?NS_FRAMING}, + #xmlel{name = <<"open">>, + attrs = skip_undefined(maps:merge(BasicAttrs, ExtraAttrs))}. -spec ws_close() -> exml:element(). ws_close() -> - #xmlel{name= <<"close">>, - attrs = #{<<"xmlns">> => <<"urn:ietf:params:xml:ns:xmpp-framing">>}}. + ws_close(#{}). + +-spec ws_close(attrs()) -> exml:element(). +ws_close(ExtraAttrs) -> + BasicAttrs = #{<<"xmlns">> => ?NS_FRAMING}, + #xmlel{name = <<"close">>, + attrs = skip_undefined(maps:merge(BasicAttrs, ExtraAttrs))}. -spec starttls() -> exml:element(). starttls() -> @@ -1011,3 +1031,6 @@ argument_to_string(E) when is_binary(E) -> ?b2l(E); argument_to_string(E) when is_list(E) -> E; argument_to_string(I) when is_integer(I) -> ?i2l(I); argument_to_string(F) when is_float(F) -> io_lib:format("~.2f", [F]). + +skip_undefined(Map) -> + maps:filter(fun(_Key, Value) -> Value =/= undefined end, Map). diff --git a/src/escalus_tcp.erl b/src/escalus_tcp.erl index 023cca3..067c98e 100644 --- a/src/escalus_tcp.erl +++ b/src/escalus_tcp.erl @@ -170,8 +170,8 @@ use_zlib(Pid) -> -spec stream_start_req(escalus_users:user_spec()) -> exml_stream:element(). stream_start_req(Props) -> {server, Server} = lists:keyfind(server, 1, Props), - NS = proplists:get_value(stream_ns, Props, <<"jabber:client">>), - escalus_stanza:stream_start(Server, NS). + Attrs = proplists:get_value(stream_attrs, Props, #{}), + escalus_stanza:stream_start(Server, Attrs). -spec stream_end_req(_) -> exml_stream:element(). stream_end_req(_) -> diff --git a/src/escalus_users.erl b/src/escalus_users.erl index 51a6f38..093edbe 100644 --- a/src/escalus_users.erl +++ b/src/escalus_users.erl @@ -372,6 +372,7 @@ is_mod_register_enabled(Config) -> | 'connection_steps' %% [escalus_session:step()] | 'parser_opts' %% a list of exml parser opts, %% e.g. infinite_stream + | stream_attrs %% XML attributes for or open/close elements | received_stanza_handlers %% list of escalus_connection:stanza_handler() | sent_stanza_handlers %% similar as above but for sent stanzas . diff --git a/src/escalus_ws.erl b/src/escalus_ws.erl index 16b0681..2db3ceb 100644 --- a/src/escalus_ws.erl +++ b/src/escalus_ws.erl @@ -128,19 +128,22 @@ use_zlib(Pid) -> -spec stream_start_req(escalus_users:user_spec()) -> exml_stream:element(). stream_start_req(Props) -> {server, Server} = lists:keyfind(server, 1, Props), + Attrs = proplists:get_value(stream_attrs, Props, #{}), case proplists:get_value(wslegacy, Props, false) of true -> - NS = proplists:get_value(stream_ns, Props, <<"jabber:client">>), - escalus_stanza:stream_start(Server, NS); + escalus_stanza:stream_start(Server, Attrs); false -> - escalus_stanza:ws_open(Server) + escalus_stanza:ws_open(Server, Attrs) end. -spec stream_end_req(_) -> exml_stream:element(). stream_end_req(Props) -> case proplists:get_value(wslegacy, Props, false) of - true -> escalus_stanza:stream_end(); - false -> escalus_stanza:ws_close() + true -> + escalus_stanza:stream_end(); + false -> + Attrs = maps:with([<<"xmlns">>], proplists:get_value(stream_attrs, Props, #{})), + escalus_stanza:ws_close(Attrs) end. -spec assert_stream_start(exml_stream:element(), _) -> exml_stream:element().