Skip to content

Commit

Permalink
[Fix] fix inferencer ut (#3117)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Jun 19, 2023
1 parent c30d506 commit b2f4b4f
Showing 1 changed file with 8 additions and 27 deletions.
35 changes: 8 additions & 27 deletions tests/test_apis/test_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.nn as nn
from mmengine import ConfigDict
from torch.utils.data import DataLoader, Dataset

from mmseg.apis import MMSegInferencer
from mmseg.models import EncoderDecoder
Expand Down Expand Up @@ -46,33 +45,8 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)


class ExampleDataset(Dataset):

def __init__(self) -> None:
super().__init__()
self.pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]

def __getitem__(self, idx):
return dict(img=torch.tensor([1]), img_metas=dict())

def __len__(self):
return 1


def test_inferencer():
register_all_modules()
test_dataset = ExampleDataset()
data_loader = DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False,
)

visualizer = dict(
type='SegLocalVisualizer',
Expand All @@ -87,7 +61,14 @@ def test_inferencer():
decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')),
visualizer=visualizer,
test_dataloader=data_loader)
test_dataloader=dict(
dataset=dict(
type='ExampleDataset',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]), ))
cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model)

Expand Down

0 comments on commit b2f4b4f

Please sign in to comment.