Skip to content

Commit

Permalink
Add missing input validation for the client config and partition filt…
Browse files Browse the repository at this point in the history
…er dictionary
  • Loading branch information
juliannguyen4 committed Feb 28, 2024
1 parent fb3d610 commit c012182
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 118 deletions.
7 changes: 6 additions & 1 deletion src/include/tls_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@

#include "macros.h"

void setup_tls_config(as_config *config, PyObject *tls_config);
typedef struct {
char *tls_key;
char *expected_type;
} as_error_type_info;

as_error_type_info *setup_tls_config(as_config *config, PyObject *tls_config);
298 changes: 264 additions & 34 deletions src/main/client/type.c

Large diffs are not rendered by default.

35 changes: 29 additions & 6 deletions src/main/convert_partition_filter.c
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,41 @@ as_status convert_partition_filter(AerospikeClient *self,
}

filter->digest.init = 0;
if (digest && PyDict_Check(digest)) {

if (digest) {
if (!PyDict_Check(digest)) {
as_error_update(err, AEROSPIKE_ERR_PARAM,
"partition_filter[\"digest\"] must be a dict");
goto ERROR_CLEANUP;
}
// TODO check these for overflow
PyObject *init = PyDict_GetItemString(digest, "init");
if (init && PyLong_Check(init)) {
filter->digest.init = PyLong_AsLong(init);
if (init) {
if (!PyBool_Check(init)) {
as_error_update(
err, AEROSPIKE_ERR_PARAM,
"partition_filter[\"digest\"][\"init\"] must be a bool");
goto ERROR_CLEANUP;
}
filter->digest.init = (bool)PyObject_IsTrue(init);
}

PyObject *value = PyDict_GetItemString(digest, "value");
if (value && PyUnicode_Check(value)) {
if (value) {
if (!PyByteArray_Check(value)) {
as_error_update(err, AEROSPIKE_ERR_PARAM,
"partition_filter[\"digest\"][\"value\"] must "
"be a bytearray");
goto ERROR_CLEANUP;
}
if (PyByteArray_Size(value) != AS_DIGEST_VALUE_SIZE) {
as_error_update(err, AEROSPIKE_ERR_PARAM,
"partition_filter[\"digest\"][\"value\"] must "
"be %d bytes long",
AS_DIGEST_VALUE_SIZE);
goto ERROR_CLEANUP;
}
strncpy((char *)filter->digest.value,
(char *)PyUnicode_AsUTF8(value), AS_DIGEST_VALUE_SIZE);
(char *)PyByteArray_AsString(value), AS_DIGEST_VALUE_SIZE);
}
}

Expand Down
108 changes: 52 additions & 56 deletions src/main/tls_config.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "tls_config.h"

static void _set_config_str_if_present(as_config *config, PyObject *tls_config,
static bool _set_config_str_if_present(as_config *config, PyObject *tls_config,
const char *key);

static char *get_string_from_string_like(PyObject *string_like);
Expand All @@ -25,68 +25,59 @@ static char *get_string_from_string_like(PyObject *string_like);
* Param: tls_conf PyDict.
* Fill in the appropriate TLS values of config based on the contents of
* tls_config
Returns NULL if no error occurred, or the config key and expected value type where the error occurred
***/
void setup_tls_config(as_config *config, PyObject *tls_config)
as_error_type_info *setup_tls_config(as_config *config, PyObject *tls_config)
{

PyObject *config_value = NULL;
int truth_value = -1;

// Setup string values in the tls config struct
_set_config_str_if_present(config, tls_config, "cafile");
_set_config_str_if_present(config, tls_config, "capath");
_set_config_str_if_present(config, tls_config, "protocols");
_set_config_str_if_present(config, tls_config, "cipher_suite");
_set_config_str_if_present(config, tls_config, "cert_blacklist");
_set_config_str_if_present(config, tls_config, "keyfile");
_set_config_str_if_present(config, tls_config, "certfile");
_set_config_str_if_present(config, tls_config, "keyfile_pw");

// Setup The boolean values of the struct if they are present
config_value = PyDict_GetItemString(tls_config, "enable");
if (config_value) {
truth_value = PyObject_IsTrue(config_value);
if (truth_value != -1) {
config->tls.enable = (bool)truth_value;
truth_value = -1;
char *tls_config_keys[] = {"cafile", "capath", "protocols",
"cipher_suite", "cert_blacklist", "keyfile",
"certfile", "keyfile_pw"};
for (unsigned long i = 0;
i < sizeof(tls_config_keys) / sizeof(tls_config_keys[0]); i++) {
bool error =
_set_config_str_if_present(config, tls_config, tls_config_keys[i]);
if (error) {
as_error_type_info *info =
(as_error_type_info *)malloc(sizeof(as_error_type_info));
info->tls_key = tls_config_keys[i];
info->expected_type = "str";
return info;
}
}

config_value = PyDict_GetItemString(tls_config, "crl_check");
if (config_value) {
truth_value = PyObject_IsTrue(config_value);
if (truth_value != -1) {
config->tls.crl_check = (bool)truth_value;
truth_value = -1;
}
}

config_value = PyDict_GetItemString(tls_config, "crl_check_all");
if (config_value) {
truth_value = PyObject_IsTrue(config_value);
if (truth_value != -1) {
config->tls.crl_check_all = (bool)truth_value;
truth_value = -1;
}
}

config_value = PyDict_GetItemString(tls_config, "log_session_info");
if (config_value) {
truth_value = PyObject_IsTrue(config_value);
if (truth_value != -1) {
config->tls.log_session_info = (bool)truth_value;
truth_value = -1;
// Setup The boolean values of the struct if they are present
char *tls_config_keys_with_bool_value[] = {
"enable", "crl_check", "crl_check_all", "log_session_info",
"for_login_only"};
bool *config_bool_ptrs[] = {
&config->tls.enable, &config->tls.crl_check, &config->tls.crl_check_all,
&config->tls.log_session_info, &config->tls.for_login_only};
PyObject *config_value = NULL;
int truth_value = -1;
unsigned long config_key_count = sizeof(tls_config_keys_with_bool_value) /
sizeof(tls_config_keys_with_bool_value[0]);
for (unsigned long i = 0; i < config_key_count; i++) {
config_value = PyDict_GetItemString(tls_config,
tls_config_keys_with_bool_value[i]);
Py_XINCREF(config_value);
if (config_value) {
if (!PyBool_Check(config_value)) {
as_error_type_info *info =
(as_error_type_info *)malloc(sizeof(as_error_type_info));
info->tls_key = tls_config_keys_with_bool_value[i];
info->expected_type = "bool";
return info;
}
truth_value = PyObject_IsTrue(config_value);
if (truth_value != -1) {
*config_bool_ptrs[i] = (bool)truth_value;
}
}
Py_XDECREF(config_value);
}

config_value = PyDict_GetItemString(tls_config, "for_login_only");
if (config_value) {
truth_value = PyObject_IsTrue(config_value);
if (truth_value != -1) {
config->tls.for_login_only = (bool)truth_value;
truth_value = -1;
}
}
return NULL;
}

/***
Expand All @@ -95,11 +86,12 @@ void setup_tls_config(as_config *config, PyObject *tls_config)
* Param config: the as_config in which to store information
* If tls_config is a string type, and key is valid,
* the appropriate field is set
* Returns false if no error, true if an invalid value type was passed to the TLS config at "key"
***/
static void _set_config_str_if_present(as_config *config, PyObject *tls_config,
static bool _set_config_str_if_present(as_config *config, PyObject *tls_config,
const char *key)
{

PyObject *config_value = NULL;
char *config_value_str = NULL;

Expand Down Expand Up @@ -142,7 +134,11 @@ static void _set_config_str_if_present(as_config *config, PyObject *tls_config,
(const char *)config_value_str);
}
}
else {
return true;
}
}
return false;
}

/***
Expand Down
5 changes: 1 addition & 4 deletions test/new_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,7 @@ def as_connection(request):
wait_for_port(a, p)
# We are using tls otherwise, so rely on the server being ready

if config["user"] is None and config["password"] is None:
as_client = aerospike.client(config).connect()
else:
as_client = aerospike.client(config).connect(config["user"], config["password"])
as_client = aerospike.client(config).connect()

request.cls.skip_old_server = True
request.cls.server_version = []
Expand Down
11 changes: 5 additions & 6 deletions test/new_tests/test_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ def get_new_connection(add_config=None):
config[key] = add_config[key]

client = aerospike.client(config)
if config["user"] is None and config["password"] is None:
client.connect()
else:
client.connect(config["user"], config["password"])

if client is not None:
build_info = client.info_all("build")
Expand Down Expand Up @@ -214,8 +210,11 @@ def get_connection_config():
config["hosts"] = hosts_conf
config["tls"] = tls_conf
config["policies"] = policies_conf
config["user"] = TestBaseClass.user
config["password"] = TestBaseClass.password
# Cannot pass `None` as a config value
if TestBaseClass.user is not None:
config["user"] = TestBaseClass.user
if TestBaseClass.password is not None:
config["password"] = TestBaseClass.password

# Disable total_timeout and timeout
# config["timeout"] = 0
Expand Down
2 changes: 1 addition & 1 deletion test/new_tests/test_bool_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_bool_read_write_pos(self, send_bool_as, expected_true, expected_false):
"""
config = TestBaseClass.get_connection_config()
config["send_bool_as"] = send_bool_as
test_client = aerospike.client(config).connect(config["user"], config["password"])
test_client = aerospike.client(config).connect()
ops = [
operation.write("cfg_true", True),
operation.write("cfg_false", False),
Expand Down
7 changes: 1 addition & 6 deletions test/new_tests/test_close.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ class TestClose:
def setup_class(cls):
config = TestBaseClass.get_connection_config()
TestClose.hostlist = config["hosts"]
TestClose.user = config["user"]
TestClose.password = config["password"]
TestClose.auth_mode = config["policies"]["auth_mode"]

def test_pos_close(self):
Expand Down Expand Up @@ -53,10 +51,7 @@ def test_close_twice_in_a_row(self):
Client call itself establishes connection.
"""
config = TestBaseClass.get_connection_config()
if TestClose.user is None and TestClose.password is None:
self.client = aerospike.client(config).connect()
else:
self.client = aerospike.client(config).connect(TestClose.user, TestClose.password)
self.client = aerospike.client(config).connect()

assert self.client.is_connected()
self.client.close()
Expand Down
6 changes: 5 additions & 1 deletion test/new_tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def test_connect_with_extra_args(self):
({"hosts": [3000]}, e.ParamError, -2, "Invalid host"),
# Errors that throw -10 can also throw 9
({"hosts": [("127.0.0.1", 2000)]}, (e.ClientError, e.TimeoutError), (-10, 9), "Failed to connect"),
({"hosts": [("127.0.0.1", "3000")]}, e.ClientError, -10, "Failed to connect"),
({"hosts": [("127.0.0.1", "3000")]}, e.ParamError, -2, "config[\"hosts\"][0][1] must be a int"),
({"hosts": [(1, 3000, "tls-name")]}, e.ParamError, -2, "config[\"hosts\"][0][0] must be a str"),
({"hosts": [("127.0.0.1", 3000, 1)]}, e.ParamError, -2, "config[\"hosts\"][0][2] must be a str"),
],
ids=[
"config not dict",
Expand All @@ -205,6 +207,8 @@ def test_connect_with_extra_args(self):
"hosts missing address",
"hosts port is incorrect",
"hosts port is string",
"host name is non-str",
"host tls-name is non-str"
],
)
def test_connect_invalid_configs(self, config, err, err_code, err_msg):
Expand Down
6 changes: 3 additions & 3 deletions test/new_tests/test_expressions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_bool_bin_true(self):

config = TestBaseClass.get_connection_config()
config["send_bool_as"] = aerospike.AS_BOOL
test_client = aerospike.client(config).connect(config["user"], config["password"])
test_client = aerospike.client(config).connect()

expr = BoolBin("t")
ops = [operations.write("t", True), expressions.expression_read("test", expr.compile())]
Expand All @@ -301,7 +301,7 @@ def test_bool_bin_false(self):

config = TestBaseClass.get_connection_config()
config["send_bool_as"] = aerospike.AS_BOOL
test_client = aerospike.client(config).connect(config["user"], config["password"])
test_client = aerospike.client(config).connect()

expr = Not(BoolBin("t"))
ops = [operations.write("t", True), expressions.expression_read("test", expr.compile())]
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_bintype_as_bool(self):
# Configure client to encode and send booleans as the server boolean type
config = TestBaseClass.get_connection_config()
config["send_bool_as"] = aerospike.AS_BOOL
test_client = aerospike.client(config).connect(config["user"], config["password"])
test_client = aerospike.client(config).connect()

# Override record 0's "t" bin to be a server bool instead of a python bool
key = ("test", "demo", 0)
Expand Down
65 changes: 65 additions & 0 deletions test/new_tests/test_new_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,68 @@ def test_query_client_default_ttl(self, config_ttl_setup):
wait_for_job_completion(self.client, job_id)

self.check_ttl()

# Helper function for the test below
# Example: {"lua": {"user_path": 1}} -> ["lua"]["user_path"]
def get_err_msg_keys(self, dict_to_traverse: dict):
keys = []
curr_value = dict_to_traverse
while type(curr_value) is dict:
key = list(curr_value.keys())[0]
keys.append(key)
curr_value = curr_value[key]
keys = ["[\"{0}\"]".format(key) for key in keys]
keys = "".join(keys)
return keys

@pytest.mark.parametrize("invalid_config, expected_type", [
({"lua": 1}, dict),
({"lua": {"user_path": 1}}, str),
({"tls": 1}, dict),
({"tls": {"enable": 1}}, bool),
({"tls": {"cafile": 1}}, str),
({"tls": {"capath": 1}}, str),
({"tls": {"protocols": 1}}, str),
({"tls": {"cipher_suite": 1}}, str),
({"tls": {"keyfile": 1}}, str),
({"tls": {"keyfile_pw": 1}}, str),
({"tls": {"cert_blacklist": 1}}, str),
({"tls": {"certfile": 1}}, str),
({"tls": {"crl_check": 1}}, bool),
({"tls": {"crl_check_all": 1}}, bool),
({"tls": {"log_session_info": 1}}, bool),
({"tls": {"for_login_only": 1}}, bool),
({"shm": 1}, dict),
({"shm": {"max_nodes": True}}, int),
({"shm": {"max_namespaces": True}}, int),
({"shm": {"takeover_threshold_sec": True}}, int),
({"shm": {"shm_key": True}}, int),
({"serialization": 1}, tuple),
({"policies": 1}, dict),
({"policies": {"login_timeout_ms": True}}, int),
({"thread_pool_size": True}, int),
({"max_threads": True}, int),
({"min_conns_per_node": True}, int),
({"max_conns_per_node": True}, int),
({"max_error_rate": True}, int),
({"error_rate_window": True}, int),
({"connect_timeout": True}, int),
({"use_shared_connection": 1}, bool),
({"send_bool_as": True}, int),
({"compression_threshold": True}, int),
({"tend_interval": True}, int),
({"cluster_name": 1}, str),
({"max_socket_idle": True}, int),
({"fail_if_not_connected": 1}, bool),
({"user": 1}, str),
({"password": 1}, str),
])
def test_client_config_invalid_value_types(self, invalid_config: dict, expected_type: type):
# Hosts key is required so just copy it from gconfig
config = copy.deepcopy(gconfig)
config.update(invalid_config)
with pytest.raises(e.ParamError) as excinfo:
aerospike.client(config)

err_msg_keys = self.get_err_msg_keys(invalid_config)
assert excinfo.value.msg == f"config{err_msg_keys} must be a {expected_type.__name__}"
Loading

0 comments on commit c012182

Please sign in to comment.