Skip to content

Commit

Permalink
[benchmark] Refactor omatadd benchmark state counter init
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick committed Jun 26, 2023
1 parent 3059575 commit d80c3aa
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 36 deletions.
17 changes: 5 additions & 12 deletions benchmark/cublas/extension/omatadd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ static inline void cublas_routine(args_t&&... args) {

template <typename scalar_t>
void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int ti_a,
int ti_b, index_t m, index_t n, index_t lda_mul, index_t ldb_mul,
index_t ldc_mul, scalar_t alpha, scalar_t beta, bool* success) {
int ti_b, index_t m, index_t n, scalar_t alpha, scalar_t beta,
index_t lda_mul, index_t ldb_mul, index_t ldc_mul, bool* success) {
// initialize the state label
blas_benchmark::utils::set_benchmark_label<scalar_t>(state);

Expand All @@ -70,16 +70,9 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int ti_a,
const auto size_b = ldb * ((*t_str_b == 't') ? m : n);
const auto size_c = ldc * n;

blas_benchmark::utils::init_level_1_counters<
blas_benchmark::utils::Level1Op::copy, scalar_t>(state, 3 * m * n);

state.counters["n_fl_ops"] = 3 * static_cast<double>(m * n);
state.counters["lda_m"] = (double)lda_mul;
state.counters["ldb_m"] = (double)ldb_mul;
state.counters["trans_a"] = (double)((*t_str_a == 't') ? 1 : 0);
state.counters["trans_b"] = (double)((*t_str_b == 't') ? 1 : 0);
state.counters["m"] = (double)m;
state.counters["n"] = (double)n;
blas_benchmark::utils::init_extension_counters<
blas_benchmark::utils::ExtensionOP::omatadd, scalar_t>(
state, t_str_a, t_str_b, m, n, lda_mul, ldb_mul, ldc_mul);

cublasHandle_t& cuda_handle = *cuda_handle_ptr;

Expand Down
17 changes: 5 additions & 12 deletions benchmark/rocblas/extension/omatadd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ static inline void rocblas_geam_f(args_t&&... args) {

template <typename scalar_t>
void run(benchmark::State& state, rocblas_handle& rb_handle, int ti_a, int ti_b,
index_t m, index_t n, index_t lda_mul, index_t ldb_mul,
index_t ldc_mul, scalar_t alpha, scalar_t beta, bool* success) {
index_t m, index_t n, scalar_t alpha, scalar_t beta, index_t lda_mul,
index_t ldb_mul, index_t ldc_mul, bool* success) {
// initialize the state label
blas_benchmark::utils::set_benchmark_label<scalar_t>(state);

Expand All @@ -70,16 +70,9 @@ void run(benchmark::State& state, rocblas_handle& rb_handle, int ti_a, int ti_b,
const auto size_b = ldb * ((*t_str_b == 't') ? m : n);
const auto size_c = ldc * n;

blas_benchmark::utils::init_level_1_counters<
blas_benchmark::utils::Level1Op::copy, scalar_t>(state, 3 * m * n);

state.counters["n_fl_ops"] = 3 * static_cast<double>(m * n);
state.counters["lda_m"] = (double)lda_mul;
state.counters["ldb_m"] = (double)ldb_mul;
state.counters["trans_a"] = (double)((*t_str_a == 't') ? 1 : 0);
state.counters["trans_b"] = (double)((*t_str_b == 't') ? 1 : 0);
state.counters["m"] = (double)m;
state.counters["n"] = (double)n;
blas_benchmark::utils::init_extension_counters<
blas_benchmark::utils::ExtensionOP::omatadd, scalar_t>(
state, t_str_a, t_str_b, m, n, lda_mul, ldb_mul, ldc_mul);

// Input matrix/vector, output vector.
std::vector<scalar_t> m_a =
Expand Down
17 changes: 5 additions & 12 deletions benchmark/syclblas/extension/omatadd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ std::string get_name(std::string ts_a, std::string ts_b, int m, int n,

template <typename scalar_t>
void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int ti_a,
int ti_b, index_t m, index_t n, index_t lda_mul, index_t ldb_mul,
index_t ldc_mul, scalar_t alpha, scalar_t beta, bool* success) {
int ti_b, index_t m, index_t n, scalar_t alpha, scalar_t beta,
index_t lda_mul, index_t ldb_mul, index_t ldc_mul, bool* success) {
// initiliaze the state label
blas_benchmark::utils::set_benchmark_label<scalar_t>(
state, sb_handle_ptr->get_queue());
Expand All @@ -61,16 +61,9 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int ti_a,
const auto size_b = ldb * ((*t_str_b == 't') ? m : n);
const auto size_c = ldc * n;

blas_benchmark::utils::init_level_1_counters<
blas_benchmark::utils::Level1Op::copy, scalar_t>(state, 3 * m * n);

state.counters["n_fl_ops"] = 3 * static_cast<double>(m * n);
state.counters["lda_m"] = (double)lda_mul;
state.counters["ldb_m"] = (double)ldb_mul;
state.counters["trans_a"] = (double)((*t_str_a == 't') ? 1 : 0);
state.counters["trans_b"] = (double)((*t_str_b == 't') ? 1 : 0);
state.counters["m"] = (double)m;
state.counters["n"] = (double)n;
blas_benchmark::utils::init_extension_counters<
blas_benchmark::utils::ExtensionOP::omatadd, scalar_t>(
state, t_str_a, t_str_b, m, n, lda_mul, ldb_mul, ldc_mul);

blas::SB_Handle& sb_handle = *sb_handle_ptr;

Expand Down

0 comments on commit d80c3aa

Please sign in to comment.