Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workgroup strided transforms #143

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 36 additions & 48 deletions src/portfft/committed_descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include "common/exceptions.hpp"
#include "common/subgroup.hpp"
#include "common/workgroup.hpp"
#include "defines.hpp"
#include "enums.hpp"
#include "specialization_constant.hpp"
Expand Down Expand Up @@ -215,57 +216,44 @@ class committed_descriptor_impl {
throw unsupported_configuration("portFFT only supports complex to complex transforms");
}

std::vector<sycl::kernel_id> ids;
std::vector<Idx> factors;
IdxGlobal fft_size = static_cast<IdxGlobal>(params.lengths[kernel_num]);
if (detail::fits_in_wi<Scalar>(fft_size)) {
ids = detail::get_ids<detail::workitem_kernel, Scalar, Domain, SubgroupSize>();
PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size);
return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}};
}
if (detail::fits_in_sg<Scalar>(fft_size, SubgroupSize)) {
Idx factor_sg = detail::factorize_sg(static_cast<Idx>(fft_size), SubgroupSize);
Idx factor_wi = static_cast<Idx>(fft_size) / factor_sg;
// This factorization is duplicated in the dispatch logic on the device.
// The CT and spec constant factors should match.
factors.push_back(factor_wi);
factors.push_back(factor_sg);
ids = detail::get_ids<detail::subgroup_kernel, Scalar, Domain, SubgroupSize>();
PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg);
return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}};
}
IdxGlobal n_idx_global = detail::factorize(fft_size);
if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) &&
detail::can_cast_safely<IdxGlobal, Idx>(fft_size / n_idx_global)) {
if (n_idx_global == 1) {
throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported");
if (static_cast<size_t>(fft_size) * 2 * sizeof(Scalar) <= static_cast<size_t>(local_memory_size)) {
// These implementations only work if the size fits in local memory.
// They still may not be suitable if the extra local memory needed for the algorithm exceeds the available memory.

if (detail::fits_in_wi<Scalar>(fft_size)) {
auto ids = detail::get_ids<detail::workitem_kernel, Scalar, Domain, SubgroupSize>();
PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size);
return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, {}}}};
}
Idx n = static_cast<Idx>(n_idx_global);
Idx m = static_cast<Idx>(fft_size / n_idx_global);
Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize);
Idx factor_wi_n = n / factor_sg_n;
Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize);
Idx factor_wi_m = m / factor_sg_m;
Idx temp_num_sgs_in_wg;
std::size_t local_memory_usage =
num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast<std::size_t>(fft_size), SubgroupSize,
{factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg,
layout::PACKED) *
sizeof(Scalar);
// Checks for PACKED layout only at the moment, as the other layout will not be supported
// by the global implementation. For such sizes, only PACKED layout will be supported
if (detail::fits_in_wi<Scalar>(factor_wi_n) && detail::fits_in_wi<Scalar>(factor_wi_m) &&
(local_memory_usage <= static_cast<std::size_t>(local_memory_size))) {
factors.push_back(factor_wi_n);
factors.push_back(factor_sg_n);
factors.push_back(factor_wi_m);
factors.push_back(factor_sg_m);
// This factorization of N and M is duplicated in the dispatch logic on the device.
if (detail::fits_in_sg<Scalar>(fft_size, SubgroupSize)) {
Idx factor_sg = detail::factorize_sg(static_cast<Idx>(fft_size), SubgroupSize);
Idx factor_wi = static_cast<Idx>(fft_size) / factor_sg;
// This factorization is duplicated in the dispatch logic on the device.
// The CT and spec constant factors should match.
ids = detail::get_ids<detail::workgroup_kernel, Scalar, Domain, SubgroupSize>();
PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n,
" factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m);
return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}};
auto ids = detail::get_ids<detail::subgroup_kernel, Scalar, Domain, SubgroupSize>();
PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg);
return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, {factor_wi, factor_sg}}}};
}
if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value();
Idx temp_num_sgs_in_wg;
std::size_t local_memory_usage =
num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast<std::size_t>(fft_size), SubgroupSize,
{factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg,
layout::PACKED) *
sizeof(Scalar);
// Checks for PACKED layout only at the moment, as the other layout will not be supported
// by the global implementation. For such sizes, only PACKED layout will be supported
if (local_memory_usage <= static_cast<std::size_t>(local_memory_size)) {
// This factorization of N and M is duplicated in the dispatch logic on the device.
// The CT and spec constant factors should match.
auto ids = detail::get_ids<detail::workgroup_kernel, Scalar, Domain, SubgroupSize>();
PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n,
" factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m);
return {detail::level::WORKGROUP,
{{detail::level::WORKGROUP, ids, {factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m}}}};
}
}
}
PORTFFT_LOG_TRACE("Preparing global impl");
Expand Down
44 changes: 44 additions & 0 deletions src/portfft/common/workgroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
#ifndef PORTFFT_COMMON_WORKGROUP_HPP
#define PORTFFT_COMMON_WORKGROUP_HPP

#include <optional>

#include "helpers.hpp"
#include "logging.hpp"
#include "memory_views.hpp"
#include "portfft/defines.hpp"
#include "portfft/enums.hpp"
#include "portfft/traits.hpp"
#include "portfft/utils.hpp"
#include "subgroup.hpp"
#include "transfers.hpp"

namespace portfft {

Expand All @@ -53,6 +58,45 @@ constexpr T bank_lines_per_pad_wg(T row_size) {
}

namespace detail {

// struct for the result of factorize_for_wg
struct wg_factorization {
Idx factor_wi_n;
Idx factor_sg_n;
Idx factor_wi_m;
Idx factor_sg_m;
};

/** Calculate a valid factorization for workgroup dfts, assuming there is sufficient local memory.
* @tparam Scalar scalar type of the transform data
* @param fft_size the number of elements in the transforms
* @param subgroup_size the size of subgroup used for the transform
* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts.
*/
template <typename Scalar>
std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move this to utils.hpp, we have only device callable functions in the common folder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have the factorization functions for workitem and subgroup in common/workitem.hpp and common/subgroup.hpp respectively, so I was following the example there.
factorize_sg is not called from device anywhere, along with fits_in_sg and fits_in_wi, so I wouldn't say we only have device callable functions in the common folder.
If we do want to refactor to puts the factorization functions in a utility file, then we should group them and put them in a "factorization.hpp" or something like that. Generic util files are a bit of a code smell imo (though I am guilty of committing that sin).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree that a factorisation.hpp would be better than having everything in utils.hpp.

IdxGlobal n_idx_global = detail::factorize(fft_size);
hjabird marked this conversation as resolved.
Show resolved Hide resolved
if (n_idx_global == 1) {
return std::nullopt;
}

IdxGlobal m_idx_global = fft_size / n_idx_global;
if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) && detail::can_cast_safely<IdxGlobal, Idx>(m_idx_global)) {
Idx n = static_cast<Idx>(n_idx_global);
Idx m = static_cast<Idx>(m_idx_global);
Idx factor_sg_n = detail::factorize_sg(n, subgroup_size);
Idx factor_wi_n = n / factor_sg_n;
Idx factor_sg_m = detail::factorize_sg(m, subgroup_size);
Idx factor_wi_m = m / factor_sg_m;

if (fits_in_wi<Scalar>(factor_wi_n) && fits_in_wi<Scalar>(factor_wi_m)) {
return wg_factorization{factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m};
}
}

return std::nullopt;
}

/**
* Calculate all dfts in one dimension of the data stored in local memory.
*
Expand Down
7 changes: 4 additions & 3 deletions src/portfft/descriptor_validation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
#include <string_view>

#include "common/exceptions.hpp"
#include "common/subgroup.hpp"
#include "common/workgroup.hpp"
#include "common/workitem.hpp"
#include "enums.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -67,8 +68,8 @@ inline void validate_layout(const std::vector<std::size_t>& lengths, portfft::de
if (forward_layout == portfft::detail::layout::UNPACKED || backward_layout == portfft::detail::layout::UNPACKED) {
bool fits_subgroup = false;
for (auto sg_size : {PORTFFT_SUBGROUP_SIZES}) {
fits_subgroup =
fits_subgroup || portfft::detail::fits_in_sg<Scalar>(static_cast<IdxGlobal>(lengths.back()), sg_size);
fits_subgroup = fits_subgroup || portfft::detail::fits_in_wi<Scalar>(lengths.back()) ||
portfft::detail::factorize_for_wg<Scalar>(static_cast<IdxGlobal>(lengths.back()), sg_size);
if (fits_subgroup) {
break;
}
Expand Down
Loading
Loading