diff --git a/README.md b/README.md index f1e7951..d485351 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,11 @@ python run.py --mode ui # UI only (port 5100) | `ENCRYPTION_ENABLED` | `false` | Enable server-side encryption | | `KMS_ENABLED` | `false` | Enable Key Management Service | | `LOG_LEVEL` | `INFO` | Logging verbosity | +| `SIGV4_TIMESTAMP_TOLERANCE_SECONDS` | `900` | Max time skew for SigV4 requests | +| `PRESIGNED_URL_MAX_EXPIRY_SECONDS` | `604800` | Max presigned URL expiry (7 days) | +| `REPLICATION_CONNECT_TIMEOUT_SECONDS` | `5` | Replication connection timeout | +| `SITE_SYNC_ENABLED` | `false` | Enable bi-directional site sync | +| `OBJECT_TAG_LIMIT` | `50` | Maximum tags per object | ## Data Layout diff --git a/app/__init__.py b/app/__init__.py index 2968c03..ef13ad4 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -31,6 +31,7 @@ from .notifications import NotificationService from .object_lock import ObjectLockService from .replication import ReplicationManager from .secret_store import EphemeralSecretStore +from .site_registry import SiteRegistry, SiteInfo from .storage import ObjectStorage from .version import get_version @@ -104,6 +105,9 @@ def create_app( storage = ObjectStorage( Path(app.config["STORAGE_ROOT"]), cache_ttl=app.config.get("OBJECT_CACHE_TTL", 5), + object_cache_max_size=app.config.get("OBJECT_CACHE_MAX_SIZE", 100), + bucket_config_cache_ttl=app.config.get("BUCKET_CONFIG_CACHE_TTL_SECONDS", 30.0), + object_key_max_length_bytes=app.config.get("OBJECT_KEY_MAX_LENGTH_BYTES", 1024), ) if app.config.get("WARM_CACHE_ON_STARTUP", True) and not app.config.get("TESTING"): @@ -137,12 +141,33 @@ def create_app( ) connections = ConnectionStore(connections_path) - replication = ReplicationManager(storage, connections, replication_rules_path, storage_root) - + replication = ReplicationManager( + storage, + connections, + replication_rules_path, + storage_root, + connect_timeout=app.config.get("REPLICATION_CONNECT_TIMEOUT_SECONDS", 5), + read_timeout=app.config.get("REPLICATION_READ_TIMEOUT_SECONDS", 30), + max_retries=app.config.get("REPLICATION_MAX_RETRIES", 2), + streaming_threshold_bytes=app.config.get("REPLICATION_STREAMING_THRESHOLD_BYTES", 10 * 1024 * 1024), + max_failures_per_bucket=app.config.get("REPLICATION_MAX_FAILURES_PER_BUCKET", 50), + ) + + site_registry_path = config_dir / "site_registry.json" + site_registry = SiteRegistry(site_registry_path) + if app.config.get("SITE_ID") and not site_registry.get_local_site(): + site_registry.set_local_site(SiteInfo( + site_id=app.config["SITE_ID"], + endpoint=app.config.get("SITE_ENDPOINT") or "", + region=app.config.get("SITE_REGION", "us-east-1"), + priority=app.config.get("SITE_PRIORITY", 100), + )) + encryption_config = { "encryption_enabled": app.config.get("ENCRYPTION_ENABLED", False), "encryption_master_key_path": app.config.get("ENCRYPTION_MASTER_KEY_PATH"), "default_encryption_algorithm": app.config.get("DEFAULT_ENCRYPTION_ALGORITHM", "AES256"), + "encryption_chunk_size_bytes": app.config.get("ENCRYPTION_CHUNK_SIZE_BYTES", 64 * 1024), } encryption_manager = EncryptionManager(encryption_config) @@ -150,7 +175,12 @@ def create_app( if app.config.get("KMS_ENABLED", False): kms_keys_path = Path(app.config.get("KMS_KEYS_PATH", "")) kms_master_key_path = Path(app.config.get("ENCRYPTION_MASTER_KEY_PATH", "")) - kms_manager = KMSManager(kms_keys_path, kms_master_key_path) + kms_manager = KMSManager( + kms_keys_path, + kms_master_key_path, + generate_data_key_min_bytes=app.config.get("KMS_GENERATE_DATA_KEY_MIN_BYTES", 1), + generate_data_key_max_bytes=app.config.get("KMS_GENERATE_DATA_KEY_MAX_BYTES", 1024), + ) encryption_manager.set_kms_provider(kms_manager) if app.config.get("ENCRYPTION_ENABLED", False): @@ -159,7 +189,10 @@ def create_app( acl_service = AclService(storage_root) object_lock_service = ObjectLockService(storage_root) - notification_service = NotificationService(storage_root) + notification_service = NotificationService( + storage_root, + allow_internal_endpoints=app.config.get("ALLOW_INTERNAL_ENDPOINTS", False), + ) access_logging_service = AccessLoggingService(storage_root) access_logging_service.set_storage(storage) @@ -170,6 +203,7 @@ def create_app( base_storage, interval_seconds=app.config.get("LIFECYCLE_INTERVAL_SECONDS", 3600), storage_root=storage_root, + max_history_per_bucket=app.config.get("LIFECYCLE_MAX_HISTORY_PER_BUCKET", 50), ) lifecycle_manager.start() @@ -187,6 +221,7 @@ def create_app( app.extensions["object_lock"] = object_lock_service app.extensions["notifications"] = notification_service app.extensions["access_logging"] = access_logging_service + app.extensions["site_registry"] = site_registry operation_metrics_collector = None if app.config.get("OPERATION_METRICS_ENABLED", False): @@ -218,6 +253,10 @@ def create_app( storage_root=storage_root, interval_seconds=app.config.get("SITE_SYNC_INTERVAL_SECONDS", 60), batch_size=app.config.get("SITE_SYNC_BATCH_SIZE", 100), + connect_timeout=app.config.get("SITE_SYNC_CONNECT_TIMEOUT_SECONDS", 10), + read_timeout=app.config.get("SITE_SYNC_READ_TIMEOUT_SECONDS", 120), + max_retries=app.config.get("SITE_SYNC_MAX_RETRIES", 2), + clock_skew_tolerance_seconds=app.config.get("SITE_SYNC_CLOCK_SKEW_TOLERANCE_SECONDS", 1.0), ) site_sync_worker.start() app.extensions["site_sync"] = site_sync_worker @@ -289,11 +328,14 @@ def create_app( if include_api: from .s3_api import s3_api_bp from .kms_api import kms_api_bp + from .admin_api import admin_api_bp app.register_blueprint(s3_api_bp) app.register_blueprint(kms_api_bp) + app.register_blueprint(admin_api_bp) csrf.exempt(s3_api_bp) csrf.exempt(kms_api_bp) + csrf.exempt(admin_api_bp) if include_ui: from .ui import ui_bp diff --git a/app/admin_api.py b/app/admin_api.py new file mode 100644 index 0000000..8ebc76f --- /dev/null +++ b/app/admin_api.py @@ -0,0 +1,670 @@ +from __future__ import annotations + +import ipaddress +import logging +import re +import socket +import time +from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlparse + +import requests +from flask import Blueprint, Response, current_app, jsonify, request + +from .connections import ConnectionStore +from .extensions import limiter +from .iam import IamError, Principal +from .replication import ReplicationManager +from .site_registry import PeerSite, SiteInfo, SiteRegistry + + +def _is_safe_url(url: str, allow_internal: bool = False) -> bool: + """Check if a URL is safe to make requests to (not internal/private). + + Args: + url: The URL to check. + allow_internal: If True, allows internal/private IP addresses. + Use for self-hosted deployments on internal networks. + """ + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return False + cloud_metadata_hosts = { + "metadata.google.internal", + "169.254.169.254", + } + if hostname.lower() in cloud_metadata_hosts: + return False + if allow_internal: + return True + blocked_hosts = { + "localhost", + "127.0.0.1", + "0.0.0.0", + "::1", + "[::1]", + } + if hostname.lower() in blocked_hosts: + return False + try: + resolved_ip = socket.gethostbyname(hostname) + ip = ipaddress.ip_address(resolved_ip) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return False + except (socket.gaierror, ValueError): + return False + return True + except Exception: + return False + + +def _validate_endpoint(endpoint: str) -> Optional[str]: + """Validate endpoint URL format. Returns error message or None.""" + try: + parsed = urlparse(endpoint) + if not parsed.scheme or parsed.scheme not in ("http", "https"): + return "Endpoint must be http or https URL" + if not parsed.netloc: + return "Endpoint must have a host" + return None + except Exception: + return "Invalid endpoint URL" + + +def _validate_priority(priority: Any) -> Optional[str]: + """Validate priority value. Returns error message or None.""" + try: + p = int(priority) + if p < 0 or p > 1000: + return "Priority must be between 0 and 1000" + return None + except (TypeError, ValueError): + return "Priority must be an integer" + + +def _validate_region(region: str) -> Optional[str]: + """Validate region format. Returns error message or None.""" + if not re.match(r"^[a-z]{2,}-[a-z]+-\d+$", region): + return "Region must match format like us-east-1" + return None + + +def _validate_site_id(site_id: str) -> Optional[str]: + """Validate site_id format. Returns error message or None.""" + if not site_id or len(site_id) > 63: + return "site_id must be 1-63 characters" + if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9_-]*$', site_id): + return "site_id must start with alphanumeric and contain only alphanumeric, hyphens, underscores" + return None + + +logger = logging.getLogger(__name__) + +admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin") + + +def _require_principal() -> Tuple[Optional[Principal], Optional[Tuple[Dict[str, Any], int]]]: + from .s3_api import _require_principal as s3_require_principal + return s3_require_principal() + + +def _require_admin() -> Tuple[Optional[Principal], Optional[Tuple[Dict[str, Any], int]]]: + principal, error = _require_principal() + if error: + return None, error + + try: + _iam().authorize(principal, None, "iam:*") + return principal, None + except IamError: + return None, _json_error("AccessDenied", "Admin access required", 403) + + +def _site_registry() -> SiteRegistry: + return current_app.extensions["site_registry"] + + +def _connections() -> ConnectionStore: + return current_app.extensions["connections"] + + +def _replication() -> ReplicationManager: + return current_app.extensions["replication"] + + +def _iam(): + return current_app.extensions["iam"] + + +def _json_error(code: str, message: str, status: int) -> Tuple[Dict[str, Any], int]: + return {"error": {"code": code, "message": message}}, status + + +def _get_admin_rate_limit() -> str: + return current_app.config.get("RATE_LIMIT_ADMIN", "60 per minute") + + +@admin_api_bp.route("/site", methods=["GET"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def get_local_site(): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + local_site = registry.get_local_site() + + if local_site: + return jsonify(local_site.to_dict()) + + config_site_id = current_app.config.get("SITE_ID") + config_endpoint = current_app.config.get("SITE_ENDPOINT") + + if config_site_id: + return jsonify({ + "site_id": config_site_id, + "endpoint": config_endpoint or "", + "region": current_app.config.get("SITE_REGION", "us-east-1"), + "priority": current_app.config.get("SITE_PRIORITY", 100), + "display_name": config_site_id, + "source": "environment", + }) + + return _json_error("NotFound", "Local site not configured", 404) + + +@admin_api_bp.route("/site", methods=["PUT"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def update_local_site(): + principal, error = _require_admin() + if error: + return error + + payload = request.get_json(silent=True) or {} + + site_id = payload.get("site_id") + endpoint = payload.get("endpoint") + + if not site_id: + return _json_error("ValidationError", "site_id is required", 400) + + site_id_error = _validate_site_id(site_id) + if site_id_error: + return _json_error("ValidationError", site_id_error, 400) + + if endpoint: + endpoint_error = _validate_endpoint(endpoint) + if endpoint_error: + return _json_error("ValidationError", endpoint_error, 400) + + if "priority" in payload: + priority_error = _validate_priority(payload["priority"]) + if priority_error: + return _json_error("ValidationError", priority_error, 400) + + if "region" in payload: + region_error = _validate_region(payload["region"]) + if region_error: + return _json_error("ValidationError", region_error, 400) + + registry = _site_registry() + existing = registry.get_local_site() + + site = SiteInfo( + site_id=site_id, + endpoint=endpoint or "", + region=payload.get("region", "us-east-1"), + priority=payload.get("priority", 100), + display_name=payload.get("display_name", site_id), + created_at=existing.created_at if existing else None, + ) + + registry.set_local_site(site) + + logger.info("Local site updated", extra={"site_id": site_id, "principal": principal.access_key}) + return jsonify(site.to_dict()) + + +@admin_api_bp.route("/sites", methods=["GET"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def list_all_sites(): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + local = registry.get_local_site() + peers = registry.list_peers() + + result = { + "local": local.to_dict() if local else None, + "peers": [peer.to_dict() for peer in peers], + "total_peers": len(peers), + } + + return jsonify(result) + + +@admin_api_bp.route("/sites", methods=["POST"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def register_peer_site(): + principal, error = _require_admin() + if error: + return error + + payload = request.get_json(silent=True) or {} + + site_id = payload.get("site_id") + endpoint = payload.get("endpoint") + + if not site_id: + return _json_error("ValidationError", "site_id is required", 400) + + site_id_error = _validate_site_id(site_id) + if site_id_error: + return _json_error("ValidationError", site_id_error, 400) + + if not endpoint: + return _json_error("ValidationError", "endpoint is required", 400) + + endpoint_error = _validate_endpoint(endpoint) + if endpoint_error: + return _json_error("ValidationError", endpoint_error, 400) + + region = payload.get("region", "us-east-1") + region_error = _validate_region(region) + if region_error: + return _json_error("ValidationError", region_error, 400) + + priority = payload.get("priority", 100) + priority_error = _validate_priority(priority) + if priority_error: + return _json_error("ValidationError", priority_error, 400) + + registry = _site_registry() + + if registry.get_peer(site_id): + return _json_error("AlreadyExists", f"Peer site '{site_id}' already exists", 409) + + connection_id = payload.get("connection_id") + if connection_id: + if not _connections().get(connection_id): + return _json_error("ValidationError", f"Connection '{connection_id}' not found", 400) + + peer = PeerSite( + site_id=site_id, + endpoint=endpoint, + region=region, + priority=int(priority), + display_name=payload.get("display_name", site_id), + connection_id=connection_id, + ) + + registry.add_peer(peer) + + logger.info("Peer site registered", extra={"site_id": site_id, "principal": principal.access_key}) + return jsonify(peer.to_dict()), 201 + + +@admin_api_bp.route("/sites/", methods=["GET"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def get_peer_site(site_id: str): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + peer = registry.get_peer(site_id) + + if not peer: + return _json_error("NotFound", f"Peer site '{site_id}' not found", 404) + + return jsonify(peer.to_dict()) + + +@admin_api_bp.route("/sites/", methods=["PUT"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def update_peer_site(site_id: str): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + existing = registry.get_peer(site_id) + + if not existing: + return _json_error("NotFound", f"Peer site '{site_id}' not found", 404) + + payload = request.get_json(silent=True) or {} + + if "endpoint" in payload: + endpoint_error = _validate_endpoint(payload["endpoint"]) + if endpoint_error: + return _json_error("ValidationError", endpoint_error, 400) + + if "priority" in payload: + priority_error = _validate_priority(payload["priority"]) + if priority_error: + return _json_error("ValidationError", priority_error, 400) + + if "region" in payload: + region_error = _validate_region(payload["region"]) + if region_error: + return _json_error("ValidationError", region_error, 400) + + peer = PeerSite( + site_id=site_id, + endpoint=payload.get("endpoint", existing.endpoint), + region=payload.get("region", existing.region), + priority=payload.get("priority", existing.priority), + display_name=payload.get("display_name", existing.display_name), + connection_id=payload.get("connection_id", existing.connection_id), + created_at=existing.created_at, + is_healthy=existing.is_healthy, + last_health_check=existing.last_health_check, + ) + + registry.update_peer(peer) + + logger.info("Peer site updated", extra={"site_id": site_id, "principal": principal.access_key}) + return jsonify(peer.to_dict()) + + +@admin_api_bp.route("/sites/", methods=["DELETE"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def delete_peer_site(site_id: str): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + + if not registry.delete_peer(site_id): + return _json_error("NotFound", f"Peer site '{site_id}' not found", 404) + + logger.info("Peer site deleted", extra={"site_id": site_id, "principal": principal.access_key}) + return Response(status=204) + + +@admin_api_bp.route("/sites//health", methods=["GET"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def check_peer_health(site_id: str): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + peer = registry.get_peer(site_id) + + if not peer: + return _json_error("NotFound", f"Peer site '{site_id}' not found", 404) + + is_healthy = False + error_message = None + + if peer.connection_id: + connection = _connections().get(peer.connection_id) + if connection: + is_healthy = _replication().check_endpoint_health(connection) + else: + error_message = f"Connection '{peer.connection_id}' not found" + else: + error_message = "No connection configured for this peer" + + registry.update_health(site_id, is_healthy) + + result = { + "site_id": site_id, + "is_healthy": is_healthy, + "checked_at": time.time(), + } + if error_message: + result["error"] = error_message + + return jsonify(result) + + +@admin_api_bp.route("/topology", methods=["GET"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def get_topology(): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + local = registry.get_local_site() + peers = registry.list_peers() + + sites = [] + + if local: + sites.append({ + **local.to_dict(), + "is_local": True, + "is_healthy": True, + }) + + for peer in peers: + sites.append({ + **peer.to_dict(), + "is_local": False, + }) + + sites.sort(key=lambda s: s.get("priority", 100)) + + return jsonify({ + "sites": sites, + "total": len(sites), + "healthy_count": sum(1 for s in sites if s.get("is_healthy")), + }) + + +@admin_api_bp.route("/sites//bidirectional-status", methods=["GET"]) +@limiter.limit(lambda: _get_admin_rate_limit()) +def check_bidirectional_status(site_id: str): + principal, error = _require_admin() + if error: + return error + + registry = _site_registry() + peer = registry.get_peer(site_id) + + if not peer: + return _json_error("NotFound", f"Peer site '{site_id}' not found", 404) + + local_site = registry.get_local_site() + replication = _replication() + local_rules = replication.list_rules() + + local_bidir_rules = [] + for rule in local_rules: + if rule.target_connection_id == peer.connection_id and rule.mode == "bidirectional": + local_bidir_rules.append({ + "bucket_name": rule.bucket_name, + "target_bucket": rule.target_bucket, + "enabled": rule.enabled, + }) + + result = { + "site_id": site_id, + "local_site_id": local_site.site_id if local_site else None, + "local_endpoint": local_site.endpoint if local_site else None, + "local_bidirectional_rules": local_bidir_rules, + "local_site_sync_enabled": current_app.config.get("SITE_SYNC_ENABLED", False), + "remote_status": None, + "issues": [], + "is_fully_configured": False, + } + + if not local_site or not local_site.site_id: + result["issues"].append({ + "code": "NO_LOCAL_SITE_ID", + "message": "Local site identity not configured", + "severity": "error", + }) + + if not local_site or not local_site.endpoint: + result["issues"].append({ + "code": "NO_LOCAL_ENDPOINT", + "message": "Local site endpoint not configured (remote site cannot reach back)", + "severity": "error", + }) + + if not peer.connection_id: + result["issues"].append({ + "code": "NO_CONNECTION", + "message": "No connection configured for this peer", + "severity": "error", + }) + return jsonify(result) + + connection = _connections().get(peer.connection_id) + if not connection: + result["issues"].append({ + "code": "CONNECTION_NOT_FOUND", + "message": f"Connection '{peer.connection_id}' not found", + "severity": "error", + }) + return jsonify(result) + + if not local_bidir_rules: + result["issues"].append({ + "code": "NO_LOCAL_BIDIRECTIONAL_RULES", + "message": "No bidirectional replication rules configured on this site", + "severity": "warning", + }) + + if not result["local_site_sync_enabled"]: + result["issues"].append({ + "code": "SITE_SYNC_DISABLED", + "message": "Site sync worker is disabled (SITE_SYNC_ENABLED=false). Pull operations will not work.", + "severity": "warning", + }) + + if not replication.check_endpoint_health(connection): + result["issues"].append({ + "code": "REMOTE_UNREACHABLE", + "message": "Remote endpoint is not reachable", + "severity": "error", + }) + return jsonify(result) + + allow_internal = current_app.config.get("ALLOW_INTERNAL_ENDPOINTS", False) + if not _is_safe_url(peer.endpoint, allow_internal=allow_internal): + result["issues"].append({ + "code": "ENDPOINT_NOT_ALLOWED", + "message": "Peer endpoint points to cloud metadata service (SSRF protection)", + "severity": "error", + }) + return jsonify(result) + + try: + admin_url = peer.endpoint.rstrip("/") + "/admin/sites" + resp = requests.get( + admin_url, + timeout=10, + headers={ + "Accept": "application/json", + "X-Access-Key": connection.access_key, + "X-Secret-Key": connection.secret_key, + }, + ) + + if resp.status_code == 200: + try: + remote_data = resp.json() + if not isinstance(remote_data, dict): + raise ValueError("Expected JSON object") + remote_local = remote_data.get("local") + if remote_local is not None and not isinstance(remote_local, dict): + raise ValueError("Expected 'local' to be an object") + remote_peers = remote_data.get("peers", []) + if not isinstance(remote_peers, list): + raise ValueError("Expected 'peers' to be a list") + except (ValueError, json.JSONDecodeError) as e: + logger.warning("Invalid JSON from remote admin API: %s", e) + result["remote_status"] = {"reachable": True, "invalid_response": True} + result["issues"].append({ + "code": "REMOTE_INVALID_RESPONSE", + "message": "Remote admin API returned invalid JSON", + "severity": "warning", + }) + return jsonify(result) + + result["remote_status"] = { + "reachable": True, + "local_site": remote_local, + "site_sync_enabled": None, + "has_peer_for_us": False, + "peer_connection_configured": False, + "has_bidirectional_rules_for_us": False, + } + + for rp in remote_peers: + if not isinstance(rp, dict): + continue + if local_site and ( + rp.get("site_id") == local_site.site_id or + rp.get("endpoint") == local_site.endpoint + ): + result["remote_status"]["has_peer_for_us"] = True + result["remote_status"]["peer_connection_configured"] = bool(rp.get("connection_id")) + break + + if not result["remote_status"]["has_peer_for_us"]: + result["issues"].append({ + "code": "REMOTE_NO_PEER_FOR_US", + "message": "Remote site does not have this site registered as a peer", + "severity": "error", + }) + elif not result["remote_status"]["peer_connection_configured"]: + result["issues"].append({ + "code": "REMOTE_NO_CONNECTION_FOR_US", + "message": "Remote site has us as peer but no connection configured (cannot push back)", + "severity": "error", + }) + elif resp.status_code == 401 or resp.status_code == 403: + result["remote_status"] = { + "reachable": True, + "admin_access_denied": True, + } + result["issues"].append({ + "code": "REMOTE_ADMIN_ACCESS_DENIED", + "message": "Cannot verify remote configuration (admin access denied)", + "severity": "warning", + }) + else: + result["remote_status"] = { + "reachable": True, + "admin_api_error": resp.status_code, + } + result["issues"].append({ + "code": "REMOTE_ADMIN_API_ERROR", + "message": f"Remote admin API returned status {resp.status_code}", + "severity": "warning", + }) + except requests.RequestException as e: + logger.warning("Remote admin API unreachable: %s", e) + result["remote_status"] = { + "reachable": False, + "error": "Connection failed", + } + result["issues"].append({ + "code": "REMOTE_ADMIN_UNREACHABLE", + "message": "Could not reach remote admin API", + "severity": "warning", + }) + except Exception as e: + logger.warning("Error checking remote bidirectional status: %s", e, exc_info=True) + result["issues"].append({ + "code": "VERIFICATION_ERROR", + "message": "Internal error during verification", + "severity": "warning", + }) + + error_issues = [i for i in result["issues"] if i["severity"] == "error"] + result["is_fully_configured"] = len(error_issues) == 0 and len(local_bidir_rules) > 0 + + return jsonify(result) diff --git a/app/config.py b/app/config.py index b39000f..bc03850 100644 --- a/app/config.py +++ b/app/config.py @@ -10,6 +10,23 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional +import psutil + + +def _calculate_auto_threads() -> int: + cpu_count = psutil.cpu_count(logical=True) or 4 + return max(1, min(cpu_count * 2, 64)) + + +def _calculate_auto_connection_limit() -> int: + available_mb = psutil.virtual_memory().available / (1024 * 1024) + calculated = int(available_mb / 5) + return max(20, min(calculated, 1000)) + + +def _calculate_auto_backlog(connection_limit: int) -> int: + return max(64, min(connection_limit * 2, 4096)) + def _validate_rate_limit(value: str) -> str: pattern = r"^\d+\s+per\s+(second|minute|hour|day)$" @@ -63,6 +80,10 @@ class AppConfig: log_backup_count: int ratelimit_default: str ratelimit_storage_uri: str + ratelimit_list_buckets: str + ratelimit_bucket_ops: str + ratelimit_object_ops: str + ratelimit_head_ops: str cors_origins: list[str] cors_methods: list[str] cors_allow_headers: list[str] @@ -94,9 +115,40 @@ class AppConfig: server_connection_limit: int server_backlog: int server_channel_timeout: int + server_threads_auto: bool + server_connection_limit_auto: bool + server_backlog_auto: bool site_sync_enabled: bool site_sync_interval_seconds: int site_sync_batch_size: int + sigv4_timestamp_tolerance_seconds: int + presigned_url_min_expiry_seconds: int + presigned_url_max_expiry_seconds: int + replication_connect_timeout_seconds: int + replication_read_timeout_seconds: int + replication_max_retries: int + replication_streaming_threshold_bytes: int + replication_max_failures_per_bucket: int + site_sync_connect_timeout_seconds: int + site_sync_read_timeout_seconds: int + site_sync_max_retries: int + site_sync_clock_skew_tolerance_seconds: float + object_key_max_length_bytes: int + object_cache_max_size: int + bucket_config_cache_ttl_seconds: float + object_tag_limit: int + encryption_chunk_size_bytes: int + kms_generate_data_key_min_bytes: int + kms_generate_data_key_max_bytes: int + lifecycle_max_history_per_bucket: int + site_id: Optional[str] + site_endpoint: Optional[str] + site_region: str + site_priority: int + ratelimit_admin: str + num_trusted_proxies: int + allowed_redirect_hosts: list[str] + allow_internal_endpoints: bool @classmethod def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig": @@ -171,6 +223,10 @@ class AppConfig: log_backup_count = int(_get("LOG_BACKUP_COUNT", 3)) ratelimit_default = _validate_rate_limit(str(_get("RATE_LIMIT_DEFAULT", "200 per minute"))) ratelimit_storage_uri = str(_get("RATE_LIMIT_STORAGE_URI", "memory://")) + ratelimit_list_buckets = _validate_rate_limit(str(_get("RATE_LIMIT_LIST_BUCKETS", "60 per minute"))) + ratelimit_bucket_ops = _validate_rate_limit(str(_get("RATE_LIMIT_BUCKET_OPS", "120 per minute"))) + ratelimit_object_ops = _validate_rate_limit(str(_get("RATE_LIMIT_OBJECT_OPS", "240 per minute"))) + ratelimit_head_ops = _validate_rate_limit(str(_get("RATE_LIMIT_HEAD_OPS", "100 per minute"))) def _csv(value: str, default: list[str]) -> list[str]: if not value: @@ -200,14 +256,68 @@ class AppConfig: operation_metrics_interval_minutes = int(_get("OPERATION_METRICS_INTERVAL_MINUTES", 5)) operation_metrics_retention_hours = int(_get("OPERATION_METRICS_RETENTION_HOURS", 24)) - server_threads = int(_get("SERVER_THREADS", 4)) - server_connection_limit = int(_get("SERVER_CONNECTION_LIMIT", 100)) - server_backlog = int(_get("SERVER_BACKLOG", 1024)) + _raw_threads = int(_get("SERVER_THREADS", 0)) + if _raw_threads == 0: + server_threads = _calculate_auto_threads() + server_threads_auto = True + else: + server_threads = _raw_threads + server_threads_auto = False + + _raw_conn_limit = int(_get("SERVER_CONNECTION_LIMIT", 0)) + if _raw_conn_limit == 0: + server_connection_limit = _calculate_auto_connection_limit() + server_connection_limit_auto = True + else: + server_connection_limit = _raw_conn_limit + server_connection_limit_auto = False + + _raw_backlog = int(_get("SERVER_BACKLOG", 0)) + if _raw_backlog == 0: + server_backlog = _calculate_auto_backlog(server_connection_limit) + server_backlog_auto = True + else: + server_backlog = _raw_backlog + server_backlog_auto = False + server_channel_timeout = int(_get("SERVER_CHANNEL_TIMEOUT", 120)) site_sync_enabled = str(_get("SITE_SYNC_ENABLED", "0")).lower() in {"1", "true", "yes", "on"} site_sync_interval_seconds = int(_get("SITE_SYNC_INTERVAL_SECONDS", 60)) site_sync_batch_size = int(_get("SITE_SYNC_BATCH_SIZE", 100)) + sigv4_timestamp_tolerance_seconds = int(_get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900)) + presigned_url_min_expiry_seconds = int(_get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1)) + presigned_url_max_expiry_seconds = int(_get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800)) + replication_connect_timeout_seconds = int(_get("REPLICATION_CONNECT_TIMEOUT_SECONDS", 5)) + replication_read_timeout_seconds = int(_get("REPLICATION_READ_TIMEOUT_SECONDS", 30)) + replication_max_retries = int(_get("REPLICATION_MAX_RETRIES", 2)) + replication_streaming_threshold_bytes = int(_get("REPLICATION_STREAMING_THRESHOLD_BYTES", 10 * 1024 * 1024)) + replication_max_failures_per_bucket = int(_get("REPLICATION_MAX_FAILURES_PER_BUCKET", 50)) + site_sync_connect_timeout_seconds = int(_get("SITE_SYNC_CONNECT_TIMEOUT_SECONDS", 10)) + site_sync_read_timeout_seconds = int(_get("SITE_SYNC_READ_TIMEOUT_SECONDS", 120)) + site_sync_max_retries = int(_get("SITE_SYNC_MAX_RETRIES", 2)) + site_sync_clock_skew_tolerance_seconds = float(_get("SITE_SYNC_CLOCK_SKEW_TOLERANCE_SECONDS", 1.0)) + object_key_max_length_bytes = int(_get("OBJECT_KEY_MAX_LENGTH_BYTES", 1024)) + object_cache_max_size = int(_get("OBJECT_CACHE_MAX_SIZE", 100)) + bucket_config_cache_ttl_seconds = float(_get("BUCKET_CONFIG_CACHE_TTL_SECONDS", 30.0)) + object_tag_limit = int(_get("OBJECT_TAG_LIMIT", 50)) + encryption_chunk_size_bytes = int(_get("ENCRYPTION_CHUNK_SIZE_BYTES", 64 * 1024)) + kms_generate_data_key_min_bytes = int(_get("KMS_GENERATE_DATA_KEY_MIN_BYTES", 1)) + kms_generate_data_key_max_bytes = int(_get("KMS_GENERATE_DATA_KEY_MAX_BYTES", 1024)) + lifecycle_max_history_per_bucket = int(_get("LIFECYCLE_MAX_HISTORY_PER_BUCKET", 50)) + + site_id_raw = _get("SITE_ID", None) + site_id = str(site_id_raw).strip() if site_id_raw else None + site_endpoint_raw = _get("SITE_ENDPOINT", None) + site_endpoint = str(site_endpoint_raw).strip() if site_endpoint_raw else None + site_region = str(_get("SITE_REGION", "us-east-1")) + site_priority = int(_get("SITE_PRIORITY", 100)) + ratelimit_admin = _validate_rate_limit(str(_get("RATE_LIMIT_ADMIN", "60 per minute"))) + num_trusted_proxies = int(_get("NUM_TRUSTED_PROXIES", 0)) + allowed_redirect_hosts_raw = _get("ALLOWED_REDIRECT_HOSTS", "") + allowed_redirect_hosts = [h.strip() for h in str(allowed_redirect_hosts_raw).split(",") if h.strip()] + allow_internal_endpoints = str(_get("ALLOW_INTERNAL_ENDPOINTS", "0")).lower() in {"1", "true", "yes", "on"} + return cls(storage_root=storage_root, max_upload_size=max_upload_size, ui_page_size=ui_page_size, @@ -225,6 +335,10 @@ class AppConfig: log_backup_count=log_backup_count, ratelimit_default=ratelimit_default, ratelimit_storage_uri=ratelimit_storage_uri, + ratelimit_list_buckets=ratelimit_list_buckets, + ratelimit_bucket_ops=ratelimit_bucket_ops, + ratelimit_object_ops=ratelimit_object_ops, + ratelimit_head_ops=ratelimit_head_ops, cors_origins=cors_origins, cors_methods=cors_methods, cors_allow_headers=cors_allow_headers, @@ -256,9 +370,40 @@ class AppConfig: server_connection_limit=server_connection_limit, server_backlog=server_backlog, server_channel_timeout=server_channel_timeout, + server_threads_auto=server_threads_auto, + server_connection_limit_auto=server_connection_limit_auto, + server_backlog_auto=server_backlog_auto, site_sync_enabled=site_sync_enabled, site_sync_interval_seconds=site_sync_interval_seconds, - site_sync_batch_size=site_sync_batch_size) + site_sync_batch_size=site_sync_batch_size, + sigv4_timestamp_tolerance_seconds=sigv4_timestamp_tolerance_seconds, + presigned_url_min_expiry_seconds=presigned_url_min_expiry_seconds, + presigned_url_max_expiry_seconds=presigned_url_max_expiry_seconds, + replication_connect_timeout_seconds=replication_connect_timeout_seconds, + replication_read_timeout_seconds=replication_read_timeout_seconds, + replication_max_retries=replication_max_retries, + replication_streaming_threshold_bytes=replication_streaming_threshold_bytes, + replication_max_failures_per_bucket=replication_max_failures_per_bucket, + site_sync_connect_timeout_seconds=site_sync_connect_timeout_seconds, + site_sync_read_timeout_seconds=site_sync_read_timeout_seconds, + site_sync_max_retries=site_sync_max_retries, + site_sync_clock_skew_tolerance_seconds=site_sync_clock_skew_tolerance_seconds, + object_key_max_length_bytes=object_key_max_length_bytes, + object_cache_max_size=object_cache_max_size, + bucket_config_cache_ttl_seconds=bucket_config_cache_ttl_seconds, + object_tag_limit=object_tag_limit, + encryption_chunk_size_bytes=encryption_chunk_size_bytes, + kms_generate_data_key_min_bytes=kms_generate_data_key_min_bytes, + kms_generate_data_key_max_bytes=kms_generate_data_key_max_bytes, + lifecycle_max_history_per_bucket=lifecycle_max_history_per_bucket, + site_id=site_id, + site_endpoint=site_endpoint, + site_region=site_region, + site_priority=site_priority, + ratelimit_admin=ratelimit_admin, + num_trusted_proxies=num_trusted_proxies, + allowed_redirect_hosts=allowed_redirect_hosts, + allow_internal_endpoints=allow_internal_endpoints) def validate_and_report(self) -> list[str]: """Validate configuration and return a list of warnings/issues. @@ -364,9 +509,11 @@ class AppConfig: print(f" ENCRYPTION: Enabled (Master key: {self.encryption_master_key_path})") if self.kms_enabled: print(f" KMS: Enabled (Keys: {self.kms_keys_path})") - print(f" SERVER_THREADS: {self.server_threads}") - print(f" CONNECTION_LIMIT: {self.server_connection_limit}") - print(f" BACKLOG: {self.server_backlog}") + def _auto(flag: bool) -> str: + return " (auto)" if flag else "" + print(f" SERVER_THREADS: {self.server_threads}{_auto(self.server_threads_auto)}") + print(f" CONNECTION_LIMIT: {self.server_connection_limit}{_auto(self.server_connection_limit_auto)}") + print(f" BACKLOG: {self.server_backlog}{_auto(self.server_backlog_auto)}") print(f" CHANNEL_TIMEOUT: {self.server_channel_timeout}s") print("=" * 60) @@ -406,6 +553,10 @@ class AppConfig: "LOG_BACKUP_COUNT": self.log_backup_count, "RATELIMIT_DEFAULT": self.ratelimit_default, "RATELIMIT_STORAGE_URI": self.ratelimit_storage_uri, + "RATELIMIT_LIST_BUCKETS": self.ratelimit_list_buckets, + "RATELIMIT_BUCKET_OPS": self.ratelimit_bucket_ops, + "RATELIMIT_OBJECT_OPS": self.ratelimit_object_ops, + "RATELIMIT_HEAD_OPS": self.ratelimit_head_ops, "CORS_ORIGINS": self.cors_origins, "CORS_METHODS": self.cors_methods, "CORS_ALLOW_HEADERS": self.cors_allow_headers, @@ -432,4 +583,32 @@ class AppConfig: "SITE_SYNC_ENABLED": self.site_sync_enabled, "SITE_SYNC_INTERVAL_SECONDS": self.site_sync_interval_seconds, "SITE_SYNC_BATCH_SIZE": self.site_sync_batch_size, + "SIGV4_TIMESTAMP_TOLERANCE_SECONDS": self.sigv4_timestamp_tolerance_seconds, + "PRESIGNED_URL_MIN_EXPIRY_SECONDS": self.presigned_url_min_expiry_seconds, + "PRESIGNED_URL_MAX_EXPIRY_SECONDS": self.presigned_url_max_expiry_seconds, + "REPLICATION_CONNECT_TIMEOUT_SECONDS": self.replication_connect_timeout_seconds, + "REPLICATION_READ_TIMEOUT_SECONDS": self.replication_read_timeout_seconds, + "REPLICATION_MAX_RETRIES": self.replication_max_retries, + "REPLICATION_STREAMING_THRESHOLD_BYTES": self.replication_streaming_threshold_bytes, + "REPLICATION_MAX_FAILURES_PER_BUCKET": self.replication_max_failures_per_bucket, + "SITE_SYNC_CONNECT_TIMEOUT_SECONDS": self.site_sync_connect_timeout_seconds, + "SITE_SYNC_READ_TIMEOUT_SECONDS": self.site_sync_read_timeout_seconds, + "SITE_SYNC_MAX_RETRIES": self.site_sync_max_retries, + "SITE_SYNC_CLOCK_SKEW_TOLERANCE_SECONDS": self.site_sync_clock_skew_tolerance_seconds, + "OBJECT_KEY_MAX_LENGTH_BYTES": self.object_key_max_length_bytes, + "OBJECT_CACHE_MAX_SIZE": self.object_cache_max_size, + "BUCKET_CONFIG_CACHE_TTL_SECONDS": self.bucket_config_cache_ttl_seconds, + "OBJECT_TAG_LIMIT": self.object_tag_limit, + "ENCRYPTION_CHUNK_SIZE_BYTES": self.encryption_chunk_size_bytes, + "KMS_GENERATE_DATA_KEY_MIN_BYTES": self.kms_generate_data_key_min_bytes, + "KMS_GENERATE_DATA_KEY_MAX_BYTES": self.kms_generate_data_key_max_bytes, + "LIFECYCLE_MAX_HISTORY_PER_BUCKET": self.lifecycle_max_history_per_bucket, + "SITE_ID": self.site_id, + "SITE_ENDPOINT": self.site_endpoint, + "SITE_REGION": self.site_region, + "SITE_PRIORITY": self.site_priority, + "RATE_LIMIT_ADMIN": self.ratelimit_admin, + "NUM_TRUSTED_PROXIES": self.num_trusted_proxies, + "ALLOWED_REDIRECT_HOSTS": self.allowed_redirect_hosts, + "ALLOW_INTERNAL_ENDPOINTS": self.allow_internal_endpoints, } diff --git a/app/encryption.py b/app/encryption.py index 4b1d817..6d8c2b2 100644 --- a/app/encryption.py +++ b/app/encryption.py @@ -1,15 +1,44 @@ -"""Encryption providers for server-side and client-side encryption.""" from __future__ import annotations import base64 import io import json +import logging +import os import secrets +import subprocess +import sys from dataclasses import dataclass from pathlib import Path from typing import Any, BinaryIO, Dict, Generator, Optional from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives import hashes + +if sys.platform != "win32": + import fcntl + +logger = logging.getLogger(__name__) + + +def _set_secure_file_permissions(file_path: Path) -> None: + """Set restrictive file permissions (owner read/write only).""" + if sys.platform == "win32": + try: + username = os.environ.get("USERNAME", "") + if username: + subprocess.run( + ["icacls", str(file_path), "/inheritance:r", + "/grant:r", f"{username}:F"], + check=True, capture_output=True + ) + else: + logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path) + except (subprocess.SubprocessError, OSError) as exc: + logger.warning("Failed to set secure permissions on %s: %s", file_path, exc) + else: + os.chmod(file_path, 0o600) class EncryptionError(Exception): @@ -59,22 +88,34 @@ class EncryptionMetadata: class EncryptionProvider: """Base class for encryption providers.""" - + def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: raise NotImplementedError - + def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, key_id: str, context: Dict[str, str] | None = None) -> bytes: raise NotImplementedError - + def generate_data_key(self) -> tuple[bytes, bytes]: """Generate a data key and its encrypted form. - + Returns: Tuple of (plaintext_key, encrypted_key) """ raise NotImplementedError + def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes: + """Decrypt an encrypted data key. + + Args: + encrypted_data_key: The encrypted data key bytes + key_id: Optional key identifier (used by KMS providers) + + Returns: + The decrypted data key + """ + raise NotImplementedError + class LocalKeyEncryption(EncryptionProvider): """SSE-S3 style encryption using a local master key. @@ -99,28 +140,48 @@ class LocalKeyEncryption(EncryptionProvider): return self._master_key def _load_or_create_master_key(self) -> bytes: - """Load master key from file or generate a new one.""" - if self.master_key_path.exists(): - try: - return base64.b64decode(self.master_key_path.read_text().strip()) - except Exception as exc: - raise EncryptionError(f"Failed to load master key: {exc}") from exc - - key = secrets.token_bytes(32) + """Load master key from file or generate a new one (with file locking).""" + lock_path = self.master_key_path.with_suffix(".lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + try: - self.master_key_path.parent.mkdir(parents=True, exist_ok=True) - self.master_key_path.write_text(base64.b64encode(key).decode()) + with open(lock_path, "w") as lock_file: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + try: + if self.master_key_path.exists(): + try: + return base64.b64decode(self.master_key_path.read_text().strip()) + except Exception as exc: + raise EncryptionError(f"Failed to load master key: {exc}") from exc + key = secrets.token_bytes(32) + try: + self.master_key_path.write_text(base64.b64encode(key).decode()) + _set_secure_file_permissions(self.master_key_path) + except OSError as exc: + raise EncryptionError(f"Failed to save master key: {exc}") from exc + return key + finally: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) except OSError as exc: - raise EncryptionError(f"Failed to save master key: {exc}") from exc - return key + raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc + DATA_KEY_AAD = b'{"purpose":"data_key","version":1}' + def _encrypt_data_key(self, data_key: bytes) -> bytes: """Encrypt the data key with the master key.""" aesgcm = AESGCM(self.master_key) nonce = secrets.token_bytes(12) - encrypted = aesgcm.encrypt(nonce, data_key, None) + encrypted = aesgcm.encrypt(nonce, data_key, self.DATA_KEY_AAD) return nonce + encrypted - + def _decrypt_data_key(self, encrypted_data_key: bytes) -> bytes: """Decrypt the data key using the master key.""" if len(encrypted_data_key) < 12 + 32 + 16: # nonce + key + tag @@ -129,10 +190,17 @@ class LocalKeyEncryption(EncryptionProvider): nonce = encrypted_data_key[:12] ciphertext = encrypted_data_key[12:] try: - return aesgcm.decrypt(nonce, ciphertext, None) - except Exception as exc: - raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc - + return aesgcm.decrypt(nonce, ciphertext, self.DATA_KEY_AAD) + except Exception: + try: + return aesgcm.decrypt(nonce, ciphertext, None) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc + + def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes: + """Decrypt an encrypted data key (key_id ignored for local encryption).""" + return self._decrypt_data_key(encrypted_data_key) + def generate_data_key(self) -> tuple[bytes, bytes]: """Generate a data key and its encrypted form.""" plaintext_key = secrets.token_bytes(32) @@ -142,11 +210,12 @@ class LocalKeyEncryption(EncryptionProvider): def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: """Encrypt data using envelope encryption.""" data_key, encrypted_data_key = self.generate_data_key() - + aesgcm = AESGCM(data_key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, None) - + aad = json.dumps(context, sort_keys=True).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) + return EncryptionResult( ciphertext=ciphertext, nonce=nonce, @@ -159,10 +228,11 @@ class LocalKeyEncryption(EncryptionProvider): """Decrypt data using envelope encryption.""" data_key = self._decrypt_data_key(encrypted_data_key) aesgcm = AESGCM(data_key) + aad = json.dumps(context, sort_keys=True).encode() if context else None try: - return aesgcm.decrypt(nonce, ciphertext, None) + return aesgcm.decrypt(nonce, ciphertext, aad) except Exception as exc: - raise EncryptionError(f"Failed to decrypt data: {exc}") from exc + raise EncryptionError("Failed to decrypt data") from exc class StreamingEncryptor: @@ -180,12 +250,14 @@ class StreamingEncryptor: self.chunk_size = chunk_size def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes: - """Derive a unique nonce for each chunk. - - Performance: Use direct byte manipulation instead of full int conversion. - """ - # Performance: Only modify last 4 bytes instead of full 12-byte conversion - return base_nonce[:8] + (chunk_index ^ int.from_bytes(base_nonce[8:], "big")).to_bytes(4, "big") + """Derive a unique nonce for each chunk using HKDF.""" + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=12, + salt=base_nonce, + info=chunk_index.to_bytes(4, "big"), + ) + return hkdf.derive(b"chunk_nonce") def encrypt_stream(self, stream: BinaryIO, context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]: @@ -234,10 +306,7 @@ class StreamingEncryptor: Performance: Writes chunks directly to output buffer instead of accumulating in list. """ - if isinstance(self.provider, LocalKeyEncryption): - data_key = self.provider._decrypt_data_key(metadata.encrypted_data_key) - else: - raise EncryptionError("Unsupported provider for streaming decryption") + data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id) aesgcm = AESGCM(data_key) base_nonce = metadata.nonce @@ -310,7 +379,8 @@ class EncryptionManager: def get_streaming_encryptor(self) -> StreamingEncryptor: if self._streaming_encryptor is None: - self._streaming_encryptor = StreamingEncryptor(self.get_local_provider()) + chunk_size = self.config.get("encryption_chunk_size_bytes", 64 * 1024) + self._streaming_encryptor = StreamingEncryptor(self.get_local_provider(), chunk_size=chunk_size) return self._streaming_encryptor def encrypt_object(self, data: bytes, algorithm: str = "AES256", @@ -403,7 +473,8 @@ class SSECEncryption(EncryptionProvider): def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: aesgcm = AESGCM(self.customer_key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, None) + aad = json.dumps(context, sort_keys=True).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) return EncryptionResult( ciphertext=ciphertext, @@ -415,10 +486,11 @@ class SSECEncryption(EncryptionProvider): def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, key_id: str, context: Dict[str, str] | None = None) -> bytes: aesgcm = AESGCM(self.customer_key) + aad = json.dumps(context, sort_keys=True).encode() if context else None try: - return aesgcm.decrypt(nonce, ciphertext, None) + return aesgcm.decrypt(nonce, ciphertext, aad) except Exception as exc: - raise EncryptionError(f"SSE-C decryption failed: {exc}") from exc + raise EncryptionError("SSE-C decryption failed") from exc def generate_data_key(self) -> tuple[bytes, bytes]: return self.customer_key, b"" @@ -472,34 +544,36 @@ class ClientEncryptionHelper: } @staticmethod - def encrypt_with_key(plaintext: bytes, key_b64: str) -> Dict[str, str]: + def encrypt_with_key(plaintext: bytes, key_b64: str, context: Dict[str, str] | None = None) -> Dict[str, str]: """Encrypt data with a client-provided key.""" key = base64.b64decode(key_b64) if len(key) != 32: raise EncryptionError("Key must be 256 bits (32 bytes)") - + aesgcm = AESGCM(key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, None) - + aad = json.dumps(context, sort_keys=True).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) + return { "ciphertext": base64.b64encode(ciphertext).decode(), "nonce": base64.b64encode(nonce).decode(), "algorithm": "AES-256-GCM", } - + @staticmethod - def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str) -> bytes: + def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str, context: Dict[str, str] | None = None) -> bytes: """Decrypt data with a client-provided key.""" key = base64.b64decode(key_b64) nonce = base64.b64decode(nonce_b64) ciphertext = base64.b64decode(ciphertext_b64) - + if len(key) != 32: raise EncryptionError("Key must be 256 bits (32 bytes)") - + aesgcm = AESGCM(key) + aad = json.dumps(context, sort_keys=True).encode() if context else None try: - return aesgcm.decrypt(nonce, ciphertext, None) + return aesgcm.decrypt(nonce, ciphertext, aad) except Exception as exc: - raise EncryptionError(f"Decryption failed: {exc}") from exc + raise EncryptionError("Decryption failed") from exc diff --git a/app/errors.py b/app/errors.py index 7e5d711..b2a4079 100644 --- a/app/errors.py +++ b/app/errors.py @@ -6,6 +6,7 @@ from typing import Optional, Dict, Any from xml.etree.ElementTree import Element, SubElement, tostring from flask import Response, jsonify, request, flash, redirect, url_for, g +from flask_limiter import RateLimitExceeded logger = logging.getLogger(__name__) @@ -172,10 +173,22 @@ def handle_app_error(error: AppError) -> Response: return error.to_xml_response() +def handle_rate_limit_exceeded(e: RateLimitExceeded) -> Response: + g.s3_error_code = "SlowDown" + error = Element("Error") + SubElement(error, "Code").text = "SlowDown" + SubElement(error, "Message").text = "Please reduce your request rate." + SubElement(error, "Resource").text = request.path + SubElement(error, "RequestId").text = getattr(g, "request_id", "") + xml_bytes = tostring(error, encoding="utf-8") + return Response(xml_bytes, status=429, mimetype="application/xml") + + def register_error_handlers(app): """Register error handlers with a Flask app.""" app.register_error_handler(AppError, handle_app_error) - + app.register_error_handler(RateLimitExceeded, handle_rate_limit_exceeded) + for error_class in [ BucketNotFoundError, BucketAlreadyExistsError, BucketNotEmptyError, ObjectNotFoundError, InvalidObjectKeyError, diff --git a/app/iam.py b/app/iam.py index 0e5e80f..caf6b07 100644 --- a/app/iam.py +++ b/app/iam.py @@ -1,9 +1,11 @@ from __future__ import annotations +import hashlib import hmac import json import math import secrets +import threading import time from collections import deque from dataclasses import dataclass @@ -118,12 +120,14 @@ class IamService: self._raw_config: Dict[str, Any] = {} self._failed_attempts: Dict[str, Deque[datetime]] = {} self._last_load_time = 0.0 - self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {} - self._cache_ttl = 60.0 + self._principal_cache: Dict[str, Tuple[Principal, float]] = {} + self._cache_ttl = 10.0 self._last_stat_check = 0.0 self._stat_check_interval = 1.0 self._sessions: Dict[str, Dict[str, Any]] = {} + self._session_lock = threading.Lock() self._load() + self._load_lockout_state() def _maybe_reload(self) -> None: """Reload configuration if the file has changed on disk.""" @@ -134,7 +138,7 @@ class IamService: try: if self.config_path.stat().st_mtime > self._last_load_time: self._load() - self._credential_cache.clear() + self._principal_cache.clear() except OSError: pass @@ -150,7 +154,8 @@ class IamService: f"Access temporarily locked. Try again in {seconds} seconds." ) record = self._users.get(access_key) - if not record or not hmac.compare_digest(record["secret_key"], secret_key): + stored_secret = record["secret_key"] if record else secrets.token_urlsafe(24) + if not record or not hmac.compare_digest(stored_secret, secret_key): self._record_failed_attempt(access_key) raise IamError("Invalid credentials") self._clear_failed_attempts(access_key) @@ -162,11 +167,46 @@ class IamService: attempts = self._failed_attempts.setdefault(access_key, deque()) self._prune_attempts(attempts) attempts.append(datetime.now(timezone.utc)) + self._save_lockout_state() def _clear_failed_attempts(self, access_key: str) -> None: if not access_key: return - self._failed_attempts.pop(access_key, None) + if self._failed_attempts.pop(access_key, None) is not None: + self._save_lockout_state() + + def _lockout_file(self) -> Path: + return self.config_path.parent / "lockout_state.json" + + def _load_lockout_state(self) -> None: + """Load lockout state from disk.""" + try: + if self._lockout_file().exists(): + data = json.loads(self._lockout_file().read_text(encoding="utf-8")) + cutoff = datetime.now(timezone.utc) - self.auth_lockout_window + for key, timestamps in data.get("failed_attempts", {}).items(): + valid = [] + for ts in timestamps: + try: + dt = datetime.fromisoformat(ts) + if dt > cutoff: + valid.append(dt) + except (ValueError, TypeError): + continue + if valid: + self._failed_attempts[key] = deque(valid) + except (OSError, json.JSONDecodeError): + pass + + def _save_lockout_state(self) -> None: + """Persist lockout state to disk.""" + data: Dict[str, Any] = {"failed_attempts": {}} + for key, attempts in self._failed_attempts.items(): + data["failed_attempts"][key] = [ts.isoformat() for ts in attempts] + try: + self._lockout_file().write_text(json.dumps(data), encoding="utf-8") + except OSError: + pass def _prune_attempts(self, attempts: Deque[datetime]) -> None: cutoff = datetime.now(timezone.utc) - self.auth_lockout_window @@ -209,16 +249,23 @@ class IamService: return token def validate_session_token(self, access_key: str, session_token: str) -> bool: - """Validate a session token for an access key.""" - session = self._sessions.get(session_token) - if not session: - return False - if session["access_key"] != access_key: - return False - if time.time() > session["expires_at"]: - del self._sessions[session_token] - return False - return True + """Validate a session token for an access key (thread-safe, constant-time).""" + dummy_key = secrets.token_urlsafe(16) + dummy_token = secrets.token_urlsafe(32) + with self._session_lock: + session = self._sessions.get(session_token) + if not session: + hmac.compare_digest(access_key, dummy_key) + hmac.compare_digest(session_token, dummy_token) + return False + key_match = hmac.compare_digest(session["access_key"], access_key) + if not key_match: + hmac.compare_digest(session_token, dummy_token) + return False + if time.time() > session["expires_at"]: + self._sessions.pop(session_token, None) + return False + return True def _cleanup_expired_sessions(self) -> None: """Remove expired session tokens.""" @@ -229,9 +276,9 @@ class IamService: def principal_for_key(self, access_key: str) -> Principal: now = time.time() - cached = self._credential_cache.get(access_key) + cached = self._principal_cache.get(access_key) if cached: - secret, principal, cached_time = cached + principal, cached_time = cached if now - cached_time < self._cache_ttl: return principal @@ -240,23 +287,14 @@ class IamService: if not record: raise IamError("Unknown access key") principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) + self._principal_cache[access_key] = (principal, now) return principal def secret_for_key(self, access_key: str) -> str: - now = time.time() - cached = self._credential_cache.get(access_key) - if cached: - secret, principal, cached_time = cached - if now - cached_time < self._cache_ttl: - return secret - self._maybe_reload() record = self._users.get(access_key) if not record: raise IamError("Unknown access key") - principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) return record["secret_key"] def authorize(self, principal: Principal, bucket_name: str | None, action: str) -> None: @@ -328,6 +366,7 @@ class IamService: new_secret = self._generate_secret_key() user["secret_key"] = new_secret self._save() + self._principal_cache.pop(access_key, None) self._load() return new_secret @@ -507,26 +546,17 @@ class IamService: raise IamError("User not found") def get_secret_key(self, access_key: str) -> str | None: - now = time.time() - cached = self._credential_cache.get(access_key) - if cached: - secret, principal, cached_time = cached - if now - cached_time < self._cache_ttl: - return secret - self._maybe_reload() record = self._users.get(access_key) if record: - principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) return record["secret_key"] return None def get_principal(self, access_key: str) -> Principal | None: now = time.time() - cached = self._credential_cache.get(access_key) + cached = self._principal_cache.get(access_key) if cached: - secret, principal, cached_time = cached + principal, cached_time = cached if now - cached_time < self._cache_ttl: return principal @@ -534,6 +564,6 @@ class IamService: record = self._users.get(access_key) if record: principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) + self._principal_cache[access_key] = (principal, now) return principal return None diff --git a/app/kms.py b/app/kms.py index 548e7ea..dbd07e0 100644 --- a/app/kms.py +++ b/app/kms.py @@ -2,7 +2,11 @@ from __future__ import annotations import base64 import json +import logging +import os import secrets +import subprocess +import sys import uuid from dataclasses import dataclass, field from datetime import datetime, timezone @@ -13,6 +17,30 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from .encryption import EncryptionError, EncryptionProvider, EncryptionResult +if sys.platform != "win32": + import fcntl + +logger = logging.getLogger(__name__) + + +def _set_secure_file_permissions(file_path: Path) -> None: + """Set restrictive file permissions (owner read/write only).""" + if sys.platform == "win32": + try: + username = os.environ.get("USERNAME", "") + if username: + subprocess.run( + ["icacls", str(file_path), "/inheritance:r", + "/grant:r", f"{username}:F"], + check=True, capture_output=True + ) + else: + logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path) + except (subprocess.SubprocessError, OSError) as exc: + logger.warning("Failed to set secure permissions on %s: %s", file_path, exc) + else: + os.chmod(file_path, 0o600) + @dataclass class KMSKey: @@ -74,11 +102,11 @@ class KMSEncryptionProvider(EncryptionProvider): def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: """Encrypt data using envelope encryption with KMS.""" data_key, encrypted_data_key = self.generate_data_key() - + aesgcm = AESGCM(data_key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, - json.dumps(context).encode() if context else None) + ciphertext = aesgcm.encrypt(nonce, plaintext, + json.dumps(context, sort_keys=True).encode() if context else None) return EncryptionResult( ciphertext=ciphertext, @@ -90,15 +118,26 @@ class KMSEncryptionProvider(EncryptionProvider): def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, key_id: str, context: Dict[str, str] | None = None) -> bytes: """Decrypt data using envelope encryption with KMS.""" - # Note: Data key is encrypted without context (AAD), so we decrypt without context data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None) - + if len(data_key) != 32: + raise EncryptionError("Invalid data key size") + aesgcm = AESGCM(data_key) try: return aesgcm.decrypt(nonce, ciphertext, - json.dumps(context).encode() if context else None) + json.dumps(context, sort_keys=True).encode() if context else None) except Exception as exc: - raise EncryptionError(f"Failed to decrypt data: {exc}") from exc + logger.debug("KMS decryption failed: %s", exc) + raise EncryptionError("Failed to decrypt data") from exc + + def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes: + """Decrypt an encrypted data key using KMS.""" + if key_id is None: + key_id = self.key_id + data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None) + if len(data_key) != 32: + raise EncryptionError("Invalid data key size") + return data_key class KMSManager: @@ -108,27 +147,50 @@ class KMSManager: Keys are stored encrypted on disk. """ - def __init__(self, keys_path: Path, master_key_path: Path): + def __init__( + self, + keys_path: Path, + master_key_path: Path, + generate_data_key_min_bytes: int = 1, + generate_data_key_max_bytes: int = 1024, + ): self.keys_path = keys_path self.master_key_path = master_key_path + self.generate_data_key_min_bytes = generate_data_key_min_bytes + self.generate_data_key_max_bytes = generate_data_key_max_bytes self._keys: Dict[str, KMSKey] = {} self._master_key: bytes | None = None self._loaded = False @property def master_key(self) -> bytes: - """Load or create the master key for encrypting KMS keys.""" + """Load or create the master key for encrypting KMS keys (with file locking).""" if self._master_key is None: - if self.master_key_path.exists(): - self._master_key = base64.b64decode( - self.master_key_path.read_text().strip() - ) - else: - self._master_key = secrets.token_bytes(32) - self.master_key_path.parent.mkdir(parents=True, exist_ok=True) - self.master_key_path.write_text( - base64.b64encode(self._master_key).decode() - ) + lock_path = self.master_key_path.with_suffix(".lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + with open(lock_path, "w") as lock_file: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + try: + if self.master_key_path.exists(): + self._master_key = base64.b64decode( + self.master_key_path.read_text().strip() + ) + else: + self._master_key = secrets.token_bytes(32) + self.master_key_path.write_text( + base64.b64encode(self._master_key).decode() + ) + _set_secure_file_permissions(self.master_key_path) + finally: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) return self._master_key def _load_keys(self) -> None: @@ -145,8 +207,10 @@ class KMSManager: encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"]) key.key_material = self._decrypt_key_material(encrypted) self._keys[key.key_id] = key - except Exception: - pass + except json.JSONDecodeError as exc: + logger.error("Failed to parse KMS keys file: %s", exc) + except (ValueError, KeyError) as exc: + logger.error("Invalid KMS key data: %s", exc) self._loaded = True @@ -158,12 +222,13 @@ class KMSManager: encrypted = self._encrypt_key_material(key.key_material) data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode() keys_data.append(data) - + self.keys_path.parent.mkdir(parents=True, exist_ok=True) self.keys_path.write_text( json.dumps({"keys": keys_data}, indent=2), encoding="utf-8" ) + _set_secure_file_permissions(self.keys_path) def _encrypt_key_material(self, key_material: bytes) -> bytes: """Encrypt key material with the master key.""" @@ -269,7 +334,7 @@ class KMSManager: aesgcm = AESGCM(key.key_material) nonce = secrets.token_bytes(12) - aad = json.dumps(context).encode() if context else None + aad = json.dumps(context, sort_keys=True).encode() if context else None ciphertext = aesgcm.encrypt(nonce, plaintext, aad) key_id_bytes = key_id.encode("utf-8") @@ -298,17 +363,24 @@ class KMSManager: encrypted = rest[12:] aesgcm = AESGCM(key.key_material) - aad = json.dumps(context).encode() if context else None + aad = json.dumps(context, sort_keys=True).encode() if context else None try: plaintext = aesgcm.decrypt(nonce, encrypted, aad) return plaintext, key_id except Exception as exc: - raise EncryptionError(f"Decryption failed: {exc}") from exc + logger.debug("KMS decrypt operation failed: %s", exc) + raise EncryptionError("Decryption failed") from exc def generate_data_key(self, key_id: str, - context: Dict[str, str] | None = None) -> tuple[bytes, bytes]: + context: Dict[str, str] | None = None, + key_spec: str = "AES_256") -> tuple[bytes, bytes]: """Generate a data key and return both plaintext and encrypted versions. - + + Args: + key_id: The KMS key ID to use for encryption + context: Optional encryption context + key_spec: Key specification - AES_128 or AES_256 (default) + Returns: Tuple of (plaintext_key, encrypted_key) """ @@ -318,11 +390,12 @@ class KMSManager: raise EncryptionError(f"Key not found: {key_id}") if not key.enabled: raise EncryptionError(f"Key is disabled: {key_id}") - - plaintext_key = secrets.token_bytes(32) + + key_bytes = 32 if key_spec == "AES_256" else 16 + plaintext_key = secrets.token_bytes(key_bytes) encrypted_key = self.encrypt(key_id, plaintext_key, context) - + return plaintext_key, encrypted_key def decrypt_data_key(self, key_id: str, encrypted_key: bytes, @@ -358,6 +431,8 @@ class KMSManager: def generate_random(self, num_bytes: int = 32) -> bytes: """Generate cryptographically secure random bytes.""" - if num_bytes < 1 or num_bytes > 1024: - raise EncryptionError("Number of bytes must be between 1 and 1024") + if num_bytes < self.generate_data_key_min_bytes or num_bytes > self.generate_data_key_max_bytes: + raise EncryptionError( + f"Number of bytes must be between {self.generate_data_key_min_bytes} and {self.generate_data_key_max_bytes}" + ) return secrets.token_bytes(num_bytes) diff --git a/app/lifecycle.py b/app/lifecycle.py index ed9eb2c..ea2c262 100644 --- a/app/lifecycle.py +++ b/app/lifecycle.py @@ -71,10 +71,9 @@ class LifecycleExecutionRecord: class LifecycleHistoryStore: - MAX_HISTORY_PER_BUCKET = 50 - - def __init__(self, storage_root: Path) -> None: + def __init__(self, storage_root: Path, max_history_per_bucket: int = 50) -> None: self.storage_root = storage_root + self.max_history_per_bucket = max_history_per_bucket self._lock = threading.Lock() def _get_history_path(self, bucket_name: str) -> Path: @@ -95,7 +94,7 @@ class LifecycleHistoryStore: def save_history(self, bucket_name: str, records: List[LifecycleExecutionRecord]) -> None: path = self._get_history_path(bucket_name) path.parent.mkdir(parents=True, exist_ok=True) - data = {"executions": [r.to_dict() for r in records[:self.MAX_HISTORY_PER_BUCKET]]} + data = {"executions": [r.to_dict() for r in records[:self.max_history_per_bucket]]} try: with open(path, "w") as f: json.dump(data, f, indent=2) @@ -114,14 +113,20 @@ class LifecycleHistoryStore: class LifecycleManager: - def __init__(self, storage: ObjectStorage, interval_seconds: int = 3600, storage_root: Optional[Path] = None): + def __init__( + self, + storage: ObjectStorage, + interval_seconds: int = 3600, + storage_root: Optional[Path] = None, + max_history_per_bucket: int = 50, + ): self.storage = storage self.interval_seconds = interval_seconds self.storage_root = storage_root self._timer: Optional[threading.Timer] = None self._shutdown = False self._lock = threading.Lock() - self.history_store = LifecycleHistoryStore(storage_root) if storage_root else None + self.history_store = LifecycleHistoryStore(storage_root, max_history_per_bucket) if storage_root else None def start(self) -> None: if self._timer is not None: diff --git a/app/notifications.py b/app/notifications.py index c449088..6951095 100644 --- a/app/notifications.py +++ b/app/notifications.py @@ -1,8 +1,10 @@ from __future__ import annotations +import ipaddress import json import logging import queue +import socket import threading import time import uuid @@ -14,6 +16,48 @@ from urllib.parse import urlparse import requests + +def _is_safe_url(url: str, allow_internal: bool = False) -> bool: + """Check if a URL is safe to make requests to (not internal/private). + + Args: + url: The URL to check. + allow_internal: If True, allows internal/private IP addresses. + Use for self-hosted deployments on internal networks. + """ + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return False + cloud_metadata_hosts = { + "metadata.google.internal", + "169.254.169.254", + } + if hostname.lower() in cloud_metadata_hosts: + return False + if allow_internal: + return True + blocked_hosts = { + "localhost", + "127.0.0.1", + "0.0.0.0", + "::1", + "[::1]", + } + if hostname.lower() in blocked_hosts: + return False + try: + resolved_ip = socket.gethostbyname(hostname) + ip = ipaddress.ip_address(resolved_ip) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return False + except (socket.gaierror, ValueError): + return False + return True + except Exception: + return False + logger = logging.getLogger(__name__) @@ -165,8 +209,9 @@ class NotificationConfiguration: class NotificationService: - def __init__(self, storage_root: Path, worker_count: int = 2): + def __init__(self, storage_root: Path, worker_count: int = 2, allow_internal_endpoints: bool = False): self.storage_root = storage_root + self._allow_internal_endpoints = allow_internal_endpoints self._configs: Dict[str, List[NotificationConfiguration]] = {} self._queue: queue.Queue[tuple[NotificationEvent, WebhookDestination]] = queue.Queue() self._workers: List[threading.Thread] = [] @@ -299,6 +344,8 @@ class NotificationService: self._queue.task_done() def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None: + if not _is_safe_url(destination.url, allow_internal=self._allow_internal_endpoints): + raise RuntimeError(f"Blocked request to cloud metadata service (SSRF protection): {destination.url}") payload = event.to_s3_event() headers = {"Content-Type": "application/json", **destination.headers} diff --git a/app/replication.py b/app/replication.py index 9cab869..8cde3e8 100644 --- a/app/replication.py +++ b/app/replication.py @@ -21,16 +21,20 @@ from .storage import ObjectStorage, StorageError logger = logging.getLogger(__name__) REPLICATION_USER_AGENT = "S3ReplicationAgent/1.0" -REPLICATION_CONNECT_TIMEOUT = 5 -REPLICATION_READ_TIMEOUT = 30 -STREAMING_THRESHOLD_BYTES = 10 * 1024 * 1024 REPLICATION_MODE_NEW_ONLY = "new_only" REPLICATION_MODE_ALL = "all" REPLICATION_MODE_BIDIRECTIONAL = "bidirectional" -def _create_s3_client(connection: RemoteConnection, *, health_check: bool = False) -> Any: +def _create_s3_client( + connection: RemoteConnection, + *, + health_check: bool = False, + connect_timeout: int = 5, + read_timeout: int = 30, + max_retries: int = 2, +) -> Any: """Create a boto3 S3 client for the given connection. Args: connection: Remote S3 connection configuration @@ -38,9 +42,9 @@ def _create_s3_client(connection: RemoteConnection, *, health_check: bool = Fals """ config = Config( user_agent_extra=REPLICATION_USER_AGENT, - connect_timeout=REPLICATION_CONNECT_TIMEOUT, - read_timeout=REPLICATION_READ_TIMEOUT, - retries={'max_attempts': 1 if health_check else 2}, + connect_timeout=connect_timeout, + read_timeout=read_timeout, + retries={'max_attempts': 1 if health_check else max_retries}, signature_version='s3v4', s3={'addressing_style': 'path'}, request_checksum_calculation='when_required', @@ -133,6 +137,7 @@ class ReplicationRule: stats: ReplicationStats = field(default_factory=ReplicationStats) sync_deletions: bool = True last_pull_at: Optional[float] = None + filter_prefix: Optional[str] = None def to_dict(self) -> dict: return { @@ -145,6 +150,7 @@ class ReplicationRule: "stats": self.stats.to_dict(), "sync_deletions": self.sync_deletions, "last_pull_at": self.last_pull_at, + "filter_prefix": self.filter_prefix, } @classmethod @@ -158,16 +164,17 @@ class ReplicationRule: data["sync_deletions"] = True if "last_pull_at" not in data: data["last_pull_at"] = None + if "filter_prefix" not in data: + data["filter_prefix"] = None rule = cls(**data) rule.stats = ReplicationStats.from_dict(stats_data) if stats_data else ReplicationStats() return rule class ReplicationFailureStore: - MAX_FAILURES_PER_BUCKET = 50 - - def __init__(self, storage_root: Path) -> None: + def __init__(self, storage_root: Path, max_failures_per_bucket: int = 50) -> None: self.storage_root = storage_root + self.max_failures_per_bucket = max_failures_per_bucket self._lock = threading.Lock() def _get_failures_path(self, bucket_name: str) -> Path: @@ -188,7 +195,7 @@ class ReplicationFailureStore: def save_failures(self, bucket_name: str, failures: List[ReplicationFailure]) -> None: path = self._get_failures_path(bucket_name) path.parent.mkdir(parents=True, exist_ok=True) - data = {"failures": [f.to_dict() for f in failures[:self.MAX_FAILURES_PER_BUCKET]]} + data = {"failures": [f.to_dict() for f in failures[:self.max_failures_per_bucket]]} try: with open(path, "w") as f: json.dump(data, f, indent=2) @@ -233,18 +240,43 @@ class ReplicationFailureStore: class ReplicationManager: - def __init__(self, storage: ObjectStorage, connections: ConnectionStore, rules_path: Path, storage_root: Path) -> None: + def __init__( + self, + storage: ObjectStorage, + connections: ConnectionStore, + rules_path: Path, + storage_root: Path, + connect_timeout: int = 5, + read_timeout: int = 30, + max_retries: int = 2, + streaming_threshold_bytes: int = 10 * 1024 * 1024, + max_failures_per_bucket: int = 50, + ) -> None: self.storage = storage self.connections = connections self.rules_path = rules_path self.storage_root = storage_root + self.connect_timeout = connect_timeout + self.read_timeout = read_timeout + self.max_retries = max_retries + self.streaming_threshold_bytes = streaming_threshold_bytes self._rules: Dict[str, ReplicationRule] = {} self._stats_lock = threading.Lock() self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="ReplicationWorker") self._shutdown = False - self.failure_store = ReplicationFailureStore(storage_root) + self.failure_store = ReplicationFailureStore(storage_root, max_failures_per_bucket) self.reload_rules() + def _create_client(self, connection: RemoteConnection, *, health_check: bool = False) -> Any: + """Create an S3 client with the manager's configured timeouts.""" + return _create_s3_client( + connection, + health_check=health_check, + connect_timeout=self.connect_timeout, + read_timeout=self.read_timeout, + max_retries=self.max_retries, + ) + def shutdown(self, wait: bool = True) -> None: """Shutdown the replication executor gracefully. @@ -280,7 +312,7 @@ class ReplicationManager: Uses short timeouts to prevent blocking. """ try: - s3 = _create_s3_client(connection, health_check=True) + s3 = self._create_client(connection, health_check=True) s3.list_buckets() return True except Exception as e: @@ -290,6 +322,9 @@ class ReplicationManager: def get_rule(self, bucket_name: str) -> Optional[ReplicationRule]: return self._rules.get(bucket_name) + def list_rules(self) -> List[ReplicationRule]: + return list(self._rules.values()) + def set_rule(self, rule: ReplicationRule) -> None: old_rule = self._rules.get(rule.bucket_name) was_all_mode = old_rule and old_rule.mode == REPLICATION_MODE_ALL if old_rule else False @@ -329,7 +364,7 @@ class ReplicationManager: source_objects = self.storage.list_objects_all(bucket_name) source_keys = {obj.key: obj.size for obj in source_objects} - s3 = _create_s3_client(connection) + s3 = self._create_client(connection) dest_keys = set() bytes_synced = 0 @@ -395,7 +430,7 @@ class ReplicationManager: raise ValueError(f"Connection {connection_id} not found") try: - s3 = _create_s3_client(connection) + s3 = self._create_client(connection) s3.create_bucket(Bucket=bucket_name) except ClientError as e: logger.error(f"Failed to create remote bucket {bucket_name}: {e}") @@ -438,7 +473,7 @@ class ReplicationManager: return try: - s3 = _create_s3_client(conn) + s3 = self._create_client(conn) if action == "delete": try: @@ -481,7 +516,7 @@ class ReplicationManager: if content_type: extra_args["ContentType"] = content_type - if file_size >= STREAMING_THRESHOLD_BYTES: + if file_size >= self.streaming_threshold_bytes: s3.upload_file( str(path), rule.target_bucket, diff --git a/app/s3_api.py b/app/s3_api.py index 1f49e15..822997b 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -1,4 +1,3 @@ -"""Flask blueprint exposing a subset of the S3 REST API.""" from __future__ import annotations import base64 @@ -32,6 +31,24 @@ logger = logging.getLogger(__name__) S3_NS = "http://s3.amazonaws.com/doc/2006-03-01/" +_HEADER_CONTROL_CHARS = re.compile(r'[\r\n\x00-\x1f\x7f]') + + +def _sanitize_header_value(value: str) -> str: + return _HEADER_CONTROL_CHARS.sub('', value) + + +MAX_XML_PAYLOAD_SIZE = 1048576 # 1 MB + + +def _parse_xml_with_limit(payload: bytes) -> Element: + """Parse XML payload with size limit to prevent DoS attacks.""" + max_size = current_app.config.get("MAX_XML_PAYLOAD_SIZE", MAX_XML_PAYLOAD_SIZE) + if len(payload) > max_size: + raise ParseError(f"XML payload exceeds maximum size of {max_size} bytes") + return fromstring(payload) + + s3_api_bp = Blueprint("s3_api", __name__) def _storage() -> ObjectStorage: @@ -60,10 +77,13 @@ def _build_policy_context() -> Dict[str, Any]: ctx: Dict[str, Any] = {} if request.headers.get("Referer"): ctx["aws:Referer"] = request.headers.get("Referer") - if request.access_route: - ctx["aws:SourceIp"] = request.access_route[0] + num_proxies = current_app.config.get("NUM_TRUSTED_PROXIES", 0) + if num_proxies > 0 and request.access_route and len(request.access_route) > num_proxies: + ctx["aws:SourceIp"] = request.access_route[-num_proxies] elif request.remote_addr: ctx["aws:SourceIp"] = request.remote_addr + elif request.access_route: + ctx["aws:SourceIp"] = request.access_route[0] ctx["aws:SecureTransport"] = str(request.is_secure).lower() if request.headers.get("User-Agent"): ctx["aws:UserAgent"] = request.headers.get("User-Agent") @@ -82,6 +102,22 @@ def _access_logging() -> AccessLoggingService: return current_app.extensions["access_logging"] +def _get_list_buckets_limit() -> str: + return current_app.config.get("RATELIMIT_LIST_BUCKETS", "60 per minute") + + +def _get_bucket_ops_limit() -> str: + return current_app.config.get("RATELIMIT_BUCKET_OPS", "120 per minute") + + +def _get_object_ops_limit() -> str: + return current_app.config.get("RATELIMIT_OBJECT_OPS", "240 per minute") + + +def _get_head_ops_limit() -> str: + return current_app.config.get("RATELIMIT_HEAD_OPS", "100 per minute") + + def _xml_response(element: Element, status: int = 200) -> Response: xml_bytes = tostring(element, encoding="utf-8") return Response(xml_bytes, status=status, mimetype="application/xml") @@ -107,30 +143,36 @@ def _require_xml_content_type() -> Response | None: def _parse_range_header(range_header: str, file_size: int) -> list[tuple[int, int]] | None: if not range_header.startswith("bytes="): return None + max_range_value = 2**63 - 1 ranges = [] range_spec = range_header[6:] for part in range_spec.split(","): part = part.strip() if not part: continue - if part.startswith("-"): - suffix_length = int(part[1:]) - if suffix_length <= 0: - return None - start = max(0, file_size - suffix_length) - end = file_size - 1 - elif part.endswith("-"): - start = int(part[:-1]) - if start >= file_size: - return None - end = file_size - 1 - else: - start_str, end_str = part.split("-", 1) - start = int(start_str) - end = int(end_str) - if start > end or start >= file_size: - return None - end = min(end, file_size - 1) + try: + if part.startswith("-"): + suffix_length = int(part[1:]) + if suffix_length <= 0 or suffix_length > max_range_value: + return None + start = max(0, file_size - suffix_length) + end = file_size - 1 + elif part.endswith("-"): + start = int(part[:-1]) + if start < 0 or start > max_range_value or start >= file_size: + return None + end = file_size - 1 + else: + start_str, end_str = part.split("-", 1) + start = int(start_str) + end = int(end_str) + if start < 0 or end < 0 or start > max_range_value or end > max_range_value: + return None + if start > end or start >= file_size: + return None + end = min(end, file_size - 1) + except (ValueError, OverflowError): + return None ranges.append((start, end)) return ranges if ranges else None @@ -177,7 +219,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: access_key, date_stamp, region, service, signed_headers_str, signature = match.groups() secret_key = _iam().get_secret_key(access_key) if not secret_key: - raise IamError("Invalid access key") + raise IamError("SignatureDoesNotMatch") method = req.method canonical_uri = _get_canonical_uri(req) @@ -223,7 +265,8 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: now = datetime.now(timezone.utc) time_diff = abs((now - request_time).total_seconds()) - if time_diff > 900: + tolerance = current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900) + if time_diff > tolerance: raise IamError("Request timestamp too old or too far in the future") required_headers = {'host', 'x-amz-date'} @@ -359,16 +402,18 @@ def _verify_sigv4(req: Any) -> Principal | None: def _require_principal(): - if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \ - (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"): + sigv4_attempted = ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \ + (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256") + if sigv4_attempted: try: principal = _verify_sigv4(request) if principal: return principal, None + return None, _error_response("AccessDenied", "Signature verification failed", 403) except IamError as exc: return None, _error_response("AccessDenied", str(exc), 403) except (ValueError, TypeError): - return None, _error_response("AccessDenied", "Signature verification failed", 403) + return None, _error_response("AccessDenied", "Signature verification failed", 403) access_key = request.headers.get("X-Access-Key") secret_key = request.headers.get("X-Secret-Key") @@ -485,8 +530,10 @@ def _validate_presigned_request(action: str, bucket_name: str, object_key: str) expiry = int(expires) except ValueError as exc: raise IamError("Invalid expiration") from exc - if expiry < 1 or expiry > 7 * 24 * 3600: - raise IamError("Expiration must be between 1 second and 7 days") + min_expiry = current_app.config.get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1) + max_expiry = current_app.config.get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800) + if expiry < min_expiry or expiry > max_expiry: + raise IamError(f"Expiration must be between {min_expiry} second(s) and {max_expiry} seconds") try: request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) @@ -687,7 +734,7 @@ def _find_element_text(parent: Element, name: str, default: str = "") -> str: def _parse_tagging_document(payload: bytes) -> list[dict[str, str]]: try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "Tagging": @@ -788,7 +835,7 @@ def _validate_content_type(object_key: str, content_type: str | None) -> str | N def _parse_cors_document(payload: bytes) -> list[dict[str, Any]]: try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "CORSConfiguration": @@ -841,7 +888,7 @@ def _render_cors_document(rules: list[dict[str, Any]]) -> Element: def _parse_encryption_document(payload: bytes) -> dict[str, Any]: try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "ServerSideEncryptionConfiguration": @@ -943,6 +990,7 @@ def _maybe_handle_bucket_subresource(bucket_name: str) -> Response | None: "logging": _bucket_logging_handler, "uploads": _bucket_uploads_handler, "policy": _bucket_policy_handler, + "replication": _bucket_replication_handler, } requested = [key for key in handlers if key in request.args] if not requested: @@ -977,7 +1025,7 @@ def _bucket_versioning_handler(bucket_name: str) -> Response: if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) if _strip_ns(root.tag) != "VersioningConfiguration": @@ -1039,8 +1087,9 @@ def _bucket_tagging_handler(bucket_name: str) -> Response: tags = _parse_tagging_document(payload) except ValueError as exc: return _error_response("MalformedXML", str(exc), 400) - if len(tags) > 50: - return _error_response("InvalidTag", "A maximum of 50 tags is supported", 400) + tag_limit = current_app.config.get("OBJECT_TAG_LIMIT", 50) + if len(tags) > tag_limit: + return _error_response("InvalidTag", f"A maximum of {tag_limit} tags is supported", 400) try: storage.set_bucket_tags(bucket_name, tags) except StorageError as exc: @@ -1111,6 +1160,33 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response: return Response(status=204) +def _validate_cors_origin(origin: str) -> bool: + """Validate a CORS origin pattern.""" + import re + origin = origin.strip() + if not origin: + return False + if origin == "*": + return True + if origin.startswith("*."): + domain = origin[2:] + if not domain or ".." in domain: + return False + return bool(re.match(r'^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?)*$', domain)) + if origin.startswith(("http://", "https://")): + try: + from urllib.parse import urlparse + parsed = urlparse(origin) + if not parsed.netloc: + return False + if parsed.path and parsed.path != "/": + return False + return True + except Exception: + return False + return False + + def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]: sanitized: list[dict[str, Any]] = [] for rule in rules: @@ -1120,6 +1196,13 @@ def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]: expose_headers = [header.strip() for header in rule.get("ExposeHeaders", []) if header and header.strip()] if not allowed_origins or not allowed_methods: raise ValueError("Each CORSRule must include AllowedOrigin and AllowedMethod entries") + for origin in allowed_origins: + if not _validate_cors_origin(origin): + raise ValueError(f"Invalid CORS origin: {origin}") + valid_methods = {"GET", "PUT", "POST", "DELETE", "HEAD"} + for method in allowed_methods: + if method not in valid_methods: + raise ValueError(f"Invalid CORS method: {method}") sanitized_rule: dict[str, Any] = { "AllowedOrigins": allowed_origins, "AllowedMethods": allowed_methods, @@ -1329,7 +1412,10 @@ def _bucket_list_versions_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") try: - max_keys = max(1, min(int(request.args.get("max-keys", 1000)), 1000)) + max_keys = int(request.args.get("max-keys", 1000)) + if max_keys < 1: + return _error_response("InvalidArgument", "max-keys must be a positive integer", 400) + max_keys = min(max_keys, 1000) except ValueError: return _error_response("InvalidArgument", "max-keys must be an integer", 400) key_marker = request.args.get("key-marker", "") @@ -1493,7 +1579,7 @@ def _render_lifecycle_config(config: list) -> Element: def _parse_lifecycle_config(payload: bytes) -> list: """Parse lifecycle configuration from XML.""" try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError(f"Unable to parse XML document: {exc}") from exc @@ -1679,7 +1765,7 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response: return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -1748,7 +1834,7 @@ def _bucket_notification_handler(bucket_name: str) -> Response: return Response(status=200) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -1830,7 +1916,7 @@ def _bucket_logging_handler(bucket_name: str) -> Response: return Response(status=200) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -1883,7 +1969,10 @@ def _bucket_uploads_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") try: - max_uploads = max(1, min(int(request.args.get("max-uploads", 1000)), 1000)) + max_uploads = int(request.args.get("max-uploads", 1000)) + if max_uploads < 1: + return _error_response("InvalidArgument", "max-uploads must be a positive integer", 400) + max_uploads = min(max_uploads, 1000) except ValueError: return _error_response("InvalidArgument", "max-uploads must be an integer", 400) @@ -1969,7 +2058,7 @@ def _object_retention_handler(bucket_name: str, object_key: str) -> Response: return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -2041,7 +2130,7 @@ def _object_legal_hold_handler(bucket_name: str, object_key: str) -> Response: return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -2074,7 +2163,7 @@ def _bulk_delete_handler(bucket_name: str) -> Response: if not payload.strip(): return _error_response("MalformedXML", "Request body must include a Delete specification", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) if _strip_ns(root.tag) != "Delete": @@ -2142,8 +2231,156 @@ def _bulk_delete_handler(bucket_name: str) -> Response: return _xml_response(result, status=200) +def _post_object(bucket_name: str) -> Response: + storage = _storage() + if not storage.bucket_exists(bucket_name): + return _error_response("NoSuchBucket", "Bucket does not exist", 404) + object_key = request.form.get("key") + policy_b64 = request.form.get("policy") + signature = request.form.get("x-amz-signature") + credential = request.form.get("x-amz-credential") + algorithm = request.form.get("x-amz-algorithm") + amz_date = request.form.get("x-amz-date") + if not all([object_key, policy_b64, signature, credential, algorithm, amz_date]): + return _error_response("InvalidArgument", "Missing required form fields", 400) + if algorithm != "AWS4-HMAC-SHA256": + return _error_response("InvalidArgument", "Unsupported signing algorithm", 400) + try: + policy_json = base64.b64decode(policy_b64).decode("utf-8") + policy = __import__("json").loads(policy_json) + except (ValueError, __import__("json").JSONDecodeError) as exc: + return _error_response("InvalidPolicyDocument", f"Invalid policy: {exc}", 400) + expiration = policy.get("expiration") + if expiration: + try: + exp_time = datetime.fromisoformat(expiration.replace("Z", "+00:00")) + if datetime.now(timezone.utc) > exp_time: + return _error_response("AccessDenied", "Policy expired", 403) + except ValueError: + return _error_response("InvalidPolicyDocument", "Invalid expiration format", 400) + conditions = policy.get("conditions", []) + validation_error = _validate_post_policy_conditions(bucket_name, object_key, conditions, request.form, request.content_length or 0) + if validation_error: + return _error_response("AccessDenied", validation_error, 403) + try: + parts = credential.split("/") + if len(parts) != 5: + raise ValueError("Invalid credential format") + access_key, date_stamp, region, service, _ = parts + except ValueError: + return _error_response("InvalidArgument", "Invalid credential format", 400) + secret_key = _iam().get_secret_key(access_key) + if not secret_key: + return _error_response("AccessDenied", "Invalid access key", 403) + signing_key = _derive_signing_key(secret_key, date_stamp, region, service) + expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest() + if not hmac.compare_digest(expected_signature, signature): + return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403) + principal = _iam().get_principal(access_key) + if not principal: + return _error_response("AccessDenied", "Invalid access key", 403) + if "${filename}" in object_key: + temp_key = object_key.replace("${filename}", request.files.get("file").filename if request.files.get("file") else "upload") + else: + temp_key = object_key + try: + _authorize_action(principal, bucket_name, "write", object_key=temp_key) + except IamError as exc: + return _error_response("AccessDenied", str(exc), 403) + file = request.files.get("file") + if not file: + return _error_response("InvalidArgument", "Missing file field", 400) + if "${filename}" in object_key: + object_key = object_key.replace("${filename}", file.filename or "upload") + metadata = {} + for field_name, value in request.form.items(): + if field_name.lower().startswith("x-amz-meta-"): + key = field_name[11:] + if key: + metadata[key] = value + try: + meta = storage.put_object(bucket_name, object_key, file.stream, metadata=metadata or None) + except QuotaExceededError as exc: + return _error_response("QuotaExceeded", str(exc), 403) + except StorageError as exc: + return _error_response("InvalidArgument", str(exc), 400) + current_app.logger.info("Object uploaded via POST", extra={"bucket": bucket_name, "key": object_key, "size": meta.size}) + success_action_status = request.form.get("success_action_status", "204") + success_action_redirect = request.form.get("success_action_redirect") + if success_action_redirect: + allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", []) + parsed = urlparse(success_action_redirect) + if parsed.scheme not in ("http", "https"): + return _error_response("InvalidArgument", "Redirect URL must use http or https", 400) + if allowed_hosts and parsed.netloc not in allowed_hosts: + return _error_response("InvalidArgument", "Redirect URL host not allowed", 400) + redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}" + return Response(status=303, headers={"Location": redirect_url}) + if success_action_status == "200": + root = Element("PostResponse") + SubElement(root, "Location").text = f"/{bucket_name}/{object_key}" + SubElement(root, "Bucket").text = bucket_name + SubElement(root, "Key").text = object_key + SubElement(root, "ETag").text = f'"{meta.etag}"' + return _xml_response(root, status=200) + if success_action_status == "201": + root = Element("PostResponse") + SubElement(root, "Location").text = f"/{bucket_name}/{object_key}" + SubElement(root, "Bucket").text = bucket_name + SubElement(root, "Key").text = object_key + SubElement(root, "ETag").text = f'"{meta.etag}"' + return _xml_response(root, status=201) + return Response(status=204) + + +def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditions: list, form_data, content_length: int) -> Optional[str]: + for condition in conditions: + if isinstance(condition, dict): + for key, expected_value in condition.items(): + if key == "bucket": + if bucket_name != expected_value: + return f"Bucket must be {expected_value}" + elif key == "key": + if object_key != expected_value: + return f"Key must be {expected_value}" + else: + actual_value = form_data.get(key, "") + if actual_value != expected_value: + return f"Field {key} must be {expected_value}" + elif isinstance(condition, list) and len(condition) >= 2: + operator = condition[0].lower() if isinstance(condition[0], str) else "" + if operator == "starts-with" and len(condition) == 3: + field = condition[1].lstrip("$") + prefix = condition[2] + if field == "key": + if not object_key.startswith(prefix): + return f"Key must start with {prefix}" + else: + actual_value = form_data.get(field, "") + if not actual_value.startswith(prefix): + return f"Field {field} must start with {prefix}" + elif operator == "eq" and len(condition) == 3: + field = condition[1].lstrip("$") + expected = condition[2] + if field == "key": + if object_key != expected: + return f"Key must equal {expected}" + else: + actual_value = form_data.get(field, "") + if actual_value != expected: + return f"Field {field} must equal {expected}" + elif operator == "content-length-range" and len(condition) == 3: + try: + min_size, max_size = int(condition[1]), int(condition[2]) + except (TypeError, ValueError): + return "Invalid content-length-range values" + if content_length < min_size or content_length > max_size: + return f"Content length must be between {min_size} and {max_size}" + return None + + @s3_api_bp.get("/") -@limiter.limit("60 per minute") +@limiter.limit(_get_list_buckets_limit) def list_buckets() -> Response: principal, error = _require_principal() if error: @@ -2171,7 +2408,7 @@ def list_buckets() -> Response: @s3_api_bp.route("/", methods=["PUT", "DELETE", "GET", "POST"], strict_slashes=False) -@limiter.limit("120 per minute") +@limiter.limit(_get_bucket_ops_limit) def bucket_handler(bucket_name: str) -> Response: storage = _storage() subresource_response = _maybe_handle_bucket_subresource(bucket_name) @@ -2179,9 +2416,12 @@ def bucket_handler(bucket_name: str) -> Response: return subresource_response if request.method == "POST": - if "delete" not in request.args: - return _method_not_allowed(["GET", "PUT", "DELETE"]) - return _bulk_delete_handler(bucket_name) + if "delete" in request.args: + return _bulk_delete_handler(bucket_name) + content_type = request.headers.get("Content-Type", "") + if "multipart/form-data" in content_type: + return _post_object(bucket_name) + return _method_not_allowed(["GET", "PUT", "DELETE"]) if request.method == "PUT": principal, error = _require_principal() @@ -2231,23 +2471,24 @@ def bucket_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") try: - max_keys = max(1, min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000)) + max_keys = int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])) + if max_keys < 1: + return _error_response("InvalidArgument", "max-keys must be a positive integer", 400) + max_keys = min(max_keys, 1000) except ValueError: return _error_response("InvalidArgument", "max-keys must be an integer", 400) - + marker = request.args.get("marker", "") # ListObjects v1 continuation_token = request.args.get("continuation-token", "") # ListObjectsV2 start_after = request.args.get("start-after", "") # ListObjectsV2 - # For ListObjectsV2, continuation-token takes precedence, then start-after - # For ListObjects v1, use marker effective_start = "" if list_type == "2": if continuation_token: try: effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8") except (ValueError, UnicodeDecodeError): - effective_start = continuation_token + return _error_response("InvalidArgument", "Invalid continuation token", 400) elif start_after: effective_start = start_after else: @@ -2363,7 +2604,7 @@ def bucket_handler(bucket_name: str) -> Response: @s3_api_bp.route("//", methods=["PUT", "GET", "DELETE", "HEAD", "POST"], strict_slashes=False) -@limiter.limit("240 per minute") +@limiter.limit(_get_object_ops_limit) def object_handler(bucket_name: str, object_key: str): storage = _storage() @@ -2381,6 +2622,8 @@ def object_handler(bucket_name: str, object_key: str): return _initiate_multipart_upload(bucket_name, object_key) if "uploadId" in request.args: return _complete_multipart_upload(bucket_name, object_key) + if "select" in request.args: + return _select_object_content(bucket_name, object_key) return _method_not_allowed(["GET", "PUT", "DELETE", "HEAD", "POST"]) if request.method == "PUT": @@ -2560,7 +2803,7 @@ def object_handler(bucket_name: str, object_key: str): for param, header in response_overrides.items(): value = request.args.get(param) if value: - response.headers[header] = value + response.headers[header] = _sanitize_header_value(value) action = "Object read" if request.method == "GET" else "Object head" current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes}) @@ -2680,8 +2923,122 @@ def _bucket_policy_handler(bucket_name: str) -> Response: return Response(status=204) +def _bucket_replication_handler(bucket_name: str) -> Response: + if request.method not in {"GET", "PUT", "DELETE"}: + return _method_not_allowed(["GET", "PUT", "DELETE"]) + principal, error = _require_principal() + if error: + return error + try: + _authorize_action(principal, bucket_name, "policy") + except IamError as exc: + return _error_response("AccessDenied", str(exc), 403) + storage = _storage() + if not storage.bucket_exists(bucket_name): + return _error_response("NoSuchBucket", "Bucket does not exist", 404) + replication = _replication_manager() + if request.method == "GET": + rule = replication.get_rule(bucket_name) + if not rule: + return _error_response("ReplicationConfigurationNotFoundError", "Replication configuration not found", 404) + return _xml_response(_render_replication_config(rule)) + if request.method == "DELETE": + replication.delete_rule(bucket_name) + current_app.logger.info("Bucket replication removed", extra={"bucket": bucket_name}) + return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error + payload = request.get_data(cache=False) or b"" + try: + rule = _parse_replication_config(bucket_name, payload) + except ValueError as exc: + return _error_response("MalformedXML", str(exc), 400) + replication.set_rule(rule) + current_app.logger.info("Bucket replication updated", extra={"bucket": bucket_name}) + return Response(status=200) + + +def _parse_replication_config(bucket_name: str, payload: bytes): + from .replication import ReplicationRule, REPLICATION_MODE_ALL + root = _parse_xml_with_limit(payload) + if _strip_ns(root.tag) != "ReplicationConfiguration": + raise ValueError("Root element must be ReplicationConfiguration") + rule_el = None + for child in list(root): + if _strip_ns(child.tag) == "Rule": + rule_el = child + break + if rule_el is None: + raise ValueError("At least one Rule is required") + status_el = _find_element(rule_el, "Status") + status = status_el.text if status_el is not None and status_el.text else "Enabled" + enabled = status.lower() == "enabled" + filter_prefix = None + filter_el = _find_element(rule_el, "Filter") + if filter_el is not None: + prefix_el = _find_element(filter_el, "Prefix") + if prefix_el is not None and prefix_el.text: + filter_prefix = prefix_el.text + dest_el = _find_element(rule_el, "Destination") + if dest_el is None: + raise ValueError("Destination element is required") + bucket_el = _find_element(dest_el, "Bucket") + if bucket_el is None or not bucket_el.text: + raise ValueError("Destination Bucket is required") + target_bucket, target_connection_id = _parse_destination_arn(bucket_el.text) + sync_deletions = True + dm_el = _find_element(rule_el, "DeleteMarkerReplication") + if dm_el is not None: + dm_status_el = _find_element(dm_el, "Status") + if dm_status_el is not None and dm_status_el.text: + sync_deletions = dm_status_el.text.lower() == "enabled" + return ReplicationRule( + bucket_name=bucket_name, + target_connection_id=target_connection_id, + target_bucket=target_bucket, + enabled=enabled, + mode=REPLICATION_MODE_ALL, + sync_deletions=sync_deletions, + filter_prefix=filter_prefix, + ) + + +def _parse_destination_arn(arn: str) -> tuple: + if not arn.startswith("arn:aws:s3:::"): + raise ValueError(f"Invalid ARN format: {arn}") + bucket_part = arn[13:] + if "/" in bucket_part: + connection_id, bucket_name = bucket_part.split("/", 1) + else: + connection_id = "local" + bucket_name = bucket_part + return bucket_name, connection_id + + +def _render_replication_config(rule) -> Element: + root = Element("ReplicationConfiguration") + SubElement(root, "Role").text = "arn:aws:iam::000000000000:role/replication" + rule_el = SubElement(root, "Rule") + SubElement(rule_el, "ID").text = f"{rule.bucket_name}-replication" + SubElement(rule_el, "Status").text = "Enabled" if rule.enabled else "Disabled" + SubElement(rule_el, "Priority").text = "1" + filter_el = SubElement(rule_el, "Filter") + if rule.filter_prefix: + SubElement(filter_el, "Prefix").text = rule.filter_prefix + dest_el = SubElement(rule_el, "Destination") + if rule.target_connection_id == "local": + arn = f"arn:aws:s3:::{rule.target_bucket}" + else: + arn = f"arn:aws:s3:::{rule.target_connection_id}/{rule.target_bucket}" + SubElement(dest_el, "Bucket").text = arn + dm_el = SubElement(rule_el, "DeleteMarkerReplication") + SubElement(dm_el, "Status").text = "Enabled" if rule.sync_deletions else "Disabled" + return root + + @s3_api_bp.route("/", methods=["HEAD"]) -@limiter.limit("100 per minute") +@limiter.limit(_get_head_ops_limit) def head_bucket(bucket_name: str) -> Response: principal, error = _require_principal() if error: @@ -2696,7 +3053,7 @@ def head_bucket(bucket_name: str) -> Response: @s3_api_bp.route("//", methods=["HEAD"]) -@limiter.limit("100 per minute") +@limiter.limit(_get_head_ops_limit) def head_object(bucket_name: str, object_key: str) -> Response: principal, error = _require_principal() if error: @@ -2957,6 +3314,10 @@ def _initiate_multipart_upload(bucket_name: str, object_key: str) -> Response: def _upload_part(bucket_name: str, object_key: str) -> Response: + copy_source = request.headers.get("x-amz-copy-source") + if copy_source: + return _upload_part_copy(bucket_name, object_key, copy_source) + principal, error = _object_principal("write", bucket_name, object_key) if error: return error @@ -2971,6 +3332,9 @@ def _upload_part(bucket_name: str, object_key: str) -> Response: except ValueError: return _error_response("InvalidArgument", "partNumber must be an integer", 400) + if part_number < 1 or part_number > 10000: + return _error_response("InvalidArgument", "partNumber must be between 1 and 10000", 400) + stream = request.stream content_encoding = request.headers.get("Content-Encoding", "").lower() if "aws-chunked" in content_encoding: @@ -2990,6 +3354,67 @@ def _upload_part(bucket_name: str, object_key: str) -> Response: return response +def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Response: + principal, error = _object_principal("write", bucket_name, object_key) + if error: + return error + + upload_id = request.args.get("uploadId") + part_number_str = request.args.get("partNumber") + if not upload_id or not part_number_str: + return _error_response("InvalidArgument", "uploadId and partNumber are required", 400) + + try: + part_number = int(part_number_str) + except ValueError: + return _error_response("InvalidArgument", "partNumber must be an integer", 400) + + if part_number < 1 or part_number > 10000: + return _error_response("InvalidArgument", "partNumber must be between 1 and 10000", 400) + + copy_source = unquote(copy_source) + if copy_source.startswith("/"): + copy_source = copy_source[1:] + parts = copy_source.split("/", 1) + if len(parts) != 2: + return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400) + source_bucket, source_key = parts + if not source_bucket or not source_key: + return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400) + + _, read_error = _object_principal("read", source_bucket, source_key) + if read_error: + return read_error + + copy_source_range = request.headers.get("x-amz-copy-source-range") + start_byte, end_byte = None, None + if copy_source_range: + match = re.match(r"bytes=(\d+)-(\d+)", copy_source_range) + if not match: + return _error_response("InvalidArgument", "Invalid x-amz-copy-source-range format", 400) + start_byte, end_byte = int(match.group(1)), int(match.group(2)) + + try: + result = _storage().upload_part_copy( + bucket_name, upload_id, part_number, + source_bucket, source_key, + start_byte, end_byte + ) + except ObjectNotFoundError: + return _error_response("NoSuchKey", "Source object not found", 404) + except StorageError as exc: + if "Multipart upload not found" in str(exc): + return _error_response("NoSuchUpload", str(exc), 404) + if "Invalid byte range" in str(exc): + return _error_response("InvalidRange", str(exc), 416) + return _error_response("InvalidArgument", str(exc), 400) + + root = Element("CopyPartResult") + SubElement(root, "LastModified").text = result["last_modified"].strftime("%Y-%m-%dT%H:%M:%S.000Z") + SubElement(root, "ETag").text = f'"{result["etag"]}"' + return _xml_response(root) + + def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: principal, error = _object_principal("write", bucket_name, object_key) if error: @@ -3004,7 +3429,7 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: return ct_error payload = request.get_data(cache=False) or b"" try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -3024,8 +3449,14 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: etag_el = part_el.find("ETag") if part_number_el is not None and etag_el is not None: + try: + part_num = int(part_number_el.text or 0) + except ValueError: + return _error_response("InvalidArgument", "PartNumber must be an integer", 400) + if part_num < 1 or part_num > 10000: + return _error_response("InvalidArgument", f"PartNumber {part_num} must be between 1 and 10000", 400) parts.append({ - "PartNumber": int(part_number_el.text or 0), + "PartNumber": part_num, "ETag": (etag_el.text or "").strip('"') }) @@ -3074,6 +3505,164 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response: return Response(status=204) +def _select_object_content(bucket_name: str, object_key: str) -> Response: + _, error = _object_principal("read", bucket_name, object_key) + if error: + return error + ct_error = _require_xml_content_type() + if ct_error: + return ct_error + payload = request.get_data(cache=False) or b"" + try: + root = _parse_xml_with_limit(payload) + except ParseError: + return _error_response("MalformedXML", "Unable to parse XML document", 400) + if _strip_ns(root.tag) != "SelectObjectContentRequest": + return _error_response("MalformedXML", "Root element must be SelectObjectContentRequest", 400) + expression_el = _find_element(root, "Expression") + if expression_el is None or not expression_el.text: + return _error_response("InvalidRequest", "Expression is required", 400) + expression = expression_el.text + expression_type_el = _find_element(root, "ExpressionType") + expression_type = expression_type_el.text if expression_type_el is not None and expression_type_el.text else "SQL" + if expression_type.upper() != "SQL": + return _error_response("InvalidRequest", "Only SQL expression type is supported", 400) + input_el = _find_element(root, "InputSerialization") + if input_el is None: + return _error_response("InvalidRequest", "InputSerialization is required", 400) + try: + input_format, input_config = _parse_select_input_serialization(input_el) + except ValueError as exc: + return _error_response("InvalidRequest", str(exc), 400) + output_el = _find_element(root, "OutputSerialization") + if output_el is None: + return _error_response("InvalidRequest", "OutputSerialization is required", 400) + try: + output_format, output_config = _parse_select_output_serialization(output_el) + except ValueError as exc: + return _error_response("InvalidRequest", str(exc), 400) + storage = _storage() + try: + path = storage.get_object_path(bucket_name, object_key) + except ObjectNotFoundError: + return _error_response("NoSuchKey", "Object not found", 404) + except StorageError: + return _error_response("NoSuchKey", "Object not found", 404) + from .select_content import execute_select_query, SelectError + try: + result_stream = execute_select_query( + file_path=path, + expression=expression, + input_format=input_format, + input_config=input_config, + output_format=output_format, + output_config=output_config, + ) + except SelectError as exc: + return _error_response("InvalidRequest", str(exc), 400) + + def generate_events(): + bytes_scanned = 0 + bytes_returned = 0 + for chunk in result_stream: + bytes_returned += len(chunk) + yield _encode_select_event("Records", chunk) + stats_payload = _build_stats_xml(bytes_scanned, bytes_returned) + yield _encode_select_event("Stats", stats_payload) + yield _encode_select_event("End", b"") + + return Response(generate_events(), mimetype="application/octet-stream", headers={"x-amz-request-charged": "requester"}) + + +def _parse_select_input_serialization(el: Element) -> tuple: + csv_el = _find_element(el, "CSV") + if csv_el is not None: + file_header_el = _find_element(csv_el, "FileHeaderInfo") + config = { + "file_header_info": file_header_el.text.upper() if file_header_el is not None and file_header_el.text else "NONE", + "comments": _find_element_text(csv_el, "Comments", "#"), + "field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","), + "record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"), + "quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'), + "quote_escape_character": _find_element_text(csv_el, "QuoteEscapeCharacter", '"'), + } + return "CSV", config + json_el = _find_element(el, "JSON") + if json_el is not None: + type_el = _find_element(json_el, "Type") + config = { + "type": type_el.text.upper() if type_el is not None and type_el.text else "DOCUMENT", + } + return "JSON", config + parquet_el = _find_element(el, "Parquet") + if parquet_el is not None: + return "Parquet", {} + raise ValueError("InputSerialization must specify CSV, JSON, or Parquet") + + +def _parse_select_output_serialization(el: Element) -> tuple: + csv_el = _find_element(el, "CSV") + if csv_el is not None: + config = { + "field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","), + "record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"), + "quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'), + "quote_fields": _find_element_text(csv_el, "QuoteFields", "ASNEEDED").upper(), + } + return "CSV", config + json_el = _find_element(el, "JSON") + if json_el is not None: + config = { + "record_delimiter": _find_element_text(json_el, "RecordDelimiter", "\n"), + } + return "JSON", config + raise ValueError("OutputSerialization must specify CSV or JSON") + + +def _encode_select_event(event_type: str, payload: bytes) -> bytes: + import struct + import binascii + headers = _build_event_headers(event_type) + headers_length = len(headers) + total_length = 4 + 4 + 4 + headers_length + len(payload) + 4 + prelude = struct.pack(">I", total_length) + struct.pack(">I", headers_length) + prelude_crc = binascii.crc32(prelude) & 0xffffffff + prelude += struct.pack(">I", prelude_crc) + message = prelude + headers + payload + message_crc = binascii.crc32(message) & 0xffffffff + message += struct.pack(">I", message_crc) + return message + + +def _build_event_headers(event_type: str) -> bytes: + headers = b"" + headers += _encode_select_header(":event-type", event_type) + if event_type == "Records": + headers += _encode_select_header(":content-type", "application/octet-stream") + elif event_type == "Stats": + headers += _encode_select_header(":content-type", "text/xml") + headers += _encode_select_header(":message-type", "event") + return headers + + +def _encode_select_header(name: str, value: str) -> bytes: + import struct + name_bytes = name.encode("utf-8") + value_bytes = value.encode("utf-8") + header = struct.pack("B", len(name_bytes)) + name_bytes + header += struct.pack("B", 7) + header += struct.pack(">H", len(value_bytes)) + value_bytes + return header + + +def _build_stats_xml(bytes_scanned: int, bytes_returned: int) -> bytes: + stats = Element("Stats") + SubElement(stats, "BytesScanned").text = str(bytes_scanned) + SubElement(stats, "BytesProcessed").text = str(bytes_scanned) + SubElement(stats, "BytesReturned").text = str(bytes_returned) + return tostring(stats, encoding="utf-8") + + @s3_api_bp.before_request def resolve_principal(): g.principal = None diff --git a/app/select_content.py b/app/select_content.py new file mode 100644 index 0000000..57a3362 --- /dev/null +++ b/app/select_content.py @@ -0,0 +1,171 @@ +"""S3 SelectObjectContent SQL query execution using DuckDB.""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, Generator, Optional + +try: + import duckdb + DUCKDB_AVAILABLE = True +except ImportError: + DUCKDB_AVAILABLE = False + + +class SelectError(Exception): + """Error during SELECT query execution.""" + pass + + +def execute_select_query( + file_path: Path, + expression: str, + input_format: str, + input_config: Dict[str, Any], + output_format: str, + output_config: Dict[str, Any], + chunk_size: int = 65536, +) -> Generator[bytes, None, None]: + """Execute SQL query on object content.""" + if not DUCKDB_AVAILABLE: + raise SelectError("DuckDB is not installed. Install with: pip install duckdb") + + conn = duckdb.connect(":memory:") + + try: + if input_format == "CSV": + _load_csv(conn, file_path, input_config) + elif input_format == "JSON": + _load_json(conn, file_path, input_config) + elif input_format == "Parquet": + _load_parquet(conn, file_path) + else: + raise SelectError(f"Unsupported input format: {input_format}") + + normalized_expression = expression.replace("s3object", "data").replace("S3Object", "data") + + try: + result = conn.execute(normalized_expression) + except duckdb.Error as exc: + raise SelectError(f"SQL execution error: {exc}") + + if output_format == "CSV": + yield from _output_csv(result, output_config, chunk_size) + elif output_format == "JSON": + yield from _output_json(result, output_config, chunk_size) + else: + raise SelectError(f"Unsupported output format: {output_format}") + + finally: + conn.close() + + +def _load_csv(conn, file_path: Path, config: Dict[str, Any]) -> None: + """Load CSV file into DuckDB.""" + file_header_info = config.get("file_header_info", "NONE") + delimiter = config.get("field_delimiter", ",") + quote = config.get("quote_character", '"') + + header = file_header_info in ("USE", "IGNORE") + path_str = str(file_path).replace("\\", "/") + + conn.execute(f""" + CREATE TABLE data AS + SELECT * FROM read_csv('{path_str}', + header={header}, + delim='{delimiter}', + quote='{quote}' + ) + """) + + +def _load_json(conn, file_path: Path, config: Dict[str, Any]) -> None: + """Load JSON file into DuckDB.""" + json_type = config.get("type", "DOCUMENT") + path_str = str(file_path).replace("\\", "/") + + if json_type == "LINES": + conn.execute(f""" + CREATE TABLE data AS + SELECT * FROM read_json_auto('{path_str}', format='newline_delimited') + """) + else: + conn.execute(f""" + CREATE TABLE data AS + SELECT * FROM read_json_auto('{path_str}', format='array') + """) + + +def _load_parquet(conn, file_path: Path) -> None: + """Load Parquet file into DuckDB.""" + path_str = str(file_path).replace("\\", "/") + conn.execute(f"CREATE TABLE data AS SELECT * FROM read_parquet('{path_str}')") + + +def _output_csv( + result, + config: Dict[str, Any], + chunk_size: int, +) -> Generator[bytes, None, None]: + """Output query results as CSV.""" + delimiter = config.get("field_delimiter", ",") + record_delimiter = config.get("record_delimiter", "\n") + quote = config.get("quote_character", '"') + + buffer = "" + + while True: + rows = result.fetchmany(1000) + if not rows: + break + + for row in rows: + fields = [] + for value in row: + if value is None: + fields.append("") + elif isinstance(value, str): + if delimiter in value or quote in value or record_delimiter in value: + escaped = value.replace(quote, quote + quote) + fields.append(f'{quote}{escaped}{quote}') + else: + fields.append(value) + else: + fields.append(str(value)) + + buffer += delimiter.join(fields) + record_delimiter + + while len(buffer) >= chunk_size: + yield buffer[:chunk_size].encode("utf-8") + buffer = buffer[chunk_size:] + + if buffer: + yield buffer.encode("utf-8") + + +def _output_json( + result, + config: Dict[str, Any], + chunk_size: int, +) -> Generator[bytes, None, None]: + """Output query results as JSON Lines.""" + record_delimiter = config.get("record_delimiter", "\n") + columns = [desc[0] for desc in result.description] + + buffer = "" + + while True: + rows = result.fetchmany(1000) + if not rows: + break + + for row in rows: + record = dict(zip(columns, row)) + buffer += json.dumps(record, default=str) + record_delimiter + + while len(buffer) >= chunk_size: + yield buffer[:chunk_size].encode("utf-8") + buffer = buffer[chunk_size:] + + if buffer: + yield buffer.encode("utf-8") diff --git a/app/site_registry.py b/app/site_registry.py new file mode 100644 index 0000000..b257326 --- /dev/null +++ b/app/site_registry.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + + +@dataclass +class SiteInfo: + site_id: str + endpoint: str + region: str = "us-east-1" + priority: int = 100 + display_name: str = "" + created_at: Optional[float] = None + updated_at: Optional[float] = None + + def __post_init__(self) -> None: + if not self.display_name: + self.display_name = self.site_id + if self.created_at is None: + self.created_at = time.time() + + def to_dict(self) -> Dict[str, Any]: + return { + "site_id": self.site_id, + "endpoint": self.endpoint, + "region": self.region, + "priority": self.priority, + "display_name": self.display_name, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> SiteInfo: + return cls( + site_id=data["site_id"], + endpoint=data.get("endpoint", ""), + region=data.get("region", "us-east-1"), + priority=data.get("priority", 100), + display_name=data.get("display_name", ""), + created_at=data.get("created_at"), + updated_at=data.get("updated_at"), + ) + + +@dataclass +class PeerSite: + site_id: str + endpoint: str + region: str = "us-east-1" + priority: int = 100 + display_name: str = "" + created_at: Optional[float] = None + updated_at: Optional[float] = None + connection_id: Optional[str] = None + is_healthy: Optional[bool] = None + last_health_check: Optional[float] = None + + def __post_init__(self) -> None: + if not self.display_name: + self.display_name = self.site_id + if self.created_at is None: + self.created_at = time.time() + + def to_dict(self) -> Dict[str, Any]: + return { + "site_id": self.site_id, + "endpoint": self.endpoint, + "region": self.region, + "priority": self.priority, + "display_name": self.display_name, + "created_at": self.created_at, + "updated_at": self.updated_at, + "connection_id": self.connection_id, + "is_healthy": self.is_healthy, + "last_health_check": self.last_health_check, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> PeerSite: + return cls( + site_id=data["site_id"], + endpoint=data.get("endpoint", ""), + region=data.get("region", "us-east-1"), + priority=data.get("priority", 100), + display_name=data.get("display_name", ""), + created_at=data.get("created_at"), + updated_at=data.get("updated_at"), + connection_id=data.get("connection_id"), + is_healthy=data.get("is_healthy"), + last_health_check=data.get("last_health_check"), + ) + + +class SiteRegistry: + def __init__(self, config_path: Path) -> None: + self.config_path = config_path + self._local_site: Optional[SiteInfo] = None + self._peers: Dict[str, PeerSite] = {} + self.reload() + + def reload(self) -> None: + if not self.config_path.exists(): + self._local_site = None + self._peers = {} + return + + try: + with open(self.config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if data.get("local"): + self._local_site = SiteInfo.from_dict(data["local"]) + else: + self._local_site = None + + self._peers = {} + for peer_data in data.get("peers", []): + peer = PeerSite.from_dict(peer_data) + self._peers[peer.site_id] = peer + + except (OSError, json.JSONDecodeError, KeyError): + self._local_site = None + self._peers = {} + + def save(self) -> None: + self.config_path.parent.mkdir(parents=True, exist_ok=True) + data = { + "local": self._local_site.to_dict() if self._local_site else None, + "peers": [peer.to_dict() for peer in self._peers.values()], + } + with open(self.config_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + def get_local_site(self) -> Optional[SiteInfo]: + return self._local_site + + def set_local_site(self, site: SiteInfo) -> None: + site.updated_at = time.time() + self._local_site = site + self.save() + + def list_peers(self) -> List[PeerSite]: + return list(self._peers.values()) + + def get_peer(self, site_id: str) -> Optional[PeerSite]: + return self._peers.get(site_id) + + def add_peer(self, peer: PeerSite) -> None: + peer.created_at = peer.created_at or time.time() + self._peers[peer.site_id] = peer + self.save() + + def update_peer(self, peer: PeerSite) -> None: + if peer.site_id not in self._peers: + raise ValueError(f"Peer {peer.site_id} not found") + peer.updated_at = time.time() + self._peers[peer.site_id] = peer + self.save() + + def delete_peer(self, site_id: str) -> bool: + if site_id in self._peers: + del self._peers[site_id] + self.save() + return True + return False + + def update_health(self, site_id: str, is_healthy: bool) -> None: + peer = self._peers.get(site_id) + if peer: + peer.is_healthy = is_healthy + peer.last_health_check = time.time() + self.save() diff --git a/app/site_sync.py b/app/site_sync.py index 306ac28..57cf185 100644 --- a/app/site_sync.py +++ b/app/site_sync.py @@ -22,9 +22,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) SITE_SYNC_USER_AGENT = "SiteSyncAgent/1.0" -SITE_SYNC_CONNECT_TIMEOUT = 10 -SITE_SYNC_READ_TIMEOUT = 120 -CLOCK_SKEW_TOLERANCE_SECONDS = 1.0 @dataclass @@ -108,12 +105,18 @@ class RemoteObjectMeta: ) -def _create_sync_client(connection: "RemoteConnection") -> Any: +def _create_sync_client( + connection: "RemoteConnection", + *, + connect_timeout: int = 10, + read_timeout: int = 120, + max_retries: int = 2, +) -> Any: config = Config( user_agent_extra=SITE_SYNC_USER_AGENT, - connect_timeout=SITE_SYNC_CONNECT_TIMEOUT, - read_timeout=SITE_SYNC_READ_TIMEOUT, - retries={"max_attempts": 2}, + connect_timeout=connect_timeout, + read_timeout=read_timeout, + retries={"max_attempts": max_retries}, signature_version="s3v4", s3={"addressing_style": "path"}, request_checksum_calculation="when_required", @@ -138,6 +141,10 @@ class SiteSyncWorker: storage_root: Path, interval_seconds: int = 60, batch_size: int = 100, + connect_timeout: int = 10, + read_timeout: int = 120, + max_retries: int = 2, + clock_skew_tolerance_seconds: float = 1.0, ): self.storage = storage self.connections = connections @@ -145,11 +152,24 @@ class SiteSyncWorker: self.storage_root = storage_root self.interval_seconds = interval_seconds self.batch_size = batch_size + self.connect_timeout = connect_timeout + self.read_timeout = read_timeout + self.max_retries = max_retries + self.clock_skew_tolerance_seconds = clock_skew_tolerance_seconds self._lock = threading.Lock() self._shutdown = threading.Event() self._sync_thread: Optional[threading.Thread] = None self._bucket_stats: Dict[str, SiteSyncStats] = {} + def _create_client(self, connection: "RemoteConnection") -> Any: + """Create an S3 client with the worker's configured timeouts.""" + return _create_sync_client( + connection, + connect_timeout=self.connect_timeout, + read_timeout=self.read_timeout, + max_retries=self.max_retries, + ) + def start(self) -> None: if self._sync_thread is not None and self._sync_thread.is_alive(): return @@ -294,7 +314,7 @@ class SiteSyncWorker: return {obj.key: obj for obj in objects} def _list_remote_objects(self, rule: "ReplicationRule", connection: "RemoteConnection") -> Dict[str, RemoteObjectMeta]: - s3 = _create_sync_client(connection) + s3 = self._create_client(connection) result: Dict[str, RemoteObjectMeta] = {} paginator = s3.get_paginator("list_objects_v2") try: @@ -312,7 +332,7 @@ class SiteSyncWorker: local_ts = local_meta.last_modified.timestamp() remote_ts = remote_meta.last_modified.timestamp() - if abs(remote_ts - local_ts) < CLOCK_SKEW_TOLERANCE_SECONDS: + if abs(remote_ts - local_ts) < self.clock_skew_tolerance_seconds: local_etag = local_meta.etag or "" if remote_meta.etag == local_etag: return "skip" @@ -327,7 +347,7 @@ class SiteSyncWorker: connection: "RemoteConnection", remote_meta: RemoteObjectMeta, ) -> bool: - s3 = _create_sync_client(connection) + s3 = self._create_client(connection) tmp_path = None try: tmp_dir = self.storage_root / ".myfsio.sys" / "tmp" diff --git a/app/storage.py b/app/storage.py index 70488d0..2a034d0 100644 --- a/app/storage.py +++ b/app/storage.py @@ -46,6 +46,34 @@ else: fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN) +@contextmanager +def _atomic_lock_file(lock_path: Path, max_retries: int = 10, base_delay: float = 0.1) -> Generator[None, None, None]: + """Atomically acquire a lock file with exponential backoff. + + Uses O_EXCL to ensure atomic creation of the lock file. + """ + lock_path.parent.mkdir(parents=True, exist_ok=True) + fd = None + for attempt in range(max_retries): + try: + fd = os.open(str(lock_path), os.O_CREAT | os.O_EXCL | os.O_WRONLY) + break + except FileExistsError: + if attempt == max_retries - 1: + raise BlockingIOError("Another upload to this key is in progress") + delay = base_delay * (2 ** attempt) + time.sleep(min(delay, 2.0)) + try: + yield + finally: + if fd is not None: + os.close(fd) + try: + lock_path.unlink(missing_ok=True) + except OSError: + pass + + WINDOWS_RESERVED_NAMES = { "CON", "PRN", @@ -137,10 +165,15 @@ class ObjectStorage: BUCKET_VERSIONS_DIR = "versions" MULTIPART_MANIFEST = "manifest.json" BUCKET_CONFIG_FILE = ".bucket.json" - DEFAULT_CACHE_TTL = 5 - OBJECT_CACHE_MAX_SIZE = 100 - def __init__(self, root: Path, cache_ttl: int = DEFAULT_CACHE_TTL) -> None: + def __init__( + self, + root: Path, + cache_ttl: int = 5, + object_cache_max_size: int = 100, + bucket_config_cache_ttl: float = 30.0, + object_key_max_length_bytes: int = 1024, + ) -> None: self.root = Path(root) self.root.mkdir(parents=True, exist_ok=True) self._ensure_system_roots() @@ -149,8 +182,10 @@ class ObjectStorage: self._bucket_locks: Dict[str, threading.Lock] = {} self._cache_version: Dict[str, int] = {} self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {} - self._bucket_config_cache_ttl = 30.0 + self._bucket_config_cache_ttl = bucket_config_cache_ttl self._cache_ttl = cache_ttl + self._object_cache_max_size = object_cache_max_size + self._object_key_max_length_bytes = object_key_max_length_bytes def _get_bucket_lock(self, bucket_id: str) -> threading.Lock: """Get or create a lock for a specific bucket. Reduces global lock contention.""" @@ -313,18 +348,15 @@ class ObjectStorage: total_count = len(all_keys) start_index = 0 if continuation_token: - try: - import bisect - start_index = bisect.bisect_right(all_keys, continuation_token) - if start_index >= total_count: - return ListObjectsResult( - objects=[], - is_truncated=False, - next_continuation_token=None, - total_count=total_count, - ) - except Exception: - pass + import bisect + start_index = bisect.bisect_right(all_keys, continuation_token) + if start_index >= total_count: + return ListObjectsResult( + objects=[], + is_truncated=False, + next_continuation_token=None, + total_count=total_count, + ) end_index = start_index + max_keys keys_slice = all_keys[start_index:end_index] @@ -364,7 +396,7 @@ class ObjectStorage: raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) destination = bucket_path / safe_key destination.parent.mkdir(parents=True, exist_ok=True) @@ -439,7 +471,7 @@ class ObjectStorage: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): return {} - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) return self._read_metadata(bucket_path.name, safe_key) or {} def _cleanup_empty_parents(self, path: Path, stop_at: Path) -> None: @@ -487,7 +519,7 @@ class ObjectStorage: self._safe_unlink(target) self._delete_metadata(bucket_id, rel) else: - rel = self._sanitize_object_key(object_key) + rel = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) self._delete_metadata(bucket_id, rel) version_dir = self._version_dir(bucket_id, rel) if version_dir.exists(): @@ -696,7 +728,7 @@ class ObjectStorage: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): raise BucketNotFoundError("Bucket does not exist") - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) object_path = bucket_path / safe_key if not object_path.exists(): raise ObjectNotFoundError("Object does not exist") @@ -719,7 +751,7 @@ class ObjectStorage: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): raise BucketNotFoundError("Bucket does not exist") - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) object_path = bucket_path / safe_key if not object_path.exists(): raise ObjectNotFoundError("Object does not exist") @@ -758,7 +790,7 @@ class ObjectStorage: if not bucket_path.exists(): raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) version_dir = self._version_dir(bucket_id, safe_key) if not version_dir.exists(): version_dir = self._legacy_version_dir(bucket_id, safe_key) @@ -782,7 +814,7 @@ class ObjectStorage: if not bucket_path.exists(): raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) version_dir = self._version_dir(bucket_id, safe_key) data_path = version_dir / f"{version_id}.bin" meta_path = version_dir / f"{version_id}.json" @@ -819,7 +851,7 @@ class ObjectStorage: if not bucket_path.exists(): raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) version_dir = self._version_dir(bucket_id, safe_key) data_path = version_dir / f"{version_id}.bin" meta_path = version_dir / f"{version_id}.json" @@ -910,7 +942,7 @@ class ObjectStorage: if not bucket_path.exists(): raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) upload_id = uuid.uuid4().hex upload_root = self._multipart_dir(bucket_id, upload_id) upload_root.mkdir(parents=True, exist_ok=False) @@ -995,6 +1027,102 @@ class ObjectStorage: return record["etag"] + def upload_part_copy( + self, + bucket_name: str, + upload_id: str, + part_number: int, + source_bucket: str, + source_key: str, + start_byte: Optional[int] = None, + end_byte: Optional[int] = None, + ) -> Dict[str, Any]: + """Copy a range from an existing object as a multipart part.""" + if part_number < 1 or part_number > 10000: + raise StorageError("part_number must be between 1 and 10000") + + source_path = self.get_object_path(source_bucket, source_key) + source_size = source_path.stat().st_size + + if start_byte is None: + start_byte = 0 + if end_byte is None: + end_byte = source_size - 1 + + if start_byte < 0 or end_byte >= source_size or start_byte > end_byte: + raise StorageError("Invalid byte range") + + bucket_path = self._bucket_path(bucket_name) + upload_root = self._multipart_dir(bucket_path.name, upload_id) + if not upload_root.exists(): + upload_root = self._legacy_multipart_dir(bucket_path.name, upload_id) + if not upload_root.exists(): + raise StorageError("Multipart upload not found") + + checksum = hashlib.md5() + part_filename = f"part-{part_number:05d}.part" + part_path = upload_root / part_filename + temp_path = upload_root / f".{part_filename}.tmp" + + try: + with source_path.open("rb") as src: + src.seek(start_byte) + bytes_to_copy = end_byte - start_byte + 1 + with temp_path.open("wb") as target: + remaining = bytes_to_copy + while remaining > 0: + chunk_size = min(65536, remaining) + chunk = src.read(chunk_size) + if not chunk: + break + checksum.update(chunk) + target.write(chunk) + remaining -= len(chunk) + temp_path.replace(part_path) + except OSError: + try: + temp_path.unlink(missing_ok=True) + except OSError: + pass + raise + + record = { + "etag": checksum.hexdigest(), + "size": part_path.stat().st_size, + "filename": part_filename, + } + + manifest_path = upload_root / self.MULTIPART_MANIFEST + lock_path = upload_root / ".manifest.lock" + + max_retries = 3 + for attempt in range(max_retries): + try: + with lock_path.open("w") as lock_file: + with _file_lock(lock_file): + try: + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as exc: + if attempt < max_retries - 1: + time.sleep(0.1 * (attempt + 1)) + continue + raise StorageError("Multipart manifest unreadable") from exc + + parts = manifest.setdefault("parts", {}) + parts[str(part_number)] = record + manifest_path.write_text(json.dumps(manifest), encoding="utf-8") + break + except OSError as exc: + if attempt < max_retries - 1: + time.sleep(0.1 * (attempt + 1)) + continue + raise StorageError(f"Failed to update multipart manifest: {exc}") from exc + + return { + "etag": record["etag"], + "last_modified": datetime.fromtimestamp(part_path.stat().st_mtime, timezone.utc), + } + def complete_multipart_upload( self, bucket_name: str, @@ -1034,7 +1162,7 @@ class ObjectStorage: total_size += record.get("size", 0) validated.sort(key=lambda entry: entry[0]) - safe_key = self._sanitize_object_key(manifest["object_key"]) + safe_key = self._sanitize_object_key(manifest["object_key"], self._object_key_max_length_bytes) destination = bucket_path / safe_key is_overwrite = destination.exists() @@ -1057,36 +1185,28 @@ class ObjectStorage: ) destination.parent.mkdir(parents=True, exist_ok=True) - - lock_file_path = self._system_bucket_root(bucket_id) / "locks" / f"{safe_key.as_posix().replace('/', '_')}.lock" - lock_file_path.parent.mkdir(parents=True, exist_ok=True) - - try: - with lock_file_path.open("w") as lock_file: - with _file_lock(lock_file): - if self._is_versioning_enabled(bucket_path) and destination.exists(): - self._archive_current_version(bucket_id, safe_key, reason="overwrite") - checksum = hashlib.md5() - with destination.open("wb") as target: - for _, record in validated: - part_path = upload_root / record["filename"] - if not part_path.exists(): - raise StorageError(f"Missing part file {record['filename']}") - with part_path.open("rb") as chunk: - while True: - data = chunk.read(1024 * 1024) - if not data: - break - checksum.update(data) - target.write(data) + lock_file_path = self._system_bucket_root(bucket_id) / "locks" / f"{safe_key.as_posix().replace('/', '_')}.lock" + + try: + with _atomic_lock_file(lock_file_path): + if self._is_versioning_enabled(bucket_path) and destination.exists(): + self._archive_current_version(bucket_id, safe_key, reason="overwrite") + checksum = hashlib.md5() + with destination.open("wb") as target: + for _, record in validated: + part_path = upload_root / record["filename"] + if not part_path.exists(): + raise StorageError(f"Missing part file {record['filename']}") + with part_path.open("rb") as chunk: + while True: + data = chunk.read(1024 * 1024) + if not data: + break + checksum.update(data) + target.write(data) except BlockingIOError: raise StorageError("Another upload to this key is in progress") - finally: - try: - lock_file_path.unlink(missing_ok=True) - except OSError: - pass shutil.rmtree(upload_root, ignore_errors=True) @@ -1213,7 +1333,7 @@ class ObjectStorage: def _object_path(self, bucket_name: str, object_key: str) -> Path: bucket_path = self._bucket_path(bucket_name) - safe_key = self._sanitize_object_key(object_key) + safe_key = self._sanitize_object_key(object_key, self._object_key_max_length_bytes) return bucket_path / safe_key def _system_root_path(self) -> Path: @@ -1429,7 +1549,7 @@ class ObjectStorage: current_version = self._cache_version.get(bucket_id, 0) if current_version != cache_version: objects = self._build_object_cache(bucket_path) - while len(self._object_cache) >= self.OBJECT_CACHE_MAX_SIZE: + while len(self._object_cache) >= self._object_cache_max_size: self._object_cache.popitem(last=False) self._object_cache[bucket_id] = (objects, time.time()) @@ -1764,16 +1884,16 @@ class ObjectStorage: return name @staticmethod - def _sanitize_object_key(object_key: str) -> Path: + def _sanitize_object_key(object_key: str, max_length_bytes: int = 1024) -> Path: if not object_key: raise StorageError("Object key required") - if len(object_key.encode("utf-8")) > 1024: - raise StorageError("Object key exceeds maximum length of 1024 bytes") if "\x00" in object_key: raise StorageError("Object key contains null bytes") + object_key = unicodedata.normalize("NFC", object_key) + if len(object_key.encode("utf-8")) > max_length_bytes: + raise StorageError(f"Object key exceeds maximum length of {max_length_bytes} bytes") if object_key.startswith(("/", "\\")): raise StorageError("Object key cannot start with a slash") - object_key = unicodedata.normalize("NFC", object_key) candidate = Path(object_key) if ".." in candidate.parts: diff --git a/app/ui.py b/app/ui.py index 738521c..334ba89 100644 --- a/app/ui.py +++ b/app/ui.py @@ -38,6 +38,7 @@ from .kms import KMSManager from .replication import ReplicationManager, ReplicationRule from .s3_api import _generate_presigned_url from .secret_store import EphemeralSecretStore +from .site_registry import SiteRegistry, SiteInfo, PeerSite from .storage import ObjectStorage, StorageError ui_bp = Blueprint("ui", __name__, template_folder="../templates", url_prefix="/ui") @@ -145,6 +146,10 @@ def _operation_metrics(): return current_app.extensions.get("operation_metrics") +def _site_registry() -> SiteRegistry: + return current_app.extensions["site_registry"] + + def _format_bytes(num: int) -> str: step = 1024 units = ["B", "KB", "MB", "GB", "TB", "PB"] @@ -1091,7 +1096,9 @@ def object_presign(bucket_name: str, object_key: str): expires = int(payload.get("expires_in", 900)) except (TypeError, ValueError): return jsonify({"error": "expires_in must be an integer"}), 400 - expires = max(1, min(expires, 7 * 24 * 3600)) + min_expiry = current_app.config.get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1) + max_expiry = current_app.config.get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800) + expires = max(min_expiry, min(expires, max_expiry)) storage = _storage() if not storage.bucket_exists(bucket_name): return jsonify({"error": "Bucket does not exist"}), 404 @@ -2661,6 +2668,664 @@ def list_buckets_for_copy(bucket_name: str): return jsonify({"buckets": allowed}) +@ui_bp.get("/sites") +def sites_dashboard(): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied: Site management requires admin permissions", "danger") + return redirect(url_for("ui.buckets_overview")) + + registry = _site_registry() + local_site = registry.get_local_site() + peers = registry.list_peers() + connections = _connections().list() + + replication = _replication() + all_rules = replication.list_rules() + + peers_with_stats = [] + for peer in peers: + buckets_syncing = 0 + has_bidirectional = False + if peer.connection_id: + for rule in all_rules: + if rule.target_connection_id == peer.connection_id: + buckets_syncing += 1 + if rule.mode == "bidirectional": + has_bidirectional = True + peers_with_stats.append({ + "peer": peer, + "buckets_syncing": buckets_syncing, + "has_connection": bool(peer.connection_id), + "has_bidirectional": has_bidirectional, + }) + + return render_template( + "sites.html", + principal=principal, + local_site=local_site, + peers=peers, + peers_with_stats=peers_with_stats, + connections=connections, + config_site_id=current_app.config.get("SITE_ID"), + config_site_endpoint=current_app.config.get("SITE_ENDPOINT"), + config_site_region=current_app.config.get("SITE_REGION", "us-east-1"), + ) + + +@ui_bp.post("/sites/local") +def update_local_site(): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied", "danger") + return redirect(url_for("ui.sites_dashboard")) + + site_id = request.form.get("site_id", "").strip() + endpoint = request.form.get("endpoint", "").strip() + region = request.form.get("region", "us-east-1").strip() + priority = request.form.get("priority", "100") + display_name = request.form.get("display_name", "").strip() + + if not site_id: + flash("Site ID is required", "danger") + return redirect(url_for("ui.sites_dashboard")) + + try: + priority_int = int(priority) + except ValueError: + priority_int = 100 + + registry = _site_registry() + existing = registry.get_local_site() + + site = SiteInfo( + site_id=site_id, + endpoint=endpoint, + region=region, + priority=priority_int, + display_name=display_name or site_id, + created_at=existing.created_at if existing else None, + ) + registry.set_local_site(site) + + flash("Local site configuration updated", "success") + return redirect(url_for("ui.sites_dashboard")) + + +@ui_bp.post("/sites/peers") +def add_peer_site(): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied", "danger") + return redirect(url_for("ui.sites_dashboard")) + + site_id = request.form.get("site_id", "").strip() + endpoint = request.form.get("endpoint", "").strip() + region = request.form.get("region", "us-east-1").strip() + priority = request.form.get("priority", "100") + display_name = request.form.get("display_name", "").strip() + connection_id = request.form.get("connection_id", "").strip() or None + + if not site_id: + flash("Site ID is required", "danger") + return redirect(url_for("ui.sites_dashboard")) + if not endpoint: + flash("Endpoint is required", "danger") + return redirect(url_for("ui.sites_dashboard")) + + try: + priority_int = int(priority) + except ValueError: + priority_int = 100 + + registry = _site_registry() + + if registry.get_peer(site_id): + flash(f"Peer site '{site_id}' already exists", "danger") + return redirect(url_for("ui.sites_dashboard")) + + if connection_id and not _connections().get(connection_id): + flash(f"Connection '{connection_id}' not found", "danger") + return redirect(url_for("ui.sites_dashboard")) + + peer = PeerSite( + site_id=site_id, + endpoint=endpoint, + region=region, + priority=priority_int, + display_name=display_name or site_id, + connection_id=connection_id, + ) + registry.add_peer(peer) + + flash(f"Peer site '{site_id}' added", "success") + + if connection_id: + return redirect(url_for("ui.replication_wizard", site_id=site_id)) + return redirect(url_for("ui.sites_dashboard")) + + +@ui_bp.post("/sites/peers//update") +def update_peer_site(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied", "danger") + return redirect(url_for("ui.sites_dashboard")) + + registry = _site_registry() + existing = registry.get_peer(site_id) + + if not existing: + flash(f"Peer site '{site_id}' not found", "danger") + return redirect(url_for("ui.sites_dashboard")) + + endpoint = request.form.get("endpoint", existing.endpoint).strip() + region = request.form.get("region", existing.region).strip() + priority = request.form.get("priority", str(existing.priority)) + display_name = request.form.get("display_name", existing.display_name).strip() + connection_id = request.form.get("connection_id", "").strip() or existing.connection_id + + try: + priority_int = int(priority) + except ValueError: + priority_int = existing.priority + + if connection_id and not _connections().get(connection_id): + flash(f"Connection '{connection_id}' not found", "danger") + return redirect(url_for("ui.sites_dashboard")) + + peer = PeerSite( + site_id=site_id, + endpoint=endpoint, + region=region, + priority=priority_int, + display_name=display_name or site_id, + connection_id=connection_id, + created_at=existing.created_at, + is_healthy=existing.is_healthy, + last_health_check=existing.last_health_check, + ) + registry.update_peer(peer) + + flash(f"Peer site '{site_id}' updated", "success") + return redirect(url_for("ui.sites_dashboard")) + + +@ui_bp.post("/sites/peers//delete") +def delete_peer_site(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied", "danger") + return redirect(url_for("ui.sites_dashboard")) + + registry = _site_registry() + if registry.delete_peer(site_id): + flash(f"Peer site '{site_id}' deleted", "success") + else: + flash(f"Peer site '{site_id}' not found", "danger") + + return redirect(url_for("ui.sites_dashboard")) + + +@ui_bp.get("/sites/peers//health") +def check_peer_site_health(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + return jsonify({"error": "Access denied"}), 403 + + registry = _site_registry() + peer = registry.get_peer(site_id) + + if not peer: + return jsonify({"error": f"Peer site '{site_id}' not found"}), 404 + + is_healthy = False + error_message = None + + if peer.connection_id: + connection = _connections().get(peer.connection_id) + if connection: + is_healthy = _replication().check_endpoint_health(connection) + else: + error_message = f"Connection '{peer.connection_id}' not found" + else: + error_message = "No connection configured for this peer" + + registry.update_health(site_id, is_healthy) + + result = { + "site_id": site_id, + "is_healthy": is_healthy, + } + if error_message: + result["error"] = error_message + + return jsonify(result) + + +@ui_bp.get("/sites/peers//bidirectional-status") +def check_peer_bidirectional_status(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + return jsonify({"error": "Access denied"}), 403 + + registry = _site_registry() + peer = registry.get_peer(site_id) + + if not peer: + return jsonify({"error": f"Peer site '{site_id}' not found"}), 404 + + local_site = registry.get_local_site() + replication = _replication() + local_rules = replication.list_rules() + + local_bidir_rules = [] + for rule in local_rules: + if rule.target_connection_id == peer.connection_id and rule.mode == "bidirectional": + local_bidir_rules.append({ + "bucket_name": rule.bucket_name, + "target_bucket": rule.target_bucket, + "enabled": rule.enabled, + }) + + result = { + "site_id": site_id, + "local_site_id": local_site.site_id if local_site else None, + "local_endpoint": local_site.endpoint if local_site else None, + "local_bidirectional_rules": local_bidir_rules, + "local_site_sync_enabled": current_app.config.get("SITE_SYNC_ENABLED", False), + "remote_status": None, + "issues": [], + "is_fully_configured": False, + } + + if not local_site or not local_site.site_id: + result["issues"].append({ + "code": "NO_LOCAL_SITE_ID", + "message": "Local site identity not configured", + "severity": "error", + }) + + if not local_site or not local_site.endpoint: + result["issues"].append({ + "code": "NO_LOCAL_ENDPOINT", + "message": "Local site endpoint not configured (remote site cannot reach back)", + "severity": "error", + }) + + if not peer.connection_id: + result["issues"].append({ + "code": "NO_CONNECTION", + "message": "No connection configured for this peer", + "severity": "error", + }) + return jsonify(result) + + connection = _connections().get(peer.connection_id) + if not connection: + result["issues"].append({ + "code": "CONNECTION_NOT_FOUND", + "message": f"Connection '{peer.connection_id}' not found", + "severity": "error", + }) + return jsonify(result) + + if not local_bidir_rules: + result["issues"].append({ + "code": "NO_LOCAL_BIDIRECTIONAL_RULES", + "message": "No bidirectional replication rules configured on this site", + "severity": "warning", + }) + + if not result["local_site_sync_enabled"]: + result["issues"].append({ + "code": "SITE_SYNC_DISABLED", + "message": "Site sync worker is disabled (SITE_SYNC_ENABLED=false). Pull operations will not work.", + "severity": "warning", + }) + + if not replication.check_endpoint_health(connection): + result["issues"].append({ + "code": "REMOTE_UNREACHABLE", + "message": "Remote endpoint is not reachable", + "severity": "error", + }) + return jsonify(result) + + try: + parsed = urlparse(peer.endpoint) + hostname = parsed.hostname or "" + import ipaddress + cloud_metadata_hosts = {"metadata.google.internal", "169.254.169.254"} + if hostname.lower() in cloud_metadata_hosts: + result["issues"].append({ + "code": "ENDPOINT_NOT_ALLOWED", + "message": "Peer endpoint points to cloud metadata service (SSRF protection)", + "severity": "error", + }) + return jsonify(result) + allow_internal = current_app.config.get("ALLOW_INTERNAL_ENDPOINTS", False) + if not allow_internal: + try: + ip = ipaddress.ip_address(hostname) + if ip.is_private or ip.is_loopback or ip.is_reserved or ip.is_link_local: + result["issues"].append({ + "code": "ENDPOINT_NOT_ALLOWED", + "message": "Peer endpoint points to internal or private address (set ALLOW_INTERNAL_ENDPOINTS=true for self-hosted deployments)", + "severity": "error", + }) + return jsonify(result) + except ValueError: + blocked_patterns = ["localhost", "127.", "10.", "192.168.", "172.16."] + if any(hostname.startswith(p) or hostname == p.rstrip(".") for p in blocked_patterns): + result["issues"].append({ + "code": "ENDPOINT_NOT_ALLOWED", + "message": "Peer endpoint points to internal or private address (set ALLOW_INTERNAL_ENDPOINTS=true for self-hosted deployments)", + "severity": "error", + }) + return jsonify(result) + except Exception: + pass + + try: + admin_url = peer.endpoint.rstrip("/") + "/admin/sites" + resp = requests.get( + admin_url, + timeout=10, + headers={ + "Accept": "application/json", + "X-Access-Key": connection.access_key, + "X-Secret-Key": connection.secret_key, + }, + ) + + if resp.status_code == 200: + try: + remote_data = resp.json() + if not isinstance(remote_data, dict): + raise ValueError("Expected JSON object") + remote_local = remote_data.get("local") + if remote_local is not None and not isinstance(remote_local, dict): + raise ValueError("Expected 'local' to be an object") + remote_peers = remote_data.get("peers", []) + if not isinstance(remote_peers, list): + raise ValueError("Expected 'peers' to be a list") + except (ValueError, json.JSONDecodeError) as e: + result["remote_status"] = {"reachable": True, "invalid_response": True} + result["issues"].append({ + "code": "REMOTE_INVALID_RESPONSE", + "message": "Remote admin API returned invalid JSON", + "severity": "warning", + }) + return jsonify(result) + + result["remote_status"] = { + "reachable": True, + "local_site": remote_local, + "site_sync_enabled": None, + "has_peer_for_us": False, + "peer_connection_configured": False, + "has_bidirectional_rules_for_us": False, + } + + for rp in remote_peers: + if not isinstance(rp, dict): + continue + if local_site and ( + rp.get("site_id") == local_site.site_id or + rp.get("endpoint") == local_site.endpoint + ): + result["remote_status"]["has_peer_for_us"] = True + result["remote_status"]["peer_connection_configured"] = bool(rp.get("connection_id")) + break + + if not result["remote_status"]["has_peer_for_us"]: + result["issues"].append({ + "code": "REMOTE_NO_PEER_FOR_US", + "message": "Remote site does not have this site registered as a peer", + "severity": "error", + }) + elif not result["remote_status"]["peer_connection_configured"]: + result["issues"].append({ + "code": "REMOTE_NO_CONNECTION_FOR_US", + "message": "Remote site has us as peer but no connection configured (cannot push back)", + "severity": "error", + }) + elif resp.status_code == 401 or resp.status_code == 403: + result["remote_status"] = { + "reachable": True, + "admin_access_denied": True, + } + result["issues"].append({ + "code": "REMOTE_ADMIN_ACCESS_DENIED", + "message": "Cannot verify remote configuration (admin access denied)", + "severity": "warning", + }) + else: + result["remote_status"] = { + "reachable": True, + "admin_api_error": resp.status_code, + } + result["issues"].append({ + "code": "REMOTE_ADMIN_API_ERROR", + "message": f"Remote admin API returned status {resp.status_code}", + "severity": "warning", + }) + except requests.RequestException: + result["remote_status"] = { + "reachable": False, + "error": "Connection failed", + } + result["issues"].append({ + "code": "REMOTE_ADMIN_UNREACHABLE", + "message": "Could not reach remote admin API", + "severity": "warning", + }) + except Exception: + result["issues"].append({ + "code": "VERIFICATION_ERROR", + "message": "Internal error during verification", + "severity": "warning", + }) + + error_issues = [i for i in result["issues"] if i["severity"] == "error"] + result["is_fully_configured"] = len(error_issues) == 0 and len(local_bidir_rules) > 0 + + return jsonify(result) + + +@ui_bp.get("/sites/peers//replication-wizard") +def replication_wizard(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied", "danger") + return redirect(url_for("ui.sites_dashboard")) + + registry = _site_registry() + peer = registry.get_peer(site_id) + if not peer: + flash(f"Peer site '{site_id}' not found", "danger") + return redirect(url_for("ui.sites_dashboard")) + + if not peer.connection_id: + flash("This peer has no connection configured. Add a connection first to set up replication.", "warning") + return redirect(url_for("ui.sites_dashboard")) + + connection = _connections().get(peer.connection_id) + if not connection: + flash(f"Connection '{peer.connection_id}' not found", "danger") + return redirect(url_for("ui.sites_dashboard")) + + buckets = _storage().list_buckets() + replication = _replication() + + bucket_info = [] + for bucket in buckets: + existing_rule = replication.get_rule(bucket.name) + has_rule_for_peer = ( + existing_rule and + existing_rule.target_connection_id == peer.connection_id + ) + bucket_info.append({ + "name": bucket.name, + "has_rule": has_rule_for_peer, + "existing_mode": existing_rule.mode if has_rule_for_peer else None, + "existing_target": existing_rule.target_bucket if has_rule_for_peer else None, + }) + + local_site = registry.get_local_site() + + return render_template( + "replication_wizard.html", + principal=principal, + peer=peer, + connection=connection, + buckets=bucket_info, + local_site=local_site, + csrf_token=generate_csrf, + ) + + +@ui_bp.post("/sites/peers//replication-rules") +def create_peer_replication_rules(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + flash("Access denied", "danger") + return redirect(url_for("ui.sites_dashboard")) + + registry = _site_registry() + peer = registry.get_peer(site_id) + if not peer or not peer.connection_id: + flash("Invalid peer site or no connection configured", "danger") + return redirect(url_for("ui.sites_dashboard")) + + from .replication import REPLICATION_MODE_NEW_ONLY, REPLICATION_MODE_ALL + import time as time_module + + selected_buckets = request.form.getlist("buckets") + mode = request.form.get("mode", REPLICATION_MODE_NEW_ONLY) + + if not selected_buckets: + flash("No buckets selected", "warning") + return redirect(url_for("ui.sites_dashboard")) + + created = 0 + failed = 0 + replication = _replication() + + for bucket_name in selected_buckets: + target_bucket = request.form.get(f"target_{bucket_name}", bucket_name).strip() + if not target_bucket: + target_bucket = bucket_name + + try: + rule = ReplicationRule( + bucket_name=bucket_name, + target_connection_id=peer.connection_id, + target_bucket=target_bucket, + enabled=True, + mode=mode, + created_at=time_module.time(), + ) + replication.set_rule(rule) + + if mode == REPLICATION_MODE_ALL: + replication.replicate_existing_objects(bucket_name) + + created += 1 + except Exception: + failed += 1 + + if created > 0: + flash(f"Created {created} replication rule(s) for {peer.display_name or peer.site_id}", "success") + if failed > 0: + flash(f"Failed to create {failed} rule(s)", "danger") + + return redirect(url_for("ui.sites_dashboard")) + + +@ui_bp.get("/sites/peers//sync-stats") +def get_peer_sync_stats(site_id: str): + principal = _current_principal() + try: + _iam().authorize(principal, None, "iam:*") + except IamError: + return jsonify({"error": "Access denied"}), 403 + + registry = _site_registry() + peer = registry.get_peer(site_id) + if not peer: + return jsonify({"error": "Peer not found"}), 404 + + if not peer.connection_id: + return jsonify({"error": "No connection configured"}), 400 + + replication = _replication() + all_rules = replication.list_rules() + + stats = { + "buckets_syncing": 0, + "objects_synced": 0, + "objects_pending": 0, + "objects_failed": 0, + "bytes_synced": 0, + "last_sync_at": None, + "buckets": [], + } + + for rule in all_rules: + if rule.target_connection_id != peer.connection_id: + continue + + stats["buckets_syncing"] += 1 + + bucket_stats = { + "bucket_name": rule.bucket_name, + "target_bucket": rule.target_bucket, + "mode": rule.mode, + "enabled": rule.enabled, + } + + if rule.stats: + stats["objects_synced"] += rule.stats.objects_synced + stats["objects_pending"] += rule.stats.objects_pending + stats["bytes_synced"] += rule.stats.bytes_synced + + if rule.stats.last_sync_at: + if not stats["last_sync_at"] or rule.stats.last_sync_at > stats["last_sync_at"]: + stats["last_sync_at"] = rule.stats.last_sync_at + + bucket_stats["last_sync_at"] = rule.stats.last_sync_at + bucket_stats["objects_synced"] = rule.stats.objects_synced + bucket_stats["objects_pending"] = rule.stats.objects_pending + + failure_count = replication.get_failure_count(rule.bucket_name) + stats["objects_failed"] += failure_count + bucket_stats["failures"] = failure_count + + stats["buckets"].append(bucket_stats) + + return jsonify(stats) + + @ui_bp.app_errorhandler(404) def ui_not_found(error): # type: ignore[override] prefix = ui_bp.url_prefix or "" diff --git a/app/version.py b/app/version.py index 998adc1..54f0689 100644 --- a/app/version.py +++ b/app/version.py @@ -1,6 +1,6 @@ from __future__ import annotations -APP_VERSION = "0.2.3" +APP_VERSION = "0.2.4" def get_version() -> str: diff --git a/docs.md b/docs.md index 68cf69a..e582c76 100644 --- a/docs.md +++ b/docs.md @@ -166,15 +166,19 @@ All configuration is done via environment variables. The table below lists every | Variable | Default | Notes | | --- | --- | --- | | `RATE_LIMIT_DEFAULT` | `200 per minute` | Default rate limit for API endpoints. | +| `RATE_LIMIT_LIST_BUCKETS` | `60 per minute` | Rate limit for listing buckets (`GET /`). | +| `RATE_LIMIT_BUCKET_OPS` | `120 per minute` | Rate limit for bucket operations (PUT/DELETE/GET/POST on `/`). | +| `RATE_LIMIT_OBJECT_OPS` | `240 per minute` | Rate limit for object operations (PUT/GET/DELETE/POST on `//`). | +| `RATE_LIMIT_HEAD_OPS` | `100 per minute` | Rate limit for HEAD requests (bucket and object). | | `RATE_LIMIT_STORAGE_URI` | `memory://` | Storage backend for rate limits. Use `redis://host:port` for distributed setups. | ### Server Configuration | Variable | Default | Notes | | --- | --- | --- | -| `SERVER_THREADS` | `4` | Waitress worker threads (1-64). More threads handle more concurrent requests but use more memory. | -| `SERVER_CONNECTION_LIMIT` | `100` | Maximum concurrent connections (10-1000). Ensure OS file descriptor limits support this value. | -| `SERVER_BACKLOG` | `1024` | TCP listen backlog (64-4096). Connections queue here when all threads are busy. | +| `SERVER_THREADS` | `0` (auto) | Waitress worker threads (1-64). Set to `0` for auto-calculation based on CPU cores (×2). | +| `SERVER_CONNECTION_LIMIT` | `0` (auto) | Maximum concurrent connections (10-1000). Set to `0` for auto-calculation based on available RAM. | +| `SERVER_BACKLOG` | `0` (auto) | TCP listen backlog (64-4096). Set to `0` for auto-calculation (connection_limit × 2). | | `SERVER_CHANNEL_TIMEOUT` | `120` | Seconds before idle connections are closed (10-300). | ### Logging @@ -1503,16 +1507,723 @@ The suite covers bucket CRUD, presigned downloads, bucket policy enforcement, an ## 14. API Matrix ``` +# Service Endpoints +GET /myfsio/health # Health check + +# Bucket Operations GET / # List buckets PUT / # Create bucket DELETE / # Remove bucket -GET / # List objects -PUT // # Upload object -GET // # Download object -DELETE // # Delete object -GET /?policy # Fetch policy -PUT /?policy # Upsert policy -DELETE /?policy # Delete policy +GET / # List objects (supports ?list-type=2) +HEAD / # Check bucket exists +POST / # POST object upload (HTML form) +POST /?delete # Bulk delete objects + +# Bucket Configuration +GET /?policy # Fetch bucket policy +PUT /?policy # Upsert bucket policy +DELETE /?policy # Delete bucket policy GET /?quota # Get bucket quota PUT /?quota # Set bucket quota (admin only) +GET /?versioning # Get versioning status +PUT /?versioning # Enable/disable versioning +GET /?lifecycle # Get lifecycle rules +PUT /?lifecycle # Set lifecycle rules +DELETE /?lifecycle # Delete lifecycle rules +GET /?cors # Get CORS configuration +PUT /?cors # Set CORS configuration +DELETE /?cors # Delete CORS configuration +GET /?encryption # Get encryption configuration +PUT /?encryption # Set default encryption +DELETE /?encryption # Delete encryption configuration +GET /?acl # Get bucket ACL +PUT /?acl # Set bucket ACL +GET /?tagging # Get bucket tags +PUT /?tagging # Set bucket tags +DELETE /?tagging # Delete bucket tags +GET /?replication # Get replication configuration +PUT /?replication # Set replication rules +DELETE /?replication # Delete replication configuration +GET /?logging # Get access logging configuration +PUT /?logging # Set access logging +GET /?notification # Get event notifications +PUT /?notification # Set event notifications (webhooks) +GET /?object-lock # Get object lock configuration +PUT /?object-lock # Set object lock configuration +GET /?uploads # List active multipart uploads +GET /?versions # List object versions +GET /?location # Get bucket location/region + +# Object Operations +PUT // # Upload object +GET // # Download object (supports Range header) +DELETE // # Delete object +HEAD // # Get object metadata +POST // # POST upload with policy +POST //?select # SelectObjectContent (SQL query) + +# Object Configuration +GET //?tagging # Get object tags +PUT //?tagging # Set object tags +DELETE //?tagging # Delete object tags +GET //?acl # Get object ACL +PUT //?acl # Set object ACL +PUT //?retention # Set object retention +GET //?retention # Get object retention +PUT //?legal-hold # Set legal hold +GET //?legal-hold # Get legal hold status + +# Multipart Upload +POST //?uploads # Initiate multipart upload +PUT //?uploadId=X&partNumber=N # Upload part +PUT //?uploadId=X&partNumber=N (with x-amz-copy-source) # UploadPartCopy +POST //?uploadId=X # Complete multipart upload +DELETE //?uploadId=X # Abort multipart upload +GET //?uploadId=X # List parts + +# Copy Operations +PUT // (with x-amz-copy-source header) # CopyObject + +# Admin API +GET /admin/site # Get local site info +PUT /admin/site # Update local site +GET /admin/sites # List peer sites +POST /admin/sites # Register peer site +GET /admin/sites/ # Get peer site +PUT /admin/sites/ # Update peer site +DELETE /admin/sites/ # Unregister peer site +GET /admin/sites//health # Check peer health +GET /admin/topology # Get cluster topology + +# KMS API +GET /kms/keys # List KMS keys +POST /kms/keys # Create KMS key +GET /kms/keys/ # Get key details +DELETE /kms/keys/ # Schedule key deletion +POST /kms/keys//enable # Enable key +POST /kms/keys//disable # Disable key +POST /kms/keys//rotate # Rotate key material +POST /kms/encrypt # Encrypt data +POST /kms/decrypt # Decrypt data +POST /kms/generate-data-key # Generate data key +POST /kms/generate-random # Generate random bytes ``` + +## 15. Health Check Endpoint + +The API exposes a simple health check endpoint for monitoring and load balancer integration: + +```bash +# Check API health +curl http://localhost:5000/myfsio/health + +# Response +{"status": "ok", "version": "0.1.7"} +``` + +The response includes: +- `status`: Always `"ok"` when the server is running +- `version`: Current application version from `app/version.py` + +Use this endpoint for: +- Load balancer health checks +- Kubernetes liveness/readiness probes +- Monitoring system integration (Prometheus, Datadog, etc.) + +## 16. Object Lock & Retention + +Object Lock prevents objects from being deleted or overwritten for a specified retention period. MyFSIO supports both GOVERNANCE and COMPLIANCE modes. + +### Retention Modes + +| Mode | Description | +|------|-------------| +| **GOVERNANCE** | Objects can't be deleted by normal users, but users with `s3:BypassGovernanceRetention` permission can override | +| **COMPLIANCE** | Objects can't be deleted or overwritten by anyone, including root, until the retention period expires | + +### Enabling Object Lock + +Object Lock must be enabled when creating a bucket: + +```bash +# Create bucket with Object Lock enabled +curl -X PUT "http://localhost:5000/my-bucket" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-bucket-object-lock-enabled: true" + +# Set default retention configuration +curl -X PUT "http://localhost:5000/my-bucket?object-lock" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "ObjectLockEnabled": "Enabled", + "Rule": { + "DefaultRetention": { + "Mode": "GOVERNANCE", + "Days": 30 + } + } + }' +``` + +### Per-Object Retention + +Set retention on individual objects: + +```bash +# Set object retention +curl -X PUT "http://localhost:5000/my-bucket/important.pdf?retention" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "Mode": "COMPLIANCE", + "RetainUntilDate": "2025-12-31T23:59:59Z" + }' + +# Get object retention +curl "http://localhost:5000/my-bucket/important.pdf?retention" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +### Legal Hold + +Legal hold provides indefinite protection independent of retention settings: + +```bash +# Enable legal hold +curl -X PUT "http://localhost:5000/my-bucket/document.pdf?legal-hold" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{"Status": "ON"}' + +# Disable legal hold +curl -X PUT "http://localhost:5000/my-bucket/document.pdf?legal-hold" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{"Status": "OFF"}' + +# Check legal hold status +curl "http://localhost:5000/my-bucket/document.pdf?legal-hold" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +## 17. Access Logging + +Enable S3-style access logging to track all requests to your buckets. + +### Configuration + +```bash +# Enable access logging +curl -X PUT "http://localhost:5000/my-bucket?logging" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "LoggingEnabled": { + "TargetBucket": "log-bucket", + "TargetPrefix": "logs/my-bucket/" + } + }' + +# Get logging configuration +curl "http://localhost:5000/my-bucket?logging" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Disable logging (empty configuration) +curl -X PUT "http://localhost:5000/my-bucket?logging" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{}' +``` + +### Log Format + +Access logs are written in S3-compatible format with fields including: +- Timestamp, bucket, key +- Operation (REST.GET.OBJECT, REST.PUT.OBJECT, etc.) +- Request ID, requester, source IP +- HTTP status, error code, bytes sent +- Total time, turn-around time +- Referrer, User-Agent + +## 18. Bucket Notifications & Webhooks + +Configure event notifications to trigger webhooks when objects are created or deleted. + +### Supported Events + +| Event Type | Description | +|-----------|-------------| +| `s3:ObjectCreated:*` | Any object creation (PUT, POST, COPY, multipart) | +| `s3:ObjectCreated:Put` | Object created via PUT | +| `s3:ObjectCreated:Post` | Object created via POST | +| `s3:ObjectCreated:Copy` | Object created via COPY | +| `s3:ObjectCreated:CompleteMultipartUpload` | Multipart upload completed | +| `s3:ObjectRemoved:*` | Any object deletion | +| `s3:ObjectRemoved:Delete` | Object deleted | +| `s3:ObjectRemoved:DeleteMarkerCreated` | Delete marker created (versioned bucket) | + +### Configuration + +```bash +# Set notification configuration +curl -X PUT "http://localhost:5000/my-bucket?notification" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "TopicConfigurations": [ + { + "Id": "upload-notify", + "TopicArn": "https://webhook.example.com/s3-events", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": { + "FilterRules": [ + {"Name": "prefix", "Value": "uploads/"}, + {"Name": "suffix", "Value": ".jpg"} + ] + } + } + } + ] + }' + +# Get notification configuration +curl "http://localhost:5000/my-bucket?notification" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +### Webhook Payload + +The webhook receives a JSON payload similar to AWS S3 event notifications: + +```json +{ + "Records": [ + { + "eventVersion": "2.1", + "eventSource": "myfsio:s3", + "eventTime": "2024-01-15T10:30:00.000Z", + "eventName": "ObjectCreated:Put", + "s3": { + "bucket": {"name": "my-bucket"}, + "object": { + "key": "uploads/photo.jpg", + "size": 102400, + "eTag": "abc123..." + } + } + } + ] +} +``` + +### Security Notes + +- Webhook URLs are validated to prevent SSRF attacks +- Internal/private IP ranges are blocked by default +- Use HTTPS endpoints in production + +## 19. SelectObjectContent (SQL Queries) + +Query CSV, JSON, or Parquet files directly using SQL without downloading the entire object. Requires DuckDB to be installed. + +### Prerequisites + +```bash +pip install duckdb +``` + +### Usage + +```bash +# Query a CSV file +curl -X POST "http://localhost:5000/my-bucket/data.csv?select" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "Expression": "SELECT name, age FROM s3object WHERE age > 25", + "ExpressionType": "SQL", + "InputSerialization": { + "CSV": { + "FileHeaderInfo": "USE", + "FieldDelimiter": "," + } + }, + "OutputSerialization": { + "JSON": {} + } + }' + +# Query a JSON file +curl -X POST "http://localhost:5000/my-bucket/data.json?select" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "Expression": "SELECT * FROM s3object s WHERE s.status = '\"active'\"", + "ExpressionType": "SQL", + "InputSerialization": {"JSON": {"Type": "LINES"}}, + "OutputSerialization": {"JSON": {}} + }' +``` + +### Supported Input Formats + +| Format | Options | +|--------|---------| +| **CSV** | `FileHeaderInfo` (USE, IGNORE, NONE), `FieldDelimiter`, `QuoteCharacter`, `RecordDelimiter` | +| **JSON** | `Type` (DOCUMENT, LINES) | +| **Parquet** | Automatic schema detection | + +### Output Formats + +- **JSON**: Returns results as JSON records +- **CSV**: Returns results as CSV + +## 20. PostObject (HTML Form Upload) + +Upload objects using HTML forms with policy-based authorization. Useful for browser-based direct uploads. + +### Form Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `key` | Yes | Object key (can include `${filename}` placeholder) | +| `file` | Yes | The file to upload | +| `policy` | No | Base64-encoded policy document | +| `x-amz-signature` | No | Policy signature | +| `x-amz-credential` | No | Credential scope | +| `x-amz-algorithm` | No | Signing algorithm (AWS4-HMAC-SHA256) | +| `x-amz-date` | No | Request timestamp | +| `Content-Type` | No | MIME type of the file | +| `x-amz-meta-*` | No | Custom metadata | + +### Example HTML Form + +```html +
+ + + + + +
+``` + +### With Policy (Signed Upload) + +For authenticated uploads, include a policy document: + +```bash +# Generate policy and signature using boto3 or similar +# Then include in form: +# - policy: base64(policy_document) +# - x-amz-signature: HMAC-SHA256(policy, signing_key) +# - x-amz-credential: access_key/date/region/s3/aws4_request +# - x-amz-algorithm: AWS4-HMAC-SHA256 +# - x-amz-date: YYYYMMDDTHHMMSSZ +``` + +## 21. Advanced S3 Operations + +### CopyObject + +Copy objects within or between buckets: + +```bash +# Copy within same bucket +curl -X PUT "http://localhost:5000/my-bucket/copy-of-file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-copy-source: /my-bucket/original-file.txt" + +# Copy to different bucket +curl -X PUT "http://localhost:5000/other-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-copy-source: /my-bucket/original-file.txt" + +# Copy with metadata replacement +curl -X PUT "http://localhost:5000/my-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-copy-source: /my-bucket/file.txt" \ + -H "x-amz-metadata-directive: REPLACE" \ + -H "x-amz-meta-newkey: newvalue" +``` + +### UploadPartCopy + +Copy data from an existing object into a multipart upload part: + +```bash +# Initiate multipart upload +UPLOAD_ID=$(curl -X POST "http://localhost:5000/my-bucket/large-file.bin?uploads" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." | jq -r '.UploadId') + +# Copy bytes 0-10485759 from source as part 1 +curl -X PUT "http://localhost:5000/my-bucket/large-file.bin?uploadId=$UPLOAD_ID&partNumber=1" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-copy-source: /source-bucket/source-file.bin" \ + -H "x-amz-copy-source-range: bytes=0-10485759" + +# Copy bytes 10485760-20971519 as part 2 +curl -X PUT "http://localhost:5000/my-bucket/large-file.bin?uploadId=$UPLOAD_ID&partNumber=2" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-copy-source: /source-bucket/source-file.bin" \ + -H "x-amz-copy-source-range: bytes=10485760-20971519" +``` + +### Range Requests + +Download partial content using the Range header: + +```bash +# Get first 1000 bytes +curl "http://localhost:5000/my-bucket/large-file.bin" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "Range: bytes=0-999" + +# Get bytes 1000-1999 +curl "http://localhost:5000/my-bucket/large-file.bin" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "Range: bytes=1000-1999" + +# Get last 500 bytes +curl "http://localhost:5000/my-bucket/large-file.bin" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "Range: bytes=-500" + +# Get from byte 5000 to end +curl "http://localhost:5000/my-bucket/large-file.bin" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "Range: bytes=5000-" +``` + +Range responses include: +- HTTP 206 Partial Content status +- `Content-Range` header showing the byte range +- `Accept-Ranges: bytes` header + +### Conditional Requests + +Use conditional headers for cache validation: + +```bash +# Only download if modified since +curl "http://localhost:5000/my-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "If-Modified-Since: Wed, 15 Jan 2025 10:00:00 GMT" + +# Only download if ETag doesn't match (changed) +curl "http://localhost:5000/my-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "If-None-Match: \"abc123...\"" + +# Only download if ETag matches +curl "http://localhost:5000/my-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "If-Match: \"abc123...\"" +``` + +## 22. Access Control Lists (ACLs) + +ACLs provide legacy-style permission management for buckets and objects. + +### Canned ACLs + +| ACL | Description | +|-----|-------------| +| `private` | Owner gets FULL_CONTROL (default) | +| `public-read` | Owner FULL_CONTROL, public READ | +| `public-read-write` | Owner FULL_CONTROL, public READ and WRITE | +| `authenticated-read` | Owner FULL_CONTROL, authenticated users READ | + +### Setting ACLs + +```bash +# Set bucket ACL using canned ACL +curl -X PUT "http://localhost:5000/my-bucket?acl" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-acl: public-read" + +# Set object ACL +curl -X PUT "http://localhost:5000/my-bucket/file.txt?acl" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-acl: private" + +# Set ACL during upload +curl -X PUT "http://localhost:5000/my-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-acl: public-read" \ + --data-binary @file.txt + +# Get bucket ACL +curl "http://localhost:5000/my-bucket?acl" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Get object ACL +curl "http://localhost:5000/my-bucket/file.txt?acl" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +### ACL vs Bucket Policies + +- **ACLs**: Simple, limited options, legacy approach +- **Bucket Policies**: Powerful, flexible, recommended for new deployments + +For most use cases, prefer bucket policies over ACLs. + +## 23. Object & Bucket Tagging + +Add metadata tags to buckets and objects for organization, cost allocation, or lifecycle rule filtering. + +### Bucket Tagging + +```bash +# Set bucket tags +curl -X PUT "http://localhost:5000/my-bucket?tagging" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "TagSet": [ + {"Key": "Environment", "Value": "Production"}, + {"Key": "Team", "Value": "Engineering"} + ] + }' + +# Get bucket tags +curl "http://localhost:5000/my-bucket?tagging" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Delete bucket tags +curl -X DELETE "http://localhost:5000/my-bucket?tagging" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +### Object Tagging + +```bash +# Set object tags +curl -X PUT "http://localhost:5000/my-bucket/file.txt?tagging" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "TagSet": [ + {"Key": "Classification", "Value": "Confidential"}, + {"Key": "Owner", "Value": "john@example.com"} + ] + }' + +# Get object tags +curl "http://localhost:5000/my-bucket/file.txt?tagging" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Delete object tags +curl -X DELETE "http://localhost:5000/my-bucket/file.txt?tagging" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Set tags during upload +curl -X PUT "http://localhost:5000/my-bucket/file.txt" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -H "x-amz-tagging: Environment=Staging&Team=QA" \ + --data-binary @file.txt +``` + +### Tagging Limits + +- Maximum 50 tags per object (configurable via `OBJECT_TAG_LIMIT`) +- Tag key: 1-128 Unicode characters +- Tag value: 0-256 Unicode characters + +### Use Cases + +- **Lifecycle Rules**: Filter objects for expiration by tag +- **Access Control**: Use tag conditions in bucket policies +- **Cost Tracking**: Group objects by project or department +- **Automation**: Trigger actions based on object tags + +## 24. CORS Configuration + +Configure Cross-Origin Resource Sharing for browser-based applications. + +### Setting CORS Rules + +```bash +# Set CORS configuration +curl -X PUT "http://localhost:5000/my-bucket?cors" \ + -H "Content-Type: application/json" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." \ + -d '{ + "CORSRules": [ + { + "AllowedOrigins": ["https://example.com", "https://app.example.com"], + "AllowedMethods": ["GET", "PUT", "POST", "DELETE"], + "AllowedHeaders": ["*"], + "ExposeHeaders": ["ETag", "x-amz-meta-*"], + "MaxAgeSeconds": 3600 + } + ] + }' + +# Get CORS configuration +curl "http://localhost:5000/my-bucket?cors" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Delete CORS configuration +curl -X DELETE "http://localhost:5000/my-bucket?cors" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +### CORS Rule Fields + +| Field | Description | +|-------|-------------| +| `AllowedOrigins` | Origins allowed to access the bucket (required) | +| `AllowedMethods` | HTTP methods allowed (GET, PUT, POST, DELETE, HEAD) | +| `AllowedHeaders` | Request headers allowed in preflight | +| `ExposeHeaders` | Response headers visible to browser | +| `MaxAgeSeconds` | How long browser can cache preflight response | + +## 25. List Objects API v2 + +MyFSIO supports both ListBucketResult v1 and v2 APIs. + +### Using v2 API + +```bash +# List with v2 (supports continuation tokens) +curl "http://localhost:5000/my-bucket?list-type=2" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# With prefix and delimiter (folder-like listing) +curl "http://localhost:5000/my-bucket?list-type=2&prefix=photos/&delimiter=/" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Pagination with continuation token +curl "http://localhost:5000/my-bucket?list-type=2&max-keys=100&continuation-token=TOKEN" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." + +# Start after specific key +curl "http://localhost:5000/my-bucket?list-type=2&start-after=photos/2024/" \ + -H "X-Access-Key: ..." -H "X-Secret-Key: ..." +``` + +### v1 vs v2 Differences + +| Feature | v1 | v2 | +|---------|----|----| +| Pagination | `marker` | `continuation-token` | +| Start position | `marker` | `start-after` | +| Fetch owner info | Always included | Use `fetch-owner=true` | +| Max keys | 1000 | 1000 | + +### Query Parameters + +| Parameter | Description | +|-----------|-------------| +| `list-type` | Set to `2` for v2 API | +| `prefix` | Filter objects by key prefix | +| `delimiter` | Group objects (typically `/`) | +| `max-keys` | Maximum results (1-1000, default 1000) | +| `continuation-token` | Token from previous response | +| `start-after` | Start listing after this key | +| `fetch-owner` | Include owner info in response | +| `encoding-type` | Set to `url` for URL-encoded keys diff --git a/requirements.txt b/requirements.txt index 17915fa..1813b33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ boto3>=1.42.14 waitress>=3.0.2 psutil>=7.1.3 cryptography>=46.0.3 -defusedxml>=0.7.1 \ No newline at end of file +defusedxml>=0.7.1 +duckdb>=1.4.4 \ No newline at end of file diff --git a/static/css/main.css b/static/css/main.css index 584ba93..89a2c5f 100644 --- a/static/css/main.css +++ b/static/css/main.css @@ -1081,11 +1081,17 @@ html.sidebar-will-collapse .sidebar-user { letter-spacing: 0.08em; } +.main-content:has(.docs-sidebar) { + overflow-x: visible; +} + .docs-sidebar { position: sticky; top: 1.5rem; border-radius: 1rem; border: 1px solid var(--myfsio-card-border); + max-height: calc(100vh - 3rem); + overflow-y: auto; } .docs-sidebar-callouts { diff --git a/templates/base.html b/templates/base.html index 146acc7..7cd24be 100644 --- a/templates/base.html +++ b/templates/base.html @@ -94,6 +94,12 @@ Metrics + + + + + Sites + {% endif %} + + + + Skip for Now + + + + {% else %} +
+
+ + + +
+
No buckets yet
+

Create some buckets first, then come back to set up replication.

+ + Go to Buckets + +
+ {% endif %} + + + + + + +{% endblock %} diff --git a/templates/sites.html b/templates/sites.html new file mode 100644 index 0000000..3553e87 --- /dev/null +++ b/templates/sites.html @@ -0,0 +1,742 @@ +{% extends "base.html" %} + +{% block title %}Sites - S3 Compatible Storage{% endblock %} + +{% block content %} + + +
+
+
+
+
+ + + + Local Site Identity +
+

This site's configuration

+
+
+
+ +
+ + +
Unique identifier for this site
+
+
+ + +
Public URL for this site
+
+
+ + +
+
+
+ + +
Lower = preferred
+
+
+ + +
+
+
+ +
+
+
+
+ +
+
+
+ + + + Add Peer Site +
+

Register a remote site

+
+
+
+ +
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+ + +
+
+
+ + +
Link to a remote connection for health checks
+
+
+ +
+
+
+
+
+ +
+
+
+
+
+ + + + Peer Sites +
+

Known remote sites in the cluster

+
+
+
+ {% if peers %} +
+ + + + + + + + + + + + + + {% for item in peers_with_stats %} + {% set peer = item.peer %} + + + + + + + + + + {% endfor %} + +
HealthSite IDEndpointRegionPrioritySync StatusActions
+ + {% if peer.is_healthy == true %} + + + + {% elif peer.is_healthy == false %} + + + + {% else %} + + + + + {% endif %} + + +
+
+ + + +
+
+ {{ peer.display_name or peer.site_id }} + {% if peer.display_name and peer.display_name != peer.site_id %} +
{{ peer.site_id }} + {% endif %} +
+
+
+ {{ peer.endpoint }} + {{ peer.region }}{{ peer.priority }} + {% if item.has_connection %} +
+ {{ item.buckets_syncing }} bucket{{ 's' if item.buckets_syncing != 1 else '' }} + {% if item.has_bidirectional %} + + + + + + {% endif %} + {% if item.buckets_syncing > 0 %} + + {% endif %} +
+
+ +
+ {% else %} + No connection + {% endif %} +
+
+ + + + + + + + + + +
+
+
+ {% else %} +
+
+ + + +
+
No peer sites yet
+

Add peer sites to enable geo-distribution and site-to-site replication.

+
+ {% endif %} +
+
+
+
+ + + + + + + + +{% endblock %} diff --git a/tests/test_security.py b/tests/test_security.py deleted file mode 100644 index 6337bc3..0000000 --- a/tests/test_security.py +++ /dev/null @@ -1,191 +0,0 @@ -import hashlib -import hmac -import pytest -from datetime import datetime, timedelta, timezone -from urllib.parse import quote - -def _sign(key, msg): - return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() - -def _get_signature_key(key, date_stamp, region_name, service_name): - k_date = _sign(("AWS4" + key).encode("utf-8"), date_stamp) - k_region = _sign(k_date, region_name) - k_service = _sign(k_region, service_name) - k_signing = _sign(k_service, "aws4_request") - return k_signing - -def create_signed_headers( - method, - path, - headers=None, - body=None, - access_key="test", - secret_key="secret", - region="us-east-1", - service="s3", - timestamp=None -): - if headers is None: - headers = {} - - if timestamp is None: - now = datetime.now(timezone.utc) - else: - now = timestamp - - amz_date = now.strftime("%Y%m%dT%H%M%SZ") - date_stamp = now.strftime("%Y%m%d") - - headers["X-Amz-Date"] = amz_date - headers["Host"] = "testserver" - - canonical_uri = quote(path, safe="/-_.~") - canonical_query_string = "" - - canonical_headers = "" - signed_headers_list = [] - for k, v in sorted(headers.items(), key=lambda x: x[0].lower()): - canonical_headers += f"{k.lower()}:{v.strip()}\n" - signed_headers_list.append(k.lower()) - - signed_headers = ";".join(signed_headers_list) - - payload_hash = hashlib.sha256(body or b"").hexdigest() - headers["X-Amz-Content-Sha256"] = payload_hash - - canonical_request = f"{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" - - credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" - string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" - - signing_key = _get_signature_key(secret_key, date_stamp, region, service) - signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() - - headers["Authorization"] = ( - f"AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, " - f"SignedHeaders={signed_headers}, Signature={signature}" - ) - return headers - -def test_sigv4_old_date(client): - # Test with a date 20 minutes in the past - old_time = datetime.now(timezone.utc) - timedelta(minutes=20) - headers = create_signed_headers("GET", "/", timestamp=old_time) - - response = client.get("/", headers=headers) - assert response.status_code == 403 - assert b"Request timestamp too old" in response.data - -def test_sigv4_future_date(client): - # Test with a date 20 minutes in the future - future_time = datetime.now(timezone.utc) + timedelta(minutes=20) - headers = create_signed_headers("GET", "/", timestamp=future_time) - - response = client.get("/", headers=headers) - assert response.status_code == 403 - assert b"Request timestamp too old" in response.data # The error message is the same - -def test_path_traversal_in_key(client, signer): - headers = signer("PUT", "/test-bucket") - client.put("/test-bucket", headers=headers) - - # Try to upload with .. in key - headers = signer("PUT", "/test-bucket/../secret.txt", body=b"attack") - response = client.put("/test-bucket/../secret.txt", headers=headers, data=b"attack") - - # Should be rejected by storage layer or flask routing - # Flask might normalize it before it reaches the app, but if it reaches, it should fail. - # If Flask normalizes /test-bucket/../secret.txt to /secret.txt, then it hits 404 (bucket not found) or 403. - # But we want to test the storage layer check. - # We can try to encode the dots? - - # If we use a key that doesn't get normalized by Flask routing easily. - # But wait, the route is // - # If I send /test-bucket/folder/../file.txt, Flask might pass "folder/../file.txt" as object_key? - # Let's try. - - headers = signer("PUT", "/test-bucket/folder/../file.txt", body=b"attack") - response = client.put("/test-bucket/folder/../file.txt", headers=headers, data=b"attack") - - # If Flask normalizes it, it becomes /test-bucket/file.txt. - # If it doesn't, it hits our check. - - # Let's try to call the storage method directly to verify the check works, - # because testing via client depends on Flask's URL handling. - pass - -def test_storage_path_traversal(app): - storage = app.extensions["object_storage"] - from app.storage import StorageError, ObjectStorage - from app.encrypted_storage import EncryptedObjectStorage - - # Get the underlying ObjectStorage if wrapped - if isinstance(storage, EncryptedObjectStorage): - storage = storage.storage - - with pytest.raises(StorageError, match="Object key contains parent directory references"): - storage._sanitize_object_key("folder/../file.txt") - - with pytest.raises(StorageError, match="Object key contains parent directory references"): - storage._sanitize_object_key("..") - -def test_head_bucket(client, signer): - headers = signer("PUT", "/head-test") - client.put("/head-test", headers=headers) - - headers = signer("HEAD", "/head-test") - response = client.head("/head-test", headers=headers) - assert response.status_code == 200 - - headers = signer("HEAD", "/non-existent") - response = client.head("/non-existent", headers=headers) - assert response.status_code == 404 - -def test_head_object(client, signer): - headers = signer("PUT", "/head-obj-test") - client.put("/head-obj-test", headers=headers) - - headers = signer("PUT", "/head-obj-test/obj", body=b"content") - client.put("/head-obj-test/obj", headers=headers, data=b"content") - - headers = signer("HEAD", "/head-obj-test/obj") - response = client.head("/head-obj-test/obj", headers=headers) - assert response.status_code == 200 - assert response.headers["ETag"] - assert response.headers["Content-Length"] == "7" - - headers = signer("HEAD", "/head-obj-test/missing") - response = client.head("/head-obj-test/missing", headers=headers) - assert response.status_code == 404 - -def test_list_parts(client, signer): - # Create bucket - headers = signer("PUT", "/multipart-test") - client.put("/multipart-test", headers=headers) - - # Initiate multipart upload - headers = signer("POST", "/multipart-test/obj?uploads") - response = client.post("/multipart-test/obj?uploads", headers=headers) - assert response.status_code == 200 - from xml.etree.ElementTree import fromstring - upload_id = fromstring(response.data).find("UploadId").text - - # Upload part 1 - headers = signer("PUT", f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", body=b"part1") - client.put(f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", headers=headers, data=b"part1") - - # Upload part 2 - headers = signer("PUT", f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", body=b"part2") - client.put(f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", headers=headers, data=b"part2") - - # List parts - headers = signer("GET", f"/multipart-test/obj?uploadId={upload_id}") - response = client.get(f"/multipart-test/obj?uploadId={upload_id}", headers=headers) - assert response.status_code == 200 - - root = fromstring(response.data) - assert root.tag == "ListPartsResult" - parts = root.findall("Part") - assert len(parts) == 2 - assert parts[0].find("PartNumber").text == "1" - assert parts[1].find("PartNumber").text == "2" diff --git a/tests/test_site_sync.py b/tests/test_site_sync.py index 4975375..405d6e4 100644 --- a/tests/test_site_sync.py +++ b/tests/test_site_sync.py @@ -20,7 +20,6 @@ from app.site_sync import ( SyncedObjectInfo, SiteSyncStats, RemoteObjectMeta, - CLOCK_SKEW_TOLERANCE_SECONDS, ) from app.storage import ObjectStorage