Skip to content

Commit

Permalink
Merge pull request #23 from RapidAI/table_optim
Browse files Browse the repository at this point in the history
feature: add table cls model
  • Loading branch information
SWHL committed Sep 12, 2024
2 parents ad4a3ed + 69cec6f commit 9ebae19
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 1 deletion.
68 changes: 68 additions & 0 deletions .github/workflows/table_cls.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
name: Push table_cls to pypi

on:
push:
branches: [ main ]
paths:
- 'table_cls/**'
# tags:
# - v*

jobs:
UnitTesting:
runs-on: ubuntu-latest
steps:
- name: Pull latest code
uses: actions/checkout@v3

- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: '3.10'
architecture: 'x64'

- name: Display Python version
run: python -c "import sys; print(sys.version)"

- name: Unit testings
run: |
pip install -r requirements.txt
pip install pytest beautifulsoup4
wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip
unzip table_cls_models.zip
mv table_cls_models/*.onnx table_cls/models/
pytest tests/test_table_cls.py
GenerateWHL_PushPyPi:
needs: UnitTesting
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up Python 3.7
uses: actions/setup-python@v4
with:
python-version: '3.7'
architecture: 'x64'

- name: Run setup.py
run: |
pip install -r requirements.txt
python -m pip install --upgrade pip
pip install wheel get_pypi_latest_version
wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/table_cls_models.zip
unzip table_cls_models.zip
mv table_cls_models/*.onnx table_cls/models/
python setup_table_cls.py bdist_wheel "${{ github.event.head_commit.message }}"
- name: Publish distribution 📦 to PyPI
uses: pypa/[email protected]
with:
password: ${{ secrets.PYPI_API_TOKEN }}
packages_dir: dist/
1 change: 0 additions & 1 deletion demo_lineless.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

complete_html = format_html(html)
os.makedirs(os.path.dirname(f"{output_dir}/table.html"), exist_ok=True)

with open(f"{output_dir}/table.html", "w", encoding="utf-8") as file:
file.write(complete_html)

Expand Down
9 changes: 9 additions & 0 deletions demo_table_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# -*- encoding: utf-8 -*-
from table_cls import TableCls

table_cls = TableCls()
output_dir = "outputs"
img_path = "tests/test_files/table_cls/lineless_table.png"
cls_str, elapse = table_cls(img_path)
print(cls_str)
print(elapse)
62 changes: 62 additions & 0 deletions setup_table_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import sys
from typing import List, Union
from pathlib import Path
from get_pypi_latest_version import GetPyPiLatestVersion

import setuptools


def read_txt(txt_path: Union[Path, str]) -> List[str]:
with open(txt_path, "r", encoding="utf-8") as f:
data = [v.rstrip("\n") for v in f]
return data


MODULE_NAME = "table_cls"

obtainer = GetPyPiLatestVersion()
try:
latest_version = obtainer(MODULE_NAME)
except Exception:
latest_version = "0.0.0"

VERSION_NUM = obtainer.version_add_one(latest_version)

if len(sys.argv) > 2:
match_str = " ".join(sys.argv[2:])
matched_versions = obtainer.extract_version(match_str)
if matched_versions:
VERSION_NUM = matched_versions
sys.argv = sys.argv[:2]

setuptools.setup(
name=MODULE_NAME,
version=VERSION_NUM,
platforms="Any",
description="A table classifier for further table rec",
long_description="A table classifier that distinguishes between wired and wireless tables",
long_description_content_type="text/markdown",
author="SWHL",
author_email="[email protected]",
url="https://github.com/RapidAI/TableStructureRec",
license="Apache-2.0",
install_requires=read_txt("requirements.txt"),
include_package_data=True,
packages=setuptools.find_packages(include=[MODULE_NAME, f"{MODULE_NAME}.*"]),
package_data={
MODULE_NAME: ["*.onnx"],
},
keywords=["table-classifier", "wired", "wireless", "table-recognition"],
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
python_requires=">=3.6,<3.12",
)
1 change: 1 addition & 0 deletions table_cls/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .main import TableCls
48 changes: 48 additions & 0 deletions table_cls/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import time

from pathlib import Path
import numpy as np
import onnxruntime
from PIL import Image

from .utils import InputType, LoadImage

cur_dir = Path(__file__).resolve().parent
table_cls_model_path = cur_dir / "models" / "table_cls.onnx"


class TableCls:
def __init__(self, device="cpu"):
providers = (
["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
)
self.table_cls = onnxruntime.InferenceSession(
table_cls_model_path, providers=providers
)
self.inp_h = 224
self.inp_w = 224
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
self.cls = {0: "wired", 1: "wireless"}
self.load_img = LoadImage()

def _preprocess(self, image):
img = Image.fromarray(np.uint8(image))
img = img.resize((self.inp_h, self.inp_w))
img = np.array(img, dtype=np.float32) / 255.0
img -= self.mean
img /= self.std
img = img.transpose(2, 0, 1) # HWC to CHW
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
return img

def __call__(self, content: InputType):
ss = time.perf_counter()
img = self.load_img(content)
img = self._preprocess(img)
output = self.table_cls.run(None, {"input": img})
predict = np.exp(output[0] - np.max(output[0], axis=1, keepdims=True))
predict /= np.sum(predict, axis=1, keepdims=True)
predict_cla = np.argmax(predict, axis=1)[0]
table_elapse = time.perf_counter() - ss
return self.cls[predict_cla], table_elapse
Empty file added table_cls/models/.gitkeep
Empty file.
111 changes: 111 additions & 0 deletions table_cls/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from io import BytesIO
from pathlib import Path
from typing import Union

from PIL import UnidentifiedImageError
from PIL import Image
import numpy as np
import cv2

InputType = Union[str, np.ndarray, bytes, Path, Image.Image]


class LoadImageError(Exception):
pass


class LoadImage:
def __init__(
self,
):
pass

def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
f"The img type {type(img)} does not in {InputType.__args__}"
)

origin_img_type = type(img)
img = self.load_img(img)
img = self.convert_img(img, origin_img_type)
return img

def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, (str, Path)):
self.verify_exist(img)
try:
img = np.array(Image.open(img))
except UnidentifiedImageError as e:
raise LoadImageError(f"cannot identify image file {img}") from e
return img

if isinstance(img, bytes):
img = np.array(Image.open(BytesIO(img)))
return img

if isinstance(img, np.ndarray):
return img

if isinstance(img, Image.Image):
return np.array(img)

raise LoadImageError(f"{type(img)} is not supported!")

def convert_img(self, img: np.ndarray, origin_img_type):
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

if img.ndim == 3:
channel = img.shape[2]
if channel == 1:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

if channel == 2:
return self.cvt_two_to_three(img)

if channel == 3:
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img

if channel == 4:
return self.cvt_four_to_three(img)

raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)

raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")

@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
img_gray = img[..., 0]
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)

img_alpha = img[..., 1]
not_a = cv2.bitwise_not(img_alpha)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))

not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)

new_img = cv2.bitwise_and(new_img, new_img, mask=a)
new_img = cv2.add(new_img, not_a)
return new_img

@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")
Binary file added tests/test_files/table_cls/lineless_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/test_files/table_cls/wired_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 23 additions & 0 deletions tests/test_table_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import sys
from pathlib import Path

import pytest

from table_cls import TableCls

cur_dir = Path(__file__).resolve().parent
root_dir = cur_dir.parent

sys.path.append(str(root_dir))
test_file_dir = cur_dir / "test_files" / "table_cls"
table_cls = TableCls()


@pytest.mark.parametrize(
"img_path, expected",
[("wired_table.png", "wired"), ("lineless_table.png", "wireless")],
)
def test_input_normal(img_path, expected):
img_path = test_file_dir / img_path
res, elasp = table_cls(img_path)
assert res == expected

0 comments on commit 9ebae19

Please sign in to comment.