Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added yolov8engine.py for TensorRT-quantized YOLOv8 models #1046

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions sahi/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"yolov5sparse": "Yolov5SparseDetectionModel",
"yolonas": "YoloNasDetectionModel",
"yolov8onnx": "Yolov8OnnxDetectionModel",
"yolov8engine": "Yolov8EngineDetectionModel"
}


Expand Down
181 changes: 181 additions & 0 deletions sahi/models/yolov8engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import json
import logging
from typing import Any, Dict, List, Optional
import numpy as np

logger = logging.getLogger(__name__)

from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import check_requirements

# Must be passed a config_path to a json with the below fields.
# Fields must be obtained from base PyTorch model before quantization.
# {
# "task": "segment", # or 'detect'
# "names": [
# "dog",
# "cat",
# ],
# "imgsz": [
# 640,
# 640
# ],
# "half": true
# }

class Yolov8EngineDetectionModel(DetectionModel):
def __init__(self):
try:
with open(self.config_path, 'r') as file:
self.cfg = json.load(file)
except Exception as e:
raise TypeError("config_path is not a valid yolov8engine config path")

super().__init__()

def check_dependencies(self) -> None:
check_requirements(["ultralytics"])

def load_model(self):
"""
Detection model is initialized and set to self.model.
"""

from ultralytics import YOLO

try:
model = YOLO(self.model_path, task=self.cfg["task"])
self.set_model(model)
except Exception as e:
raise TypeError("model_path is not a valid yolov8engine model path: ", e)

def set_model(self, model: Any):
"""
Sets the underlying YOLOv8 model.
Args:
model: Any
A YOLOv8 model
"""

self.model = model

# set category_mapping
if not self.category_mapping:
category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)}
self.category_mapping = category_mapping

def perform_inference(self, image: np.ndarray):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""

# Confirm model is loaded
if self.model is None:
raise ValueError("Model is not loaded, load it by calling .load_model()")

kwargs = {"cfg": None, "verbose": False, "conf": self.confidence_threshold, "device": self.device, "imgsz": self.cfg["imgsz"], "half": self.cfg["half"]}

if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}

prediction_result = self.model.predict(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR

# We do not filter results again as confidence threshold is already applied above
prediction_result = [result.boxes.data for result in prediction_result]

self._original_predictions = prediction_result

@property
def category_names(self):
return self.cfg["names"]

@property
def num_categories(self):
"""
Returns number of categories
"""
return len(self.cfg["names"])

@property
def has_mask(self):
"""
Returns if model output contains segmentation mask
"""
return False

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.15
shift_amount_list = fix_shift_amount_list(shift_amount_list)
full_shape_list = fix_full_shape_list(full_shape_list)

# handle all predictions
object_prediction_list_per_image = []
for image_ind, image_predictions_in_xyxy_format in enumerate(original_predictions):
shift_amount = shift_amount_list[image_ind]
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []

# process predictions
for prediction in image_predictions_in_xyxy_format.cpu().detach().numpy():
x1 = prediction[0]
y1 = prediction[1]
x2 = prediction[2]
y2 = prediction[3]
bbox = [x1, y1, x2, y2]
score = prediction[4]
category_id = int(prediction[5])
category_name = self.category_mapping[str(category_id)]

# fix negative box coords
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = max(0, bbox[2])
bbox[3] = max(0, bbox[3])

# fix out of image box coords
if full_shape is not None:
bbox[0] = min(full_shape[1], bbox[0])
bbox[1] = min(full_shape[0], bbox[1])
bbox[2] = min(full_shape[1], bbox[2])
bbox[3] = min(full_shape[0], bbox[3])

# ignore invalid predictions
if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue

object_prediction = ObjectPrediction(
bbox=bbox,
category_id=category_id,
score=score,
bool_mask=None,
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image
Loading