Skip to content

maciejbalawejder/Data-Augmentation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Data Augmentation in Torchvison

This is an analysis of different data augmentations techniques in Torchvision evaluated on CIFAR10. You can find the accompanying article on my Medium page.

Table of content:

Augmentations

  1. Plain - only Normalize() operation.

  2. Baseline - HorizontalFlipping(), RandomCrop(), RandomErasing().

  3. AutoAugment - AutoAugment policy for CIFAR10 applied on the top of Baseline configuration.

from augmentations import GetAugment
plain, baseline, autoaugment = GetAugment()

Dataset

from cifar10 import LoadDataset
trainloader, valloader, testloader = LoadDataset(batch, normalization, augmentations) 

Model

from resnet import ResNet
n = 3
resnet20 = ResNet(n)

Training Loops

from training_functions import Network
network = Network(model=ResNet(3), learning_rate=0.01, device="cuda")
network.train_step(trainloader)

Plots

from plots import plot
plot([model1_train_loss, model1_val_loss, model2_train_loss, model2_val_los], "Loss")

Training

train.py - combines all of the files above and train three different configurations

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages