-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Project] Add pp_mobileseg onnx inference demo (#3268)
## Motivation Add a model deployment example. ## Modification Add an inference script and update the README. ## BC-breaking (Optional) None ## Use cases (Optional) In README.
- Loading branch information
Showing
2 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import argparse | ||
import time | ||
from typing import List, Tuple | ||
|
||
import cv2 | ||
import loguru | ||
import numpy as np | ||
import onnxruntime as ort | ||
|
||
logger = loguru.logger | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='PP_Mobileseg ONNX inference demo.') | ||
parser.add_argument('onnx_file', help='ONNX file path') | ||
parser.add_argument('image_file', help='Input image file path') | ||
parser.add_argument( | ||
'--input-size', | ||
type=int, | ||
nargs='+', | ||
default=[512, 512], | ||
help='input image size') | ||
parser.add_argument( | ||
'--device', help='device type for inference', default='cpu') | ||
parser.add_argument( | ||
'--save-path', | ||
help='path to save the output image', | ||
default='output.jpg') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def preprocess( | ||
img: np.ndarray, input_size: Tuple[int, int] = (512, 512) | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Preprocess image for inference.""" | ||
img_shape = img.shape[:2] | ||
# Resize | ||
resized_img = cv2.resize(img, input_size) | ||
|
||
# Normalize | ||
mean = np.array([123.575, 116.28, 103.53], dtype=np.float32) | ||
std = np.array([58.395, 57.12, 57.375], dtype=np.float32) | ||
resized_img = (resized_img - mean) / std | ||
|
||
return resized_img, img_shape | ||
|
||
|
||
def build_session(onnx_file: str, device: str = 'cpu') -> ort.InferenceSession: | ||
"""Build onnxruntime session. | ||
Args: | ||
onnx_file (str): ONNX file path. | ||
device (str): Device type for inference. | ||
Returns: | ||
sess (ort.InferenceSession): ONNXRuntime session. | ||
""" | ||
providers = ['CPUExecutionProvider' | ||
] if device == 'cpu' else ['CUDAExecutionProvider'] | ||
sess = ort.InferenceSession(path_or_bytes=onnx_file, providers=providers) | ||
|
||
return sess | ||
|
||
|
||
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: | ||
"""Inference RTMPose model. | ||
Args: | ||
sess (ort.InferenceSession): ONNXRuntime session. | ||
img (np.ndarray): Input image in shape. | ||
Returns: | ||
outputs (np.ndarray): Output of RTMPose model. | ||
""" | ||
# build input | ||
input_img = [img.transpose(2, 0, 1).astype(np.float32)] | ||
|
||
# build output | ||
sess_input = {sess.get_inputs()[0].name: input_img} | ||
sess_output = [] | ||
for out in sess.get_outputs(): | ||
sess_output.append(out.name) | ||
|
||
# inference | ||
outputs = sess.run(output_names=sess_output, input_feed=sess_input) | ||
|
||
return outputs | ||
|
||
|
||
def postprocess(outputs: List[np.ndarray], | ||
origin_shape: Tuple[int, int]) -> np.ndarray: | ||
"""Postprocess outputs of PP_Mobileseg model. | ||
Args: | ||
outputs (List[np.ndarray]): Outputs of PP_Mobileseg model. | ||
origin_shape (Tuple[int, int]): Input size of PP_Mobileseg model. | ||
Returns: | ||
seg_map (np.ndarray): Segmentation map. | ||
""" | ||
seg_map = outputs[0][0][0] | ||
seg_map = cv2.resize(seg_map.astype(np.float32), origin_shape) | ||
return seg_map | ||
|
||
|
||
def visualize(img: np.ndarray, | ||
seg_map: np.ndarray, | ||
filename: str = 'output.jpg', | ||
opacity: float = 0.8) -> np.ndarray: | ||
assert 0.0 <= opacity <= 1.0, 'opacity should be in range [0, 1]' | ||
palette = np.array(PALETTE) | ||
color_seg = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), | ||
dtype=np.uint8) | ||
for label, color in enumerate(palette): | ||
color_seg[seg_map == label, :] = color | ||
# convert to BGR | ||
color_seg = color_seg[..., ::-1] | ||
|
||
img = img * (1 - opacity) + color_seg * opacity | ||
cv2.imwrite(filename, img) | ||
|
||
return img | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
logger.info('Start running model inference...') | ||
|
||
# read image from file | ||
logger.info(f'1. Read image from file {args.image_file}...') | ||
img = cv2.imread(args.image_file) | ||
|
||
# build onnx model | ||
logger.info(f'2. Build onnx model from {args.onnx_file}...') | ||
sess = build_session(args.onnx_file, args.device) | ||
|
||
# preprocess | ||
logger.info('3. Preprocess image...') | ||
model_input_size = tuple(args.input_size) | ||
assert len(model_input_size) == 2 | ||
resized_img, origin_shape = preprocess(img, model_input_size) | ||
|
||
# inference | ||
logger.info('4. Inference...') | ||
start = time.time() | ||
outputs = inference(sess, resized_img) | ||
logger.info(f'Inference time: {time.time() - start:.4f}s') | ||
|
||
# postprocess | ||
logger.info('5. Postprocess...') | ||
h, w = origin_shape | ||
seg_map = postprocess(outputs, (w, h)) | ||
|
||
# visualize | ||
logger.info('6. Visualize...') | ||
visualize(img, seg_map, args.save_path) | ||
|
||
logger.info('Done...') | ||
|
||
|
||
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], | ||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], | ||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], | ||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], | ||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], | ||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], | ||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], | ||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], | ||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], | ||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], | ||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], | ||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], | ||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], | ||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], | ||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], | ||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], | ||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], | ||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], | ||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], | ||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], | ||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], | ||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], | ||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], | ||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], | ||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], | ||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], | ||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], | ||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], | ||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], | ||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], | ||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], | ||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], | ||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], | ||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], | ||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], | ||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], | ||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], | ||
[102, 255, 0], [92, 0, 255]] | ||
|
||
if __name__ == '__main__': | ||
main() |