From 8d4172a6f9c7e7e4f773ff14120f2fa89982aa14 Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Tue, 27 Aug 2024 15:29:04 -0300 Subject: [PATCH 01/15] feat: integrating with HCI App Models --- app/config/base.py | 11 ++- app/routers/frontend.py | 192 ++++++++-------------------------------- app/types/frontend.py | 1 + 3 files changed, 48 insertions(+), 156 deletions(-) diff --git a/app/config/base.py b/app/config/base.py index 7729784..7b41b71 100644 --- a/app/config/base.py +++ b/app/config/base.py @@ -6,8 +6,17 @@ # Logging LOG_LEVEL = getenv_or_action("LOG_LEVEL", default="INFO") -# BigQuery Project +# BigQuery Integration BIGQUERY_PROJECT = getenv_or_action("BIGQUERY_PROJECT", action="raise") +BIGQUERY_PATIENT_HEADER_TABLE_ID = getenv_or_action( + "BIGQUERY_PATIENT_HEADER_TABLE_ID", action="raise" +) +BIGQUERY_PATIENT_SUMMARY_TABLE_ID = getenv_or_action( + "BIGQUERY_PATIENT_SUMMARY_TABLE_ID", action="raise" +) +BIGQUERY_PATIENT_ENCOUNTERS_TABLE_ID = getenv_or_action( + "BIGQUERY_PATIENT_ENCOUNTERS_TABLE_ID", action="raise" +) # JWT configuration JWT_SECRET_KEY = getenv_or_action("JWT_SECRET_KEY", default=token_bytes(32).hex()) diff --git a/app/routers/frontend.py b/app/routers/frontend.py index f255d01..aa71077 100644 --- a/app/routers/frontend.py +++ b/app/routers/frontend.py @@ -4,6 +4,7 @@ from typing import Annotated, List from fastapi import APIRouter, Depends, HTTPException from basedosdados import read_sql +from tortoise.exceptions import ValidationError from app.dependencies import ( get_current_frontend_user @@ -15,8 +16,13 @@ Encounter, UserInfo, ) -from app.config import BIGQUERY_PROJECT -from app.utils import read_timestamp, normalize_case +from app.validators import CPFValidator +from app.config import ( + BIGQUERY_PROJECT, + BIGQUERY_PATIENT_HEADER_TABLE_ID, + BIGQUERY_PATIENT_SUMMARY_TABLE_ID, + BIGQUERY_PATIENT_ENCOUNTERS_TABLE_ID +) router = APIRouter(prefix="/frontend", tags=["Frontend Application"]) @@ -45,80 +51,29 @@ async def get_patient_header( _: Annotated[User, Depends(get_current_frontend_user)], cpf: str, ) -> PatientHeader: + validator = CPFValidator() + try: + validator(cpf) + except ValidationError: + raise HTTPException(status_code=400, detail="Invalid CPF") + results_json = read_sql( f""" SELECT * - FROM `{BIGQUERY_PROJECT}`.`saude_dados_mestres`.`paciente` + FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_HEADER_TABLE_ID} WHERE cpf = '{cpf}' """, from_file="/tmp/credentials.json", ).to_json(orient="records") + try: + results = json.loads(results_json) + except Exception: + results = [] - results = json.loads(results_json) - - if len(results) > 0: - patient_record = results[0] - else: + if len(results) == 0: raise HTTPException(status_code=404, detail="Patient not found") - - data = patient_record["dados"] - - cns_principal = None - if len(patient_record["cns"]) > 0: - cns_principal = patient_record["cns"][0] - - telefone_principal = None - if len(patient_record["contato"]["telefone"]) > 0: - telefone_principal = patient_record["contato"]["telefone"][0]["valor"] - - clinica_principal, equipe_principal = {}, {} - medicos, enfermeiros = [], [] - if len(patient_record["equipe_saude_familia"]) > 0: - equipe_principal = patient_record["equipe_saude_familia"][0] - - # Pega Clínica da Família - if equipe_principal["clinica_familia"]: - clinica_principal = equipe_principal["clinica_familia"] - - for equipe in patient_record["equipe_saude_familia"]: - medicos.extend(equipe["medicos"]) - enfermeiros.extend(equipe["enfermeiros"]) - - for medico in medicos: - medico['registry'] = medico.pop('id_profissional_sus') - medico['name'] = medico.pop('nome') - - for enfermeiro in enfermeiros: - enfermeiro['registry'] = enfermeiro.pop('id_profissional_sus') - enfermeiro['name'] = enfermeiro.pop('nome') - - data_nascimento = None - if data.get("data_nascimento") is not None: - data_nascimento = read_timestamp(data.get("data_nascimento"), output_format='date') - - return { - "registration_name": data.get("nome"), - "social_name": data.get("nome_social"), - "cpf": f"{cpf[:3]}.{cpf[3:6]}.{cpf[6:9]}-{cpf[9:]}", - "cns": cns_principal, - "birth_date": data_nascimento, - "gender": data.get("genero"), - "race": data.get("raca"), - "phone": telefone_principal, - "family_clinic": { - "cnes": clinica_principal.get("id_cnes"), - "name": clinica_principal.get("nome"), - "phone": clinica_principal.get("telefone"), - }, - "family_health_team": { - "ine_code": equipe_principal.get("id_ine"), - "name": equipe_principal.get("nome"), - "phone": equipe_principal.get("telefone"), - }, - "medical_responsible": medicos, - "nursing_responsible": enfermeiros, - "validated": data.get("identidade_validada_indicador"), - } + else: + return results[0] @@ -128,51 +83,19 @@ async def get_patient_summary( cpf: str, ) -> PatientSummary: - query = f""" - with - base as (select '{cpf}' as cpf), - alergias_grouped as ( - select - cpf, - alergias as allergies - from `saude_historico_clinico.alergia` - where cpf = '{cpf}' - ), - medicamentos_cronicos_single as ( - select - cpf, - med.nome as nome_medicamento - from `saude_historico_clinico.medicamentos_cronicos`, - unnest(medicamentos) as med - where cpf = '{cpf}' - ), - medicamentos_cronicos_grouped as ( - select - cpf, - array_agg(nome_medicamento) as continuous_use_medications - from medicamentos_cronicos_single - group by cpf - ) - select - alergias_grouped.allergies, - medicamentos_cronicos_grouped.continuous_use_medications - from base - left join alergias_grouped on alergias_grouped.cpf = base.cpf - left join medicamentos_cronicos_grouped on medicamentos_cronicos_grouped.cpf = base.cpf - """ results_json = read_sql( - query, - from_file="/tmp/credentials.json" + f""" + SELECT * + FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_SUMMARY_TABLE_ID} + WHERE cpf = '{cpf}' + """, + from_file="/tmp/credentials.json", ).to_json(orient="records") - - result = json.loads(results_json) - if len(result) > 0: - return result[0] - - return { - "allergies": [], - "continuous_use_medications": [] - } + results = json.loads(results_json) + if len(results) == 0: + raise HTTPException(status_code=404, detail="Patient not found") + else: + return results[0] @router.get("/patient/filter_tags") async def get_filter_tags( @@ -199,51 +122,10 @@ async def get_patient_encounters( results_json = read_sql( f""" SELECT * - FROM `{BIGQUERY_PROJECT}`.`saude_historico_clinico`.`episodio_assistencial` - WHERE paciente.cpf = '{cpf}' + FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_ENCOUNTERS_TABLE_ID} + WHERE cpf = '{cpf}' """, from_file="/tmp/credentials.json", ).to_json(orient="records") - - encounters = [] - for result in json.loads(results_json): - # Responsible professional - professional = result.get('profissional_saude_responsavel') - if professional: - if isinstance(professional, list): - professional = professional[0] if len(professional) > 0 else {} - - if not professional['nome'] and not professional['especialidade']: - professional = None - else: - professional = { - "name": professional.get('nome'), - "role": professional.get('especialidade') - } - - # Filter Tags - unit_type = result['estabelecimento']['estabelecimento_tipo'] - if unit_type in [ - 'CLINICA DA FAMILIA', - 'CENTRO MUNICIPAL DE SAUDE' - ]: - unit_type = 'CF/CMS' - - encounter = { - "entry_datetime": read_timestamp(result['entrada_datahora'], output_format='datetime'), - "exit_datetime": read_timestamp(result['saida_datahora'], output_format='datetime'), - "location": result['estabelecimento']['nome'], - "type": result['tipo'], - "subtype": result['subtipo'], - "active_cids": [cid['descricao'] for cid in result['condicoes'] if cid['descricao']], - "responsible": professional, - "clinical_motivation": normalize_case(result['motivo_atendimento']), - "clinical_outcome": normalize_case(result['desfecho_atendimento']), - "filter_tags": [unit_type], - } - encounters.append(encounter) - - # Sort Encounters by entry_datetime - encounters = sorted(encounters, key=lambda x: x['entry_datetime'], reverse=True) - - return encounters + results = json.loads(results_json) + return results diff --git a/app/types/frontend.py b/app/types/frontend.py index 24b6bc5..b4432bf 100644 --- a/app/types/frontend.py +++ b/app/types/frontend.py @@ -36,6 +36,7 @@ class Encounter(BaseModel): location: str type: str subtype: Optional[str] + exhibition_type: str = 'default' active_cids: List[str] responsible: Optional[Responsible] clinical_motivation: Optional[str] From 65be0021f2265c8edf21c54329867767f891b75c Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Wed, 28 Aug 2024 10:29:23 -0300 Subject: [PATCH 02/15] feat: Restrict Patients and Encounters data --- app/routers/frontend.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/app/routers/frontend.py b/app/routers/frontend.py index aa71077..4c59468 100644 --- a/app/routers/frontend.py +++ b/app/routers/frontend.py @@ -72,8 +72,15 @@ async def get_patient_header( if len(results) == 0: raise HTTPException(status_code=404, detail="Patient not found") - else: - return results[0] + + dados = results[0] + configuracao_exibicao = dados.get('exibicao', {}) + + if configuracao_exibicao.get('indicador', False) is False: + message = ",".join(configuracao_exibicao.get('motivos', [])) + raise HTTPException(status_code=204, detail=message) + + return dados @@ -123,7 +130,7 @@ async def get_patient_encounters( f""" SELECT * FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_ENCOUNTERS_TABLE_ID} - WHERE cpf = '{cpf}' + WHERE cpf = '{cpf}' and exibicao.indicador = true """, from_file="/tmp/credentials.json", ).to_json(orient="records") From cfc8262e56976e4072f04f69f6e646156bedb414 Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Wed, 28 Aug 2024 14:53:58 -0300 Subject: [PATCH 03/15] feat: Update exit_datetime field in Encounter model to be optional --- app/types/frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/types/frontend.py b/app/types/frontend.py index b4432bf..94297f8 100644 --- a/app/types/frontend.py +++ b/app/types/frontend.py @@ -32,7 +32,7 @@ class Responsible(BaseModel): # Medical Visit model class Encounter(BaseModel): entry_datetime: str - exit_datetime: str + exit_datetime: Optional[str] location: str type: str subtype: Optional[str] From 9e8b1715af6973200579afe3fe13991d86265837 Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Wed, 28 Aug 2024 15:59:18 -0300 Subject: [PATCH 04/15] feat: Using Query Preview from Big Query --- app/config/__init__.py | 2 +- app/routers/frontend.py | 20 +++++++------------- app/utils.py | 37 ++++++++++++++++++------------------- poetry.lock | 16 +++++++++++++++- pyproject.toml | 1 + 5 files changed, 42 insertions(+), 34 deletions(-) diff --git a/app/config/__init__.py b/app/config/__init__.py index 6c2d783..43b52c5 100644 --- a/app/config/__init__.py +++ b/app/config/__init__.py @@ -81,7 +81,7 @@ def inject_environment_variables(environment: str): f"Injecting {len(secrets)} environment variables from Infisical:") for secret in secrets: logger.info( - f" - {secret.secret_name}: {'*' * len(secret.secret_value)}") + f" - {secret.secret_name}: {len(secret.secret_value)} chars") environment = getenv_or_action("ENVIRONMENT", action="warn", default="dev") diff --git a/app/routers/frontend.py b/app/routers/frontend.py index 4c59468..10424f0 100644 --- a/app/routers/frontend.py +++ b/app/routers/frontend.py @@ -3,7 +3,6 @@ from typing import Annotated, List from fastapi import APIRouter, Depends, HTTPException -from basedosdados import read_sql from tortoise.exceptions import ValidationError from app.dependencies import ( @@ -16,6 +15,7 @@ Encounter, UserInfo, ) +from app.utils import read_bq from app.validators import CPFValidator from app.config import ( BIGQUERY_PROJECT, @@ -57,18 +57,14 @@ async def get_patient_header( except ValidationError: raise HTTPException(status_code=400, detail="Invalid CPF") - results_json = read_sql( + results = await read_bq( f""" SELECT * FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_HEADER_TABLE_ID} WHERE cpf = '{cpf}' """, from_file="/tmp/credentials.json", - ).to_json(orient="records") - try: - results = json.loads(results_json) - except Exception: - results = [] + ) if len(results) == 0: raise HTTPException(status_code=404, detail="Patient not found") @@ -90,15 +86,14 @@ async def get_patient_summary( cpf: str, ) -> PatientSummary: - results_json = read_sql( + results = await read_bq( f""" SELECT * FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_SUMMARY_TABLE_ID} WHERE cpf = '{cpf}' """, from_file="/tmp/credentials.json", - ).to_json(orient="records") - results = json.loads(results_json) + ) if len(results) == 0: raise HTTPException(status_code=404, detail="Patient not found") else: @@ -126,13 +121,12 @@ async def get_patient_encounters( cpf: str, ) -> List[Encounter]: - results_json = read_sql( + results = await read_bq( f""" SELECT * FROM `{BIGQUERY_PROJECT}`.{BIGQUERY_PATIENT_ENCOUNTERS_TABLE_ID} WHERE cpf = '{cpf}' and exibicao.indicador = true """, from_file="/tmp/credentials.json", - ).to_json(orient="records") - results = json.loads(results_json) + ) return results diff --git a/app/utils.py b/app/utils.py index 5dc8003..67be48f 100644 --- a/app/utils.py +++ b/app/utils.py @@ -3,7 +3,11 @@ import jwt import hashlib import json -from typing import Literal +import os + +from google.cloud import bigquery +from google.oauth2 import service_account +from asyncer import asyncify from loguru import logger from passlib.context import CryptContext @@ -124,24 +128,19 @@ async def get_instance(Model, table, slug=None, code=None): return table[slug] -def read_timestamp(timestamp: int, output_format=Literal['date','datetime']) -> str: - if output_format == 'date': - denominator = 1000 - str_format = "%Y-%m-%d" - elif output_format == 'datetime': - denominator = 1 - str_format = "%Y-%m-%d %H:%M:%S" - else: - raise ValueError("Invalid format") +async def read_bq(query, from_file="/tmp/credentials.json"): + logger.debug(f"""Reading BigQuery with query (QUERY_PREVIEW_ENABLED={ + os.environ['QUERY_PREVIEW_ENABLED'] + }): {query}""") - try: - value = datetime(1970, 1, 1) + timedelta(seconds=timestamp/denominator) - except Exception as exc: - logger.error(f"Invalid timestamp: {timestamp} from {exc}") - return None + def execute_job(): + credentials = service_account.Credentials.from_service_account_file( + from_file, + ) + client = bigquery.Client(credentials=credentials) + row_iterator = client.query_and_wait(query) + return [dict(row) for row in row_iterator] - return value.strftime(str_format) + rows = await asyncify(execute_job)() -def normalize_case(text): - # TODO - return text + return rows \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index e0ad362..bed4472 100644 --- a/poetry.lock +++ b/poetry.lock @@ -85,6 +85,20 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "asyncer" +version = "0.0.8" +description = "Asyncer, async and await, focused on developer experience." +optional = false +python-versions = ">=3.8" +files = [ + {file = "asyncer-0.0.8-py3-none-any.whl", hash = "sha256:5920d48fc99c8f8f0f1576e1882f5022885589c5fcbc46ce4224ec3e53776eeb"}, + {file = "asyncer-0.0.8.tar.gz", hash = "sha256:a589d980f57e20efb07ed91d0dbe67f1d2fd343e7142c66d3a099f05c620739c"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5.0" + [[package]] name = "asyncpg" version = "0.29.0" @@ -2793,4 +2807,4 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "af2cee69c2de80a6861a61b21a5ac4dacbcbdf1a79ee275beab6ea8bc72cba7c" +content-hash = "1ec4944b80ec680487b4e0ffc0144b035cc4fc7efeb12e81f5b585127b8c0271" diff --git a/pyproject.toml b/pyproject.toml index 92157cb..37f0d01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ urllib3 = "2.0.7" idna = "3.7" basedosdados = "^2.0.0b16" nltk = "^3.9.1" +asyncer = "^0.0.8" [tool.poetry.group.dev.dependencies] From d9241dd81724cbf6ae9c1ab1ab02ea33003c2e5f Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Wed, 28 Aug 2024 16:16:34 -0300 Subject: [PATCH 05/15] feat: Remove unused import in frontend.py --- app/routers/frontend.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app/routers/frontend.py b/app/routers/frontend.py index 10424f0..7a97c4e 100644 --- a/app/routers/frontend.py +++ b/app/routers/frontend.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -import json - from typing import Annotated, List from fastapi import APIRouter, Depends, HTTPException from tortoise.exceptions import ValidationError From bd9ddcb8173018b93dca032efc90f1e68b236801 Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Thu, 29 Aug 2024 16:18:52 -0300 Subject: [PATCH 06/15] Include in Request Return Clinical Exams --- app/types/frontend.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/app/types/frontend.py b/app/types/frontend.py index 94297f8..92f750e 100644 --- a/app/types/frontend.py +++ b/app/types/frontend.py @@ -16,6 +16,10 @@ class FamilyHealthTeam(BaseModel): name: Optional[str] phone: Optional[str] +# Clinical Exam Model +class ClinicalExam(BaseModel): + type: str + description: Optional[str] # Medical Conditions model class PatientSummary(BaseModel): @@ -41,6 +45,7 @@ class Encounter(BaseModel): responsible: Optional[Responsible] clinical_motivation: Optional[str] clinical_outcome: Optional[str] + clinical_exames: List[ClinicalExam] filter_tags: List[str] From d8636c7ff0bdaffd208dd4cb020a32db09b5828e Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Thu, 29 Aug 2024 16:39:59 -0300 Subject: [PATCH 07/15] feat: Fix typo in variable name in frontend.py --- app/types/frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/types/frontend.py b/app/types/frontend.py index 92f750e..60743cc 100644 --- a/app/types/frontend.py +++ b/app/types/frontend.py @@ -45,7 +45,7 @@ class Encounter(BaseModel): responsible: Optional[Responsible] clinical_motivation: Optional[str] clinical_outcome: Optional[str] - clinical_exames: List[ClinicalExam] + clinical_exams: List[ClinicalExam] filter_tags: List[str] From 39b388390b2dc92c806702b6d90399a9cce111ec Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Thu, 29 Aug 2024 17:26:05 -0300 Subject: [PATCH 08/15] Fix HTTP status code in get_patient_header function --- app/routers/frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/routers/frontend.py b/app/routers/frontend.py index 7a97c4e..c780cc8 100644 --- a/app/routers/frontend.py +++ b/app/routers/frontend.py @@ -72,7 +72,7 @@ async def get_patient_header( if configuracao_exibicao.get('indicador', False) is False: message = ",".join(configuracao_exibicao.get('motivos', [])) - raise HTTPException(status_code=204, detail=message) + raise HTTPException(status_code=403, detail=message) return dados From 095bcfa82314cc9b2340cc279209462a3e1b9326 Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Fri, 30 Aug 2024 14:27:46 -0300 Subject: [PATCH 09/15] Basic Adapting to 2FA use --- app/models.py | 1 + app/routers/auth.py | 52 +++++++++++++++++++++++------- app/security.py | 78 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 11 deletions(-) create mode 100644 app/security.py diff --git a/app/models.py b/app/models.py index 56dd11f..a20df6a 100644 --- a/app/models.py +++ b/app/models.py @@ -273,6 +273,7 @@ class User(Model): cpf = fields.CharField(max_length=11, unique=True, null=True, validators=[CPFValidator()]) email = fields.CharField(max_length=255, unique=True) password = fields.CharField(max_length=255) + #secret_key = fields.CharField(max_length=255, null=True) is_active = fields.BooleanField(default=True) is_superuser = fields.BooleanField(default=False) user_class = fields.CharEnumField(enum_type=UserClassEnum, null=True, default=UserClassEnum.PIPELINE_USER) diff --git a/app/routers/auth.py b/app/routers/auth.py index 3e23581..8c81d2c 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,14 +1,17 @@ # -*- coding: utf-8 -*- +import io from datetime import timedelta from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm +from fastapi.responses import StreamingResponse from app import config from app.models import User from app.types.pydantic_models import Token from app.utils import authenticate_user, create_access_token +from app.security import TwoFactorAuth, get_two_factor_auth router = APIRouter(prefix="/auth", tags=["Autenticação"]) @@ -23,21 +26,48 @@ async def login_for_access_token( if not user: raise HTTPException( - status_code = status.HTTP_401_UNAUTHORIZED, - detail = "Incorrect username or password", - headers = {"WWW-Authenticate": "Bearer"}, + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta( - minutes = config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES - ) + access_token_expires = timedelta(minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( - data = {"sub": user.username}, - expires_delta = access_token_expires + data={"sub": user.username}, expires_delta=access_token_expires ) + return {"access_token": access_token, "token_type": "bearer"} + + +@router.post("/enable-2fa/{user_id}") +async def enable_2fa( + two_factor_auth: TwoFactorAuth = Depends(get_two_factor_auth) +): + return { + "secret_key": two_factor_auth.secret_key + } + + +@router.get("/generate-qr/{user_id}") +async def generate_qr( + two_factor_auth: TwoFactorAuth = Depends(get_two_factor_auth) +): + qr_code = two_factor_auth.qr_code + if qr_code is None: + raise HTTPException(status_code=404, detail="User not found") + + return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") + + +@router.post("/verify-totp/{user_id}") +async def verify_totp( + totp_code: str, + two_factor_auth: TwoFactorAuth = Depends(get_two_factor_auth), +): + is_valid = two_factor_auth.verify_totp_code(totp_code) + if not is_valid: + raise HTTPException(status_code=400, detail="Code invalid") return { - "access_token": access_token, - "token_type": "bearer" - } \ No newline at end of file + "valid": is_valid + } diff --git a/app/security.py b/app/security.py new file mode 100644 index 0000000..5688285 --- /dev/null +++ b/app/security.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +import base64 +import io +import secrets +from typing import Optional + +import qrcode +from fastapi import Depends +from pyotp import TOTP + +from app.models import User + + +class TwoFactorAuth: + + def __init__(self, user_id: str, secret_key: str): + self._user_id = user_id + self._secret_key = secret_key + self._totp = TOTP(self._secret_key) + self._qr_cache: Optional[bytes] = None + + @property + def totp(self) -> TOTP: + return self._totp + + @property + def secret_key(self) -> str: + return self._secret_key + + @staticmethod + def _generate_secret_key() -> str: + secret_bytes = secrets.token_bytes(20) + secret_key = base64.b32encode(secret_bytes).decode("utf-8") + return secret_key + + @staticmethod + async def get_or_create_secret_key(user_id: str) -> str: + user = await User.get_or_none(id=user_id) + + if not user: + raise ValueError(f"User with id {user_id} not found") + + # If User doesn't have a secret_key, create one + if not user.secret_key: + secret_key = TwoFactorAuth._generate_secret_key() + user.secret_key = secret_key + user.save() + + return user.secret_key + + def _create_qr_code(self) -> bytes: + uri = self.totp.provisioning_uri( + name=self._user_id, + issuer_name="2FA", + ) + img = qrcode.make(uri) + img_byte_array = io.BytesIO() + img.save(img_byte_array, format="PNG") + img_byte_array.seek(0) + return img_byte_array.getvalue() + + @property + def qr_code(self) -> bytes: + if self._qr_cache is None: + self._qr_cache = self._create_qr_code() + return self._qr_cache + + def verify_totp_code(self, totp_code: str) -> bool: + return self.totp.verify(totp_code) + + +async def get_two_factor_auth( + user_id: str +) -> TwoFactorAuth: + secret_key = await TwoFactorAuth.get_or_create_secret_key( + user_id + ) + return TwoFactorAuth(user_id, secret_key) From 90d764f82ee16a5a8f12634909a18ab3903434f3 Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Fri, 30 Aug 2024 14:35:04 -0300 Subject: [PATCH 10/15] Adding packages --- poetry.lock | 50 +++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index bed4472..7dad60e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1938,6 +1938,20 @@ cffi = ">=1.4.1" docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] +[[package]] +name = "pyotp" +version = "2.9.0" +description = "Python One Time Password Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612"}, + {file = "pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63"}, +] + +[package.extras] +test = ["coverage", "mypy", "ruff", "wheel"] + [[package]] name = "pyparsing" version = "3.1.2" @@ -1963,6 +1977,17 @@ files = [ {file = "pypika_tortoise-0.1.6-py3-none-any.whl", hash = "sha256:2d68bbb7e377673743cff42aa1059f3a80228d411fbcae591e4465e173109fd8"}, ] +[[package]] +name = "pypng" +version = "0.20220715.0" +description = "Pure Python library for saving and loading PNG images" +optional = false +python-versions = "*" +files = [ + {file = "pypng-0.20220715.0-py3-none-any.whl", hash = "sha256:4a43e969b8f5aaafb2a415536c1a8ec7e341cd6a3f957fd5b5f32a4cfeed902c"}, + {file = "pypng-0.20220715.0.tar.gz", hash = "sha256:739c433ba96f078315de54c0db975aee537cbc3e1d0ae4ed9aab0ca1e427e2c1"}, +] + [[package]] name = "pytest" version = "7.4.4" @@ -2146,6 +2171,29 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "qrcode" +version = "7.4.2" +description = "QR Code image generator" +optional = false +python-versions = ">=3.7" +files = [ + {file = "qrcode-7.4.2-py3-none-any.whl", hash = "sha256:581dca7a029bcb2deef5d01068e39093e80ef00b4a61098a2182eac59d01643a"}, + {file = "qrcode-7.4.2.tar.gz", hash = "sha256:9dd969454827e127dbd93696b20747239e6d540e082937c90f14ac95b30f5845"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} +pypng = "*" +typing-extensions = "*" + +[package.extras] +all = ["pillow (>=9.1.0)", "pytest", "pytest-cov", "tox", "zest.releaser[recommended]"] +dev = ["pytest", "pytest-cov", "tox"] +maintainer = ["zest.releaser[recommended]"] +pil = ["pillow (>=9.1.0)"] +test = ["coverage", "pytest"] + [[package]] name = "regex" version = "2024.7.24" @@ -2807,4 +2855,4 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "1ec4944b80ec680487b4e0ffc0144b035cc4fc7efeb12e81f5b585127b8c0271" +content-hash = "5715de4721d690e2ecff456af1ed2d255f810d0c660d5cb7f26bd307d92e4199" diff --git a/pyproject.toml b/pyproject.toml index 37f0d01..f9022a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ idna = "3.7" basedosdados = "^2.0.0b16" nltk = "^3.9.1" asyncer = "^0.0.8" +qrcode = "^7.4.2" +pyotp = "^2.9.0" [tool.poetry.group.dev.dependencies] From 603e6a417ed671de302134cc5136a048cc77a809 Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Fri, 30 Aug 2024 15:57:30 -0300 Subject: [PATCH 11/15] Implemented 2FA new login endpoint --- app/models.py | 2 +- app/routers/auth.py | 91 +++++++++++++++++----- app/security.py | 6 +- migrations/app/27_20240830145823_update.py | 12 +++ 4 files changed, 87 insertions(+), 24 deletions(-) create mode 100644 migrations/app/27_20240830145823_update.py diff --git a/app/models.py b/app/models.py index a20df6a..8ccc082 100644 --- a/app/models.py +++ b/app/models.py @@ -273,7 +273,7 @@ class User(Model): cpf = fields.CharField(max_length=11, unique=True, null=True, validators=[CPFValidator()]) email = fields.CharField(max_length=255, unique=True) password = fields.CharField(max_length=255) - #secret_key = fields.CharField(max_length=255, null=True) + secret_key = fields.CharField(max_length=255, null=True) is_active = fields.BooleanField(default=True) is_superuser = fields.BooleanField(default=False) user_class = fields.CharEnumField(enum_type=UserClassEnum, null=True, default=UserClassEnum.PIPELINE_USER) diff --git a/app/routers/auth.py b/app/routers/auth.py index 8c81d2c..909d139 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import io from datetime import timedelta -from typing import Annotated +from typing import Annotated, Optional from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm @@ -11,7 +11,10 @@ from app.models import User from app.types.pydantic_models import Token from app.utils import authenticate_user, create_access_token -from app.security import TwoFactorAuth, get_two_factor_auth +from app.security import TwoFactorAuth +from app.dependencies import ( + get_current_frontend_user +) router = APIRouter(prefix="/auth", tags=["Autenticação"]) @@ -39,35 +42,83 @@ async def login_for_access_token( return {"access_token": access_token, "token_type": "bearer"} +@router.post("/2fa/login") +async def login( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + totp_code: Optional[str] = None, +) -> Token: + + user: User = await authenticate_user(form_data.username, form_data.password) + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # IF 2FA is not Enabled + is_2fa_disabled = user.secret_key is None + is_trying_2fa = totp_code is not None + + # If 2FA is enabled and user is initializing the session + if not is_2fa_disabled and not is_trying_2fa: + # User must provide a OTP + return { + "2fa_enabled": True + } + + # If 2FA is enabled and user is trying to login with OTP + if is_trying_2fa and not is_2fa_disabled: + secret_key = await TwoFactorAuth.get_or_create_secret_key(user.id) + two_factor_auth = TwoFactorAuth(user.id, secret_key) + + is_valid = two_factor_auth.verify_totp_code(totp_code) + if not is_valid: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect OTP", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # If 2FA is disabled + if is_2fa_disabled: + is_valid = True + + # Generate Token + access_token_expires = timedelta(minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) -@router.post("/enable-2fa/{user_id}") + return { + "2fa_enabled": False, + "access_token": access_token, + "token_type": "bearer" + } + + +@router.post("/2fa/enable/") async def enable_2fa( - two_factor_auth: TwoFactorAuth = Depends(get_two_factor_auth) + current_user: Annotated[User, Depends(get_current_frontend_user)], ): + secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) + two_factor_auth = TwoFactorAuth(current_user.id, secret_key) + return { "secret_key": two_factor_auth.secret_key } -@router.get("/generate-qr/{user_id}") +@router.get("/2fa/generate-qr/") async def generate_qr( - two_factor_auth: TwoFactorAuth = Depends(get_two_factor_auth) + current_user: Annotated[User, Depends(get_current_frontend_user)], ): + secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) + two_factor_auth = TwoFactorAuth(current_user.id, secret_key) + qr_code = two_factor_auth.qr_code if qr_code is None: raise HTTPException(status_code=404, detail="User not found") - return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") - - -@router.post("/verify-totp/{user_id}") -async def verify_totp( - totp_code: str, - two_factor_auth: TwoFactorAuth = Depends(get_two_factor_auth), -): - is_valid = two_factor_auth.verify_totp_code(totp_code) - if not is_valid: - raise HTTPException(status_code=400, detail="Code invalid") - return { - "valid": is_valid - } + return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") \ No newline at end of file diff --git a/app/security.py b/app/security.py index 5688285..84fd716 100644 --- a/app/security.py +++ b/app/security.py @@ -44,18 +44,18 @@ async def get_or_create_secret_key(user_id: str) -> str: if not user.secret_key: secret_key = TwoFactorAuth._generate_secret_key() user.secret_key = secret_key - user.save() + await user.save() return user.secret_key def _create_qr_code(self) -> bytes: uri = self.totp.provisioning_uri( - name=self._user_id, + name=str(self._user_id), issuer_name="2FA", ) img = qrcode.make(uri) img_byte_array = io.BytesIO() - img.save(img_byte_array, format="PNG") + img.save(img_byte_array) img_byte_array.seek(0) return img_byte_array.getvalue() diff --git a/migrations/app/27_20240830145823_update.py b/migrations/app/27_20240830145823_update.py new file mode 100644 index 0000000..f17dac5 --- /dev/null +++ b/migrations/app/27_20240830145823_update.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +from tortoise import BaseDBAsyncClient + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" ADD "secret_key" VARCHAR(255);""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" DROP COLUMN "secret_key";""" From 8d064e32c9f53836fd12228ede95eeaa157369fb Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Fri, 30 Aug 2024 17:46:04 -0300 Subject: [PATCH 12/15] refactor: 2FA logic improving use flow --- app/models.py | 5 +- app/routers/auth.py | 127 +++++++++++---------- app/security.py | 2 +- app/utils.py | 9 ++ migrations/app/28_20240830163720_update.py | 12 ++ migrations/app/29_20240830164032_update.py | 12 ++ migrations/app/30_20240830164307_update.py | 16 +++ 7 files changed, 123 insertions(+), 60 deletions(-) create mode 100644 migrations/app/28_20240830163720_update.py create mode 100644 migrations/app/29_20240830164032_update.py create mode 100644 migrations/app/30_20240830164307_update.py diff --git a/app/models.py b/app/models.py index 8ccc082..2432836 100644 --- a/app/models.py +++ b/app/models.py @@ -273,11 +273,14 @@ class User(Model): cpf = fields.CharField(max_length=11, unique=True, null=True, validators=[CPFValidator()]) email = fields.CharField(max_length=255, unique=True) password = fields.CharField(max_length=255) - secret_key = fields.CharField(max_length=255, null=True) is_active = fields.BooleanField(default=True) is_superuser = fields.BooleanField(default=False) user_class = fields.CharEnumField(enum_type=UserClassEnum, null=True, default=UserClassEnum.PIPELINE_USER) data_source = fields.ForeignKeyField("app.DataSource", related_name="users", null=True) + # 2FA + secret_key = fields.CharField(max_length=255, null=True) + is_2fa_required = fields.BooleanField(default=False) + is_2fa_activated = fields.BooleanField(default=False) created_at = fields.DatetimeField(auto_now_add=True) updated_at = fields.DatetimeField(auto_now=True) diff --git a/app/routers/auth.py b/app/routers/auth.py index 909d139..606c0e5 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,16 +1,14 @@ # -*- coding: utf-8 -*- import io -from datetime import timedelta from typing import Annotated, Optional from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from fastapi.responses import StreamingResponse -from app import config from app.models import User from app.types.pydantic_models import Token -from app.utils import authenticate_user, create_access_token +from app.utils import authenticate_user, generate_user_token from app.security import TwoFactorAuth from app.dependencies import ( get_current_frontend_user @@ -21,12 +19,11 @@ @router.post("/token") -async def login_for_access_token( - form_data: Annotated[OAuth2PasswordRequestForm, Depends()] +async def login_without_2fa( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> Token: - user: User = await authenticate_user(form_data.username, form_data.password) - + user = await authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -34,22 +31,26 @@ async def login_for_access_token( headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + if user.is_2fa_required: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="2FA required. Use the /2fa/login/ endpoint", + headers={"WWW-Authenticate": "Bearer"}, + ) - access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires - ) + return { + "access_token": generate_user_token(user), + "token_type": "bearer" + } - return {"access_token": access_token, "token_type": "bearer"} -@router.post("/2fa/login") -async def login( +@router.post("/2fa/login/") +async def login_with_2fa( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - totp_code: Optional[str] = None, + totp_code: str, ) -> Token: - user: User = await authenticate_user(form_data.username, form_data.password) - + user = await authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -57,47 +58,30 @@ async def login( headers={"WWW-Authenticate": "Bearer"}, ) - # IF 2FA is not Enabled - is_2fa_disabled = user.secret_key is None - is_trying_2fa = totp_code is not None - - # If 2FA is enabled and user is initializing the session - if not is_2fa_disabled and not is_trying_2fa: - # User must provide a OTP - return { - "2fa_enabled": True - } - - # If 2FA is enabled and user is trying to login with OTP - if is_trying_2fa and not is_2fa_disabled: - secret_key = await TwoFactorAuth.get_or_create_secret_key(user.id) - two_factor_auth = TwoFactorAuth(user.id, secret_key) - - is_valid = two_factor_auth.verify_totp_code(totp_code) - if not is_valid: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect OTP", - headers={"WWW-Authenticate": "Bearer"}, - ) - - # If 2FA is disabled - if is_2fa_disabled: - is_valid = True - - # Generate Token - access_token_expires = timedelta(minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires - ) + # Caso 1: Usuário não registrou 2FA e está tentando logar + if user.is_2fa_required and not user.is_2fa_activated: + raise HTTPException( + status_code=status.HTTP_412_PRECONDITION_FAILED, + detail="2FA not activated. Use the /2fa/enable/ endpoint", + headers={"WWW-Authenticate": "Bearer"}, + ) + # Caso 2: Usuário registrou 2FA e está tentando logar + secret_key = await TwoFactorAuth.get_or_create_secret_key(user.id) + two_factor_auth = TwoFactorAuth(user.id, secret_key) + + is_valid = two_factor_auth.verify_totp_code(totp_code) + if not is_valid: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect OTP", + headers={"WWW-Authenticate": "Bearer"}, + ) return { - "2fa_enabled": False, - "access_token": access_token, - "token_type": "bearer" + "access_token": generate_user_token(user), + "token_type": "bearer", } - @router.post("/2fa/enable/") async def enable_2fa( current_user: Annotated[User, Depends(get_current_frontend_user)], @@ -110,10 +94,18 @@ async def enable_2fa( } -@router.get("/2fa/generate-qr/") -async def generate_qr( - current_user: Annotated[User, Depends(get_current_frontend_user)], +@router.get("/2fa/activate/generate-qrcode/") +async def generate_qrcode( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ): + current_user = await authenticate_user(form_data.username, form_data.password) + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) two_factor_auth = TwoFactorAuth(current_user.id, secret_key) @@ -121,4 +113,23 @@ async def generate_qr( if qr_code is None: raise HTTPException(status_code=404, detail="User not found") - return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") \ No newline at end of file + return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") + + +@router.post('/2fa/activate/verify-code/') +async def verify_code( + totp_code: str, + current_user: Annotated[User, Depends(get_current_frontend_user)], +): + secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) + two_factor_auth = TwoFactorAuth(current_user.id, secret_key) + + is_valid_totp = two_factor_auth.verify_totp_code(totp_code) + + if is_valid_totp: + current_user.is_2fa_activated = True + await current_user.save() + + return { + 'success': is_valid_totp + } diff --git a/app/security.py b/app/security.py index 84fd716..70e3a46 100644 --- a/app/security.py +++ b/app/security.py @@ -75,4 +75,4 @@ async def get_two_factor_auth( secret_key = await TwoFactorAuth.get_or_create_secret_key( user_id ) - return TwoFactorAuth(user_id, secret_key) + return TwoFactorAuth(user_id, secret_key) \ No newline at end of file diff --git a/app/utils.py b/app/utils.py index 67be48f..25aa6cf 100644 --- a/app/utils.py +++ b/app/utils.py @@ -18,6 +18,15 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +def generate_user_token(user: User) -> str: + access_token_expires = timedelta(minutes=config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) + + return access_token + async def authenticate_user(username: str, password: str) -> User: """Authenticate a user. diff --git a/migrations/app/28_20240830163720_update.py b/migrations/app/28_20240830163720_update.py new file mode 100644 index 0000000..0070d08 --- /dev/null +++ b/migrations/app/28_20240830163720_update.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +from tortoise import BaseDBAsyncClient + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" ADD "is_2fa_enabled" BOOL NOT NULL DEFAULT False;""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" DROP COLUMN "is_2fa_enabled";""" diff --git a/migrations/app/29_20240830164032_update.py b/migrations/app/29_20240830164032_update.py new file mode 100644 index 0000000..a54b4d4 --- /dev/null +++ b/migrations/app/29_20240830164032_update.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +from tortoise import BaseDBAsyncClient + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" ADD "is_2fa_active" BOOL NOT NULL DEFAULT False;""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" DROP COLUMN "is_2fa_active";""" diff --git a/migrations/app/30_20240830164307_update.py b/migrations/app/30_20240830164307_update.py new file mode 100644 index 0000000..0c4c781 --- /dev/null +++ b/migrations/app/30_20240830164307_update.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from tortoise import BaseDBAsyncClient + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" RENAME COLUMN "is_2fa_enabled" TO "is_2fa_required"; + ALTER TABLE "user" RENAME COLUMN "is_2fa_active" TO "is_2fa_activated";""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + ALTER TABLE "user" RENAME COLUMN "is_2fa_required" TO "is_2fa_active"; + ALTER TABLE "user" RENAME COLUMN "is_2fa_activated" TO "is_2fa_active"; + ALTER TABLE "user" RENAME COLUMN "is_2fa_required" TO "is_2fa_enabled"; + ALTER TABLE "user" RENAME COLUMN "is_2fa_activated" TO "is_2fa_enabled";""" From 5de1bc9135016359691664c033489e49a9f1b6eb Mon Sep 17 00:00:00 2001 From: Pedro Nascimento Date: Fri, 30 Aug 2024 18:45:28 -0300 Subject: [PATCH 13/15] chore: Remove unused import in auth.py --- app/routers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/routers/auth.py b/app/routers/auth.py index 606c0e5..1155c88 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import io -from typing import Annotated, Optional +from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm From 42ff25b53d5300aa1a4dfc440d80320467e82c81 Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Fri, 30 Aug 2024 21:55:23 -0300 Subject: [PATCH 14/15] refactor: Improve 2FA login flow and activate 2FA when logging in without 2FA --- app/routers/auth.py | 50 +++++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/app/routers/auth.py b/app/routers/auth.py index 1155c88..f6d123f 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -44,6 +44,21 @@ async def login_without_2fa( } +@router.post("/2fa/is-2fa-active/") +async def is_2fa_active( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], +) -> bool: + user = await authenticate_user(form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return user.is_2fa_activated + + @router.post("/2fa/login/") async def login_with_2fa( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], @@ -58,15 +73,6 @@ async def login_with_2fa( headers={"WWW-Authenticate": "Bearer"}, ) - # Caso 1: Usuário não registrou 2FA e está tentando logar - if user.is_2fa_required and not user.is_2fa_activated: - raise HTTPException( - status_code=status.HTTP_412_PRECONDITION_FAILED, - detail="2FA not activated. Use the /2fa/enable/ endpoint", - headers={"WWW-Authenticate": "Bearer"}, - ) - - # Caso 2: Usuário registrou 2FA e está tentando logar secret_key = await TwoFactorAuth.get_or_create_secret_key(user.id) two_factor_auth = TwoFactorAuth(user.id, secret_key) @@ -77,11 +83,16 @@ async def login_with_2fa( detail="Incorrect OTP", headers={"WWW-Authenticate": "Bearer"}, ) + if not user.is_2fa_activated: + user.is_2fa_activated = True + await user.save() + return { "access_token": generate_user_token(user), "token_type": "bearer", } + @router.post("/2fa/enable/") async def enable_2fa( current_user: Annotated[User, Depends(get_current_frontend_user)], @@ -94,7 +105,7 @@ async def enable_2fa( } -@router.get("/2fa/activate/generate-qrcode/") +@router.get("/2fa/generate-qrcode/") async def generate_qrcode( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ): @@ -114,22 +125,3 @@ async def generate_qrcode( raise HTTPException(status_code=404, detail="User not found") return StreamingResponse(io.BytesIO(qr_code), media_type="image/png") - - -@router.post('/2fa/activate/verify-code/') -async def verify_code( - totp_code: str, - current_user: Annotated[User, Depends(get_current_frontend_user)], -): - secret_key = await TwoFactorAuth.get_or_create_secret_key(current_user.id) - two_factor_auth = TwoFactorAuth(current_user.id, secret_key) - - is_valid_totp = two_factor_auth.verify_totp_code(totp_code) - - if is_valid_totp: - current_user.is_2fa_activated = True - await current_user.save() - - return { - 'success': is_valid_totp - } From 575446cb8a8f84efb13b7725c02426f691c48812 Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Fri, 30 Aug 2024 22:01:23 -0300 Subject: [PATCH 15/15] refactor: Remove unused import in security.py --- app/security.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/security.py b/app/security.py index 70e3a46..b3e2298 100644 --- a/app/security.py +++ b/app/security.py @@ -5,7 +5,6 @@ from typing import Optional import qrcode -from fastapi import Depends from pyotp import TOTP from app.models import User