Skip to content

[On-device Training] Yolov4 custom loss #19390

Discussion options

You must be logged in to vote

Here is an idea for you to consider: Try to make the exported onnx model contain the loss and then when building the training artifacts, avoid passing in the loss argument:

class MyPTModelWithLoss:
    def __init__(self):
         ...

    def forward(self, ...):
        p, q, r = compute_logits()
        loss = loss1(p) + loss2(q) + loss3(r)
        return loss

pt_model = MyPTModelWithLoss(...)
torch.onnx.export(pt_model, ...)

onnx_model = onnx.load(<exported_onnx_model_path>)
artifacts.generate_artifacts(onnx_model, requires_grad=[...], frozen_params=[...], loss=None, optimizer=...)

Let me know how that goes.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by baijumeswani
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
2 participants