diff --git a/sahi/auto_model.py b/sahi/auto_model.py index 1768a2a5..2edab669 100644 --- a/sahi/auto_model.py +++ b/sahi/auto_model.py @@ -13,6 +13,7 @@ "yolov5sparse": "Yolov5SparseDetectionModel", "yolonas": "YoloNasDetectionModel", "yolov8onnx": "Yolov8OnnxDetectionModel", + "yolov8engine": "Yolov8EngineDetectionModel" } diff --git a/sahi/models/yolov8engine.py b/sahi/models/yolov8engine.py new file mode 100644 index 00000000..9957216a --- /dev/null +++ b/sahi/models/yolov8engine.py @@ -0,0 +1,181 @@ +import json +import logging +from typing import Any, Dict, List, Optional +import numpy as np + +logger = logging.getLogger(__name__) + +from sahi.models.base import DetectionModel +from sahi.prediction import ObjectPrediction +from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list +from sahi.utils.import_utils import check_requirements + +# Must be passed a config_path to a json with the below fields. +# Fields must be obtained from base PyTorch model before quantization. +# { +# "task": "segment", # or 'detect' +# "names": [ +# "dog", +# "cat", +# ], +# "imgsz": [ +# 640, +# 640 +# ], +# "half": true +# } + +class Yolov8EngineDetectionModel(DetectionModel): + def __init__(self): + try: + with open(self.config_path, 'r') as file: + self.cfg = json.load(file) + except Exception as e: + raise TypeError("config_path is not a valid yolov8engine config path") + + super().__init__() + + def check_dependencies(self) -> None: + check_requirements(["ultralytics"]) + + def load_model(self): + """ + Detection model is initialized and set to self.model. + """ + + from ultralytics import YOLO + + try: + model = YOLO(self.model_path, task=self.cfg["task"]) + self.set_model(model) + except Exception as e: + raise TypeError("model_path is not a valid yolov8engine model path: ", e) + + def set_model(self, model: Any): + """ + Sets the underlying YOLOv8 model. + Args: + model: Any + A YOLOv8 model + """ + + self.model = model + + # set category_mapping + if not self.category_mapping: + category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} + self.category_mapping = category_mapping + + def perform_inference(self, image: np.ndarray): + """ + Prediction is performed using self.model and the prediction result is set to self._original_predictions. + Args: + image: np.ndarray + A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. + """ + + # Confirm model is loaded + if self.model is None: + raise ValueError("Model is not loaded, load it by calling .load_model()") + + kwargs = {"cfg": None, "verbose": False, "conf": self.confidence_threshold, "device": self.device, "imgsz": self.cfg["imgsz"], "half": self.cfg["half"]} + + if self.image_size is not None: + kwargs = {"imgsz": self.image_size, **kwargs} + + prediction_result = self.model.predict(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR + + # We do not filter results again as confidence threshold is already applied above + prediction_result = [result.boxes.data for result in prediction_result] + + self._original_predictions = prediction_result + + @property + def category_names(self): + return self.cfg["names"] + + @property + def num_categories(self): + """ + Returns number of categories + """ + return len(self.cfg["names"]) + + @property + def has_mask(self): + """ + Returns if model output contains segmentation mask + """ + return False + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. + Args: + shift_amount_list: list of list + To shift the box and mask predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] + """ + original_predictions = self._original_predictions + + # compatilibty for sahi v0.8.15 + shift_amount_list = fix_shift_amount_list(shift_amount_list) + full_shape_list = fix_full_shape_list(full_shape_list) + + # handle all predictions + object_prediction_list_per_image = [] + for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions): + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None else full_shape_list[image_ind] + object_prediction_list = [] + + # process predictions + for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy(): + x1 = prediction[0] + y1 = prediction[1] + x2 = prediction[2] + y2 = prediction[3] + bbox = [x1, y1, x2, y2] + score = prediction[4] + category_id = int(prediction[5]) + category_name = self.category_mapping[str(category_id)] + + # fix negative box coords + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = max(0, bbox[2]) + bbox[3] = max(0, bbox[3]) + + # fix out of image box coords + if full_shape is not None: + bbox[0] = min(full_shape[1], bbox[0]) + bbox[1] = min(full_shape[0], bbox[1]) + bbox[2] = min(full_shape[1], bbox[2]) + bbox[3] = min(full_shape[0], bbox[3]) + + # ignore invalid predictions + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): + logger.warning(f"ignoring invalid prediction with bbox: {bbox}") + continue + + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + bool_mask=None, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + + self._object_prediction_list_per_image = object_prediction_list_per_image