Skip to content

Commit

Permalink
add matmul mininal tile size
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Sep 29, 2024
1 parent 1e0de9b commit 464b23b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
25 changes: 25 additions & 0 deletions modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Matmul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.IR.CPU;
using Nncase.Schedule;
using Nncase.TIR.CPU;

Expand All @@ -17,6 +18,30 @@ public MicroKernelInfo Visit(Matmul op, MicroKernelContext context)
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();

var (k, m, n) = (context.BufferShapes[0][^1], context.BufferShapes[2][^2], context.BufferShapes[2][^1]);
var (lpack, rpack) = PackedMatMul.GetPackKind(op.LhsPackedAxes, op.RhsPackedAxes);
switch (lpack, rpack)
{
case (PackedMatMul.PackKind.M | PackedMatMul.PackKind.K, PackedMatMul.PackKind.K | PackedMatMul.PackKind.N):
if (m % 2 == 0)
{
multipliers[^3].Min = 2;
}

if (k % 2 == 0)
{
multipliers[^2].Min = 2;
}

if (n % 4 == 0)
{
multipliers[^1].Min = 4;
}

break;
}

var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
Expand Down
39 changes: 39 additions & 0 deletions modules/Nncase.Modules.CPU/IR/CPU/PackedMatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ public sealed partial class PackedMatMul : PackedOp
/// </summary>
public static readonly ParameterInfo Rhs = new(typeof(PackedMatMul), 1, "rhs", ParameterKind.Input);

[Flags]
public enum PackKind : byte
{
None = 1 << 0,
M = 1 << 1,
K = 1 << 2,
N = 1 << 3,
}

public IRArray<int> LhsPackedAxes { get; }

public IRArray<int> LhsPadedNums { get; }
Expand All @@ -30,5 +39,35 @@ public sealed partial class PackedMatMul : PackedOp

public bool TransposeB { get; }

public static (PackKind Lhs, PackKind Rhs) GetPackKind(IRArray<int> lhsPackedAxes, IRArray<int> rhsPackedAxes)
{
switch (lhsPackedAxes.Count, rhsPackedAxes.Count)
{
case (0, 0):
return (PackKind.None, PackKind.None);
case (0, 1):
return (PackKind.None, PackKind.N);
case (1, 0):
return (PackKind.M, PackKind.None);
case (1, 1):
return (PackKind.M, PackKind.N);
case (1, 2):
return (PackKind.K, PackKind.K | PackKind.N);
case (2, 1):
return (PackKind.M | PackKind.K, PackKind.K);
case (2, 2):
return (PackKind.M | PackKind.K, PackKind.K | PackKind.N);
default:
throw new NotSupportedException($"{lhsPackedAxes.Count}, {rhsPackedAxes.Count}");
}
}

public static ((int LM, int LK) Lhs, (int RK, int RN) Rhs) GetAxes(PackedMatMul target, int[] lhs, int[] rhs)
{
var (lm, lk) = target.TransposeA ? (lhs.Rank - 1, lhs.Rank - 2) : (lhs.Rank - 2, lhs.Rank - 1);
var (rk, rn) = target.TransposeB ? (rhs.Rank - 1, rhs.Rank - 2) : (rhs.Rank - 2, rhs.Rank - 1);
return ((lm, lk), (rk, rn));
}

public override string DisplayProperty() => $"LhsPackedAxes: {LhsPackedAxes}, LhsPadedNums: {LhsPadedNums}, RhsPackedAxes: {RhsPackedAxes}, RhsPadedNums: {RhsPadedNums}, TransposeA: {TransposeA}, TransposeB: {TransposeB}";
}
41 changes: 16 additions & 25 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,6 @@ public PackMatMul(int rank, int lane)
{
}

[Flags]
public enum PackKind : byte
{
None = 1 << 0,
M = 1 << 1,
K = 1 << 2,
N = 1 << 3,
}

public override Pattern Pattern { get; } = IsMatMul(
"target",
IsWildcard("lhs", e => e is not Call { Target: IR.CPU.Unpack }) with { TypePattern = IsFloat() & !IsVector() },
Expand All @@ -182,26 +173,26 @@ public override List<Expr> GetReplaceCandidates(IMatchResult result, RunPassCont
// AddCandidate(rcontext, PackKind.M, PackKind.None);

// only pack B's n
AddCandidate(rcontext, PackKind.None, PackKind.N, transB: rhs is Const);
AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.None, IR.CPU.PackedMatMul.PackKind.N/* , transB: rhs is Const */);
if (Rank > 1)
{
// pack A's m and B's n, when B is const, force transpose
AddCandidate(rcontext, PackKind.M, PackKind.N, transB: rhs is Const);
AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.M, IR.CPU.PackedMatMul.PackKind.N/* , transB: rhs is Const */);

// pack A's m,k and B's k,n
AddCandidate(rcontext, PackKind.M | PackKind.K, PackKind.K | PackKind.N, transB: rhs is Const);
AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.M | IR.CPU.PackedMatMul.PackKind.K, IR.CPU.PackedMatMul.PackKind.K | IR.CPU.PackedMatMul.PackKind.N/* , transB: rhs is Const */);

// pack A's m,k and B's k
// AddCandidate(rcontext, PackKind.M | PackKind.K, PackKind.K);
// AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.M | IR.CPU.PackedMatMul.PackKind.K, IR.CPU.PackedMatMul.PackKind.K);

// pack A's k and B's k,n
AddCandidate(rcontext, PackKind.K, PackKind.K | PackKind.N, transB: lhs is Const);
AddCandidate(rcontext, IR.CPU.PackedMatMul.PackKind.K, IR.CPU.PackedMatMul.PackKind.K | IR.CPU.PackedMatMul.PackKind.N/* , transB: lhs is Const */);
}

return rets;
}

private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPack, bool transA = false, bool transB = false)
private void AddCandidate(RuleContext context, IR.CPU.PackedMatMul.PackKind lhsPack, IR.CPU.PackedMatMul.PackKind rhsPack, bool transA = false, bool transB = false)
{
var (rets, lhs, rhs, candidate, _, _) = context;
var lhsShape = context.LhsShape.ToArray();
Expand All @@ -228,19 +219,19 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac
var (rk, rn) = transB ? (rhsShape.Length - 1, rhsShape.Length - 2) : (rhsShape.Length - 2, rhsShape.Length - 1);
switch (lhsPack)
{
case PackKind.None:
case IR.CPU.PackedMatMul.PackKind.None:
lhsLanes = Array.Empty<int>();
lhsPackedAxes = Array.Empty<int>();
break;
case PackKind.M:
case IR.CPU.PackedMatMul.PackKind.M:
lhsLanes = [Lane];
lhsPackedAxes = [lm];
break;
case PackKind.K:
case IR.CPU.PackedMatMul.PackKind.K:
lhsLanes = [Lane];
lhsPackedAxes = [lk];
break;
case PackKind.M | PackKind.K:
case IR.CPU.PackedMatMul.PackKind.M | IR.CPU.PackedMatMul.PackKind.K:
lhsLanes = [Lane, Lane];
lhsPackedAxes = [lm, lk];
break;
Expand All @@ -252,19 +243,19 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac
int[] rhsPackedAxes;
switch (rhsPack)
{
case PackKind.None:
case IR.CPU.PackedMatMul.PackKind.None:
rhsLanes = Array.Empty<int>();
rhsPackedAxes = Array.Empty<int>();
break;
case PackKind.N:
case IR.CPU.PackedMatMul.PackKind.N:
rhsLanes = [Lane];
rhsPackedAxes = [rn];
break;
case PackKind.K:
case IR.CPU.PackedMatMul.PackKind.K:
rhsLanes = [Lane];
rhsPackedAxes = [rk];
break;
case PackKind.K | PackKind.N:
case IR.CPU.PackedMatMul.PackKind.K | IR.CPU.PackedMatMul.PackKind.N:
rhsLanes = [Lane, Lane];
rhsPackedAxes = [rk, rn];
break;
Expand All @@ -290,15 +281,15 @@ private void AddCandidate(RuleContext context, PackKind lhsPack, PackKind rhsPac
var unpackAxes = new List<int>();
var unpadNums = new List<int>();
var unpackLanes = new List<int>();
if (lhsPack.HasFlag(PackKind.M))
if (lhsPack.HasFlag(IR.CPU.PackedMatMul.PackKind.M))
{
var mPackIndex = Array.IndexOf(lhsPackedAxes, lm);
unpackAxes.Add(outRank - 2);
unpadNums.Add(lhsPadNums[mPackIndex]);
unpackLanes.Add(Lane);
}

if (rhsPack.HasFlag(PackKind.N))
if (rhsPack.HasFlag(IR.CPU.PackedMatMul.PackKind.N))
{
var nPackIndex = Array.IndexOf(rhsPackedAxes, rn);
unpackAxes.Add(outRank - 1);
Expand Down

0 comments on commit 464b23b

Please sign in to comment.