Skip to content

Commit

Permalink
[SYCL][UR] Use Windows proxy loader for UR (intel#15262)
Browse files Browse the repository at this point in the history
The issues with DLLs and teardown of global objects on Windows is well
documented, and was the reason for the use of the `pi_win_proxy_loader`
library in SYCL-RT. When we ported from PI to UR, we ported this library
(it's now called `ur_win_proxy_loader`), but it was not actually fully
utilized. SYCL-RT still linked with `ur_loader.dll` and still
experienced issues with race conditions in the teardown of SYCL-RT and
Unified Runtime. See intel#14768.

This PR reintroduces the proxy loader as it was previously used with PI.
The UR loader (`ur_loader.dll`) is loaded via `LoadLibraryEx` at
initialization, and is therefore not cleaned up too early for normal
teardown to occur.

This necessitates changing the signature of `Plugin->call` to look like
it did with PI, taking an enum template argument to specify which UR
entry point to call.

On Windows, when each plugin (which is a wrapper over a UR adapter) is
loaded, it populates a table of function pointers to each API entry
point in the UR loader. When UR entry points are called, the function
pointer is retrieved from the table. This is more or less equivalent to
the previous PI implementation.

On Linux, the UR loader is dynamically linked as before. The
`Plugin->call` methods just use the regular UR functions rather than
programmatically looking up the symbols.

For the unittest executables, the UR loader is still dynamically linked
as before to avoid having to introduce noisy changes to the tests, and
since we aren't concerned about teardown issues there.

The implementation of these changes in the runtime should avoid as much
overhead as possible (and be no worse than PI), but suggestions on how
to improve and tidy things are more than welcome.

Associated UR change:
oneapi-src/unified-runtime#2045
  • Loading branch information
callumfare committed Sep 6, 2024
1 parent e0e7b50 commit 39cd4f3
Show file tree
Hide file tree
Showing 56 changed files with 1,275 additions and 1,097 deletions.
3 changes: 3 additions & 0 deletions sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_directory ${sycl_inc_dir}/syclcompat ${SYCL_INCLUDE_BUILD_DIR}/syclcompat
COMMAND ${CMAKE_COMMAND} -E copy ${sycl_inc_dir}/syclcompat.hpp ${SYCL_INCLUDE_BUILD_DIR}/syclcompat.hpp
COMMAND ${CMAKE_COMMAND} -E copy ${UNIFIED_RUNTIME_INCLUDE_DIR}/ur_api.h ${SYCL_INCLUDE_BUILD_DIR}/sycl
COMMAND ${CMAKE_COMMAND} -E copy ${UNIFIED_RUNTIME_INCLUDE_DIR}/ur_api_funcs.def ${SYCL_INCLUDE_BUILD_DIR}/sycl
COMMAND ${CMAKE_COMMAND} -E copy ${UNIFIED_RUNTIME_INCLUDE_DIR}/ur_print.hpp ${SYCL_INCLUDE_BUILD_DIR}/sycl
COMMENT "Copying SYCL headers ...")

Expand All @@ -263,6 +264,8 @@ install(DIRECTORY "${sycl_inc_dir}/syclcompat" DESTINATION ${SYCL_INCLUDE_DIR} C
install(FILES "${sycl_inc_dir}/syclcompat.hpp" DESTINATION ${SYCL_INCLUDE_DIR} COMPONENT sycl-headers)
install(FILES "${UNIFIED_RUNTIME_INCLUDE_DIR}/ur_api.h" DESTINATION ${SYCL_INCLUDE_DIR}/sycl
COMPONENT sycl-headers)
install(FILES "${UNIFIED_RUNTIME_INCLUDE_DIR}/ur_api_funcs.def" DESTINATION ${SYCL_INCLUDE_DIR}/sycl
COMPONENT sycl-headers)
install(FILES "${UNIFIED_RUNTIME_INCLUDE_DIR}/ur_print.hpp" DESTINATION ${SYCL_INCLUDE_DIR}/sycl
COMPONENT sycl-headers)

Expand Down
4 changes: 4 additions & 0 deletions sycl/cmake/modules/AddSYCLUnitTest.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ macro(add_sycl_unittest test_dirname link_variant)
target_link_libraries(${test_dirname} PRIVATE sycl-jit)
endif(SYCL_ENABLE_EXTENSION_JIT)

if(WIN32)
target_link_libraries(${test_dirname} PRIVATE UnifiedRuntimeLoader ur_win_proxy_loader)
endif()

target_include_directories(${test_dirname}
PRIVATE SYSTEM
${sycl_inc_dir}
Expand Down
51 changes: 50 additions & 1 deletion sycl/include/sycl/detail/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <sycl/backend_types.hpp>
#include <sycl/detail/export.hpp>
#include <sycl/detail/os_util.hpp>
#
#include <ur_api.h>

#include <memory>
Expand Down Expand Up @@ -48,6 +47,54 @@ class context;

namespace detail {

enum class UrApiKind {
#define _UR_API(api) api,
#include <ur_api_funcs.def>
#undef _UR_API
};

struct UrFuncPtrMapT {
#define _UR_API(api) decltype(&::api) pfn_##api = nullptr;
#include <ur_api_funcs.def>
#undef _UR_API
};

template <UrApiKind UrApiOffset> struct UrFuncInfo {};

#ifdef _WIN32
void *GetWinProcAddress(void *module, const char *funcName);
inline void PopulateUrFuncPtrTable(UrFuncPtrMapT *funcs, void *module) {
#define _UR_API(api) \
funcs->pfn_##api = (decltype(&::api))GetWinProcAddress(module, #api);
#include <ur_api_funcs.def>
#undef _UR_API
}

#define _UR_API(api) \
template <> struct UrFuncInfo<UrApiKind::api> { \
using FuncPtrT = decltype(&::api); \
inline const char *getFuncName() { return #api; } \
inline FuncPtrT getFuncPtr(const UrFuncPtrMapT *funcs) { \
return funcs->pfn_##api; \
} \
inline FuncPtrT getFuncPtrFromModule(void *module) { \
return (FuncPtrT)GetWinProcAddress(module, #api); \
} \
};
#include <ur_api_funcs.def>
#undef _UR_API
#else
#define _UR_API(api) \
template <> struct UrFuncInfo<UrApiKind::api> { \
using FuncPtrT = decltype(&::api); \
inline const char *getFuncName() { return #api; } \
constexpr inline FuncPtrT getFuncPtr(const void *) { return &api; } \
constexpr inline FuncPtrT getFuncPtrFromModule(void *) { return &api; } \
};
#include <ur_api_funcs.def>
#undef _UR_API
#endif

namespace pi {
// This function is deprecated and it should be removed in the next release
// cycle (along with the definition for pi_context_extended_deleter).
Expand Down Expand Up @@ -76,6 +123,8 @@ int unloadOsLibrary(void *Library);
// library, implementation is OS dependent.
void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName);

void *getURLoaderLibrary();

// Performs UR one-time initialization.
std::vector<PluginPtr> &
initializeUr(ur_loader_config_handle_t LoaderConfig = nullptr);
Expand Down
11 changes: 9 additions & 2 deletions sycl/source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ function(add_sycl_rt_library LIB_NAME LIB_OBJ_NAME)
# Link and include UR
target_link_libraries(${LIB_OBJ_NAME}
PRIVATE
UnifiedRuntimeLoader
UnifiedRuntime-Headers
UnifiedRuntimeCommon
)
Expand All @@ -183,11 +182,19 @@ function(add_sycl_rt_library LIB_NAME LIB_OBJ_NAME)

target_link_libraries(${LIB_NAME}
PRIVATE
UnifiedRuntimeLoader
UnifiedRuntime-Headers
UnifiedRuntimeCommon
)

if (NOT WIN32)
target_link_libraries(${LIB_NAME}
PRIVATE
UnifiedRuntimeLoader
)
else()
add_dependencies(${LIB_NAME} UnifiedRuntimeLoader)
endif()

target_include_directories(${LIB_NAME}
PRIVATE
"${UNIFIED_RUNTIME_SRC_INCLUDE_DIR}"
Expand Down
84 changes: 43 additions & 41 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ platform make_platform(ur_native_handle_t NativeHandle, backend Backend) {

// Create UR platform first.
ur_platform_handle_t UrPlatform = nullptr;
Plugin->call(urPlatformCreateWithNativeHandle, NativeHandle,
Plugin->getUrAdapter(), nullptr, &UrPlatform);
Plugin->call<UrApiKind::urPlatformCreateWithNativeHandle>(
NativeHandle, Plugin->getUrAdapter(), nullptr, &UrPlatform);

return detail::createSyclObjFromImpl<platform>(
platform_impl::getOrMakePlatformImpl(UrPlatform, Plugin));
Expand All @@ -84,8 +84,8 @@ __SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
const auto &Plugin = getPlugin(Backend);

ur_device_handle_t UrDevice = nullptr;
Plugin->call(urDeviceCreateWithNativeHandle, NativeHandle,
Plugin->getUrAdapter(), nullptr, &UrDevice);
Plugin->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Plugin->getUrAdapter(), nullptr, &UrDevice);
// Construct the SYCL device from UR device.
return detail::createSyclObjFromImpl<device>(
std::make_shared<device_impl>(UrDevice, Plugin));
Expand All @@ -105,9 +105,9 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
for (const auto &Dev : DeviceList) {
DeviceHandles.push_back(detail::getSyclObjImpl(Dev)->getHandleRef());
}
Plugin->call(urContextCreateWithNativeHandle, NativeHandle,
Plugin->getUrAdapter(), DeviceHandles.size(),
DeviceHandles.data(), &Properties, &UrContext);
Plugin->call<UrApiKind::urContextCreateWithNativeHandle>(
NativeHandle, Plugin->getUrAdapter(), DeviceHandles.size(),
DeviceHandles.data(), &Properties, &UrContext);
// Construct the SYCL context from UR context.
return detail::createSyclObjFromImpl<context>(std::make_shared<context_impl>(
UrContext, Handler, Plugin, DeviceList, !KeepOwnership));
Expand Down Expand Up @@ -150,9 +150,9 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
// Create UR queue first.
ur_queue_handle_t UrQueue = nullptr;

Plugin->call(urQueueCreateWithNativeHandle, NativeHandle,
ContextImpl->getHandleRef(), UrDevice, &NativeProperties,
&UrQueue);
Plugin->call<UrApiKind::urQueueCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), UrDevice, &NativeProperties,
&UrQueue);
// Construct the SYCL queue from UR queue.
return detail::createSyclObjFromImpl<queue>(
std::make_shared<queue_impl>(UrQueue, ContextImpl, Handler, PropList));
Expand All @@ -174,13 +174,13 @@ __SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle,
Properties.stype = UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;

Plugin->call(urEventCreateWithNativeHandle, NativeHandle,
ContextImpl->getHandleRef(), &Properties, &UrEvent);
Plugin->call<UrApiKind::urEventCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrEvent);
event Event = detail::createSyclObjFromImpl<event>(
std::make_shared<event_impl>(UrEvent, Context));

if (Backend == backend::opencl)
Plugin->call(urEventRetain, UrEvent);
Plugin->call<UrApiKind::urEventRetain>(UrEvent);
return Event;
}

Expand All @@ -196,50 +196,50 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;

Plugin->call(urProgramCreateWithNativeHandle, NativeHandle,
ContextImpl->getHandleRef(), &Properties, &UrProgram);
Plugin->call<UrApiKind::urProgramCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrProgram);
if (UrProgram == nullptr)
throw sycl::exception(
sycl::make_error_code(sycl::errc::invalid),
"urProgramCreateWithNativeHandle resulted in a null program handle.");

if (ContextImpl->getBackend() == backend::opencl)
Plugin->call(urProgramRetain, UrProgram);
Plugin->call<UrApiKind::urProgramRetain>(UrProgram);

std::vector<ur_device_handle_t> ProgramDevices;
uint32_t NumDevices = 0;

Plugin->call(urProgramGetInfo, UrProgram, UR_PROGRAM_INFO_NUM_DEVICES,
sizeof(NumDevices), &NumDevices, nullptr);
Plugin->call<UrApiKind::urProgramGetInfo>(
UrProgram, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(NumDevices), &NumDevices,
nullptr);
ProgramDevices.resize(NumDevices);
Plugin->call(urProgramGetInfo, UrProgram, UR_PROGRAM_INFO_DEVICES,
sizeof(ur_device_handle_t) * NumDevices, ProgramDevices.data(),
nullptr);
Plugin->call<UrApiKind::urProgramGetInfo>(
UrProgram, UR_PROGRAM_INFO_DEVICES,
sizeof(ur_device_handle_t) * NumDevices, ProgramDevices.data(), nullptr);

for (auto &Dev : ProgramDevices) {
ur_program_binary_type_t BinaryType;
Plugin->call(urProgramGetBuildInfo, UrProgram, Dev,
UR_PROGRAM_BUILD_INFO_BINARY_TYPE,
sizeof(ur_program_binary_type_t), &BinaryType, nullptr);
Plugin->call<UrApiKind::urProgramGetBuildInfo>(
UrProgram, Dev, UR_PROGRAM_BUILD_INFO_BINARY_TYPE,
sizeof(ur_program_binary_type_t), &BinaryType, nullptr);
switch (BinaryType) {
case (UR_PROGRAM_BINARY_TYPE_NONE):
if (State == bundle_state::object) {
auto Res = Plugin->call_nocheck(urProgramCompileExp, UrProgram, 1, &Dev,
nullptr);
auto Res = Plugin->call_nocheck<UrApiKind::urProgramCompileExp>(
UrProgram, 1, &Dev, nullptr);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Plugin->call_nocheck(urProgramCompile,
ContextImpl->getHandleRef(), UrProgram,
nullptr);
Res = Plugin->call_nocheck<UrApiKind::urProgramCompile>(
ContextImpl->getHandleRef(), UrProgram, nullptr);
}
Plugin->checkUrResult<errc::build>(Res);
}

else if (State == bundle_state::executable) {
auto Res = Plugin->call_nocheck(urProgramBuildExp, UrProgram, 1, &Dev,
nullptr);
auto Res = Plugin->call_nocheck<UrApiKind::urProgramBuildExp>(
UrProgram, 1, &Dev, nullptr);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Plugin->call_nocheck(
urProgramBuild, ContextImpl->getHandleRef(), UrProgram, nullptr);
Res = Plugin->call_nocheck<UrApiKind::urProgramBuild>(
ContextImpl->getHandleRef(), UrProgram, nullptr);
}
Plugin->checkUrResult<errc::build>(Res);
}
Expand All @@ -254,12 +254,13 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
detail::codeToString(UR_RESULT_ERROR_INVALID_VALUE));
if (State == bundle_state::executable) {
ur_program_handle_t UrLinkedProgram = nullptr;
auto Res =
Plugin->call_nocheck(urProgramLinkExp, ContextImpl->getHandleRef(),
1, &Dev, 1, &UrProgram, nullptr, &UrLinkedProgram);
auto Res = Plugin->call_nocheck<UrApiKind::urProgramLinkExp>(
ContextImpl->getHandleRef(), 1, &Dev, 1, &UrProgram, nullptr,
&UrLinkedProgram);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Plugin->call_nocheck(urProgramLink, ContextImpl->getHandleRef(),
1, &UrProgram, nullptr, &UrLinkedProgram);
Res = Plugin->call_nocheck<UrApiKind::urProgramLink>(
ContextImpl->getHandleRef(), 1, &UrProgram, nullptr,
&UrLinkedProgram);
}
Plugin->checkUrResult<errc::build>(Res);
if (UrLinkedProgram != nullptr) {
Expand Down Expand Up @@ -345,11 +346,12 @@ kernel make_kernel(const context &TargetContext,
ur_kernel_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;
Plugin->call(urKernelCreateWithNativeHandle, NativeHandle,
ContextImpl->getHandleRef(), UrProgram, &Properties, &UrKernel);
Plugin->call<UrApiKind::urKernelCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), UrProgram, &Properties,
&UrKernel);

if (Backend == backend::opencl)
Plugin->call(urKernelRetain, UrKernel);
Plugin->call<UrApiKind::urKernelRetain>(UrKernel);

// Construct the SYCL queue from UR queue.
return detail::createSyclObjFromImpl<kernel>(
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/backend/level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ __SYCL_EXPORT device make_device(const platform &Platform,
const auto &PlatformImpl = getSyclObjImpl(Platform);
// Create UR device first.
ur_device_handle_t UrDevice;
Plugin->call(urDeviceCreateWithNativeHandle, NativeHandle,
Plugin->getUrAdapter(), nullptr, &UrDevice);
Plugin->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Plugin->getUrAdapter(), nullptr, &UrDevice);

return detail::createSyclObjFromImpl<device>(
PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl));
Expand Down
24 changes: 14 additions & 10 deletions sycl/source/backend/opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ __SYCL_EXPORT bool has_extension(const sycl::platform &SyclPlatform,
// Manual invocation of plugin API to avoid using deprecated
// info::platform::extensions call.
size_t ResultSize = 0;
Plugin->call(urPlatformGetInfo, PluginPlatform, UR_PLATFORM_INFO_EXTENSIONS,
/*propSize=*/0,
/*pPropValue=*/nullptr, &ResultSize);
Plugin->call<UrApiKind::urPlatformGetInfo>(
PluginPlatform, UR_PLATFORM_INFO_EXTENSIONS,
/*propSize=*/0,
/*pPropValue=*/nullptr, &ResultSize);
if (ResultSize == 0)
return false;

std::unique_ptr<char[]> Result(new char[ResultSize]);
Plugin->call(urPlatformGetInfo, PluginPlatform, UR_PLATFORM_INFO_EXTENSIONS,
ResultSize, Result.get(), nullptr);
Plugin->call<UrApiKind::urPlatformGetInfo>(PluginPlatform,
UR_PLATFORM_INFO_EXTENSIONS,
ResultSize, Result.get(), nullptr);

std::string_view ExtensionsString(Result.get());
return ExtensionsString.find(Extension) != std::string::npos;
Expand All @@ -68,15 +70,17 @@ __SYCL_EXPORT bool has_extension(const sycl::device &SyclDevice,
// Manual invocation of plugin API to avoid using deprecated
// info::device::extensions call.
size_t ResultSize = 0;
Plugin->call(urDeviceGetInfo, PluginDevice, UR_DEVICE_INFO_EXTENSIONS,
/*propSize=*/0,
/*pPropValue=*/nullptr, &ResultSize);
Plugin->call<UrApiKind::urDeviceGetInfo>(PluginDevice,
UR_DEVICE_INFO_EXTENSIONS,
/*propSize=*/0,
/*pPropValue=*/nullptr, &ResultSize);
if (ResultSize == 0)
return false;

std::unique_ptr<char[]> Result(new char[ResultSize]);
Plugin->call(urDeviceGetInfo, PluginDevice, UR_DEVICE_INFO_EXTENSIONS,
ResultSize, Result.get(), nullptr);
Plugin->call<UrApiKind::urDeviceGetInfo>(PluginDevice,
UR_DEVICE_INFO_EXTENSIONS,
ResultSize, Result.get(), nullptr);

std::string_view ExtensionsString(Result.get());
return ExtensionsString.find(Extension) != std::string::npos;
Expand Down
7 changes: 3 additions & 4 deletions sycl/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ context::context(cl_context ClContext, async_handler AsyncHandler) {
ur_context_handle_t hContext = nullptr;
ur_native_handle_t nativeHandle =
reinterpret_cast<ur_native_handle_t>(ClContext);
Plugin->call(urContextCreateWithNativeHandle, nativeHandle, Plugin->getUrAdapter(),
0, nullptr, nullptr,
&hContext);

Plugin->call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, Plugin->getUrAdapter(), 0, nullptr, nullptr, &hContext);

impl = std::make_shared<detail::context_impl>(
hContext, AsyncHandler, Plugin);
}
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/allowlist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
auto DeviceImpl = PlatformImpl->getOrMakeDeviceImpl(Device, PlatformImpl);
// get DeviceType value and put it to DeviceDesc
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
Plugin->call(urDeviceGetInfo, Device, UR_DEVICE_INFO_TYPE,
sizeof(UrDevType), &UrDevType, nullptr);
Plugin->call<UrApiKind::urDeviceGetInfo>(
Device, UR_DEVICE_INFO_TYPE, sizeof(UrDevType), &UrDevType, nullptr);
// TODO need mechanism to do these casts, there's a bunch of this sort of
// thing
sycl::info::device_type DeviceType = info::device_type::all;
Expand Down
Loading

0 comments on commit 39cd4f3

Please sign in to comment.