forked from rvankoert/loghi-htr
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
367 additions
and
275 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,142 @@ | ||
# Imports | ||
|
||
# > Standard library | ||
import asyncio | ||
from contextlib import asynccontextmanager | ||
import socket | ||
import multiprocessing as mp | ||
|
||
# > Third-party dependencies | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from uvicorn.config import Config | ||
from uvicorn.server import Server | ||
|
||
# > Local dependencies | ||
import errors | ||
from routes import main | ||
from app_utils import setup_logging, get_env_variable, start_workers | ||
from simple_security import SimpleSecurity | ||
from app_utils import (setup_logging, get_env_variable, | ||
start_workers, stop_workers) | ||
from routes import create_router | ||
|
||
# > Third-party dependencies | ||
from flask import Flask | ||
# Set up logging | ||
logging_level = get_env_variable("LOGGING_LEVEL", "INFO") | ||
logger = setup_logging(logging_level) | ||
|
||
# Get Loghi-HTR options from environment variables | ||
logger.info("Getting Loghi-HTR options from environment variables") | ||
batch_size = int(get_env_variable("LOGHI_BATCH_SIZE", "256")) | ||
model_path = get_env_variable("LOGHI_MODEL_PATH") | ||
output_path = get_env_variable("LOGHI_OUTPUT_PATH") | ||
max_queue_size = int(get_env_variable("LOGHI_MAX_QUEUE_SIZE", "10000")) | ||
patience = float(get_env_variable("LOGHI_PATIENCE", "0.5")) | ||
|
||
def create_app() -> Flask: | ||
""" | ||
Create and configure a Flask app for image prediction. | ||
# Get GPU options from environment variables | ||
logger.info("Getting GPU options from environment variables") | ||
gpus = get_env_variable("LOGHI_GPUS", "0") | ||
|
||
This function initializes a Flask app, sets up necessary configurations, | ||
starts image preparation and batch prediction processes, and returns the | ||
configured app instance. | ||
|
||
Returns | ||
------- | ||
Flask | ||
Configured Flask app instance ready for serving. | ||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
""" | ||
Manage the lifespan of the FastAPI application. | ||
Parameters | ||
---------- | ||
app : FastAPI | ||
The FastAPI application instance. | ||
Side Effects | ||
------------ | ||
- Initializes and starts preparation, prediction, and decoding processes. | ||
- Logs various messages regarding the app and process initialization. | ||
Yields | ||
------ | ||
None | ||
""" | ||
# Create a stop event | ||
stop_event = mp.Event() | ||
|
||
# Set up logging | ||
logging_level = get_env_variable("LOGGING_LEVEL", "INFO") | ||
logger = setup_logging(logging_level) | ||
|
||
# Get Loghi-HTR options from environment variables | ||
logger.info("Getting Loghi-HTR options from environment variables") | ||
batch_size = int(get_env_variable("LOGHI_BATCH_SIZE", "256")) | ||
model_path = get_env_variable("LOGHI_MODEL_PATH") | ||
output_path = get_env_variable("LOGHI_OUTPUT_PATH") | ||
max_queue_size = int(get_env_variable("LOGHI_MAX_QUEUE_SIZE", "10000")) | ||
patience = float(get_env_variable("LOGHI_PATIENCE", "0.5")) | ||
|
||
# Get GPU options from environment variables | ||
logger.info("Getting GPU options from environment variables") | ||
gpus = get_env_variable("LOGHI_GPUS", "0") | ||
|
||
# Create Flask app | ||
logger.info("Creating Flask app") | ||
app = Flask(__name__) | ||
|
||
# Register error handler | ||
app.register_error_handler(ValueError, errors.handle_invalid_usage) | ||
app.register_error_handler(405, errors.method_not_allowed) | ||
|
||
# Add security to app | ||
security_config = \ | ||
{"enabled": get_env_variable("SECURITY_ENABLED", "False"), | ||
"key_user_json": get_env_variable("API_KEY_USER_JSON_STRING", "{}")} | ||
security = SimpleSecurity(app, security_config) | ||
logger.info(f"Security enabled: {security.enabled}") | ||
|
||
# Start the worker processes | ||
# Startup: Start the worker processes | ||
logger.info("Starting worker processes") | ||
workers, queues = start_workers(batch_size, max_queue_size, output_path, | ||
gpus, model_path, patience) | ||
gpus, model_path, patience, stop_event) | ||
# Add request queue and stop event to the app | ||
app.state.request_queue = queues["Request"] | ||
app.state.stop_event = stop_event | ||
app.state.workers = workers | ||
|
||
yield | ||
|
||
# Add request queue to the app | ||
app.request_queue = queues["Request"] | ||
# Shutdown: Stop all workers and join them | ||
logger.info("Shutting down worker processes") | ||
stop_workers(app.state.workers, app.state.stop_event) | ||
logger.info("All workers have been stopped and joined") | ||
|
||
# Add the workers to the app | ||
app.workers = workers | ||
|
||
# Register blueprints | ||
app.register_blueprint(main) | ||
def create_app() -> FastAPI: | ||
""" | ||
Create and configure the FastAPI application. | ||
Returns | ||
------- | ||
FastAPI | ||
The configured FastAPI application instance. | ||
""" | ||
app = FastAPI( | ||
title="Loghi-HTR API", | ||
description="API for Loghi-HTR", | ||
lifespan=lifespan | ||
) | ||
|
||
# Add CORS middleware | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], # Allows all origins | ||
allow_credentials=True, | ||
allow_methods=["*"], # Allows all methods | ||
allow_headers=["*"], # Allows all headers | ||
) | ||
|
||
# Include the router | ||
router = create_router(app) | ||
app.include_router(router) | ||
|
||
return app | ||
|
||
|
||
app = create_app() | ||
|
||
|
||
async def run_server(): | ||
""" | ||
Run the FastAPI server. | ||
Returns | ||
------- | ||
None | ||
""" | ||
host = get_env_variable("UVICORN_HOST", "127.0.0.1") | ||
port = int(get_env_variable("UVICORN_PORT", "5000")) | ||
|
||
# Attempt to resolve the hostname | ||
try: | ||
socket.gethostbyname(host) | ||
except socket.gaierror: | ||
logger.error( | ||
f"Unable to resolve hostname: {host}. Falling back to localhost.") | ||
host = "127.0.0.1" | ||
|
||
config = Config("app:app", host=host, port=port, workers=1) | ||
server = Server(config=config) | ||
|
||
try: | ||
await server.serve() | ||
except OSError as e: | ||
logger.error(f"Error starting server: {e}") | ||
if e.errno == 98: # Address already in use | ||
logger.error( | ||
f"Port {port} is already in use. Try a different port.") | ||
elif e.errno == 13: # Permission denied | ||
logger.error( | ||
f"Permission denied when trying to bind to port {port}. Try a " | ||
"port number > 1024 or run with sudo.") | ||
except Exception as e: | ||
logger.error(f"Unexpected error occurred: {e}") | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(run_server()) |
Oops, something went wrong.