Skip to content

Commit

Permalink
Add support for RapidOCR class init to specify the number of threads …
Browse files Browse the repository at this point in the history
…in rapidocr_paddle
  • Loading branch information
SWHL committed Mar 7, 2024
1 parent 26438c2 commit e9f96b8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 15 deletions.
3 changes: 2 additions & 1 deletion python/rapidocr_openvino/ch_ppocr_v3_rec/text_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import cv2
import numpy as np

from rapidocr_openvino.utils import OpenVINOInferSession, read_yaml

from .utils import CTCLabelDecode
Expand All @@ -33,7 +34,7 @@ def __init__(self, config):
self.rec_batch_num = config["rec_batch_num"]

dict_path = str(Path(__file__).parent / "ppocr_keys_v1.txt")
self.character_dict_path = config.get("keys_path", dict_path)
self.character_dict_path = config.get("rec_keys_path", dict_path)
self.postprocess_op = CTCLabelDecode(self.character_dict_path)

self.infer = OpenVINOInferSession(config)
Expand Down
2 changes: 1 addition & 1 deletion python/rapidocr_paddle/ch_ppocr_v3_rec/text_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, config):
self.session = PaddleInferSession(config, mode="rec")

dict_path = str(Path(__file__).parent / "ppocr_keys_v1.txt")
self.character_dict_path = config.get("keys_path", dict_path)
self.character_dict_path = config.get("rec_keys_path", dict_path)
self.postprocess_op = CTCLabelDecode(self.character_dict_path)

self.rec_batch_num = config["rec_batch_num"]
Expand Down
22 changes: 9 additions & 13 deletions python/rapidocr_paddle/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@ Global:
min_height: 30
width_height_ratio: 8

use_cuda: &use_cuda false
gpu_id: &gpu_id 0
gpu_mem: &gpu_mem 500

cpu_math_library_num_threads: &infer_num_threads -1

Det:
use_cuda: *use_cuda
gpu_id: *gpu_id
gpu_mem: *gpu_mem
use_cuda: false
gpu_id: 0
gpu_mem: 500

cpu_math_library_num_threads: *infer_num_threads

Expand All @@ -33,9 +29,9 @@ Det:
score_mode: fast

Cls:
use_cuda: *use_cuda
gpu_id: *gpu_id
gpu_mem: *gpu_mem
use_cuda: false
gpu_id: 0
gpu_mem: 500

cpu_math_library_num_threads: *infer_num_threads

Expand All @@ -47,9 +43,9 @@ Cls:
label_list: ['0', '180']

Rec:
use_cuda: *use_cuda
gpu_id: *gpu_id
gpu_mem: *gpu_mem
use_cuda: false
gpu_id: 0
gpu_mem: 500

cpu_math_library_num_threads: *infer_num_threads

Expand Down
16 changes: 16 additions & 0 deletions python/rapidocr_paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def init_args():
global_group.add_argument("--min_height", type=int, default=30)
global_group.add_argument("--width_height_ratio", type=int, default=8)

global_group.add_argument("--cpu_math_library_num_threads", type=int, default=-1)

det_group = parser.add_argument_group(title="Det")
det_group.add_argument("--det_use_cuda", action="store_true", default=False)
det_group.add_argument("--det_gpu_id", type=int, default=0)
Expand Down Expand Up @@ -299,6 +301,7 @@ def init_args():
rec_group.add_argument("--rec_gpu_id", type=int, default=0)
rec_group.add_argument("--rec_gpu_mem", type=int, default=500)
rec_group.add_argument("--rec_model_path", type=str, default=None)
rec_group.add_argument("--rec_keys_path", type=str, default=None)
rec_group.add_argument("--rec_img_shape", type=list, default=[3, 48, 320])
rec_group.add_argument("--rec_batch_num", type=int, default=6)

Expand Down Expand Up @@ -358,8 +361,21 @@ def __call__(self, config, **kwargs):
config["Rec"], rec_dict, "rec_", ["rec_model_path", "rec_use_cuda"]
),
}

update_params = ["cpu_math_library_num_threads"]
new_config = self.update_global_to_module(
config, update_params, src="Global", dsts=["Det", "Cls", "Rec"]
)
return new_config

def update_global_to_module(
self, config, params: List[str], src: str, dsts: List[str]
):
for dst in dsts:
for param in params:
config[dst].update({param: config[src][param]})
return config

def update_global_params(self, config, global_dict):
if global_dict:
config.update(global_dict)
Expand Down

0 comments on commit e9f96b8

Please sign in to comment.