Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed blocking heuristics for M,K,N in BRGEMM Matmul to enable better memory/thread utilization in aarch64 #2103

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/cpu/aarch64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2023 FUJITSU LIMITED
* Copyright 2023-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -191,6 +191,8 @@ struct brgemm_t {
int LDB = 0;
int LDC = 0;
int LDD = 0;

int M, K, N;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M, K, and N are literally declared above as bcast_dim, reduce_dim and load_dim. Any specific reason to add duplicated entries for them?

// we use two isa_ variables
// isa_user to store the user provided isa value
// isa_impl to store actual implementation. This can change until the kernel
Expand Down
91 changes: 31 additions & 60 deletions src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,10 +766,38 @@ void jit_brgemm_kernel_t::read_params() {
void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, bool is_bdb_tail,
int ld_block2, bool is_ld_tail, bool skip_accumulation) {
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
const bool need_to_apply_beta = brg.beta != 0.f;
for_(int bd = 0; bd < bd_block; bd++)
for (int ld = 0; ld < ld_block2; ld++) {
auto zmm = accm(ld_block2, bd, ld);
eor(zmm.d, zmm.d, zmm.d);
// This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided.
// This is also required only when K is blocked.
if (need_to_apply_beta && brg.K != brg.reduce_dim) {
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;

const int offset = C_offset(bd, ld);

int base_offset = 0;
auto x_addr = reg_aux_C;

if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
base_offset = offset;
x_addr = reg_tmp_;
}
LD_MUL_VL(ld1w, zmm.s, k_mask, x_addr, offset - base_offset, 4);

const bool need_init_beta_vmm = brg.beta != 1.f;
auto vmm_beta = z_tail_mask();
if (need_init_beta_vmm) {
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
dup(vmm_beta.s, wreg_tmp);
fmul(zmm.s, zmm.s, vmm_beta.s);
}
} else
eor(zmm.d, zmm.d, zmm.d);
}
}

Expand All @@ -791,57 +819,7 @@ void jit_brgemm_kernel_t::apply_alpha_beta(
if (apply_alpha) { fmul(vmm.s, vmm.s, vmm_alpha.s); }
}

if (brg.beta == 0.f) return;
const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
const bool need_init_beta_vmm = brg.beta != 1.f;
auto vmm_prev_dst = z_tmp_1();
auto vmm_beta = z_tail_mask();
if (need_init_beta_vmm) {
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
dup(vmm_beta.s, wreg_tmp);
}

int base_offset = 0;
auto x_addr = reg_aux_C;
for_(int bd = 0; bd < bd_block; bd++)
for (int ld = 0; ld < ld_block2; ld++) {
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
auto vmm = accm(ld_block2, bd, ld);
if (use_vadd_for_beta) {
if (brg.is_int8) {
assert(!"unsupported\n");
} else {
ZRegS z_masked = vmm.s;
ZRegS z(vmm.getIdx());

const int offset = C_offset(bd, ld);

if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
base_offset = offset;
x_addr = reg_tmp_;
}
LD_MUL_VL(ld1w, vmm_prev_dst.s, k_mask, x_addr,
offset - base_offset, 4);
if (is_ld_tail) {
movprfx(z_masked, k_mask / T_z, z);
fadd(z_masked, k_mask / T_m, vmm_prev_dst.s);
} else {
fadd(z_masked, z_masked, vmm_prev_dst.s);
}
}
} else {
add_imm(X_DEFAULT_ADDR, reg_aux_C, C_offset(bd, ld), X_TMP_0);
ld1w(vmm_prev_dst.s, k_mask / T_z, ptr(X_DEFAULT_ADDR));
if (brg.beta == 1.f) {
fadd(vmm.s, vmm.s, vmm_prev_dst.s);
} else {
fmla(vmm.s, P_ALL_ONE / T_m, vmm_prev_dst.s, vmm_beta.s);
}
}
}
// This part is moved to the function zero_accumulators.
}

void jit_brgemm_kernel_t::apply_post_ops(
Expand Down Expand Up @@ -1464,7 +1442,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
int base_offset = 0;

for (int rd = 0; rd < rd_loop; rd += brg.rd_step) {
int prefetch_count_B = 0;
for (int ld = 0; ld < ld_block2; ld++) {
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
if (brg.dt_b == data_type::f16) {
Expand Down Expand Up @@ -1496,13 +1473,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
broadcast(bcst(), A_offset(bd, rd),
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
}
if (prefetch_count_B < ld_block2) {
add_imm(X_DEFAULT_ADDR, reg_aux_B,
B_offset(prefetch_count_B++, rd)
+ brg.LDB * brg.rd_block * brg.typesize_B,
X_TMP_0);
prfm(PLDL1KEEP, ptr(X_DEFAULT_ADDR));
}
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
for (int ld = 0; ld < ld_block2; ld++) {
auto zmm = accm(ld_block2, bd, ld);
if (is_emdbd) {
Expand Down
3 changes: 3 additions & 0 deletions src/cpu/aarch64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
? (dim_t)bgmmc_.wei_k_blk
: bgmmc_.LDA;
const auto kernel_isa = i_M == max_m_ker_idx - 1 ? backup_isa : isa;
brg.M = bgmmc_.M;
brg.K = bgmmc_.K;
brg.N = bgmmc_.N;
CHECK(brgemm_desc_init(&brg, kernel_isa, bgmmc_.brg_type, bgmmc_.src_dt,
bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta,
LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK));
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init(
matmul_conf_for_reorder_.K = dims[ndims - 2];
matmul_conf_for_reorder_.N = dims[ndims - 1];
matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk
= matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag);
= matmul_conf_for_reorder_.LDB
= matmul::get_default_n_block(otag, matmul_conf_for_reorder_);
matmul_conf_for_reorder_.N_tail
= matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk;
matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity;
Expand Down
50 changes: 43 additions & 7 deletions src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,7 +48,8 @@ using namespace dnnl::impl::utils;
using namespace data_type;
using namespace format_tag;

int get_default_n_block(format_tag_t matrix_b_tag) {
int get_default_n_block(
format_tag_t matrix_b_tag, brgemm_matmul_conf_t &bgmmc) {
// Note: consider using weights mem_descriptor 'inner_blks' to
// return B's inner block for non-default cases.
switch (matrix_b_tag) {
Expand Down Expand Up @@ -75,7 +77,23 @@ int get_default_n_block(format_tag_t matrix_b_tag) {
case BA16a16b:
case BA16a16b2a:
case BA16a16b4a: return 16;
default: return 64;
default: {
if (bgmmc.N == 16 || bgmmc.N == 32 || bgmmc.N == 64) return bgmmc.N;
if (!mayiuse(sve_512)) {
if (bgmmc.N <= 16)
return 16;
else {
// It is observed that for M,K>512, N block of 64 works better provided that thread distribution is not hindered.
if (bgmmc.N / 64 >= bgmmc.nthr && bgmmc.K > 512
&& bgmmc.M > 512)
return 64;
else
return 32;
}

} else
return 64;
}
}
}

Expand Down Expand Up @@ -178,7 +196,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag(

if (B_any_layout) {
const int default_n_block = init_n_tag
? get_default_n_block(format_tag::undef)
? get_default_n_block(format_tag::undef, bgmmc)
: bgmmc.N_blk;
bgmmc.wei_tag = blocked_B_layouts_allowed
? this->pick_blocked_B_layout(default_n_block)
Expand Down Expand Up @@ -580,14 +598,17 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
const int nthr = bgmmc.nthr;

const int max_m_blk = nstl::min(/*64*/ 256, matmul.M);
int min_m_blk = nstl::min(32, matmul.M); // max_m_blk
// It is found that for 2d shapes min_m_blk = 128 works better than 32 for most of the shapes.
int min_m = (matmul.batch > 1) ? 32 : 128;
int min_m_blk = nstl::min(min_m, matmul.M); // max_m_blk

int n_blk = bgmmc.N_blk;
const int n_chunks = div_up(matmul.N, n_blk);
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);

int default_k_blk = 1024;
//It is found that for M<512 k_blk of 128 works better than 1024 for most of the shapes.
int default_k_blk = (matmul.M >= 512) ? 1024 : 128;
int k_blk = nstl::min(matmul.K, default_k_blk);
int start_nthr_k = 1;

Expand All @@ -597,7 +618,22 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel;
if (low_parallel_work) {

min_m_blk = nstl::min(matmul.M, 16);
int best_m_blk = 0;
float scr = 0, best_scr = 16 * nthr;
for (int i = 16; i >= 4; i--) {
scr = 0.7 * (matmul.M % i)
+ 0.3 * std::abs(nthr - ((float)matmul.M / (float)i));
if (scr < best_scr) {
best_scr = scr;
best_m_blk = i;
}
}
min_m_blk = nstl::min(matmul.M, best_m_blk);
// Here min_m_blk is set based on M value and no.of threads. Decreasing m_blk size will
// increase no.of m blocks which might make better utilisation of threads. But it is found
// that m_blk being a factor of M is more important than max thread utilisation.Therefore
// in scoring that has been given more weightage(0.7). This was experimentally verified to
// be the best hueristics with multiple shapes.

bool low_spatial_work = matmul.M <= 40;
if (low_spatial_work) {
Expand Down Expand Up @@ -834,7 +870,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,

VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);

bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag);
bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag, bgmmc);

bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -312,7 +313,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const brgemm_matmul_conf_t &bgmmc);

int get_default_n_block(format_tag_t matrix_b_tag);
int get_default_n_block(format_tag_t, brgemm_matmul_conf_t &bgmmc);

} // namespace matmul
} // namespace aarch64
Expand Down
Loading