Skip to content

Commit

Permalink
Reauthenticate SASL connections based on session lifetime
Browse files Browse the repository at this point in the history
The broker response to a SASL authentication request can contain a
maximum session lifetime (see the [KIP][kip]).
Session lifetime is returned by the broker in [Version 1
SaslAuthenticate Response][sasl_authenticate_protocol].

When a SASL authentication callback returns `{ok, ServerResponse}` and
the ServerResponse contains a larger than 0 session lifetime,
kpro_connection automatically sets a timer to re-authenticate in half
the session lifetime.

As kpro_sasl mechanisms are synchronous, in-flight requests must first
be drained to ensure that kpro_sasl receives a response to its own SASL
request.

The draining algorithm behaves as follows:
* `sasl_authenticate` message handler adds the message onto the backlog
  and immediately flush the backlog if there are no in-flight requests.
* `{From, {send, Request}}` handler adds the request onto the backlog if
  the backlog has any items to allow in-flight requests to drain.
* Inbound message handler flushes the backlog if in-flight requests are
  empty.

[kip]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-368%3A+Allow+SASL+Connections+to+Periodically+Re-Authenticate
[sasl_authenticate_protocol]: https://kafka.apache.org/protocol#The_Messages_SaslAuthenticate
  • Loading branch information
urmastalimaa committed Aug 13, 2024
1 parent ca0bf49 commit 9f1e8cf
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 32 deletions.
108 changes: 76 additions & 32 deletions src/kpro_connection.erl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
, req_timeout :: ?undef | timeout()
, api_vsns :: ?undef | kpro:vsn_ranges()
, requests :: ?undef | requests()
, backlog :: false | queue:queue()
}).

-type state() :: #state{}.
Expand Down Expand Up @@ -225,11 +226,12 @@ connect(Parent, Host, Port, Config) ->
SockOpts = [{active, false}, binary] ++ get_extra_sock_opts(Config),
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, config = Config
, sock = Sock
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, config = Config
, sock = Sock
, backlog = false
},
init_connection(State, Config, Deadline);
{error, Reason} ->
Expand Down Expand Up @@ -402,7 +404,8 @@ handle_msg({_, Sock, Bin}, #state{ sock = Sock
Rsp = kpro_rsp_lib:decode(API, Vsn, Body, Ref),
ok = cast(Caller, {msg, self(), Rsp}),
NewRequests = kpro_sent_reqs:del(Requests, CorrId),
?MODULE:loop(State#state{requests = NewRequests}, Debug);
State1 = maybe_flush_backlog(State#state{requests = NewRequests}),
?MODULE:loop(State1, Debug);
handle_msg(assert_max_req_age, #state{ requests = Requests
, req_timeout = ReqTimeout
} = State, Debug) ->
Expand All @@ -420,12 +423,41 @@ handle_msg({tcp_error, Sock, Reason}, #state{sock = Sock}, _) ->
exit({tcp_error, Reason});
handle_msg({ssl_error, Sock, Reason}, #state{sock = Sock}, _) ->
exit({ssl_error, Reason});
handle_msg({From, {send, Request}},
#state{ client_id = ClientId
, mod = Mod
, sock = Sock
, requests = Requests
} = State, Debug) ->
handle_msg({_From, {send, _}} = Msg, #state{backlog = false} = State, Debug) ->
State1 = send_request(Msg, State),
?MODULE:loop(State1, Debug);
handle_msg({_From, {send, _}} = Msg, #state{backlog = Q} = State, Debug) ->
%% Avoid sending new requests until in-flight requests have been resolved
State1 = State#state{backlog = queue:in(Msg, Q)},
?MODULE:loop(State1, Debug);
handle_msg({From, get_api_vsns}, State, Debug) ->
maybe_reply(From, {ok, State#state.api_vsns}),
?MODULE:loop(State, Debug);
handle_msg({From, get_endpoint}, State, Debug) ->
maybe_reply(From, {ok, State#state.remote}),
?MODULE:loop(State, Debug);
handle_msg({From, get_tcp_sock}, State, Debug) ->
maybe_reply(From, {ok, State#state.sock}),
?MODULE:loop(State, Debug);
handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) ->
Mod:close(Sock),
maybe_reply(From, ok),
ok;
handle_msg(sasl_authenticate, State, Debug) ->
State1 = State#state{backlog = queue:from_list([sasl_authenticate])},
State2 = maybe_flush_backlog(State1),
?MODULE:loop(State2, Debug);
handle_msg(Msg, #state{} = State, Debug) ->
error_logger:warning_msg("[~p] ~p got unrecognized message: ~p",
[?MODULE, self(), Msg]),
?MODULE:loop(State, Debug).

send_request({From, {send, Request}},
#state{ client_id = ClientId
, mod = Mod
, sock = Sock
, requests = Requests
} = State) ->
{Caller, _Ref} = From,
#kpro_req{api = API, vsn = Vsn} = Request,
{CorrId, NewRequests} =
Expand All @@ -451,24 +483,25 @@ handle_msg({From, {send, Request}},
],
exit({send_error, Reason})
end,
?MODULE:loop(State#state{requests = NewRequests}, Debug);
handle_msg({From, get_api_vsns}, State, Debug) ->
maybe_reply(From, {ok, State#state.api_vsns}),
?MODULE:loop(State, Debug);
handle_msg({From, get_endpoint}, State, Debug) ->
maybe_reply(From, {ok, State#state.remote}),
?MODULE:loop(State, Debug);
handle_msg({From, get_tcp_sock}, State, Debug) ->
maybe_reply(From, {ok, State#state.sock}),
?MODULE:loop(State, Debug);
handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) ->
Mod:close(Sock),
maybe_reply(From, ok),
ok;
handle_msg(Msg, #state{} = State, Debug) ->
error_logger:warning_msg("[~p] ~p got unrecognized message: ~p",
[?MODULE, self(), Msg]),
?MODULE:loop(State, Debug).
State#state{requests = NewRequests}.

maybe_flush_backlog(#state{backlog = false} = State) ->
State;
maybe_flush_backlog(#state{requests = Requests, backlog = Backlog} = State) ->
case kpro_sent_reqs:is_empty(Requests) of
true ->
NewState = case queue:out(Backlog) of
{{value, sasl_authenticate}, RemainingBacklog} ->
sasl_authenticate(State#state{backlog = RemainingBacklog});
{{value, {_From, {send, _}} = Msg}, RemainingBacklog} ->
send_request(Msg, State#state{backlog = RemainingBacklog});
{empty, _} ->
State#state{backlog = false}
end,
maybe_flush_backlog(NewState);
false ->
State
end.

sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config} = State) ->
Timeout = get_connect_timeout(Config),
Expand All @@ -483,8 +516,17 @@ sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote =
timeout(Deadline), SaslOpts, HandshakeVsn) of
ok ->
ok;
{ok, _ServerResponse} ->
ok
{ok, ServerResponse} ->
case find(session_lifetime_ms, ServerResponse) of
Lifetime when is_integer(Lifetime) andalso Lifetime > 0 ->
%% Broker can report back a maximal session lifetime: https://kafka.apache.org/protocol#The_Messages_SaslAuthenticate.
%% Respect the session lifetime by draining in-flight requests and re-authenticating in half the time.
ReauthenticationDeadline = Lifetime div 2,
_ = erlang:send_after(ReauthenticationDeadline, self(), sasl_authenticate),
ok;
_ ->
ok
end
end,
ok = setopts(Sock, Mod, [{active, once}]),
State.
Expand Down Expand Up @@ -515,6 +557,8 @@ print_msg(Device, {_From, {send, Request}}, State) ->
do_print_msg(Device, "send: ~p", [Request], State);
print_msg(Device, {_From, {get_api_vsns, Request}}, State) ->
do_print_msg(Device, "get_api_vsns", [Request], State);
print_msg(Device, sasl_authenticate, State) ->
do_print_msg(Device, "sasl_authenticate", [], State);
print_msg(Device, {tcp, _Sock, Bin}, State) ->
do_print_msg(Device, "tcp: ~p", [Bin], State);
print_msg(Device, {ssl, _Sock, Bin}, State) ->
Expand Down
4 changes: 4 additions & 0 deletions src/kpro_sent_reqs.erl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
, get_corr_id/1
, increment_corr_id/1
, scan_for_max_age/1
, is_empty/1
]).

-export_type([requests/0]).
Expand All @@ -56,6 +57,9 @@
-spec new() -> requests().
new() -> #requests{}.

-spec is_empty(requests()) -> boolean().
is_empty(#requests{sent = Sent}) -> maps:size(Sent) == 0.

%% @doc Add a new request to sent collection.
%% Return the last corrlation ID and the new collection.
-spec add(requests(), pid(), reference(), kpro:api(), kpro:vsn()) ->
Expand Down
29 changes: 29 additions & 0 deletions test/kpro_connection_tests.erl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
-include_lib("eunit/include/eunit.hrl").
-include("kpro_private.hrl").

-export([ auth/7 ]).

plaintext_test() ->
Config = kpro_test_lib:connection_config(plaintext),
{ok, Pid} = connect(Config),
Expand Down Expand Up @@ -57,6 +59,33 @@ sasl_file_test() ->
ok = kpro_connection:stop(Pid)
end.

% SASL callback implementation for subsequent tests
auth(_Host, _Sock, _Vsn, _Mod, _ClientName, _Timeout, #{test_pid := TestPid} = SaslOpts) ->
case SaslOpts of
#{response_session_lifetime_ms := ResponseSessionLifeTimeMs} ->
TestPid ! sasl_authenticated,
{ok, #{session_lifetime_ms => ResponseSessionLifeTimeMs}};
_ ->
ok
end.

sasl_callback_test() ->
Config0 = kpro_test_lib:connection_config(sasl_ssl),
case kpro_test_lib:get_kafka_version() of
?KAFKA_0_9 ->
ok;
_ ->
Config = Config0#{sasl => {callback, ?MODULE, #{response_session_lifetime_ms => 51, test_pid => self()}}},
{ok, Pid} = connect(Config),

% initial authentication
receive sasl_authenticated -> ok end,
% repeated authentication as session expires
receive sasl_authenticated -> ok end,

ok = kpro_connection:stop(Pid)
end.

no_api_version_query_test() ->
Config = #{query_api_versions => false},
{ok, Pid} = connect(Config),
Expand Down

0 comments on commit 9f1e8cf

Please sign in to comment.