diff --git a/.Dockerignore b/.Dockerignore new file mode 100644 index 0000000..6320cd2 --- /dev/null +++ b/.Dockerignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0ceb8d0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.7-slim + +WORKDIR /opt/ann + +RUN apt-get update && \ + apt-get install --no-install-suggests --no-install-recommends -y build-essential && \ + apt-get autoremove --purge && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +COPY *.py /opt/ann/ +COPY requirements.txt /opt/ann/ +RUN pip3 install --no-cache-dir -r /opt/ann/requirements.txt + +WORKDIR /opt/ann +ENTRYPOINT ["/usr/local/bin/gunicorn", "--config", "/opt/ann/gunicorn_conf.py", "api:api"] diff --git a/api.py b/api.py new file mode 100644 index 0000000..8fced75 --- /dev/null +++ b/api.py @@ -0,0 +1,222 @@ +import pathlib +import random +from typing import Any, Dict, List, Optional + +import annoy +import falcon +import numpy as np +import schema +import settings +from embeddings import EMBEDDING_STORE, add_logos, get_embedding +from falcon.media.validators import jsonschema +from falcon_cors import CORS +from falcon_multipart.middleware import MultipartMiddleware +from sentry_sdk.integrations.falcon import FalconIntegration +from utils import get_image_from_url, get_logger, text_file_iter + +logger = get_logger() + +settings.init_sentry(integrations=[FalconIntegration()]) + + +def load_index(file_path: pathlib.Path, dimension: int) -> annoy.AnnoyIndex: + index = annoy.AnnoyIndex(dimension, "euclidean") + index.load(str(file_path), prefault=True) + return index + + +def load_keys(file_path: pathlib.Path) -> List[int]: + return [int(x) for x in text_file_iter(file_path)] + + +class ANNIndex: + def __init__(self, index: annoy.AnnoyIndex, keys: List[int]): + self.index: annoy.AnnoyIndex = index + self.keys: List[int] = keys + self.key_to_ann_id = {x: i for i, x in enumerate(self.keys)} + + @classmethod + def load(cls, index_dir: pathlib.Path) -> "ANNIndex": + dimension = settings.INDEX_DIM[index_dir.name] + index = load_index(index_dir / settings.INDEX_FILE_NAME, dimension) + keys = load_keys(index_dir / settings.KEYS_FILE_NAME) + return cls(index, keys) + + +logger.info("Loading ANN indexes...") +INDEXES: Dict[str, ANNIndex] = { + index_dir.name: ANNIndex.load(index_dir) + for index_dir in settings.DATA_DIR.iterdir() + if index_dir.is_dir() +} +logger.info("Index loaded") + + +class ANNResource: + def on_get( + self, req: falcon.Request, resp: falcon.Response, logo_id: Optional[int] = None + ): + index_name = req.get_param("index", default=settings.DEFAULT_INDEX) + count = req.get_param_as_int("count", min_value=1, max_value=500, default=100) + + if index_name not in INDEXES: + raise falcon.HTTPBadRequest("unknown index: {}".format(index_name)) + + ann_index = INDEXES[index_name] + + if logo_id is None: + logo_id = ann_index.keys[random.randint(0, len(ann_index.keys) - 1)] + + results = get_nearest_neighbors(ann_index, count, logo_id) + + if results is None: + resp.status = falcon.HTTP_404 + else: + resp.media = {"results": results, "count": len(results)} + + +class ANNBatchResource: + def on_get(self, req: falcon.Request, resp: falcon.Response): + index_name = req.get_param("index", default=settings.DEFAULT_INDEX) + count = req.get_param_as_int("count", min_value=1, max_value=500, default=100) + logo_ids = req.get_param_as_list( + "logo_ids", required=True, transform=int, default=[] + ) + if index_name not in INDEXES: + raise falcon.HTTPBadRequest("unknown index: {}".format(index_name)) + + ann_index = INDEXES[index_name] + results = {} + + for logo_id in logo_ids: + logo_results = get_nearest_neighbors(ann_index, count, logo_id) + + if logo_results is not None: + results[logo_id] = logo_results + + resp.media = { + "results": results, + "count": len(results), + } + + +def get_nearest_neighbors( + ann_index: ANNIndex, count: int, logo_id: int +) -> Optional[List[Dict[str, Any]]]: + if logo_id in ann_index.key_to_ann_id: + item_index = ann_index.key_to_ann_id[logo_id] + indexes, distances = ann_index.index.get_nns_by_item( + item_index, count, include_distances=True + ) + else: + embedding = get_embedding(logo_id) + + if embedding is None: + return None + + indexes, distances = ann_index.index.get_nns_by_vector( + embedding, count, include_distances=True + ) + + logo_ids = [ann_index.keys[index] for index in indexes] + results = [] + + for ann_logo_id, distance in zip(logo_ids, distances): + results.append({"distance": distance, "logo_id": ann_logo_id}) + + return results + + +class ANNEmbeddingResource: + def on_post(self, req: falcon.Request, resp: falcon.Response): + index_name = req.get_param("index", default=settings.DEFAULT_INDEX) + + if index_name not in INDEXES: + raise falcon.HTTPBadRequest("unknown index: {}".format(index_name)) + + ann_index = INDEXES[index_name] + + count = req.media.get("count", 1) + embedding = req.media["embedding"] + + if len(embedding) != settings.INDEX_DIM: + raise falcon.HTTPBadRequest( + "invalid dimension", + "embedding must be of size {}, here: {}".format( + settings.INDEX_DIM, len(embedding) + ), + ) + + indexes, distances = ann_index.index.get_nns_by_vector( + embedding, count, include_distances=True + ) + + logo_ids = [ann_index.keys[index] for index in indexes] + results = [] + + for ann_logo_id, distance in zip(logo_ids, distances): + results.append({"distance": distance, "logo_id": ann_logo_id}) + + resp.media = {"results": results, "count": len(results)} + + +class AddLogoResource: + @jsonschema.validate(schema.ADD_LOGO_SCHEMA) + def on_post(self, req: falcon.Request, resp: falcon.Response): + image_url = req.media["image_url"] + logos = req.media["logos"] + logo_ids = [logo["id"] for logo in logos] + + if all(logo_id in EMBEDDING_STORE for logo_id in logo_ids): + resp.media = { + "added": 0, + } + return + + bounding_boxes = [logo["bounding_box"] for logo in logos] + + image = get_image_from_url(image_url) + + if image is None: + raise falcon.HTTPBadRequest("invalid image") + + if np.array(image).shape[-1] != 3: + image = image.convert("RGB") + + added = add_logos(image, logo_ids, bounding_boxes) + resp.media = { + "added": added, + } + + +class ANNCountResource: + def on_get(self, req: falcon.Request, resp: falcon.Response): + resp.media = {"count": len(EMBEDDING_STORE)} + + +class ANNStoredLogoResource: + def on_get(self, req: falcon.Request, resp: falcon.Response): + resp.media = {"stored": list(EMBEDDING_STORE.get_logo_ids())} + + +cors = CORS( + allow_all_origins=True, + allow_all_headers=True, + allow_all_methods=True, + allow_credentials_all_origins=True, + max_age=600, +) + +api = falcon.API(middleware=[cors.middleware, MultipartMiddleware()]) + +# Parse form parameters +api.req_options.auto_parse_form_urlencoded = True +api.req_options.strip_url_path_trailing_slash = True +api.req_options.auto_parse_qs_csv = True +api.add_route("/api/v1/ann/{logo_id:int}", ANNResource()) +api.add_route("/api/v1/ann", ANNResource()) +api.add_route("/api/v1/ann/batch", ANNBatchResource()) +api.add_route("/api/v1/ann/from_embedding", ANNEmbeddingResource()) +api.add_route("/api/v1/ann/add", AddLogoResource()) +api.add_route("/api/v1/ann/count", ANNCountResource()) +api.add_route("/api/v1/ann/stored", ANNStoredLogoResource()) diff --git a/embeddings.py b/embeddings.py new file mode 100644 index 0000000..34d9715 --- /dev/null +++ b/embeddings.py @@ -0,0 +1,191 @@ +import operator +import pathlib +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import h5py +import numpy as np +import settings +import torch +from efficientnet_pytorch import EfficientNet +from PIL import Image + + +class EmbeddingStore: + def __init__(self, hdf5_path: pathlib.Path): + self.hdf5_path = hdf5_path + self.logo_id_to_idx: Dict[int, int] = self.load() + self.offset = ( + max(self.logo_id_to_idx.values()) + 1 if self.logo_id_to_idx else 0 + ) + + def __len__(self): + return len(self.logo_id_to_idx) + + def __contains__(self, logo_id: int) -> bool: + return self.get_index(logo_id) is not None + + def get_logo_ids(self) -> Iterable[int]: + return self.logo_id_to_idx.keys() + + def get_index(self, logo_id: int) -> Optional[int]: + return self.logo_id_to_idx.get(logo_id) + + def get_embedding(self, logo_id: int) -> Optional[np.ndarray]: + idx = self.get_index(logo_id) + + if idx is None: + return None + + if self.hdf5_path.is_file(): + with h5py.File(self.hdf5_path, "r") as f: + embedding_dset = f["embedding"] + return embedding_dset[idx] + + return None + + def load(self): + if self.hdf5_path.is_file(): + with h5py.File(self.hdf5_path, "r") as f: + external_id_dset = f["external_id"] + array = external_id_dset[:] + non_zero_indexes = np.flatnonzero(array) + array = array[: non_zero_indexes[-1] + 1] + return {int(x): i for i, x in enumerate(array)} + + return {} + + def iter_embeddings(self) -> Iterable[Tuple[int, np.ndarray]]: + if not self.hdf5_path.is_file(): + return + + idx_logo_id = sorted( + ((idx, logo_id) for logo_id, idx in self.logo_id_to_idx.items()), + key=operator.itemgetter(0), + ) + + with h5py.File(self.hdf5_path, "r") as f: + embedding_dset = f["embedding"] + for idx, logo_id in idx_logo_id: + embedding = embedding_dset[idx] + yield logo_id, embedding + + def save_embeddings( + self, + embeddings: np.ndarray, + external_ids: np.ndarray, + ): + file_exists = self.hdf5_path.is_file() + + with h5py.File(self.hdf5_path, "a") as f: + if not file_exists: + embedding_dset = f.create_dataset( + "embedding", + (settings.DEFAULT_HDF5_COUNT, embeddings.shape[-1]), + dtype="f", + chunks=True, + ) + external_id_dset = f.create_dataset( + "external_id", + (settings.DEFAULT_HDF5_COUNT,), + dtype="i", + chunks=True, + ) + else: + embedding_dset = f["embedding"] + external_id_dset = f["external_id"] + + slicing = slice(self.offset, self.offset + len(embeddings)) + embedding_dset[slicing] = embeddings + external_id_dset[slicing] = external_ids + + for external_id, idx in zip( + external_ids, range(self.offset, self.offset + len(embeddings)) + ): + self.logo_id_to_idx[int(external_id)] = idx + + self.offset += len(embeddings) + + +EMBEDDING_STORE = EmbeddingStore(settings.EMBEDDINGS_HDF5_PATH) + + +def build_model(model_type: str): + return EfficientNet.from_pretrained(model_type) + + +def generate_embeddings(model, images: np.ndarray, device: torch.device) -> np.ndarray: + images = np.moveaxis(images, -1, 1) # move channel dim to 1st dim + + with torch.no_grad(): + torch_images = torch.tensor(images, dtype=torch.float32, device=device) + embeddings = model.extract_features(torch_images).cpu().numpy() + + return np.max(embeddings, (-1, -2)) + + +def crop_image( + image: Image.Image, bounding_box: Tuple[float, float, float, float] +) -> Image.Image: + y_min, x_min, y_max, x_max = bounding_box + (left, right, top, bottom) = ( + x_min * image.width, + x_max * image.width, + y_min * image.height, + y_max * image.height, + ) + return image.crop((left, top, right, bottom)) + + +def get_embedding(logo_id: int) -> Optional[np.ndarray]: + return EMBEDDING_STORE.get_embedding(logo_id) + + +def add_logos( + image: Image.Image, + external_ids: List[int], + bounding_boxes: List[Tuple[float, float, float, float]], + device: Optional[torch.device] = None, +) -> int: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = ModelStore.get(settings.DEFAULT_MODEL, device) + image_dim = settings.IMAGE_INPUT_DIM[settings.DEFAULT_MODEL] + + selected_external_ids = [] + selected_bounding_boxes = [] + + for (bounding_box, external_id) in zip(bounding_boxes, external_ids): + if external_id in EMBEDDING_STORE: + continue + + selected_external_ids.append(external_id) + selected_bounding_boxes.append(bounding_box) + + if not selected_bounding_boxes: + return 0 + + images = np.zeros((len(selected_bounding_boxes), image_dim, image_dim, 3)) + for i, bounding_box in enumerate(selected_bounding_boxes): + cropped_image = crop_image(image, bounding_box) + cropped_image = cropped_image.resize((image_dim, image_dim)) + images[i] = np.array(cropped_image) + + embeddings = generate_embeddings(model, images, device) + EMBEDDING_STORE.save_embeddings( + embeddings, np.array(selected_external_ids, dtype="i") + ) + return len(embeddings) + + +class ModelStore: + store: Dict[str, Any] = {} + + @classmethod + def get(cls, model_name: str, device: torch.device): + if model_name not in cls.store: + model = build_model(model_name) + model = model.to(device) + cls.store[model_name] = model + + return cls.store[model_name] diff --git a/gunicorn_conf.py b/gunicorn_conf.py new file mode 100644 index 0000000..a2f2287 --- /dev/null +++ b/gunicorn_conf.py @@ -0,0 +1,3 @@ +bind = ":5501" +workers = 1 +timeout = 60 diff --git a/manage.py b/manage.py new file mode 100644 index 0000000..e8e20fd --- /dev/null +++ b/manage.py @@ -0,0 +1,62 @@ +if __name__ == "__main__": + import pathlib + + import click + + @click.group() + def cli(): + pass + + @click.command() + @click.argument("output", type=pathlib.Path) + @click.option("--tree-count", type=int, default=100) + def generate_index(output: pathlib.Path, tree_count: int): + import shutil + import tempfile + + import settings + import tqdm + from annoy import AnnoyIndex + from embeddings import EmbeddingStore + from utils import get_logger + + logger = get_logger() + + with tempfile.TemporaryDirectory() as tmp_dir: + embedding_path = pathlib.Path(tmp_dir) / "embeddings.hdf5" + logger.info(f"Copying embedding file to {embedding_path}...") + shutil.copy(str(settings.EMBEDDINGS_HDF5_PATH), str(embedding_path)) + + logger.info(f"Loading {embedding_path}...") + embedding_store = EmbeddingStore(embedding_path) + + index = None + offset: int = 0 + keys = [] + + logger.info("Adding embeddings to index...") + for logo_id, embedding in tqdm.tqdm(embedding_store.iter_embeddings()): + if index is None: + output_dim = embedding.shape[-1] + index = AnnoyIndex(output_dim, "euclidean") + + index.add_item(offset, embedding) + keys.append(int(logo_id)) + offset += 1 + + logger.info("Building index...") + if index is not None: + index.build(tree_count) + index.save(str(output)) + + logger.info("Index built.") + logger.info("Saving keys...") + + with output.with_suffix(".txt").open("w") as f: + for key in keys: + f.write(str(key) + "\n") + + logger.info("Keys saved.") + + cli.add_command(generate_index) + cli() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..73002e4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +annoy==1.16.3 +gunicorn==20.0.4 +falcon==2.0.0 +falcon-cors==1.1.7 +falcon-multipart==0.2.0 +sentry-sdk[falcon]==0.14.4 +efficientnet_pytorch==0.6.3 +torch==1.5.0 +h5py==2.10.0 +Pillow==7.1.2 +requests==2.23.0 +jsonschema==3.2.0 +click==7.1.2 +tqdm==4.47.0 \ No newline at end of file diff --git a/schema.py b/schema.py new file mode 100644 index 0000000..6d4d8b4 --- /dev/null +++ b/schema.py @@ -0,0 +1,27 @@ +from typing import Any, Dict + +ADD_LOGO_SCHEMA: Dict[str, Any] = { + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Add Logo", + "type": "object", + "properties": { + "logos": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "bounding_box": { + "type": "array", + "minItems": 4, + "maxItems": 4, + "items": {"type": "number"}, + }, + }, + "required": ["id", "bounding_box"], + }, + }, + "image_url": {"type": "string", "format": "uri"}, + }, + "required": ["image_url", "logos"], +} diff --git a/settings.py b/settings.py new file mode 100644 index 0000000..6ffe8d1 --- /dev/null +++ b/settings.py @@ -0,0 +1,43 @@ +import os +import pathlib +from typing import Dict, Sequence + +import sentry_sdk +from sentry_sdk.integrations import Integration + +PROJECT_DIR = pathlib.Path(__file__).parent +DATA_DIR = PROJECT_DIR / "data" + +INDEX_DIM: Dict[str, int] = {"efficientnet-b0": 1280, "efficientnet-b5": 2048} +IMAGE_INPUT_DIM: Dict[str, int] = {"efficientnet-b0": 224} + +INDEX_FILE_NAME = "index.bin" +KEYS_FILE_NAME = "index.txt" + +DEFAULT_INDEX = "efficientnet-b0" +DEFAULT_MODEL = "efficientnet-b0" +DEFAULT_HDF5_COUNT = 10000000 +EMBEDDINGS_HDF5_PATH = DATA_DIR / "efficientnet-b0.hdf5" + +# Should be either 'prod' or 'dev'. +_ann_instance = os.environ.get("ANN_INSTANCE", "dev") + +if _ann_instance != "prod" and _ann_instance != "dev": + raise ValueError( + "ANN_INSTANCE should be either 'prod' or 'dev', got %s" % _ann_instance + ) + +_sentry_dsn = os.environ.get("SENTRY_DSN") + + +def init_sentry(integrations: Sequence[Integration] = ()): + if _sentry_dsn: + sentry_sdk.init( + _sentry_dsn, + environment=_ann_instance, + integrations=integrations, + ) + else: + raise ValueError( + "init_sentry was requested, yet SENTRY_DSN env variable was not provided" + ) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..cf4d323 --- /dev/null +++ b/utils.py @@ -0,0 +1,128 @@ +import gzip +import json +import logging +import os +import pathlib +import sys +from io import BytesIO +from typing import Callable, Dict, Iterable, Optional, Tuple, Union + +import requests +from PIL import Image + + +def get_logger(name=None, level: Optional[int] = None): + logger = logging.getLogger(name) + + if level is None: + log_level = os.environ.get("LOG_LEVEL", "INFO").upper() + level = logging.getLevelName(log_level) + + if not isinstance(level, int): + print( + "Unknown log level: {}, fallback to INFO".format(log_level), + file=sys.stderr, + ) + level = 20 + + logger.setLevel(level) + + if name is None: + configure_root_logger(logger, level) + + return logger + + +def configure_root_logger(logger, level: int = 20): + logger.setLevel(level) + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s :: %(processName)s :: " + "%(threadName)s :: %(levelname)s :: " + "%(message)s" + ) + handler.setFormatter(formatter) + handler.setLevel(level) + logger.addHandler(handler) + + +def jsonl_iter(jsonl_path: Union[str, pathlib.Path]) -> Iterable[Dict]: + open_fn = get_open_fn(jsonl_path) + + with open_fn(str(jsonl_path), "rt", encoding="utf-8") as f: + yield from jsonl_iter_fp(f) + + +def gzip_jsonl_iter(jsonl_path: Union[str, pathlib.Path]) -> Iterable[Dict]: + with gzip.open(jsonl_path, "rt", encoding="utf-8") as f: + yield from jsonl_iter_fp(f) + + +def jsonl_iter_fp(fp) -> Iterable[Dict]: + for line in fp: + line = line.strip("\n") + if line: + yield json.loads(line) + + +def dump_jsonl(filepath: Union[str, pathlib.Path], json_iter: Iterable[Dict]) -> int: + count = 0 + open_fn = get_open_fn(filepath) + + with open_fn(str(filepath), "wt") as f: + for item in json_iter: + f.write(json.dumps(item) + "\n") + count += 1 + + return count + + +def get_open_fn(filepath: Union[str, pathlib.Path]) -> Callable: + filepath = str(filepath) + if filepath.endswith(".gz"): + return gzip.open + else: + return open + + +def text_file_iter(filepath: Union[str, pathlib.Path]) -> Iterable[str]: + open_fn = get_open_fn(filepath) + + with open_fn(str(filepath), "rt") as f: + for item in f: + item = item.strip("\n") + + if item: + yield item + + +def crop_image( + image: Image.Image, bounding_box: Tuple[float, float, float, float] +) -> Image.Image: + y_min, x_min, y_max, x_max = bounding_box + (left, right, top, bottom) = ( + x_min * image.width, + x_max * image.width, + y_min * image.height, + y_max * image.height, + ) + return image.crop((left, top, right, bottom)) + + +def get_image_from_url( + image_url: str, + error_raise: bool = False, + session: Optional[requests.Session] = None, +) -> Optional[Image.Image]: + if session: + r = session.get(image_url) + else: + r = requests.get(image_url) + + if error_raise: + r.raise_for_status() + + if r.status_code != 200: + return None + + return Image.open(BytesIO(r.content))