From ebad4869e4fd539f659ad8ed41e17d7c3aaa5d38 Mon Sep 17 00:00:00 2001 From: urmastalimaa Date: Mon, 12 Aug 2024 16:42:15 +0300 Subject: [PATCH 1/3] Revert "Add kpro_connection:sasl_reauthenticate_after/2" This reverts commit aa0ba2183f09b02482cc17c1a5197a3deeb34d46 which was never included in the upstream --- src/kpro_connection.erl | 47 ---------------------------------- test/kpro_connection_tests.erl | 15 ----------- 2 files changed, 62 deletions(-) diff --git a/src/kpro_connection.erl b/src/kpro_connection.erl index d2f871c..1ad8c91 100644 --- a/src/kpro_connection.erl +++ b/src/kpro_connection.erl @@ -27,7 +27,6 @@ , send/2 , start/3 , stop/1 - , sasl_reauthenticate_after/2 , debug/2 ]). @@ -92,7 +91,6 @@ -record(state, { client_id :: client_id() , parent :: pid() - , config :: config() , remote :: kpro:endpoint() , sock :: gen_tcp:socket() | ssl:sslsocket() , mod :: ?undef | gen_tcp | ssl @@ -163,30 +161,6 @@ stop(Pid) when is_pid(Pid) -> stop(_) -> ok. -%% @doc Reauthenticates SASL by repeating the authentication flow after given time. -%% The authentication flow is repeated only once and not periodically. -%% This would be called by an authentication adapter to reauthenticate before -%% session_lifetime_ms provided in the v1 SASL authentication response is -%% reached. -%% Example use case: -%% -module(my_custom_sasl_authentication). -%% -%% auth(Host, Sock, Vsn, Mod, ClientId, Timeout, Opts) -> -%% {ok, SaslResponse} = do_authenticate(...), -%% case session_lifetime_ms(SaslResponse) of -%% SessionLifetime when SessionLifetime > 0 -> -%% kpro_connection:sasl_reauthenticate_after(self(), a_bit_before(SessionLifetime)); -%% _ -> -%% ok -%% end, -%% ok. --spec sasl_reauthenticate_after(connection(), timeout()) -> ok. -sasl_reauthenticate_after(Pid, Time) when is_pid(Pid) -> - erlang:send_after(Time, Pid, sasl_reauthenticate), - ok; -sasl_reauthenticate_after(_, _) -> - ok. - -spec get_api_vsns(pid()) -> {ok, ?undef | kpro:vsn_ranges()} | {error, any()}. get_api_vsns(Pid) -> @@ -253,7 +227,6 @@ connect(Parent, Host, Port, Config) -> State = #state{ client_id = get_client_id(Config) , parent = Parent , remote = {Host, Port} - , config = Config , sock = Sock }, init_connection(State, Config, Deadline); @@ -496,29 +469,11 @@ handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) -> Mod:close(Sock), maybe_reply(From, ok), ok; -handle_msg(sasl_reauthenticate, State, Debug) -> - do_sasl_reauthenticate(State), - ?MODULE:loop(State, Debug); handle_msg(Msg, #state{} = State, Debug) -> error_logger:warning_msg("[~p] ~p got unrecognized message: ~p", [?MODULE, self(), Msg]), ?MODULE:loop(State, Debug). -do_sasl_reauthenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config}) -> - %% Imitates logic in init -> connect, but using existing api_vsns and socket - Timeout = get_connect_timeout(Config), - Deadline = deadline(Timeout), - SaslOpts = get_sasl_opt(Config), - HandshakeVsn = case Versions of - #{sasl_handshake := {_, V}} -> V; - _ -> 0 - end, - ok = setopts(Sock, Mod, [{active, false}]), - ok = kpro_sasl:auth(Host, Sock, Mod, ClientId, - timeout(Deadline), SaslOpts, HandshakeVsn), - ok = setopts(Sock, Mod, [{active, once}]), - ok. - cast(Pid, Msg) -> try Pid ! Msg, @@ -545,8 +500,6 @@ print_msg(Device, {_From, {send, Request}}, State) -> do_print_msg(Device, "send: ~p", [Request], State); print_msg(Device, {_From, {get_api_vsns, Request}}, State) -> do_print_msg(Device, "get_api_vsns", [Request], State); -print_msg(Device, sasl_reauthenticate, State) -> - do_print_msg(Device, "sasl_reauthenticate", [], State); print_msg(Device, {tcp, _Sock, Bin}, State) -> do_print_msg(Device, "tcp: ~p", [Bin], State); print_msg(Device, {ssl, _Sock, Bin}, State) -> diff --git a/test/kpro_connection_tests.erl b/test/kpro_connection_tests.erl index 583379b..44c0f16 100644 --- a/test/kpro_connection_tests.erl +++ b/test/kpro_connection_tests.erl @@ -73,21 +73,6 @@ extra_sock_opts_test() -> ?assertEqual(true, proplists:get_value(delay_send, InetSockOpts)), ok = kpro_connection:stop(Pid). -sasl_reauthenticate_after_test() -> - Config0 = kpro_test_lib:connection_config(ssl), - case kpro_test_lib:get_kafka_version() of - ?KAFKA_0_9 -> - ok; - ?KAFKA_0_10 -> - {ok, Pid} = connect(Config0#{sasl => kpro_test_lib:sasl_config(file)}), - ok = kpro_connection:sasl_reauthenticate_after(Pid, 1000), - ok = kpro_connection:stop(Pid); - _ -> - {ok, Pid} = connect(Config0#{sasl => kpro_test_lib:sasl_config(file)}), - ok = kpro_connection:sasl_reauthenticate_after(Pid, 1000), - ok = kpro_connection:stop(Pid) - end. - connect(Config) -> Protocol = kpro_test_lib:guess_protocol(Config), [{Host, Port} | _] = kpro_test_lib:get_endpoints(Protocol), From ebbeff887b9c429bd3c241a851e19e82f9ef41c1 Mon Sep 17 00:00:00 2001 From: urmastalimaa Date: Fri, 9 Aug 2024 14:16:26 +0300 Subject: [PATCH 2/3] Allow SASL callback modules to return {ok, ServerResponse} This expansion of the callback return values allows `kpro_connection` to interrogate the server response message, in preparation for re-authenticating SASL connections before session lifetime expires. Authentication was moved to a separate function to allow repeating authentication flow, which also required storing connection configuration in process state. --- src/kpro_auth_backend.erl | 10 ++++++---- src/kpro_connection.erl | 31 +++++++++++++++++++++++-------- src/kpro_sasl.erl | 2 ++ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/kpro_auth_backend.erl b/src/kpro_auth_backend.erl index 568cbf6..60c9f3e 100644 --- a/src/kpro_auth_backend.erl +++ b/src/kpro_auth_backend.erl @@ -18,15 +18,17 @@ -export([auth/8]). +-type server_auth_response() :: term(). + -callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(), Mod :: gen_tcp | ssl, ClientName :: binary(), Timeout :: pos_integer(), SaslOpts :: term()) -> - ok | {error, Reason :: term()}. + ok | {ok, server_auth_response()} | {error, Reason :: term()}. -callback auth(Host :: string(), Sock :: gen_tcp:socket() | ssl:sslsocket(), HandShakeVsn :: non_neg_integer(), Mod :: gen_tcp | ssl, ClientName :: binary(), Timeout :: pos_integer(), SaslOpts :: term()) -> - ok | {error, Reason :: term()}. + ok | {ok, server_auth_response()} | {error, Reason :: term()}. -optional_callbacks([auth/6]). @@ -34,7 +36,7 @@ Sock :: gen_tcp:socket() | ssl:sslsocket(), Mod :: gen_tcp | ssl, ClientName :: binary(), Timeout :: pos_integer(), SaslOpts :: term()) -> - ok | {error, Reason :: term()}. + ok | {ok, server_auth_response()} | {error, Reason :: term()}. auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) -> CallbackModule:auth(Host, Sock, Mod, ClientName, Timeout, SaslOpts). @@ -43,7 +45,7 @@ auth(CallbackModule, Host, Sock, Mod, ClientName, Timeout, SaslOpts) -> HandShakeVsn :: non_neg_integer(), Mod :: gen_tcp | ssl, ClientName :: binary(), Timeout :: pos_integer(), SaslOpts :: term()) -> - ok | {error, Reason :: term()}. + ok | {ok, server_auth_response()} | {error, Reason :: term()}. auth(CallbackModule, Host, Sock, Vsn, Mod, ClientName, Timeout, SaslOpts) -> case is_exported(CallbackModule, auth, 7) of true -> diff --git a/src/kpro_connection.erl b/src/kpro_connection.erl index 1ad8c91..0d3be1a 100644 --- a/src/kpro_connection.erl +++ b/src/kpro_connection.erl @@ -91,6 +91,7 @@ -record(state, { client_id :: client_id() , parent :: pid() + , config :: config() , remote :: kpro:endpoint() , sock :: gen_tcp:socket() | ssl:sslsocket() , mod :: ?undef | gen_tcp | ssl @@ -227,6 +228,7 @@ connect(Parent, Host, Port, Config) -> State = #state{ client_id = get_client_id(Config) , parent = Parent , remote = {Host, Port} + , config = Config , sock = Sock }, init_connection(State, Config, Deadline); @@ -258,14 +260,8 @@ init_connection(#state{ client_id = ClientId #{query_api_versions := false} -> ?undef; _ -> query_api_versions(NewSock, Mod, ClientId, Deadline) end, - HandshakeVsn = case Versions of - #{sasl_handshake := {_, V}} -> V; - _ -> 0 - end, - SaslOpts = get_sasl_opt(Config), - ok = kpro_sasl:auth(Host, NewSock, Mod, ClientId, - timeout(Deadline), SaslOpts, HandshakeVsn), - State#state{mod = Mod, sock = NewSock, api_vsns = Versions}. + State1 = State#state{mod = Mod, sock = NewSock, api_vsns = Versions}, + sasl_authenticate(State1). query_api_versions(Sock, Mod, ClientId, Deadline) -> Req = kpro_req_lib:make(api_versions, 0, []), @@ -474,6 +470,25 @@ handle_msg(Msg, #state{} = State, Debug) -> [?MODULE, self(), Msg]), ?MODULE:loop(State, Debug). +sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = {Host, _Port}, api_vsns = Versions, config = Config} = State) -> + Timeout = get_connect_timeout(Config), + Deadline = deadline(Timeout), + SaslOpts = get_sasl_opt(Config), + HandshakeVsn = case Versions of + #{sasl_handshake := {_, V}} -> V; + _ -> 0 + end, + ok = setopts(Sock, Mod, [{active, false}]), + case kpro_sasl:auth(Host, Sock, Mod, ClientId, + timeout(Deadline), SaslOpts, HandshakeVsn) of + ok -> + ok; + {ok, _ServerResponse} -> + ok + end, + ok = setopts(Sock, Mod, [{active, once}]), + State. + cast(Pid, Msg) -> try Pid ! Msg, diff --git a/src/kpro_sasl.erl b/src/kpro_sasl.erl index ec274a0..f09814a 100644 --- a/src/kpro_sasl.erl +++ b/src/kpro_sasl.erl @@ -39,6 +39,8 @@ auth(Host, Sock, Mod, ClientId, Timeout, ClientId, Timeout, Opts) of ok -> ok; + {ok, ServerResponse} -> + {ok, ServerResponse}; {error, Reason} -> ?ERROR(Reason) end; From ed065d0704b0bf2a84bfef6afcb032d97a1c45d8 Mon Sep 17 00:00:00 2001 From: urmastalimaa Date: Fri, 9 Aug 2024 14:31:28 +0300 Subject: [PATCH 3/3] Reauthenticate SASL connections based on session lifetime The broker response to a SASL authentication request can contain a maximum session lifetime (see the [KIP][kip]). Session lifetime is returned by the broker in [Version 1 SaslAuthenticate Response][sasl_authenticate_protocol]. When a SASL authentication callback returns `{ok, ServerResponse}` and the ServerResponse contains a larger than 0 session lifetime, kpro_connection automatically sets a timer to re-authenticate in half the session lifetime. As kpro_sasl mechanisms are synchronous, in-flight requests must first be drained to ensure that kpro_sasl receives a response to its own SASL request. The draining mechanism is tied to the main loop, flushing the post-drain queue when `requests` are empty. When requests are not empty, previous behaviour is retained with the exception of `{From, {send, Req}}` handler, which adds the request onto the queue when in `drain` state. [kip]: https://cwiki.apache.org/confluence/display/KAFKA/KIP-368%3A+Allow+SASL+Connections+to+Periodically+Re-Authenticate [sasl_authenticate_protocol]: https://kafka.apache.org/protocol#The_Messages_SaslAuthenticate --- src/kpro_connection.erl | 62 +++++++++++++++++++++++++++++----- src/kpro_sent_reqs.erl | 4 +++ test/kpro_connection_tests.erl | 29 ++++++++++++++++ 3 files changed, 86 insertions(+), 9 deletions(-) diff --git a/src/kpro_connection.erl b/src/kpro_connection.erl index 0d3be1a..31138a2 100644 --- a/src/kpro_connection.erl +++ b/src/kpro_connection.erl @@ -86,6 +86,7 @@ -type portnum() :: kpro:portnum(). -type client_id() :: kpro:client_id(). -type connection() :: pid(). +-type post_drain_queue() :: list(). -define(undef, undefined). @@ -98,6 +99,7 @@ , req_timeout :: ?undef | timeout() , api_vsns :: ?undef | kpro:vsn_ranges() , requests :: ?undef | requests() + , draining :: false | {drain | flush_queue, post_drain_queue()} }). -type state() :: #state{}. @@ -225,11 +227,12 @@ connect(Parent, Host, Port, Config) -> SockOpts = [{active, false}, binary] ++ get_extra_sock_opts(Config), case gen_tcp:connect(Host, Port, SockOpts, Timeout) of {ok, Sock} -> - State = #state{ client_id = get_client_id(Config) - , parent = Parent - , remote = {Host, Port} - , config = Config - , sock = Sock + State = #state{ client_id = get_client_id(Config) + , parent = Parent + , remote = {Host, Port} + , config = Config + , sock = Sock + , draining = false }, init_connection(State, Config, Deadline); {error, Reason} -> @@ -380,9 +383,25 @@ maybe_reply({To, Ref}, Reply) -> _ = erlang:send(To, {Ref, Reply}), ok. -loop(#state{} = State, Debug) -> +loop(#state{draining = false} = State, Debug) -> Msg = receive Input -> Input end, - decode_msg(Msg, State, Debug). + decode_msg(Msg, State, Debug); + +loop(#state{draining = {_DrainState, DrainQueue}, requests = Requests} = State, Debug) -> + case kpro_sent_reqs:is_empty(Requests) of + true -> + case DrainQueue of + [Msg | Rest] -> + %% decode calls back to loop which recursively flushes Rest + decode_msg(Msg, State#state{draining = {flush_queue, Rest}}, Debug); + [] -> + %% Queue has been flushed, draining is complete + loop(State#state{draining = false}, Debug) + end; + false -> + Msg = receive Input -> Input end, + decode_msg(Msg, State, Debug) + end. decode_msg({system, From, Msg}, #state{parent = Parent} = State, Debug) -> sys:handle_system_msg(Msg, From, Parent, ?MODULE, Debug, State); @@ -420,6 +439,9 @@ handle_msg({tcp_error, Sock, Reason}, #state{sock = Sock}, _) -> exit({tcp_error, Reason}); handle_msg({ssl_error, Sock, Reason}, #state{sock = Sock}, _) -> exit({ssl_error, Reason}); +handle_msg({_From, {send, _}} = Msg, #state{ draining = {drain, _}} = State, Debug) -> + %% Avoid sending new requests until in-flight requests have been resolved + ?MODULE:loop(drain_and_postpone(Msg, State), Debug); handle_msg({From, {send, Request}}, #state{ client_id = ClientId , mod = Mod @@ -465,6 +487,10 @@ handle_msg({From, stop}, #state{mod = Mod, sock = Sock}, _Debug) -> Mod:close(Sock), maybe_reply(From, ok), ok; +handle_msg(sasl_authenticate, #state{draining = {flush_queue, _}} = State, Debug) -> + ?MODULE:loop(sasl_authenticate(State), Debug); +handle_msg(sasl_authenticate, State, Debug) -> + ?MODULE:loop(drain_and_postpone(sasl_authenticate, State), Debug); handle_msg(Msg, #state{} = State, Debug) -> error_logger:warning_msg("[~p] ~p got unrecognized message: ~p", [?MODULE, self(), Msg]), @@ -483,12 +509,28 @@ sasl_authenticate(#state{client_id = ClientId, mod = Mod, sock = Sock, remote = timeout(Deadline), SaslOpts, HandshakeVsn) of ok -> ok; - {ok, _ServerResponse} -> - ok + {ok, ServerResponse} -> + case find(session_lifetime_ms, ServerResponse) of + Lifetime when is_integer(Lifetime) andalso Lifetime > 0 -> + %% Broker can report back a maximal session lifetime: https://kafka.apache.org/protocol#The_Messages_SaslAuthenticate. + %% Respect the session lifetime by draining in-flight requests and re-authenticating in half the time. + ReauthenticationDeadline = Lifetime div 2, + _ = erlang:send_after(ReauthenticationDeadline, self(), sasl_authenticate), + ok; + _ -> + ok + end end, ok = setopts(Sock, Mod, [{active, once}]), State. +drain_and_postpone(Msg, #state{draining = false} = State) -> + DrainQueue = [Msg], + State#state{draining = {drain, DrainQueue}}; +drain_and_postpone(Msg, #state{draining = {drain, DrainQueue}} = State) -> + DrainQueue = lists:append(DrainQueue, [Msg]), + State#state{draining = {drain, DrainQueue}}. + cast(Pid, Msg) -> try Pid ! Msg, @@ -515,6 +557,8 @@ print_msg(Device, {_From, {send, Request}}, State) -> do_print_msg(Device, "send: ~p", [Request], State); print_msg(Device, {_From, {get_api_vsns, Request}}, State) -> do_print_msg(Device, "get_api_vsns", [Request], State); +print_msg(Device, sasl_authenticate, State) -> + do_print_msg(Device, "sasl_authenticate", [], State); print_msg(Device, {tcp, _Sock, Bin}, State) -> do_print_msg(Device, "tcp: ~p", [Bin], State); print_msg(Device, {ssl, _Sock, Bin}, State) -> diff --git a/src/kpro_sent_reqs.erl b/src/kpro_sent_reqs.erl index 2dc09f9..8dc8161 100644 --- a/src/kpro_sent_reqs.erl +++ b/src/kpro_sent_reqs.erl @@ -33,6 +33,7 @@ , get_corr_id/1 , increment_corr_id/1 , scan_for_max_age/1 + , is_empty/1 ]). -export_type([requests/0]). @@ -56,6 +57,9 @@ -spec new() -> requests(). new() -> #requests{}. +-spec is_empty(requests()) -> boolean(). +is_empty(#requests{sent = Sent}) -> maps:size(Sent) == 0. + %% @doc Add a new request to sent collection. %% Return the last corrlation ID and the new collection. -spec add(requests(), pid(), reference(), kpro:api(), kpro:vsn()) -> diff --git a/test/kpro_connection_tests.erl b/test/kpro_connection_tests.erl index 44c0f16..0ee17d3 100644 --- a/test/kpro_connection_tests.erl +++ b/test/kpro_connection_tests.erl @@ -17,6 +17,8 @@ -include_lib("eunit/include/eunit.hrl"). -include("kpro_private.hrl"). +-export([ auth/7 ]). + plaintext_test() -> Config = kpro_test_lib:connection_config(plaintext), {ok, Pid} = connect(Config), @@ -57,6 +59,33 @@ sasl_file_test() -> ok = kpro_connection:stop(Pid) end. +% SASL callback implementation for subsequent tests +auth(_Host, _Sock, _Vsn, _Mod, _ClientName, _Timeout, #{test_pid := TestPid} = SaslOpts) -> + case SaslOpts of + #{response_session_lifetime_ms := ResponseSessionLifeTimeMs} -> + TestPid ! sasl_authenticated, + {ok, #{session_lifetime_ms => ResponseSessionLifeTimeMs}}; + _ -> + ok + end. + +sasl_callback_test() -> + Config0 = kpro_test_lib:connection_config(sasl_ssl), + case kpro_test_lib:get_kafka_version() of + ?KAFKA_0_9 -> + ok; + _ -> + Config = Config0#{sasl => {callback, ?MODULE, #{response_session_lifetime_ms => 51, test_pid => self()}}}, + {ok, Pid} = connect(Config), + + % initial authentication + receive sasl_authenticated -> ok end, + % repeated authentication as session expires + receive sasl_authenticated -> ok end, + + ok = kpro_connection:stop(Pid) + end. + no_api_version_query_test() -> Config = #{query_api_versions => false}, {ok, Pid} = connect(Config),