From 4f4a95c3e29103767df31bff5b65f26e7d5efca0 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Sun, 29 Sep 2024 06:05:36 +0000 Subject: [PATCH] Add ntt.store, optimize u_matmul for RVV --- .../nncase/ntt/arch/riscv64/primitive_ops.h | 11 + .../nncase/ntt/arch/riscv64/ukernels.h | 38 +-- .../include/nncase/ntt/arch/x86_64/ukernels.h | 2 - .../include/nncase/ntt/kernels/matmul.h | 228 +-------------- src/Native/include/nncase/ntt/primitive_ops.h | 16 ++ src/Native/include/nncase/ntt/ukernels.h | 160 +---------- .../include/nncase/ntt/ukernels/u_matmul.h | 270 ++++++++++++++++++ .../include/nncase/ntt/ukernels/u_mul_add.h | 74 +++++ .../include/nncase/ntt/ukernels/u_pack.h | 49 ++++ .../include/nncase/ntt/ukernels/u_reduce.h | 112 ++++++++ 10 files changed, 561 insertions(+), 399 deletions(-) create mode 100644 src/Native/include/nncase/ntt/ukernels/u_matmul.h create mode 100644 src/Native/include/nncase/ntt/ukernels/u_mul_add.h create mode 100644 src/Native/include/nncase/ntt/ukernels/u_pack.h create mode 100644 src/Native/include/nncase/ntt/ukernels/u_reduce.h diff --git a/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h b/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h index 9d3a40736..70554956c 100644 --- a/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h +++ b/src/Native/include/nncase/ntt/arch/riscv64/primitive_ops.h @@ -14,6 +14,8 @@ */ #pragma once #include "../../primitive_ops.h" +#include "nncase/ntt/arch/riscv64/arch_types.h" +#include "nncase/ntt/vector.h" #include "rvv_mathfun.h" #ifdef __riscv_vector @@ -29,6 +31,15 @@ namespace nncase::ntt::ops { kernel(1, 32) kernel(2, 16) kernel(4, 8) kernel(8, 4) #endif +template <> +struct store, + ntt::vector> { + void operator()(ntt::vector &dest, + const ntt::vector &v) const noexcept { + __riscv_vse32_v_f32m1((float *)&dest, v, NTT_VLEN / 32); + } +}; + #define RVV_UNARY_OP(op, dtype, vl, kernel) \ template <> struct op> { \ ntt::vector \ diff --git a/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h b/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h index b40004733..c9649881d 100644 --- a/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h @@ -14,9 +14,9 @@ */ #pragma once #include "../../ukernels.h" -#include "arch_types.h" +#include "nncase/ntt/arch/riscv64/arch_types.h" #include "nncase/ntt/vector.h" -#include +#include namespace nncase::ntt::ukernels { template struct u_reduce_policy { @@ -32,8 +32,8 @@ struct u_matmul_policy { // Pack M template <> -struct u_matmul_policy, float, - vector, true> { +struct u_matmul_policy, + float, vector, true> { static constexpr size_t m0_tile = 2; static constexpr size_t n0_tile = 4; static constexpr size_t m0_subtile = 0; @@ -41,8 +41,8 @@ struct u_matmul_policy, float, // Pack K template <> -struct u_matmul_policy, - vector, float, true> { +struct u_matmul_policy, + vector, float, true> { static constexpr size_t m0_tile = 2; static constexpr size_t n0_tile = 2; static constexpr size_t m0_subtile = 0; @@ -50,8 +50,9 @@ struct u_matmul_policy, // Pack N template <> -struct u_matmul_policy, - vector, true> { +struct u_matmul_policy, + vector, true> { static constexpr size_t m0_tile = 4; static constexpr size_t n0_tile = 2; static constexpr size_t m0_subtile = 0; @@ -59,8 +60,9 @@ struct u_matmul_policy, // Pack MN template <> -struct u_matmul_policy, - vector, vector, true> { +struct u_matmul_policy, + vector, + vector, true> { static constexpr size_t m0_tile = 1; static constexpr size_t n0_tile = 2; static constexpr size_t m0_subtile = 4; @@ -68,8 +70,9 @@ struct u_matmul_policy, // Pack MK template <> -struct u_matmul_policy, - vector, vector, true> { +struct u_matmul_policy< + mamtul_pack_kind::pack_mk, vector, + vector, vector, true> { static constexpr size_t m0_tile = 1; static constexpr size_t n0_tile = 1; static constexpr size_t m0_subtile = 0; @@ -77,8 +80,9 @@ struct u_matmul_policy, // Pack KN template <> -struct u_matmul_policy, - vector, vector, true> { +struct u_matmul_policy, + vector, + vector, true> { static constexpr size_t m0_tile = 4; static constexpr size_t n0_tile = 2; static constexpr size_t m0_subtile = 0; @@ -86,8 +90,10 @@ struct u_matmul_policy, // Pack MKN template <> -struct u_matmul_policy, - vector, vector, true> { +struct u_matmul_policy, + vector, + vector, true> { static constexpr size_t m0_tile = 1; static constexpr size_t n0_tile = 2; static constexpr size_t m0_subtile = 4; diff --git a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h index ad04bb060..f49ae7be9 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h @@ -14,9 +14,7 @@ */ #pragma once #include "../../ukernels.h" -#include "arch_types.h" #include "nncase/ntt/vector.h" -#include namespace nncase::ntt::ukernels { template diff --git a/src/Native/include/nncase/ntt/kernels/matmul.h b/src/Native/include/nncase/ntt/kernels/matmul.h index 7da990889..aadf3a8c8 100644 --- a/src/Native/include/nncase/ntt/kernels/matmul.h +++ b/src/Native/include/nncase/ntt/kernels/matmul.h @@ -166,234 +166,12 @@ class matmul_impl{}); - - // 1. pack M & N - if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn && - m0_subtile) { - using TSubOutElem = ntt::vector; - TSubOutElem c0_tmp[m0_subtile][N0Tile]; - - for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; - sm1 += m0_subtile) { - ntt::apply(fixed_shape{}, [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(0, index[1])(sm1 + index[0]) - : TSubOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - outer_product(a, b, c0_tmp, m1, k1, n1, - sm1); - } - - ntt::apply(fixed_shape{}, [&](auto index) { - c0(0, index[1])(sm1 + index[0]) = - c0_tmp[index[0]][index[1]]; - }); - } - } - // 2. pack K & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) { - using TLhsElem = std::remove_const_t; - - TOutElem c0_tmp[M0Tile][N0Tile]; - ntt::apply(c0.shape(), [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(index) : TOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { - outer_product(a, b, c0_tmp, m1, k1, n1, 0, - sk1); - } - } - - ntt::apply(c0.shape(), [&](auto index) { - c0(index) = c0_tmp[index[0]][index[1]]; - }); - } - // 3. pack MK & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn && - m0_subtile) { - using TLhsElem = std::remove_const_t; - using TSubOutElem = ntt::vector; - - TSubOutElem c0_tmp[m0_subtile][N0Tile]; - - for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; - sm1 += m0_subtile) { - ntt::apply(fixed_shape{}, [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(0, index[1])(sm1 + index[0]) - : TSubOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { - outer_product(a, b, c0_tmp, m1, k1, n1, - sm1, sk1); - } - } - - ntt::apply(fixed_shape{}, [&](auto index) { - c0(0, index[1])(sm1 + index[0]) = - c0_tmp[index[0]][index[1]]; - }); - } - } - // Other packs - else { - TOutElem c0_tmp[M0Tile][N0Tile]; - ntt::apply(c0.shape(), [&](auto index) { - c0_tmp[index[0]][index[1]] = - AccumulateC ? c0(index) : TOutElem{}; - }); - - for (size_t k1 = 0; k1 < K; k1++) { - outer_product(a, b, c0_tmp, m1, k1, n1); - } - - ntt::apply(c0.shape(), [&](auto index) { - c0(index) = c0_tmp[index[0]][index[1]]; - }); - } - } - - template - void outer_product(const TA &a, const TB &b, TC &c0_tmp, size_t m1, - size_t k1, size_t n1, size_t sm1 = 0, size_t sk1 = 0) { auto a1 = - a.view(make_ranked_shape(m1, k1), make_ranked_shape(M0Tile, 1)); + a.view(make_ranked_shape(m1, 0), make_ranked_shape(M0Tile, K)); auto b1 = - b.view(make_ranked_shape(k1, n1), make_ranked_shape(1, N0Tile)); - - using TLhsElem = std::remove_const_t; - using TRhsElem = std::remove_const_t; - - // 1. pack M & N - if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn && - m0_subtile) { - using TSubLhsElem = typename TLhsElem::element_type; - TSubLhsElem a0_tmp[m0_subtile]; - TRhsElem b0_tmp[N0Tile]; + b.view(make_ranked_shape(0, n1), make_ranked_shape(K, N0Tile)); - ntt::apply(fixed_shape{}, [&](auto index) { - a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0]); - }); - ntt::apply(fixed_shape{}, - [&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < m0_subtile; m++) { - mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); - } - } - } - // 2. pack K & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) { - using TSubLhsElem = typename TLhsElem::element_type; - using TSubRhsElem = ntt::vector; - TSubLhsElem a0_tmp[M0Tile]; - TSubRhsElem b0_tmp[N0Tile]; - - ntt::apply(fixed_shape{}, [&](auto index) { - a0_tmp[index[0]] = a1(index[0], 0)(sk1); - }); - ntt::apply(fixed_shape{}, [&](auto index) { - b0_tmp[index[0]] = b1(0, index[0])(sk1); - }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < M0Tile; m++) { - mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); - } - } - } - // 1. pack MK & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn && - m0_subtile) { - using TSubLhsElem = typename TLhsElem::element_type; - using TSubRhsElem = ntt::vector; - TSubLhsElem a0_tmp[m0_subtile]; - TSubRhsElem b0_tmp[N0Tile]; - - ntt::apply(fixed_shape{}, [&](auto index) { - a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0], sk1); - }); - ntt::apply(fixed_shape{}, [&](auto index) { - b0_tmp[index[0]] = b1(0, index[0])(sk1); - }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < m0_subtile; m++) { - auto &output = c0_tmp[m][n]; - auto value = ntt::outer_product(a0_tmp[m], b0_tmp[n]); - output = output + value; - } - } - } - // Other packs - else { - TLhsElem a0_tmp[M0Tile]; - TRhsElem b0_tmp[N0Tile]; - - ntt::apply(fixed_shape{}, - [&](auto index) { a0_tmp[index[0]] = a1(index[0], 0); }); - ntt::apply(fixed_shape{}, - [&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); }); - - for (size_t n = 0; n < N0Tile; n++) { - for (size_t m = 0; m < M0Tile; m++) { - mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); - } - } - } - } - - template - void mul_add(const TLhsElem &lhs, const TRhsElem &rhs, TOutElem &output) { - // 1. 0D-packing - if constexpr (pack_kind == ukernels::mamtul_pack_kind::no_pack) { - output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); - } - // 2. 1D-packing - // 2.1. pack M - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_m) { - output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); - } - // 2.2. pack K - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_k) { - auto value = ntt::inner_product(lhs, rhs); - output = AccC ? output + value : value; - } - // 2.3. pack N - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_n) { - output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); - } - // 2.4. pack M & N - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn || - pack_kind == ukernels::mamtul_pack_kind::pack_kn) { - auto value = ntt::outer_product(lhs, rhs); - output = AccC ? output + value : value; - } - // 3.1. pack MK & K - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mk) { - for (size_t m = 0; m < lhs.shape()[0]; m++) { - auto value = ntt::inner_product(lhs(m), rhs); - output(m) = AccC ? output(m) + value : value; - } - } - // 3.2. pack MK & KN - else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn) { - output = ntt::mma(lhs, rhs, output); - } else { - static_assert(sizeof(TLhsElem) == 0, "Unsupported packing."); - } + ntt::u_matmul(a1, b1, c0, K); } }; } // namespace detail diff --git a/src/Native/include/nncase/ntt/primitive_ops.h b/src/Native/include/nncase/ntt/primitive_ops.h index b98566314..245cf3d16 100644 --- a/src/Native/include/nncase/ntt/primitive_ops.h +++ b/src/Native/include/nncase/ntt/primitive_ops.h @@ -31,6 +31,17 @@ enum class reduce_op { namespace ops { +/** + * @defgroup Load/Store operation functors + * @{ + */ + +template struct store { + constexpr void operator()(TDest &dest, const TSource &v) const noexcept { + dest = v; + } +}; + /** * @defgroup Unary operation functors * @{ @@ -277,6 +288,11 @@ template struct clamp { return ntt::reduce(v, init_value); \ } +template +constexpr void store(TDest &dest, const TSource &v) noexcept { + ops::store, std::decay_t>()(dest, v); +} + NTT_DEFINE_UNARY_FUNC_IMPL(abs) NTT_DEFINE_UNARY_FUNC_IMPL(acos) NTT_DEFINE_UNARY_FUNC_IMPL(acosh) diff --git a/src/Native/include/nncase/ntt/ukernels.h b/src/Native/include/nncase/ntt/ukernels.h index 2a79d325a..f8e4d51f9 100644 --- a/src/Native/include/nncase/ntt/ukernels.h +++ b/src/Native/include/nncase/ntt/ukernels.h @@ -13,159 +13,7 @@ * limitations under the License. */ #pragma once -#include "apply.h" -#include "primitive_ops.h" -#include "tensor.h" -#include "tensor_traits.h" - -namespace nncase::ntt::ukernels { -template -class u_pack { - public: - constexpr void operator()(const TIn *input, TOut *output) noexcept { - for (size_t j = 0; j < N; j++) { - for (size_t i = 0; i < M; i++) { - output[j](i) = input[i * MStrides + j]; - } - } - - if constexpr (M < TOut::shape_type::length()) { - for (size_t j = 0; j < N; j++) { - for (size_t i = M; i < TOut::shape_type::length(); i++) { - output[j](i) = (TIn)0; - } - } - } - } -}; - -template struct reduce_to_binary_type; - -template <> struct reduce_to_binary_type { - template using type = ops::add; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::min; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::max; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::add; -}; - -template <> struct reduce_to_binary_type { - template using type = ops::mul; -}; - -template struct u_reduce_policy { - static constexpr size_t unroll = 2; -}; - -template struct u_reduce { - public: - constexpr T operator()(const T *input, size_t input_stride, size_t count, - T init_value) noexcept { - using binary_op_t = - typename reduce_to_binary_type::template type; - using policy_t = u_reduce_policy; - constexpr auto unroll = policy_t::unroll; - - if (count / unroll) { - T temp[unroll]; -#if 1 - for (size_t i = 0; i < unroll; i++) { - temp[i] = *input; - input += input_stride; - count--; - } - - while (count / unroll) { - for (size_t i = 0; i < unroll; i++) { - temp[i] = binary_op_t()(temp[i], *input); - input += input_stride; - count--; - } - } - - init_value = binary_op_t()(init_value, tree_reduce(temp)); -#else - while (count / unroll) { - for (size_t i = 0; i < unroll; i++) { - temp[i] = *input; - input += input_stride; - count--; - } - init_value = - binary_op_t()(init_value, tree_reduce(temp)); - } -#endif - } - - for (size_t i = 0; i < count; i++) { - init_value = binary_op_t()(init_value, *input); - input += input_stride; - } - return init_value; - } - - template constexpr T tree_reduce(T *input) noexcept { - using binary_op_t = - typename reduce_to_binary_type::template type; - if constexpr (N == 2) { - return binary_op_t()(input[0], input[1]); - } else { - return binary_op_t()(tree_reduce(input), - tree_reduce(input + N / 2)); - } - } -}; - -enum class mamtul_pack_kind { - unknown, - no_pack, - pack_m, - pack_k, - pack_n, - pack_mn, - pack_mk, - pack_kn, - pack_mkn, -}; - -template -struct u_matmul_policy { - static constexpr size_t m0_tile = 1; - static constexpr size_t n0_tile = 1; - static constexpr size_t m0_subtile = 0; -}; -} // namespace nncase::ntt::ukernels - -namespace nncase::ntt { -template -constexpr void u_pack(const TIn *input, TOut *output) noexcept { - ukernels::u_pack, - std::decay_t> - impl; - impl(input, output); -} - -template -constexpr T u_reduce(const T *input, size_t input_stride, size_t count, - T init_value) noexcept { - ukernels::u_reduce impl; - return impl(input, input_stride, count, init_value); -} - -// template -// constexpr void u_matmul(const TLhsElem *&lhs, const TRhsElem *&rhs, -// TOutElem *output, size_t M, size_t N, size_t K, -// size_t lhs_stride, size_t rhs_stride, -// size_t out_stride) noexcept { - -// } -} // namespace nncase::ntt +#include "ukernels/u_matmul.h" +#include "ukernels/u_mul_add.h" +#include "ukernels/u_pack.h" +#include "ukernels/u_reduce.h" diff --git a/src/Native/include/nncase/ntt/ukernels/u_matmul.h b/src/Native/include/nncase/ntt/ukernels/u_matmul.h new file mode 100644 index 000000000..3d93a05d6 --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_matmul.h @@ -0,0 +1,270 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../apply.h" +#include "nncase/ntt/shape.h" +#include "u_mul_add.h" + +namespace nncase::ntt { +namespace ukernels { +template +struct u_matmul_policy { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; + +template +struct u_matmul_generic { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + TOutElem c0_tmp[M0Tile][N0Tile]; + ntt::apply(c0.shape(), [&](auto index) { + c0_tmp[index[0]][index[1]] = AccumulateC ? c0(index) : TOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + auto a0 = + a.view(make_ranked_shape(0, k1), fixed_shape{}); + auto b0 = + b.view(make_ranked_shape(k1, 0), fixed_shape<1, N0Tile>{}); + TLhsElem a0_tmp[M0Tile]; + TRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, + [&](auto index) { a0_tmp[index[0]] = a0(index[0], 0); }); + ntt::apply(fixed_shape{}, + [&](auto index) { b0_tmp[index[0]] = b0(0, index[0]); }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < M0Tile; m++) { + u_mul_add(a0_tmp[m], b0_tmp[n], + c0_tmp[m][n]); + } + } + } + + ntt::apply(c0.shape(), [&](auto index) { + ntt::store(c0(index), c0_tmp[index[0]][index[1]]); + }); + } +}; + +template +struct u_matmul : u_matmul_generic {}; + +template +struct u_matmul { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + using TSubOutElem = ntt::vector; + using policy_t = + ntt::ukernels::u_matmul_policy; + constexpr auto m0_subtile = policy_t::m0_subtile; + + if constexpr (m0_subtile) { + TSubOutElem c0_tmp[m0_subtile][N0Tile]; + + for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; + sm1 += m0_subtile) { + ntt::apply(fixed_shape{}, [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(0, index[1])(sm1 + index[0]) + : TSubOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + using TSubLhsElem = typename TLhsElem::element_type; + TSubLhsElem a0_tmp[m0_subtile]; + TRhsElem b0_tmp[N0Tile]; + + auto a0 = a.view(make_ranked_shape(0, k1), + fixed_shape{}); + auto b0 = b.view(make_ranked_shape(k1, 0), + fixed_shape<1, N0Tile>{}); + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a0(0, 0)(sm1 + index[0]); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b0(0, index[0]); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < m0_subtile; m++) { + ntt::u_mul_add( + a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); + } + } + } + + ntt::apply(fixed_shape{}, [&](auto index) { + ntt::store(c0(0, index[1])(sm1 + index[0]), + c0_tmp[index[0]][index[1]]); + }); + } + } else { + u_matmul_generic + impl; + impl(a, b, c0, K); + } + } +}; + +template +struct u_matmul { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + TOutElem c0_tmp[M0Tile][N0Tile]; + ntt::apply(c0.shape(), [&](auto index) { + c0_tmp[index[0]][index[1]] = AccumulateC ? c0(index) : TOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + auto a0 = + a.view(make_ranked_shape(0, k1), fixed_shape{}); + auto b0 = + b.view(make_ranked_shape(k1, 0), fixed_shape<1, N0Tile>{}); + for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { + using TSubLhsElem = typename TLhsElem::element_type; + using TSubRhsElem = ntt::vector; + + TSubLhsElem a0_tmp[M0Tile]; + TSubRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a0(index[0], 0)(sk1); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b0(0, index[0])(sk1); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < M0Tile; m++) { + ntt::u_mul_add( + a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); + } + } + } + } + + ntt::apply(c0.shape(), [&](auto index) { + ntt::store(c0(index), c0_tmp[index[0]][index[1]]); + }); + } +}; + +template +struct u_matmul { + template + constexpr void operator()(const TA &a, const TB &b, TC &c0, + size_t K) noexcept { + using TSubOutElem = ntt::vector; + using policy_t = + ntt::ukernels::u_matmul_policy; + constexpr auto m0_subtile = policy_t::m0_subtile; + + if constexpr (m0_subtile) { + TSubOutElem c0_tmp[m0_subtile][N0Tile]; + + for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; + sm1 += m0_subtile) { + ntt::apply(fixed_shape{}, [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(0, index[1])(sm1 + index[0]) + : TSubOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { + using TSubLhsElem = typename TLhsElem::element_type; + using TSubRhsElem = + ntt::vector; + + auto a0 = a.view(make_ranked_shape(0, k1), + fixed_shape{}); + auto b0 = b.view(make_ranked_shape(k1, 0), + fixed_shape<1, N0Tile>{}); + + TSubLhsElem a0_tmp[m0_subtile]; + TSubRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a0(0, 0)(sm1 + index[0], sk1); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b0(0, index[0])(sk1); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < m0_subtile; m++) { + auto &output = c0_tmp[m][n]; + auto value = + ntt::outer_product(a0_tmp[m], b0_tmp[n]); + output = output + value; + } + } + } + } + + ntt::apply(fixed_shape{}, [&](auto index) { + ntt::store(c0(0, index[1])(sm1 + index[0]), + c0_tmp[index[0]][index[1]]); + }); + } + } else { + u_matmul_generic + impl; + impl(a, b, c0, K); + } + } +}; +} // namespace ukernels + +template +constexpr void u_matmul(const TA &a, const TB &b, TC &c, size_t K) noexcept { + using TLhsElem = std::decay_t; + using TRhsElem = std::decay_t; + using TOutElem = std::decay_t; + ukernels::u_matmul + impl; + impl(a, b, c, K); +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/ukernels/u_mul_add.h b/src/Native/include/nncase/ntt/ukernels/u_mul_add.h new file mode 100644 index 000000000..c2b2b7463 --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_mul_add.h @@ -0,0 +1,74 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../primitive_ops.h" + +namespace nncase::ntt { +namespace ukernels { +enum class mamtul_pack_kind { + unknown, + no_pack, + pack_m, + pack_k, + pack_n, + pack_mn, + pack_mk, + pack_kn, + pack_mkn, +}; +} // namespace ukernels + +template +void u_mul_add(const TLhsElem &lhs, const TRhsElem &rhs, TOutElem &output) { + // 1. 0D-packing + if constexpr (PackKind == ukernels::mamtul_pack_kind::no_pack) { + output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); + } + // 2. 1D-packing + // 2.1. pack M + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_m) { + output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); + } + // 2.2. pack K + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_k) { + auto value = ntt::inner_product(lhs, rhs); + output = AccC ? output + value : value; + } + // 2.3. pack N + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_n) { + output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); + } + // 2.4. pack M & N + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_mn || + PackKind == ukernels::mamtul_pack_kind::pack_kn) { + auto value = ntt::outer_product(lhs, rhs); + output = AccC ? output + value : value; + } + // 3.1. pack MK & K + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_mk) { + for (size_t m = 0; m < lhs.shape()[0]; m++) { + auto value = ntt::inner_product(lhs(m), rhs); + output(m) = AccC ? output(m) + value : value; + } + } + // 3.2. pack MK & KN + else if constexpr (PackKind == ukernels::mamtul_pack_kind::pack_mkn) { + output = ntt::mma(lhs, rhs, output); + } else { + static_assert(sizeof(TLhsElem) == 0, "Unsupported packing."); + } +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/ukernels/u_pack.h b/src/Native/include/nncase/ntt/ukernels/u_pack.h new file mode 100644 index 000000000..899dcb9db --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_pack.h @@ -0,0 +1,49 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace nncase::ntt { +namespace ukernels { +template +class u_pack { + public: + constexpr void operator()(const TIn *input, TOut *output) noexcept { + for (size_t j = 0; j < N; j++) { + for (size_t i = 0; i < M; i++) { + output[j](i) = input[i * MStrides + j]; + } + } + + if constexpr (M < TOut::shape_type::length()) { + for (size_t j = 0; j < N; j++) { + for (size_t i = M; i < TOut::shape_type::length(); i++) { + output[j](i) = (TIn)0; + } + } + } + } +}; +} // namespace ukernels + +template +constexpr void u_pack(const TIn *input, TOut *output) noexcept { + ukernels::u_pack, + std::decay_t> + impl; + impl(input, output); +} +} // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/ukernels/u_reduce.h b/src/Native/include/nncase/ntt/ukernels/u_reduce.h new file mode 100644 index 000000000..612d7520c --- /dev/null +++ b/src/Native/include/nncase/ntt/ukernels/u_reduce.h @@ -0,0 +1,112 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../primitive_ops.h" + +namespace nncase::ntt { +namespace ukernels { +template struct reduce_to_binary_type; + +template <> struct reduce_to_binary_type { + template using type = ops::add; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::min; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::max; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::add; +}; + +template <> struct reduce_to_binary_type { + template using type = ops::mul; +}; + +template struct u_reduce_policy { + static constexpr size_t unroll = 2; +}; + +template struct u_reduce { + public: + constexpr T operator()(const T *input, size_t input_stride, size_t count, + T init_value) noexcept { + using binary_op_t = + typename reduce_to_binary_type::template type; + using policy_t = u_reduce_policy; + constexpr auto unroll = policy_t::unroll; + + if (count / unroll) { + T temp[unroll]; +#if 1 + for (size_t i = 0; i < unroll; i++) { + temp[i] = *input; + input += input_stride; + count--; + } + + while (count / unroll) { + for (size_t i = 0; i < unroll; i++) { + temp[i] = binary_op_t()(temp[i], *input); + input += input_stride; + count--; + } + } + + init_value = binary_op_t()(init_value, tree_reduce(temp)); +#else + while (count / unroll) { + for (size_t i = 0; i < unroll; i++) { + temp[i] = *input; + input += input_stride; + count--; + } + init_value = + binary_op_t()(init_value, tree_reduce(temp)); + } +#endif + } + + for (size_t i = 0; i < count; i++) { + init_value = binary_op_t()(init_value, *input); + input += input_stride; + } + return init_value; + } + + template constexpr T tree_reduce(T *input) noexcept { + using binary_op_t = + typename reduce_to_binary_type::template type; + if constexpr (N == 2) { + return binary_op_t()(input[0], input[1]); + } else { + return binary_op_t()(tree_reduce(input), + tree_reduce(input + N / 2)); + } + } +}; +} // namespace ukernels + +template +constexpr T u_reduce(const T *input, size_t input_stride, size_t count, + T init_value) noexcept { + ukernels::u_reduce impl; + return impl(input, input_stride, count, init_value); +} +} // namespace nncase::ntt