Skip to content

Commit

Permalink
Merge pull request #122 from salemove/reauthenticate_sasl_connections
Browse files Browse the repository at this point in the history
Proposal: Reauthenticate SASL connections based on session lifetime
  • Loading branch information
zmstone authored Aug 14, 2024
2 parents 8bece10 + 49ccc86 commit 95d0944
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 41 deletions.
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
* 4.1.7
- Automatically re-authenticate before session lifetime expires if SASL
authentication module returns `{ok, ServerResponse}` and ServerResponse
contains a non-zero `session_timeout_ms`.
https://github.com/kafka4beam/kafka_protocol/pull/122

* 4.1.6
- Fix docs. PR #120

Expand Down
10 changes: 6 additions & 4 deletions src/kpro_auth_backend.erl
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@

-export([auth/8]).

-type server_auth_response() :: term().

-callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.

-callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(),
HandShakeVsn :: non_neg_integer(), Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.

-optional_callbacks([auth/6]).

-spec auth(CallbackModule :: atom(), Host :: string(),
Sock :: gen_tcp:socket() | ssl:sslsocket(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.
auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) ->
CallbackModule:auth(Host, Sock, Mod, ClientName, Timeout, SaslOpts).

Expand All @@ -43,7 +45,7 @@ auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) ->
HandShakeVsn :: non_neg_integer(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.
auth(CallbackModule, Host, Sock, Vsn, Mod, ClientName, Timeout, SaslOpts) ->
case is_exported(CallbackModule, auth, 7) of
true ->
Expand Down
137 changes: 100 additions & 37 deletions src/kpro_connection.erl
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@

-record(state, { client_id :: client_id()
, parent :: pid()
, config :: config()
, remote :: kpro:endpoint()
, sock :: gen_tcp:socket() | ssl:sslsocket()
, mod :: ?undef | gen_tcp | ssl
, req_timeout :: ?undef | timeout()
, api_vsns :: ?undef | kpro:vsn_ranges()
, requests :: ?undef | requests()
, backlog :: false | queue:queue()
}).

-type state() :: #state{}.
Expand Down Expand Up @@ -226,10 +228,12 @@ connect(Parent, Host, Port, Config) ->
SockOpts = [{active, false}, binary] ++ get_extra_sock_opts(Config),
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, sock = Sock
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, config = Config
, sock = Sock
, backlog = false
},
init_connection(State, Config, Deadline);
{error, Reason} ->
Expand Down Expand Up @@ -260,14 +264,8 @@ init_connection(#state{ client_id = ClientId
#{query_api_versions := false} -> ?undef;
_ -> query_api_versions(NewSock, Mod, ClientId, Deadline)
end,
HandshakeVsn = case Versions of
#{sasl_handshake := {_, V}} -> V;
_ -> 0
end,
SaslOpts = get_sasl_opt(Config),
ok = kpro_sasl:auth(Host, NewSock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn),
State#state{mod = Mod, sock = NewSock, api_vsns = Versions}.
State1 = State#state{mod = Mod, sock = NewSock, api_vsns = Versions},
sasl_authenticate(State1).

query_api_versions(Sock, Mod, ClientId, Deadline) ->
Req = kpro_req_lib:make(api_versions, 0, []),
Expand Down Expand Up @@ -408,7 +406,8 @@ handle_msg({_, Sock, Bin}, #state{ sock = Sock
Rsp = kpro_rsp_lib:decode(API, Vsn, Body, Ref),
ok = cast(Caller, {msg, self(), Rsp}),
NewRequests = kpro_sent_reqs:del(Requests, CorrId),
?MODULE:loop(State#state{requests = NewRequests}, Debug);
State1 = maybe_flush_backlog(State#state{requests = NewRequests}),
?MODULE:loop(State1, Debug);
handle_msg(assert_max_req_age, #state{ requests = Requests
, req_timeout = ReqTimeout
} = State, Debug) ->
Expand All @@ -426,12 +425,41 @@ handle_msg({tcp_error, Sock, Reason}, #state{sock = Sock}, _) ->
exit({tcp_error, Reason});
handle_msg({ssl_error, Sock, Reason}, #state{sock = Sock}, _) ->
exit({ssl_error, Reason});
handle_msg({From, {send, Request}},
#state{ client_id = ClientId
, mod = Mod
, sock = Sock
, requests = Requests
} = State, Debug) ->
handle_msg({_From, {send, _}} = Msg, #state{backlog = false} = State, Debug) ->
State1 = send_request(Msg, State),
?MODULE:loop(State1, Debug);
handle_msg({_From, {send, _}} = Msg, #state{backlog = Q} = State, Debug) ->
%% Avoid sending new requests until in-flight requests have been resolved
State1 = State#state{backlog = queue:in(Msg, Q)},
?MODULE:loop(State1, Debug);
handle_msg({From, get_api_vsns}, State, Debug) ->
maybe_reply(From, {ok, State#state.api_vsns}),
?MODULE:loop(State, Debug);
handle_msg({From, get_endpoint}, State, Debug) ->
maybe_reply(From, {ok, State#state.remote}),
?MODULE:loop(State, Debug);
handle_msg({From, get_tcp_sock}, State, Debug) ->
maybe_reply(From, {ok, State#state.sock}),
?MODULE:loop(State, Debug);
handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) ->
Mod:close(Sock),
maybe_reply(From, ok),
ok;
handle_msg(sasl_authenticate, State, Debug) ->
State1 = State#state{backlog = queue:from_list([sasl_authenticate])},
State2 = maybe_flush_backlog(State1),
?MODULE:loop(State2, Debug);
handle_msg(Msg, #state{} = State, Debug) ->
error_logger:warning_msg("[~p] ~p got unrecognized message: ~p",
[?MODULE, self(), Msg]),
?MODULE:loop(State, Debug).

send_request({From, {send, Request}},
#state{ client_id = ClientId
, mod = Mod
, sock = Sock
, requests = Requests
} = State) ->
{Caller, _Ref} = From,
#kpro_req{api = API, vsn = Vsn} = Request,
{CorrId, NewRequests} =
Expand All @@ -457,24 +485,53 @@ handle_msg({From, {send, Request}},
],
exit({send_error, Reason})
end,
?MODULE:loop(State#state{requests = NewRequests}, Debug);
handle_msg({From, get_api_vsns}, State, Debug) ->
maybe_reply(From, {ok, State#state.api_vsns}),
?MODULE:loop(State, Debug);
handle_msg({From, get_endpoint}, State, Debug) ->
maybe_reply(From, {ok, State#state.remote}),
?MODULE:loop(State, Debug);
handle_msg({From, get_tcp_sock}, State, Debug) ->
maybe_reply(From, {ok, State#state.sock}),
?MODULE:loop(State, Debug);
handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) ->
Mod:close(Sock),
maybe_reply(From, ok),
ok;
handle_msg(Msg, #state{} = State, Debug) ->
error_logger:warning_msg("[~p] ~p got unrecognized message: ~p",
[?MODULE, self(), Msg]),
?MODULE:loop(State, Debug).
State#state{requests = NewRequests}.

maybe_flush_backlog(#state{backlog = false} = State) ->
State;
maybe_flush_backlog(#state{requests = Requests, backlog = Backlog} = State) ->
case kpro_sent_reqs:is_empty(Requests) of
true ->
NewState = case queue:out(Backlog) of
{{value, sasl_authenticate}, RemainingBacklog} ->
sasl_authenticate(State#state{backlog = RemainingBacklog});
{{value, {_From, {send, _}} = Msg}, RemainingBacklog} ->
send_request(Msg, State#state{backlog = RemainingBacklog});
{empty, _} ->
State#state{backlog = false}
end,
maybe_flush_backlog(NewState);
false ->
State
end.

sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config} = State) ->
Timeout = get_connect_timeout(Config),
Deadline = deadline(Timeout),
SaslOpts = get_sasl_opt(Config),
HandshakeVsn = case Versions of
#{sasl_handshake := {_, V}} -> V;
_ -> 0
end,
ok = setopts(Sock, Mod, [{active, false}]),
case kpro_sasl:auth(Host, Sock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn) of
ok ->
ok;
{ok, ServerResponse} ->
case find(session_lifetime_ms, ServerResponse) of
Lifetime when is_integer(Lifetime) andalso Lifetime > 0 ->
%% Broker can report back a maximal session lifetime: https://kafka.apache.org/protocol#The_Messages_SaslAuthenticate.
%% Respect the session lifetime by draining in-flight requests and re-authenticating in half the time.
ReauthenticationDeadline = Lifetime div 2,
_ = erlang:send_after(ReauthenticationDeadline, self(), sasl_authenticate),
ok;
_ ->
ok
end
end,
ok = setopts(Sock, Mod, [{active, once}]),
State.

cast(Pid, Msg) ->
try
Expand All @@ -500,8 +557,14 @@ format_status(Opt, Status) ->

print_msg(Device, {_From, {send, Request}}, State) ->
do_print_msg(Device, "send: ~p", [Request], State);
print_msg(Device, {_From, {get_api_vsns, Request}}, State) ->
do_print_msg(Device, "get_api_vsns", [Request], State);
print_msg(Device, sasl_authenticate, State) ->
do_print_msg(Device, "sasl_authenticate", [], State);
print_msg(Device, {tcp, _Sock, Bin}, State) ->
do_print_msg(Device, "tcp: ~p", [Bin], State);
print_msg(Device, {ssl, _Sock, Bin}, State) ->
do_print_msg(Device, "ssl: ~p", [Bin], State);
print_msg(Device, {tcp_closed, _Sock}, State) ->
do_print_msg(Device, "tcp_closed", [], State);
print_msg(Device, {tcp_error, _Sock, Reason}, State) ->
Expand Down
2 changes: 2 additions & 0 deletions src/kpro_sasl.erl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ auth(Host, Sock, Mod, ClientId, Timeout,
ClientId, Timeout, Opts) of
ok ->
ok;
{ok, ServerResponse} ->
{ok, ServerResponse};
{error, Reason} ->
?ERROR(Reason)
end;
Expand Down
4 changes: 4 additions & 0 deletions src/kpro_sent_reqs.erl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
, get_corr_id/1
, increment_corr_id/1
, scan_for_max_age/1
, is_empty/1
]).

-export_type([requests/0]).
Expand All @@ -56,6 +57,9 @@
-spec new() -> requests().
new() -> #requests{}.

-spec is_empty(requests()) -> boolean().
is_empty(#requests{sent = Sent}) -> maps:size(Sent) == 0.

%% @doc Add a new request to sent collection.
%% Return the last corrlation ID and the new collection.
-spec add(requests(), pid(), reference(), kpro:api(), kpro:vsn()) ->
Expand Down
29 changes: 29 additions & 0 deletions test/kpro_connection_tests.erl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
-include_lib("eunit/include/eunit.hrl").
-include("kpro_private.hrl").

-export([ auth/7 ]).

plaintext_test() ->
Config = kpro_test_lib:connection_config(plaintext),
{ok, Pid} = connect(Config),
Expand Down Expand Up @@ -57,6 +59,33 @@ sasl_file_test() ->
ok = kpro_connection:stop(Pid)
end.

% SASL callback implementation for subsequent tests
auth(_Host, _Sock, _Vsn, _Mod, _ClientName, _Timeout, #{test_pid := TestPid} = SaslOpts) ->
case SaslOpts of
#{response_session_lifetime_ms := ResponseSessionLifeTimeMs} ->
TestPid ! sasl_authenticated,
{ok, #{session_lifetime_ms => ResponseSessionLifeTimeMs}};
_ ->
ok
end.

sasl_callback_test() ->
Config0 = kpro_test_lib:connection_config(sasl_ssl),
case kpro_test_lib:get_kafka_version() of
?KAFKA_0_9 ->
ok;
_ ->
Config = Config0#{sasl => {callback, ?MODULE, #{response_session_lifetime_ms => 51, test_pid => self()}}},
{ok, Pid} = connect(Config),

% initial authentication
receive sasl_authenticated -> ok end,
% repeated authentication as session expires
receive sasl_authenticated -> ok end,

ok = kpro_connection:stop(Pid)
end.

no_api_version_query_test() ->
Config = #{query_api_versions => false},
{ok, Pid} = connect(Config),
Expand Down

0 comments on commit 95d0944

Please sign in to comment.