Skip to content

Commit

Permalink
Optim OrtInferSession
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed May 12, 2024
1 parent 01430ed commit c5c2246
Showing 1 changed file with 50 additions and 40 deletions.
90 changes: 50 additions & 40 deletions python/rapidocr_onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]

CPU_EP = "CPUExecutionProvider"
CUDA_EP = "CUDAExecutionProvider"
DIRECTML_EP = "DmlExecutionProvider"


class OrtInferSession:
def __init__(self, config):
Expand All @@ -43,55 +47,56 @@ def __init__(self, config):
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
sess_opt.inter_op_num_threads = inter_op_num_threads

cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
model_path = config.get("model_path", None)
self._verify_model(model_path)

self.cfg_use_cuda = config.get("use_cuda", None)
EP_list = self._get_ep_list()

self.session = InferenceSession(
model_path, sess_options=sess_opt, providers=EP_list
)

self._verify_providers()

def _get_ep_list(self) -> List[Tuple[str, str]]:
had_providers: List[str] = get_available_providers()

cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
}

cuda_ep = "CUDAExecutionProvider"
cuda_provider_options = {
EP_list = [(CPU_EP, cpu_provider_opts)]

use_cuda = (
self.cfg_use_cuda and get_device() == "GPU" and CUDA_EP in had_providers
)
cuda_provider_opts = {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}
if use_cuda:
EP_list.insert(0, (CUDA_EP, cuda_provider_opts))

use_directml = os.name == "nt" and DIRECTML_EP in had_providers
if use_directml:
print("Windows platform detected, try to use DirectML as primary provider")
directml_options = cuda_provider_opts if use_cuda else cpu_provider_opts
EP_list.insert(0, (DIRECTML_EP, directml_options))
return EP_list

def _verify_providers(self) -> None:
session_providers = self.session.get_providers()
if os.name == "nt" and session_providers[0] != DIRECTML_EP:
warnings.warn(
"DirectML is not available for the current environment, the inference part is automatically shifted to be executed under other EP.\n"
)

EP_list = []
is_use_cude = config["use_cuda"] and get_device() == "GPU" and cuda_ep in get_available_providers()
if (is_use_cude):
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))

# if platform is windows, use directml as primary provider
if os.name == "nt":
directml_ep = "DmlExecutionProvider"
# print (get_available_providers())
if directml_ep in get_available_providers():
print ("Windows platform detected, try to use DirectML as primary provider")
EP_list.insert(0, (directml_ep,
cuda_provider_options if is_use_cude else cpu_provider_options
))


self._verify_model(config["model_path"])
self.session = InferenceSession(
config["model_path"], sess_options=sess_opt, providers=EP_list
)

# TODO: verify this is correct for detecting current_provider
current_provider = self.session.get_providers()[0]

# verify if the DirectML provider is used
if os.name == "nt":
if current_provider != directml_ep:
warnings.warn(
f"DirectML is not available for the current environment, the inference part is automatically shifted to be executed under other EP.\n"
)


if config["use_cuda"] and cuda_ep not in self.session.get_providers():
if self.cfg_use_cuda and CUDA_EP not in session_providers:
warnings.warn(
f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
f"{CUDA_EP} is not avaiable for current env, the inference part is automatically shifted to be executed under {CPU_EP}.\n"
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
"you can check their relations from the offical web site: "
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
Expand Down Expand Up @@ -127,10 +132,15 @@ def have_key(self, key: str = "character") -> bool:
return False

@staticmethod
def _verify_model(model_path):
def _verify_model(model_path: Union[str, Path, None]):
if model_path is None:
raise ValueError("model_path is None!")

model_path = Path(model_path)

if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")

if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")

Expand Down

0 comments on commit c5c2246

Please sign in to comment.