Skip to content

Commit

Permalink
torchrun for DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Sep 20, 2024
1 parent 3a925bf commit c171b2c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion xuance/torch/learners/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ def __init__(self,

self.distributed_training = config.distributed_training
if self.distributed_training:
self.device = int(os.environ['LOCAL_RANK'])
self.snapshot_path = os.path.join(config.model_dir, "DDP_Snapshot")
if os.path.exists(self.snapshot_path):
print("Loading Snapshot...")
self.load_snapshot(self.snapshot_path)
else:
os.makedirs(self.snapshot_path)
self.device = int(os.environ['LOCAL_RANK'])
self.policy = DistributedDataParallel(self.policy, find_unused_parameters=True,
device_ids=[int(os.environ['LOCAL_RANK'])])
else:
self.device = config.device
self.use_grad_clip = config.use_grad_clip
self.grad_clip_norm = config.grad_clip_norm
self.device = config.device
Expand Down

0 comments on commit c171b2c

Please sign in to comment.