Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] A tensor parallel API for beginners #40

Merged
merged 13 commits into from
Oct 31, 2023
6 changes: 5 additions & 1 deletion axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401
from .fully_connected import Linear # noqa: F401
from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401

from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_ # noqa: F401

from axonn import axonn as ax


Expand All @@ -18,4 +21,5 @@ def gather(x, transpose=False, dim=-1):
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

return Gather.apply(x, group, dim)
152 changes: 135 additions & 17 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from axonn import axonn as ax
import torch.distributed as dist
import torch
from .communication import Drop
from .communication import Drop, Gather
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
import math


Expand All @@ -11,20 +12,34 @@ def divide(a, b):
return a // b


def extract_local_params_from_full_params(
full_params, out_features_group, in_features_group
):
params = Drop.apply(torch.t(full_params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params, in_features_group)
return params


@torch.no_grad()
def initialize_params(
out_features, in_features, out_features_group, in_features_group, init_method
):
params = torch.empty((out_features, in_features))
init_method(params)
params = Drop.apply(torch.t(params).contiguous(), out_features_group)
params = torch.t(params).contiguous()
params = Drop.apply(params, in_features_group)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group
)
return params


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class AsyncLinear(Function):
@staticmethod
@custom_fwd
def forward(
ctx,
input_,
Expand All @@ -41,6 +56,7 @@ def forward(
return output

@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
handle = None
Expand All @@ -53,7 +69,7 @@ def backward(ctx, grad_output):
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.view(-1, grad_output.shape[-1])
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)
Expand All @@ -62,17 +78,14 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, None, None, None


def default_init_method(weight):
return torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))


class Linear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
*args,
transpose=False,
bias=True,
skip_bias_add=False,
init_method=None,
async_comm_in_backward_pass=True,
Expand All @@ -84,6 +97,10 @@ def __init__(

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)

self.in_features = in_features
self.out_features = out_features

self.async_comm_in_backward_pass = async_comm_in_backward_pass

if init_method is None:
Expand Down Expand Up @@ -116,35 +133,136 @@ def __init__(

self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
)
setattr(self.weight, "is_tensor_parallel", True)
setattr(
self.weight,
"process_group_for_norm_reduction",
ax.comm_handle.intra_layer_group,
)

if bias:
self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
)
)
setattr(self.bias, "is_tensor_parallel", True)
if not transpose:
setattr(
self.bias,
"process_group_for_norm_reduction",
ax.comm_handle.outer_intra_layer_parallel_group,
)
else:
setattr(
self.bias,
"process_group_for_norm_reduction",
ax.comm_handle.inner_intra_layer_parallel_group,
)
else:
self.bias = None

self.transpose = transpose
self.skip_bias_add = skip_bias_add
self._old_load_from_state_dict = self._load_from_state_dict
self._load_from_state_dict = self._modified_load_from_state_dict

def get_output_feature_size(self):
return self.local_out_features

def forward(self, x):
def forward(self, x, scatter_input=True, gather_output=True):
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
x = AsyncLinear.apply(
x,
self.weight,
self.inner_group,
self.outer_group,
self.async_comm_in_backward_pass,
)
if gather_output:
x = Gather.apply(x, self.outer_group)
else:
if scatter_input:
x = Drop.apply(x, self.outer_group)
x = AsyncLinear.apply(
x,
self.weight,
self.outer_group,
self.inner_group,
self.async_comm_in_backward_pass,
)
if self.skip_bias_add:
return x, self.bias
if gather_output:
x = Gather.apply(x, self.inner_group)

if self.bias is None:
return x
else:
return x + self.bias
bias = self.bias
if gather_output:
bias = Gather.apply(
self.bias,
self.outer_group if not self.transpose else self.inner_group,
)
if self.skip_bias_add:
return x, bias
else:
return x + bias

def _is_full_weight_matrix(self, weight):
return (weight.size(0) == self.out_features) and (
weight.size(1) == self.in_features
)

def _is_sharded_weight_matrix(self, weight):
return (weight.size(0) == self.local_out_features) and (
weight.size(1) == self.local_in_features
)

@torch.no_grad()
def _modified_load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
weight = (
state_dict[prefix + "weight"] if prefix + "weight" in state_dict else None
)

if weight is not None:
is_full_weight_matrix = self._is_full_weight_matrix(weight)
is_sharded_weight_matrix = self._is_sharded_weight_matrix(weight)

assert (
is_full_weight_matrix or is_sharded_weight_matrix
), "This is neither a full checkpoint nor a sharded checkpoint"

if is_full_weight_matrix:
out_features_group, in_features_group = (
self.outer_group,
self.inner_group,
)
if self.transpose:
out_features_group, in_features_group = (
self.inner_group,
self.outer_group,
)
weight = extract_local_params_from_full_params(
weight, out_features_group, in_features_group
)
state_dict[prefix + "weight"] = weight

if self.bias is not None:
bias = (
state_dict[prefix + "bias"] if prefix + "bias" in state_dict else None
)
if bias is not None:
if bias.size(0) == self.out_features:
bias = Drop.apply(
bias,
self.outer_group if not self.transpose else self.inner_group,
)
state_dict[prefix + "bias"] = bias
else:
assert (
bias.size(0) == self.local_out_features
), "This is neither a full checkpoint nor a sharded checkpoint"

self._old_load_from_state_dict(state_dict, prefix, *args, **kwargs)
90 changes: 90 additions & 0 deletions axonn/intra_layer/gradient_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch

# for backwards compatibility with pytorch 1.13
try:
from torch._six import inf
except ImportError:
from torch import inf

import torch.distributed as dist
from collections import defaultdict


def get_total_norm(tensors, norm_type, error_if_nonfinite):
if len(tensors) == 0:
return torch.tensor(0.0)
device = tensors[0].device
total_norm = torch.norm(
torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in tensors]),
norm_type,
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)

return total_norm


def clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False):
if norm_type == inf:
raise NotImplementedError

if isinstance(parameters, torch.Tensor):
parameters = [parameters]

tensor_parallel_params = defaultdict(list)
non_tensor_parallel_params = []
for p in parameters:
if hasattr(p, "is_tensor_parallel") and p.is_tensor_parallel:
assert hasattr(
p, "process_group_for_norm_reduction"
), "each tensor parallel tensor should"
"have a process group for all-reducing norms"
tensor_parallel_params[p.process_group_for_norm_reduction].append(p)
else:
non_tensor_parallel_params.append(p)

tensor_parallel_grads = {}
for process_group, group_params in tensor_parallel_params.items():
tensor_parallel_grads[process_group] = [
p.grad for p in group_params if p.grad is not None
]

non_tensor_parallel_grads = [
p.grad for p in non_tensor_parallel_params if p.grad is not None
]

max_norm = float(max_norm)
norm_type = float(norm_type)

non_tensor_parallel_norm = get_total_norm(
non_tensor_parallel_grads, norm_type, error_if_nonfinite
)

tensor_parallel_norms = []
for process_group, grads in tensor_parallel_grads.items():
local_tensor_parallel_norm = get_total_norm(
grads, norm_type, error_if_nonfinite
)
tensor_parallel_norm = local_tensor_parallel_norm**norm_type
dist.all_reduce(tensor_parallel_norm, group=process_group)
tensor_parallel_norm = tensor_parallel_norm ** (1.0 / norm_type)
tensor_parallel_norms.append(tensor_parallel_norm)

all_norms = tensor_parallel_norms + [non_tensor_parallel_norm]
total_norm = get_total_norm(all_norms, norm_type, error_if_nonfinite)

clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in non_tensor_parallel_grads:
g.detach().mul_(clip_coef_clamped.to(g.device))

for group_grads in tensor_parallel_grads.values():
for g in group_grads:
g.detach().mul_(clip_coef_clamped.to(g.device))

return total_norm
Loading
Loading