Skip to content

Commit

Permalink
[BYO-FT] Improve check for available memory (#210)
Browse files Browse the repository at this point in the history
Previously, any parameters whose allocations was performed by the
relax VM would be double-counted.
  • Loading branch information
Lunderberg committed Feb 15, 2024
1 parent 6302271 commit 5e36d1b
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,20 @@ def __init__(

self.cache_blocks = None

def get_used_memory(self):
def get_param_nbytes(self):
"""Get the total size of the parameters"""
if self.disco_session:
params = self.params.debug_get_from_remote(0)
else:
params = self.params

return sum(
math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params
)

def get_used_memory(self):
"""Get the total memory allocated by the VM"""
if self.disco_session:
get_used_memory_func = self.disco_session.get_global_func(
"vm.memory_manager.get_used_memory"
)
Expand All @@ -197,18 +207,12 @@ def get_used_memory(self):
tvm.device("cuda", 0)
).debug_get_from_remote(0)
else:
params = self.params

get_used_memory_func = tvm.get_global_func(
"vm.memory_manager.get_used_memory"
)
peak_memory = get_used_memory_func(self.dev)

param_bytes = sum(
math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params
)

return peak_memory + param_bytes
return peak_memory

def profile_memory_usage(self, seq_lens):
input_ids = [0] * sum(seq_lens)
Expand All @@ -217,6 +221,8 @@ def profile_memory_usage(self, seq_lens):
for s in seq_lens:
positions += range(s)

vm_alloc_before = self.get_used_memory()

input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), self.dev)
positions = tvm.nd.array(np.array(positions, dtype="int32"), self.dev)
seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), self.dev)
Expand Down Expand Up @@ -244,7 +250,9 @@ def profile_memory_usage(self, seq_lens):
self.mod["evaluate"](input_ids, positions, seq_lens, self.params)
stop_profiling_func()

return self.get_used_memory()
vm_alloc_after = self.get_used_memory()

return self.get_param_nbytes() + (vm_alloc_after - vm_alloc_before)

def generate_multi_query(
self,
Expand Down

0 comments on commit 5e36d1b

Please sign in to comment.