Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swap vecs collection for users #1014

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 84 additions & 16 deletions .github/workflows/build-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,29 @@ on:
type: string

jobs:
build-and-publish:
runs-on: [ self-hosted, Linux ]
prepare:
runs-on: ubuntu-latest
outputs:
release_version: ${{ steps.version.outputs.RELEASE_VERSION }}
steps:
- name: Determine version to use
id: version
run: |
if [ -n "${{ github.event.inputs.version }}" ]; then
echo "RELEASE_VERSION=${{ github.event.inputs.version }}" >> $GITHUB_OUTPUT
else
echo "RELEASE_VERSION=main" >> $GITHUB_OUTPUT
fi

build-and-publish-amd64:
needs: prepare
runs-on: amd2
permissions:
packages: write
contents: read
id-token: write
actions: write
timeout-minutes: 720
steps:
- name: Checkout Repository
uses: actions/checkout@v4
Expand All @@ -40,17 +56,69 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
with:
images: |
ragtoriches/prod
us-east1-docker.pkg.dev/alert-rush-397022/sciphi-r2r/r2r
tags: |
type=raw,value=${{ needs.prepare.outputs.release_version }}
type=raw,value=latest

- name: Determine version to use
id: version
run: |
if [ -n "${{ github.event.inputs.version }}" ]; then
echo "RELEASE_VERSION=${{ github.event.inputs.version }}" >> $GITHUB_OUTPUT
else
echo "RELEASE_VERSION=main" >> $GITHUB_OUTPUT
fi
- name: Build and Push Docker Image
uses: docker/build-push-action@v5
with:
context: ./py
file: ./py/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
platforms: linux/amd64

- name: Build and Push Docker Image (Unstructured)
uses: docker/build-push-action@v5
with:
context: ./py
file: ./py/Dockerfile.unstructured
push: true
tags: ${{ steps.meta.outputs.tags }}-unstructured
labels: ${{ steps.meta.outputs.labels }}-unstructured
platforms: linux/amd64

build-and-publish-arm64:
needs: prepare
runs-on: arm2
permissions:
packages: write
contents: read
id-token: write
actions: write
timeout-minutes: 720
steps:
- name: Checkout Repository
uses: actions/checkout@v4

- name: Google Auth
uses: 'google-github-actions/auth@v2'
with:
credentials_json: '${{ secrets.GCP_SA_KEY }}'

- name: Set up Cloud SDK
uses: 'google-github-actions/setup-gcloud@v2'

- name: Configure SDK
run: 'gcloud auth configure-docker us-east1-docker.pkg.dev'

- name: Docker Auth
uses: docker/login-action@v3
with:
username: ${{ secrets.RAGTORICHES_DOCKER_UNAME }}
password: ${{ secrets.RAGTORICHES_DOCKER_TOKEN }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Extract metadata (tags, labels) for Docker
id: meta
Expand All @@ -60,7 +128,7 @@ jobs:
ragtoriches/prod
us-east1-docker.pkg.dev/alert-rush-397022/sciphi-r2r/r2r
tags: |
type=raw,value=${{ steps.version.outputs.RELEASE_VERSION }}
type=raw,value=${{ needs.prepare.outputs.release_version }}
type=raw,value=latest

- name: Build and Push Docker Image
Expand All @@ -71,14 +139,14 @@ jobs:
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
platforms: linux/amd64,linux/arm64
platforms: linux/arm64

- name: Build and Push Docker Image
- name: Build and Push Docker Image (Unstructured)
uses: docker/build-push-action@v5
with:
context: ./py
file: ./py/Dockerfile.unstructured
push: true
tags: ${{ steps.meta.outputs.tags }}-unstructured
labels: ${{ steps.meta.outputs.labels }}-unstructured
platforms: linux/amd64,linux/arm64
platforms: linux/arm64
4 changes: 2 additions & 2 deletions .github/workflows/py-ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
jobs:
pre-commit:
continue-on-error: true
runs-on: [ self-hosted, Linux ]
runs-on: amd1

steps:
- name: Checkout code
Expand Down Expand Up @@ -41,7 +41,7 @@ jobs:

pytest:
continue-on-error: true
runs-on: [ self-hosted, Linux ]
runs-on: amd1
timeout-minutes: 15

env:
Expand Down
1 change: 1 addition & 0 deletions py/core/base/abstractions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ class UserStats(BaseModel):
num_files: int
total_size_in_bytes: int
document_ids: list[UUID]
vecs_collection: str
1 change: 1 addition & 0 deletions py/core/base/api/models/auth/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class UserResponse(BaseModel):
updated_at: datetime = datetime.now()
is_verified: bool = False
group_ids: list[UUID] = []
vecs_collection: Optional[str] = None

# Optional fields (to update or set at creation)
hashed_password: Optional[str] = None
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def run(
state: Optional[AsyncState] = None,
stream: bool = False,
run_manager: Optional[RunManager] = None,
vecs_collection: Optional[str] = None,
*args: Any,
**kwargs: Any,
):
Expand All @@ -64,6 +65,7 @@ async def run(
pipe_num,
current_input,
run_manager,
vecs_collection,
*args,
**kwargs,
)
Expand Down
8 changes: 7 additions & 1 deletion py/core/base/pipes/base_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def run(
input: Input,
state: AsyncState,
run_manager: Optional[RunManager] = None,
vecs_collection: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[Any, None]:
Expand All @@ -140,7 +141,12 @@ async def wrapped_run() -> AsyncGenerator[Any, None]:
)
try:
async for result in self._run_logic(
input, state, run_id=run_id, *args, **kwargs
input,
state,
run_id=run_id,
vecs_collection=vecs_collection,
*args,
**kwargs,
):
yield result
finally:
Expand Down
16 changes: 15 additions & 1 deletion py/core/main/api/routes/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ async def search_app(
Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.

"""
vecs_collection = None
if (
hasattr(auth_user, "vecs_collection")
and auth_user.vecs_collection is not None
):
vecs_collection = auth_user.vecs_collection

user_groups = set(auth_user.group_ids)
selected_groups = set(vector_search_settings.selected_group_ids)
allowed_groups = user_groups.intersection(selected_groups)
Expand All @@ -85,7 +92,6 @@ async def search_app(
filters = {
"$or": [
{"user_id": {"$eq": str(auth_user.id)}},
# {"group_ids": {"$any": list([str(ele) for ele in allowed_groups])}},
{"group_ids": {"$overlap": list(allowed_groups)}},
]
}
Expand All @@ -97,6 +103,7 @@ async def search_app(
query=query,
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
vecs_collection=vecs_collection,
)
return results

Expand Down Expand Up @@ -136,6 +143,12 @@ async def rag_app(

The generation process can be customized using the rag_generation_config parameter.
"""
vecs_collection = None
if (
hasattr(auth_user, "vecs_collection")
and auth_user.vecs_collection is not None
):
vecs_collection = auth_user.vecs_collection
allowed_groups = set(auth_user.group_ids)
filters = {
"$or": [
Expand All @@ -154,6 +167,7 @@ async def rag_app(
kg_search_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
task_prompt_override=task_prompt_override,
vecs_collection=vecs_collection,
)

if rag_generation_config.stream:
Expand Down
7 changes: 7 additions & 0 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def search(
query: str,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
vecs_collection: Optional[str] = None,
*args,
**kwargs,
) -> SearchResponse:
Expand Down Expand Up @@ -97,6 +98,7 @@ async def search(
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
run_manager=self.run_manager,
vecs_collection=vecs_collection,
*args,
**kwargs,
)
Expand All @@ -119,6 +121,7 @@ async def rag(
rag_generation_config: GenerationConfig,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
vecs_collection: Optional[str] = None,
*args,
**kwargs,
) -> RAGResponse:
Expand Down Expand Up @@ -151,6 +154,7 @@ async def rag(
rag_generation_config,
vector_search_settings,
kg_search_settings,
vecs_collection=vecs_collection,
*args,
**kwargs,
)
Expand All @@ -161,6 +165,7 @@ async def rag(
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
vecs_collection=vecs_collection,
*args,
**kwargs,
)
Expand Down Expand Up @@ -213,6 +218,7 @@ async def stream_rag_response(
rag_generation_config,
vector_search_settings,
kg_search_settings,
vecs_collection=None,
*args,
**kwargs,
):
Expand All @@ -227,6 +233,7 @@ async def stream_response():
kg_search_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
completion_record=completion_record,
vecs_collection=vecs_collection,
*args,
**kwargs,
):
Expand Down
2 changes: 2 additions & 0 deletions py/core/pipelines/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def run(
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
rag_generation_config: GenerationConfig = GenerationConfig(),
vecs_collection: Optional[str] = None,
*args: Any,
**kwargs: Any,
):
Expand All @@ -55,6 +56,7 @@ async def multi_query_generator(input):
run_manager=run_manager,
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
vecs_collection=vecs_collection,
*args,
**kwargs,
)
Expand Down
2 changes: 2 additions & 0 deletions py/core/pipelines/search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def run(
run_manager: Optional[RunManager] = None,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
kg_search_settings: KGSearchSettings = KGSearchSettings(),
vecs_collection: Optional[str] = None,
*args: Any,
**kwargs: Any,
):
Expand Down Expand Up @@ -76,6 +77,7 @@ async def enqueue_requests():
stream,
run_manager,
vector_search_settings=vector_search_settings,
vecs_collection=vecs_collection,
*args,
**kwargs,
)
Expand Down
4 changes: 4 additions & 0 deletions py/core/pipes/retrieval/vector_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def search(
message: str,
run_id: UUID,
vector_search_settings: VectorSearchSettings,
vecs_collection: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[VectorSearchResult, None]:
Expand Down Expand Up @@ -70,6 +71,7 @@ async def search(
else self.database_provider.vector.semantic_search(
query_vector=query_vector,
search_settings=vector_search_settings,
vecs_collection=vecs_collection,
)
)
reranked_results = self.embedding_provider.rerank(
Expand Down Expand Up @@ -104,6 +106,7 @@ async def _run_logic(
state: AsyncState,
run_id: UUID,
vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
vecs_collection: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[VectorSearchResult, None]:
Expand All @@ -115,6 +118,7 @@ async def _run_logic(
message=search_request,
run_id=run_id,
vector_search_settings=vector_search_settings,
vecs_collection=vecs_collection,
*args,
**kwargs,
):
Expand Down
Loading
Loading