From 464b23bdc8a3b80d4ef115c3aca5b73ccef5b0dc Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Sun, 29 Sep 2024 11:55:30 +0000 Subject: [PATCH] add matmul mininal tile size --- .../Evaluator/TIR/CPU/Matmul.cs | 25 +++++++++++ .../Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs | 39 ++++++++++++++++++ .../Passes/Rules/CPU/PackRule.cs | 41 ++++++++----------- 3 files changed, 80 insertions(+), 25 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs index 549dcbbbe..8bf092a53 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs @@ -3,6 +3,7 @@ using Nncase.IR; using Nncase.IR.Affine; +using Nncase.IR.CPU; using Nncase.Schedule; using Nncase.TIR.CPU; @@ -17,6 +18,30 @@ public MicroKernelInfo Visit(Matmul op, MicroKernelContext context) var domain = context.AccessMaps[0].Domains; var primitives = Enumerable.Repeat(1, domain.Length).ToArray(); var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray(); + + var (k, m, n) = (context.BufferShapes[0][^1], context.BufferShapes[2][^2], context.BufferShapes[2][^1]); + var (lpack, rpack) = PackedMatMul.GetPackKind(op.LhsPackedAxes, op.RhsPackedAxes); + switch (lpack, rpack) + { + case (PackedMatMul.PackKind.M | PackedMatMul.PackKind.K, PackedMatMul.PackKind.K | PackedMatMul.PackKind.N): + if (m % 2 == 0) + { + multipliers[^3].Min = 2; + } + + if (k % 2 == 0) + { + multipliers[^2].Min = 2; + } + + if (n % 4 == 0) + { + multipliers[^1].Min = 4; + } + + break; + } + var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length]; var opt = (ICpuTargetOptions)context.TargetOptions; bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read); diff --git a/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs b/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs index 6d0106040..569364f11 100644 --- a/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs +++ b/modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs @@ -18,6 +18,15 @@ public sealed partial class PackedMatMul : PackedOp /// public static readonly ParameterInfo Rhs = new(typeof(PackedMatMul), 1, "rhs", ParameterKind.Input); + [Flags] + public enum PackKind : byte + { + None = 1 << 0, + M = 1 << 1, + K = 1 << 2, + N = 1 << 3, + } + public IRArray LhsPackedAxes { get; } public IRArray LhsPadedNums { get; } @@ -30,5 +39,35 @@ public sealed partial class PackedMatMul : PackedOp public bool TransposeB { get; } + public static (PackKind Lhs, PackKind Rhs) GetPackKind(IRArray lhsPackedAxes, IRArray rhsPackedAxes) + { + switch (lhsPackedAxes.Count, rhsPackedAxes.Count) + { + case (0, 0): + return (PackKind.None, PackKind.None); + case (0, 1): + return (PackKind.None, PackKind.N); + case (1, 0): + return (PackKind.M, PackKind.None); + case (1, 1): + return (PackKind.M, PackKind.N); + case (1, 2): + return (PackKind.K, PackKind.K | PackKind.N); + case (2, 1): + return (PackKind.M | PackKind.K, PackKind.K); + case (2, 2): + return (PackKind.M | PackKind.K, PackKind.K | PackKind.N); + default: + throw new NotSupportedException($"{lhsPackedAxes.Count}, {rhsPackedAxes.Count}"); + } + } + + public static ((int LM, int LK) Lhs, (int RK, int RN) Rhs) GetAxes(PackedMatMul target, int[] lhs, int[] rhs) + { + var (lm, lk) = target.TransposeA ? (lhs.Rank - 1, lhs.Rank - 2) : (lhs.Rank - 2, lhs.Rank - 1); + var (rk, rn) = target.TransposeB ? (rhs.Rank - 1, rhs.Rank - 2) : (rhs.Rank - 2, rhs.Rank - 1); + return ((lm, lk), (rk, rn)); + } + public override string DisplayProperty() => $"LhsPackedAxes: {LhsPackedAxes}, LhsPadedNums: {LhsPadedNums}, RhsPackedAxes: {RhsPackedAxes}, RhsPadedNums: {RhsPadedNums}, TransposeA: {TransposeA}, TransposeB: {TransposeB}"; } diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs index 77989e308..887c07029 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs @@ -151,15 +151,6 @@ public PackMatMul(int rank, int lane) { } - [Flags] - public enum PackKind : byte - { - None = 1 << 0, - M = 1 << 1, - K = 1 << 2, - N = 1 << 3, - } - public override Pattern Pattern { get; } = IsMatMul( "target", IsWildcard("lhs", e => e is not Call { Target: IR.CPU.Unpack }) with { TypePattern = IsFloat() & !IsVector() }, @@ -182,26 +173,26 @@ public override List GetReplaceCandidates(IMatchResult result, RunPassCont // AddCandidate(rcontext, PackKind.M, PackKind.None); // only pack B's n - AddCandidate(rcontext, PackKind.None, PackKind.N, transB: rhs is Const); + AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.None, IR.CPU.PackedMatMul.PackKind.N/* , transB: rhs is Const */); if (Rank > 1) { // pack A's m and B's n, when B is const, force transpose - AddCandidate(rcontext, PackKind.M, PackKind.N, transB: rhs is Const); + AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.M, IR.CPU.PackedMatMul.PackKind.N/* , transB: rhs is Const */); // pack A's m,k and B's k,n - AddCandidate(rcontext, PackKind.M | PackKind.K, PackKind.K | PackKind.N, transB: rhs is Const); + AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.M | IR.CPU.PackedMatMul.PackKind.K, IR.CPU.PackedMatMul.PackKind.K | IR.CPU.PackedMatMul.PackKind.N/* , transB: rhs is Const */); // pack A's m,k and B's k - // AddCandidate(rcontext, PackKind.M | PackKind.K, PackKind.K); + // AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.M | IR.CPU.PackedMatMul.PackKind.K, IR.CPU.PackedMatMul.PackKind.K); // pack A's k and B's k,n - AddCandidate(rcontext, PackKind.K, PackKind.K | PackKind.N, transB: lhs is Const); + AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.K, IR.CPU.PackedMatMul.PackKind.K | IR.CPU.PackedMatMul.PackKind.N/* , transB: lhs is Const */); } return rets; } - private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPack, bool transA = false, bool transB = false) + private void AddCandidate(RuleContext context, IR.CPU.PackedMatMul.PackKind lhsPack, IR.CPU.PackedMatMul.PackKind rhsPack, bool transA = false, bool transB = false) { var (rets, lhs, rhs, candidate, _, _) = context; var lhsShape = context.LhsShape.ToArray(); @@ -228,19 +219,19 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac var (rk, rn) = transB ? (rhsShape.Length - 1, rhsShape.Length - 2) : (rhsShape.Length - 2, rhsShape.Length - 1); switch (lhsPack) { - case PackKind.None: + case IR.CPU.PackedMatMul.PackKind.None: lhsLanes = Array.Empty(); lhsPackedAxes = Array.Empty(); break; - case PackKind.M: + case IR.CPU.PackedMatMul.PackKind.M: lhsLanes = [Lane]; lhsPackedAxes = [lm]; break; - case PackKind.K: + case IR.CPU.PackedMatMul.PackKind.K: lhsLanes = [Lane]; lhsPackedAxes = [lk]; break; - case PackKind.M | PackKind.K: + case IR.CPU.PackedMatMul.PackKind.M | IR.CPU.PackedMatMul.PackKind.K: lhsLanes = [Lane, Lane]; lhsPackedAxes = [lm, lk]; break; @@ -252,19 +243,19 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac int[] rhsPackedAxes; switch (rhsPack) { - case PackKind.None: + case IR.CPU.PackedMatMul.PackKind.None: rhsLanes = Array.Empty(); rhsPackedAxes = Array.Empty(); break; - case PackKind.N: + case IR.CPU.PackedMatMul.PackKind.N: rhsLanes = [Lane]; rhsPackedAxes = [rn]; break; - case PackKind.K: + case IR.CPU.PackedMatMul.PackKind.K: rhsLanes = [Lane]; rhsPackedAxes = [rk]; break; - case PackKind.K | PackKind.N: + case IR.CPU.PackedMatMul.PackKind.K | IR.CPU.PackedMatMul.PackKind.N: rhsLanes = [Lane, Lane]; rhsPackedAxes = [rk, rn]; break; @@ -290,7 +281,7 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac var unpackAxes = new List(); var unpadNums = new List(); var unpackLanes = new List(); - if (lhsPack.HasFlag(PackKind.M)) + if (lhsPack.HasFlag(IR.CPU.PackedMatMul.PackKind.M)) { var mPackIndex = Array.IndexOf(lhsPackedAxes, lm); unpackAxes.Add(outRank - 2); @@ -298,7 +289,7 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac unpackLanes.Add(Lane); } - if (rhsPack.HasFlag(PackKind.N)) + if (rhsPack.HasFlag(IR.CPU.PackedMatMul.PackKind.N)) { var nPackIndex = Array.IndexOf(rhsPackedAxes, rn); unpackAxes.Add(outRank - 1);