Skip to content

Commit

Permalink
Merge pull request #2 from salemove/replace_reauthenticate_after_with…
Browse files Browse the repository at this point in the history
…_automatic_reauthentication

Replace reauthenticate after with automatic reauthentication
  • Loading branch information
urmastalimaa authored Aug 12, 2024
2 parents 7486e39 + ed065d0 commit 1624e79
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 69 deletions.
10 changes: 6 additions & 4 deletions src/kpro_auth_backend.erl
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@

-export([auth/8]).

-type server_auth_response() :: term().

-callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.

-callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(),
HandShakeVsn :: non_neg_integer(), Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.

-optional_callbacks([auth/6]).

-spec auth(CallbackModule :: atom(), Host :: string(),
Sock :: gen_tcp:socket() | ssl:sslsocket(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.
auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) ->
CallbackModule:auth(Host, Sock, Mod, ClientName, Timeout, SaslOpts).

Expand All @@ -43,7 +45,7 @@ auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) ->
HandShakeVsn :: non_neg_integer(),
Mod :: gen_tcp | ssl, ClientName :: binary(),
Timeout :: pos_integer(), SaslOpts :: term()) ->
ok | {error, Reason :: term()}.
ok | {ok, server_auth_response()} | {error, Reason :: term()}.
auth(CallbackModule, Host, Sock, Vsn, Mod, ClientName, Timeout, SaslOpts) ->
case is_exported(CallbackModule, auth, 7) of
true ->
Expand Down
112 changes: 62 additions & 50 deletions src/kpro_connection.erl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
, send/2
, start/3
, stop/1
, sasl_reauthenticate_after/2
, debug/2
]).

Expand Down Expand Up @@ -87,6 +86,7 @@
-type portnum() :: kpro:portnum().
-type client_id() :: kpro:client_id().
-type connection() :: pid().
-type post_drain_queue() :: list().

-define(undef, undefined).

Expand All @@ -99,6 +99,7 @@
, req_timeout :: ?undef | timeout()
, api_vsns :: ?undef | kpro:vsn_ranges()
, requests :: ?undef | requests()
, draining :: false | {drain | flush_queue, post_drain_queue()}
}).

-type state() :: #state{}.
Expand Down Expand Up @@ -163,30 +164,6 @@ stop(Pid) when is_pid(Pid) ->
stop(_) ->
ok.

%% @doc Reauthenticates SASL by repeating the authentication flow after given time.
%% The authentication flow is repeated only once and not periodically.
%% This would be called by an authentication adapter to reauthenticate before
%% session_lifetime_ms provided in the v1 SASL authentication response is
%% reached.
%% Example use case:
%% -module(my_custom_sasl_authentication).
%%
%% auth(Host, Sock, Vsn, Mod, ClientId, Timeout, Opts) ->
%% {ok, SaslResponse} = do_authenticate(...),
%% case session_lifetime_ms(SaslResponse) of
%% SessionLifetime when SessionLifetime > 0 ->
%% kpro_connection:sasl_reauthenticate_after(self(), a_bit_before(SessionLifetime));
%% _ ->
%% ok
%% end,
%% ok.
-spec sasl_reauthenticate_after(connection(), timeout()) -> ok.
sasl_reauthenticate_after(Pid, Time) when is_pid(Pid) ->
erlang:send_after(Time, Pid, sasl_reauthenticate),
ok;
sasl_reauthenticate_after(_, _) ->
ok.

-spec get_api_vsns(pid()) ->
{ok, ?undef | kpro:vsn_ranges()} | {error, any()}.
get_api_vsns(Pid) ->
Expand Down Expand Up @@ -250,11 +227,12 @@ connect(Parent, Host, Port, Config) ->
SockOpts = [{active, false}, binary] ++ get_extra_sock_opts(Config),
case gen_tcp:connect(Host, Port, SockOpts, Timeout) of
{ok, Sock} ->
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, config = Config
, sock = Sock
State = #state{ client_id = get_client_id(Config)
, parent = Parent
, remote = {Host, Port}
, config = Config
, sock = Sock
, draining = false
},
init_connection(State, Config, Deadline);
{error, Reason} ->
Expand Down Expand Up @@ -285,14 +263,8 @@ init_connection(#state{ client_id = ClientId
#{query_api_versions := false} -> ?undef;
_ -> query_api_versions(NewSock, Mod, ClientId, Deadline)
end,
HandshakeVsn = case Versions of
#{sasl_handshake := {_, V}} -> V;
_ -> 0
end,
SaslOpts = get_sasl_opt(Config),
ok = kpro_sasl:auth(Host, NewSock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn),
State#state{mod = Mod, sock = NewSock, api_vsns = Versions}.
State1 = State#state{mod = Mod, sock = NewSock, api_vsns = Versions},
sasl_authenticate(State1).

query_api_versions(Sock, Mod, ClientId, Deadline) ->
Req = kpro_req_lib:make(api_versions, 0, []),
Expand Down Expand Up @@ -411,9 +383,25 @@ maybe_reply({To, Ref}, Reply) ->
_ = erlang:send(To, {Ref, Reply}),
ok.

loop(#state{} = State, Debug) ->
loop(#state{draining = false} = State, Debug) ->
Msg = receive Input -> Input end,
decode_msg(Msg, State, Debug).
decode_msg(Msg, State, Debug);

loop(#state{draining = {_DrainState, DrainQueue}, requests = Requests} = State, Debug) ->
case kpro_sent_reqs:is_empty(Requests) of
true ->
case DrainQueue of
[Msg | Rest] ->
%% decode calls back to loop which recursively flushes Rest
decode_msg(Msg, State#state{draining = {flush_queue, Rest}}, Debug);
[] ->
%% Queue has been flushed, draining is complete
loop(State#state{draining = false}, Debug)
end;
false ->
Msg = receive Input -> Input end,
decode_msg(Msg, State, Debug)
end.

decode_msg({system, From, Msg}, #state{parent = Parent} = State, Debug) ->
sys:handle_system_msg(Msg, From, Parent, ?MODULE, Debug, State);
Expand Down Expand Up @@ -451,6 +439,9 @@ handle_msg({tcp_error, Sock, Reason}, #state{sock = Sock}, _) ->
exit({tcp_error, Reason});
handle_msg({ssl_error, Sock, Reason}, #state{sock = Sock}, _) ->
exit({ssl_error, Reason});
handle_msg({_From, {send, _}} = Msg, #state{ draining = {drain, _}} = State, Debug) ->
%% Avoid sending new requests until in-flight requests have been resolved
?MODULE:loop(drain_and_postpone(Msg, State), Debug);
handle_msg({From, {send, Request}},
#state{ client_id = ClientId
, mod = Mod
Expand Down Expand Up @@ -496,16 +487,16 @@ handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) ->
Mod:close(Sock),
maybe_reply(From, ok),
ok;
handle_msg(sasl_reauthenticate, State, Debug) ->
do_sasl_reauthenticate(State),
?MODULE:loop(State, Debug);
handle_msg(sasl_authenticate, #state{draining = {flush_queue, _}} = State, Debug) ->
?MODULE:loop(sasl_authenticate(State), Debug);
handle_msg(sasl_authenticate, State, Debug) ->
?MODULE:loop(drain_and_postpone(sasl_authenticate, State), Debug);
handle_msg(Msg, #state{} = State, Debug) ->
error_logger:warning_msg("[~p] ~p got unrecognized message: ~p",
[?MODULE, self(), Msg]),
?MODULE:loop(State, Debug).

do_sasl_reauthenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config}) ->
%% Imitates logic in init -> connect, but using existing api_vsns and socket
sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config} = State) ->
Timeout = get_connect_timeout(Config),
Deadline = deadline(Timeout),
SaslOpts = get_sasl_opt(Config),
Expand All @@ -514,10 +505,31 @@ do_sasl_reauthenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remo
_ -> 0
end,
ok = setopts(Sock, Mod, [{active, false}]),
ok = kpro_sasl:auth(Host, Sock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn),
case kpro_sasl:auth(Host, Sock, Mod, ClientId,
timeout(Deadline), SaslOpts, HandshakeVsn) of
ok ->
ok;
{ok, ServerResponse} ->
case find(session_lifetime_ms, ServerResponse) of
Lifetime when is_integer(Lifetime) andalso Lifetime > 0 ->
%% Broker can report back a maximal session lifetime: https://kafka.apache.org/protocol#The_Messages_SaslAuthenticate.
%% Respect the session lifetime by draining in-flight requests and re-authenticating in half the time.
ReauthenticationDeadline = Lifetime div 2,
_ = erlang:send_after(ReauthenticationDeadline, self(), sasl_authenticate),
ok;
_ ->
ok
end
end,
ok = setopts(Sock, Mod, [{active, once}]),
ok.
State.

drain_and_postpone(Msg, #state{draining = false} = State) ->
DrainQueue = [Msg],
State#state{draining = {drain, DrainQueue}};
drain_and_postpone(Msg, #state{draining = {drain, DrainQueue}} = State) ->
DrainQueue = lists:append(DrainQueue, [Msg]),
State#state{draining = {drain, DrainQueue}}.

cast(Pid, Msg) ->
try
Expand Down Expand Up @@ -545,8 +557,8 @@ print_msg(Device, {_From, {send, Request}}, State) ->
do_print_msg(Device, "send: ~p", [Request], State);
print_msg(Device, {_From, {get_api_vsns, Request}}, State) ->
do_print_msg(Device, "get_api_vsns", [Request], State);
print_msg(Device, sasl_reauthenticate, State) ->
do_print_msg(Device, "sasl_reauthenticate", [], State);
print_msg(Device, sasl_authenticate, State) ->
do_print_msg(Device, "sasl_authenticate", [], State);
print_msg(Device, {tcp, _Sock, Bin}, State) ->
do_print_msg(Device, "tcp: ~p", [Bin], State);
print_msg(Device, {ssl, _Sock, Bin}, State) ->
Expand Down
2 changes: 2 additions & 0 deletions src/kpro_sasl.erl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ auth(Host, Sock, Mod, ClientId, Timeout,
ClientId, Timeout, Opts) of
ok ->
ok;
{ok, ServerResponse} ->
{ok, ServerResponse};
{error, Reason} ->
?ERROR(Reason)
end;
Expand Down
4 changes: 4 additions & 0 deletions src/kpro_sent_reqs.erl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
, get_corr_id/1
, increment_corr_id/1
, scan_for_max_age/1
, is_empty/1
]).

-export_type([requests/0]).
Expand All @@ -56,6 +57,9 @@
-spec new() -> requests().
new() -> #requests{}.

-spec is_empty(requests()) -> boolean().
is_empty(#requests{sent = Sent}) -> maps:size(Sent) == 0.

%% @doc Add a new request to sent collection.
%% Return the last corrlation ID and the new collection.
-spec add(requests(), pid(), reference(), kpro:api(), kpro:vsn()) ->
Expand Down
44 changes: 29 additions & 15 deletions test/kpro_connection_tests.erl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
-include_lib("eunit/include/eunit.hrl").
-include("kpro_private.hrl").

-export([ auth/7 ]).

plaintext_test() ->
Config = kpro_test_lib:connection_config(plaintext),
{ok, Pid} = connect(Config),
Expand Down Expand Up @@ -57,6 +59,33 @@ sasl_file_test() ->
ok = kpro_connection:stop(Pid)
end.

% SASL callback implementation for subsequent tests
auth(_Host, _Sock, _Vsn, _Mod, _ClientName, _Timeout, #{test_pid := TestPid} = SaslOpts) ->
case SaslOpts of
#{response_session_lifetime_ms := ResponseSessionLifeTimeMs} ->
TestPid ! sasl_authenticated,
{ok, #{session_lifetime_ms => ResponseSessionLifeTimeMs}};
_ ->
ok
end.

sasl_callback_test() ->
Config0 = kpro_test_lib:connection_config(sasl_ssl),
case kpro_test_lib:get_kafka_version() of
?KAFKA_0_9 ->
ok;
_ ->
Config = Config0#{sasl => {callback, ?MODULE, #{response_session_lifetime_ms => 51, test_pid => self()}}},
{ok, Pid} = connect(Config),

% initial authentication
receive sasl_authenticated -> ok end,
% repeated authentication as session expires
receive sasl_authenticated -> ok end,

ok = kpro_connection:stop(Pid)
end.

no_api_version_query_test() ->
Config = #{query_api_versions => false},
{ok, Pid} = connect(Config),
Expand All @@ -73,21 +102,6 @@ extra_sock_opts_test() ->
?assertEqual(true, proplists:get_value(delay_send, InetSockOpts)),
ok = kpro_connection:stop(Pid).

sasl_reauthenticate_after_test() ->
Config0 = kpro_test_lib:connection_config(ssl),
case kpro_test_lib:get_kafka_version() of
?KAFKA_0_9 ->
ok;
?KAFKA_0_10 ->
{ok, Pid} = connect(Config0#{sasl => kpro_test_lib:sasl_config(file)}),
ok = kpro_connection:sasl_reauthenticate_after(Pid, 1000),
ok = kpro_connection:stop(Pid);
_ ->
{ok, Pid} = connect(Config0#{sasl => kpro_test_lib:sasl_config(file)}),
ok = kpro_connection:sasl_reauthenticate_after(Pid, 1000),
ok = kpro_connection:stop(Pid)
end.

connect(Config) ->
Protocol = kpro_test_lib:guess_protocol(Config),
[{Host, Port} | _] = kpro_test_lib:get_endpoints(Protocol),
Expand Down

0 comments on commit 1624e79

Please sign in to comment.