Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceSanitizer] Unmap/Release virtual memory when there is no dependency #2065

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 161 additions & 37 deletions source/loader/layers/sanitizer/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,47 +111,56 @@ ur_result_t enqueueMemSetShadow(ur_context_handle_t Context,

ur_physical_mem_properties_t Desc{
UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES, nullptr, 0};
static ur_physical_mem_handle_t PhysicalMem{};

// Make sure [Ptr, Ptr + Size] is mapped to physical memory
for (auto MappedPtr = RoundDownTo(ShadowBegin, PageSize);
MappedPtr <= ShadowEnd; MappedPtr += PageSize) {
if (!PhysicalMem) {
auto URes = getContext()->urDdiTable.PhysicalMem.pfnCreate(
Context, DeviceInfo->Handle, PageSize, &Desc,
&PhysicalMem);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error("urPhysicalMemCreate(): {}",
URes);
return URes;
}
}

getContext()->logger.debug("urVirtualMemMap: {} ~ {}",
(void *)MappedPtr,
(void *)(MappedPtr + PageSize - 1));

// FIXME: No flag to check the failed reason is VA is already mapped
auto URes = getContext()->urDdiTable.VirtualMem.pfnMap(
Context, (void *)MappedPtr, PageSize, PhysicalMem, 0,
UR_VIRTUAL_MEM_ACCESS_FLAG_READ_WRITE);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.debug("urVirtualMemMap({}, {}): {}",
(void *)MappedPtr, PageSize,
URes);
}

// Initialize to zero
if (URes == UR_RESULT_SUCCESS) {
// Reset PhysicalMem to null since it's been mapped
PhysicalMem = nullptr;

auto URes =
urEnqueueUSMSet(Queue, (void *)MappedPtr, 0, PageSize);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error("urEnqueueUSMFill(): {}",
URes);
return URes;
{
std::scoped_lock<ur_mutex> Guard(DeviceInfo->Mutex);
if (DeviceInfo->VirtualMemMaps.find(MappedPtr) ==
DeviceInfo->VirtualMemMaps.end()) {
ur_physical_mem_handle_t PhysicalMem{};
auto URes =
getContext()->urDdiTable.PhysicalMem.pfnCreate(
Context, DeviceInfo->Handle, PageSize, &Desc,
&PhysicalMem);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error(
"urPhysicalMemCreate(): {}", URes);
return URes;
}

URes = getContext()->urDdiTable.VirtualMem.pfnMap(
Context, (void *)MappedPtr, PageSize, PhysicalMem,
0, UR_VIRTUAL_MEM_ACCESS_FLAG_READ_WRITE);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.debug(
"urVirtualMemMap({}, {}): {}",
(void *)MappedPtr, PageSize, URes);
return URes;
}

getContext()->logger.debug(
"urVirtualMemMap: {} ~ {}", (void *)MappedPtr,
(void *)(MappedPtr + PageSize - 1));

// Initialize to zero
URes = urEnqueueUSMSet(Queue, (void *)MappedPtr, 0,
PageSize);
if (URes != UR_RESULT_SUCCESS) {
getContext()->logger.error("urEnqueueUSMFill(): {}",
URes);
return URes;
}

auto AllocInfoIt =
getContext()->interceptor->findAllocInfoByAddress(
Ptr);
assert(AllocInfoIt);
DeviceInfo->VirtualMemMaps[MappedPtr].first =
PhysicalMem;
DeviceInfo->VirtualMemMaps[MappedPtr].second.insert(
(*AllocInfoIt)->second);
}
}
}
Expand Down Expand Up @@ -349,6 +358,13 @@ ur_result_t SanitizerInterceptor::releaseMemory(ur_context_handle_t Context,
getContext()->logger.debug("Free: {}", (void *)AllocInfo->AllocBegin);
std::scoped_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
m_AllocationMap.erase(AllocInfoIt);
if (AllocInfo->Type == AllocType::HOST_USM) {
UR_CALL(releasePhysicalMem(Context, ContextInfo->DeviceList,
AllocInfo));
} else {
UR_CALL(
releasePhysicalMem(Context, {AllocInfo->Device}, AllocInfo));
}
return getContext()->urDdiTable.USM.pfnFree(
Context, (void *)(AllocInfo->AllocBegin));
}
Expand All @@ -360,6 +376,13 @@ ur_result_t SanitizerInterceptor::releaseMemory(ur_context_handle_t Context,
getContext()->logger.info("Quarantine Free: {}",
(void *)It->second->AllocBegin);
m_AllocationMap.erase(It);
if (AllocInfo->Type == AllocType::HOST_USM) {
UR_CALL(releasePhysicalMem(Context, ContextInfo->DeviceList,
AllocInfo));
} else {
UR_CALL(releasePhysicalMem(Context, {AllocInfo->Device},
AllocInfo));
}
UR_CALL(getContext()->urDdiTable.USM.pfnFree(
Context, (void *)(It->second->AllocBegin)));
}
Expand All @@ -368,6 +391,59 @@ ur_result_t SanitizerInterceptor::releaseMemory(ur_context_handle_t Context,
return UR_RESULT_SUCCESS;
}

ur_result_t SanitizerInterceptor::releasePhysicalMem(
ur_context_handle_t Context, const std::vector<ur_device_handle_t> &Devices,
std::shared_ptr<struct AllocInfo> AI) {
for (auto Device : Devices) {
auto DeviceInfo = getDeviceInfo(Device);

if (DeviceInfo->Type != DeviceType::GPU_PVC &&
DeviceInfo->Type != DeviceType::GPU_DG2) {
continue;
}

uptr ShadowBegin = 0, ShadowEnd = 0;

if (DeviceInfo->Type == DeviceType::GPU_PVC) {
ShadowBegin =
MemToShadow_PVC(DeviceInfo->ShadowOffset, AI->AllocBegin);
ShadowEnd = MemToShadow_PVC(DeviceInfo->ShadowOffset,
AI->AllocBegin + AI->AllocSize - 1);
} else if (DeviceInfo->Type == DeviceType::GPU_DG2) {
ShadowBegin =
MemToShadow_DG2(DeviceInfo->ShadowOffset, AI->AllocBegin);
ShadowEnd = MemToShadow_DG2(DeviceInfo->ShadowOffset,
AI->AllocBegin + AI->AllocSize - 1);
}

assert(ShadowBegin <= ShadowEnd);

static const size_t PageSize =
GetVirtualMemGranularity(Context, DeviceInfo->Handle);

for (auto MappedPtr = RoundDownTo(ShadowBegin, PageSize);
MappedPtr <= ShadowEnd; MappedPtr += PageSize) {
std::scoped_lock<ur_mutex> Guard(DeviceInfo->Mutex);
if (DeviceInfo->VirtualMemMaps.find(MappedPtr) ==
DeviceInfo->VirtualMemMaps.end()) {
continue;
}
DeviceInfo->VirtualMemMaps[MappedPtr].second.erase(AI);
if (DeviceInfo->VirtualMemMaps[MappedPtr].second.empty()) {
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnUnmap(
Context, (void *)MappedPtr, PageSize));
UR_CALL(getContext()->urDdiTable.PhysicalMem.pfnRelease(
DeviceInfo->VirtualMemMaps[MappedPtr].first));
getContext()->logger.debug("urVirtualMemUnmap: {} ~ {}",
(void *)MappedPtr,
(void *)(MappedPtr + PageSize - 1));
}
}
}

return UR_RESULT_SUCCESS;
}

ur_result_t SanitizerInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel,
ur_queue_handle_t Queue,
USMLaunchInfo &LaunchInfo) {
Expand Down Expand Up @@ -584,6 +660,54 @@ SanitizerInterceptor::registerDeviceGlobals(ur_context_handle_t Context,
{}});

ContextInfo->insertAllocInfo({Device}, AI);

// For VirtualMem Unmap
{
std::scoped_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
m_AllocationMap.emplace(AI->AllocBegin, std::move(AI));
}
}
}

return UR_RESULT_SUCCESS;
}

ur_result_t
SanitizerInterceptor::unregisterDeviceGlobals(ur_context_handle_t Context,
ur_program_handle_t Program) {
std::vector<ur_device_handle_t> Devices = GetProgramDevices(Program);

for (auto Device : Devices) {
ManagedQueue Queue(Context, Device);

uint64_t NumOfDeviceGlobal;
auto Result =
getContext()->urDdiTable.Enqueue.pfnDeviceGlobalVariableRead(
Queue, Program, kSPIR_AsanDeviceGlobalCount, true,
sizeof(NumOfDeviceGlobal), 0, &NumOfDeviceGlobal, 0, nullptr,
nullptr);
if (Result != UR_RESULT_SUCCESS) {
getContext()->logger.info("No device globals");
continue;
}

std::vector<DeviceGlobalInfo> GVInfos(NumOfDeviceGlobal);
Result = getContext()->urDdiTable.Enqueue.pfnDeviceGlobalVariableRead(
Queue, Program, kSPIR_AsanDeviceGlobalMetadata, true,
sizeof(DeviceGlobalInfo) * NumOfDeviceGlobal, 0, &GVInfos[0], 0,
nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
getContext()->logger.error("Device Global[{}] Read Failed: {}",
kSPIR_AsanDeviceGlobalMetadata, Result);
return Result;
}

for (size_t i = 0; i < NumOfDeviceGlobal; i++) {
std::scoped_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
auto AllocInfo = m_AllocationMap[GVInfos[i].Addr];
UR_CALL(
releasePhysicalMem(Context, {AllocInfo->Device}, AllocInfo));
m_AllocationMap.erase(GVInfos[i].Addr);
}
}

Expand Down
12 changes: 12 additions & 0 deletions source/loader/layers/sanitizer/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ struct DeviceInfo {

ur_mutex Mutex;
std::queue<std::shared_ptr<AllocInfo>> Quarantine;
std::unordered_map<
uptr, std::pair<ur_physical_mem_handle_t,
std::unordered_set<std::shared_ptr<AllocInfo>>>>
VirtualMemMaps;
size_t QuarantineSize = 0;

// Device handles are special and alive in the whole process lifetime,
Expand Down Expand Up @@ -173,9 +177,17 @@ class SanitizerInterceptor {
AllocType Type, void **ResultPtr);
ur_result_t releaseMemory(ur_context_handle_t Context, void *Ptr);

ur_result_t
releasePhysicalMem(ur_context_handle_t Context,
const std::vector<ur_device_handle_t> &Devices,
std::shared_ptr<struct AllocInfo> AI);

ur_result_t registerDeviceGlobals(ur_context_handle_t Context,
ur_program_handle_t Program);

ur_result_t unregisterDeviceGlobals(ur_context_handle_t Context,
ur_program_handle_t Program);

ur_result_t preLaunchKernel(ur_kernel_handle_t Kernel,
ur_queue_handle_t Queue,
USMLaunchInfo &LaunchInfo);
Expand Down
22 changes: 22 additions & 0 deletions source/loader/layers/sanitizer/ur_sanddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,28 @@ ur_result_t UR_APICALL urProgramLinkExp(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urProgramRelease
ur_result_t UR_APICALL urProgramRelease(
ur_program_handle_t
hProgram ///< [in][release] handle for the Program to release
) {
auto pfnProgramRelease = getContext()->urDdiTable.Program.pfnRelease;

if (nullptr == pfnProgramRelease) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_CALL(getContext()->interceptor->unregisterDeviceGlobals(
GetContext(hProgram), hProgram));

UR_CALL(pfnProgramRelease(hProgram));

getContext()->logger.debug("==== urProgramRelease");

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEnqueueKernelLaunch
__urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(
Expand Down
Loading