Skip to content

Commit

Permalink
Format by black
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jun 24, 2023
1 parent ad4631f commit f4fd5cd
Show file tree
Hide file tree
Showing 40 changed files with 918 additions and 864 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
<a href="https://pepy.tech/project/rapidocr_onnxruntime"><img src="https://static.pepy.tech/personalized-badge/rapidocr_onnxruntime?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Ort"></a>
<a href="https://pypi.org/project/rapidocr-onnxruntime/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr-onnxruntime"></a>
<a href="https://github.com/RapidAI/RapidOCR/stargazers"><img src="https://img.shields.io/github/stars/RapidAI/RapidOCR?color=ccf"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href='https://rapidocr.readthedocs.io/en/latest/?badge=latest'>
<img src='https://readthedocs.org/projects/rapidocr/badge/?version=latest' alt='Documentation Status' />
</a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
</p>


Expand Down
29 changes: 14 additions & 15 deletions api/rapidocr_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent))


class OCRAPIUtils():
class OCRAPIUtils:
def __init__(self) -> None:
self.ocr = RapidOCR()

Expand All @@ -31,10 +31,10 @@ def __call__(self, img):
if not ocr_res:
return json.dumps({})

out_dict = {str(i): {'rec_txt': rec,
'dt_boxes': dt_box,
'score': score}
for i, (dt_box, rec, score) in enumerate(ocr_res)}
out_dict = {
str(i): {"rec_txt": rec, "dt_boxes": dt_box, "score": score}
for i, (dt_box, rec, score) in enumerate(ocr_res)
}
return out_dict


Expand All @@ -44,10 +44,10 @@ def __call__(self, img):

@app.get("/")
async def root():
return {'message': 'Welcome to RapidOCR Server!'}
return {"message": "Welcome to RapidOCR Server!"}


@app.post('/ocr')
@app.post("/ocr")
async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
if image_file:
img = Image.open(image_file.file)
Expand All @@ -57,25 +57,24 @@ async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
img = Image.open(io.BytesIO(img_b64decode))
else:
raise ValueError(
'When sending a post request, data or files must have a value.')
"When sending a post request, data or files must have a value."
)

ocr_res = processor(img)
return ocr_res


def main():
parser = argparse.ArgumentParser('rapidocr_api')
parser.add_argument('-ip', '--ip', type=str, default='0.0.0.0',
help='IP Address')
parser.add_argument('-p', '--port', type=int, default=9003,
help='IP port')
parser = argparse.ArgumentParser("rapidocr_api")
parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address")
parser.add_argument("-p", "--port", type=int, default=9003, help="IP port")
args = parser.parse_args()

cur_file_path = Path(__file__).resolve()
app_path = f'{cur_file_path.parent.name}.{cur_file_path.stem}:app'
app_path = f"{cur_file_path.parent.name}.{cur_file_path.stem}:app"
print(app_path)
uvicorn.run(app_path, host=args.ip, port=args.port, reload=True)


if __name__ == '__main__':
if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion docs/README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
<a href="https://pepy.tech/project/rapidocr_onnxruntime"><img src="https://static.pepy.tech/personalized-badge/rapidocr_onnxruntime?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Ort"></a>
<a href="https://pypi.org/project/rapidocr-onnxruntime/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr-onnxruntime"></a>
<a href="https://github.com/RapidAI/RapidOCR/stargazers"><img src="https://img.shields.io/github/stars/RapidAI/RapidOCR?color=ccf"></a>
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
<a href='https://rapidocr.readthedocs.io/en/latest/?badge=latest'>
<img src='https://readthedocs.org/projects/rapidocr/badge/?version=latest' alt='Documentation Status' />
<a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
</a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
</p>

<details>
Expand Down
26 changes: 12 additions & 14 deletions ocrweb/rapidocr_web/ocrweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,34 @@

root_dir = Path(__file__).resolve().parent

app = Flask(__name__, template_folder='templates')
app.config['MAX_CONTENT_LENGTH'] = 3 * 1024 * 1024
app = Flask(__name__, template_folder="templates")
app.config["MAX_CONTENT_LENGTH"] = 3 * 1024 * 1024
processor = OCRWebUtils()


@app.route('/')
@app.route("/")
def index():
return render_template('index.html')
return render_template("index.html")


@app.route('/ocr', methods=['POST'])
@app.route("/ocr", methods=["POST"])
def ocr():
if request.method == 'POST':
img_str = request.get_json().get('file', None)
if request.method == "POST":
img_str = request.get_json().get("file", None)
ocr_res = processor(img_str)
return ocr_res


def main():
parser = argparse.ArgumentParser('rapidocr_web')
parser.add_argument('-ip', '--ip', type=str, default='0.0.0.0',
help='IP Address')
parser.add_argument('-p', '--port', type=int, default=9003,
help='IP port')
parser = argparse.ArgumentParser("rapidocr_web")
parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address")
parser.add_argument("-p", "--port", type=int, default=9003, help="IP port")
args = parser.parse_args()

print(f'Successfully launched and visit https://{args.ip}:{args.port} to view.')
print(f"Successfully launched and visit https://{args.ip}:{args.port} to view.")
server = make_server(args.ip, args.port, app)
server.serve_forever()


if __name__ == '__main__':
if __name__ == "__main__":
main()
26 changes: 13 additions & 13 deletions ocrweb_multi/build.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os
import shutil

print('Compile ocrweb')
os.system('pyinstaller -y main.spec')
print("Compile ocrweb")
os.system("pyinstaller -y main.spec")

print('Compile wrapper')
os.system('windres .\wrapper.rc -O coff -o wrapper.res')
os.system('gcc .\wrapper.c wrapper.res -o dist/ocrweb.exe')
print("Compile wrapper")
os.system("windres .\wrapper.rc -O coff -o wrapper.res")
os.system("gcc .\wrapper.c wrapper.res -o dist/ocrweb.exe")

print('Copy config.yaml')
shutil.copy2('config.yaml', 'dist/config.yaml')
print("Copy config.yaml")
shutil.copy2("config.yaml", "dist/config.yaml")

print('Copy models')
shutil.copytree('models', 'dist/models', dirs_exist_ok=True)
os.remove('dist/models/.gitkeep')
print("Copy models")
shutil.copytree("models", "dist/models", dirs_exist_ok=True)
os.remove("dist/models/.gitkeep")

print('Pack to ocrweb.zip')
shutil.make_archive('ocrweb', 'zip', 'dist')
print("Pack to ocrweb.zip")
shutil.make_archive("ocrweb", "zip", "dist")

print('Done')
print("Done")
60 changes: 33 additions & 27 deletions ocrweb_multi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,67 @@
from utils.utils import tojson, parse_bool

app = Flask(__name__)
log = logging.getLogger('app')
log = logging.getLogger("app")
# 设置上传文件大小
app.config['MAX_CONTENT_LENGTH'] = 3 * 1024 * 1024
app.config["MAX_CONTENT_LENGTH"] = 3 * 1024 * 1024


@app.route('/')
@app.route("/")
def index():
return send_file('static/index.html')
return send_file("static/index.html")


def json_response(data, status=200):
return make_response(tojson(data), status, {"content-type": 'application/json'})
return make_response(tojson(data), status, {"content-type": "application/json"})


@app.route('/lang')
@app.route("/lang")
def get_languages():
"""返回可用语言列表"""
data = [
{'code': key, 'name': val['name']} for key, val in conf['languages'].items()
{"code": key, "name": val["name"]} for key, val in conf["languages"].items()
]
result = {'msg': 'OK', 'data': data}
log.info('Send langs: %s', data)
result = {"msg": "OK", "data": data}
log.info("Send langs: %s", data)
return json_response(result)


@app.route('/ocr', methods=['POST', 'GET'])
@app.route("/ocr", methods=["POST", "GET"])
def ocr():
"""执行文字识别"""
if conf['server'].get('token'):
if request.values.get('token') != conf['server']['token']:
return json_response({'msg': 'invalid token'}, status=403)
if conf["server"].get("token"):
if request.values.get("token") != conf["server"]["token"]:
return json_response({"msg": "invalid token"}, status=403)

lang = request.values.get('lang') or 'ch'
detect = parse_bool(request.values.get('detect') or 'true')
classify = parse_bool(request.values.get('classify') or 'true')
lang = request.values.get("lang") or "ch"
detect = parse_bool(request.values.get("detect") or "true")
classify = parse_bool(request.values.get("classify") or "true")

image_file = request.files.get('image')
image_file = request.files.get("image")
if not image_file:
return json_response({'msg': 'no image'}, 400)
return json_response({"msg": "no image"}, 400)
nparr = np.frombuffer(image_file.stream.read(), np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
log.info('Input: image %s, lang=%s, detect=%s, classify=%s', image.shape, lang, detect, classify)
log.info(
"Input: image %s, lang=%s, detect=%s, classify=%s",
image.shape,
lang,
detect,
classify,
)
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
result = detect_recognize(image, lang=lang, detect=detect, classify=classify)
log.info('OCR Done %s %s', result['ts'], len(result['results']))
return json_response({'msg': 'OK', 'data': result})
log.info("OCR Done %s %s", result["ts"], len(result["results"]))
return json_response({"msg": "OK", "data": result})


if __name__ == '__main__':
logging.basicConfig(level='INFO')
logging.getLogger('waitress').setLevel(logging.INFO)
if parse_bool(conf.get('debug', '0')):
if __name__ == "__main__":
logging.basicConfig(level="INFO")
logging.getLogger("waitress").setLevel(logging.INFO)
if parse_bool(conf.get("debug", "0")):
# Debug
app.run(host=conf['server']['host'], port=conf['server']['port'], debug=True)
app.run(host=conf["server"]["host"], port=conf["server"]["port"], debug=True)
else:
# Deploy with waitress
serve(app, host=conf['server']['host'], port=conf['server']['port'])
serve(app, host=conf["server"]["host"], port=conf["server"]["port"])
18 changes: 9 additions & 9 deletions ocrweb_multi/rapidocr/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from utils.utils import OrtInferSession


class ClsPostProcess():
class ClsPostProcess:
"""Convert between text-label and text-index"""

def __init__(self, label_list):
Expand All @@ -40,18 +40,18 @@ def __call__(self, preds, label=None):
return decode_out, label


class TextClassifier():
class TextClassifier:
def __init__(self, path, config):
self.cls_batch_num = config['batch_size']
self.cls_thresh = config['score_thresh']
self.cls_batch_num = config["batch_size"]
self.cls_thresh = config["score_thresh"]

session_instance = OrtInferSession(path)
self.session = session_instance.session
metamap = self.session.get_modelmeta().custom_metadata_map

self.cls_image_shape = json.loads(metamap['shape'])
self.cls_image_shape = json.loads(metamap["shape"])

labels = json.loads(metamap['labels'])
labels = json.loads(metamap["labels"])
self.postprocess_op = ClsPostProcess(labels)
self.input_name = session_instance.get_input_name()

Expand All @@ -65,7 +65,7 @@ def resize_norm_img(self, img):
resized_w = int(math.ceil(img_h * ratio))

resized_image = cv2.resize(img, (resized_w, img_h))
resized_image = resized_image.astype('float32')
resized_image = resized_image.astype("float32")
if img_c == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
Expand All @@ -91,7 +91,7 @@ def __call__(self, img_list: List[np.ndarray]):
indices = np.argsort(np.array(width_list))

img_num = len(img_list)
cls_res = [['', 0.0]] * img_num
cls_res = [["", 0.0]] * img_num
batch_num = self.cls_batch_num
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
Expand All @@ -115,7 +115,7 @@ def __call__(self, img_list: List[np.ndarray]):
for rno in range(len(cls_result)):
label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
if label == '180' and score > self.cls_thresh:
if label == "180" and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1
)
Expand Down
12 changes: 6 additions & 6 deletions ocrweb_multi/rapidocr/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@
from .detect_process import DBPostProcess, create_operators, transform


class TextDetector():
class TextDetector:
def __init__(self, path, config):
self.preprocess_op = create_operators(config['pre_process'])
self.postprocess_op = DBPostProcess(**config['post_process'])
self.preprocess_op = create_operators(config["pre_process"])
self.postprocess_op = DBPostProcess(**config["post_process"])

session_instance = OrtInferSession(path)
self.session = session_instance.session
self.input_name = session_instance.get_input_name()

def __call__(self, img):
if img is None:
raise ValueError('img is None')
raise ValueError("img is None")

ori_im_shape = img.shape[:2]

data = {'image': img}
data = {"image": img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
Expand All @@ -49,7 +49,7 @@ def __call__(self, img):

post_result = self.postprocess_op(preds[0], shape_list)

dt_boxes = post_result[0]['points']
dt_boxes = post_result[0]["points"]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im_shape)
return dt_boxes

Expand Down
Loading

0 comments on commit f4fd5cd

Please sign in to comment.