Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

periodic sync upstream KF to midstream ODH #118

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"Area/UI":
- changed-files:
- any-glob-to-any-file: "clients/ui/**"

"Area/MR Python client":
- changed-files:
- any-glob-to-any-file: "clients/python/**"

"Area/Go REST server":
- changed-files:
- any-glob-to-any-file:
- "api/**"
- "cmd/**"
- "internal/**"
- "patches/**"
- "pkg/**"
- "templates/go-server/**"

"Area/CSI":
- changed-files:
- any-glob-to-any-file: "csi/**"

"Area/Manifests":
- changed-files:
- any-glob-to-any-file: "manifests/**"

"Area/Documentation":
- changed-files:
- any-glob-to-any-file: "docs/**"

"Area/GitHub":
- changed-files:
- any-glob-to-any-file: ".github/**"
12 changes: 12 additions & 0 deletions .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "Pull Request Labeler"
on:
- pull_request_target

jobs:
labeler:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
9 changes: 8 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
if [[ ${{ matrix.session }} == "tests" ]]; then
make build-mr
nox --python=${{ matrix.python }} -- --cov-report=xml
poetry build
elif [[ ${{ matrix.session }} == "mypy" ]]; then
nox --python=${{ matrix.python }} ||\
echo "::error title='mypy failure'::Check the logs for more details"
Expand All @@ -80,9 +81,15 @@ jobs:
files: coverage.xml
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload dist
if: matrix.session == 'tests' && matrix.python == '3.12'
uses: actions/upload-artifact@v4
with:
name: py-dist
path: clients/python/dist
- name: Upload documentation
if: matrix.session == 'docs-build'
uses: actions/upload-artifact@v4
with:
name: docs
name: py-docs
path: clients/python/docs/_build
4 changes: 2 additions & 2 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ components:
enum:
- CREATE_TIME
- LAST_UPDATE_TIME
- Id
- ID
type: string
Artifact:
oneOf:
Expand Down Expand Up @@ -1661,7 +1661,7 @@ components:
explode: true
examples:
orderBy:
value: Id
value: ID
name: orderBy
description: Specifies the order by criteria for listing entities.
schema:
Expand Down
6 changes: 6 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ model = registry.get_registered_model("my-model")
version = registry.get_model_version("my-model", "2.0.0")

experiment = registry.get_model_artifact("my-model", "2.0.0")

# change is not reflected on pushed model version
version.description = "Updated model version"

# you can update it using
registry.update(version)
```

### Importing from S3
Expand Down
5 changes: 4 additions & 1 deletion clients/python/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def lint(session: Session) -> None:
def mypy(session: Session) -> None:
"""Type check using mypy."""
session.install(".")
session.install("mypy")
session.install(
"mypy",
"types-python-dateutil",
)

session.run("mypy", "src/model_registry")

Expand Down
14 changes: 12 additions & 2 deletions clients/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mypy = "^1.7.0"
pytest-asyncio = ">=0.23.7,<0.25.0"
requests = "^2.32.2"
black = "^24.4.2"
types-python-dateutil = "^2.9.0.20240906"

[tool.coverage.run]
branch = true
Expand Down
27 changes: 22 additions & 5 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Any, get_args
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .core import ModelRegistryAPIClient
Expand All @@ -18,6 +18,9 @@
SupportedTypes,
)

ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)


class ModelRegistry:
"""Model registry client."""
Expand All @@ -29,7 +32,7 @@ def __init__(
*,
author: str,
is_secure: bool = True,
user_token: bytes | None = None,
user_token: str | None = None,
custom_ca: str | None = None,
):
"""Constructor.
Expand All @@ -41,8 +44,8 @@ def __init__(
Keyword Args:
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a byte string. Defaults to path on envvar CERT.
user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT.
"""
import nest_asyncio

Expand All @@ -55,7 +58,7 @@ def __init__(
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
user_token = Path(sa_token).read_text()
else:
warn("User access token is missing", stacklevel=2)

Expand Down Expand Up @@ -191,6 +194,20 @@ def register_model(

return rm

def update(self, model: TModel) -> TModel:
"""Update a model."""
if not model.id:
msg = "Model must have an ID"
raise StoreError(msg)
if not isinstance(model, get_args(ModelTypes)):
msg = f"Model must be one of {get_args(ModelTypes)}"
raise StoreError(msg)
if isinstance(model, RegisteredModel):
return self.async_runner(self._api.upsert_registered_model(model))
if isinstance(model, ModelVersion):
return self.async_runner(self._api.upsert_model_version(model, model.id))
return self.async_runner(self._api.upsert_model_artifact(model, model.id))

def register_hf_model(
self,
repo: str,
Expand Down
8 changes: 4 additions & 4 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def secure_connection(
server_address: str,
port: int = 443,
*,
user_token: bytes,
user_token: str,
custom_ca: str | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Expand All @@ -52,7 +52,7 @@ def secure_connection(
port: Server port. Defaults to 443.

Keyword Args:
user_token: The PEM-encoded user token as a byte string.
user_token: The PEM-encoded user token as a string.
custom_ca: The path to a PEM-
"""
return cls(
Expand All @@ -68,14 +68,14 @@ def insecure_connection(
cls,
server_address: str,
port: int,
user_token: bytes | None = None,
user_token: str | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.

Args:
server_address: Server address.
port: Server port.
user_token: The PEM-encoded user token as a byte string.
user_token: The PEM-encoded user token as a string.
"""
return cls(
Configuration(host=f"{server_address}:{port}", access_token=user_token)
Expand Down
16 changes: 11 additions & 5 deletions clients/python/src/model_registry/types/pager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def restart(self) -> Pager[T]:
This keeps the current options and page function, but resets the internal state.
"""
# as MLMD loops over pages, we need to keep track of the first page or we'll loop forever
self._start = None
self._current_page = None
self._start: str | None = None
self._current_page: list[T] | None = None
# tracks the next item on the current page
self._i = 0
self.options.next_page_token = None
Expand Down Expand Up @@ -112,7 +112,9 @@ async def _anext_page(self) -> list[T]:
return await cast(Awaitable[list[T]], self.page_fn(self.options))

def _needs_fetch(self) -> bool:
return not self._current_page or self._i >= len(self._current_page)
return not self._current_page or (
self._i >= len(self._current_page) and self._start is not None
)

def _next_item(self) -> T:
"""Get the next item in the pager.
Expand All @@ -126,6 +128,8 @@ def _next_item(self) -> T:
self._current_page = self._next_page()
self._i = 0
assert self._current_page
if self._i >= len(self._current_page):
raise StopIteration

item = self._current_page[self._i]
self._i += 1
Expand All @@ -143,6 +147,8 @@ async def _anext_item(self) -> T:
self._current_page = await self._anext_page()
self._i = 0
assert self._current_page
if self._i >= len(self._current_page):
raise StopIteration

item = self._current_page[self._i]
self._i += 1
Expand All @@ -153,7 +159,7 @@ def __next__(self) -> T:

item = self._next_item()

if not self._start:
if self._start is None:
self._start = self.options.next_page_token
elif check_looping and self.options.next_page_token == self._start:
raise StopIteration
Expand All @@ -165,7 +171,7 @@ async def __anext__(self) -> T:

item = await self._anext_item()

if not self._start:
if self._start is None:
self._start = self.options.next_page_token
elif check_looping and self.options.next_page_token == self._start:
raise StopAsyncIteration
Expand Down
2 changes: 1 addition & 1 deletion clients/python/src/mr_openapi/models/order_by_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OrderByField(str, Enum):
"""
CREATE_TIME = "CREATE_TIME"
LAST_UPDATE_TIME = "LAST_UPDATE_TIME"
ID = "Id"
ID = "ID"

@classmethod
def from_json(cls, json_str: str) -> Self:
Expand Down
16 changes: 16 additions & 0 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import subprocess
import tempfile
import time
from contextlib import asynccontextmanager
from pathlib import Path
Expand Down Expand Up @@ -133,3 +134,18 @@ def event_loop():
@cleanup
def client() -> ModelRegistry:
return ModelRegistry(REGISTRY_HOST, REGISTRY_PORT, author="author", is_secure=False)

@pytest.fixture(scope="module")
def setup_env_user_token():
with tempfile.NamedTemporaryFile(delete=False) as token_file:
token_file.write(b"Token")
old_token_path = os.getenv("KF_PIPELINES_SA_TOKEN_PATH")
os.environ["KF_PIPELINES_SA_TOKEN_PATH"] = token_file.name

yield token_file.name

if old_token_path is None:
del os.environ["KF_PIPELINES_SA_TOKEN_PATH"]
else:
os.environ["KF_PIPELINES_SA_TOKEN_PATH"] = old_token_path
os.remove(token_file.name)
Loading