Skip to content

Latest commit

ย 

History

History
121 lines (90 loc) ยท 5.43 KB

nvidia_deeplearningexamples_resnext.md

File metadata and controls

121 lines (90 loc) ยท 5.43 KB
layout background-class body-class title summary category image author tags github-link github-id featured_image_1 featured_image_2 accelerator order demo-model-link
hub_detail
hub-background
hub
ResNeXt101
ResNet with bottleneck 3x3 Convolutions substituted by 3x3 Grouped Convolutions, trained with mixed precision using Tensor Cores.
researchers
nvidia_logo.png
NVIDIA
vision
NVIDIA/DeepLearningExamples
ResNeXtArch.png
classification.jpg
cuda
10

๋ชจ๋ธ ์„ค๋ช…

ResNeXt101-32x4d๋Š” Aggregated Residual Transformations for Deep Neural Networks ๋…ผ๋ฌธ์— ์†Œ๊ฐœ๋œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

์ด ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์ธ ResNet ๋ชจ๋ธ์— ๊ธฐ๋ฐ˜์„ ๋‘๊ณ  ์žˆ์œผ๋ฉฐ ResNet์˜ 3x3 ๊ทธ๋ฃน ํ•ฉ์„ฑ๊ณฑ(Grouped Convolution) ๊ณ„์ธต์„ ๋ณ‘๋ชฉ ๋ธ”๋ก(Bottleneck Block) ๋‚ด๋ถ€์˜ 3x3 ํ•ฉ์„ฑ๊ณฑ ๊ณ„์ธต์œผ๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.

ResNeXt101 ๋ชจ๋ธ์€ Volta, Turing ๋ฐ NVIDIA Ampere ์•„ํ‚คํ…์ฒ˜์—์„œ Tensor Core๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ˜ผํ•ฉ ์ •๋ฐ€๋„(Mixed Precision) ๋ฐฉ์‹[1]์œผ๋กœ ํ•™์Šต๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฐ๊ตฌ์ž๋“ค์€ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ•™์Šต์˜ ์žฅ์ ์„ ๊ฒฝํ—˜ํ•˜๋Š” ๋™์‹œ์— Tensor Cores๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ๋•Œ๋ณด๋‹ค ๊ฒฐ๊ณผ๋ฅผ 3๋ฐฐ ๋น ๋ฅด๊ฒŒ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ์‹œ๊ฐ„์ด ์ง€๋‚จ์—๋„ ์ง€์†์ ์ธ ์ •ํ™•๋„์™€ ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ์›”๋ณ„ NGC ์ปจํ…Œ์ด๋„ˆ ์ถœ์‹œ์— ๋Œ€ํ•ด ํ…Œ์ŠคํŠธ๋˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ํ•™์Šต์—๋Š” NHWC ๋ฐ์ดํ„ฐ ๋ ˆ์ด์•„์›ƒ์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

ResNeXt101-32x4d ๋ชจ๋ธ์€ ์ถ”๋ก ์„ ์œ„ํ•ด TorchScript, ONNX Runtime ๋˜๋Š” TensorRT๋ฅผ ์‹คํ–‰ ๋ฐฑ์—”๋“œ๋กœ ์‚ฌ์šฉํ•˜๊ณ  NVIDIA Triton Inference Server์— ๋ฐฐํฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ NGC์—์„œ ํ™•์ธํ•˜์„ธ์š”.

๋ชจ๋ธ ๊ตฌ์กฐ

ResNextArch

์ด๋ฏธ์ง€ ์ถœ์ฒ˜: Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf)

์œ„์˜ ์ด๋ฏธ์ง€๋Š” ResNet ๋ชจ๋ธ์˜ ๋ณ‘๋ชฉ ๋ธ”๋ก๊ณผ ResNeXt ๋ชจ๋ธ์˜ ๋ณ‘๋ชฉ ๋ธ”๋ก์˜ ์ฐจ์ด๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

ResNeXt101-32x4d ๋ชจ๋ธ์˜ ์นด๋””๋„๋ฆฌํ‹ฐ(Cardinality)๋Š” 32์ด๊ณ  ๋ณ‘๋ชฉ ๋ธ”๋ก์˜ Width๋Š” 4์ž…๋‹ˆ๋‹ค.

์˜ˆ์‹œ

์•„๋ž˜ ์˜ˆ์‹œ์—์„œ ์‚ฌ์ „ ํ•™์Šต๋œ ResNeXt101-32x4d๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋“ค์— ๋Œ€ํ•œ ์ถ”๋ก ์„ ์ง„ํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.

์˜ˆ์‹œ๋ฅผ ์‹คํ–‰ํ•˜๋ ค๋ฉด ์ถ”๊ฐ€์ ์ธ ํŒŒ์ด์ฌ ํŒจํ‚ค์ง€๋“ค์ด ์„ค์น˜๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด ํŒจํ‚ค์ง€๋“ค์€ ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ๋ฐ ์‹œ๊ฐํ™”์— ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

!pip install validators matplotlib
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

IMAGENET ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์‚ฌ์ „ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.

resneXt = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resneXt')
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')

resneXt.eval().to(device)

์ƒ˜ํ”Œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค.

uris = [
    'http://images.cocodataset.org/test-stuff2017/000000024309.jpg',
    'http://images.cocodataset.org/test-stuff2017/000000028117.jpg',
    'http://images.cocodataset.org/test-stuff2017/000000006149.jpg',
    'http://images.cocodataset.org/test-stuff2017/000000004954.jpg',
]


batch = torch.cat(
    [utils.prepare_input_from_uri(uri) for uri in uris]
).to(device)

์ถ”๋ก ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ํ—ฌํผ ํ•จ์ˆ˜ pick_n_best(predictions=output, n=topN)๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋ธ์— ๋Œ€ํ•œ N๊ฐœ์˜ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ๊ฐ€์„ค์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.

with torch.no_grad():
    output = torch.nn.functional.softmax(resneXt(batch), dim=1)
    
results = utils.pick_n_best(predictions=output, n=5)

๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.

for uri, result in zip(uris, results):
    img = Image.open(requests.get(uri, stream=True).raw)
    img.thumbnail((256,256), Image.ANTIALIAS)
    plt.imshow(img)
    plt.show()
    print(result)

์„ธ๋ถ€์‚ฌํ•ญ

๋ชจ๋ธ ์ž…๋ ฅ ๋ฐ ์ถœ๋ ฅ, ํ•™์Šต ๋ฐฉ๋ฒ•, ์ถ”๋ก  ๋ฐ ์„ฑ๋Šฅ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ github์ด๋‚˜ NGC์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฐธ๊ณ ๋ฌธํ—Œ

[1]: ๋น ๋ฅด๊ณ  ํšจ์œจ์ ์ธ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด 16๋น„ํŠธ ๋ถ€๋™์†Œ์ˆ˜์ ๊ณผ 32๋น„ํŠธ ๋ถ€๋™์†Œ์ˆ˜์ ์„ ํ•จ๊ป˜ ์‚ฌ์šฉํ•ด ํ•™์Šตํ•˜๋Š” ๋ฐฉ์‹.