Skip to content

Commit

Permalink
fix: python client (#219)
Browse files Browse the repository at this point in the history
Co-authored-by: Pratik Mishra <[email protected]>
  • Loading branch information
pratikmishra356 and Pratik Mishra committed Sep 5, 2024
1 parent 670d44d commit c86d330
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 80 deletions.
2 changes: 1 addition & 1 deletion clients/python/cacclient/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .client import CacClient
from .client import CacClient, MergeStrategy
68 changes: 50 additions & 18 deletions clients/python/cacclient/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import os
import threading
import ast

platform = os.uname().sysname.lower()
lib_path = os.environ.get("SUPERPOSITION_LIB_PATH")
Expand All @@ -15,6 +16,22 @@

lib_path = os.path.join(lib_path, file_name)


from enum import Enum, auto

class MergeStrategy(Enum):
MERGE = auto()
REPLACE = auto()

class Config:
def __init__(self, config_dict):
try:
self.contexts = config_dict['contexts']
self.overrides = config_dict['overrides']
self.default_configs = config_dict['default_configs']
except Exception as e:
raise Exception("Invalid config dictionary", e)

class CacClient:
rust_lib = ctypes.CDLL(lib_path)

Expand Down Expand Up @@ -57,6 +74,12 @@ def __init__(self, tenant_name: str, polling_frequency: int, cac_host_name: str)
self.polling_frequency = polling_frequency
self.cac_host_name = cac_host_name

resp = self.rust_lib.cac_new_client(
self.tenant.encode(), self.polling_frequency, self.cac_host_name.encode())
if resp == 1:
error_message = self.get_cac_last_error_message()
raise Exception("Error Occured while creating new client ", error_message)

def get_cac_last_error_message(self) -> str:
return self.rust_lib.cac_last_error_message().decode()

Expand All @@ -66,40 +89,49 @@ def get_cac_last_error_length(self) -> int:
def get_cac_client(self) -> str:
return self.rust_lib.cac_get_client(self.tenant.encode())

def create_new_cac_client(self) -> int:
resp = self.rust_lib.cac_new_client(
self.tenant.encode(), self.polling_frequency, self.cac_host_name.encode())
if resp == 1:
error_message = self.get_cac_last_error_message()
print("Some Error Occur while creating new client ", error_message)
return resp

def start_cac_polling_update(self):
threading.Thread(target=self._polling_update_worker).start()

def _polling_update_worker(self):
self.rust_lib.cac_start_polling_update(self.tenant.encode())

def get_cac_config(self, filter_query: str | None = None, filter_prefix: str | None = None) -> str:
def get_cac_config(self, filter_query: str | None = None, filter_prefix: str | None = None) -> Config:
client_ptr = self.get_cac_client()
filter_prefix_ptr = None if filter_prefix is None else filter_prefix.encode()
filter_query_ptr = None if filter_query is None else filter_query.encode()
return self.rust_lib.cac_get_config(client_ptr, filter_query_ptr, filter_prefix_ptr).decode()

try:
result = self.rust_lib.cac_get_config(client_ptr, filter_query_ptr, filter_prefix_ptr).decode()
print("pppp", result)
print(ast.literal_eval(result))
return Config(ast.literal_eval(result))
except:
raise Exception(self.rust_lib.get_cac_last_error_message())

def free_cac_client(self, client_ptr: str):
self.rust_lib.cac_free_client(client_ptr.encode())

def free_cac_string(self, string: str):
self.rust_lib.cac_free_string(string.encode())

def get_last_modified(self) -> str:
return self.rust_lib.cac_get_last_modified(self.get_cac_client()).decode()
try:
return self.rust_lib.cac_get_last_modified(self.get_cac_client()).decode()
except:
raise Exception(self.rust_lib.get_cac_last_error_message())

def get_resolved_config(self, query: str, merge_strategy: str, filter_keys: str | None = None) -> str:
def get_resolved_config(self, query: dict, merge_strategy: MergeStrategy, filter_keys: str | None = None) -> dict:
filter_keys_ptr = None if filter_keys is None else filter_keys.encode()
return self.rust_lib.cac_get_resolved_config(
self.get_cac_client(), query.encode(), filter_keys_ptr, merge_strategy.encode()).decode()

def get_default_config(self, filter_keys: str | None = None) -> str:
try:
result = self.rust_lib.cac_get_resolved_config(
self.get_cac_client(), str(query).encode(), filter_keys_ptr, merge_strategy.name.encode()).decode()
return ast.literal_eval(result)
except:
raise Exception(self.rust_lib.get_cac_last_error_message())

def get_default_config(self, filter_keys: list[str] | None = None) -> dict:
filter_keys_ptr = None if filter_keys is None else filter_keys.encode()
return self.rust_lib.cac_get_default_config(self.get_cac_client(), filter_keys_ptr).decode()
try:
result = self.rust_lib.cac_get_default_config(self.get_cac_client(), filter_keys_ptr).decode()
return ast.literal_eval(result)
except:
raise Exception(self.rust_lib.get_cac_last_error_message())
56 changes: 34 additions & 22 deletions clients/python/expclient/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import os
import threading
import ast

platform = os.uname().sysname.lower()
lib_path = os.environ.get("SUPERPOSITION_LIB_PATH")
Expand Down Expand Up @@ -59,22 +60,22 @@ def __init__(self, tenant_name: str, polling_frequency: int, cac_host_name: str)
self.polling_frequency = polling_frequency
self.cac_host_name = cac_host_name

def get_experimentation_last_error_message(self) -> str:
return self.rust_lib.expt_last_error_message().decode()

def create_new_experimentation_client(self) -> int:
resp_code = self.rust_lib.expt_new_client(self.tenant.encode(), self.polling_frequency, self.cac_host_name.encode())
if resp_code == 1:
error_message = self.get_experimentation_last_error_message()
print("Some error occurred while creating new experimentation client:", error_message)
raise Exception("Client Creation Error")
return resp_code
raise Exception("Error Occured while creating new client", error_message)

def get_experimentation_last_error_message(self) -> str:
return self.rust_lib.expt_last_error_message().decode()

def get_experimentation_client(self) -> str:
return self.rust_lib.expt_get_client(self.tenant.encode())

def get_running_experiments(self) -> str:
return self.rust_lib.expt_get_running_experiments(self.get_experimentation_client()).decode()
try:
return self.rust_lib.expt_get_running_experiments(self.get_experimentation_client()).decode()
except:
raise Exception(self.rust_lib.get_experimentation_last_error_message())

def free_string(self, string: str):
self.rust_lib.expt_free_string(string.encode())
Expand All @@ -91,19 +92,30 @@ def get_experimentation_last_error_length(self) -> int:
def free_experimentation_client(self):
self.rust_lib.expt_free_client(self.get_experimentation_client())

def get_filtered_satisfied_experiments(self, context: str, filter_prefix: str | None = None) -> str:
def get_filtered_satisfied_experiments(self, context: dict, filter_prefix: list[str] | None = None) -> list[dict]:
filter_prefix_ptr = None if filter_prefix is None else filter_prefix.encode()
return self.rust_lib.expt_get_filtered_satisfied_experiments(
self.get_experimentation_client(), context.encode(), filter_prefix_ptr
).decode()

def get_applicable_variant(self, context: str, toss: int) -> str:
return self.rust_lib.expt_get_applicable_variant(
self.get_experimentation_client(), context.encode(), toss
).decode()

def get_satisfied_experiments(self, context: str, filter_prefix: str | None = None) -> str:
try:
return self.rust_lib.expt_get_filtered_satisfied_experiments(
self.get_experimentation_client(), str(context).encode(), filter_prefix_ptr
).decode()
except:
raise Exception(self.rust_lib.get_experimentation_last_error_message())

def get_applicable_variant(self, context: dict, toss: int) -> list[str]:
try:
result = self.rust_lib.expt_get_applicable_variant(
self.get_experimentation_client(), str(context).encode(), toss).decode()
return ast.literal_eval(result)
except:
raise Exception(self.rust_lib.get_experimentation_last_error_message())

def get_satisfied_experiments(self, context: dict, filter_prefix: list[str] | None = None) -> list[dict]:
filter_prefix_ptr = None if filter_prefix is None else filter_prefix.encode()
return self.rust_lib.expt_get_satisfied_experiments(
self.get_experimentation_client(), context.encode(), filter_prefix_ptr
).decode()
try:
result = self.rust_lib.expt_get_satisfied_experiments(
self.get_experimentation_client(), str(context).encode(), filter_prefix_ptr
).decode()
return ast.literal_eval(result)

except:
raise Exception(self.rust_lib.get_experimentation_last_error_message())
11 changes: 4 additions & 7 deletions clients/python/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from expclient import ExperimentationClient
from cacclient import CacClient
from cacclient import CacClient , MergeStrategy
import http.server
import socketserver

Expand All @@ -10,9 +10,6 @@
exp_client = ExperimentationClient(tenant_name, polling_frequency, cac_host_name)
cac_client = CacClient(tenant_name, polling_frequency, cac_host_name)

exp_client.create_new_experimentation_client()
cac_client.create_new_cac_client()

cac_client.start_cac_polling_update()
exp_client.start_experimentation_polling_update()

Expand All @@ -28,13 +25,13 @@ def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(cacClientResp.encode())
self.wfile.write(str((cacClientResp)).encode())
elif self.path == '/testexp':
expClientResp = exp_client.get_running_experiments()
expClientResp = exp_client.get_satisfied_experiments({})
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(expClientResp.encode())
self.wfile.write(str(expClientResp).encode())

with socketserver.TCPServer(("", PORT), MyHandler) as httpd:
print(f"Serving at port http://localhost:{PORT}")
Expand Down
3 changes: 3 additions & 0 deletions crates/cac_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ impl Client {
let reqwc = clone_reqw(&reqw)?;
let resp = reqwc.send().await.map_err_to_string()?;
let last_modified_at = get_last_modified(&resp);
if resp.status().is_client_error() {
return Err("Invalid tenant".to_string());
}
let config = resp.json::<Config>().await.map_err_to_string()?;

let client = Client {
Expand Down
72 changes: 40 additions & 32 deletions crates/experimentation_client/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,20 @@ pub extern "C" fn expt_get_satisfied_experiments(
};

let local = task::LocalSet::new();
let experiments = local.block_on(&Runtime::new().unwrap(), unsafe {
(*client).get_satisfied_experiments(&context, prefix_list)
});
let experiments = unwrap_safe!(
serde_json::to_value(experiments),
return std::ptr::null_mut()
);
serde_json::to_string(&experiments)
.map(|exp| rstring_to_cstring(exp).into_raw())
.unwrap_or_else(|err| error_block(err.to_string()))
local.block_on(&Runtime::new().unwrap(), async move {
unsafe {
unwrap_safe!(
(*client)
.get_satisfied_experiments(&context, prefix_list)
.await
.map(|exp| {
rstring_to_cstring(serde_json::to_value(exp).unwrap().to_string())
.into_raw()
}),
std::ptr::null_mut()
)
}
})
}

#[no_mangle]
Expand Down Expand Up @@ -249,31 +253,35 @@ pub extern "C" fn expt_get_filtered_satisfied_experiments(

Some(prefix_list).filter(|list| !list.is_empty())
};

let local = task::LocalSet::new();
let experiments = local.block_on(&Runtime::new().unwrap(), unsafe {
(*client).get_filtered_satisfied_experiments(&context, prefix_list)
});
let experiments = unwrap_safe!(
serde_json::to_value(experiments),
return std::ptr::null_mut()
);
serde_json::to_string(&experiments)
.map(|exp| rstring_to_cstring(exp).into_raw())
.unwrap_or_else(|err| error_block(err.to_string()))
local.block_on(&Runtime::new().unwrap(), async move {
unsafe {
unwrap_safe!(
(*client)
.get_filtered_satisfied_experiments(&context, prefix_list)
.await
.map(|exp| {
rstring_to_cstring(serde_json::to_value(exp).unwrap().to_string())
.into_raw()
}),
std::ptr::null_mut()
)
}
})
}

#[no_mangle]
pub extern "C" fn expt_get_running_experiments(client: *mut Arc<Client>) -> *mut c_char {
let experiments =
EXP_RUNTIME.block_on(unsafe { (*client).get_running_experiments() });
let experiments = unwrap_safe!(
serde_json::to_value(experiments),
return std::ptr::null_mut()
);
let result = unwrap_safe!(
serde_json::to_string(&experiments),
return std::ptr::null_mut()
);
rstring_to_cstring(result).into_raw()
let local = task::LocalSet::new();
local.block_on(&Runtime::new().unwrap(), async move {
unsafe {
unwrap_safe!(
(*client).get_running_experiments().await.map(|exp| {
rstring_to_cstring(serde_json::to_value(exp).unwrap().to_string())
.into_raw()
}),
std::ptr::null_mut()
)
}
})
}

0 comments on commit c86d330

Please sign in to comment.