Skip to content

Commit

Permalink
Merge branch 'master' into update-cicd
Browse files Browse the repository at this point in the history
  • Loading branch information
wbarnha committed Jun 26, 2023
2 parents 336e0c9 + f79cc16 commit 96e7b0c
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 54 deletions.
35 changes: 35 additions & 0 deletions docs/includes/settingref.txt
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,41 @@ SASL Authentication
password=BROKER_PASSWORD,
))

OAuth2 Authentication
You can enable SASL authentication via OAuth2 Bearer tokens:

.. sourcecode:: python

import faust
from asyncio import get_running_loop
from aiokafka.helpers import create_ssl_context
from aiokafka.conn import AbstractTokenProvider

class TokenProvider(AbstractTokenProvider):
async def token(self):
return await get_running_loop().run_in_executor(
None, self.get_token)

def get_token(self):
return 'token'

app = faust.App(
broker_credentials=faust.OAuthCredentials(
oauth_cb=TokenProvider()
ssl_context=create_ssl_context()
)
)

.. info::

The implementation should ensure token reuse so that multiple
calls at connect time do not create multiple tokens.
The implementation should also periodically refresh the token in order to
guarantee that each call returns an unexpired token.

Token Providers MUST implement the :meth:`token` method


GSSAPI Authentication
GSSAPI authentication over plain text:

Expand Down
9 changes: 8 additions & 1 deletion faust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ def _extract_arg_from_argv( # pragma: no cover

from .agents import Agent # noqa: E402
from .app import App # noqa: E402
from .auth import GSSAPICredentials, SASLCredentials, SSLCredentials # noqa: E402
from .auth import ( # noqa: E402
GSSAPICredentials,
OAuthCredentials,
SASLCredentials,
SSLCredentials,
)
from .channels import Channel, ChannelT # noqa: E402
from .events import Event, EventT # noqa: E402
from .models import Model, ModelOptions, Record # noqa: E402
Expand Down Expand Up @@ -184,6 +189,7 @@ def _extract_arg_from_argv( # pragma: no cover
"TopicT",
"GSSAPICredentials",
"SASLCredentials",
"OAuthCredentials",
"SSLCredentials",
"Settings",
"HoppingWindow",
Expand Down Expand Up @@ -219,6 +225,7 @@ def _extract_arg_from_argv( # pragma: no cover
"GSSAPICredentials",
"SASLCredentials",
"SSLCredentials",
"OAuthCredentials",
],
"faust.types.settings": ["Settings"],
"faust.windows": [
Expand Down
30 changes: 30 additions & 0 deletions faust/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import ssl
from typing import Any, Optional, Union

from aiokafka.conn import AbstractTokenProvider

from faust.types.auth import AuthProtocol, CredentialsT, SASLMechanism

__all__ = [
"Credentials",
"SASLCredentials",
"OAuthCredentials",
"GSSAPICredentials",
"SSLCredentials",
]
Expand Down Expand Up @@ -49,6 +52,33 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}: username={self.username}>"


class OAuthCredentials(Credentials):
"""Describe OAuth Bearer credentials over SASL"""

protocol = AuthProtocol.SASL_PLAINTEXT
mechanism: SASLMechanism = SASLMechanism.OAUTHBEARER

ssl_context: Optional[ssl.SSLContext]

def __init__(
self,
*,
oauth_cb: AbstractTokenProvider,
ssl_context: Optional[ssl.SSLContext] = None,
):
self.oauth_cb = oauth_cb
self.ssl_context = ssl_context

if ssl_context is not None:
self.protocol = AuthProtocol.SASL_SSL

def __repr__(self) -> str:
return "<{0}: oauth credentials {1} SSL support".format(
type(self).__name__,
"with" if self.protocol == AuthProtocol.SASL_SSL else "without",
)


class GSSAPICredentials(Credentials):
"""Describe GSSAPI credentials over SASL."""

Expand Down
19 changes: 1 addition & 18 deletions faust/tables/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,18 @@ class ChangeloggedObjectManager(Store):
data: MutableMapping

_storage: Optional[StoreT] = None
_dirty: Set

def __init__(self, table: Table, **kwargs: Any) -> None:
self.table = table
self.table_name = self.table.name
self.data = {}
self._dirty = set()
Service.__init__(self, loop=table.loop, **kwargs)

def send_changelog_event(self, key: Any, operation: int, value: Any) -> None:
"""Send changelog event to the tables changelog topic."""
event = current_event()
self._dirty.add(key)
self.table._send_changelog(event, (operation, key), value)
self.storage[key] = self[key].as_stored_value()

def __getitem__(self, key: Any) -> ChangeloggedObject:
if key in self.data:
Expand All @@ -100,10 +98,6 @@ async def on_start(self) -> None:
"""Call when the changelogged object manager starts."""
await self.add_runtime_dependency(self.storage)

async def on_stop(self) -> None:
"""Call when the changelogged object manager stops."""
self.flush_to_storage()

def persisted_offset(self, tp: TP) -> Optional[int]:
"""Get the last persisted offset for changelog topic partition."""
return self.storage.persisted_offset(tp)
Expand Down Expand Up @@ -133,17 +127,6 @@ def sync_from_storage(self) -> None:
for key, value in self.storage.items():
self[key].sync_from_storage(value)

def flush_to_storage(self) -> None:
"""Flush set contents to storage."""
for key in self._dirty:
self.storage[key] = self.data[key].as_stored_value()
self._dirty.clear()

@Service.task
async def _periodic_flush(self) -> None: # pragma: no cover
async for sleep_time in self.itertimer(2.0, name="SetManager.flush"):
self.flush_to_storage()

def reset_state(self) -> None:
"""Reset table local state."""
# delegate to underlying RocksDB store.
Expand Down
10 changes: 5 additions & 5 deletions faust/tables/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def on_change(self, added: Set[VT], removed: Set[VT]) -> None:
self.manager.send_changelog_event(self.key, OPERATION_UPDATE, [added, removed])

def sync_from_storage(self, value: Any) -> None:
self.data = cast(Set, value)
self.data = set(value)

def as_stored_value(self) -> Any:
return self.data
Expand Down Expand Up @@ -204,19 +204,19 @@ async def symmetric_difference_update(self, key: KT, members: Iterable[VT]) -> N
await self._send_operation(SetAction.SYMDIFF, key, members)

def _update(self, key: KT, members: List[VT]) -> None:
self.set_table[key].update(members)
self.set_table[key].update(set(members))

def _difference_update(self, key: KT, members: List[VT]) -> None:
self.set_table[key].difference_update(members)
self.set_table[key].difference_update(set(members))

def _clear(self, key: KT, members: List[VT]) -> None:
self.set_table[key].clear()

def _intersection_update(self, key: KT, members: List[VT]) -> None:
self.set_table[key].intersection_update(members)
self.set_table[key].intersection_update(set(members))

def _symmetric_difference_update(self, key: KT, members: List[VT]) -> None:
self.set_table[key].symmetric_difference_update(members)
self.set_table[key].symmetric_difference_update(set(members))

async def _send_operation(
self, action: SetAction, key: KT, members: Iterable[VT]
Expand Down
14 changes: 13 additions & 1 deletion faust/transport/drivers/aiokafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
from opentracing.ext import tags
from yarl import URL

from faust.auth import GSSAPICredentials, SASLCredentials, SSLCredentials
from faust.auth import (
GSSAPICredentials,
OAuthCredentials,
SASLCredentials,
SSLCredentials,
)
from faust.exceptions import (
ConsumerNotStarted,
ImproperlyConfigured,
Expand Down Expand Up @@ -1598,6 +1603,13 @@ def credentials_to_aiokafka_auth(
"security_protocol": credentials.protocol.value,
"ssl_context": credentials.context,
}
elif isinstance(credentials, OAuthCredentials):
return {
"security_protocol": credentials.protocol.value,
"sasl_mechanism": credentials.mechanism.value,
"sasl_oauth_token_provider": credentials.oauth_cb,
"ssl_context": credentials.ssl_context,
}
elif isinstance(credentials, SASLCredentials):
return {
"security_protocol": credentials.protocol.value,
Expand Down
1 change: 1 addition & 0 deletions faust/types/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class SASLMechanism(Enum):
GSSAPI = "GSSAPI"
SCRAM_SHA_256 = "SCRAM-SHA-256"
SCRAM_SHA_512 = "SCRAM-SHA-512"
OAUTHBEARER = "OAUTHBEARER"


AUTH_PROTOCOLS_SSL = {AuthProtocol.SSL, AuthProtocol.SASL_SSL}
Expand Down
17 changes: 2 additions & 15 deletions tests/unit/tables/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Test_ChangeloggedObjectManager:
def man(self, *, table):
man = ChangeloggedObjectManager(table)
man.ValueType = ValueType
man.storage.__setitem__ = Mock()
return man

@pytest.fixture()
Expand All @@ -62,7 +63,7 @@ def storage(self, *, table):

def test_send_changelog_event(self, *, man, table, key, current_event):
man.send_changelog_event(key, 3, "value")
assert key in man._dirty
assert man.storage.__setitem__.called_once_with(key, "value")
table._send_changelog.assert_called_once_with(
current_event(),
(3, key),
Expand Down Expand Up @@ -98,12 +99,6 @@ async def test_on_start(self, *, man):
await man.on_start()
man.add_runtime_dependency.assert_called_once_with(man.storage)

@pytest.mark.asyncio
async def test_on_stop(self, *, man):
man.flush_to_storage = Mock()
await man.on_stop()
man.flush_to_storage.assert_called_once_with()

def test_persisted_offset(self, *, man, storage):
ret = man.persisted_offset(TP1)
storage.persisted_offset.assert_called_once_with(TP1)
Expand Down Expand Up @@ -135,14 +130,6 @@ def test_sync_from_storage(self, *, man, storage):
assert 1 in man["foo"].synced
assert 2 in man["bar"].synced

def test_flush_to_storage(self, *, man):
man._storage = {}
man._dirty = {"foo", "bar"}
assert man["foo"]
assert man["bar"]
man.flush_to_storage()
assert man._storage["foo"] == "foo-stored"

def test_reset_state(self, *, man, storage):
man.reset_state()
storage.reset_state.assert_called_once_with()
Expand Down
26 changes: 13 additions & 13 deletions tests/unit/tables/test_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,12 @@ async def test_symmetric_difference_update(self, *, man):
def test__update(self, *, man):
man.set_table = {"a": Mock(name="a"), "b": Mock(name="b")}
man._update("a", ["v1"])
man.set_table["a"].update.assert_called_once_with(["v1"])
man.set_table["a"].update.assert_called_once_with({"v1"})

def test__difference_update(self, *, man):
man.set_table = {"a": Mock(name="a"), "b": Mock(name="b")}
man._difference_update("a", ["v1"])
man.set_table["a"].difference_update.assert_called_once_with(["v1"])
man.set_table["a"].difference_update.assert_called_once_with({"v1"})

def test__clear(self, *, man):
man.set_table = {"a": Mock(name="a"), "b": Mock(name="b")}
Expand All @@ -264,14 +264,14 @@ def test__intersection_update(self, *, man):
man.set_table = {"a": Mock(name="a"), "b": Mock(name="b")}
man._intersection_update("a", ["v1", "v2", "v3"])
man.set_table["a"].intersection_update.assert_called_once_with(
["v1", "v2", "v3"],
{"v1", "v2", "v3"},
)

def test__symmetric_difference_update(self, *, man):
man.set_table = {"a": Mock(name="a"), "b": Mock(name="b")}
man._symmetric_difference_update("a", ["v1", "v2", "v3"])
man.set_table["a"].symmetric_difference_update.assert_called_once_with(
["v1", "v2", "v3"],
{"v1", "v2", "v3"},
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -396,29 +396,29 @@ async def stream_items():

await man._modify_set(stream)

man.set_table["k1"].update.assert_called_with(["v"])
man.set_table["k2"].difference_update.assert_called_with(["v2"])
man.set_table["k3"].difference_update.assert_called_with([X(10, 30)])
man.set_table["k1"].update.assert_called_with({"v"})
man.set_table["k2"].difference_update.assert_called_with({"v2"})
man.set_table["k3"].difference_update.assert_called_with({X(10, 30)})
man.set_table["k5"].update.assert_called_with(
[
{
X(10, 30),
X(20, 40),
"v3",
]
}
)
man.set_table["k6"].intersection_update.assert_called_with(
[
{
X(10, 30),
X(20, 40),
"v3",
]
}
)
man.set_table["k7"].symmetric_difference_update.assert_called_with(
[
{
X(10, 30),
X(20, 40),
"v3",
]
}
)
man.set_table["k8"].clear.assert_called_once_with()

Expand Down
Loading

0 comments on commit 96e7b0c

Please sign in to comment.