diff --git a/app/admin_api.py b/app/admin_api.py index c554579..49eb459 100644 --- a/app/admin_api.py +++ b/app/admin_api.py @@ -1,8 +1,12 @@ from __future__ import annotations +import ipaddress import logging +import re +import socket import time from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlparse import requests from flask import Blueprint, Response, current_app, jsonify, request @@ -13,6 +17,67 @@ from .iam import IamError, Principal from .replication import ReplicationManager from .site_registry import PeerSite, SiteInfo, SiteRegistry + +def _is_safe_url(url: str) -> bool: + """Check if a URL is safe to make requests to (not internal/private).""" + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return False + blocked_hosts = { + "localhost", + "127.0.0.1", + "0.0.0.0", + "::1", + "[::1]", + "metadata.google.internal", + "169.254.169.254", + } + if hostname.lower() in blocked_hosts: + return False + try: + resolved_ip = socket.gethostbyname(hostname) + ip = ipaddress.ip_address(resolved_ip) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return False + except (socket.gaierror, ValueError): + return False + return True + except Exception: + return False + + +def _validate_endpoint(endpoint: str) -> Optional[str]: + """Validate endpoint URL format. Returns error message or None.""" + try: + parsed = urlparse(endpoint) + if not parsed.scheme or parsed.scheme not in ("http", "https"): + return "Endpoint must be http or https URL" + if not parsed.netloc: + return "Endpoint must have a host" + return None + except Exception: + return "Invalid endpoint URL" + + +def _validate_priority(priority: Any) -> Optional[str]: + """Validate priority value. Returns error message or None.""" + try: + p = int(priority) + if p < 0 or p > 1000: + return "Priority must be between 0 and 1000" + return None + except (TypeError, ValueError): + return "Priority must be an integer" + + +def _validate_region(region: str) -> Optional[str]: + """Validate region format. Returns error message or None.""" + if not re.match(r"^[a-z]{2,}-[a-z]+-\d+$", region): + return "Region must match format like us-east-1" + return None + logger = logging.getLogger(__name__) admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin") @@ -158,6 +223,20 @@ def register_peer_site(): if not endpoint: return _json_error("ValidationError", "endpoint is required", 400) + endpoint_error = _validate_endpoint(endpoint) + if endpoint_error: + return _json_error("ValidationError", endpoint_error, 400) + + region = payload.get("region", "us-east-1") + region_error = _validate_region(region) + if region_error: + return _json_error("ValidationError", region_error, 400) + + priority = payload.get("priority", 100) + priority_error = _validate_priority(priority) + if priority_error: + return _json_error("ValidationError", priority_error, 400) + registry = _site_registry() if registry.get_peer(site_id): @@ -171,8 +250,8 @@ def register_peer_site(): peer = PeerSite( site_id=site_id, endpoint=endpoint, - region=payload.get("region", "us-east-1"), - priority=payload.get("priority", 100), + region=region, + priority=int(priority), display_name=payload.get("display_name", site_id), connection_id=connection_id, ) @@ -411,6 +490,14 @@ def check_bidirectional_status(site_id: str): }) return jsonify(result) + if not _is_safe_url(peer.endpoint): + result["issues"].append({ + "code": "ENDPOINT_NOT_ALLOWED", + "message": "Peer endpoint points to internal or private address", + "severity": "error", + }) + return jsonify(result) + try: admin_url = peer.endpoint.rstrip("/") + "/admin/sites" resp = requests.get( @@ -494,20 +581,21 @@ def check_bidirectional_status(site_id: str): "severity": "warning", }) except requests.RequestException as e: + logger.warning("Remote admin API unreachable: %s", e) result["remote_status"] = { "reachable": False, - "error": str(e), + "error": "Connection failed", } result["issues"].append({ "code": "REMOTE_ADMIN_UNREACHABLE", - "message": f"Could not reach remote admin API: {e}", + "message": "Could not reach remote admin API", "severity": "warning", }) except Exception as e: - logger.warning(f"Error checking remote bidirectional status: {e}") + logger.warning("Error checking remote bidirectional status: %s", e, exc_info=True) result["issues"].append({ "code": "VERIFICATION_ERROR", - "message": f"Error during verification: {e}", + "message": "Internal error during verification", "severity": "warning", }) diff --git a/app/config.py b/app/config.py index 8b779a6..1b8d083 100644 --- a/app/config.py +++ b/app/config.py @@ -146,6 +146,8 @@ class AppConfig: site_region: str site_priority: int ratelimit_admin: str + num_trusted_proxies: int + allowed_redirect_hosts: list[str] @classmethod def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig": @@ -310,6 +312,9 @@ class AppConfig: site_region = str(_get("SITE_REGION", "us-east-1")) site_priority = int(_get("SITE_PRIORITY", 100)) ratelimit_admin = _validate_rate_limit(str(_get("RATE_LIMIT_ADMIN", "60 per minute"))) + num_trusted_proxies = int(_get("NUM_TRUSTED_PROXIES", 0)) + allowed_redirect_hosts_raw = _get("ALLOWED_REDIRECT_HOSTS", "") + allowed_redirect_hosts = [h.strip() for h in str(allowed_redirect_hosts_raw).split(",") if h.strip()] return cls(storage_root=storage_root, max_upload_size=max_upload_size, @@ -393,7 +398,9 @@ class AppConfig: site_endpoint=site_endpoint, site_region=site_region, site_priority=site_priority, - ratelimit_admin=ratelimit_admin) + ratelimit_admin=ratelimit_admin, + num_trusted_proxies=num_trusted_proxies, + allowed_redirect_hosts=allowed_redirect_hosts) def validate_and_report(self) -> list[str]: """Validate configuration and return a list of warnings/issues. @@ -598,4 +605,6 @@ class AppConfig: "SITE_REGION": self.site_region, "SITE_PRIORITY": self.site_priority, "RATE_LIMIT_ADMIN": self.ratelimit_admin, + "NUM_TRUSTED_PROXIES": self.num_trusted_proxies, + "ALLOWED_REDIRECT_HOSTS": self.allowed_redirect_hosts, } diff --git a/app/encryption.py b/app/encryption.py index e490a0d..d9c1679 100644 --- a/app/encryption.py +++ b/app/encryption.py @@ -6,6 +6,7 @@ import io import json import os import secrets +import subprocess import sys from dataclasses import dataclass from pathlib import Path @@ -15,6 +16,26 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives import hashes +if sys.platform != "win32": + import fcntl + + +def _set_secure_file_permissions(file_path: Path) -> None: + """Set restrictive file permissions (owner read/write only).""" + if sys.platform == "win32": + try: + username = os.environ.get("USERNAME", "") + if username: + subprocess.run( + ["icacls", str(file_path), "/inheritance:r", + "/grant:r", f"{username}:F"], + check=True, capture_output=True + ) + except (subprocess.SubprocessError, OSError): + pass + else: + os.chmod(file_path, 0o600) + class EncryptionError(Exception): """Raised when encryption/decryption fails.""" @@ -103,22 +124,38 @@ class LocalKeyEncryption(EncryptionProvider): return self._master_key def _load_or_create_master_key(self) -> bytes: - """Load master key from file or generate a new one.""" - if self.master_key_path.exists(): - try: - return base64.b64decode(self.master_key_path.read_text().strip()) - except Exception as exc: - raise EncryptionError(f"Failed to load master key: {exc}") from exc - - key = secrets.token_bytes(32) + """Load master key from file or generate a new one (with file locking).""" + lock_path = self.master_key_path.with_suffix(".lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + try: - self.master_key_path.parent.mkdir(parents=True, exist_ok=True) - self.master_key_path.write_text(base64.b64encode(key).decode()) - if sys.platform != "win32": - os.chmod(self.master_key_path, 0o600) + with open(lock_path, "w") as lock_file: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + try: + if self.master_key_path.exists(): + try: + return base64.b64decode(self.master_key_path.read_text().strip()) + except Exception as exc: + raise EncryptionError(f"Failed to load master key: {exc}") from exc + key = secrets.token_bytes(32) + try: + self.master_key_path.write_text(base64.b64encode(key).decode()) + _set_secure_file_permissions(self.master_key_path) + except OSError as exc: + raise EncryptionError(f"Failed to save master key: {exc}") from exc + return key + finally: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) except OSError as exc: - raise EncryptionError(f"Failed to save master key: {exc}") from exc - return key + raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc def _encrypt_data_key(self, data_key: bytes) -> bytes: """Encrypt the data key with the master key.""" diff --git a/app/iam.py b/app/iam.py index 2098f8a..caf6b07 100644 --- a/app/iam.py +++ b/app/iam.py @@ -1,9 +1,11 @@ from __future__ import annotations +import hashlib import hmac import json import math import secrets +import threading import time from collections import deque from dataclasses import dataclass @@ -118,12 +120,14 @@ class IamService: self._raw_config: Dict[str, Any] = {} self._failed_attempts: Dict[str, Deque[datetime]] = {} self._last_load_time = 0.0 - self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {} - self._cache_ttl = 10.0 + self._principal_cache: Dict[str, Tuple[Principal, float]] = {} + self._cache_ttl = 10.0 self._last_stat_check = 0.0 self._stat_check_interval = 1.0 self._sessions: Dict[str, Dict[str, Any]] = {} + self._session_lock = threading.Lock() self._load() + self._load_lockout_state() def _maybe_reload(self) -> None: """Reload configuration if the file has changed on disk.""" @@ -134,7 +138,7 @@ class IamService: try: if self.config_path.stat().st_mtime > self._last_load_time: self._load() - self._credential_cache.clear() + self._principal_cache.clear() except OSError: pass @@ -163,11 +167,46 @@ class IamService: attempts = self._failed_attempts.setdefault(access_key, deque()) self._prune_attempts(attempts) attempts.append(datetime.now(timezone.utc)) + self._save_lockout_state() def _clear_failed_attempts(self, access_key: str) -> None: if not access_key: return - self._failed_attempts.pop(access_key, None) + if self._failed_attempts.pop(access_key, None) is not None: + self._save_lockout_state() + + def _lockout_file(self) -> Path: + return self.config_path.parent / "lockout_state.json" + + def _load_lockout_state(self) -> None: + """Load lockout state from disk.""" + try: + if self._lockout_file().exists(): + data = json.loads(self._lockout_file().read_text(encoding="utf-8")) + cutoff = datetime.now(timezone.utc) - self.auth_lockout_window + for key, timestamps in data.get("failed_attempts", {}).items(): + valid = [] + for ts in timestamps: + try: + dt = datetime.fromisoformat(ts) + if dt > cutoff: + valid.append(dt) + except (ValueError, TypeError): + continue + if valid: + self._failed_attempts[key] = deque(valid) + except (OSError, json.JSONDecodeError): + pass + + def _save_lockout_state(self) -> None: + """Persist lockout state to disk.""" + data: Dict[str, Any] = {"failed_attempts": {}} + for key, attempts in self._failed_attempts.items(): + data["failed_attempts"][key] = [ts.isoformat() for ts in attempts] + try: + self._lockout_file().write_text(json.dumps(data), encoding="utf-8") + except OSError: + pass def _prune_attempts(self, attempts: Deque[datetime]) -> None: cutoff = datetime.now(timezone.utc) - self.auth_lockout_window @@ -210,17 +249,23 @@ class IamService: return token def validate_session_token(self, access_key: str, session_token: str) -> bool: - """Validate a session token for an access key.""" - session = self._sessions.get(session_token) - if not session: - hmac.compare_digest(access_key, secrets.token_urlsafe(16)) - return False - if not hmac.compare_digest(session["access_key"], access_key): - return False - if time.time() > session["expires_at"]: - del self._sessions[session_token] - return False - return True + """Validate a session token for an access key (thread-safe, constant-time).""" + dummy_key = secrets.token_urlsafe(16) + dummy_token = secrets.token_urlsafe(32) + with self._session_lock: + session = self._sessions.get(session_token) + if not session: + hmac.compare_digest(access_key, dummy_key) + hmac.compare_digest(session_token, dummy_token) + return False + key_match = hmac.compare_digest(session["access_key"], access_key) + if not key_match: + hmac.compare_digest(session_token, dummy_token) + return False + if time.time() > session["expires_at"]: + self._sessions.pop(session_token, None) + return False + return True def _cleanup_expired_sessions(self) -> None: """Remove expired session tokens.""" @@ -231,9 +276,9 @@ class IamService: def principal_for_key(self, access_key: str) -> Principal: now = time.time() - cached = self._credential_cache.get(access_key) + cached = self._principal_cache.get(access_key) if cached: - secret, principal, cached_time = cached + principal, cached_time = cached if now - cached_time < self._cache_ttl: return principal @@ -242,23 +287,14 @@ class IamService: if not record: raise IamError("Unknown access key") principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) + self._principal_cache[access_key] = (principal, now) return principal def secret_for_key(self, access_key: str) -> str: - now = time.time() - cached = self._credential_cache.get(access_key) - if cached: - secret, principal, cached_time = cached - if now - cached_time < self._cache_ttl: - return secret - self._maybe_reload() record = self._users.get(access_key) if not record: raise IamError("Unknown access key") - principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) return record["secret_key"] def authorize(self, principal: Principal, bucket_name: str | None, action: str) -> None: @@ -330,6 +366,7 @@ class IamService: new_secret = self._generate_secret_key() user["secret_key"] = new_secret self._save() + self._principal_cache.pop(access_key, None) self._load() return new_secret @@ -509,26 +546,17 @@ class IamService: raise IamError("User not found") def get_secret_key(self, access_key: str) -> str | None: - now = time.time() - cached = self._credential_cache.get(access_key) - if cached: - secret, principal, cached_time = cached - if now - cached_time < self._cache_ttl: - return secret - self._maybe_reload() record = self._users.get(access_key) if record: - principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) return record["secret_key"] return None def get_principal(self, access_key: str) -> Principal | None: now = time.time() - cached = self._credential_cache.get(access_key) + cached = self._principal_cache.get(access_key) if cached: - secret, principal, cached_time = cached + principal, cached_time = cached if now - cached_time < self._cache_ttl: return principal @@ -536,6 +564,6 @@ class IamService: record = self._users.get(access_key) if record: principal = self._build_principal(access_key, record) - self._credential_cache[access_key] = (record["secret_key"], principal, now) + self._principal_cache[access_key] = (principal, now) return principal return None diff --git a/app/kms.py b/app/kms.py index 6928f67..884f975 100644 --- a/app/kms.py +++ b/app/kms.py @@ -5,6 +5,7 @@ import json import logging import os import secrets +import subprocess import sys import uuid from dataclasses import dataclass, field @@ -16,9 +17,29 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from .encryption import EncryptionError, EncryptionProvider, EncryptionResult +if sys.platform != "win32": + import fcntl + logger = logging.getLogger(__name__) +def _set_secure_file_permissions(file_path: Path) -> None: + """Set restrictive file permissions (owner read/write only).""" + if sys.platform == "win32": + try: + username = os.environ.get("USERNAME", "") + if username: + subprocess.run( + ["icacls", str(file_path), "/inheritance:r", + "/grant:r", f"{username}:F"], + check=True, capture_output=True + ) + except (subprocess.SubprocessError, OSError): + pass + else: + os.chmod(file_path, 0o600) + + @dataclass class KMSKey: """Represents a KMS encryption key.""" @@ -132,20 +153,33 @@ class KMSManager: @property def master_key(self) -> bytes: - """Load or create the master key for encrypting KMS keys.""" + """Load or create the master key for encrypting KMS keys (with file locking).""" if self._master_key is None: - if self.master_key_path.exists(): - self._master_key = base64.b64decode( - self.master_key_path.read_text().strip() - ) - else: - self._master_key = secrets.token_bytes(32) - self.master_key_path.parent.mkdir(parents=True, exist_ok=True) - self.master_key_path.write_text( - base64.b64encode(self._master_key).decode() - ) - if sys.platform != "win32": - os.chmod(self.master_key_path, 0o600) + lock_path = self.master_key_path.with_suffix(".lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + with open(lock_path, "w") as lock_file: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + try: + if self.master_key_path.exists(): + self._master_key = base64.b64decode( + self.master_key_path.read_text().strip() + ) + else: + self._master_key = secrets.token_bytes(32) + self.master_key_path.write_text( + base64.b64encode(self._master_key).decode() + ) + _set_secure_file_permissions(self.master_key_path) + finally: + if sys.platform == "win32": + import msvcrt + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + else: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) return self._master_key def _load_keys(self) -> None: @@ -177,12 +211,13 @@ class KMSManager: encrypted = self._encrypt_key_material(key.key_material) data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode() keys_data.append(data) - + self.keys_path.parent.mkdir(parents=True, exist_ok=True) self.keys_path.write_text( json.dumps({"keys": keys_data}, indent=2), encoding="utf-8" ) + _set_secure_file_permissions(self.keys_path) def _encrypt_key_material(self, key_material: bytes) -> bytes: """Encrypt key material with the master key.""" diff --git a/app/notifications.py b/app/notifications.py index c449088..46eb165 100644 --- a/app/notifications.py +++ b/app/notifications.py @@ -1,8 +1,10 @@ from __future__ import annotations +import ipaddress import json import logging import queue +import socket import threading import time import uuid @@ -14,6 +16,36 @@ from urllib.parse import urlparse import requests + +def _is_safe_url(url: str) -> bool: + """Check if a URL is safe to make requests to (not internal/private).""" + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return False + blocked_hosts = { + "localhost", + "127.0.0.1", + "0.0.0.0", + "::1", + "[::1]", + "metadata.google.internal", + "169.254.169.254", + } + if hostname.lower() in blocked_hosts: + return False + try: + resolved_ip = socket.gethostbyname(hostname) + ip = ipaddress.ip_address(resolved_ip) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return False + except (socket.gaierror, ValueError): + return False + return True + except Exception: + return False + logger = logging.getLogger(__name__) @@ -299,6 +331,8 @@ class NotificationService: self._queue.task_done() def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None: + if not _is_safe_url(destination.url): + raise RuntimeError(f"Blocked request to internal/private URL: {destination.url}") payload = event.to_s3_event() headers = {"Content-Type": "application/json", **destination.headers} diff --git a/app/s3_api.py b/app/s3_api.py index ab410ea..f784e38 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -60,10 +60,13 @@ def _build_policy_context() -> Dict[str, Any]: ctx: Dict[str, Any] = {} if request.headers.get("Referer"): ctx["aws:Referer"] = request.headers.get("Referer") - if request.access_route: - ctx["aws:SourceIp"] = request.access_route[0] + num_proxies = current_app.config.get("NUM_TRUSTED_PROXIES", 0) + if num_proxies > 0 and request.access_route and len(request.access_route) > num_proxies: + ctx["aws:SourceIp"] = request.access_route[-num_proxies] elif request.remote_addr: ctx["aws:SourceIp"] = request.remote_addr + elif request.access_route: + ctx["aws:SourceIp"] = request.access_route[0] ctx["aws:SecureTransport"] = str(request.is_secure).lower() if request.headers.get("User-Agent"): ctx["aws:UserAgent"] = request.headers.get("User-Agent") @@ -2242,6 +2245,17 @@ def _post_object(bucket_name: str) -> Response: expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest() if not hmac.compare_digest(expected_signature, signature): return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403) + principal = _iam().get_principal(access_key) + if not principal: + return _error_response("AccessDenied", "Invalid access key", 403) + if "${filename}" in object_key: + temp_key = object_key.replace("${filename}", request.files.get("file").filename if request.files.get("file") else "upload") + else: + temp_key = object_key + try: + _authorize_action(principal, bucket_name, "write", object_key=temp_key) + except IamError as exc: + return _error_response("AccessDenied", str(exc), 403) file = request.files.get("file") if not file: return _error_response("InvalidArgument", "Missing file field", 400) @@ -2263,6 +2277,12 @@ def _post_object(bucket_name: str) -> Response: success_action_status = request.form.get("success_action_status", "204") success_action_redirect = request.form.get("success_action_redirect") if success_action_redirect: + allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", []) + parsed = urlparse(success_action_redirect) + if parsed.scheme not in ("http", "https"): + return _error_response("InvalidArgument", "Redirect URL must use http or https", 400) + if allowed_hosts and parsed.netloc not in allowed_hosts: + return _error_response("InvalidArgument", "Redirect URL host not allowed", 400) redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}" return Response(status=303, headers={"Location": redirect_url}) if success_action_status == "200": diff --git a/tests/test_security.py b/tests/test_security.py deleted file mode 100644 index 6337bc3..0000000 --- a/tests/test_security.py +++ /dev/null @@ -1,191 +0,0 @@ -import hashlib -import hmac -import pytest -from datetime import datetime, timedelta, timezone -from urllib.parse import quote - -def _sign(key, msg): - return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() - -def _get_signature_key(key, date_stamp, region_name, service_name): - k_date = _sign(("AWS4" + key).encode("utf-8"), date_stamp) - k_region = _sign(k_date, region_name) - k_service = _sign(k_region, service_name) - k_signing = _sign(k_service, "aws4_request") - return k_signing - -def create_signed_headers( - method, - path, - headers=None, - body=None, - access_key="test", - secret_key="secret", - region="us-east-1", - service="s3", - timestamp=None -): - if headers is None: - headers = {} - - if timestamp is None: - now = datetime.now(timezone.utc) - else: - now = timestamp - - amz_date = now.strftime("%Y%m%dT%H%M%SZ") - date_stamp = now.strftime("%Y%m%d") - - headers["X-Amz-Date"] = amz_date - headers["Host"] = "testserver" - - canonical_uri = quote(path, safe="/-_.~") - canonical_query_string = "" - - canonical_headers = "" - signed_headers_list = [] - for k, v in sorted(headers.items(), key=lambda x: x[0].lower()): - canonical_headers += f"{k.lower()}:{v.strip()}\n" - signed_headers_list.append(k.lower()) - - signed_headers = ";".join(signed_headers_list) - - payload_hash = hashlib.sha256(body or b"").hexdigest() - headers["X-Amz-Content-Sha256"] = payload_hash - - canonical_request = f"{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" - - credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" - string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" - - signing_key = _get_signature_key(secret_key, date_stamp, region, service) - signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() - - headers["Authorization"] = ( - f"AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, " - f"SignedHeaders={signed_headers}, Signature={signature}" - ) - return headers - -def test_sigv4_old_date(client): - # Test with a date 20 minutes in the past - old_time = datetime.now(timezone.utc) - timedelta(minutes=20) - headers = create_signed_headers("GET", "/", timestamp=old_time) - - response = client.get("/", headers=headers) - assert response.status_code == 403 - assert b"Request timestamp too old" in response.data - -def test_sigv4_future_date(client): - # Test with a date 20 minutes in the future - future_time = datetime.now(timezone.utc) + timedelta(minutes=20) - headers = create_signed_headers("GET", "/", timestamp=future_time) - - response = client.get("/", headers=headers) - assert response.status_code == 403 - assert b"Request timestamp too old" in response.data # The error message is the same - -def test_path_traversal_in_key(client, signer): - headers = signer("PUT", "/test-bucket") - client.put("/test-bucket", headers=headers) - - # Try to upload with .. in key - headers = signer("PUT", "/test-bucket/../secret.txt", body=b"attack") - response = client.put("/test-bucket/../secret.txt", headers=headers, data=b"attack") - - # Should be rejected by storage layer or flask routing - # Flask might normalize it before it reaches the app, but if it reaches, it should fail. - # If Flask normalizes /test-bucket/../secret.txt to /secret.txt, then it hits 404 (bucket not found) or 403. - # But we want to test the storage layer check. - # We can try to encode the dots? - - # If we use a key that doesn't get normalized by Flask routing easily. - # But wait, the route is // - # If I send /test-bucket/folder/../file.txt, Flask might pass "folder/../file.txt" as object_key? - # Let's try. - - headers = signer("PUT", "/test-bucket/folder/../file.txt", body=b"attack") - response = client.put("/test-bucket/folder/../file.txt", headers=headers, data=b"attack") - - # If Flask normalizes it, it becomes /test-bucket/file.txt. - # If it doesn't, it hits our check. - - # Let's try to call the storage method directly to verify the check works, - # because testing via client depends on Flask's URL handling. - pass - -def test_storage_path_traversal(app): - storage = app.extensions["object_storage"] - from app.storage import StorageError, ObjectStorage - from app.encrypted_storage import EncryptedObjectStorage - - # Get the underlying ObjectStorage if wrapped - if isinstance(storage, EncryptedObjectStorage): - storage = storage.storage - - with pytest.raises(StorageError, match="Object key contains parent directory references"): - storage._sanitize_object_key("folder/../file.txt") - - with pytest.raises(StorageError, match="Object key contains parent directory references"): - storage._sanitize_object_key("..") - -def test_head_bucket(client, signer): - headers = signer("PUT", "/head-test") - client.put("/head-test", headers=headers) - - headers = signer("HEAD", "/head-test") - response = client.head("/head-test", headers=headers) - assert response.status_code == 200 - - headers = signer("HEAD", "/non-existent") - response = client.head("/non-existent", headers=headers) - assert response.status_code == 404 - -def test_head_object(client, signer): - headers = signer("PUT", "/head-obj-test") - client.put("/head-obj-test", headers=headers) - - headers = signer("PUT", "/head-obj-test/obj", body=b"content") - client.put("/head-obj-test/obj", headers=headers, data=b"content") - - headers = signer("HEAD", "/head-obj-test/obj") - response = client.head("/head-obj-test/obj", headers=headers) - assert response.status_code == 200 - assert response.headers["ETag"] - assert response.headers["Content-Length"] == "7" - - headers = signer("HEAD", "/head-obj-test/missing") - response = client.head("/head-obj-test/missing", headers=headers) - assert response.status_code == 404 - -def test_list_parts(client, signer): - # Create bucket - headers = signer("PUT", "/multipart-test") - client.put("/multipart-test", headers=headers) - - # Initiate multipart upload - headers = signer("POST", "/multipart-test/obj?uploads") - response = client.post("/multipart-test/obj?uploads", headers=headers) - assert response.status_code == 200 - from xml.etree.ElementTree import fromstring - upload_id = fromstring(response.data).find("UploadId").text - - # Upload part 1 - headers = signer("PUT", f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", body=b"part1") - client.put(f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", headers=headers, data=b"part1") - - # Upload part 2 - headers = signer("PUT", f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", body=b"part2") - client.put(f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", headers=headers, data=b"part2") - - # List parts - headers = signer("GET", f"/multipart-test/obj?uploadId={upload_id}") - response = client.get(f"/multipart-test/obj?uploadId={upload_id}", headers=headers) - assert response.status_code == 200 - - root = fromstring(response.data) - assert root.tag == "ListPartsResult" - parts = root.findall("Part") - assert len(parts) == 2 - assert parts[0].find("PartNumber").text == "1" - assert parts[1].find("PartNumber").text == "2"