Skip to content

Commit

Permalink
Merge pull request #4077 from esl/c2s_features
Browse files Browse the repository at this point in the history
C2S features small optimisation
  • Loading branch information
chrzaszcz authored Aug 4, 2023
2 parents 017a1b5 + 4650328 commit b27f3bd
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 26 deletions.
29 changes: 16 additions & 13 deletions src/c2s/mongoose_c2s.erl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_info/1, set_info/2,
get_mod_state/2, get_listener_opts/1, merge_mod_state/2, remove_mod_state/2,
get_ip/1, get_socket/1, get_lang/1, get_stream_id/1, hook_arg/5]).
-export([filter_mechanism/2, c2s_stream_error/2, maybe_retry_state/1, merge_states/2]).
-export([get_auth_mechs/1, c2s_stream_error/2, maybe_retry_state/1, merge_states/2]).
-export([route/2, reroute_buffer/2, reroute_buffer_to_pid/3, open_session/1]).

-ignore_xref([get_ip/1, get_socket/1]).
Expand Down Expand Up @@ -260,16 +260,6 @@ close_socket(#c2s_data{socket = Socket}) ->
activate_socket(#c2s_data{socket = Socket}) ->
mongoose_c2s_socket:activate(Socket).

-spec filter_mechanism(data(), binary()) -> boolean().
filter_mechanism(#c2s_data{socket = Socket}, <<"SCRAM-SHA-1-PLUS">>) ->
mongoose_c2s_socket:is_channel_binding_supported(Socket);
filter_mechanism(#c2s_data{socket = Socket}, <<"SCRAM-SHA-", _N:3/binary, "-PLUS">>) ->
mongoose_c2s_socket:is_channel_binding_supported(Socket);
filter_mechanism(#c2s_data{socket = Socket, listener_opts = LOpts}, <<"EXTERNAL">>) ->
mongoose_c2s_socket:has_peer_cert(Socket, LOpts);
filter_mechanism(_, _) ->
true.

%%%----------------------------------------------------------------------
%%% error handler helpers
%%%----------------------------------------------------------------------
Expand Down Expand Up @@ -404,8 +394,7 @@ handle_auth_start(StateData, El, SaslState, Retries) ->
do_handle_auth_start(StateData, El, SaslState, Retries) ->
Mech = exml_query:attr(El, <<"mechanism">>),
ClientIn = base64:mime_decode(exml_query:cdata(El)),
HostType = StateData#c2s_data.host_type,
AuthMech = [M || M <- cyrsasl:listmech(HostType), filter_mechanism(StateData, M)],
AuthMech = get_auth_mechs(StateData),
SocketData = #{socket => StateData#c2s_data.socket, auth_mech => AuthMech,
listener_opts => StateData#c2s_data.listener_opts},
StepResult = cyrsasl:server_start(SaslState, Mech, ClientIn, SocketData),
Expand Down Expand Up @@ -1050,6 +1039,20 @@ cast(Pid, EventTag, EventContent) ->
create_data(#{host_type := HostType, jid := Jid}) ->
#c2s_data{host_type = HostType, jid = Jid}.

-spec get_auth_mechs(data()) -> [cyrsasl:mechanism()].
get_auth_mechs(#c2s_data{host_type = HostType} = StateData) ->
[M || M <- cyrsasl:listmech(HostType), filter_mechanism(StateData, M)].

-spec filter_mechanism(data(), binary()) -> boolean().
filter_mechanism(#c2s_data{socket = Socket}, <<"SCRAM-SHA-1-PLUS">>) ->
mongoose_c2s_socket:is_channel_binding_supported(Socket);
filter_mechanism(#c2s_data{socket = Socket}, <<"SCRAM-SHA-", _N:3/binary, "-PLUS">>) ->
mongoose_c2s_socket:is_channel_binding_supported(Socket);
filter_mechanism(#c2s_data{socket = Socket, listener_opts = LOpts}, <<"EXTERNAL">>) ->
mongoose_c2s_socket:has_peer_cert(Socket, LOpts);
filter_mechanism(_, _) ->
true.

-spec get_host_type(data()) -> mongooseim:host_type().
get_host_type(#c2s_data{host_type = HostType}) ->
HostType.
Expand Down
19 changes: 10 additions & 9 deletions src/c2s/mongoose_c2s_stanzas.erl
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,20 @@ stream_features_before_auth(HostType, LServer, LOpts, StateData) ->
determine_features(_, _, #{tls := #{mode := starttls_required}}, false, _StateData) ->
[starttls_stanza(required)];
determine_features(HostType, LServer, _, true, StateData) ->
mongoose_hooks:c2s_stream_features(HostType, LServer) ++ maybe_sasl_mechanisms(HostType, StateData);
InitialFeatures = maybe_sasl_mechanisms(StateData),
mongoose_hooks:c2s_stream_features(HostType, LServer, InitialFeatures);
determine_features(HostType, LServer, _, _, StateData) ->
[starttls_stanza(optional)
| mongoose_hooks:c2s_stream_features(HostType, LServer) ++ maybe_sasl_mechanisms(HostType, StateData)].
InitialFeatures = [starttls_stanza(optional) | maybe_sasl_mechanisms(StateData)],
mongoose_hooks:c2s_stream_features(HostType, LServer, InitialFeatures).

maybe_sasl_mechanisms(HostType, StateData) ->
case cyrsasl:listmech(HostType) of
-spec maybe_sasl_mechanisms(mongoose_c2s:data()) -> [exml:element()].
maybe_sasl_mechanisms(StateData) ->
case mongoose_c2s:get_auth_mechs(StateData) of
[] -> [];
Mechanisms ->
[#xmlel{name = <<"mechanisms">>,
attrs = [{<<"xmlns">>, ?NS_SASL}],
children = [ mechanism(M)
|| M <- Mechanisms, mongoose_c2s:filter_mechanism(StateData, M) ]}]
children = [ mechanism(M) || M <- Mechanisms ]}]
end.

-spec mechanism(binary()) -> exml:element().
Expand Down Expand Up @@ -103,8 +104,8 @@ stream_features_after_auth(HostType, LServer, #{backwards_compatible_session :=
stream_features(Features).

hook_enabled_features(HostType, LServer) ->
mongoose_hooks:roster_get_versioning_feature(HostType)
++ mongoose_hooks:c2s_stream_features(HostType, LServer).
InitialFeatures = mongoose_hooks:roster_get_versioning_feature(HostType),
mongoose_hooks:c2s_stream_features(HostType, LServer, InitialFeatures).

-spec sasl_success_stanza(binary()) -> exml:element().
sasl_success_stanza(ServerOut) ->
Expand Down
9 changes: 5 additions & 4 deletions src/hooks/mongoose_hooks.erl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

-export([get_pep_recipients/2,
filter_pep_recipient/3,
c2s_stream_features/2,
c2s_stream_features/3,
check_bl_c2s/1,
forbidden_session_hook/3,
session_opening_allowed_for_user/2]).
Expand Down Expand Up @@ -469,13 +469,14 @@ filter_pep_recipient(C2SData, Feature, To) ->
HostType = mongoose_c2s:get_host_type(C2SData),
run_hook_for_host_type(filter_pep_recipient, HostType, true, Params).

-spec c2s_stream_features(HostType, LServer) -> Result when
-spec c2s_stream_features(HostType, LServer, InitialFeatures) -> Result when
HostType :: mongooseim:host_type(),
LServer :: jid:lserver(),
InitialFeatures :: [exml:element()],
Result :: [exml:element()].
c2s_stream_features(HostType, LServer) ->
c2s_stream_features(HostType, LServer, InitialFeatures) ->
Params = #{lserver => LServer},
run_hook_for_host_type(c2s_stream_features, HostType, [], Params).
run_hook_for_host_type(c2s_stream_features, HostType, InitialFeatures, Params).

-spec check_bl_c2s(IP) -> Result when
IP :: inet:ip_address(),
Expand Down

0 comments on commit b27f3bd

Please sign in to comment.