From 5e36d1bb538d59e927cc624d261c02f16c9fc2a3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 15 Feb 2024 13:31:26 -0600 Subject: [PATCH] [BYO-FT] Improve check for available memory (#210) Previously, any parameters whose allocations was performed by the relax VM would be double-counted. --- serve/mlc_serve/model/tvm_model.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 653da250d8..5e117836ca 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -185,10 +185,20 @@ def __init__( self.cache_blocks = None - def get_used_memory(self): + def get_param_nbytes(self): + """Get the total size of the parameters""" if self.disco_session: params = self.params.debug_get_from_remote(0) + else: + params = self.params + + return sum( + math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params + ) + def get_used_memory(self): + """Get the total memory allocated by the VM""" + if self.disco_session: get_used_memory_func = self.disco_session.get_global_func( "vm.memory_manager.get_used_memory" ) @@ -197,18 +207,12 @@ def get_used_memory(self): tvm.device("cuda", 0) ).debug_get_from_remote(0) else: - params = self.params - get_used_memory_func = tvm.get_global_func( "vm.memory_manager.get_used_memory" ) peak_memory = get_used_memory_func(self.dev) - param_bytes = sum( - math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params - ) - - return peak_memory + param_bytes + return peak_memory def profile_memory_usage(self, seq_lens): input_ids = [0] * sum(seq_lens) @@ -217,6 +221,8 @@ def profile_memory_usage(self, seq_lens): for s in seq_lens: positions += range(s) + vm_alloc_before = self.get_used_memory() + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), self.dev) positions = tvm.nd.array(np.array(positions, dtype="int32"), self.dev) seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), self.dev) @@ -244,7 +250,9 @@ def profile_memory_usage(self, seq_lens): self.mod["evaluate"](input_ids, positions, seq_lens, self.params) stop_profiling_func() - return self.get_used_memory() + vm_alloc_after = self.get_used_memory() + + return self.get_param_nbytes() + (vm_alloc_after - vm_alloc_before) def generate_multi_query( self,