Skip to content

Commit

Permalink
fix apex fused rms (#1166)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Sep 18, 2024
1 parent ea96657 commit 9580411
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 77 deletions.
45 changes: 42 additions & 3 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,31 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
)
orig_scopes[-1].append(bsym_of_custom_fwd)

# not augmented for when we don't need grad
trace_of_fwd = TraceCtx()
for bsym in custom_fwd_bsyms:
trace_of_fwd.add_bound_symbol(bsym)
with tracectx(trace_of_fwd):
prims.python_return(unwrapped_custom_forward_result)

si = SigInfo(custom_fwd_sym.name)
for a in unwrapped_custom_forward_args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, None))
trace_of_fwd._siginfo = si
trace_of_fwd.args = unwrapped_custom_forward_args

@wraps(trace_of_fwd.python_callable())
def core_of_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs)

thunder.executors.torchex._register_implementation(
custom_fwd_sym, core_of_forward, checker=thunder.executors.torchex._always_executable
)

augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(unwrapped_custom_forward_result)),
ctx_proxy.saved_tensors,
Expand All @@ -660,7 +685,13 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
pa = proxy(a)
si.args.append((pa.name, None))
trace_of_augmented_fwd._siginfo = si
core_of_augmented_forward = trace_of_augmented_fwd.python_callable(include_decorators=False)
trace_of_augmented_fwd.args = unwrapped_custom_forward_args

@wraps(trace_of_augmented_fwd.python_callable())
def core_of_augmented_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, *args, **kwargs)

# core_of_augmented_forward = trace_of_augmented_fwd.python_callable(include_decorators=False)

@wraps(core_of_augmented_forward)
def augmented_custom_forward_rule(*args, **kwargs):
Expand Down Expand Up @@ -695,7 +726,11 @@ def augmented_custom_forward_rule(*args, **kwargs):
pa = proxy(a)
bwd_si.args.append((pa.name, None))
trace_of_backward._siginfo = bwd_si
bwd_trace_callable_interface = trace_of_backward.python_callable(include_decorators=False)
trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads)

@wraps(trace_of_backward.python_callable())
def bwd_trace_callable_interface(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs)

bwd_si = SigInfo("backward_impl")
for a in ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads:
Expand All @@ -709,7 +744,11 @@ def augmented_custom_forward_rule(*args, **kwargs):
bwd_trace_impl.add_bound_symbol(bsym)
bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*sequencify(unwrap(custom_backward_result)), output=()))
bwd_trace_impl._siginfo = bwd_si
bwd_impl_callable = bwd_trace_impl.python_callable(include_decorators=False)
bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads)

@wraps(bwd_trace_impl.python_callable())
def bwd_impl_callable(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(bwd_trace_impl, *args, **kwargs)

@wraps(bwd_trace_callable_interface)
def backward_impl(*args, **kwargs):
Expand Down
99 changes: 25 additions & 74 deletions thunder/executors/apex_fused_rms_norm_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,96 +24,47 @@ def apex_fused_norms_available() -> bool:
return APEX_FUSED_NORMS_AVAILABLE


def meta_impl_fn(
input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool
def apex_fused_rms_norm_forward_affine_meta(
input: TensorLike, normalized_shape: Sequence[int], weight: TensorLike, eps: float
):
output_or_input = TensorProxy(like=input)
weight = TensorProxy(like=input, shape=normalized_shape)
unnormalized_dims = len(input.shape) - len(normalized_shape)
invvar = TensorProxy(like=input, shape=(math.prod(input.shape[:unnormalized_dims]),))
return TensorProxy(like=input), (output_or_input, weight, invvar), AnyProxy(object())
return TensorProxy(like=input), invvar


def fused_rms_norm_impl(
input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool
):
ctx = Context()
output = FusedRMSNormAffineMixedDtypesFunction.forward(ctx, input, weight, normalized_shape, eps, memory_efficient)
return output, ctx.pop_saved_tensors(), ctx


fused_rms_norm_fwd = apex_ex.register_operator("fused_rms_norm_fwd", meta=meta_impl_fn, fn=fused_rms_norm_impl)


def fused_rms_norm_backward_meta(saved_tensors: Sequence[torch.Tensor], ctx: Context, g: TensorLike):
# saved_tensors[0] - input or output
# saved_tensors[1] - weight
return TensorProxy(like=saved_tensors[0]), TensorProxy(like=saved_tensors[1])


def fused_rms_norm_backward_impl(saved_tensors: Sequence[torch.Tensor], ctx: Context, g: TensorLike):
with set_saved_tensors(ctx, saved_tensors):
return FusedRMSNormAffineMixedDtypesFunction.backward(ctx, g)[:2]


fused_rms_norm_backward = apex_ex.register_operator(
"fused_rms_norm_backward", meta=fused_rms_norm_backward_meta, fn=fused_rms_norm_backward_impl
)


def fused_rms_norm_grad_rule(
ctx, input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool
):
output, saved_tensors, saved_meta = fused_rms_norm_fwd(input, weight, normalized_shape, eps, memory_efficient)
g = get_grad(output)
grad_input, grad_weight = fused_rms_norm_backward(saved_tensors, saved_meta, g)
put_grads((input, weight), (grad_input, grad_weight))
return output


def execution_tfms(
ctx, input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool
):
output, _, _ = fused_rms_norm_fwd(input, weight, normalized_shape, eps, memory_efficient)
return output


def _fused_rms_norm_checker(
ctx: torch.autograd.Function,
input: TensorLike,
weight: TensorLike,
def apex_fused_rms_norm_backward_affine_meta(
grad_output: TensorLike,
invvar: TensorLike,
input_or_output: TensorLike,
normalized_shape: Sequence[int],
weight_,
eps: float,
memory_efficient: bool,
):
use_apex_fused_rms_norm = get_compile_option(
"use_apex_fused_rms_norm", "Whether to enable `fused_rms_norm` from `apex_ex`. Defaults to `True`."
)
# We explicitly check for `False` as if the value is unspecified by user, `get_compile_option` returns `None` and `not None` is equal to True.
if use_apex_fused_rms_norm == False: # User explicitly disabled this.
return False

# use_apex_fused_rms_norm is `None` or `True`.
return True
return TensorProxy(like=grad_output), TensorProxy(like=weight_)


# Create a new symbol and register lookaside only if import is available.
if apex_fused_norms_available():
apex_fused_rms_norm_forward_affine = apex_ex.register_operator(
"apex_fused_rms_norm_forward_affine",
meta=apex_fused_rms_norm_forward_affine_meta,
fn=fused_layer_norm_cuda.rms_forward_affine,
replaces=fused_layer_norm_cuda.rms_forward_affine,
)

def meta_fn(
ctx, input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool
):
return TensorProxy(like=input)

# Symbol which will be used by lookaside.
fused_rms_norm = apex_ex.register_operator(
"fused_rms_norm", meta=meta_fn, replaces=FusedRMSNormAffineMixedDtypesFunction.forward
apex_fused_rms_norm_forward_affine_mixed_dtypes = apex_ex.register_operator(
"apex_fused_rms_norm_forward_affine_mixed_dtypes",
meta=apex_fused_rms_norm_forward_affine_meta,
fn=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes,
replaces=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes,
)

apex_ex.register_implementation(
fused_rms_norm,
execution_transform=execution_tfms,
grad_transform=fused_rms_norm_grad_rule,
checker=_fused_rms_norm_checker,
apex_fused_rms_norm_backward_affine = apex_ex.register_operator(
"apex_fused_rms_norm_backward_affine",
meta=apex_fused_rms_norm_backward_affine_meta,
fn=fused_layer_norm_cuda.rms_backward_affine,
replaces=fused_layer_norm_cuda.rms_backward_affine,
)
apex_ex.register_implementation(fused_rms_norm_backward, fused_rms_norm_backward)

0 comments on commit 9580411

Please sign in to comment.