Skip to content

Commit

Permalink
Addressed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Aug 16, 2023
1 parent fd61d19 commit 48faa2d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 18 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ For all these operations:
* `A`, `B` and `C` are containers for the column-major matrices A, B and C.
* `lda`, `ldb` and `ldc` are the leading dimensions of the matrices A, B and C
(cf BLAS 2). The leading dimension of a matrix must be greater than or equal
to its number of rows. In the case of in-place transpose, the same matrix `A`
to its number of rows. In the case of in-place copy/transpose, the same matrix `A`
is used with two different leading dimensions for input & output.
* `stride_a`, `stride_b` and `stride_c` are the striding size between consecutive
matrices in a batched entry for inputs/outputs A, B and C.
Expand All @@ -319,6 +319,7 @@ matrices in a batched entry for inputs/outputs A, B and C.
| `_omatcopy2`| `sb_handle`, `transa`, `M`, `N`, `alpha`, `A`, `lda`, `inc_a`, `B`, `ldb`, `inc_b` | Computes two-strided scaling and out-of-place transposition or copying of general dense matrices. |
| `_omatadd`| `sb_handle`, `transa`, `transb`, `M`, `N`, `alpha`, `A`, `lda`, `beta`, `B`, `ldb`, `C`,`ldc` | Computes scaled general dense matrix addition with possibly transposed arguments. |
| `_omatcopy_batch` | `sb_handle`, `transa`, `M`, `N`, `alpha`, `A`, `lda`, `stride_a`, `B`, `ldb`, `stride_b`, `batch_size` | Perform an out-of-place scaled batched-strided matrix transpose or copy operation using a general dense matrix. |
| `_imatcopy_batch` | `sb_handle`, `transa`, `M`, `N`, `alpha`, `A`, `lda`, `ldb`, `stride`, `batch_size` | Perform an in-place scaled batched-strided matrix transpose* or copy operation using a general dense matrix. (*: Currently the transpose case is not supported). |

Other non-official extension operators :
| operation | arguments | description |
Expand Down
9 changes: 5 additions & 4 deletions common/include/common/common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1199,10 +1199,11 @@ static inline std::vector<omatadd_param_t<scalar_t>> get_omatadd_params(
}

/**
*@fn get_matcopy_batch_params *@brief Returns a vector containing the
matcopy_batch benchmark parameters,
*either read from a file according to the command - line args,
or the default *ones.*/
* @fn get_matcopy_batch_params
* @brief Returns a vector containing the matcopy_batch benchmark parameters,
* either read from a file according to the command - line args, or the default
* ones.
*/
template <typename scalar_t>
static inline std::vector<matcopy_batch_param_t<scalar_t>>
get_matcopy_batch_params(Args& args) {
Expand Down
27 changes: 15 additions & 12 deletions src/interface/extension_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ _matcopy_impl(sb_handle_t& sb_handle, index_t m, index_t n, element_t alpha,
} else {
// TODO
// In-place transpose not implemented.
typename sb_handle_t::event_t ret;
return ret;
throw std::runtime_error("In-place transpose not implemented.");
}
}

Expand Down Expand Up @@ -156,9 +155,12 @@ typename sb_handle_t::event_t _matcopy_batch_impl(
index_t ld_out, index_t out_stride, index_t batch_size) {
auto in_view = make_matrix_view<col_major>(in_memory, m, n, ld_in);
auto out_view = make_matrix_view<col_major>(out_memory, m, n, ld_out);
const element_t beta = 0;
const index_t ld_b = 0;
const index_t stride_b = 0;
auto copy_batch_tree = make_matcopy_batch<TileSize, TilePerWG>(
out_view, in_view, in_view, alpha, 0, m, n, ld_out, ld_in, 1, out_stride,
in_stride, 1, batch_size);
out_view, in_view, in_view, alpha, beta, m, n, ld_out, ld_in, ld_b,
out_stride, in_stride, stride_b, batch_size);
constexpr index_t local_size = TileSize * TilePerWG;
const index_t tile_per_matrix =
(((m - 1) / TileSize) + 1) * (((n - 1) / TileSize) + 1);
Expand Down Expand Up @@ -349,8 +351,7 @@ typename sb_handle_t::event_t _matcopy(sb_handle_t& sb_handle, char trans,
// bail out early if the leading dimensions are not correct
if (ld_in < (inc_in * (m - 1) + 1) ||
(ld_out - 1) < (trans == 't' ? inc_out * (n - 1) : inc_out * (m - 1))) {
typename sb_handle_t::event_t ret;
return ret;
throw std::invalid_argument("invalid ld_in and/or ld_out, inc_out, inc_in");
}

const index_t stride = 1;
Expand All @@ -374,11 +375,12 @@ typename sb_handle_t::event_t _matcopy_batch(
in_t in_memory, index_t ld_in, index_t stride_in, out_t out_memory,
index_t ld_out, index_t stride_out, index_t batch_size) {
// bail out early if the leading dimensions / strides are not correct
if (ld_in < m || (ld_out < (trans == 't' ? n : m)) ||
(stride_in < ld_in * n) ||
if (ld_in < m || (ld_out < (trans == 't' ? n : m))) {
throw std::invalid_argument("invalid ld_in and/or ld_out");
}
if ((stride_in < ld_in * n) ||
(stride_out < (ld_out * (trans == 't' ? m : n)))) {
typename sb_handle_t::event_t ret;
return ret;
throw std::invalid_argument("invalid stride_in and/or stride_out");
}

const index_t increment = 1;
Expand Down Expand Up @@ -434,12 +436,13 @@ typename sb_handle_t::event_t _transpose(sb_handle_t& sb_handle, index_t m,
return ret;
}

const element_t alpha = 1;
const index_t inc = 1;
const index_t stride = 1;
const index_t batch_size = 1;

return _matcopy_impl<in_place, true>(sb_handle, m, n, (float)1.0, A, ld_a,
inc, stride, B, ld_b, inc, stride,
return _matcopy_impl<in_place, true>(sb_handle, m, n, alpha, A, ld_a, inc,
stride, B, ld_b, inc, stride,
batch_size);
}

Expand Down
2 changes: 2 additions & 0 deletions src/operations/extension/matcopy_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,15 @@ SYCL_BLAS_INLINE void Matcopy_batch<TileSize, TilePerWG, lhs_t, rhs_t>::bind(
cl::sycl::handler& h) {
lhs_.bind(h);
rhs_1_.bind(h);
rhs_2_.bind(h);
}

template <int TileSize, int TilePerWG, typename lhs_t, typename rhs_t>
SYCL_BLAS_INLINE void
Matcopy_batch<TileSize, TilePerWG, lhs_t, rhs_t>::adjust_access_displacement() {
lhs_.adjust_access_displacement();
rhs_1_.adjust_access_displacement();
rhs_2_.adjust_access_displacement();
}

template <int TileSize, int TilePerWG, typename lhs_t, typename rhs_t>
Expand Down
3 changes: 2 additions & 1 deletion test/unittest/extension/transpose_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ void run_test(const combination_t<scalar_t>& combi) {
ASSERT_TRUE(isAlmostEqual);

} else {
// Inplace Transpose: TODO
// Inplace Transpose currently disabled (TODO)
GTEST_SKIP();
}
}

Expand Down

0 comments on commit 48faa2d

Please sign in to comment.