Skip to content

Commit

Permalink
Add ntt.store, optimize u_matmul for RVV
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Sep 29, 2024
1 parent 81f3318 commit 4f4a95c
Show file tree
Hide file tree
Showing 10 changed files with 561 additions and 399 deletions.
11 changes: 11 additions & 0 deletions src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
*/
#pragma once
#include "../../primitive_ops.h"
#include "nncase/ntt/arch/riscv64/arch_types.h"
#include "nncase/ntt/vector.h"
#include "rvv_mathfun.h"

#ifdef __riscv_vector
Expand All @@ -29,6 +31,15 @@ namespace nncase::ntt::ops {
kernel(1, 32) kernel(2, 16) kernel(4, 8) kernel(8, 4)
#endif

template <>
struct store<ntt::vector<float, NTT_VLEN / 32>,
ntt::vector<float, NTT_VLEN / 32>> {
void operator()(ntt::vector<float, NTT_VLEN / 32> &dest,
const ntt::vector<float, NTT_VLEN / 32> &v) const noexcept {
__riscv_vse32_v_f32m1((float *)&dest, v, NTT_VLEN / 32);
}
};

#define RVV_UNARY_OP(op, dtype, vl, kernel) \
template <> struct op<ntt::vector<dtype, vl>> { \
ntt::vector<dtype, vl> \
Expand Down
38 changes: 22 additions & 16 deletions src/Native/include/nncase/ntt/arch/riscv64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
*/
#pragma once
#include "../../ukernels.h"
#include "arch_types.h"
#include "nncase/ntt/arch/riscv64/arch_types.h"
#include "nncase/ntt/vector.h"
#include <vector>
#include <riscv_vector.h>

namespace nncase::ntt::ukernels {
template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
Expand All @@ -32,62 +32,68 @@ struct u_matmul_policy<mamtul_pack_kind::no_pack, float, float, float, true> {

// Pack M
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_m, vector<float, 4>, float,
vector<float, 4>, true> {
struct u_matmul_policy<mamtul_pack_kind::pack_m, vector<float, NTT_VLEN / 32>,
float, vector<float, NTT_VLEN / 32>, true> {
static constexpr size_t m0_tile = 2;
static constexpr size_t n0_tile = 4;
static constexpr size_t m0_subtile = 0;
};

// Pack K
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_k, vector<float, 8>,
vector<float, 8>, float, true> {
struct u_matmul_policy<mamtul_pack_kind::pack_k, vector<float, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32>, float, true> {
static constexpr size_t m0_tile = 2;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 0;
};

// Pack N
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_n, float, vector<float, 8>,
vector<float, 8>, true> {
struct u_matmul_policy<mamtul_pack_kind::pack_n, float,
vector<float, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32>, true> {
static constexpr size_t m0_tile = 4;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 0;
};

// Pack MN
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_mn, vector<float, 8>,
vector<float, 8>, vector<float, 8, 8>, true> {
struct u_matmul_policy<mamtul_pack_kind::pack_mn, vector<float, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32, NTT_VLEN / 32>, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 4;
};

// Pack MK
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_mk, vector<float, 8, 8>,
vector<float, 8>, vector<float, 8>, true> {
struct u_matmul_policy<
mamtul_pack_kind::pack_mk, vector<float, NTT_VLEN / 32, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32>, vector<float, NTT_VLEN / 32>, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 1;
static constexpr size_t m0_subtile = 0;
};

// Pack KN
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_kn, vector<float, 8>,
vector<float, 8, 8>, vector<float, 8>, true> {
struct u_matmul_policy<mamtul_pack_kind::pack_kn, vector<float, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32>, true> {
static constexpr size_t m0_tile = 4;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 0;
};

// Pack MKN
template <>
struct u_matmul_policy<mamtul_pack_kind::pack_mkn, vector<float, 8, 8>,
vector<float, 8, 8>, vector<float, 8, 8>, true> {
struct u_matmul_policy<mamtul_pack_kind::pack_mkn,
vector<float, NTT_VLEN / 32, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32, NTT_VLEN / 32>,
vector<float, NTT_VLEN / 32, NTT_VLEN / 32>, true> {
static constexpr size_t m0_tile = 1;
static constexpr size_t n0_tile = 2;
static constexpr size_t m0_subtile = 4;
Expand Down
2 changes: 0 additions & 2 deletions src/Native/include/nncase/ntt/arch/x86_64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
*/
#pragma once
#include "../../ukernels.h"
#include "arch_types.h"
#include "nncase/ntt/vector.h"
#include <vector>

namespace nncase::ntt::ukernels {
template <size_t M, size_t N, size_t MStrides>
Expand Down
228 changes: 3 additions & 225 deletions src/Native/include/nncase/ntt/kernels/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,234 +166,12 @@ class matmul_impl<false, false, AccumulateC, TLhs, TRhs, TOut, LhsPackedAxes,
size_t n1) {
auto c0 =
c.view(make_ranked_shape(m1, n1), fixed_shape<M0Tile, N0Tile>{});

// 1. pack M & N
if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn &&
m0_subtile) {
using TSubOutElem = ntt::vector<typename TOutElem::element_type,
TOutElem::shape().last()>;
TSubOutElem c0_tmp[m0_subtile][N0Tile];

for (size_t sm1 = 0; sm1 < TOutElem::shape()[0];
sm1 += m0_subtile) {
ntt::apply(fixed_shape<m0_subtile, N0Tile>{}, [&](auto index) {
c0_tmp[index[0]][index[1]] =
AccumulateC ? c0(0, index[1])(sm1 + index[0])
: TSubOutElem{};
});

for (size_t k1 = 0; k1 < K; k1++) {
outer_product<M0Tile, N0Tile>(a, b, c0_tmp, m1, k1, n1,
sm1);
}

ntt::apply(fixed_shape<m0_subtile, N0Tile>{}, [&](auto index) {
c0(0, index[1])(sm1 + index[0]) =
c0_tmp[index[0]][index[1]];
});
}
}
// 2. pack K & KN
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) {
using TLhsElem = std::remove_const_t<typename TA::element_type>;

TOutElem c0_tmp[M0Tile][N0Tile];
ntt::apply(c0.shape(), [&](auto index) {
c0_tmp[index[0]][index[1]] =
AccumulateC ? c0(index) : TOutElem{};
});

for (size_t k1 = 0; k1 < K; k1++) {
for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) {
outer_product<M0Tile, N0Tile>(a, b, c0_tmp, m1, k1, n1, 0,
sk1);
}
}

ntt::apply(c0.shape(), [&](auto index) {
c0(index) = c0_tmp[index[0]][index[1]];
});
}
// 3. pack MK & KN
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn &&
m0_subtile) {
using TLhsElem = std::remove_const_t<typename TA::element_type>;
using TSubOutElem = ntt::vector<typename TOutElem::element_type,
TOutElem::shape().last()>;

TSubOutElem c0_tmp[m0_subtile][N0Tile];

for (size_t sm1 = 0; sm1 < TOutElem::shape()[0];
sm1 += m0_subtile) {
ntt::apply(fixed_shape<m0_subtile, N0Tile>{}, [&](auto index) {
c0_tmp[index[0]][index[1]] =
AccumulateC ? c0(0, index[1])(sm1 + index[0])
: TSubOutElem{};
});

for (size_t k1 = 0; k1 < K; k1++) {
for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) {
outer_product<M0Tile, N0Tile>(a, b, c0_tmp, m1, k1, n1,
sm1, sk1);
}
}

ntt::apply(fixed_shape<m0_subtile, N0Tile>{}, [&](auto index) {
c0(0, index[1])(sm1 + index[0]) =
c0_tmp[index[0]][index[1]];
});
}
}
// Other packs
else {
TOutElem c0_tmp[M0Tile][N0Tile];
ntt::apply(c0.shape(), [&](auto index) {
c0_tmp[index[0]][index[1]] =
AccumulateC ? c0(index) : TOutElem{};
});

for (size_t k1 = 0; k1 < K; k1++) {
outer_product<M0Tile, N0Tile>(a, b, c0_tmp, m1, k1, n1);
}

ntt::apply(c0.shape(), [&](auto index) {
c0(index) = c0_tmp[index[0]][index[1]];
});
}
}

template <size_t M0Tile, size_t N0Tile, class TA, class TB, class TC>
void outer_product(const TA &a, const TB &b, TC &c0_tmp, size_t m1,
size_t k1, size_t n1, size_t sm1 = 0, size_t sk1 = 0) {
auto a1 =
a.view(make_ranked_shape(m1, k1), make_ranked_shape(M0Tile, 1));
a.view(make_ranked_shape(m1, 0), make_ranked_shape(M0Tile, K));
auto b1 =
b.view(make_ranked_shape(k1, n1), make_ranked_shape(1, N0Tile));

using TLhsElem = std::remove_const_t<typename TA::element_type>;
using TRhsElem = std::remove_const_t<typename TB::element_type>;

// 1. pack M & N
if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn &&
m0_subtile) {
using TSubLhsElem = typename TLhsElem::element_type;
TSubLhsElem a0_tmp[m0_subtile];
TRhsElem b0_tmp[N0Tile];
b.view(make_ranked_shape(0, n1), make_ranked_shape(K, N0Tile));

ntt::apply(fixed_shape<m0_subtile>{}, [&](auto index) {
a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0]);
});
ntt::apply(fixed_shape<N0Tile>{},
[&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); });

for (size_t n = 0; n < N0Tile; n++) {
for (size_t m = 0; m < m0_subtile; m++) {
mul_add<true>(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]);
}
}
}
// 2. pack K & KN
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) {
using TSubLhsElem = typename TLhsElem::element_type;
using TSubRhsElem = ntt::vector<typename TRhsElem::element_type,
TRhsElem::shape().last()>;
TSubLhsElem a0_tmp[M0Tile];
TSubRhsElem b0_tmp[N0Tile];

ntt::apply(fixed_shape<M0Tile>{}, [&](auto index) {
a0_tmp[index[0]] = a1(index[0], 0)(sk1);
});
ntt::apply(fixed_shape<N0Tile>{}, [&](auto index) {
b0_tmp[index[0]] = b1(0, index[0])(sk1);
});

for (size_t n = 0; n < N0Tile; n++) {
for (size_t m = 0; m < M0Tile; m++) {
mul_add<true>(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]);
}
}
}
// 1. pack MK & KN
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn &&
m0_subtile) {
using TSubLhsElem = typename TLhsElem::element_type;
using TSubRhsElem = ntt::vector<typename TRhsElem::element_type,
TRhsElem::shape().last()>;
TSubLhsElem a0_tmp[m0_subtile];
TSubRhsElem b0_tmp[N0Tile];

ntt::apply(fixed_shape<m0_subtile>{}, [&](auto index) {
a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0], sk1);
});
ntt::apply(fixed_shape<N0Tile>{}, [&](auto index) {
b0_tmp[index[0]] = b1(0, index[0])(sk1);
});

for (size_t n = 0; n < N0Tile; n++) {
for (size_t m = 0; m < m0_subtile; m++) {
auto &output = c0_tmp[m][n];
auto value = ntt::outer_product(a0_tmp[m], b0_tmp[n]);
output = output + value;
}
}
}
// Other packs
else {
TLhsElem a0_tmp[M0Tile];
TRhsElem b0_tmp[N0Tile];

ntt::apply(fixed_shape<M0Tile>{},
[&](auto index) { a0_tmp[index[0]] = a1(index[0], 0); });
ntt::apply(fixed_shape<N0Tile>{},
[&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); });

for (size_t n = 0; n < N0Tile; n++) {
for (size_t m = 0; m < M0Tile; m++) {
mul_add<true>(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]);
}
}
}
}

template <bool AccC, class TLhsElem, class TRhsElem, class TOutElem>
void mul_add(const TLhsElem &lhs, const TRhsElem &rhs, TOutElem &output) {
// 1. 0D-packing
if constexpr (pack_kind == ukernels::mamtul_pack_kind::no_pack) {
output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs);
}
// 2. 1D-packing
// 2.1. pack M
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_m) {
output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs);
}
// 2.2. pack K
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_k) {
auto value = ntt::inner_product(lhs, rhs);
output = AccC ? output + value : value;
}
// 2.3. pack N
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_n) {
output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs);
}
// 2.4. pack M & N
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn ||
pack_kind == ukernels::mamtul_pack_kind::pack_kn) {
auto value = ntt::outer_product(lhs, rhs);
output = AccC ? output + value : value;
}
// 3.1. pack MK & K
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mk) {
for (size_t m = 0; m < lhs.shape()[0]; m++) {
auto value = ntt::inner_product(lhs(m), rhs);
output(m) = AccC ? output(m) + value : value;
}
}
// 3.2. pack MK & KN
else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn) {
output = ntt::mma<AccC>(lhs, rhs, output);
} else {
static_assert(sizeof(TLhsElem) == 0, "Unsupported packing.");
}
ntt::u_matmul<pack_kind, AccumulateC, M0Tile, N0Tile>(a1, b1, c0, K);
}
};
} // namespace detail
Expand Down
16 changes: 16 additions & 0 deletions src/Native/include/nncase/ntt/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ enum class reduce_op {

namespace ops {

/**
* @defgroup Load/Store operation functors
* @{
*/

template <class TDest, class TSource> struct store {
constexpr void operator()(TDest &dest, const TSource &v) const noexcept {
dest = v;
}
};

/**
* @defgroup Unary operation functors
* @{
Expand Down Expand Up @@ -277,6 +288,11 @@ template <class T1, class T2> struct clamp {
return ntt::reduce<op>(v, init_value); \
}

template <class TDest, class TSource>
constexpr void store(TDest &dest, const TSource &v) noexcept {
ops::store<std::decay_t<TDest>, std::decay_t<TSource>>()(dest, v);
}

NTT_DEFINE_UNARY_FUNC_IMPL(abs)
NTT_DEFINE_UNARY_FUNC_IMPL(acos)
NTT_DEFINE_UNARY_FUNC_IMPL(acosh)
Expand Down
Loading

0 comments on commit 4f4a95c

Please sign in to comment.