From 9580411e1dcec34ce0226a5a688ed65c32422b8b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 18 Sep 2024 13:58:35 +0200 Subject: [PATCH] fix apex fused rms (#1166) --- thunder/core/jit_ext.py | 45 ++++++++- thunder/executors/apex_fused_rms_norm_impl.py | 99 +++++-------------- 2 files changed, 67 insertions(+), 77 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index d0524c87a..8c67b5d3b 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -644,6 +644,31 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar ) orig_scopes[-1].append(bsym_of_custom_fwd) + # not augmented for when we don't need grad + trace_of_fwd = TraceCtx() + for bsym in custom_fwd_bsyms: + trace_of_fwd.add_bound_symbol(bsym) + with tracectx(trace_of_fwd): + prims.python_return(unwrapped_custom_forward_result) + + si = SigInfo(custom_fwd_sym.name) + for a in unwrapped_custom_forward_args: + if isinstance(a, Proxy): + si.args.append((a.name, None)) + else: + pa = proxy(a) + si.args.append((pa.name, None)) + trace_of_fwd._siginfo = si + trace_of_fwd.args = unwrapped_custom_forward_args + + @wraps(trace_of_fwd.python_callable()) + def core_of_forward(*args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs) + + thunder.executors.torchex._register_implementation( + custom_fwd_sym, core_of_forward, checker=thunder.executors.torchex._always_executable + ) + augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = ( tuple(sequencify(unwrapped_custom_forward_result)), ctx_proxy.saved_tensors, @@ -660,7 +685,13 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar pa = proxy(a) si.args.append((pa.name, None)) trace_of_augmented_fwd._siginfo = si - core_of_augmented_forward = trace_of_augmented_fwd.python_callable(include_decorators=False) + trace_of_augmented_fwd.args = unwrapped_custom_forward_args + + @wraps(trace_of_augmented_fwd.python_callable()) + def core_of_augmented_forward(*args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, *args, **kwargs) + + # core_of_augmented_forward = trace_of_augmented_fwd.python_callable(include_decorators=False) @wraps(core_of_augmented_forward) def augmented_custom_forward_rule(*args, **kwargs): @@ -695,7 +726,11 @@ def augmented_custom_forward_rule(*args, **kwargs): pa = proxy(a) bwd_si.args.append((pa.name, None)) trace_of_backward._siginfo = bwd_si - bwd_trace_callable_interface = trace_of_backward.python_callable(include_decorators=False) + trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads) + + @wraps(trace_of_backward.python_callable()) + def bwd_trace_callable_interface(*args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs) bwd_si = SigInfo("backward_impl") for a in ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads: @@ -709,7 +744,11 @@ def augmented_custom_forward_rule(*args, **kwargs): bwd_trace_impl.add_bound_symbol(bsym) bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*sequencify(unwrap(custom_backward_result)), output=())) bwd_trace_impl._siginfo = bwd_si - bwd_impl_callable = bwd_trace_impl.python_callable(include_decorators=False) + bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads) + + @wraps(bwd_trace_impl.python_callable()) + def bwd_impl_callable(*args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(bwd_trace_impl, *args, **kwargs) @wraps(bwd_trace_callable_interface) def backward_impl(*args, **kwargs): diff --git a/thunder/executors/apex_fused_rms_norm_impl.py b/thunder/executors/apex_fused_rms_norm_impl.py index be90c76d3..cb2e51fd7 100644 --- a/thunder/executors/apex_fused_rms_norm_impl.py +++ b/thunder/executors/apex_fused_rms_norm_impl.py @@ -24,96 +24,47 @@ def apex_fused_norms_available() -> bool: return APEX_FUSED_NORMS_AVAILABLE -def meta_impl_fn( - input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool +def apex_fused_rms_norm_forward_affine_meta( + input: TensorLike, normalized_shape: Sequence[int], weight: TensorLike, eps: float ): output_or_input = TensorProxy(like=input) weight = TensorProxy(like=input, shape=normalized_shape) unnormalized_dims = len(input.shape) - len(normalized_shape) invvar = TensorProxy(like=input, shape=(math.prod(input.shape[:unnormalized_dims]),)) - return TensorProxy(like=input), (output_or_input, weight, invvar), AnyProxy(object()) + return TensorProxy(like=input), invvar -def fused_rms_norm_impl( - input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool -): - ctx = Context() - output = FusedRMSNormAffineMixedDtypesFunction.forward(ctx, input, weight, normalized_shape, eps, memory_efficient) - return output, ctx.pop_saved_tensors(), ctx - - -fused_rms_norm_fwd = apex_ex.register_operator("fused_rms_norm_fwd", meta=meta_impl_fn, fn=fused_rms_norm_impl) - - -def fused_rms_norm_backward_meta(saved_tensors: Sequence[torch.Tensor], ctx: Context, g: TensorLike): - # saved_tensors[0] - input or output - # saved_tensors[1] - weight - return TensorProxy(like=saved_tensors[0]), TensorProxy(like=saved_tensors[1]) - - -def fused_rms_norm_backward_impl(saved_tensors: Sequence[torch.Tensor], ctx: Context, g: TensorLike): - with set_saved_tensors(ctx, saved_tensors): - return FusedRMSNormAffineMixedDtypesFunction.backward(ctx, g)[:2] - - -fused_rms_norm_backward = apex_ex.register_operator( - "fused_rms_norm_backward", meta=fused_rms_norm_backward_meta, fn=fused_rms_norm_backward_impl -) - - -def fused_rms_norm_grad_rule( - ctx, input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool -): - output, saved_tensors, saved_meta = fused_rms_norm_fwd(input, weight, normalized_shape, eps, memory_efficient) - g = get_grad(output) - grad_input, grad_weight = fused_rms_norm_backward(saved_tensors, saved_meta, g) - put_grads((input, weight), (grad_input, grad_weight)) - return output - - -def execution_tfms( - ctx, input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool -): - output, _, _ = fused_rms_norm_fwd(input, weight, normalized_shape, eps, memory_efficient) - return output - - -def _fused_rms_norm_checker( - ctx: torch.autograd.Function, - input: TensorLike, - weight: TensorLike, +def apex_fused_rms_norm_backward_affine_meta( + grad_output: TensorLike, + invvar: TensorLike, + input_or_output: TensorLike, normalized_shape: Sequence[int], + weight_, eps: float, memory_efficient: bool, ): - use_apex_fused_rms_norm = get_compile_option( - "use_apex_fused_rms_norm", "Whether to enable `fused_rms_norm` from `apex_ex`. Defaults to `True`." - ) - # We explicitly check for `False` as if the value is unspecified by user, `get_compile_option` returns `None` and `not None` is equal to True. - if use_apex_fused_rms_norm == False: # User explicitly disabled this. - return False - - # use_apex_fused_rms_norm is `None` or `True`. - return True + return TensorProxy(like=grad_output), TensorProxy(like=weight_) # Create a new symbol and register lookaside only if import is available. if apex_fused_norms_available(): + apex_fused_rms_norm_forward_affine = apex_ex.register_operator( + "apex_fused_rms_norm_forward_affine", + meta=apex_fused_rms_norm_forward_affine_meta, + fn=fused_layer_norm_cuda.rms_forward_affine, + replaces=fused_layer_norm_cuda.rms_forward_affine, + ) - def meta_fn( - ctx, input: TensorLike, weight: TensorLike, normalized_shape: Sequence[int], eps: float, memory_efficient: bool - ): - return TensorProxy(like=input) - - # Symbol which will be used by lookaside. - fused_rms_norm = apex_ex.register_operator( - "fused_rms_norm", meta=meta_fn, replaces=FusedRMSNormAffineMixedDtypesFunction.forward + apex_fused_rms_norm_forward_affine_mixed_dtypes = apex_ex.register_operator( + "apex_fused_rms_norm_forward_affine_mixed_dtypes", + meta=apex_fused_rms_norm_forward_affine_meta, + fn=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes, + replaces=fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes, ) - apex_ex.register_implementation( - fused_rms_norm, - execution_transform=execution_tfms, - grad_transform=fused_rms_norm_grad_rule, - checker=_fused_rms_norm_checker, + apex_fused_rms_norm_backward_affine = apex_ex.register_operator( + "apex_fused_rms_norm_backward_affine", + meta=apex_fused_rms_norm_backward_affine_meta, + fn=fused_layer_norm_cuda.rms_backward_affine, + replaces=fused_layer_norm_cuda.rms_backward_affine, ) - apex_ex.register_implementation(fused_rms_norm_backward, fused_rms_norm_backward)