Skip to content

Commit

Permalink
Add mem_type
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Jul 7, 2023
1 parent aa7b758 commit 65b52ce
Show file tree
Hide file tree
Showing 24 changed files with 60 additions and 107 deletions.
4 changes: 2 additions & 2 deletions benchmark/cublas/blas3/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ void register_benchmark(blas_benchmark::Args& args,
run<scalar_t>(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(t1s, t2s, m, k,
n)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t1s, t2s, m, k, n, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/cublas/blas3/gemm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t1s, t2s, m, k, n, batch_count, batch_type)
t1s, t2s, m, k, n, batch_count, batch_type,
blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_count,
batch_type, success)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/cublas/blas3/gemm_batched_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void register_benchmark(blas_benchmark::Args& args,
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul,
stride_c_mul)
stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size,
stride_a_mul, stride_b_mul, stride_c_mul, success)
Expand Down
3 changes: 2 additions & 1 deletion benchmark/cublas/blas3/symm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
s_side, s_uplo, m, n, alpha, beta)
s_side, s_uplo, m, n, alpha, beta,
blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, s_side, s_uplo, m, n, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/cublas/blas3/syr2k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
s_uplo, s_trans, n, k, alpha, beta)
s_uplo, s_trans, n, k, alpha, beta,
blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, s_uplo, s_trans, n, k, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/cublas/blas3/syrk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
s_uplo, s_trans, n, k, alpha, beta)
s_uplo, s_trans, n, k, alpha, beta,
blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, s_uplo, s_trans, n, k, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/cublas/blas3/trmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
s_side, s_uplo, s_t, s_diag, m, n)
s_side, s_uplo, s_t, s_diag, m, n,
blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, s_side, s_uplo, s_t, s_diag, m, n, alpha,
success)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/cublas/blas3/trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n)
side, uplo, trans, diag, m, n, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, side, uplo, trans, diag, m, n, alpha,
success)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/cublas/blas3/trsm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ void register_benchmark(blas_benchmark::Args& args,
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
s_side, s_uplo, s_t, s_diag, m, n, batch_count, stride_a_mul,
stride_b_mul)
stride_b_mul, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, cuda_handle_ptr, s_side, s_uplo, s_t, s_diag, m, n, alpha,
batch_count, stride_a_mul, stride_b_mul, success)
Expand Down
4 changes: 2 additions & 2 deletions benchmark/rocblas/blas3/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
run<scalar_t>(st, rb_handle, t1i, t2i, m, k, n, alpha, beta, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(t_a, t_b, m, k,
n)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t_a, t_b, m, k, n, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/rocblas/blas3/gemm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t_a, t_b, m, k, n, batch_size, batch_type)
t_a, t_b, m, k, n, batch_size, batch_type,
blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size,
batch_type, success)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/rocblas/blas3/gemm_batched_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t_a, t_b, m, k, n, batch_size, stride_a_mul, stride_b_mul,
stride_c_mul)
stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size,
stride_a_mul, stride_b_mul, stride_c_mul, success)
Expand Down
4 changes: 2 additions & 2 deletions benchmark/rocblas/blas3/symm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
run<scalar_t>(st, rb_handle, side, uplo, m, n, alpha, beta, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(side, uplo, m,
n, alpha, beta)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, m, n, alpha, beta, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, side, uplo, m, n, alpha, beta, success)
->UseRealTime();
Expand Down
4 changes: 2 additions & 2 deletions benchmark/rocblas/blas3/syr2k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
run<scalar_t>(st, rb_handle, uplo, trans, n, k, alpha, beta, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(uplo, trans, n,
k, alpha, beta)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
uplo, trans, n, k, alpha, beta, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, uplo, trans, n, k, alpha, beta, success)
->UseRealTime();
Expand Down
4 changes: 2 additions & 2 deletions benchmark/rocblas/blas3/syrk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
run<scalar_t>(st, rb_handle, uplo, trans, n, k, alpha, beta, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(uplo, trans, n,
k, alpha, beta)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
uplo, trans, n, k, alpha, beta, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, uplo, trans, n, k, alpha, beta, success)
->UseRealTime();
Expand Down
2 changes: 1 addition & 1 deletion benchmark/rocblas/blas3/trmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n)
side, uplo, trans, diag, m, n, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, side, uplo, trans, diag, m, n, alpha, success)
->UseRealTime();
Expand Down
2 changes: 1 addition & 1 deletion benchmark/rocblas/blas3/trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n)
side, uplo, trans, diag, m, n, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, side, uplo, trans, diag, m, n, alpha, success)
->UseRealTime();
Expand Down
2 changes: 1 addition & 1 deletion benchmark/rocblas/blas3/trsm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle,
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
s_side, s_uplo, s_t, s_diag, m, n, batch_size, stride_a_mul,
stride_b_mul)
stride_b_mul, blas_benchmark::utils::MEM_TYPE_USM)
.c_str(),
BM_lambda, rb_handle, s_side, s_uplo, s_t, s_diag, m, n, alpha,
batch_size, stride_a_mul, stride_b_mul, success)
Expand Down
4 changes: 2 additions & 2 deletions benchmark/syclblas/blas3/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ void register_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_p
run<scalar_t>(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(t1s, t2s, m, k,
n)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t1s, t2s, m, k, n, blas_benchmark::utils::MEM_TYPE_BUFFER)
.c_str(),
BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/syclblas/blas3/gemm_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ void register_benchmark(blas_benchmark::Args& args,
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t1s, t2s, m, k, n, batch_size, batch_type)
t1s, t2s, m, k, n, batch_size, batch_type,
blas_benchmark::utils::MEM_TYPE_BUFFER)
.c_str(),
BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size,
batch_type, success)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/syclblas/blas3/gemm_batched_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void register_benchmark(blas_benchmark::Args& args,
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul,
stride_c_mul)
stride_c_mul, blas_benchmark::utils::MEM_TYPE_BUFFER)
.c_str(),
BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size,
stride_a_mul, stride_b_mul, stride_c_mul, success)
Expand Down
5 changes: 3 additions & 2 deletions benchmark/syclblas/blas3/symm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ void register_benchmark(blas_benchmark::Args& args,
success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(side, uplo, m,
n, alpha, beta)
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, m, n, alpha, beta,
blas_benchmark::utils::MEM_TYPE_BUFFER)
.c_str(),
BM_lambda, sb_handle_ptr, side_c, uplo_c, m, n, alpha, beta, success)
->UseRealTime();
Expand Down
3 changes: 2 additions & 1 deletion benchmark/syclblas/blas3/trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ void register_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_p
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
side, uplo, trans, diag, m, n)
side, uplo, trans, diag, m, n,
blas_benchmark::utils::MEM_TYPE_BUFFER)
.c_str(),
BM_lambda, sb_handle_ptr, side, uplo, trans, diag, m, n, alpha, success)
->UseRealTime();
Expand Down
98 changes: 21 additions & 77 deletions common/include/common/benchmark_names.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,114 +179,58 @@ get_name(std::string uplo, std::string t, std::string diag, index_t n,
return internal::get_name<op, scalar_t>(uplo, t, diag, n, mem_type);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::gbmv, std::string>::type
get_name(std::string t, index_t m, index_t n, index_t kl, index_t ku) {
return internal::get_name<op, scalar_t>(t, m, n, kl, ku);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::gemv || op == Level2Op::sbmv,
std::string>::type
get_name(std::string t, index_t m, index_t n) {
return internal::get_name<op, scalar_t>(t, m, n);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::ger, std::string>::type get_name(
index_t m, index_t n) {
return internal::get_name<op, scalar_t>(m, n);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::syr || op == Level2Op::syr2,
std::string>::type
get_name(std::string uplo, index_t n, scalar_t alpha) {
return internal::get_name<op, scalar_t>(uplo, n, alpha);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::spr, std::string>::type get_name(
std::string uplo, index_t n, scalar_t alpha, index_t incx) {
return internal::get_name<op, scalar_t>(uplo, n, alpha, incx);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::spmv || op == Level2Op::symv,
std::string>::type
get_name(std::string uplo, index_t n, scalar_t alpha, scalar_t beta) {
return internal::get_name<op, scalar_t>(uplo, n, alpha, beta);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::spr2, std::string>::type
get_name(std::string uplo, index_t n, scalar_t alpha, index_t incx,
index_t incy) {
return internal::get_name<op, scalar_t>(uplo, n, alpha, incx, incy);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::tbmv || op == Level2Op::tbsv,
std::string>::type
get_name(std::string uplo, std::string t, std::string diag, index_t n,
index_t k) {
return internal::get_name<op, scalar_t>(uplo, t, diag, n, k);
}

template <Level2Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level2Op::tpmv || op == Level2Op::trmv ||
op == Level2Op::trsv,
std::string>::type
get_name(std::string uplo, std::string t, std::string diag, index_t n) {
return internal::get_name<op, scalar_t>(uplo, t, diag, n);
}

template <Level3Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level3Op::gemm, std::string>::type
get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n) {
return internal::get_name<op, scalar_t>(t1, t2, m, k, n);
get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n,
std::string mem_type) {
return internal::get_name<op, scalar_t>(t1, t2, m, k, n, mem_type);
}

template <Level3Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level3Op::gemm_batched, std::string>::type
get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n,
index_t batch_size, int batch_type) {
return internal::get_name<op, scalar_t>(t1, t2, m, k, n, batch_size,
batch_type_to_str(batch_type));
index_t batch_size, int batch_type, std::string mem_type) {
return internal::get_name<op, scalar_t>(
t1, t2, m, k, n, batch_size, batch_type_to_str(batch_type), mem_type);
}

template <Level3Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level3Op::gemm_batched_strided,
std::string>::type
get_name(std::string t1, std::string t2, index_t m, index_t k, index_t n,
index_t batch_size, index_t stride_a_mul, index_t stride_b_mul,
index_t stride_c_mul) {
return internal::get_name<op, scalar_t>(
t1, t2, m, k, n, batch_size, stride_a_mul, stride_b_mul, stride_c_mul);
index_t stride_c_mul, std::string mem_type) {
return internal::get_name<op, scalar_t>(t1, t2, m, k, n, batch_size,
stride_a_mul, stride_b_mul,
stride_c_mul, mem_type);
}

template <Level3Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level3Op::symm || op == Level3Op::syr2k ||
op == Level3Op::syrk,
std::string>::type
get_name(std::string s1, std::string s2, index_t m, index_t n, scalar_t alpha,
scalar_t beta) {
return internal::get_name<op, scalar_t>(s1, s2, m, n, alpha, beta);
scalar_t beta, std::string mem_type) {
return internal::get_name<op, scalar_t>(s1, s2, m, n, alpha, beta, mem_type);
}

template <Level3Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level3Op::trsm || op == Level3Op::trmm,
std::string>::type
get_name(char side, char uplo, char trans, char diag, index_t m, index_t n) {
return internal::get_name<op, scalar_t>(side, uplo, trans, diag, m, n);
get_name(char side, char uplo, char trans, char diag, index_t m, index_t n,
std::string mem_type) {
return internal::get_name<op, scalar_t>(side, uplo, trans, diag, m, n,
mem_type);
}

template <Level3Op op, typename scalar_t, typename index_t>
inline typename std::enable_if<op == Level3Op::trsm_batched, std::string>::type
get_name(char side, char uplo, char trans, char diag, index_t m, index_t n,
index_t batch_size, index_t stride_a_mul, index_t stride_b_mul) {
return internal::get_name<op, scalar_t>(
side, uplo, trans, diag, m, n, batch_size, stride_a_mul, stride_b_mul);
index_t batch_size, index_t stride_a_mul, index_t stride_b_mul,
std::string mem_type) {
return internal::get_name<op, scalar_t>(side, uplo, trans, diag, m, n,
batch_size, stride_a_mul,
stride_b_mul, mem_type);
}

} // namespace utils
Expand Down

0 comments on commit 65b52ce

Please sign in to comment.