Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
haobibo committed May 15, 2024
2 parents b0bd45f + 0bcf850 commit 5d46522
Show file tree
Hide file tree
Showing 13 changed files with 357 additions and 30 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Aloha!

[![License](https://img.shields.io/github/license/QPod/aloha)](https://github.com/QPod/aloha/blob/main/LICENSE)
[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/QPod/aloha/build)](https://github.com/QPod/aloha/actions)
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/QPod/aloha-python/pip.yml?branch=main)](https://github.com/QPod/aloha-python/actions)
[![Join the Gitter Chat](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/QPod/)
[![PyPI version](https://img.shields.io/pypi/v/aloha)](https://pypi.python.org/pypi/aloha/)
[![PyPI Downloads](https://img.shields.io/pypi/dm/aloha)](https://pepy.tech/badge/aloha/)
Expand All @@ -21,6 +21,6 @@ Please generously STAR★ our project or donate to us! [![GitHub Starts](https:

## Getting started

```py
```shell
pip install aloha[all]
```
Empty file.
86 changes: 86 additions & 0 deletions demo/app_common/ainlp/model_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List

import torch
from transformers import AutoTokenizer, AutoModel

from aloha.service.streamer import ManagedModel

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


class TextUnmaskModel:
def __init__(self, max_sent_len=16, model_path="bert-base-uncased"):
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.transformer = AutoModel.from_pretrained(self.model_path)
self.transformer.eval()
self.transformer.to(device="cuda")
self.max_sent_len = max_sent_len

def predict(self, batch: List[str]) -> List[str]:
"""predict masked word"""
batch_inputs = []
masked_indexes = []

for text in batch:
tokenized_text = self.tokenizer.tokenize(text)
if len(tokenized_text) > self.max_sent_len - 2:
tokenized_text = tokenized_text[: self.max_sent_len - 2]

tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]']
tokenized_text += ['[PAD]'] * (self.max_sent_len - len(tokenized_text))

indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
batch_inputs.append(indexed_tokens)
masked_indexes.append(tokenized_text.index('[MASK]'))

tokens_tensor = torch.tensor(batch_inputs).to("cuda")

with torch.no_grad():
# prediction_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
prediction_scores = self.transformer(tokens_tensor)[0]

batch_outputs = []
for i in range(len(batch_inputs)):
predicted_index = torch.argmax(prediction_scores[i, masked_indexes[i]]).item()
predicted_token = self.tokenizer.convert_ids_to_tokens(predicted_index)
batch_outputs.append(predicted_token)

return batch_outputs


class ManagedBertModel(ManagedModel):
def init_model(self):
self.model = TextUnmaskModel()

def predict(self, batch):
return self.model.predict(batch)


def test_simple():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello! My name is [MASK]!", return_tensors="pt")
outputs = model(**inputs)
print(outputs)

predicted_index = torch.argmax(outputs[1]).item()
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
print(predicted_token)


def test_batch():
batch_text = [
"twinkle twinkle [MASK] star.",
"Happy birthday to [MASK].",
'the answer to life, the [MASK], and everything.'
]
model = TextUnmaskModel()
outputs = model.predict(batch_text)
print(outputs)


if __name__ == "__main__":
test_simple()
Empty file.
14 changes: 14 additions & 0 deletions demo/app_common/api/api_multipart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from aloha.logger import LOG
from aloha.service.api.v0 import APIHandler


class MultipartHandler(APIHandler):
def response(self, params=None, *args, **kwargs):
LOG.debug(params)
return params


default_handlers = [
# internal API: QueryDB Postgres with sql directly
(r"/api_internal/multipart", MultipartHandler),
]
1 change: 1 addition & 0 deletions demo/app_common/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ def main():
modules_to_load = [
"app_common.api.api_common_sys_info",
"app_common.api.api_common_query_postgres",
"app_common.api.api_multipart",
]

if 'service' not in SETTINGS.config:
Expand Down
8 changes: 6 additions & 2 deletions src/aloha/config/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@ def get_config_files() -> list:

files = files_config.split(',')
ret = []
msgs = []
for f in files:
file = get_config_dir(f)
if not os.path.exists(file):
warnings.warn('Expecting config file [%s] but it does not exists!' % file)
msgs.append('Expecting config file [%s] but it does not exists!' % file)
else:
print(' ---> Loading config file [%s]' % file)
ret.append(os.path.expandvars(f))
if len(ret) == 0:
warnings.warn('No config files set properly, EMPTY config will be used!')
msgs.append('No config files set properly, EMPTY config will be used!')

if len(msgs) > 0:
warnings.warn('\n'.join(msgs))
return ret


Expand Down
3 changes: 2 additions & 1 deletion src/aloha/encrypt/vault/cyberark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from ...logger import LOG

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += ':HIGHT:!DH:!aNULL'
if hasattr(requests.packages.urllib3.util.ssl_, 'DEFAULT_CIPHERS'):
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += ':HIGHT:!DH:!aNULL'


class CyberArkVault(BaseVault, AesEncryptor):
Expand Down
6 changes: 4 additions & 2 deletions src/aloha/service/api/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ class APIHandler(AbstractApiHandler, ABC):
}

async def post(self, *args, **kwargs):
body_arguments = self.request_body
kwargs.update(body_arguments)
req_body = self.request_body

if req_body is not None: # body_arguments
kwargs.update(req_body)

resp = dict(code=5200, message=['success'])
try:
Expand Down
14 changes: 11 additions & 3 deletions src/aloha/service/http/base_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def request_body(self) -> dict:
body_arguments: dict = Optional[None]

if content_type.startswith('multipart/form-data'): # only parse files when 'Content-Type' starts with 'multipart/form-data'
body_arguments = self.request.body_arguments
body_arguments = self.request_param # self.request.body_arguments
else:
try:
body = self.request.body.decode('utf-8')
Expand All @@ -62,8 +62,16 @@ def request_body(self) -> dict:

@property
def request_param(self) -> dict:
url_arguments: dict = {k: v[0].decode('utf-8') for k, v in self.request.arguments.items()}
return url_arguments
ret: dict = {}
for k, v in self.request.arguments.items():
val = v[0].decode('utf-8')
try:
value = json.loads(val)
except json.JSONDecodeError:
value = val
ret[k] = value

return ret


class DefaultHandler404(AbstractApiHandler):
Expand Down
33 changes: 33 additions & 0 deletions src/aloha/service/http/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import time

import requests

from ...logger import LOG


def iter_over_request_files(request, url_files):
for file_key, files in request.files.items(): # iter over files uploaded by multipart
for f in files:
file_name, content_type = f["filename"], f["content_type"]
body = f.get('body', b"")
LOG.info(f"File {file_name} from multipart has content type {content_type} and length bytes={len(body)}")
yield file_key, file_name, content_type, body

for file_key, list_url in {'url_files': url_files or []}.items(): # iter over files specified by `url_files`
for url in sorted(set(list_url)):
try:
t_start = time.time()
resp = requests.get(url, stream=True) # download the file from given url
if resp.status_code == 200:
body = resp.content
content_type = resp.headers.get("Content-Type", "UNKNOWN")
else:
raise RuntimeError("Failed to download file after %s seconds with code=%s from URL %s" % (
time.time() - t_start, resp.status_code, url
))
del resp
except Exception as e:
raise e
t_cost = time.time() - t_start
LOG.info(f"File {url} has content type {content_type} and length bytes={len(body)}, downloaded in {t_cost} seconds")
yield 'url_files', url, content_type, body
8 changes: 6 additions & 2 deletions src/aloha/service/streamer/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import threading
import time

from redis import Redis

from .base import BaseStreamer, BaseWorker, TIMEOUT, TIME_SLEEP, logger
from ...logger import LOG

try:
from redis import Redis
except ImportError:
LOG.warn('redis not installed, service.streamer.RedisStreamer will no be available!')


class RedisWorker(BaseWorker):
Expand Down
Loading

0 comments on commit 5d46522

Please sign in to comment.