diff --git a/benchmark/cublas/blas3/gemm.cpp b/benchmark/cublas/blas3/gemm.cpp index 8a354cb5c..8e442bbe4 100644 --- a/benchmark/cublas/blas3/gemm.cpp +++ b/benchmark/cublas/blas3/gemm.cpp @@ -160,8 +160,8 @@ void register_benchmark(blas_benchmark::Args& args, run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(t1s, t2s, m, k, - n) + blas_benchmark::utils::get_name( + t1s, t2s, m, k, n, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/cublas/blas3/gemm_batched.cpp b/benchmark/cublas/blas3/gemm_batched.cpp index 709bb7ea3..e575e18c5 100644 --- a/benchmark/cublas/blas3/gemm_batched.cpp +++ b/benchmark/cublas/blas3/gemm_batched.cpp @@ -200,7 +200,8 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - t1s, t2s, m, k, n, batch_count, batch_type) + t1s, t2s, m, k, n, batch_count, batch_type, + blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_count, batch_type, success) diff --git a/benchmark/cublas/blas3/gemm_batched_strided.cpp b/benchmark/cublas/blas3/gemm_batched_strided.cpp index a5378f0cc..52cefc411 100644 --- a/benchmark/cublas/blas3/gemm_batched_strided.cpp +++ b/benchmark/cublas/blas3/gemm_batched_strided.cpp @@ -200,7 +200,7 @@ void register_benchmark(blas_benchmark::Args& args, benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, - stride_c_mul) + stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, stride_a_mul, stride_b_mul, stride_c_mul, success) diff --git a/benchmark/cublas/blas3/symm.cpp b/benchmark/cublas/blas3/symm.cpp index be1e4b250..29fddb038 100644 --- a/benchmark/cublas/blas3/symm.cpp +++ b/benchmark/cublas/blas3/symm.cpp @@ -158,7 +158,8 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - s_side, s_uplo, m, n, alpha, beta) + s_side, s_uplo, m, n, alpha, beta, + blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, s_side, s_uplo, m, n, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/cublas/blas3/syr2k.cpp b/benchmark/cublas/blas3/syr2k.cpp index 184b556cd..54fdbc0f0 100644 --- a/benchmark/cublas/blas3/syr2k.cpp +++ b/benchmark/cublas/blas3/syr2k.cpp @@ -157,7 +157,8 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - s_uplo, s_trans, n, k, alpha, beta) + s_uplo, s_trans, n, k, alpha, beta, + blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, s_uplo, s_trans, n, k, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/cublas/blas3/syrk.cpp b/benchmark/cublas/blas3/syrk.cpp index 5743402fe..c5cea7fef 100644 --- a/benchmark/cublas/blas3/syrk.cpp +++ b/benchmark/cublas/blas3/syrk.cpp @@ -153,7 +153,8 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - s_uplo, s_trans, n, k, alpha, beta) + s_uplo, s_trans, n, k, alpha, beta, + blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, s_uplo, s_trans, n, k, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/cublas/blas3/trmm.cpp b/benchmark/cublas/blas3/trmm.cpp index c76b12da7..47bfb36e7 100644 --- a/benchmark/cublas/blas3/trmm.cpp +++ b/benchmark/cublas/blas3/trmm.cpp @@ -172,7 +172,8 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - s_side, s_uplo, s_t, s_diag, m, n) + s_side, s_uplo, s_t, s_diag, m, n, + blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, s_side, s_uplo, s_t, s_diag, m, n, alpha, success) diff --git a/benchmark/cublas/blas3/trsm.cpp b/benchmark/cublas/blas3/trsm.cpp index d2d27d570..4d39b712f 100644 --- a/benchmark/cublas/blas3/trsm.cpp +++ b/benchmark/cublas/blas3/trsm.cpp @@ -172,7 +172,7 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - side, uplo, trans, diag, m, n) + side, uplo, trans, diag, m, n, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, side, uplo, trans, diag, m, n, alpha, success) diff --git a/benchmark/cublas/blas3/trsm_batched.cpp b/benchmark/cublas/blas3/trsm_batched.cpp index b52b6c86a..6bdb6bffc 100644 --- a/benchmark/cublas/blas3/trsm_batched.cpp +++ b/benchmark/cublas/blas3/trsm_batched.cpp @@ -210,7 +210,7 @@ void register_benchmark(blas_benchmark::Args& args, benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( s_side, s_uplo, s_t, s_diag, m, n, batch_count, stride_a_mul, - stride_b_mul) + stride_b_mul, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, cuda_handle_ptr, s_side, s_uplo, s_t, s_diag, m, n, alpha, batch_count, stride_a_mul, stride_b_mul, success) diff --git a/benchmark/rocblas/blas3/gemm.cpp b/benchmark/rocblas/blas3/gemm.cpp index 44f92e1f3..c254ff377 100644 --- a/benchmark/rocblas/blas3/gemm.cpp +++ b/benchmark/rocblas/blas3/gemm.cpp @@ -175,8 +175,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, run(st, rb_handle, t1i, t2i, m, k, n, alpha, beta, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(t_a, t_b, m, k, - n) + blas_benchmark::utils::get_name( + t_a, t_b, m, k, n, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/rocblas/blas3/gemm_batched.cpp b/benchmark/rocblas/blas3/gemm_batched.cpp index a4ca0aa53..62b1cbced 100644 --- a/benchmark/rocblas/blas3/gemm_batched.cpp +++ b/benchmark/rocblas/blas3/gemm_batched.cpp @@ -200,7 +200,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - t_a, t_b, m, k, n, batch_size, batch_type) + t_a, t_b, m, k, n, batch_size, batch_type, + blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size, batch_type, success) diff --git a/benchmark/rocblas/blas3/gemm_batched_strided.cpp b/benchmark/rocblas/blas3/gemm_batched_strided.cpp index b64ec2a6f..cdac699ca 100644 --- a/benchmark/rocblas/blas3/gemm_batched_strided.cpp +++ b/benchmark/rocblas/blas3/gemm_batched_strided.cpp @@ -211,7 +211,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( t_a, t_b, m, k, n, batch_size, stride_a_mul, stride_b_mul, - stride_c_mul) + stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size, stride_a_mul, stride_b_mul, stride_c_mul, success) diff --git a/benchmark/rocblas/blas3/symm.cpp b/benchmark/rocblas/blas3/symm.cpp index 0daa71f29..f14413f51 100644 --- a/benchmark/rocblas/blas3/symm.cpp +++ b/benchmark/rocblas/blas3/symm.cpp @@ -165,8 +165,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, run(st, rb_handle, side, uplo, m, n, alpha, beta, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(side, uplo, m, - n, alpha, beta) + blas_benchmark::utils::get_name( + side, uplo, m, n, alpha, beta, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, side, uplo, m, n, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/rocblas/blas3/syr2k.cpp b/benchmark/rocblas/blas3/syr2k.cpp index 503db036e..1e10b2c0d 100644 --- a/benchmark/rocblas/blas3/syr2k.cpp +++ b/benchmark/rocblas/blas3/syr2k.cpp @@ -162,8 +162,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, run(st, rb_handle, uplo, trans, n, k, alpha, beta, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(uplo, trans, n, - k, alpha, beta) + blas_benchmark::utils::get_name( + uplo, trans, n, k, alpha, beta, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, uplo, trans, n, k, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/rocblas/blas3/syrk.cpp b/benchmark/rocblas/blas3/syrk.cpp index ab96d4191..579066b13 100644 --- a/benchmark/rocblas/blas3/syrk.cpp +++ b/benchmark/rocblas/blas3/syrk.cpp @@ -159,8 +159,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, run(st, rb_handle, uplo, trans, n, k, alpha, beta, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(uplo, trans, n, - k, alpha, beta) + blas_benchmark::utils::get_name( + uplo, trans, n, k, alpha, beta, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, uplo, trans, n, k, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/rocblas/blas3/trmm.cpp b/benchmark/rocblas/blas3/trmm.cpp index ac4fe1b0d..09e1d4933 100644 --- a/benchmark/rocblas/blas3/trmm.cpp +++ b/benchmark/rocblas/blas3/trmm.cpp @@ -173,7 +173,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - side, uplo, trans, diag, m, n) + side, uplo, trans, diag, m, n, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, side, uplo, trans, diag, m, n, alpha, success) ->UseRealTime(); diff --git a/benchmark/rocblas/blas3/trsm.cpp b/benchmark/rocblas/blas3/trsm.cpp index 01f2487a8..dbf72bf3e 100644 --- a/benchmark/rocblas/blas3/trsm.cpp +++ b/benchmark/rocblas/blas3/trsm.cpp @@ -175,7 +175,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - side, uplo, trans, diag, m, n) + side, uplo, trans, diag, m, n, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, side, uplo, trans, diag, m, n, alpha, success) ->UseRealTime(); diff --git a/benchmark/rocblas/blas3/trsm_batched.cpp b/benchmark/rocblas/blas3/trsm_batched.cpp index 3b2eb5288..e6d9d89f8 100644 --- a/benchmark/rocblas/blas3/trsm_batched.cpp +++ b/benchmark/rocblas/blas3/trsm_batched.cpp @@ -208,7 +208,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( s_side, s_uplo, s_t, s_diag, m, n, batch_size, stride_a_mul, - stride_b_mul) + stride_b_mul, blas_benchmark::utils::MEM_TYPE_USM) .c_str(), BM_lambda, rb_handle, s_side, s_uplo, s_t, s_diag, m, n, alpha, batch_size, stride_a_mul, stride_b_mul, success) diff --git a/benchmark/syclblas/blas3/gemm.cpp b/benchmark/syclblas/blas3/gemm.cpp index d91388ba8..a87c10559 100644 --- a/benchmark/syclblas/blas3/gemm.cpp +++ b/benchmark/syclblas/blas3/gemm.cpp @@ -133,8 +133,8 @@ void register_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_p run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(t1s, t2s, m, k, - n) + blas_benchmark::utils::get_name( + t1s, t2s, m, k, n, blas_benchmark::utils::MEM_TYPE_BUFFER) .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/syclblas/blas3/gemm_batched.cpp b/benchmark/syclblas/blas3/gemm_batched.cpp index 488a92384..64c928d3b 100644 --- a/benchmark/syclblas/blas3/gemm_batched.cpp +++ b/benchmark/syclblas/blas3/gemm_batched.cpp @@ -202,7 +202,8 @@ void register_benchmark(blas_benchmark::Args& args, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - t1s, t2s, m, k, n, batch_size, batch_type) + t1s, t2s, m, k, n, batch_size, batch_type, + blas_benchmark::utils::MEM_TYPE_BUFFER) .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, batch_type, success) diff --git a/benchmark/syclblas/blas3/gemm_batched_strided.cpp b/benchmark/syclblas/blas3/gemm_batched_strided.cpp index 4a93124d7..c47ab55ad 100644 --- a/benchmark/syclblas/blas3/gemm_batched_strided.cpp +++ b/benchmark/syclblas/blas3/gemm_batched_strided.cpp @@ -175,7 +175,7 @@ void register_benchmark(blas_benchmark::Args& args, benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, - stride_c_mul) + stride_c_mul, blas_benchmark::utils::MEM_TYPE_BUFFER) .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, stride_a_mul, stride_b_mul, stride_c_mul, success) diff --git a/benchmark/syclblas/blas3/symm.cpp b/benchmark/syclblas/blas3/symm.cpp index 17e03a1c7..1448ab7d4 100644 --- a/benchmark/syclblas/blas3/symm.cpp +++ b/benchmark/syclblas/blas3/symm.cpp @@ -132,8 +132,9 @@ void register_benchmark(blas_benchmark::Args& args, success); }; benchmark::RegisterBenchmark( - blas_benchmark::utils::get_name(side, uplo, m, - n, alpha, beta) + blas_benchmark::utils::get_name( + side, uplo, m, n, alpha, beta, + blas_benchmark::utils::MEM_TYPE_BUFFER) .c_str(), BM_lambda, sb_handle_ptr, side_c, uplo_c, m, n, alpha, beta, success) ->UseRealTime(); diff --git a/benchmark/syclblas/blas3/trsm.cpp b/benchmark/syclblas/blas3/trsm.cpp index 9bb7f86b3..bdafd4558 100644 --- a/benchmark/syclblas/blas3/trsm.cpp +++ b/benchmark/syclblas/blas3/trsm.cpp @@ -168,7 +168,8 @@ void register_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_p }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - side, uplo, trans, diag, m, n) + side, uplo, trans, diag, m, n, + blas_benchmark::utils::MEM_TYPE_BUFFER) .c_str(), BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success) ->UseRealTime(); diff --git a/common/include/common/benchmark_names.hpp b/common/include/common/benchmark_names.hpp index d97e5bfbe..c9ea50de0 100644 --- a/common/include/common/benchmark_names.hpp +++ b/common/include/common/benchmark_names.hpp @@ -179,80 +179,19 @@ get_name(std::string uplo, std::string t, std::string diag, index_t n, return internal::get_name(uplo, t, diag, n, mem_type); } -template -inline typename std::enable_if::type -get_name(std::string t, index_t m, index_t n, index_t kl, index_t ku) { - return internal::get_name(t, m, n, kl, ku); -} - -template -inline typename std::enable_if::type -get_name(std::string t, index_t m, index_t n) { - return internal::get_name(t, m, n); -} - -template -inline typename std::enable_if::type get_name( - index_t m, index_t n) { - return internal::get_name(m, n); -} - -template -inline typename std::enable_if::type -get_name(std::string uplo, index_t n, scalar_t alpha) { - return internal::get_name(uplo, n, alpha); -} - -template -inline typename std::enable_if::type get_name( - std::string uplo, index_t n, scalar_t alpha, index_t incx) { - return internal::get_name(uplo, n, alpha, incx); -} - -template -inline typename std::enable_if::type -get_name(std::string uplo, index_t n, scalar_t alpha, scalar_t beta) { - return internal::get_name(uplo, n, alpha, beta); -} - -template -inline typename std::enable_if::type -get_name(std::string uplo, index_t n, scalar_t alpha, index_t incx, - index_t incy) { - return internal::get_name(uplo, n, alpha, incx, incy); -} - -template -inline typename std::enable_if::type -get_name(std::string uplo, std::string t, std::string diag, index_t n, - index_t k) { - return internal::get_name(uplo, t, diag, n, k); -} - -template -inline typename std::enable_if::type -get_name(std::string uplo, std::string t, std::string diag, index_t n) { - return internal::get_name(uplo, t, diag, n); -} - template inline typename std::enable_if::type -get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n) { - return internal::get_name(t1, t2, m, k, n); +get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n, + std::string mem_type) { + return internal::get_name(t1, t2, m, k, n, mem_type); } template inline typename std::enable_if::type get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n, - index_t batch_size, int batch_type) { - return internal::get_name(t1, t2, m, k, n, batch_size, - batch_type_to_str(batch_type)); + index_t batch_size, int batch_type, std::string mem_type) { + return internal::get_name( + t1, t2, m, k, n, batch_size, batch_type_to_str(batch_type), mem_type); } template @@ -260,9 +199,10 @@ inline typename std::enable_if::type get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n, index_t batch_size, index_t stride_a_mul, index_t stride_b_mul, - index_t stride_c_mul) { - return internal::get_name( - t1, t2, m, k, n, batch_size, stride_a_mul, stride_b_mul, stride_c_mul); + index_t stride_c_mul, std::string mem_type) { + return internal::get_name(t1, t2, m, k, n, batch_size, + stride_a_mul, stride_b_mul, + stride_c_mul, mem_type); } template @@ -270,23 +210,27 @@ inline typename std::enable_if::type get_name(std::string s1, std::string s2, index_t m, index_t n, scalar_t alpha, - scalar_t beta) { - return internal::get_name(s1, s2, m, n, alpha, beta); + scalar_t beta, std::string mem_type) { + return internal::get_name(s1, s2, m, n, alpha, beta, mem_type); } template inline typename std::enable_if::type -get_name(char side, char uplo, char trans, char diag, index_t m, index_t n) { - return internal::get_name(side, uplo, trans, diag, m, n); +get_name(char side, char uplo, char trans, char diag, index_t m, index_t n, + std::string mem_type) { + return internal::get_name(side, uplo, trans, diag, m, n, + mem_type); } template inline typename std::enable_if::type get_name(char side, char uplo, char trans, char diag, index_t m, index_t n, - index_t batch_size, index_t stride_a_mul, index_t stride_b_mul) { - return internal::get_name( - side, uplo, trans, diag, m, n, batch_size, stride_a_mul, stride_b_mul); + index_t batch_size, index_t stride_a_mul, index_t stride_b_mul, + std::string mem_type) { + return internal::get_name(side, uplo, trans, diag, m, n, + batch_size, stride_a_mul, + stride_b_mul, mem_type); } } // namespace utils