Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Code #200

Open
yashsandansing opened this issue Oct 6, 2022 · 9 comments
Open

Training Code #200

yashsandansing opened this issue Oct 6, 2022 · 9 comments

Comments

@yashsandansing
Copy link

yashsandansing commented Oct 6, 2022

I had to review a lot of documentation and issues to implement the training code. So here is the code you'll be needing for training.

Initially, you'll need to download the pretrained model files from https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing and move it to MODNet/pretrained.
In case you need to fine-tune the model to your own dataset, download - modnet_photographic_portrait_matting.ckpt.
In case you need to use the backbone mobilenetv2 model, download that too.

For preparing the dataset, I prepared a pandas dataframe which had 2 columns - ["image", "matte"]
"image" had the absolute path to the images' location and "matte" had that respective image's matte image location.

After downloading, for preprocessing, the code is:

class ModNetDataLoader(Dataset):
    def __init__(self, annotations_file, resize_dim, transform=None):
        self.img_labels =annotations_file
        self.transform=transform
        self.resize_dim=resize_dim

    def __len__(self):
        #return the total number of images
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = self.img_labels.iloc[idx,0]
        mask_path = self.img_labels.iloc[idx,1]

        img = np.asarray(Image.open(img_path))

        in_image = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = in_image[:,:,3]

        if len(img.shape)==2:
            img = img[:,:,None]
        if img.shape[2]==1:
            img = np.repeat(img, 3, axis=2)
        elif img.shape[2]==4:
            img = img[:,:,0:3]

        if len(mask.shape)==3:
            mask = mask[:,:, 0]

        #convert Image to pytorch tensor
        img = Image.fromarray(img)
        mask = Image.fromarray(mask)
        if self.transform:
            img = self.transform(img)
            trimap = self.get_trimap(mask)
            mask = self.transform(mask)

        img = self._resize(img)
        mask = self._resize(mask)
        trimap = self._resize(trimap, trimap=True)

        img = torch.squeeze(img, 0)
        mask = torch.squeeze(mask, 0)
        trimap = torch.squeeze(trimap, 1)

        return img, trimap, mask

    def get_trimap(self, alpha):
        # alpha \in [0, 1] should be taken into account
        # be careful when dealing with regions of alpha=0 and alpha=1
        fg = np.array(np.equal(alpha, 255).astype(np.float32))
        unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) # unknown = alpha > 0
        unknown = unknown - fg
        # image dilation implemented by Euclidean distance transform
        unknown = morphology.distance_transform_edt(unknown==0) <= np.random.randint(1, 20)
        trimap = fg
        trimap[unknown] = 0.5
        return torch.unsqueeze(torch.from_numpy(trimap), dim=0)#.astype(np.uint8)

    def _resize(self, img, trimap=False):
        im = img[None, :, :, :]
        ref_size = self.resize_dim

        # resize image for input
        im_b, im_c, im_h, im_w = im.shape
        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32
        if trimap == True:
            im = F.interpolate(im, size=(im_rh, im_rw), mode='nearest')
        else:
            im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
        return im

You might need to change the above code in methods, get_trimap and get_item according to your dataset
You would need to verify if your data is proper in the next to next step

Finally, create your dataset using the code below:

transformer = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5)
            )
        ]
    )
data = ModNetDataLoader(data_csv, 512, transform=transformer)

After your dataset has been created, 1st verify it by printing the first row of data and verifying if the shapes of image, matte, trimap are equal (only the channels can be different).
IMPORTANT: Try printing the first of the trimaps. The only values in the numpy array should be 0, 0.5 and 1.
Use the dataloader function to prepare your data for training:

train_dataloader = DataLoader(data, batch_size=8, shuffle=True)

After this, the code for training is available in the trainer.py file:

import torch
from src.models.modnet import MODNet
from src.trainer import supervised_training_iter
bs = 16         # batch size
lr = 0.01       # learn rate
epochs = 40     # total epochs

modnet = torch.nn.DataParallel(MODNet()).cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

for epoch in range(0, epochs):
    for idx, (image, trimap, gt_matte) in enumerate(dataloader):
        semantic_loss, detail_loss, matte_loss = \
            supervised_training_iter(modnet, optimizer, image.cuda(), trimap.cuda(), gt_matte.cuda())
    lr_scheduler.step()

For using the backbone - Change modnet = torch.nn.DataParallel(MODNet()).cuda() to modnet = torch.nn.DataParallel(MODNet(backbone_pretrained=True)).cuda() in case you have the mobilenetv2 in your pretrained directory.

For fine-tuning the existing MODNet model, use this snippet before the optimizer line:

modnet = torch.nn.DataParallel(MODNet()).cuda()
state_dict = torch.load("path_to_torchscript_model.ckpt")
modnet.load_state_dict(state_dict)
modnet.train()
@ezio1320
Copy link

Thanks a lot! It's very helpful for me!

@SamStark-AW
Copy link

Hi, I was wondering if anyone tried and succeeded in training/tuning the model? The code runs well, but it seems like the model didn't change at all.
I wrote an extra bit of code that allows the model to eval a few test images after each epochs, one would assume the result of the evals will get better over epochs (but there was 0 changes, all the pixels are in exact same values).

Was wondering if there's a mistake in my code, or is it the same for everyone?

@yashsandansing
Copy link
Author

@SamStark-AtWork I was the one who wrote this training code and even I couldn't get it to work later. The loss remains constant more or less throughout the training process. I read in one of the issues that if you set the model to model.eval() in the trainer function before training, you might get better results. But this script has gotten me some terrible results on multiple datasets. People have gotten it to work but there have been 0 training scripts on here. I think I saw a training script pending approval in the PR section once. You can maybe try that code?

@SamStark-AW
Copy link

I see, thanks for sharing, so it's not just me haha. I found the training script u mentioned, will try it soon, the author mentioned he's busy to check it, so hopefully it will work!

@SamStark-AW
Copy link

SamStark-AW commented Nov 9, 2022

Some updates, so I've taken a look at the training script mentioned by @yashsandansing, to be honest it's not much difference, the biggest different is how the original trimap was produced (which did work a bit better), and most of the importing functions and utilities, there are some slight adjustments here and there for the training code on different commits that person did, but overall they don't produce much differences.

For the problem I described above where there are 0 changes by epoch, I realized the LR scheduler was stepping thru each batch (should be each epoch), which was my mistake, the problem is gone after that was dealt with.

Below section is more abt the quality and anyone who would like to try train/tune the model.

For context, I'm trying to produce a model that can produce matte of an object in given picture (not human in portrait). My dataset is self-produced, and very questionable quality (it's most likely one of the reason why I couldn't get great result), and they also only contains very similar objects differ in angles, which might explains the overfitting.

I tried both training the model from scratch and fine-tuning a pre-trained model. The result wasn't good, feels like there's a lot of overfitting (might work better if I try a better dataset). However, it does works better than a generic VGG16 autoencoder, and at a much lower VRAM cost too (which is incredible). Fine-tuning the pre-trained model does work better than expected, both method do seem to try and approach to a closer results than when they started.

I also tried some metaparams, the current optimizer and its params are pretty good already. The biggest changes I noticed is on the batch size, I tried 4 and 32, bigger batch size does produce better result (less overfitting).

The result accuracy jumped significantly between no-train and after first epoch, any changes after that is small to almost none, and my dataset contains a lot of similar pictures, which is why I deduced that overfitting is happening.

So far I haven't seen anyone that did train/fine-tune a trained model that gives great result yet, so do be aware of that if you are interested in doing it your own.

@mtrabelsi
Copy link

mtrabelsi commented Jan 21, 2023

@SamStark-AtWork nice findings, do you mind provide your training code?
I want to check if dataset matters at all.
will post the result here

best regards

@SamStark-AW
Copy link

Hey @mtrabelsi , here u go, hope u get some good result! It's been a while since I last done anything on this, so it might be a bit messy, let me know if anything is not explained properly.

# imports
from src.models import modnet as MODNet
from src import trainer as MODTrainer

# import dataset and load as "dataloader" (check @yashsandansing's comment or the scripts at PR)

# import model
modnet = MODNet.MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()

# for evaluate progress
evalPath = "INSERT YOUR PATH HERE"
if not os.path.isdir(evalPath):
    os.makedirs(evalPath)
# pick 2 or more images here and store it for infer/eval later

# metaparams
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)   # can try momentum=0.45 too
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)


# Training starts here
for epoch in range(0, epochs):
    for idx, (image, trimap, gt_matte) in enumerate(dataloader):
        semantic_loss, detail_loss, matte_loss = MODTrainer.supervised_training_iter(modnet, optimizer, image.cuda(), trimap.cuda(), gt_matte.cuda())
    lr_scheduler.step()


    # eval for progress check and save images (here's where u visualize changes over training time)
    with torch.no_grad():
        _,_,debugImages = modnet(testImages.cuda(),True)
        for idx, img in enumerate(debugImages):
            saveName = "eval_%g_%g.jpg"%(idx,epoch+1)
            torchvision.utils.save_image(img, os.path.join(evalPath,saveName))

    print("Epoch done: " + str(epoch))

@zzzcyyyw
Copy link

@yashsandansing
Hi, I would like to ask, in the code you provided, the size of each image in the data processing section is not the same, how to do batch training?

@yashsandansing
Copy link
Author

hey @zzzcyyyw sorry for the extremely late reply. If I remember correctly, the code did give different sizes. I believe that the resize-code was data-specific and worked for my dataset. There were some other datasets for which it didnt work, so you'd need to tweak that section of the code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants