diff --git a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h index 721f530dc..ad04bb060 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h @@ -43,4 +43,74 @@ class u_pack> { template struct u_reduce_policy { static constexpr size_t unroll = 8; }; + +template <> +struct u_matmul_policy { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; + +// Pack M +template <> +struct u_matmul_policy, float, + vector, true> { + static constexpr size_t m0_tile = 2; + static constexpr size_t n0_tile = 4; + static constexpr size_t m0_subtile = 0; +}; + +// Pack K +template <> +struct u_matmul_policy, + vector, float, true> { + static constexpr size_t m0_tile = 2; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 0; +}; + +// Pack N +template <> +struct u_matmul_policy, + vector, true> { + static constexpr size_t m0_tile = 4; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 0; +}; + +// Pack MN +template <> +struct u_matmul_policy, + vector, vector, true> { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 4; +}; + +// Pack MK +template <> +struct u_matmul_policy, + vector, vector, true> { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; + +// Pack KN +template <> +struct u_matmul_policy, + vector, vector, true> { + static constexpr size_t m0_tile = 4; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 0; +}; + +// Pack MKN +template <> +struct u_matmul_policy, + vector, vector, true> { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 2; + static constexpr size_t m0_subtile = 4; +}; } // namespace nncase::ntt::ukernels diff --git a/src/Native/include/nncase/ntt/detail/shape_storage.h b/src/Native/include/nncase/ntt/detail/shape_storage.h index ee321575d..63ad9bd41 100644 --- a/src/Native/include/nncase/ntt/detail/shape_storage.h +++ b/src/Native/include/nncase/ntt/detail/shape_storage.h @@ -30,6 +30,8 @@ template class shape_storage { template class shape_storage> { public: + constexpr shape_storage(fixed_shape = {}) noexcept {}; + static constexpr size_t rank() noexcept { return sizeof...(Dims); } static constexpr auto shape() noexcept { return fixed_shape{}; } }; @@ -47,6 +49,8 @@ template class strides_storage { template class strides_storage> { public: + constexpr strides_storage(fixed_strides = {}) noexcept {}; + static constexpr auto strides() noexcept { return fixed_strides{}; } diff --git a/src/Native/include/nncase/ntt/kernels/matmul.h b/src/Native/include/nncase/ntt/kernels/matmul.h index 91e8deef6..7da990889 100644 --- a/src/Native/include/nncase/ntt/kernels/matmul.h +++ b/src/Native/include/nncase/ntt/kernels/matmul.h @@ -15,12 +15,63 @@ #pragma once #include "../apply.h" #include "../profiler.h" -#include "../tensor_ops.h" -#include "binary.h" -#include "matmul.h" +#include "../shape_infer/matmul.h" +#include "../ukernels.h" +#include "nncase/ntt/primitive_ops.h" +#include "nncase/ntt/shape.h" +#include "nncase/ntt/shape_infer/reduce.h" +#include "nncase/ntt/utility.h" +#include namespace nncase::ntt { namespace detail { +template +constexpr ukernels::mamtul_pack_kind get_matmul_pack_kind() noexcept { + if constexpr (LhsPackedAxes::rank() == 0 && RhsPackedAxes::rank() == 0) { + return ukernels::mamtul_pack_kind::no_pack; + } else if constexpr (LhsPackedAxes::rank() == 1 && + LhsPackedAxes::at(0) == TLhs::rank() - 2 && + RhsPackedAxes::rank() == 0) { + return ukernels::mamtul_pack_kind::pack_m; + } else if constexpr (LhsPackedAxes::rank() == 0 && + RhsPackedAxes::rank() == 1 && + RhsPackedAxes::at(0) == TRhs::rank() - 1) { + return ukernels::mamtul_pack_kind::pack_n; + } else if constexpr (LhsPackedAxes::rank() == 1 && + LhsPackedAxes::at(0) == TLhs::rank() - 1 && + RhsPackedAxes::rank() == 1 && + RhsPackedAxes::at(0) == TRhs::rank() - 2) { + return ukernels::mamtul_pack_kind::pack_k; + } else if constexpr (LhsPackedAxes::rank() == 1 && + LhsPackedAxes::at(0) == TLhs::rank() - 2 && + RhsPackedAxes::rank() == 1 && + RhsPackedAxes::at(0) == TRhs::rank() - 1) { + return ukernels::mamtul_pack_kind::pack_mn; + } else if constexpr (LhsPackedAxes::rank() == 2 && + LhsPackedAxes::at(0) == TLhs::rank() - 2 && + LhsPackedAxes::at(1) == TLhs::rank() - 1 && + RhsPackedAxes::rank() == 1 && + RhsPackedAxes::at(0) == TRhs::rank() - 2) { + return ukernels::mamtul_pack_kind::pack_mk; + } else if constexpr (LhsPackedAxes::rank() == 1 && + LhsPackedAxes::at(0) == TLhs::rank() - 1 && + RhsPackedAxes::rank() == 2 && + RhsPackedAxes::at(0) == TRhs::rank() - 2 && + RhsPackedAxes::at(1) == TRhs::rank() - 1) { + return ukernels::mamtul_pack_kind::pack_kn; + } else if constexpr (LhsPackedAxes::rank() == 2 && + LhsPackedAxes::at(0) == TLhs::rank() - 2 && + LhsPackedAxes::at(1) == TLhs::rank() - 1 && + RhsPackedAxes::rank() == 2 && + RhsPackedAxes::at(0) == TRhs::rank() - 2 && + RhsPackedAxes::at(1) == TRhs::rank() - 1) { + return ukernels::mamtul_pack_kind::pack_mkn; + } else { + return ukernels::mamtul_pack_kind::unknown; + } +} + template @@ -28,172 +79,278 @@ class matmul_impl; /** * @brief Fixed 1D-packed matmul with non transposed A/B - * @remarks Loop orders: (k, m, n) + * @remarks Loop orders: (m, n, k) */ template class matmul_impl { + using TOutElem = typename TOut::element_type; + + static constexpr auto pack_kind = + get_matmul_pack_kind(); + using policy_t = + ntt::ukernels::u_matmul_policy; + static constexpr auto m0_subtile = policy_t::m0_subtile; + public: void operator()(const TLhs &lhs, const TRhs &rhs, TOut &output) { - auto lhs_p = lhs.elements().data(); - auto rhs_p = rhs.elements().data(); - auto out_p = output.elements().data(); - apply<0>(lhs, rhs, output, lhs_p, rhs_p, out_p); + auto domain = + slice_fixed_dims(typename TOut::shape_type{}); + ntt::apply(domain, [&](auto out_offset_prefix) { + ranked_shape out_offset{}; + std::copy(out_offset_prefix.begin(), out_offset_prefix.end(), + out_offset.begin()); + auto lhs_offset = + shape_infer::reduced_index_by_shape(out_offset, TLhs::shape()); + auto rhs_offset = + shape_infer::reduced_index_by_shape(out_offset, TRhs::shape()); + auto lhs_shape = shape_infer::sub_matmul_shape(TLhs::shape()); + auto rhs_shape = shape_infer::sub_matmul_shape(TRhs::shape()); + auto out_shape = shape_infer::sub_matmul_shape(TOut::shape()); + + auto a = lhs.view(lhs_offset, lhs_shape) + .squeeze(make_index_axes()); + auto b = rhs.view(rhs_offset, rhs_shape) + .squeeze(make_index_axes()); + auto c = output.view(out_offset, out_shape) + .squeeze(make_index_axes()); + matmul_2d_l1(a, b, c); + }); } private: - template - constexpr void apply(const TLhs &lhs, const TRhs &rhs, TOut &output, - TLhsP lhs_p, TRhsP rhs_p, TOutP out_p) { - // 1. Inner matmul ranks - if constexpr (Axis == TOut::rank() - 2) { - matmul_2d(lhs, rhs, output, lhs_p, rhs_p, out_p); - } else { - for (size_t i = 0; i < TOut::shape()[Axis]; i++) { - apply(lhs, rhs, output, lhs_p, rhs_p, out_p); - lhs_p += - utility_detail::get_safe_stride(lhs, Axis, TOut::shape()); - rhs_p += - utility_detail::get_safe_stride(rhs, Axis, TOut::shape()); - out_p += output.strides()[Axis]; + template + constexpr void matmul_2d_l1(const TA &a, const TB &b, TC &c) { + size_t M = c.shape()[c.rank() - 2]; + size_t N = c.shape()[c.rank() - 1]; + size_t K = a.shape()[a.rank() - 1]; + constexpr auto m0_tile = policy_t::m0_tile; + constexpr auto n0_tile = policy_t::n0_tile; + + size_t m1 = 0; + for (; m1 < M / m0_tile * m0_tile; m1 += m0_tile) { + size_t n1 = 0; + for (; n1 < N / n0_tile * n0_tile; n1 += n0_tile) { + matmul_2d_l0(a, b, c, K, m1, n1); } - } - } - template - constexpr void matmul_2d(const TLhs &lhs, const TRhs &rhs, TOut &output, - TLhsP lhs_p, TRhsP rhs_p, TOutP out_p) { - const size_t M = output.shape()[output.rank() - 2]; - const size_t K = lhs.shape()[lhs.rank() - 1]; - const size_t N = output.shape()[output.rank() - 1]; - const size_t lhs_stride = lhs.strides()[lhs.rank() - 2]; - const size_t rhs_stride = rhs.strides()[rhs.rank() - 2]; - const size_t out_stride = output.strides()[output.rank() - 2]; - - outer_product(lhs_p, rhs_p, out_p, M, N, K, lhs_stride, - rhs_stride, out_stride); - - if constexpr (LhsPackedAxes::rank() == 1 && - LhsPackedAxes::at(0) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 2) { - return; + if (N % n0_tile) { + for (; n1 < N; n1++) { + matmul_2d_l0(a, b, c, K, m1, n1); + } + } } - if constexpr (LhsPackedAxes::rank() == 2 && - LhsPackedAxes::at(0) == TLhs::rank() - 2 && - LhsPackedAxes::at(1) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 2) { - return; - } + if (M % m0_tile) { + for (; m1 < M; m1++) { + size_t n1 = 0; + for (; n1 < N / n0_tile * n0_tile; n1 += n0_tile) { + matmul_2d_l0<1, n0_tile>(a, b, c, K, m1, n1); + } - for (size_t k = 1; k < K; k++) { - outer_product(lhs_p, rhs_p, out_p, M, N, K, lhs_stride, - rhs_stride, out_stride); + if (N % n0_tile) { + for (; n1 < N; n1++) { + matmul_2d_l0<1, 1>(a, b, c, K, m1, n1); + } + } + } } } - template - void outer_product(const TLhsElem *&lhs, const TRhsElem *&rhs, - TOutElem *output, size_t M, size_t N, - [[maybe_unused]] size_t K, size_t lhs_stride, - size_t rhs_stride, size_t out_stride) { - - // 1. 1D-packing: pack K - if constexpr (LhsPackedAxes::rank() == 1 && - LhsPackedAxes::at(0) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 2) { - auto lhs_mp = lhs; - for (size_t m = 0; m < M; m++) { - mul_add(lhs_mp, rhs, output, N, K, rhs_stride); - lhs_mp += lhs_stride; - output += out_stride; + template + void matmul_2d_l0(const TA &a, const TB &b, TC &c, size_t K, size_t m1, + size_t n1) { + auto c0 = + c.view(make_ranked_shape(m1, n1), fixed_shape{}); + + // 1. pack M & N + if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn && + m0_subtile) { + using TSubOutElem = ntt::vector; + TSubOutElem c0_tmp[m0_subtile][N0Tile]; + + for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; + sm1 += m0_subtile) { + ntt::apply(fixed_shape{}, [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(0, index[1])(sm1 + index[0]) + : TSubOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + outer_product(a, b, c0_tmp, m1, k1, n1, + sm1); + } + + ntt::apply(fixed_shape{}, [&](auto index) { + c0(0, index[1])(sm1 + index[0]) = + c0_tmp[index[0]][index[1]]; + }); } } - // 2. 2D-packing: pack MK & K - else if constexpr (LhsPackedAxes::rank() == 2 && - LhsPackedAxes::at(0) == TLhs::rank() - 2 && - LhsPackedAxes::at(1) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 2) { - auto lhs_mp = lhs; - for (size_t m = 0; m < M; m++) { - mul_add(lhs_mp, rhs, output, N, K, rhs_stride); - lhs_mp += lhs_stride; - output += out_stride; + // 2. pack K & KN + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) { + using TLhsElem = std::remove_const_t; + + TOutElem c0_tmp[M0Tile][N0Tile]; + ntt::apply(c0.shape(), [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(index) : TOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { + outer_product(a, b, c0_tmp, m1, k1, n1, 0, + sk1); + } + } + + ntt::apply(c0.shape(), [&](auto index) { + c0(index) = c0_tmp[index[0]][index[1]]; + }); + } + // 3. pack MK & KN + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn && + m0_subtile) { + using TLhsElem = std::remove_const_t; + using TSubOutElem = ntt::vector; + + TSubOutElem c0_tmp[m0_subtile][N0Tile]; + + for (size_t sm1 = 0; sm1 < TOutElem::shape()[0]; + sm1 += m0_subtile) { + ntt::apply(fixed_shape{}, [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(0, index[1])(sm1 + index[0]) + : TSubOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + for (size_t sk1 = 0; sk1 < TLhsElem::shape()[0]; sk1++) { + outer_product(a, b, c0_tmp, m1, k1, n1, + sm1, sk1); + } + } + + ntt::apply(fixed_shape{}, [&](auto index) { + c0(0, index[1])(sm1 + index[0]) = + c0_tmp[index[0]][index[1]]; + }); } } - // 3. Other Case + // Other packs else { - auto lhs_mp = lhs; - for (size_t m = 0; m < M; m++) { - outer_product(*lhs_mp, rhs, output, N); - lhs_mp += lhs_stride; - output += out_stride; + TOutElem c0_tmp[M0Tile][N0Tile]; + ntt::apply(c0.shape(), [&](auto index) { + c0_tmp[index[0]][index[1]] = + AccumulateC ? c0(index) : TOutElem{}; + }); + + for (size_t k1 = 0; k1 < K; k1++) { + outer_product(a, b, c0_tmp, m1, k1, n1); } - lhs += 1; - rhs += rhs_stride; + + ntt::apply(c0.shape(), [&](auto index) { + c0(index) = c0_tmp[index[0]][index[1]]; + }); } } - template - void outer_product(const TLhsElem &lhs, const TRhsElem *rhs, - TOutElem *output, size_t extent) { - for (size_t i = 0; i < extent; i++) { - mul_add(lhs, *rhs++, *output++); + template + void outer_product(const TA &a, const TB &b, TC &c0_tmp, size_t m1, + size_t k1, size_t n1, size_t sm1 = 0, size_t sk1 = 0) { + auto a1 = + a.view(make_ranked_shape(m1, k1), make_ranked_shape(M0Tile, 1)); + auto b1 = + b.view(make_ranked_shape(k1, n1), make_ranked_shape(1, N0Tile)); + + using TLhsElem = std::remove_const_t; + using TRhsElem = std::remove_const_t; + + // 1. pack M & N + if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn && + m0_subtile) { + using TSubLhsElem = typename TLhsElem::element_type; + TSubLhsElem a0_tmp[m0_subtile]; + TRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0]); + }); + ntt::apply(fixed_shape{}, + [&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < m0_subtile; m++) { + mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); + } + } } - } + // 2. pack K & KN + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_kn) { + using TSubLhsElem = typename TLhsElem::element_type; + using TSubRhsElem = ntt::vector; + TSubLhsElem a0_tmp[M0Tile]; + TSubRhsElem b0_tmp[N0Tile]; - // 1. 1D-packing: pack K - template - void mul_add(const TLhsElem *lhs, const TRhsElem *rhs, TOutElem *output, - size_t extent, size_t K, size_t rhs_stride) { - - // 1. 1D-packing: pack K - if constexpr (LhsPackedAxes::rank() == 1 && - LhsPackedAxes::at(0) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 2) { - - for (size_t i = 0; i < extent; i++) { - auto rhs_mp = rhs; - auto lhs_mp = lhs; - for (size_t k = 0; k < K; k++) { - auto value = ntt::inner_product(*lhs_mp, *rhs_mp); - *output = AccumulateC || k > 0 ? *output + value : value; - lhs_mp++; - rhs_mp += rhs_stride; + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a1(index[0], 0)(sk1); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b1(0, index[0])(sk1); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < M0Tile; m++) { + mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); } - rhs++; - output++; } } + // 1. pack MK & KN + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn && + m0_subtile) { + using TSubLhsElem = typename TLhsElem::element_type; + using TSubRhsElem = ntt::vector; + TSubLhsElem a0_tmp[m0_subtile]; + TSubRhsElem b0_tmp[N0Tile]; - // 2. 2D-packing: pack MK & K - else if constexpr (LhsPackedAxes::rank() == 2 && - LhsPackedAxes::at(0) == TLhs::rank() - 2 && - LhsPackedAxes::at(1) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 2) { - - for (size_t i = 0; i < extent; i++) { - auto rhs_mp = rhs; - auto lhs_mp = lhs; - for (size_t k = 0; k < K; k++) { - for (size_t m = 0; m < TLhsElem::shape().at(0); m++) { - auto value = ntt::inner_product((*lhs_mp)(m), *rhs_mp); - (*output)(m) = - AccumulateC || k > 0 ? (*output)(m) + value : value; - } - lhs_mp++; - rhs_mp += rhs_stride; + ntt::apply(fixed_shape{}, [&](auto index) { + a0_tmp[index[0]] = a1(0, 0)(sm1 + index[0], sk1); + }); + ntt::apply(fixed_shape{}, [&](auto index) { + b0_tmp[index[0]] = b1(0, index[0])(sk1); + }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < m0_subtile; m++) { + auto &output = c0_tmp[m][n]; + auto value = ntt::outer_product(a0_tmp[m], b0_tmp[n]); + output = output + value; + } + } + } + // Other packs + else { + TLhsElem a0_tmp[M0Tile]; + TRhsElem b0_tmp[N0Tile]; + + ntt::apply(fixed_shape{}, + [&](auto index) { a0_tmp[index[0]] = a1(index[0], 0); }); + ntt::apply(fixed_shape{}, + [&](auto index) { b0_tmp[index[0]] = b1(0, index[0]); }); + + for (size_t n = 0; n < N0Tile; n++) { + for (size_t m = 0; m < M0Tile; m++) { + mul_add(a0_tmp[m], b0_tmp[n], c0_tmp[m][n]); } - rhs++; - output++; } } } @@ -201,51 +358,38 @@ class matmul_impl void mul_add(const TLhsElem &lhs, const TRhsElem &rhs, TOutElem &output) { // 1. 0D-packing - if constexpr (LhsPackedAxes::rank() == 0 && - RhsPackedAxes::rank() == 0) { + if constexpr (pack_kind == ukernels::mamtul_pack_kind::no_pack) { output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); } // 2. 1D-packing // 2.1. pack M - else if constexpr (LhsPackedAxes::rank() == 1 && - LhsPackedAxes::at(0) == TLhs::rank() - 2 && - RhsPackedAxes::rank() == 0) { + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_m) { output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); } - // 2.2. pack N - else if constexpr (LhsPackedAxes::rank() == 0 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 1) { + // 2.2. pack K + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_k) { + auto value = ntt::inner_product(lhs, rhs); + output = AccC ? output + value : value; + } + // 2.3. pack N + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_n) { output = AccC ? ntt::mul_add(lhs, rhs, output) : ntt::mul(lhs, rhs); } - // 2.3. pack M & N - else if constexpr (LhsPackedAxes::rank() == 1 && - LhsPackedAxes::at(0) == TLhs::rank() - 2 && - RhsPackedAxes::rank() == 1 && - RhsPackedAxes::at(0) == TRhs::rank() - 1) { + // 2.4. pack M & N + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mn || + pack_kind == ukernels::mamtul_pack_kind::pack_kn) { auto value = ntt::outer_product(lhs, rhs); output = AccC ? output + value : value; } - // 3.2. pack K & KN - else if constexpr (LhsPackedAxes::rank() == 1 && - LhsPackedAxes::at(0) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 2 && - RhsPackedAxes::at(0) == TRhs::rank() - 2 && - RhsPackedAxes::at(1) == TRhs::rank() - 1) { - fixed_tensor_alike_t lhs_2d{ - {lhs}}; - fixed_tensor_alike_t - output_2d{{output}}; - output_2d = ntt::mma(lhs_2d, rhs, output_2d); - output = output_2d(0); + // 3.1. pack MK & K + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mk) { + for (size_t m = 0; m < lhs.shape()[0]; m++) { + auto value = ntt::inner_product(lhs(m), rhs); + output(m) = AccC ? output(m) + value : value; + } } - // 3.3. pack MK & KN - else if constexpr (LhsPackedAxes::rank() == 2 && - LhsPackedAxes::at(0) == TLhs::rank() - 2 && - LhsPackedAxes::at(1) == TLhs::rank() - 1 && - RhsPackedAxes::rank() == 2 && - RhsPackedAxes::at(0) == TRhs::rank() - 2 && - RhsPackedAxes::at(1) == TRhs::rank() - 1) { + // 3.2. pack MK & KN + else if constexpr (pack_kind == ukernels::mamtul_pack_kind::pack_mkn) { output = ntt::mma(lhs, rhs, output); } else { static_assert(sizeof(TLhsElem) == 0, "Unsupported packing."); diff --git a/src/Native/include/nncase/ntt/primitive_ops.h b/src/Native/include/nncase/ntt/primitive_ops.h index 50a3c46d7..979bfa765 100644 --- a/src/Native/include/nncase/ntt/primitive_ops.h +++ b/src/Native/include/nncase/ntt/primitive_ops.h @@ -502,8 +502,8 @@ mma::operator()(const T1 &lhs, const T2 &rhs, TResult::rank() == 2, "only support 2d mma"); TResult output = v3; - for (size_t m = 0; m < T1::shape().at(0); m++) { - for (size_t k = 0; k < T2::shape().at(0); k++) { + for (size_t k = 0; k < T2::shape().at(0); k++) { + for (size_t m = 0; m < T1::shape().at(0); m++) { output(m) = (k != 0 || AccC) ? ntt::mul_add(lhs(m, k), rhs(k), output(m)) : ntt::mul(lhs(m, k), rhs(k)); diff --git a/src/Native/include/nncase/ntt/shape.h b/src/Native/include/nncase/ntt/shape.h index a5683b5d4..39fe50f2b 100644 --- a/src/Native/include/nncase/ntt/shape.h +++ b/src/Native/include/nncase/ntt/shape.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace nncase::ntt { @@ -58,6 +59,9 @@ template struct ranked_dims_base { constexpr auto begin() const noexcept { return dims_.begin(); } constexpr auto end() const noexcept { return dims_.end(); } + constexpr auto begin() noexcept { return dims_.begin(); } + constexpr auto end() noexcept { return dims_.end(); } + constexpr size_t last() const noexcept { return at(rank() - 1); } constexpr size_t &last() noexcept { return at(rank() - 1); } @@ -123,6 +127,32 @@ template struct repeat_shape_impl> { using shape_t = fixed_shape<((void)Dims, Value)...>; }; + +#define DEFINE_SQUEEZE_FIXED_DIMS_IMPL(name) \ + template \ + struct squeeze_fixed_##name##_impl { \ + static constexpr auto get_##name() noexcept { \ + if constexpr (Index >= TShape::rank()) { \ + return fixed_##name<>{}; \ + } else { \ + using src_type = \ + typename squeeze_fixed_##name##_impl::name##_t; \ + if constexpr (TAxes::contains(Index)) { \ + return src_type{}; \ + } else { \ + using type = \ + src_type::template prepend::type; \ + return type{}; \ + } \ + } \ + } \ + \ + using name##_t = decltype(get_##name()); \ + }; + +DEFINE_SQUEEZE_FIXED_DIMS_IMPL(shape) +DEFINE_SQUEEZE_FIXED_DIMS_IMPL(strides) } // namespace detail template struct is_fixed_dims : std::false_type {}; @@ -206,6 +236,27 @@ using repeat_shape_t = template using zero_shape_t = repeat_shape_t<0, Rank>; +#define DEFINE_SQUEEZE_DIMS_TYPE(name) \ + template struct squeeze_##name##_type; \ + \ + template \ + struct squeeze_##name##_type, Axes...> { \ + using type = ranked_##name; \ + }; \ + \ + template \ + struct squeeze_##name##_type, Axes...> { \ + using type = detail::squeeze_fixed_##name##_impl< \ + fixed_##name, fixed_##name, 0>::name##_t; \ + }; \ + \ + template \ + using squeeze_##name##_t = \ + typename squeeze_##name##_type::type; + +DEFINE_SQUEEZE_DIMS_TYPE(shape) +DEFINE_SQUEEZE_DIMS_TYPE(strides) + template auto make_ranked_shape(Args &&...args) noexcept { return ranked_shape{ static_cast(std::forward(args))...}; @@ -216,6 +267,18 @@ template auto make_ranked_strides(Args &&...args) noexcept { static_cast(std::forward(args))...}; } +template +constexpr auto to_fixed_shape(ranked_shape shape, + std::index_sequence) noexcept { + return fixed_shape{}; +} + +template +constexpr auto to_fixed_strides(ranked_strides strides, + std::index_sequence) noexcept { + return fixed_strides{}; +} + template constexpr auto default_strides(const Shape &shape) noexcept { if constexpr (is_fixed_dims_v) { @@ -330,6 +393,37 @@ ranked_shape get_reduced_offset(Index in_offset) { return off; } +template +constexpr auto squeeze_shape(fixed_shape axes, TShape shape) noexcept { + if constexpr (is_fixed_dims_v>) { + return squeeze_shape_t{}; + } else { + ranked_shape new_shape; + size_t cnt = 0; + for (size_t axis = 0; axis < axes.rank(); axis++) { + if (!axes.contains(axis)) { + new_shape[cnt++] = shape[axis]; + } + } + } +} + +template +constexpr auto squeeze_strides(fixed_shape axes, + TStrides strides) noexcept { + if constexpr (is_fixed_dims_v>) { + return squeeze_strides_t{}; + } else { + ranked_strides new_strides; + size_t cnt = 0; + for (size_t axis = 0; axis < axes.rank(); axis++) { + if (!axes.contains(axis)) { + new_strides[cnt++] = strides[axis]; + } + } + } +} + template bool operator==(const ranked_shape &lhs, const ranked_shape &rhs) noexcept { diff --git a/src/Native/include/nncase/ntt/shape_infer/matmul.h b/src/Native/include/nncase/ntt/shape_infer/matmul.h new file mode 100644 index 000000000..90c536750 --- /dev/null +++ b/src/Native/include/nncase/ntt/shape_infer/matmul.h @@ -0,0 +1,61 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../shape.h" +#include + +namespace nncase::ntt::shape_infer { +namespace detail { +template +constexpr size_t sub_matmul_shape_dim(const Shape &shape, size_t axis) { + return axis >= Shape::rank() - 2 ? shape[axis] : 1; +} + +template struct ranked_sub_matmul_shape_impl; + +template +struct ranked_sub_matmul_shape_impl> { + using type = ranked_shape; + + static constexpr type value(const Shape &shape) { + return type{sub_matmul_shape_dim(shape, Axes)...}; + } +}; + +template struct fixed_sub_matmul_shape_impl; + +template +struct fixed_sub_matmul_shape_impl> { + using type = fixed_shape; + + static constexpr type value(const Shape &) { return type{}; } +}; + +template +struct sub_matmul_shape_impl + : ranked_sub_matmul_shape_impl> {}; + +template +struct sub_matmul_shape_impl> + : fixed_sub_matmul_shape_impl, + std::make_index_sequence> {}; +} // namespace detail + +template constexpr auto sub_matmul_shape(const Shape &shape) { + return detail::sub_matmul_shape_impl::value(shape); +} + +} // namespace nncase::ntt::shape_infer diff --git a/src/Native/include/nncase/ntt/tensor.h b/src/Native/include/nncase/ntt/tensor.h index 82b15552c..c94dbb9ed 100644 --- a/src/Native/include/nncase/ntt/tensor.h +++ b/src/Native/include/nncase/ntt/tensor.h @@ -16,7 +16,9 @@ #include "detail/shape_storage.h" #include "detail/tensor_storage.h" #include "nncase/ntt/shape.h" +#include "nncase/ntt/utility.h" #include "tensor_traits.h" +#include namespace nncase::ntt { template @@ -234,6 +236,15 @@ class basic_tensor return {buffer(), shape, default_strides(shape)}; } + template + constexpr auto squeeze(fixed_shape axes) noexcept { + constexpr auto new_shape = squeeze_shape(axes, shape()); + constexpr auto new_strides = squeeze_strides(axes, strides()); + return tensor_view, + std::decay_t>( + buffer(), new_shape, new_strides); + } + constexpr tensor_view view() noexcept { return view(zero_shape_t{}, shape()); } diff --git a/src/Native/include/nncase/ntt/ukernels.h b/src/Native/include/nncase/ntt/ukernels.h index e4045b721..2a79d325a 100644 --- a/src/Native/include/nncase/ntt/ukernels.h +++ b/src/Native/include/nncase/ntt/ukernels.h @@ -123,6 +123,26 @@ template struct u_reduce { } } }; + +enum class mamtul_pack_kind { + unknown, + no_pack, + pack_m, + pack_k, + pack_n, + pack_mn, + pack_mk, + pack_kn, + pack_mkn, +}; + +template +struct u_matmul_policy { + static constexpr size_t m0_tile = 1; + static constexpr size_t n0_tile = 1; + static constexpr size_t m0_subtile = 0; +}; } // namespace nncase::ntt::ukernels namespace nncase::ntt { @@ -136,8 +156,16 @@ constexpr void u_pack(const TIn *input, TOut *output) noexcept { template constexpr T u_reduce(const T *input, size_t input_stride, size_t count, - T init_value) { + T init_value) noexcept { ukernels::u_reduce impl; return impl(input, input_stride, count, init_value); } + +// template +// constexpr void u_matmul(const TLhsElem *&lhs, const TRhsElem *&rhs, +// TOutElem *output, size_t M, size_t N, size_t K, +// size_t lhs_stride, size_t rhs_stride, +// size_t out_stride) noexcept { + +// } } // namespace nncase::ntt diff --git a/src/Native/include/nncase/ntt/utility.h b/src/Native/include/nncase/ntt/utility.h index a72aa0bb1..98376f8eb 100644 --- a/src/Native/include/nncase/ntt/utility.h +++ b/src/Native/include/nncase/ntt/utility.h @@ -149,4 +149,13 @@ reduce_source_offset(ranked_shape out_index) noexcept { return in_index; } } + +template +constexpr auto make_index_axes(std::index_sequence) noexcept { + return fixed_shape{}; +} + +template constexpr auto make_index_axes() noexcept { + return make_index_axes(std::make_index_sequence()); +} } // namespace nncase::ntt diff --git a/src/Native/test/benchmark_test/benchmark_ntt_clamp.cpp b/src/Native/test/benchmark_test/benchmark_ntt_clamp.cpp index 383b09dac..42a42cb87 100644 --- a/src/Native/test/benchmark_test/benchmark_ntt_clamp.cpp +++ b/src/Native/test/benchmark_test/benchmark_ntt_clamp.cpp @@ -45,8 +45,10 @@ void benchmark_ntt_clamp(T init_low, T init_high, T clamp_low, T clamp_high) { // run auto t1 = NttTest::get_cpu_cycle(); - for (size_t i = 0; i < run_size; i++) + for (size_t i = 0; i < run_size; i++) { ntt::clamp(*ntt_input, *ntt_output, clamp_low, clamp_high); + asm volatile("" ::"g"(ntt_output)); + } auto t2 = NttTest::get_cpu_cycle(); std::cout << __FUNCTION__ << "_" << pack_mode << " took "