Skip to content

Commit

Permalink
[Feature] remote sensing inference (#3131)
Browse files Browse the repository at this point in the history
## Motivation

Supports inference for ultra-large-scale remote sensing images.

## Modification

Add RSImageInference.py in demo.

## Use cases

Taking the inference of Vaihingen dataset images using PSPNet as an
example, the following settings are required:

**img**: Specify the path of the image.
**model**: Provide the configuration file for the model.
**checkpoint**: Specify the weight file for the model.
**out**: Set the output path for the results.
**batch_size**: Determine the batch size used during inference.
**win_size**: Specify the width and height(512x512) of the sliding
window.
**stride**: Set the stride(400x400) for sliding the window.
**thread(default: 1)**: Specify the number of threads to be used for
inference.
**Inference device (default: cuda:0)**: Specify the device for inference
(e.g., cuda:0 for CPU).

```shell
python demo/rs_image_inference.py demo/demo.png projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth --batch-size 8 --device cpu --thread 2
```

---------

Co-authored-by: xiexinch <[email protected]>
  • Loading branch information
Zoulinx and xiexinch committed Aug 31, 2023
1 parent 35ff78a commit 72e20a8
Show file tree
Hide file tree
Showing 9 changed files with 489 additions and 76 deletions.
4 changes: 2 additions & 2 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
- run:
name: Skip timm unittests and generate coverage report
command: |
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
python -m coverage xml
python -m coverage report -m
build_cuda:
Expand Down Expand Up @@ -119,7 +119,7 @@ jobs:
- run:
name: Run unittests but skip timm unittests
command: |
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
workflows:
pr_stage_lint:
when: << pipeline.parameters.lint_only >>
Expand Down
50 changes: 50 additions & 0 deletions demo/rs_image_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

from mmseg.apis import RSImage, RSInferencer


def main():
parser = ArgumentParser()
parser.add_argument('image', help='Image file path')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--output-path',
help='Path to save result image',
default='result.png')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='maximum number of windows inferred simultaneously')
parser.add_argument(
'--window-size',
help='window xsize,ysize',
default=(224, 224),
type=int,
nargs=2)
parser.add_argument(
'--stride',
help='window xstride,ystride',
default=(224, 224),
type=int,
nargs=2)
parser.add_argument(
'--thread', default=1, type=int, help='number of inference threads')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
inferencer = RSInferencer.from_config_path(
args.config,
args.checkpoint,
batch_size=args.batch_size,
thread=args.thread,
device=args.device)
image = RSImage(args.image)

inferencer.run(image, args.window_size, args.stride, args.output_path)


if __name__ == '__main__':
main()
4 changes: 3 additions & 1 deletion mmseg/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model, show_result_pyplot
from .mmseg_inferencer import MMSegInferencer
from .remote_sense_inferencer import RSImage, RSInferencer

__all__ = [
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
'RSInferencer', 'RSImage'
]
40 changes: 2 additions & 38 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Optional, Union

import mmcv
import numpy as np
import torch
from mmengine import Config
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
Expand All @@ -18,6 +16,7 @@
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
from mmseg.visualization import SegLocalVisualizer
from .utils import ImageType, _preprare_data


def init_model(config: Union[str, Path, Config],
Expand Down Expand Up @@ -90,41 +89,6 @@ def init_model(config: Union[str, Path, Config],
return model


ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def _preprare_data(imgs: ImageType, model: BaseSegmentor):

cfg = model.cfg
for t in cfg.test_pipeline:
if t.get('type') == 'LoadAnnotations':
cfg.test_pipeline.remove(t)

is_batch = True
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
is_batch = False

if isinstance(imgs[0], np.ndarray):
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'

# TODO: Consider using the singleton pattern to avoid building
# a pipeline for each inference
pipeline = Compose(cfg.test_pipeline)

data = defaultdict(list)
for img in imgs:
if isinstance(img, np.ndarray):
data_ = dict(img=img)
else:
data_ = dict(img_path=img)
data_ = pipeline(data_)
data['inputs'].append(data_['inputs'])
data['data_samples'].append(data_['data_samples'])

return data, is_batch


def inference_model(model: BaseSegmentor,
img: ImageType) -> Union[SegDataSample, SampleList]:
"""Inference image(s) with the segmentor.
Expand Down
Loading

0 comments on commit 72e20a8

Please sign in to comment.