diff --git a/.circleci/test.yml b/.circleci/test.yml index 505822b3db..ceef7884f7 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -73,7 +73,7 @@ jobs: - run: name: Skip timm unittests and generate coverage report command: | - python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py + python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py python -m coverage xml python -m coverage report -m build_cuda: @@ -119,7 +119,7 @@ jobs: - run: name: Run unittests but skip timm unittests command: | - docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py + docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py workflows: pr_stage_lint: when: << pipeline.parameters.lint_only >> diff --git a/demo/rs_image_inference.py b/demo/rs_image_inference.py new file mode 100644 index 0000000000..799181f93c --- /dev/null +++ b/demo/rs_image_inference.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmseg.apis import RSImage, RSInferencer + + +def main(): + parser = ArgumentParser() + parser.add_argument('image', help='Image file path') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--output-path', + help='Path to save result image', + default='result.png') + parser.add_argument( + '--batch-size', + type=int, + default=1, + help='maximum number of windows inferred simultaneously') + parser.add_argument( + '--window-size', + help='window xsize,ysize', + default=(224, 224), + type=int, + nargs=2) + parser.add_argument( + '--stride', + help='window xstride,ystride', + default=(224, 224), + type=int, + nargs=2) + parser.add_argument( + '--thread', default=1, type=int, help='number of inference threads') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + inferencer = RSInferencer.from_config_path( + args.config, + args.checkpoint, + batch_size=args.batch_size, + thread=args.thread, + device=args.device) + image = RSImage(args.image) + + inferencer.run(image, args.window_size, args.stride, args.output_path) + + +if __name__ == '__main__': + main() diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py index d22dc3f0ad..b50a266319 100644 --- a/mmseg/apis/__init__.py +++ b/mmseg/apis/__init__.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_model, init_model, show_result_pyplot from .mmseg_inferencer import MMSegInferencer +from .remote_sense_inferencer import RSImage, RSInferencer __all__ = [ - 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer' + 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer', + 'RSInferencer', 'RSImage' ] diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 6a398ebc5e..0dd70cd615 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -1,14 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from collections import defaultdict from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Optional, Union import mmcv import numpy as np import torch from mmengine import Config -from mmengine.dataset import Compose from mmengine.registry import init_default_scope from mmengine.runner import load_checkpoint from mmengine.utils import mkdir_or_exist @@ -18,6 +16,7 @@ from mmseg.structures import SegDataSample from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette from mmseg.visualization import SegLocalVisualizer +from .utils import ImageType, _preprare_data def init_model(config: Union[str, Path, Config], @@ -90,41 +89,6 @@ def init_model(config: Union[str, Path, Config], return model -ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] - - -def _preprare_data(imgs: ImageType, model: BaseSegmentor): - - cfg = model.cfg - for t in cfg.test_pipeline: - if t.get('type') == 'LoadAnnotations': - cfg.test_pipeline.remove(t) - - is_batch = True - if not isinstance(imgs, (list, tuple)): - imgs = [imgs] - is_batch = False - - if isinstance(imgs[0], np.ndarray): - cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray' - - # TODO: Consider using the singleton pattern to avoid building - # a pipeline for each inference - pipeline = Compose(cfg.test_pipeline) - - data = defaultdict(list) - for img in imgs: - if isinstance(img, np.ndarray): - data_ = dict(img=img) - else: - data_ = dict(img_path=img) - data_ = pipeline(data_) - data['inputs'].append(data_['inputs']) - data['data_samples'].append(data_['data_samples']) - - return data, is_batch - - def inference_model(model: BaseSegmentor, img: ImageType) -> Union[SegDataSample, SampleList]: """Inference image(s) with the segmentor. diff --git a/mmseg/apis/remote_sense_inferencer.py b/mmseg/apis/remote_sense_inferencer.py new file mode 100644 index 0000000000..6726c6ae34 --- /dev/null +++ b/mmseg/apis/remote_sense_inferencer.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import threading +from queue import Queue +from typing import List, Optional, Tuple + +import numpy as np +import torch +from mmengine import Config +from mmengine.model import BaseModel +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint + +try: + from osgeo import gdal +except ImportError: + gdal = None + +from mmseg.registry import MODELS +from .utils import _preprare_data + + +class RSImage: + """Remote sensing image class. + + Args: + img (str or gdal.Dataset): Image file path or gdal.Dataset. + """ + + def __init__(self, image): + self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance( + image, str) else image + assert isinstance(self.dataset, gdal.Dataset), \ + f'{image} is not a image' + self.width = self.dataset.RasterXSize + self.height = self.dataset.RasterYSize + self.channel = self.dataset.RasterCount + self.trans = self.dataset.GetGeoTransform() + self.proj = self.dataset.GetProjection() + self.band_list = [] + self.band_list.extend( + self.dataset.GetRasterBand(c + 1) for c in range(self.channel)) + self.grids = [] + + def read(self, grid: Optional[List] = None) -> np.ndarray: + """Read image data. If grid is None, read the whole image. + + Args: + grid (Optional[List], optional): Grid to read. Defaults to None. + Returns: + np.ndarray: Image data. + """ + if grid is None: + return np.einsum('ijk->jki', self.dataset.ReadAsArray()) + assert len( + grid) >= 4, 'grid must be a list containing at least 4 elements' + data = self.dataset.ReadAsArray(*grid[:4]) + if data.ndim == 2: + data = data[np.newaxis, ...] + return np.einsum('ijk->jki', data) + + def write(self, data: Optional[np.ndarray], grid: Optional[List] = None): + """Write image data. + + Args: + grid (Optional[List], optional): Grid to write. Defaults to None. + data (Optional[np.ndarray], optional): Data to write. + Defaults to None. + + Raises: + ValueError: Either grid or data must be provided. + """ + if grid is not None: + assert len(grid) == 8, 'grid must be a list of 8 elements' + for band in self.band_list: + band.WriteArray( + data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]], + grid[0] + grid[4], grid[1] + grid[5]) + elif data is not None: + for i in range(self.channel): + self.band_list[i].WriteArray(data[..., i]) + else: + raise ValueError('Either grid or data must be provided.') + + def create_seg_map(self, output_path: Optional[str] = None): + if output_path is None: + output_path = 'output_label.tif' + driver = gdal.GetDriverByName('GTiff') + seg_map = driver.Create(output_path, self.width, self.height, 1, + gdal.GDT_Byte) + seg_map.SetGeoTransform(self.trans) + seg_map.SetProjection(self.proj) + seg_map_img = RSImage(seg_map) + seg_map_img.path = output_path + return seg_map_img + + def create_grids(self, + window_size: Tuple[int, int], + stride: Tuple[int, int] = (0, 0)): + """Create grids for image inference. + + Args: + window_size (Tuple[int, int]): the size of the sliding window. + stride (Tuple[int, int], optional): the stride of the sliding + window. Defaults to (0, 0). + + Raises: + AssertionError: window_size must be a tuple of 2 elements. + AssertionError: stride must be a tuple of 2 elements. + """ + assert len( + window_size) == 2, 'window_size must be a tuple of 2 elements' + assert len(stride) == 2, 'stride must be a tuple of 2 elements' + win_w, win_h = window_size + stride_x, stride_y = stride + + stride_x = win_w if stride_x == 0 else stride_x + stride_y = win_h if stride_y == 0 else stride_y + + x_half_overlap = (win_w - stride_x + 1) // 2 + y_half_overlap = (win_h - stride_y + 1) // 2 + + for y in range(0, self.height, stride_y): + y_end = y + win_h >= self.height + y_offset = self.height - win_h if y_end else y + y_size = win_h + y_crop_off = 0 if y_offset == 0 else y_half_overlap + y_crop_size = y_size if y_end else win_h - y_crop_off + + for x in range(0, self.width, stride_x): + x_end = x + win_w >= self.width + x_offset = self.width - win_w if x_end else x + x_size = win_w + x_crop_off = 0 if x_offset == 0 else x_half_overlap + x_crop_size = x_size if x_end else win_w - x_crop_off + + self.grids.append([ + x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off, + x_crop_size, y_crop_size + ]) + + +class RSInferencer: + """Remote sensing inference class. + + Args: + model (BaseModel): The loaded model. + batch_size (int, optional): Batch size. Defaults to 1. + thread (int, optional): Number of threads. Defaults to 1. + """ + + def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1): + self.model = model + self.batch_size = batch_size + self.END_FLAG = object() + self.read_buffer = Queue(self.batch_size) + self.write_buffer = Queue(self.batch_size) + self.thread = thread + + @classmethod + def from_config_path(cls, + config_path: str, + checkpoint_path: str, + batch_size: int = 1, + thread: int = 1, + device: Optional[str] = 'cpu'): + """Initialize a segmentor from config file. + + Args: + config_path (str): Config file path. + checkpoint_path (str): Checkpoint path. + batch_size (int, optional): Batch size. Defaults to 1. + """ + init_default_scope('mmseg') + cfg = Config.fromfile(config_path) + model = MODELS.build(cfg.model) + model.cfg = cfg + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + model.eval() + return cls(model, batch_size, thread) + + @classmethod + def from_model(cls, + model: BaseModel, + checkpoint_path: Optional[str] = None, + batch_size: int = 1, + thread: int = 1, + device: Optional[str] = 'cpu'): + """Initialize a segmentor from model. + + Args: + model (BaseModel): The loaded model. + checkpoint_path (Optional[str]): Checkpoint path. + batch_size (int, optional): Batch size. Defaults to 1. + """ + if checkpoint_path is not None: + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + return cls(model, batch_size, thread) + + def read(self, + image: RSImage, + window_size: Tuple[int, int], + strides: Tuple[int, int] = (0, 0)): + """Load image data to read buffer. + + Args: + image (RSImage): The image to read. + window_size (Tuple[int, int]): The size of the sliding window. + strides (Tuple[int, int], optional): The stride of the sliding + window. Defaults to (0, 0). + """ + image.create_grids(window_size, strides) + for grid in image.grids: + self.read_buffer.put([grid, image.read(grid=grid)]) + self.read_buffer.put(self.END_FLAG) + + def inference(self): + """Inference image data from read buffer and put the result to write + buffer.""" + while True: + item = self.read_buffer.get() + if item == self.END_FLAG: + self.read_buffer.put(self.END_FLAG) + self.write_buffer.put(item) + break + data, _ = _preprare_data(item[1], self.model) + with torch.no_grad(): + result = self.model.test_step(data) + item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0] + self.write_buffer.put(item) + self.read_buffer.task_done() + + def write(self, image: RSImage, output_path: Optional[str] = None): + """Write image data from write buffer. + + Args: + image (RSImage): The image to write. + output_path (Optional[str], optional): The path to save the + segmentation map. Defaults to None. + """ + seg_map = image.create_seg_map(output_path) + while True: + item = self.write_buffer.get() + if item == self.END_FLAG: + break + seg_map.write(data=item[1], grid=item[0]) + self.write_buffer.task_done() + + def run(self, + image: RSImage, + window_size: Tuple[int, int], + strides: Tuple[int, int] = (0, 0), + output_path: Optional[str] = None): + """Run inference with multi-threading. + + Args: + image (RSImage): The image to inference. + window_size (Tuple[int, int]): The size of the sliding window. + strides (Tuple[int, int], optional): The stride of the sliding + window. Defaults to (0, 0). + output_path (Optional[str], optional): The path to save the + segmentation map. Defaults to None. + """ + read_thread = threading.Thread( + target=self.read, args=(image, window_size, strides)) + read_thread.start() + inference_threads = [] + for _ in range(self.thread): + inference_thread = threading.Thread(target=self.inference) + inference_thread.start() + inference_threads.append(inference_thread) + write_thread = threading.Thread( + target=self.write, args=(image, output_path)) + write_thread.start() + read_thread.join() + for inference_thread in inference_threads: + inference_thread.join() + write_thread.join() diff --git a/mmseg/apis/utils.py b/mmseg/apis/utils.py new file mode 100644 index 0000000000..4cf8775660 --- /dev/null +++ b/mmseg/apis/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Sequence, Union + +import numpy as np +from mmengine.dataset import Compose +from mmengine.model import BaseModel + +ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def _preprare_data(imgs: ImageType, model: BaseModel): + + cfg = model.cfg + for t in cfg.test_pipeline: + if t.get('type') == 'LoadAnnotations': + cfg.test_pipeline.remove(t) + + is_batch = True + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + is_batch = False + + if isinstance(imgs[0], np.ndarray): + cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray' + + # TODO: Consider using the singleton pattern to avoid building + # a pipeline for each inference + pipeline = Compose(cfg.test_pipeline) + + data = defaultdict(list) + for img in imgs: + if isinstance(img, np.ndarray): + data_ = dict(img=img) + else: + data_ = dict(img_path=img) + data_ = pipeline(data_) + data['inputs'].append(data_['inputs']) + data['data_samples'].append(data_['data_samples']) + + return data, is_batch diff --git a/tests/test_apis/test_inferencer.py b/tests/test_apis/test_inferencer.py index 663680976e..d8dbce8f38 100644 --- a/tests/test_apis/test_inferencer.py +++ b/tests/test_apis/test_inferencer.py @@ -3,48 +3,14 @@ import numpy as np import torch -import torch.nn as nn from mmengine import ConfigDict +from utils import * # noqa: F401, F403 from mmseg.apis import MMSegInferencer -from mmseg.models import EncoderDecoder -from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.registry import MODELS from mmseg.utils import register_all_modules -@MODELS.register_module(name='InferExampleHead') -class ExampleDecodeHead(BaseDecodeHead): - - def __init__(self, num_classes=19, out_channels=None): - super().__init__( - 3, 3, num_classes=num_classes, out_channels=out_channels) - - def forward(self, inputs): - return self.cls_seg(inputs[0]) - - -@MODELS.register_module(name='InferExampleBackbone') -class ExampleBackbone(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(3, 3, 3) - - def init_weights(self, pretrained=None): - pass - - def forward(self, x): - return [self.conv(x)] - - -@MODELS.register_module(name='InferExampleModel') -class ExampleModel(EncoderDecoder): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def test_inferencer(): register_all_modules() diff --git a/tests/test_apis/test_rs_inferencer.py b/tests/test_apis/test_rs_inferencer.py new file mode 100644 index 0000000000..03423d9680 --- /dev/null +++ b/tests/test_apis/test_rs_inferencer.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from unittest import TestCase + +import numpy as np +from mmengine import ConfigDict, init_default_scope +from utils import * # noqa: F401, F403 + +from mmseg.apis import RSImage, RSInferencer +from mmseg.registry import MODELS + + +class TestRSImage(TestCase): + + def test_read_whole_image(self): + init_default_scope('mmseg') + img_path = osp.join( + osp.dirname(__file__), + '../data/pseudo_loveda_dataset/img_dir/0.png') + rs_image = RSImage(img_path) + window_size = (16, 16) + rs_image.create_grids(window_size) + image_data = rs_image.read(rs_image.grids[0]) + self.assertIsNotNone(image_data) + + def test_write_image_data(self): + init_default_scope('mmseg') + img_path = osp.join( + osp.dirname(__file__), + '../data/pseudo_loveda_dataset/img_dir/0.png') + rs_image = RSImage(img_path) + window_size = (16, 16) + rs_image.create_grids(window_size) + data = np.random.random((16, 16)).astype(np.int8) + rs_image.write(data, rs_image.grids[0]) + + +class TestRSInferencer(TestCase): + + def test_read_and_inference(self): + init_default_scope('mmseg') + cfg_dict = dict( + model=dict( + type='InferExampleModel', + data_preprocessor=dict(type='SegDataPreProcessor'), + backbone=dict(type='InferExampleBackbone'), + decode_head=dict(type='InferExampleHead'), + test_cfg=dict(mode='whole')), + test_dataloader=dict( + dataset=dict( + type='ExampleDataset', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') + ])), + test_pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') + ]) + cfg = ConfigDict(cfg_dict) + model = MODELS.build(cfg.model) + model.cfg = cfg + inferencer = RSInferencer.from_model(model) + + img_path = osp.join( + osp.dirname(__file__), + '../data/pseudo_loveda_dataset/img_dir/0.png') + rs_image = RSImage(img_path) + window_size = (16, 16) + stride = (16, 16) + inferencer.run(rs_image, window_size, stride) diff --git a/tests/test_apis/utils.py b/tests/test_apis/utils.py new file mode 100644 index 0000000000..0a9928fccf --- /dev/null +++ b/tests/test_apis/utils.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmseg.models import EncoderDecoder +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS + + +@MODELS.register_module(name='InferExampleHead') +class ExampleDecodeHead(BaseDecodeHead): + + def __init__(self, num_classes=19, out_channels=None): + super().__init__( + 3, 3, num_classes=num_classes, out_channels=out_channels) + + def forward(self, inputs): + return self.cls_seg(inputs[0]) + + +@MODELS.register_module(name='InferExampleBackbone') +class ExampleBackbone(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 3) + + def init_weights(self, pretrained=None): + pass + + def forward(self, x): + return [self.conv(x)] + + +@MODELS.register_module(name='InferExampleModel') +class ExampleModel(EncoderDecoder): + + def __init__(self, **kwargs): + super().__init__(**kwargs)