diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 1af8bf951..0e2be342a 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -24,6 +24,9 @@ assert ( version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0") ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0" + assert version.parse("1.1.10") <= version.parse( + hivemind.__version__ + ), "Please install a proper hivemind version: pip install hivemind>=1.1.10" def _override_bfloat16_mode_default(): diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 34d24c7ec..8789be7da 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -4,18 +4,16 @@ import itertools import time import uuid -from typing import AsyncIterator, List, Optional, Tuple +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple import torch -from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor +from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 -from hivemind.utils.tensor_descr import BatchTensorDescriptor +from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten -from petals.client.config import ClientConfig from petals.client.routing import RemoteSequenceManager, maybe_log_traceback -from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo +from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy from petals.utils.packaging import pack_args_kwargs @@ -32,23 +30,21 @@ class _ServerInferenceSession: def __init__( self, - config: ClientConfig, + sequence_manager: RemoteSequenceManager, span: RemoteSpanInfo, - uid: ModuleUID, - rpc_info: RPCInfo, + span_uids: Sequence[ModuleUID], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, - *, + *block_kwargs, max_length: int, - **metadata, ): - self.config = config - self.span, self.uid, self.rpc_info = span, uid, rpc_info - self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 + self.sequence_manager = sequence_manager + self.span, self.span_uids = span, span_uids + self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.session_id = str(uuid.uuid4()) - self.session_metadata = dict(max_length=max_length, **metadata) + self.max_length = max_length self.stepped = False self.closed = False @@ -56,24 +52,26 @@ def __init__( self.history = None # Used in case of server failures to regenerate attention caches on new servers self.next_session = None + self.block_kwargs = block_kwargs + assert len(self.block_kwargs) in (0, self.num_blocks) + @classmethod async def create( cls, - config: ClientConfig, - p2p: P2P, + sequence_manager: RemoteSequenceManager, span: RemoteSpanInfo, - uid: ModuleUID, - rpc_info: RPCInfo, - **metadata, + span_uids: Sequence[ModuleUID], + *block_kwargs: Dict[str, Any], + **kwargs, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" - stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - config.connect_timeout, + sequence_manager.config.connect_timeout, ) - return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) + return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: @@ -84,11 +82,16 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ break # this message means "done sending" def step( - self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str + self, + inputs: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + *, + hypo_ids: Optional[torch.Tensor] = None, + step_id: str, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs - :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, + :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ if self.closed: @@ -106,41 +109,70 @@ def step( if not self.stepped: inputs = self.history # Pass full inputs including prefix + block_kwargs = self.block_kwargs else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + block_kwargs = [] - # serialize inputs and put them into the queue - input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) + if prompts is None or is_dummy(prompts): + prompts = DUMMY + else: + assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" + assert prompts.shape[0] == self.num_blocks + assert prompts.shape[1] in (inputs.shape[0], 1) + assert prompts.shape[2] <= inputs.shape[1] + assert prompts.shape[3] == inputs.shape[2] - request_metadata = dict(session_id=self.session_id, step_id=step_id) - if not self.stepped: - request_metadata.update(self.session_metadata) - elif self.config.use_server_to_server: + if hypo_ids is None or is_dummy(hypo_ids): + hypo_ids = DUMMY_INT64 + else: + assert len(hypo_ids) == len(inputs) + assert hypo_ids.dtype == torch.int64 + + metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length) + metadata.update( + self.sequence_manager.get_request_metadata( + self.span.peer_id, + "rpc_inference", + self.span_uids, + inputs, + prompts, + *block_kwargs, + max_length=self.max_length, + session_id=self.session_id, + step_id=step_id, + ) + ) + if self.stepped and self.sequence_manager.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: - request_metadata["next_servers"] = next_servers + metadata["next_servers"] = next_servers - request_metadata["args_structure"] = args_structure + codecs = self.sequence_manager.get_compression_codecs( + self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs + ) - # TODO: make possible to use different compression method for different tensors - server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"] - compression = server_side_inference_schema[0].compression - inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors) + # serialize inputs and put them into the queue + input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs) + args_structure = metadata.setdefault("args_structure", args_structure) - # TODO: create more explicit way to check servers schema and client's structure - assert len(input_tensors) >= len( - server_side_inference_schema - ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step" + if codecs is None: + codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors) + else: + codecs = list(nested_flatten(codecs)) + assert len(codecs) == len( + input_tensors + ), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs" outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( - uid=self.uid, + uid=CHAIN_DELIMITER.join(self.span_uids), tensors=[ - serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(input_tensors, inference_schema) + serialize_torch_tensor(tensor, compression) + for tensor, compression in zip(input_tensors, codecs) ], - metadata=MSGPackSerializer.dumps(request_metadata), + metadata=MSGPackSerializer.dumps(metadata), ) ) ) @@ -167,7 +199,7 @@ async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_p """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True - return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout) + return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout) def close(self): """Finish a given inference session, close the underlying connection""" @@ -204,7 +236,7 @@ class InferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): + def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]): self._sequence_manager = sequence_manager self._closed = False self._server_sessions = [] @@ -212,6 +244,12 @@ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._max_length = max_length self.output_ids = None + num_blocks = len(self._sequence_manager) + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * num_blocks + assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}" + self.block_kwargs = block_kwargs + @property def num_blocks(self) -> int: return len(self._sequence_manager) @@ -224,17 +262,13 @@ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_Se server_sessions = [] try: for span in chosen_spans: - span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) - metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( - self._sequence_manager.config, - self._sequence_manager.state.p2p, + self._sequence_manager, span, - span_uids, - rpc_info=self._sequence_manager.rpc_info, + self._sequence_manager.block_uids[span.start : span.end], + *self.block_kwargs[span.start : span.end], max_length=self._max_length, - **metadata, ) ) server_sessions.append(session) @@ -256,8 +290,12 @@ def __enter__(self) -> "InferenceSession": return self def step( - self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None + self, + inputs: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") @@ -302,7 +340,10 @@ def step( server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id + inputs, + prompts[server_session.span.start : server_session.span.end], + hypo_ids=hypo_ids, + step_id=step_id, ) server_idx += 1 @@ -328,7 +369,7 @@ def step( outputs = outputs.to(device=inputs_device, dtype=inputs_dtype) return outputs - def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int: + def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int): # If there is a failed server session, this code closes it self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 44abe2686..45a30c150 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -2,20 +2,22 @@ Utility functions that call RPC forward or backward on a single remote server """ import asyncio -from typing import Iterable, List, Optional, Sequence, Tuple +from typing import Iterable, List, Sequence, Tuple import torch -from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor +from hivemind import PeerID, nested_flatten, serialize_torch_tensor from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor -from hivemind.p2p import StubBase from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter +from hivemind.utils.serializer import MSGPackSerializer from hivemind.utils.streaming import split_for_streaming -from hivemind.utils.tensor_descr import BatchTensorDescriptor from petals.client.config import ClientConfig -from petals.data_structures import ModuleUID, RPCInfo +from petals.client.routing import RemoteSequenceManager +from petals.data_structures import CHAIN_DELIMITER, ModuleUID +from petals.server.handler import TransformerConnectionHandler +from petals.utils.packaging import pack_args_kwargs async def _forward_unary( @@ -65,85 +67,93 @@ async def _backward_stream( async def run_remote_forward( - uid: ModuleUID, - stub: StubBase, - rpc_info: RPCInfo, - *inputs: torch.Tensor, - config: ClientConfig, - metadata: Optional[bytes] = None, - **kwargs, + sequence_manager: RemoteSequenceManager, + peer_id: PeerID, + span_uids: Sequence[ModuleUID], + *args: torch.Tensor, + **kwargs: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """ Serializes input tensors and calls "rpc_forward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - - # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] - # detach to avoid pickling the computation graph - assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}" - kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]} - - # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors - forward_inputs = tuple(nested_flatten((inputs, kwargs))) - args_schema, kwargs_schema = rpc_info["forward_schema"] - compression = args_schema[0].compression - forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs) - inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) - # TODO: create more explicit way to check servers schema and client's structure - assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step" + merged_uid = CHAIN_DELIMITER.join(span_uids) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id) + metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs) + codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs) + flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs) + flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors) + args_structure = metadata.setdefault("args_structure", args_structure) + if codecs is None: + codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors) + else: + codecs = list(nested_flatten(codecs)) + assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( - loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, forward_schema) + loop.run_in_executor(None, serialize_torch_tensor, tensor, compression) + for tensor, compression in zip(flat_tensors, codecs) ) ) # call RPC on remote server - size = sum(t.element_size() * t.nelement() for t in inputs) + size = sum(t.element_size() * t.nelement() for t in flat_tensors) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary - # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) - return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) + # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR + output_tensors = await forward_fn( + merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata) + ) + # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591 + requires_grad = any(tensor.requires_grad for tensor in flat_tensors) + output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors] + return output_tensors async def run_remote_backward( - uid: ModuleUID, - stub: StubBase, - rpc_info: RPCInfo, - *inputs_and_grad_outputs: torch.Tensor, - config: ClientConfig, - metadata: Optional[bytes] = None, - **kwargs, + sequence_manager: RemoteSequenceManager, + peer_id: PeerID, + span_uids: Sequence[ModuleUID], + grad_outputs: Sequence[torch.Tensor], + *args: torch.Tensor, + **kwargs: torch.Tensor, ) -> Sequence[torch.Tensor]: """ Serializes grad outputs and calls "rpc_backward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - args_schema, kwargs_schema = rpc_info["forward_schema"] - outputs_schema = rpc_info["outputs_schema"] - compression = args_schema[0].compression - backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs) - # TODO: create more explicit way to check servers schema and client's structure - assert ( - len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1 - ), "Inputs, grad_outputs and prompt tensors are necessary for a backward step" + merged_uid = CHAIN_DELIMITER.join(span_uids) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id) + metadata = sequence_manager.get_request_metadata(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs) + codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs) + flat_tensors, args_structure = pack_args_kwargs(grad_outputs, *args, **kwargs) + flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors) + args_structure = metadata.setdefault("args_structure", args_structure) + + if codecs is None: + codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors) + else: + codecs = list(nested_flatten(codecs)) + assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( - loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) + loop.run_in_executor(None, serialize_torch_tensor, tensor, compression) + for tensor, compression in zip(flat_tensors, codecs) ) ) + for tensor, serialized in zip(flat_tensors, serialized_tensors): + serialized.requires_grad = tensor.requires_grad # see https://github.com/learning-at-home/hivemind/pull/591 - size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) + size = sum(t.element_size() * t.nelement() for t in flat_tensors) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) - return deserialized_grad_inputs + return await backward_fn( + merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata) + ) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index c6d2833d1..4d43f310a 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -49,13 +49,13 @@ def __init__( self._active_session = ContextVar("active_session", default=None) - def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor: assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" if self.active_session is None: assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}" - return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) + return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args) else: - return self.active_session.step(inputs, prompts, **kwargs) + return self.active_session.step(inputs, prompts, *args, **kwargs) @property def active_session(self) -> Optional[InferenceSession]: diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index ed5224cbd..c9cc94ba0 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -471,21 +471,33 @@ def get_retry_delay(self, attempt_no: int) -> float: return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) def get_request_metadata( - self, protocol: str, args_structure: Any = None, *args, **kwargs + self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs ) -> Optional[Dict[str, Any]]: """ + :param peer_id: remote server's PeerID :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" - :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging - :param args: request-specific inputs, typically block uids and input tensors - :param kwargs: additional request context, such as remote peer ID - :returns: msgpack-serialized metadata dict that will be passed alongside a given request + :param args: request-specific input tensors + :param kwargs: additional request keyword arguments + :returns: metadata dict that will be passed alongside a given request """ return dict( points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter, - args_structure=args_structure, ) + def get_compression_codecs( + self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs + ) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]: + """ + return a sequence of compression codecs for client-side compression (applied to tensors sent to remote server) + :param peer_id: remote server's PeerID + :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" + :param args: request-specific input tensors + :param kwargs: additional request keyword arguments + :returns: compressions for each input tensor; contains as many elements as there are tensors in (args, kwargs) + """ + return None + def shutdown(self): self._thread.shutdown() diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 9d965d2a5..6d450e370 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -4,19 +4,16 @@ import asyncio import itertools from collections import deque -from typing import List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch -from hivemind import MSGPackSerializer from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward from petals.client.routing import RemoteSequenceManager, maybe_log_traceback -from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo -from petals.server.handler import TransformerConnectionHandler +from petals.data_structures import RemoteSpanInfo from petals.utils.misc import DUMMY, is_dummy -from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) @@ -24,19 +21,26 @@ async def sequential_forward( + sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor, - sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None, + *block_kwargs: Dict[str, Any], ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . Performs chained forward for each subsequence of blocks on the path. If some subsequence fails, reconstructs the remaining path and tries to finish the forward. - """ - assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" + :param inputs: initial hidden states of shape [batch_size, sequence length, hidden_size] + :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, + if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] + :param sequence_manager: a running SequenceManager used to select remote servers and handle failures + :param start_index: run remote blocks starting from this index + :param end_index: run remote blocks up to (but not including) this index + :param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block + """ inputs_device = inputs.device inputs_dtype = inputs.dtype @@ -45,6 +49,12 @@ async def sequential_forward( end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * (end_index - start_index) + assert ( + not block_kwargs or len(block_kwargs) == end_index - start_index + ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" + assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( sequence_manager.block_uids ) # should be n_layers - 1 but add extra prompts for convenience @@ -67,20 +77,13 @@ async def sequential_forward( span = sequences.popleft() - stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) - flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end]) - - span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - metadata = sequence_manager.get_request_metadata( - "rpc_forward", args_structure, span_uids, *flat_tensors - ) (outputs,) = await run_remote_forward( - span_uids, - stub, - sequence_manager.rpc_info, - *flat_tensors, - config=sequence_manager.config, - metadata=MSGPackSerializer.dumps(metadata), + sequence_manager, + span.peer_id, + sequence_manager.block_uids[span.start : span.end], + inputs, + prompts[span.start : span.end], + *block_kwargs[span.start : span.end], ) assert isinstance(outputs, torch.Tensor) @@ -111,11 +114,12 @@ async def sequential_forward( async def sequential_backward( + sequence_manager: RemoteSequenceManager, + forward_sequences: List[RemoteSpanInfo], grad_outputs: Sequence[torch.Tensor], intermediate_inputs: List[torch.Tensor], prompts: torch.Tensor, - forward_sequences: List[RemoteSpanInfo], - sequence_manager: RemoteSequenceManager, + *block_kwargs: Dict[str, Any], ) -> Tuple[Sequence[torch.Tensor], torch.Tensor]: """ Performs chained backward for each forward subsequence. @@ -141,7 +145,7 @@ async def sequential_backward( try: if attempt_no >= 1: _, backup_inputs, backup_sequences = await sequential_forward( - inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end + sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end ) assert len(backup_inputs) == len(backup_sequences) assert backup_sequences[0].start == span.start @@ -152,23 +156,14 @@ async def sequential_backward( inputs = intermediate_inputs.pop() span = forward_sequences.pop() - grad_outputs_cpu = [grad.cpu() for grad in grad_outputs] - flat_tensors, args_structure = pack_args_kwargs( - inputs, *grad_outputs_cpu, prompts[span.start : span.end] - ) - - span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) - metadata = sequence_manager.get_request_metadata( - "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id - ) grad_outputs, *span_grad_prompts = await run_remote_backward( - span_uids, - stub, - sequence_manager.rpc_info, - *flat_tensors, - config=sequence_manager.config, - metadata=MSGPackSerializer.dumps(metadata), + sequence_manager, + span.peer_id, + sequence_manager.block_uids[span.start : span.end], + grad_outputs, + inputs, + prompts[span.start : span.end], + *block_kwargs[span.start : span.end], ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) @@ -200,7 +195,7 @@ async def _gather_forward(input_batches, prompt_batches, sequence_manager): """Wrapper for asyncio.gather to perform parallel sequential forwards""" return await asyncio.gather( *[ - sequential_forward(input_batch, prompt_batch, sequence_manager) + sequential_forward(sequence_manager, input_batch, prompt_batch) for input_batch, prompt_batch in zip(input_batches, prompt_batches) ] ) @@ -212,7 +207,7 @@ async def _gather_backward( """Wrapper for asyncio.gather to perform parallel sequential backwards""" return await asyncio.gather( *[ - sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager) + sequential_backward(sequence_manager, spans, (grad_output,), input_batch, prompt_batch) for grad_output, input_batch, prompt_batch, spans in zip( grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences ) @@ -227,15 +222,17 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): + def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor): + # TODO add kwargs here; figure out a way to split kwargs across servers batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) + input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches) if prompts is None or is_dummy(prompts): prompt_batches = [DUMMY] * len(input_batches) else: prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) + prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches) - sequence_manager.rpc_info # lazy init outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager)) assert len(outputs) == len(input_batches) @@ -274,4 +271,5 @@ def backward(ctx, grad_outputs: torch.Tensor): grad_inputs = torch.cat(grad_input_batches, dim=0) dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches] grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None - return (grad_inputs, grad_prompts, None) + # TODO return grads w.r.t. kwargs here + return (None, grad_inputs, grad_prompts) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3a9b63ef4..2f9684ff2 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch -from hivemind import BatchTensorDescriptor, TensorDescriptor +from hivemind import BatchTensorDescriptor, TensorDescriptor, nested_flatten, nested_map from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger @@ -96,22 +96,29 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S cache_tensors.extend((keys, values)) return cache_tensors - def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: - *inputs, active_adapter = inputs - with self._peft_module.using_adapter(active_adapter): - return super().forward(*inputs) - - def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: - *inputs, active_adapter = inputs - with self._peft_module.using_adapter(active_adapter): - return super().backward(*inputs) + def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]: + with self._peft_module.using_adapter(active_adapter), torch.no_grad(): + return self.module(*args, **kwargs) + + def backward( + self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs + ) -> Tuple[Union[torch.Tensor, Any], ...]: + with self._peft_module.using_adapter(active_adapter), torch.enable_grad(): + (outputs,) = self.module(*args, **kwargs) + assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape + torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False) + return nested_map(self._get_grad_if_required, (*args, kwargs)) + + @staticmethod + def _get_grad_if_required(input: Any) -> Optional[torch.Tensor]: + """Get grad w.r.t. input if input is a tensor that requires grad; otherwise return None""" + if isinstance(input, torch.Tensor) and input.requires_grad: + return input.grad if input.grad is not None else torch.zeros_like(input) + return None @torch.inference_mode() def inference_step( - self, - hidden_states: torch.Tensor, - hypo_ids: torch.LongTensor, - inference_info: InferenceMetadata, + self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, **kwargs ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" seq_len = hidden_states.shape[1] @@ -129,8 +136,9 @@ def inference_step( layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) for offset in range(0, seq_len, max_chunk_length): hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] + kwargs_chunk = self._select_kwargs_chunk(kwargs, seq_len, offset, max_chunk_length) output_hidden_states_chunk, new_kvs = self.module.forward( - hidden_states_chunk, layer_past=layer_past, use_cache=True + hidden_states_chunk, layer_past=layer_past, use_cache=True, **kwargs_chunk ) if seq_len > max_chunk_length: output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk @@ -178,6 +186,17 @@ def _update_cache_inplace( new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim) cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] + @staticmethod + def _select_kwargs_chunk(kwargs: Dict[str, Any], seq_len: int, offset: int, max_chunk_length: int): + if offset == 0 and max_chunk_length >= seq_len: + return kwargs + kwargs_chunk = {} + for key, value in kwargs.items(): + if isinstance(value, torch.Tensor) and value.ndim >= 2 and value.shape[-2] == seq_len: + value = value[:, offset : offset + max_chunk_length] + kwargs_chunk[key] = value + return kwargs_chunk + def get_pools(self) -> Sequence[PrioritizedTaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool @@ -200,8 +219,9 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]) """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values()) first_pool = next(iter(backends.values())).inference_pool + merged_inference_func = _MergedInferenceStep(backends) merged_pool = PrioritizedTaskPool( - _MergedInferenceStep(backends), + merged_inference_func, max_batch_size=first_pool.max_batch_size, device=first_pool.device, name=f"merged_inference", @@ -222,12 +242,15 @@ def __call__( hypo_ids: torch.LongTensor, inference_infos: Sequence[InferenceMetadata], *optional_prompts: Optional[torch.Tensor], + block_kwargs: Sequence[Dict[str, torch.Tensor]], ) -> Tuple[torch.Tensor, ...]: - assert len(inference_infos) == len( - optional_prompts - ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts" - for inference_info, optional_prompt in zip(inference_infos, optional_prompts): + assert ( + len(inference_infos) == len(optional_prompts) == len(block_kwargs) + ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(block_kwargs)} kwargs" + for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, block_kwargs): if optional_prompt is not None: hidden_states[:, : optional_prompt.shape[1]] += optional_prompt - (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info) + (hidden_states,) = self.backends[inference_info.uid].inference_step( + hidden_states, hypo_ids, inference_info, **kwargs + ) return (hidden_states,) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 2c375666d..d4898c98a 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -18,7 +18,7 @@ from petals.server.task_prioritizer import TaskPrioritizerBase from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy -from petals.utils.packaging import unpack_args_kwargs +from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs # We prioritize short inference requests and make them use a *merged* inference pool, # so they are processed without interruptions and extra overheads @@ -31,36 +31,40 @@ async def run_rpc_forward( *flat_tensors: torch.Tensor, + args_structure: Any, requested_backends: Sequence[TransformerBackend], active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - args_structure: Any = None, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors - :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg + :note: see pack_args_kwargs function for the definition of args_structure :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :param active_adapter: the name of LoRA adapter to use; defaults to no adapter + :param prioritizer: assigns priorities to each sub-request based on the number of points + :param points: client-specified number of points, used to assign priorities + :param args_structure: :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ - if args_structure is not None: - # TODO: kwargs currently is unused, it can be used later for peft-like adaptation - flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) - hidden_states, prompts, *_ = flat_tensors - + requires_grad = any(tensor.requires_grad for tensor in flat_tensors) + flat_tensors = tuple(tensor.detach() for tensor in flat_tensors) + (hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) assert hidden_states.ndim == 3 + num_tokens = hidden_states.shape[0] * hidden_states.shape[1] if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends - for backend, prompt in zip(requested_backends, prompts): + for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt @@ -69,16 +73,18 @@ async def run_rpc_forward( hidden_states, points=points / len(requested_backends), backend=backend, type="forward" ) (hidden_states,) = await backend.forward_pool.submit_task( - hidden_states, active_adapter, + hidden_states, + **kwargs, priority=priority, + size=num_tokens, ) assert isinstance(hidden_states, torch.Tensor) assert ( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - return hidden_states + return hidden_states.requires_grad_(requires_grad) async def run_rpc_backward( @@ -87,58 +93,70 @@ async def run_rpc_backward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - args_structure: Any = None, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - if args_structure is not None: - # TODO: kwargs currently is unused, it can be used later for peft-like adaptation - flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) - inputs, grad_outputs, prompts, *_ = flat_tensors + args_structure: Any, +) -> Tuple[Sequence[torch.Tensor], Any]: + """A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests""" + assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad" + ((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs( + requested_backends, flat_tensors, args_structure + ) + input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) - grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + num_tokens = hidden_states.shape[0] * hidden_states.shape[1] + hidden_states = hidden_states.detach().to(requested_backends[0].dtype) + grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype) if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)] # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output inter_inputs = [] - for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs): + assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): - inputs[:, : prompt.shape[1]] += prompt - inter_inputs.append(inputs) + hidden_states[:, : prompt.shape[1]] += prompt + inter_inputs.append(hidden_states) assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( - inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward" ) - (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) - - assert isinstance(inputs, torch.Tensor) + (hidden_states,) = await backend.forward_pool.submit_task( + active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens + ) + assert isinstance(hidden_states, torch.Tensor), "intermediate hidden states is not a tensor" if not is_dummy(prompts[-1]): - inputs[:, : prompts[-1].shape[1]] += prompts[-1] - inter_inputs.append(inputs) + hidden_states[:, : prompts[-1].shape[1]] += prompts[-1] + inter_inputs.append(hidden_states) assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" grad_prompts_reversed = [] + grad_block_kwargs_reversed = [] + # Run a chain of requested backends - for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + for hidden_states, prompt, backend, kwargs in reversed( + list(zip(inter_inputs, prompts, requested_backends, block_kwargs)) + ): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + hidden_states = hidden_states.detach().requires_grad_(True) priority = prioritizer.prioritize( - inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + ) + (grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task( + active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) - assert isinstance(grad_outputs, torch.Tensor) - if not is_dummy(prompt): + if not is_dummy(prompt) and prompts_requires_grad: grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) + grad_block_kwargs_reversed.append(grad_kwargs) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY - return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] + return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed)))) async def iterate_rpc_inference( @@ -161,12 +179,11 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) - if args_structure is not None: - # TODO: kwargs currently is unused, it can be used later for peft-like adaptation - flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) - - hidden_states, prompts, hypo_ids, *_ = flat_tensors + (hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs( + requested_backends, flat_tensors, args_structure + ) batch_size, length_increment, _ = hidden_states.shape + num_tokens = batch_size * length_increment # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) @@ -209,13 +226,27 @@ async def iterate_rpc_inference( for uid, handles in zip(requested_uids, cache_handles) ) (hidden_states,) = await requested_backends[0].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + hidden_states, + hypo_ids, + inference_infos, + *prompts, + block_kwargs=block_kwargs, + priority=priority, + size=num_tokens, ) else: - for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts): + for backend, uid, handles, prompt, kwargs in zip( + requested_backends, requested_uids, cache_handles, prompts, block_kwargs + ): inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, prompt, priority=priority + hidden_states, + hypo_ids, + inference_infos, + prompt, + block_kwargs=(kwargs,), + priority=priority, + size=num_tokens, ) # serialize and send last layer outputs @@ -228,3 +259,29 @@ async def iterate_rpc_inference( # prepare for next step prefix_length += length_increment + + +def _check_inputs( + requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any +): + if len(flat_tensors) == 3: # backward compatibility for rpc_backward, remove after 2.3 + if flat_tensors[0].requires_grad and not flat_tensors[1].requires_grad: + hidden_states, grad_outputs, prompts = flat_tensors + flat_tensors = grad_outputs, hidden_states, prompts + if args_structure is not None: + args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure) + else: + args, *block_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2 + + if len(block_kwargs) not in (1, len(requested_backends)): + raise RuntimeError( + f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts " + f"(one for each block). Found {len(block_kwargs)} instead." + ) + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * len(requested_backends) + assert len(block_kwargs) == len(requested_backends) + for i, kwargs in enumerate(block_kwargs): + if not isinstance(kwargs, dict): + raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}") + return args, block_kwargs diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d8f0ec05e..8e4d84846 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -361,18 +361,19 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) args_structure = metadata.get("args_structure") + assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" - hidden_states = await run_rpc_forward( *flat_inputs, + args_structure=args_structure, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, - args_structure=args_structure, ) + return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) ) @@ -396,11 +397,11 @@ async def rpc_forward_stream( hidden_states = await run_rpc_forward( *flat_inputs, + args_structure=args_structure, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, - args_structure=args_structure, ) # Split the serialized_output for streaming and respond to client @@ -447,16 +448,18 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" - grads = await run_rpc_backward( + flat_grads, grads_structure = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, args_structure=args_structure, ) - return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) + serialized_flat_grads = self._serialize_grads(flat_grads, flat_tensors, metadata) + serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grads_structure)) + return runtime_pb2.ExpertResponse(tensors=serialized_flat_grads, metadata=serialized_output_metadata) async def rpc_backward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext @@ -474,18 +477,20 @@ async def rpc_backward_stream( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await run_rpc_backward( + flat_grads, grad_structure = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, args_structure=args_structure, ) # Split the serialized_grad_inputs for streaming and respond - for tensor in self._serialize_grads(grads, requested_backends, metadata): + serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grad_structure)) + for tensor in self._serialize_grads(flat_grads, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): - yield runtime_pb2.ExpertResponse(tensors=[part]) + yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata) + serialized_output_metadata = None # attach metadata to the first response only def _get_active_adapter(self, metadata: dict) -> str: active_adapter = metadata.get("active_adapter", "") @@ -495,28 +500,31 @@ def _get_active_adapter(self, metadata: dict) -> str: def _serialize_grads( self, - grads: Sequence[torch.Tensor], - requested_backends: Sequence[TransformerBackend], - metadata: Dict[str, Any], + flat_grads: Sequence[torch.Tensor], + flat_inputs: Sequence[runtime_pb2.Tensor], + input_metadata: Dict[str, Any], ) -> Sequence[runtime_pb2.Tensor]: """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema""" + inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad) + assert len(flat_grads) == len(inputs_with_grad), ( + f"user provides {len(inputs_with_grad)} inputs with grad, " + f"but backward produced {len(flat_grads)} gradients" + ) # Modify grad_inputs_schema to support grad_prompts - assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize - flat_grads_schema = tuple( - nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema)) - ) # TODO generalize - - if metadata.get("output_compression") is not None: - assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" - output_compression = tuple(metadata["output_compression"]) + if input_metadata.get("output_compression") is not None: + output_compression = input_metadata["output_compression"] + assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list" assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" - assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements" + assert len(output_compression) == len(flat_grads), ( + f"output_compression should have {len(flat_grads)} " + f"elements, one for every tensor thar requires grad" + ) else: - output_compression = tuple(tensor.compression for tensor in flat_grads_schema) - + output_compression = tuple(runtime_pb2.NONE for _ in flat_grads) + output_compression = tuple(output_compression) return [ - serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) - for result, proto, compression in zip(grads, flat_grads_schema, output_compression) + serialize_torch_tensor(result.to(input.dtype), compression, allow_inplace=True) + for result, input, compression in zip(flat_grads, inputs_with_grad, output_compression) ] def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 45884e3b7..5769adb33 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -8,7 +8,7 @@ import sys import threading import time -from typing import Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import hivemind import psutil @@ -17,6 +17,7 @@ from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime +from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig @@ -773,3 +774,15 @@ class RuntimeWithDeduplicatedPools(Runtime): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pools = tuple(set(self.pools)) + + def process_batch( + self, pool: TaskPoolBase, batch_index: int, args: Sequence[Any], kwargs: Dict[str, Any] + ) -> Tuple[Any, int]: + """process one batch of tasks from a given pool, return a batch of results and total batch size""" + outputs = pool.process_func(*args, **kwargs) + batch_size = 1 + for arg in args: + if isinstance(arg, torch.Tensor) and arg.ndim > 2: + batch_size = arg.shape[0] * arg.shape[1] + break + return outputs, batch_size diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 94bad7904..c39b40157 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -4,13 +4,17 @@ import time from concurrent.futures._base import PENDING from dataclasses import dataclass, field +from functools import partial from queue import PriorityQueue from typing import Any, List, Optional, Sequence, Tuple, Union import torch -from hivemind import get_logger +from hivemind import get_logger, nested_map +from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.utils.mpfuture import ALL_STATES, MPFuture +from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs + logger = get_logger(__name__) @@ -18,8 +22,10 @@ class Task: priority: float time_submitted: float + size: int future: MPFuture = field(compare=False) - args: Sequence[torch.Tensor] = field(compare=False) + flat_tensors: Sequence[torch.Tensor] = field(compare=False) + structure: Any @property def uid(self) -> int: @@ -92,15 +98,14 @@ def terminate(self): def shutdown(self): self.submitted_tasks.put(None) # Shuts down self.run() - def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: + def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture: """Add task to this pool's queue, return Future for its output""" future = MPFuture() # Remove shmem from MPFuture. This disables the .cancel() feature but # saves the server from "could not unlink the shared memory file" crashes during rebalancing future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8) - - task = Task(priority, time.monotonic(), future, args) - if self.get_task_size(task) > self.max_batch_size: + task = Task(priority, time.monotonic(), size, future, *pack_args_kwargs(*args, **kwargs)) + if task.size > self.max_batch_size: exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed") task.future.set_exception(exc) else: @@ -110,33 +115,27 @@ def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: self.priority = (task.priority, task.time_submitted) return task.future - def get_task_size(self, task: Task) -> int: - """compute task processing complexity; defaults to the total number of tokens""" - if task.args and task.args[0].ndim >= 2: - return task.args[0].shape[0] * task.args[0].shape[1] - return 1 - def load_batch_to_runtime( self, timeout: Optional[float] = None, device: Optional[torch.device] = None - ) -> Tuple[Any, List[torch.Tensor]]: + ) -> Tuple[int, Any]: """receive next batch of arrays""" device = device if device is not None else self.device task = self._ordered_tasks.get(block=True, timeout=timeout) - batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args] + device_flat_tensors = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.flat_tensors] self._dispatched_tasks[task.uid] = task self.batch_receiver.recv() # reduce the number of active batches if not self._ordered_tasks.empty(): first_remaining_task: Task = self._ordered_tasks.queue[0] self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted) - return task.uid, batch_inputs + return task.uid, unpack_args_kwargs(device_flat_tensors, task.structure) def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]): """send results for a processed batch, previously loaded through load_batch_to_runtime""" - batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs] + batch_outputs = nested_map(partial(_move_to_device_if_tensor, device="cpu", share_memory=True), batch_outputs) task = self._dispatched_tasks.pop(uid, None) if task is None: logger.error( - f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result" + f"Internal error: task task with index {uid} is missing from the dictionary; Could not set result" ) else: task.future.set_result(batch_outputs) diff --git a/src/petals/utils/packaging.py b/src/petals/utils/packaging.py index c6d9faa3d..a50e8b25b 100644 --- a/src/petals/utils/packaging.py +++ b/src/petals/utils/packaging.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Sequence, Tuple import torch from hivemind import nested_flatten, nested_pack @@ -18,7 +18,7 @@ def _get_tensor_index(item: bytes) -> int: return int(item[3:]) -def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]: +def pack_args_kwargs(*args, **kwargs) -> Tuple[Sequence[torch.Tensor], Any]: """ Check the function's arguments and pack all tensors into different flattened lists. :returns: a flattened list of tensors and args and kwargs, where tensors were masked @@ -35,7 +35,7 @@ def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]: return flat_tensors, nested_pack(masked_flat_values, (args, kwargs)) -def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any): +def unpack_args_kwargs(flat_tensors: Sequence[torch.Tensor], args_structure: Any): """ Restore arguments after `pack_args_kwargs` function. :returns: list of args and dict of kwargs diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py index 1a0b1da47..15c6de5f4 100644 --- a/tests/test_priority_pool.py +++ b/tests/test_priority_pool.py @@ -4,8 +4,8 @@ import pytest import torch -from hivemind.moe.server.runtime import Runtime +from petals.server.server import RuntimeWithDeduplicatedPools from petals.server.task_pool import PrioritizedTaskPool @@ -57,7 +57,9 @@ def get_pools(self): proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid)) proc.start() - runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) + runtime = RuntimeWithDeduplicatedPools( + {str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0 + ) runtime.ready = runtime_ready runtime.start() diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 20c6011ef..90c403332 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -73,8 +73,8 @@ def rpc_info(self): rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs) return rpc_info - def get_request_metadata(self, protocol: str, *args, **kwargs): - metadata = super().get_request_metadata(protocol, *args, **kwargs) + def get_request_metadata(self, peer_id, protocol, block_uids, *args, **kwargs): + metadata = super().get_request_metadata(peer_id, protocol, block_uids, *args, **kwargs) if protocol == "rpc_forward": metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) elif protocol == "rpc_backward":