-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #214 from prefeitura-rio/feat/2fa
2FA Implementation
- Loading branch information
Showing
10 changed files
with
296 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,127 @@ | ||
# -*- coding: utf-8 -*- | ||
from datetime import timedelta | ||
import io | ||
from typing import Annotated | ||
|
||
from fastapi import APIRouter, Depends, HTTPException, status | ||
from fastapi.security import OAuth2PasswordRequestForm | ||
from fastapi.responses import StreamingResponse | ||
|
||
from app import config | ||
from app.models import User | ||
from app.types.pydantic_models import Token | ||
from app.utils import authenticate_user, create_access_token | ||
from app.utils import authenticate_user, generate_user_token | ||
from app.security import TwoFactorAuth | ||
from app.dependencies import ( | ||
get_current_frontend_user | ||
) | ||
|
||
|
||
router = APIRouter(prefix="/auth", tags=["Autenticação"]) | ||
|
||
|
||
@router.post("/token") | ||
async def login_for_access_token( | ||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()] | ||
async def login_without_2fa( | ||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], | ||
) -> Token: | ||
|
||
user: User = await authenticate_user(form_data.username, form_data.password) | ||
user = await authenticate_user(form_data.username, form_data.password) | ||
if not user: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect username or password", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
|
||
if user.is_2fa_required: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="2FA required. Use the /2fa/login/ endpoint", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
|
||
return { | ||
"access_token": generate_user_token(user), | ||
"token_type": "bearer" | ||
} | ||
|
||
|
||
@router.post("/2fa/is-2fa-active/") | ||
async def is_2fa_active( | ||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], | ||
) -> bool: | ||
user = await authenticate_user(form_data.username, form_data.password) | ||
if not user: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect username or password", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
|
||
return user.is_2fa_activated | ||
|
||
|
||
@router.post("/2fa/login/") | ||
async def login_with_2fa( | ||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], | ||
totp_code: str, | ||
) -> Token: | ||
|
||
user = await authenticate_user(form_data.username, form_data.password) | ||
if not user: | ||
raise HTTPException( | ||
status_code = status.HTTP_401_UNAUTHORIZED, | ||
detail = "Incorrect username or password", | ||
headers = {"WWW-Authenticate": "Bearer"}, | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect username or password", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
|
||
access_token_expires = timedelta( | ||
minutes = config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES | ||
) | ||
secret_key = await TwoFactorAuth.get_or_create_secret_key(user.id) | ||
two_factor_auth = TwoFactorAuth(user.id, secret_key) | ||
|
||
access_token = create_access_token( | ||
data = {"sub": user.username}, | ||
expires_delta = access_token_expires | ||
) | ||
is_valid = two_factor_auth.verify_totp_code(totp_code) | ||
if not is_valid: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect OTP", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
if not user.is_2fa_activated: | ||
user.is_2fa_activated = True | ||
await user.save() | ||
|
||
return { | ||
"access_token": access_token, | ||
"token_type": "bearer" | ||
} | ||
"access_token": generate_user_token(user), | ||
"token_type": "bearer", | ||
} | ||
|
||
|
||
@router.post("/2fa/enable/") | ||
async def enable_2fa( | ||
current_user: Annotated[User, Depends(get_current_frontend_user)], | ||
): | ||
secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) | ||
two_factor_auth = TwoFactorAuth(current_user.id, secret_key) | ||
|
||
return { | ||
"secret_key": two_factor_auth.secret_key | ||
} | ||
|
||
|
||
@router.get("/2fa/generate-qrcode/") | ||
async def generate_qrcode( | ||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], | ||
): | ||
current_user = await authenticate_user(form_data.username, form_data.password) | ||
if not current_user: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Incorrect username or password", | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
) | ||
|
||
secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) | ||
two_factor_auth = TwoFactorAuth(current_user.id, secret_key) | ||
|
||
qr_code = two_factor_auth.qr_code | ||
if qr_code is None: | ||
raise HTTPException(status_code=404, detail="User not found") | ||
|
||
return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# -*- coding: utf-8 -*- | ||
import base64 | ||
import io | ||
import secrets | ||
from typing import Optional | ||
|
||
import qrcode | ||
from pyotp import TOTP | ||
|
||
from app.models import User | ||
|
||
|
||
class TwoFactorAuth: | ||
|
||
def __init__(self, user_id: str, secret_key: str): | ||
self._user_id = user_id | ||
self._secret_key = secret_key | ||
self._totp = TOTP(self._secret_key) | ||
self._qr_cache: Optional[bytes] = None | ||
|
||
@property | ||
def totp(self) -> TOTP: | ||
return self._totp | ||
|
||
@property | ||
def secret_key(self) -> str: | ||
return self._secret_key | ||
|
||
@staticmethod | ||
def _generate_secret_key() -> str: | ||
secret_bytes = secrets.token_bytes(20) | ||
secret_key = base64.b32encode(secret_bytes).decode("utf-8") | ||
return secret_key | ||
|
||
@staticmethod | ||
async def get_or_create_secret_key(user_id: str) -> str: | ||
user = await User.get_or_none(id=user_id) | ||
|
||
if not user: | ||
raise ValueError(f"User with id {user_id} not found") | ||
|
||
# If User doesn't have a secret_key, create one | ||
if not user.secret_key: | ||
secret_key = TwoFactorAuth._generate_secret_key() | ||
user.secret_key = secret_key | ||
await user.save() | ||
|
||
return user.secret_key | ||
|
||
def _create_qr_code(self) -> bytes: | ||
uri = self.totp.provisioning_uri( | ||
name=str(self._user_id), | ||
issuer_name="2FA", | ||
) | ||
img = qrcode.make(uri) | ||
img_byte_array = io.BytesIO() | ||
img.save(img_byte_array) | ||
img_byte_array.seek(0) | ||
return img_byte_array.getvalue() | ||
|
||
@property | ||
def qr_code(self) -> bytes: | ||
if self._qr_cache is None: | ||
self._qr_cache = self._create_qr_code() | ||
return self._qr_cache | ||
|
||
def verify_totp_code(self, totp_code: str) -> bool: | ||
return self.totp.verify(totp_code) | ||
|
||
|
||
async def get_two_factor_auth( | ||
user_id: str | ||
) -> TwoFactorAuth: | ||
secret_key = await TwoFactorAuth.get_or_create_secret_key( | ||
user_id | ||
) | ||
return TwoFactorAuth(user_id, secret_key) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- coding: utf-8 -*- | ||
from tortoise import BaseDBAsyncClient | ||
|
||
|
||
async def upgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" ADD "secret_key" VARCHAR(255);""" | ||
|
||
|
||
async def downgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" DROP COLUMN "secret_key";""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- coding: utf-8 -*- | ||
from tortoise import BaseDBAsyncClient | ||
|
||
|
||
async def upgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" ADD "is_2fa_enabled" BOOL NOT NULL DEFAULT False;""" | ||
|
||
|
||
async def downgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" DROP COLUMN "is_2fa_enabled";""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- coding: utf-8 -*- | ||
from tortoise import BaseDBAsyncClient | ||
|
||
|
||
async def upgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" ADD "is_2fa_active" BOOL NOT NULL DEFAULT False;""" | ||
|
||
|
||
async def downgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" DROP COLUMN "is_2fa_active";""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# -*- coding: utf-8 -*- | ||
from tortoise import BaseDBAsyncClient | ||
|
||
|
||
async def upgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" RENAME COLUMN "is_2fa_enabled" TO "is_2fa_required"; | ||
ALTER TABLE "user" RENAME COLUMN "is_2fa_active" TO "is_2fa_activated";""" | ||
|
||
|
||
async def downgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
ALTER TABLE "user" RENAME COLUMN "is_2fa_required" TO "is_2fa_active"; | ||
ALTER TABLE "user" RENAME COLUMN "is_2fa_activated" TO "is_2fa_active"; | ||
ALTER TABLE "user" RENAME COLUMN "is_2fa_required" TO "is_2fa_enabled"; | ||
ALTER TABLE "user" RENAME COLUMN "is_2fa_activated" TO "is_2fa_enabled";""" |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters