Skip to content

Commit

Permalink
update auth tests
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Aug 23, 2024
1 parent d603a6b commit 14a396a
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 59 deletions.
4 changes: 3 additions & 1 deletion py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def refresh_access_token(self, refresh_token: str) -> Dict[str, Token]:
new_access_token = self.create_access_token(
data={"sub": token_data.email}
)
new_refresh_token = self.create_refresh_token()
new_refresh_token = self.create_refresh_token(
data={"sub": token_data.email}
)
return {
"access_token": Token(token=new_access_token, token_type="access"),
"refresh_token": Token(
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class GroupMixin(DatabaseMixin):
def create_table(self) -> None:
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name('groups')} (
group_id UUID PRIMARY KEY DEFAULT public.uuid_generate_v4(),
group_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BlacklistedTokensMixin(DatabaseMixin):
def create_table(self):
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name('blacklisted_tokens')} (
id UUID PRIMARY KEY DEFAULT public.uuid_generate_v4(),
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
token TEXT NOT NULL,
blacklisted_at TIMESTAMPTZ DEFAULT NOW()
);
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class UserMixin(DatabaseMixin):
def create_table(self):
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name('users')} (
user_id UUID PRIMARY KEY DEFAULT public.uuid_generate_v4(),
user_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
email TEXT UNIQUE NOT NULL,
hashed_password TEXT NOT NULL,
is_superuser BOOLEAN DEFAULT FALSE,
Expand Down
7 changes: 2 additions & 5 deletions py/core/providers/database/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,11 +800,8 @@ def parse_condition(key, value):
else:
# Handle JSON-based filters
json_col = self.table.c.metadata
if not key.startswith("metadata."):
raise FilterError(
"metadata key must start with 'metadata.'"
)
key = key.split("metadata.")[1]
if key.startswith("metadata."):
key = key.split("metadata.")[1]
if isinstance(value, dict):
if len(value) > 1:
raise FilterError("only one operator permitted")
Expand Down
66 changes: 16 additions & 50 deletions py/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def test_verify_email_with_expired_code(auth_service, auth_provider):
)

with pytest.raises(R2RException) as exc_info:
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")
assert "Invalid or expired verification code" in str(exc_info.value)


Expand All @@ -243,7 +243,7 @@ async def test_refresh_token_flow(auth_service, auth_provider):
email="[email protected]", password="password123"
)

await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Login to get initial tokens
tokens = await auth_service.login("[email protected]", "password123")
Expand All @@ -252,46 +252,12 @@ async def test_refresh_token_flow(auth_service, auth_provider):

# Use refresh token to get new access token
new_tokens = await auth_service.refresh_access_token(
"[email protected]", refresh_token.token
refresh_token.token
)
assert "access_token" in new_tokens
assert new_tokens["access_token"].token != initial_access_token.token


@pytest.mark.asyncio
async def test_refresh_token_with_wrong_user(auth_service, auth_provider):
with patch.object(
auth_provider.crypto_provider,
"generate_verification_code",
return_value="123456",
):
new_user1 = await auth_service.register(
email="[email protected]", password="password123"
)
with patch.object(
auth_provider.crypto_provider,
"generate_verification_code",
return_value="1234567",
):
new_user2 = await auth_service.register(
email="[email protected]", password="password123"
)

await auth_service.verify_email("123456")
await auth_service.verify_email("1234567")

# Login as user1
tokens = await auth_service.login("[email protected]", "password123")
refresh_token = tokens["refresh_token"]

# Try to use user1's refresh token for user2
with pytest.raises(R2RException) as exc_info:
await auth_service.refresh_access_token(
"[email protected]", refresh_token.token
)
assert "Invalid email address attached to token" in str(exc_info.value)


@pytest.mark.asyncio
async def test_get_current_user_with_expired_token(
auth_service, auth_provider
Expand All @@ -305,7 +271,7 @@ async def test_get_current_user_with_expired_token(
email="[email protected]", password="password123"
)

await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Manually expire the token
auth_provider.access_token_lifetime_in_minutes = (
Expand Down Expand Up @@ -339,7 +305,7 @@ async def test_change_password(auth_service, auth_provider):
new_user = await auth_service.register(
email="[email protected]", password="old_password"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Change password
await auth_service.change_password(
Expand Down Expand Up @@ -370,7 +336,7 @@ async def test_reset_password_flow(
new_user = await auth_service.register(
email="[email protected]", password="old_password"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Request password reset
await auth_service.request_password_reset("[email protected]")
Expand Down Expand Up @@ -411,7 +377,7 @@ async def test_logout(auth_service, auth_provider):
new_user = await auth_service.register(
email="[email protected]", password="password123"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Login to get tokens
tokens = await auth_service.login("[email protected]", "password123")
Expand All @@ -437,7 +403,7 @@ async def test_update_user_profile(auth_service, auth_provider):
new_user = await auth_service.register(
email="[email protected]", password="password123"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Update user profile
updated_profile = await auth_service.update_user(
Expand All @@ -462,7 +428,7 @@ async def test_delete_user_account(auth_service, auth_provider):
new_user = await auth_service.register(
email="[email protected]", password="password123"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Delete user account
await auth_service.delete_user(new_user.id, "password123")
Expand Down Expand Up @@ -491,7 +457,7 @@ async def test_token_blacklist_cleanup(auth_service, auth_provider):
await auth_service.register(
email="[email protected]", password="password123"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

# Login and logout to create a blacklisted token
tokens = await auth_service.login("[email protected]", "password123")
Expand Down Expand Up @@ -539,7 +505,7 @@ async def test_register_and_verify(auth_service, auth_provider):
assert new_user.email == "[email protected]"
assert not new_user.is_verified

await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

new_user = auth_provider.db_provider.relational.get_user_by_email(
"[email protected]"
Expand All @@ -559,7 +525,7 @@ async def test_login_logout(auth_service, auth_provider):
await auth_service.register(
email="[email protected]", password="password123"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

tokens = await auth_service.login("[email protected]", "password123")
assert "access_token" in tokens
Expand All @@ -580,11 +546,11 @@ async def test_refresh_token(auth_service, auth_provider):
await auth_service.register(
email="[email protected]", password="password123"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

tokens = await auth_service.login("[email protected]", "password123")
new_tokens = await auth_service.refresh_access_token(
"[email protected]", tokens["refresh_token"].token
tokens["refresh_token"].token
)
assert new_tokens["access_token"].token != tokens["access_token"].token

Expand All @@ -599,7 +565,7 @@ async def test_change_password(auth_service, auth_provider):
new_user = await auth_service.register(
email="[email protected]", password="oldpassword"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")

result = await auth_service.change_password(
new_user, "oldpassword", "newpassword"
Expand Down Expand Up @@ -636,7 +602,7 @@ async def test_confirm_reset_password(auth_service, auth_provider):
await auth_service.register(
email="[email protected]", password="oldpassword"
)
await auth_service.verify_email("123456")
await auth_service.verify_email("[email protected]", "123456")
await auth_service.request_password_reset("[email protected]")
result = await auth_service.confirm_password_reset(
"123456", "newpassword"
Expand Down

0 comments on commit 14a396a

Please sign in to comment.