Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Updates: routing table, models endpoint, inference routing, rebase on safety_refactor #85

Open
wants to merge 45 commits into
base: safety_refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
9bb6ce5
example config
yanxi0830 Sep 20, 2024
7d4135d
add new resolve_impls_with_routing
yanxi0830 Sep 20, 2024
cda6111
migrate router for memory wip
yanxi0830 Sep 20, 2024
9c33587
delete router from providers
yanxi0830 Sep 20, 2024
3787408
clean up
yanxi0830 Sep 20, 2024
8df53ac
simple run config
yanxi0830 Sep 20, 2024
308a1d1
backward compatibility
yanxi0830 Sep 20, 2024
a6be32b
stage tmp changes
yanxi0830 Sep 20, 2024
06abd7e
update MemoryToolDefinition
yanxi0830 Sep 21, 2024
32beecb
Add a special header per-client call to parser provider data
ashwinb Sep 18, 2024
7e40eea
safety API cleanup part 1
ashwinb Sep 20, 2024
82ddd85
Update the meta reference safety implementation to match new API
Sep 20, 2024
d6a41d9
Update safety implementation inside agents
Sep 20, 2024
9252e81
test safety against safety client
ashwinb Sep 20, 2024
a57411b
Further bug fixes
ashwinb Sep 20, 2024
446914e
Add a special header per-client call to parser provider data
ashwinb Sep 18, 2024
e3fc36d
Revert "Add a special header per-client call to parser provider data"
yanxi0830 Sep 21, 2024
73133fb
Revert "stage tmp changes"
yanxi0830 Sep 21, 2024
5f9a7dc
Revert "backward compatibility"
yanxi0830 Sep 21, 2024
cbd4fa6
Revert "simple run config"
yanxi0830 Sep 21, 2024
ee77431
Revert "clean up"
yanxi0830 Sep 21, 2024
665ab1f
Revert "delete router from providers"
yanxi0830 Sep 21, 2024
39c27a3
Revert "migrate router for memory wip"
yanxi0830 Sep 21, 2024
32b9907
Revert "add new resolve_impls_with_routing"
yanxi0830 Sep 21, 2024
abe312c
Revert "example config"
yanxi0830 Sep 21, 2024
2dc14cb
stage tmp changes
yanxi0830 Sep 20, 2024
85d927a
skeleton unified routing table, api routers
yanxi0830 Sep 21, 2024
951cc9d
router table registration works
yanxi0830 Sep 21, 2024
04f480d
router method wrapper
yanxi0830 Sep 21, 2024
f058025
memory routers working
yanxi0830 Sep 21, 2024
8bf8c07
Respect user sent instructions in agent config and add them to system…
Sep 21, 2024
20a4302
models API
yanxi0830 Sep 22, 2024
c019902
supported models wip
yanxi0830 Sep 22, 2024
0348f26
models endpoint testing
yanxi0830 Sep 22, 2024
d29405d
update MemoryToolDefinition
yanxi0830 Sep 21, 2024
8e757ed
Respect user sent instructions in agent config and add them to system…
Sep 21, 2024
0b715c0
Add a special header per-client call to parser provider data
ashwinb Sep 18, 2024
9380661
Add a special header per-client call to parser provider data
ashwinb Sep 18, 2024
bafb0ce
Revert "Add a special header per-client call to parser provider data"
yanxi0830 Sep 21, 2024
d027eab
Merge branch 'main' into new_router
yanxi0830 Sep 22, 2024
c0f2f94
delete docs
yanxi0830 Sep 22, 2024
b5217fe
fix configure
yanxi0830 Sep 22, 2024
e42b555
Merge branch 'safety_refactor' into new_router
yanxi0830 Sep 22, 2024
44fe099
update example run files
yanxi0830 Sep 22, 2024
b8914bb
add safety/list_shields to query available shields
yanxi0830 Sep 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class CustomMemoryQueryGeneratorConfig(BaseModel):
]


@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
Expand Down
8 changes: 3 additions & 5 deletions llama_stack/apis/inference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,9 @@ async def run_main(host: str, port: int, stream: bool):
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(
model="Meta-Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
model="Meta-Llama3.1-8B-Instruct",
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
Expand Down
72 changes: 72 additions & 0 deletions llama_stack/apis/models/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import json
from pathlib import Path

from typing import Any, Dict, List, Optional

import fire
import httpx

from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint

from .models import * # noqa: F403


class ModelsClient(Models):
def __init__(self, base_url: str):
self.base_url = base_url

async def initialize(self) -> None:
pass

async def shutdown(self) -> None:
pass

async def list_models(self) -> ModelsListResponse:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsListResponse(**response.json())

async def get_model(self, core_model_id: str) -> ModelsGetResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/models/get",
json={
"core_model_id": core_model_id,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return ModelsGetResponse(**response.json())


async def run_main(host: str, port: int, stream: bool):
client = ModelsClient(f"http://{host}:{port}")

response = await client.list_models()
cprint(f"list_models response={response}", "green")

response = await client.get_model("Meta-Llama3.1-8B-Instruct")
cprint(f"get_model response={response}", "blue")

response = await client.get_model("Llama-Guard-3-8B")
cprint(f"get_model response={response}", "red")


def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))


if __name__ == "__main__":
fire.Fire(main)
39 changes: 35 additions & 4 deletions llama_stack/apis/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,42 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Protocol
from typing import Any, Dict, List, Optional, Protocol

from llama_models.schema_utils import webmethod # noqa: F401
from llama_models.llama3.api.datatypes import Model

from pydantic import BaseModel # noqa: F401
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field


class Models(Protocol): ...
@json_schema_type
class ModelServingSpec(BaseModel):
llama_model: Model = Field(
description="All metadatas associated with llama model (defined in llama_models.models.sku_list).",
)
provider_config: GenericProviderConfig = Field(
description="Provider config for the model, including provider_id, and corresponding config. ",
)
api: str = Field(
description="The API that this model is serving (e.g. inference / safety).",
default="inference",
)


@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelServingSpec]


@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelServingSpec] = None


class Models(Protocol):
@webmethod(route="/models/list", method="GET")
async def list_models(self) -> ModelsListResponse: ...

@webmethod(route="/models/get", method="POST")
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...
24 changes: 22 additions & 2 deletions llama_stack/apis/safety/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import httpx

from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import RemoteProviderConfig
from pydantic import BaseModel
from termcolor import cprint

from llama_stack.distribution.datatypes import RemoteProviderConfig

from llama_stack.apis.safety import * # noqa: F403


Expand Down Expand Up @@ -62,6 +61,24 @@ async def run_shield(
content = response.json()
return RunShieldResponse(**content)

async def list_shields(self) -> ListShieldsResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/list_shields",
json={},
headers={"Content-Type": "application/json"},
timeout=20,
)

if response.status_code != 200:
content = await response.aread()
error = f"Error: HTTP {response.status_code} {content.decode()}"
cprint(error, "red")
raise Exception(error)

content = response.json()
return ListShieldsResponse(**content)


async def run_main(host: str, port: int):
client = SafetyClient(f"http://{host}:{port}")
Expand All @@ -83,6 +100,9 @@ async def run_main(host: str, port: int):
)
print(response)

response = await client.list_shields()
print(response)


def main(host: str, port: int):
asyncio.run(run_main(host, port))
Expand Down
10 changes: 9 additions & 1 deletion llama_stack/apis/safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# the root directory of this source tree.

from enum import Enum
from typing import Any, Dict, Protocol
from typing import Any, Dict, List, Protocol

from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
Expand Down Expand Up @@ -37,8 +37,16 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None


@json_schema_type
class ListShieldsResponse(BaseModel):
shields: List[str] = None


class Safety(Protocol):
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...

@webmethod(route="/safety/list_shields")
async def list_shields(self) -> ListShieldsResponse: ...
4 changes: 2 additions & 2 deletions llama_stack/cli/stack/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None:
import pkg_resources

import yaml
from termcolor import cprint

from llama_stack.distribution.build import ImageType
from llama_stack.distribution.utils.exec import run_with_pty
from termcolor import cprint

docker_image = None

Expand Down Expand Up @@ -121,10 +121,10 @@ def _configure_llama_distribution(
from pathlib import Path

import yaml
from termcolor import cprint

from llama_stack.distribution.configure import configure_api_providers
from llama_stack.distribution.utils.serialize import EnumEncoder
from termcolor import cprint

builds_dir = BUILDS_BASE_DIR / build_config.image_type
if output_dir:
Expand Down
76 changes: 63 additions & 13 deletions llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Api(Enum):
agents = "agents"
memory = "memory"
telemetry = "telemetry"
models = "models"


@json_schema_type
Expand All @@ -43,31 +44,65 @@ class ProviderSpec(BaseModel):
)


class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]


@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str


@json_schema_type
class RouterProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""

docker_image: Optional[str] = None

inner_specs: List[ProviderSpec]
routing_table: List[ProviderRoutingEntry] = Field(
default_factory=list,
description="Routing table entries corresponding to the API",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
Fully-qualified name of the module to import. The module is expected to have:

- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)

@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on RouterProviderSpec")


class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
@json_schema_type
class BuiltinProviderSpec(ProviderSpec):
provider_id: str = "builtin"
config_class: str = ""
docker_image: Optional[str] = None
api_dependencies: List[Api] = []
provider_data_validator: Optional[str] = Field(
default=None,
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:

- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)


@json_schema_type
Expand Down Expand Up @@ -95,6 +130,10 @@ class AdapterSpec(BaseModel):
provider_data_validator: Optional[str] = Field(
default=None,
)
supported_model_ids: List[str] = Field(
default_factory=list,
description="The list of model ids that this adapter supports",
)


@json_schema_type
Expand Down Expand Up @@ -204,12 +243,7 @@ class DistributionSpec(BaseModel):
)


@json_schema_type
class ProviderRoutingEntry(GenericProviderConfig):
routing_key: str


ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
ProviderMapEntry = GenericProviderConfig


@json_schema_type
Expand Down Expand Up @@ -248,6 +282,22 @@ class StackRunConfig(BaseModel):

The key may support wild-cards alsothe routing_key to route to the correct provider.""",
)
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
default_factory=dict,
description="""
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.

E.g. The following is a ProviderRoutingEntry for inference API:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
""",
)


@json_schema_type
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry

Expand Down Expand Up @@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
}

for api, protocol in protocols.items():
Expand Down
Loading