Skip to content

Commit

Permalink
[benchmark] Refactor state counters
Browse files Browse the repository at this point in the history
  • Loading branch information
s-Nick committed Jun 26, 2023
1 parent 49f48a5 commit d2da8a7
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
12 changes: 3 additions & 9 deletions benchmark/syclblas/extension/omatcopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,9 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int ti,
const auto size_a = lda * n;
const auto size_b = ldb * ((*t_str == 't') ? m : n);

blas_benchmark::utils::init_level_1_counters<
blas_benchmark::utils::Level1Op::copy, scalar_t>(state, 2 * m * n);

state.counters["n_fl_ops"] = static_cast<double>(m * n);
state.counters["lda_m"] = (double)lda_mul;
state.counters["ldb_m"] = (double)ldb_mul;
state.counters["trans"] = (double)((*t_str == 't') ? 1 : 0);
state.counters["m"] = (double)m;
state.counters["n"] = (double)n;
blas_benchmark::utils::init_extension_counters<
blas_benchmark::utils::ExtensionOP::omatcopy, scalar_t>(
state, t_str, m, n, lda_mul, ldb_mul);

blas::SB_Handle& sb_handle = *sb_handle_ptr;

Expand Down
79 changes: 79 additions & 0 deletions common/include/common/blas_extension_state_counters.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/***************************************************************************
*
* @license
* Copyright (C) Codeplay Software Limited
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* For your convenience, a copy of the License has been included in this
* repository.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SYCL-BLAS: BLAS implementation using SYCL
*
* @filename blas_extension_state_counters.hpp
*
**************************************************************************/

#ifndef COMMON_BLAS_EXTENSION_STATE_COUNTERS
#define COMMON_BLAS_EXTENSION_STATE_COUNTERS

namespace blas_benchmark {
namespace utils {

enum class ExtensionOP : int {
omatcopy = 0,
imatcopy = 1,
omatadd = 2,
omatcopy_batch = 3,
imatcopy_batch = 4,
omatadd_batch = 5
};

template <ExtensionOP op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == ExtensionOP::omatcopy ||
op == ExtensionOP::imatcopy>::type
init_extension_counters(benchmark::State& state, const char* trans, index_t m,
index_t n, index_t lda_mul, index_t ldb_mul) {
// Google-benchmark counters are double.
double size_d = static_cast<double>(m * n);
state.counters["m"] = static_cast<double>(m);
state.counters["n"] = static_cast<double>(n);
state.counters["n_fl_ops"] = size_d;
state.counters["lda_m"] = static_cast<double>(lda_mul);
state.counters["ldb_m"] = static_cast<double>(ldb_mul);
state.counters["trans"] = static_cast<double>((*trans == 't') ? 1 : 0);
state.counters["bytes_processed"] = (2 * size_d + 1) * sizeof(scalar_t);
return;
}

template <ExtensionOP op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == ExtensionOP::omatadd>::type
init_extension_counters(benchmark::State& state, const char* t_a,
const char* t_b, index_t m, index_t n, index_t lda_mul,
index_t ldb_mul, index_t ldc_mul) {
// Google-benchmark counters are double.
double size_d = static_cast<double>(m * n);
state.counters["m"] = static_cast<double>(m);
state.counters["n"] = static_cast<double>(n);
state.counters["n_fl_ops"] = 3 * static_cast<double>(m * n);
state.counters["lda_m"] = static_cast<double>(lda_mul);
state.counters["ldb_m"] = static_cast<double>(ldb_mul);
state.counters["ldc_m"] = static_cast<double>(ldc_mul);
state.counters["trans_a"] = static_cast<double>((*t_a == 't') ? 1 : 0);
state.counters["trans_b"] = static_cast<double>((*t_b == 't') ? 1 : 0);
state.counters["bytes_processed"] = (3 * size_d + 1) * sizeof(scalar_t);
return;
}
} // namespace utils
} // namespace blas_benchmark

#endif // COMMON_BLAS_EXTENSION_STATE_COUNTERS
1 change: 1 addition & 0 deletions common/include/common/common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <common/blas1_state_counters.hpp>
#include <common/blas2_state_counters.hpp>
#include <common/blas3_state_counters.hpp>
#include <common/blas_extension_state_counters.hpp>
#include <common/float_comparison.hpp>
#include <common/set_benchmark_label.hpp>
#include <common/system_reference_blas.hpp>
Expand Down

0 comments on commit d2da8a7

Please sign in to comment.