"""Application factory for the mini S3-compatible object store.""" from __future__ import annotations import logging import sys import time import uuid from logging.handlers import RotatingFileHandler from pathlib import Path from datetime import timedelta from typing import Any, Dict, Optional from flask import Flask, g, has_request_context, redirect, render_template, request, url_for from flask_cors import CORS from flask_wtf.csrf import CSRFError from werkzeug.middleware.proxy_fix import ProxyFix from .bucket_policies import BucketPolicyStore from .config import AppConfig from .connections import ConnectionStore from .encryption import EncryptionManager from .extensions import limiter, csrf from .iam import IamService from .kms import KMSManager from .replication import ReplicationManager from .secret_store import EphemeralSecretStore from .storage import ObjectStorage from .version import get_version def create_app( test_config: Optional[Dict[str, Any]] = None, *, include_api: bool = True, include_ui: bool = True, ) -> Flask: """Create and configure the Flask application.""" config = AppConfig.from_env(test_config) if getattr(sys, "frozen", False): project_root = Path(sys._MEIPASS) else: project_root = Path(__file__).resolve().parent.parent app = Flask( __name__, static_folder=str(project_root / "static"), template_folder=str(project_root / "templates"), ) app.config.update(config.to_flask_config()) if test_config: app.config.update(test_config) app.config.setdefault("APP_VERSION", get_version()) app.permanent_session_lifetime = timedelta(days=int(app.config.get("SESSION_LIFETIME_DAYS", 30))) if app.config.get("TESTING"): app.config.setdefault("WTF_CSRF_ENABLED", False) # Trust X-Forwarded-* headers from proxies app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) _configure_cors(app) _configure_logging(app) limiter.init_app(app) csrf.init_app(app) storage = ObjectStorage(Path(app.config["STORAGE_ROOT"])) iam = IamService( Path(app.config["IAM_CONFIG"]), auth_max_attempts=app.config.get("AUTH_MAX_ATTEMPTS", 5), auth_lockout_minutes=app.config.get("AUTH_LOCKOUT_MINUTES", 15), ) bucket_policies = BucketPolicyStore(Path(app.config["BUCKET_POLICY_PATH"])) secret_store = EphemeralSecretStore(default_ttl=app.config.get("SECRET_TTL_SECONDS", 300)) # Initialize Replication components connections_path = Path(app.config["STORAGE_ROOT"]) / ".connections.json" replication_rules_path = Path(app.config["STORAGE_ROOT"]) / ".replication_rules.json" connections = ConnectionStore(connections_path) replication = ReplicationManager(storage, connections, replication_rules_path) # Initialize encryption and KMS encryption_config = { "encryption_enabled": app.config.get("ENCRYPTION_ENABLED", False), "encryption_master_key_path": app.config.get("ENCRYPTION_MASTER_KEY_PATH"), "default_encryption_algorithm": app.config.get("DEFAULT_ENCRYPTION_ALGORITHM", "AES256"), } encryption_manager = EncryptionManager(encryption_config) kms_manager = None if app.config.get("KMS_ENABLED", False): kms_keys_path = Path(app.config.get("KMS_KEYS_PATH", "")) kms_master_key_path = Path(app.config.get("ENCRYPTION_MASTER_KEY_PATH", "")) kms_manager = KMSManager(kms_keys_path, kms_master_key_path) encryption_manager.set_kms_provider(kms_manager) app.extensions["object_storage"] = storage app.extensions["iam"] = iam app.extensions["bucket_policies"] = bucket_policies app.extensions["secret_store"] = secret_store app.extensions["limiter"] = limiter app.extensions["connections"] = connections app.extensions["replication"] = replication app.extensions["encryption"] = encryption_manager app.extensions["kms"] = kms_manager @app.errorhandler(500) def internal_error(error): return render_template('500.html'), 500 @app.errorhandler(CSRFError) def handle_csrf_error(e): return render_template('csrf_error.html', reason=e.description), 400 @app.template_filter("filesizeformat") def filesizeformat(value: int) -> str: """Format bytes as human-readable file size.""" for unit in ["B", "KB", "MB", "GB", "TB", "PB"]: if abs(value) < 1024.0 or unit == "PB": if unit == "B": return f"{int(value)} {unit}" return f"{value:.1f} {unit}" value /= 1024.0 return f"{value:.1f} PB" @app.template_filter("timestamp_to_datetime") def timestamp_to_datetime(value: float) -> str: """Format Unix timestamp as human-readable datetime.""" from datetime import datetime if not value: return "Never" try: dt = datetime.fromtimestamp(value) return dt.strftime("%Y-%m-%d %H:%M:%S") except (ValueError, OSError): return "Unknown" if include_api: from .s3_api import s3_api_bp from .kms_api import kms_api_bp app.register_blueprint(s3_api_bp) app.register_blueprint(kms_api_bp) csrf.exempt(s3_api_bp) csrf.exempt(kms_api_bp) if include_ui: from .ui import ui_bp app.register_blueprint(ui_bp) if not include_api: @app.get("/") def ui_root_redirect(): return redirect(url_for("ui.buckets_overview")) @app.errorhandler(404) def handle_not_found(error): wants_html = request.accept_mimetypes.accept_html path = request.path or "" if include_ui and wants_html: if not include_api or path.startswith("/ui") or path == "/": return render_template("404.html"), 404 return error @app.get("/healthz") def healthcheck() -> Dict[str, str]: return {"status": "ok", "version": app.config.get("APP_VERSION", "unknown")} return app def create_api_app(test_config: Optional[Dict[str, Any]] = None) -> Flask: return create_app(test_config, include_api=True, include_ui=False) def create_ui_app(test_config: Optional[Dict[str, Any]] = None) -> Flask: return create_app(test_config, include_api=False, include_ui=True) def _configure_cors(app: Flask) -> None: origins = app.config.get("CORS_ORIGINS", ["*"]) methods = app.config.get("CORS_METHODS", ["GET", "PUT", "POST", "DELETE", "OPTIONS"]) allow_headers = app.config.get( "CORS_ALLOW_HEADERS", ["Content-Type", "X-Access-Key", "X-Secret-Key", "X-Amz-Date", "X-Amz-SignedHeaders"], ) CORS( app, resources={r"/*": {"origins": origins, "methods": methods, "allow_headers": allow_headers}}, supports_credentials=True, ) class _RequestContextFilter(logging.Filter): """Inject request-specific attributes into log records.""" def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - simple boilerplate if has_request_context(): record.request_id = getattr(g, "request_id", "-") record.path = request.path record.method = request.method record.remote_addr = request.remote_addr or "-" else: record.request_id = getattr(record, "request_id", "-") record.path = getattr(record, "path", "-") record.method = getattr(record, "method", "-") record.remote_addr = getattr(record, "remote_addr", "-") return True def _configure_logging(app: Flask) -> None: formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(request_id)s | %(method)s %(path)s | %(message)s" ) # Stream Handler (stdout) - Primary for Docker stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setFormatter(formatter) stream_handler.addFilter(_RequestContextFilter()) logger = app.logger logger.handlers.clear() logger.addHandler(stream_handler) # File Handler (optional, if configured) if app.config.get("LOG_TO_FILE"): log_file = Path(app.config["LOG_FILE"]) log_file.parent.mkdir(parents=True, exist_ok=True) file_handler = RotatingFileHandler( log_file, maxBytes=int(app.config.get("LOG_MAX_BYTES", 5 * 1024 * 1024)), backupCount=int(app.config.get("LOG_BACKUP_COUNT", 3)), encoding="utf-8", ) file_handler.setFormatter(formatter) file_handler.addFilter(_RequestContextFilter()) logger.addHandler(file_handler) logger.setLevel(getattr(logging, app.config.get("LOG_LEVEL", "INFO"), logging.INFO)) @app.before_request def _log_request_start() -> None: g.request_id = uuid.uuid4().hex g.request_started_at = time.perf_counter() app.logger.info( "Request started", extra={"path": request.path, "method": request.method, "remote_addr": request.remote_addr}, ) @app.after_request def _log_request_end(response): duration_ms = 0.0 if hasattr(g, "request_started_at"): duration_ms = (time.perf_counter() - g.request_started_at) * 1000 request_id = getattr(g, "request_id", uuid.uuid4().hex) response.headers.setdefault("X-Request-ID", request_id) app.logger.info( "Request completed", extra={ "path": request.path, "method": request.method, "remote_addr": request.remote_addr, }, ) response.headers["X-Request-Duration-ms"] = f"{duration_ms:.2f}" return response