Skip to content

Commit

Permalink
filter allowed_types (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
emphasize committed Jun 26, 2023
1 parent 4d367c8 commit f3f4648
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 10 deletions.
13 changes: 11 additions & 2 deletions hivemind_core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
def cast_to_client_obj():
valid_kwargs: Iterable[str] = ("client_id", "api_key", "name",
"description", "is_admin", "last_seen",
"blacklist", "crypto_key", "password")
"blacklist", "allowed_types", "crypto_key",
"password")

def _handler(func):

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self,
is_admin: bool = False,
last_seen: float = -1,
blacklist: Optional[Dict[str, List[str]]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None):

Expand All @@ -62,6 +64,9 @@ def __init__(self,
"skills": [],
"intents": []
}
self.allowed_types = allowed_types or ["recognizer_loop:utterance"]
if "recognizer_loop:utterance" not in self.allowed_types:
self.allowed_types.append("recognizer_loop:utterance")

def __getitem__(self, item: str) -> Any:
return self.__dict__.get(item)
Expand Down Expand Up @@ -179,6 +184,7 @@ def add_client(self,
key: str = "",
admin: bool = False,
blacklist: Optional[Dict[str, Any]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None) -> Client:

Expand All @@ -191,6 +197,8 @@ def add_client(self,
user["name"] = name
if blacklist:
user["blacklist"] = blacklist
if allowed_types:
user["allowed_types"] = allowed_types
if admin is not None:
user["is_admin"] = admin
if crypto_key:
Expand All @@ -202,7 +210,8 @@ def add_client(self,
user = Client(api_key=key, name=name,
blacklist=blacklist, crypto_key=crypto_key,
client_id=self.total_clients() + 1,
is_admin=admin, password=password)
is_admin=admin, password=password,
allowed_types=allowed_types)
self.add_item(user)
return user

Expand Down
20 changes: 17 additions & 3 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class HiveMindClientConnection:
socket: Optional[WebSocketHandler] = None
crypto_key: Optional[str] = None
blacklist: List[str] = field(default_factory=list) # list of ovos message_type to never be sent to this client
allowed_types: List[str] = field(default_factory=list) # list of ovos message_type to allow to be sent from this client

@property
def peer(self) -> str:
Expand All @@ -62,6 +63,16 @@ def peer(self) -> str:
return f"{self.name}:{self.ip}::{self.sess.session_id}"

def send(self, message: HiveMessage):
# TODO some cleaning around HiveMessage
if isinstance(message.payload, dict):
_msg_type = message.payload.get("type")
else:
_msg_type = message.payload.msg_type

if _msg_type in self.blacklist:
return LOG.debug(f"message type {_msg_type} "
f"is blacklisted for {self.peer}")

LOG.info(f"sending to {self.peer}: {message}")
payload = message.serialize() # json string
if self.crypto_key and message.msg_type not in [HiveMessageType.HANDSHAKE,
Expand Down Expand Up @@ -90,7 +101,7 @@ def decode(self, payload: str) -> HiveMessage:
def authorize(self, message: Message) -> bool:
""" parse the message being injected into ovos-core bus
if this client is not authorized to inject it return False"""
if message.msg_type in self.blacklist:
if message.msg_type not in self.allowed_types:
return False

# TODO check intent / skill that will trigger
Expand Down Expand Up @@ -185,7 +196,7 @@ def handle_internal_mycroft(self, message: str):
@dataclass()
class HiveMindListenerProtocol:
loop: ioloop.IOLoop
clients: dict = field(default_factory=dict)
clients = {}
internal_protocol: Optional[HiveMindListenerInternalProtocol] = None
peer: str = "master:0.0.0.0"

Expand Down Expand Up @@ -455,7 +466,10 @@ def handle_inject_mycroft_msg(self, message: Message, client: HiveMindClientConn

# ensure client specific session data is injected in query to ovos
message.context["session"] = client.sess.serialize()
message.context["destination"] = "skills" # ensure not treated as a broadcast
if message.msg_type == "speak":
message.context["destination"] = ["audio"]
elif message.context.get("destination") is None:
message.context["destination"] = "skills" # ensure not treated as a broadcast

# send client message to internal mycroft bus
LOG.info(f"Forwarding message to mycroft bus from client: {client.peer}")
Expand Down
53 changes: 52 additions & 1 deletion hivemind_core/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click
from ovos_utils.xdg_utils import xdg_data_home
from rich.console import Console
from rich.prompt import Prompt
from rich.table import Table

from hivemind_core.database import ClientDatabase
Expand Down Expand Up @@ -51,7 +52,57 @@ def add_client(name, access_key, password, crypto_key):
print("WARNING: Encryption Key is deprecated, only use if your client does not support password")


@hmcore_cmds.command(help="remove credentials for a client (numeric unique ID)", name="delete-client")
@hmcore_cmds.command(help="allow message types sent from a client", name="allow-msg")
@click.argument("msg_type", required=True, type=str)
@click.argument("node_id", required=False, type=int)
def allow_msg(msg_type, node_id):
if not node_id:
# list clients and prompt for id using rich
table = Table(title="HiveMind Clients")
table.add_column("ID", justify="right", style="cyan", no_wrap=True)
table.add_column("Name", style="magenta")
table.add_column("Allowed Msg Types", style="yellow")
_choices = []
for client in ClientDatabase():
if client["client_id"] != -1:
table.add_row(str(client["client_id"]),
client["name"],
str(client.get("allowed_types", [])))
_choices.append(str(client["client_id"]))

if not _choices:
print("No clients found!")
exit()
elif len(_choices) > 1:
console = Console()
console.print(table)
_exit = str(max(int(i) for i in _choices) + 1)
node_id = Prompt.ask(f"To which client you want to add '{msg_type}'? ({_exit}='Exit')",
choices=_choices + [_exit])
if node_id == _exit:
console.print("User exit", style="red")
exit()
else:
node_id = _choices[0]

with ClientDatabase() as db:
for client in db:
if client["client_id"] == int(node_id):
allowed_types = client.get("allowed_types", [])
if msg_type in allowed_types:
print(f"Client {client['name']} already allowed '{msg_type}'")
exit()

allowed_types.append(msg_type)
client["allowed_types"] = allowed_types
item_id = db.get_item_id(client)
db.update_item(item_id, client)
print(f"Allowed '{msg_type}' for {client['name']}")
break


@hmcore_cmds.command(help="remove credentials for a client (numeric unique ID)",
name="delete-client")
@click.argument("node_id", required=True, type=int)
def delete_client(node_id):
with ClientDatabase() as db:
Expand Down
9 changes: 5 additions & 4 deletions hivemind_core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,12 @@ def open(self):
self.close()
return

self.client.crypto_key = users.get_crypto_key(key)
pswd = users.get_password(key)
if pswd:
self.client.crypto_key = user.crypto_key
self.client.blacklist = user.blacklist.get("messages", [])
self.client.allowed_types = user.allowed_types
if user.password:
# pre-shared password to derive aes_key
self.client.pswd_handshake = PasswordHandShake(pswd)
self.client.pswd_handshake = PasswordHandShake(user.password)

self.client.node_type = HiveMindNodeType.NODE # TODO . placeholder

Expand Down

0 comments on commit f3f4648

Please sign in to comment.