Skip to content

Commit

Permalink
Refactor template argument to fix compilation error with ComputeCpp
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick committed Jun 28, 2023
1 parent f637af5 commit 2be9541
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
10 changes: 4 additions & 6 deletions include/operations/extension/matcopy_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@

namespace blas {

enum class matcopy_op : int { inplace = 0, outplace = 1, outplaceadd = 2 };

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
struct Matcopy_batch {
public:
Expand Down Expand Up @@ -57,16 +55,16 @@ struct Matcopy_batch {
void adjust_access_displacement();
};

template <blas::matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t> make_matcopy_batch(
Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t> make_matcopy_batch(
lhs_t lhs, rhs_t rhs_1, rhs_t rhs_2, typename rhs_t::value_t alpha,
typename rhs_t::value_t beta, typename rhs_t::index_t m,
typename rhs_t::index_t n, typename rhs_t::index_t lhs_ld,
typename rhs_t::index_t rhs_ld, typename rhs_t::index_t rhs_2_ld,
typename rhs_t::index_t lhs_stride, typename rhs_t::index_t rhs_stride,
typename rhs_t::index_t rhs_2_stride, typename rhs_t::index_t batch_size) {
return Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>(
return Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t>(
lhs, rhs_1, rhs_2, alpha, beta, m, n, lhs_ld, rhs_ld, rhs_2_ld,
lhs_stride, rhs_stride, rhs_2_stride, batch_size);
}
Expand Down
2 changes: 1 addition & 1 deletion src/interface/extension_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ typename sb_handle_t::event_t _matcopy_batch_impl(
auto out_view =
make_matrix_view<col_major>(out_memory, m, n, ld_out);
auto copy_batch_tree =
make_matcopy_batch<matcopy_op::outplace, TileSize, TilePerWG>(
make_matcopy_batch<false, TileSize, TilePerWG>(
out_view, in_view, in_view, alpha, 0, m, n, ld_out, ld_in, 1,
out_stride, in_stride, 1, batch_size);
constexpr index_t local_size = TileSize * TilePerWG;
Expand Down
28 changes: 14 additions & 14 deletions src/operations/extension/matcopy_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

namespace blas {

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::Matcopy_batch(
Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t>::Matcopy_batch(
lhs_t lhs, rhs_t rhs_1, rhs_t rhs_2, typename lhs_t::value_t alpha,
typename lhs_t::value_t beta, typename rhs_t::index_t m,
typename rhs_t::index_t n, typename rhs_t::index_t lhs_ld,
Expand All @@ -53,15 +53,15 @@ Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::Matcopy_batch(
rhs_2_stride_(rhs_2_stride),
batch_size_(batch_size) {}

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
typename lhs_t::value_t
Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::eval(index_t i) {}
Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t>::eval(index_t i) {}

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
typename lhs_t::value_t
Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::eval(
Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t>::eval(
cl::sycl::nd_item<1> ndItem) {
const index_t m{m_};
const index_t n{n_};
Expand Down Expand Up @@ -142,33 +142,33 @@ Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::eval(
return 0;
}

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
SYCL_BLAS_INLINE void Matcopy_batch<op, TileSize, TilePerWG, lhs_t,
SYCL_BLAS_INLINE void Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t,
rhs_t>::bind(cl::sycl::handler& h) {
lhs_.bind(h);
rhs_1_.bind(h);
}

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
SYCL_BLAS_INLINE void Matcopy_batch<op, TileSize, TilePerWG, lhs_t,
SYCL_BLAS_INLINE void Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t,
rhs_t>::adjust_access_displacement() {
lhs_.adjust_access_displacement();
rhs_1_.adjust_access_displacement();
}

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
SYCL_BLAS_INLINE typename rhs_t::index_t
Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::get_size() const {
Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t>::get_size() const {
return m_ * n_;
}

template <matcopy_op op, int TileSize, int TilePerWG, typename lhs_t,
template <bool is_add, int TileSize, int TilePerWG, typename lhs_t,
typename rhs_t>
SYCL_BLAS_INLINE bool
Matcopy_batch<op, TileSize, TilePerWG, lhs_t, rhs_t>::valid_thread(
Matcopy_batch<is_add, TileSize, TilePerWG, lhs_t, rhs_t>::valid_thread(
cl::sycl::nd_item<1> ndItem) const {
return true;
}
Expand Down

0 comments on commit 2be9541

Please sign in to comment.