Fix 15 security vulnerabilities across auth, storage, and API modules

This commit is contained in:
2026-01-31 00:55:27 +08:00
parent 9385d1fe1c
commit 8c4bf67974
8 changed files with 327 additions and 267 deletions

View File

@@ -1,8 +1,12 @@
from __future__ import annotations
import ipaddress
import logging
import re
import socket
import time
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse
import requests
from flask import Blueprint, Response, current_app, jsonify, request
@@ -13,6 +17,67 @@ from .iam import IamError, Principal
from .replication import ReplicationManager
from .site_registry import PeerSite, SiteInfo, SiteRegistry
def _is_safe_url(url: str) -> bool:
"""Check if a URL is safe to make requests to (not internal/private)."""
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
blocked_hosts = {
"localhost",
"127.0.0.1",
"0.0.0.0",
"::1",
"[::1]",
"metadata.google.internal",
"169.254.169.254",
}
if hostname.lower() in blocked_hosts:
return False
try:
resolved_ip = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(resolved_ip)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return False
except (socket.gaierror, ValueError):
return False
return True
except Exception:
return False
def _validate_endpoint(endpoint: str) -> Optional[str]:
"""Validate endpoint URL format. Returns error message or None."""
try:
parsed = urlparse(endpoint)
if not parsed.scheme or parsed.scheme not in ("http", "https"):
return "Endpoint must be http or https URL"
if not parsed.netloc:
return "Endpoint must have a host"
return None
except Exception:
return "Invalid endpoint URL"
def _validate_priority(priority: Any) -> Optional[str]:
"""Validate priority value. Returns error message or None."""
try:
p = int(priority)
if p < 0 or p > 1000:
return "Priority must be between 0 and 1000"
return None
except (TypeError, ValueError):
return "Priority must be an integer"
def _validate_region(region: str) -> Optional[str]:
"""Validate region format. Returns error message or None."""
if not re.match(r"^[a-z]{2,}-[a-z]+-\d+$", region):
return "Region must match format like us-east-1"
return None
logger = logging.getLogger(__name__)
admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin")
@@ -158,6 +223,20 @@ def register_peer_site():
if not endpoint:
return _json_error("ValidationError", "endpoint is required", 400)
endpoint_error = _validate_endpoint(endpoint)
if endpoint_error:
return _json_error("ValidationError", endpoint_error, 400)
region = payload.get("region", "us-east-1")
region_error = _validate_region(region)
if region_error:
return _json_error("ValidationError", region_error, 400)
priority = payload.get("priority", 100)
priority_error = _validate_priority(priority)
if priority_error:
return _json_error("ValidationError", priority_error, 400)
registry = _site_registry()
if registry.get_peer(site_id):
@@ -171,8 +250,8 @@ def register_peer_site():
peer = PeerSite(
site_id=site_id,
endpoint=endpoint,
region=payload.get("region", "us-east-1"),
priority=payload.get("priority", 100),
region=region,
priority=int(priority),
display_name=payload.get("display_name", site_id),
connection_id=connection_id,
)
@@ -411,6 +490,14 @@ def check_bidirectional_status(site_id: str):
})
return jsonify(result)
if not _is_safe_url(peer.endpoint):
result["issues"].append({
"code": "ENDPOINT_NOT_ALLOWED",
"message": "Peer endpoint points to internal or private address",
"severity": "error",
})
return jsonify(result)
try:
admin_url = peer.endpoint.rstrip("/") + "/admin/sites"
resp = requests.get(
@@ -494,20 +581,21 @@ def check_bidirectional_status(site_id: str):
"severity": "warning",
})
except requests.RequestException as e:
logger.warning("Remote admin API unreachable: %s", e)
result["remote_status"] = {
"reachable": False,
"error": str(e),
"error": "Connection failed",
}
result["issues"].append({
"code": "REMOTE_ADMIN_UNREACHABLE",
"message": f"Could not reach remote admin API: {e}",
"message": "Could not reach remote admin API",
"severity": "warning",
})
except Exception as e:
logger.warning(f"Error checking remote bidirectional status: {e}")
logger.warning("Error checking remote bidirectional status: %s", e, exc_info=True)
result["issues"].append({
"code": "VERIFICATION_ERROR",
"message": f"Error during verification: {e}",
"message": "Internal error during verification",
"severity": "warning",
})

View File

@@ -146,6 +146,8 @@ class AppConfig:
site_region: str
site_priority: int
ratelimit_admin: str
num_trusted_proxies: int
allowed_redirect_hosts: list[str]
@classmethod
def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig":
@@ -310,6 +312,9 @@ class AppConfig:
site_region = str(_get("SITE_REGION", "us-east-1"))
site_priority = int(_get("SITE_PRIORITY", 100))
ratelimit_admin = _validate_rate_limit(str(_get("RATE_LIMIT_ADMIN", "60 per minute")))
num_trusted_proxies = int(_get("NUM_TRUSTED_PROXIES", 0))
allowed_redirect_hosts_raw = _get("ALLOWED_REDIRECT_HOSTS", "")
allowed_redirect_hosts = [h.strip() for h in str(allowed_redirect_hosts_raw).split(",") if h.strip()]
return cls(storage_root=storage_root,
max_upload_size=max_upload_size,
@@ -393,7 +398,9 @@ class AppConfig:
site_endpoint=site_endpoint,
site_region=site_region,
site_priority=site_priority,
ratelimit_admin=ratelimit_admin)
ratelimit_admin=ratelimit_admin,
num_trusted_proxies=num_trusted_proxies,
allowed_redirect_hosts=allowed_redirect_hosts)
def validate_and_report(self) -> list[str]:
"""Validate configuration and return a list of warnings/issues.
@@ -598,4 +605,6 @@ class AppConfig:
"SITE_REGION": self.site_region,
"SITE_PRIORITY": self.site_priority,
"RATE_LIMIT_ADMIN": self.ratelimit_admin,
"NUM_TRUSTED_PROXIES": self.num_trusted_proxies,
"ALLOWED_REDIRECT_HOSTS": self.allowed_redirect_hosts,
}

View File

@@ -6,6 +6,7 @@ import io
import json
import os
import secrets
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
@@ -15,6 +16,26 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
if sys.platform != "win32":
import fcntl
def _set_secure_file_permissions(file_path: Path) -> None:
"""Set restrictive file permissions (owner read/write only)."""
if sys.platform == "win32":
try:
username = os.environ.get("USERNAME", "")
if username:
subprocess.run(
["icacls", str(file_path), "/inheritance:r",
"/grant:r", f"{username}:F"],
check=True, capture_output=True
)
except (subprocess.SubprocessError, OSError):
pass
else:
os.chmod(file_path, 0o600)
class EncryptionError(Exception):
"""Raised when encryption/decryption fails."""
@@ -103,22 +124,38 @@ class LocalKeyEncryption(EncryptionProvider):
return self._master_key
def _load_or_create_master_key(self) -> bytes:
"""Load master key from file or generate a new one."""
if self.master_key_path.exists():
try:
return base64.b64decode(self.master_key_path.read_text().strip())
except Exception as exc:
raise EncryptionError(f"Failed to load master key: {exc}") from exc
key = secrets.token_bytes(32)
"""Load master key from file or generate a new one (with file locking)."""
lock_path = self.master_key_path.with_suffix(".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
try:
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(base64.b64encode(key).decode())
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
with open(lock_path, "w") as lock_file:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
if self.master_key_path.exists():
try:
return base64.b64decode(self.master_key_path.read_text().strip())
except Exception as exc:
raise EncryptionError(f"Failed to load master key: {exc}") from exc
key = secrets.token_bytes(32)
try:
self.master_key_path.write_text(base64.b64encode(key).decode())
_set_secure_file_permissions(self.master_key_path)
except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key
finally:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key
raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc
def _encrypt_data_key(self, data_key: bytes) -> bytes:
"""Encrypt the data key with the master key."""

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import hashlib
import hmac
import json
import math
import secrets
import threading
import time
from collections import deque
from dataclasses import dataclass
@@ -118,12 +120,14 @@ class IamService:
self._raw_config: Dict[str, Any] = {}
self._failed_attempts: Dict[str, Deque[datetime]] = {}
self._last_load_time = 0.0
self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {}
self._cache_ttl = 10.0
self._principal_cache: Dict[str, Tuple[Principal, float]] = {}
self._cache_ttl = 10.0
self._last_stat_check = 0.0
self._stat_check_interval = 1.0
self._sessions: Dict[str, Dict[str, Any]] = {}
self._session_lock = threading.Lock()
self._load()
self._load_lockout_state()
def _maybe_reload(self) -> None:
"""Reload configuration if the file has changed on disk."""
@@ -134,7 +138,7 @@ class IamService:
try:
if self.config_path.stat().st_mtime > self._last_load_time:
self._load()
self._credential_cache.clear()
self._principal_cache.clear()
except OSError:
pass
@@ -163,11 +167,46 @@ class IamService:
attempts = self._failed_attempts.setdefault(access_key, deque())
self._prune_attempts(attempts)
attempts.append(datetime.now(timezone.utc))
self._save_lockout_state()
def _clear_failed_attempts(self, access_key: str) -> None:
if not access_key:
return
self._failed_attempts.pop(access_key, None)
if self._failed_attempts.pop(access_key, None) is not None:
self._save_lockout_state()
def _lockout_file(self) -> Path:
return self.config_path.parent / "lockout_state.json"
def _load_lockout_state(self) -> None:
"""Load lockout state from disk."""
try:
if self._lockout_file().exists():
data = json.loads(self._lockout_file().read_text(encoding="utf-8"))
cutoff = datetime.now(timezone.utc) - self.auth_lockout_window
for key, timestamps in data.get("failed_attempts", {}).items():
valid = []
for ts in timestamps:
try:
dt = datetime.fromisoformat(ts)
if dt > cutoff:
valid.append(dt)
except (ValueError, TypeError):
continue
if valid:
self._failed_attempts[key] = deque(valid)
except (OSError, json.JSONDecodeError):
pass
def _save_lockout_state(self) -> None:
"""Persist lockout state to disk."""
data: Dict[str, Any] = {"failed_attempts": {}}
for key, attempts in self._failed_attempts.items():
data["failed_attempts"][key] = [ts.isoformat() for ts in attempts]
try:
self._lockout_file().write_text(json.dumps(data), encoding="utf-8")
except OSError:
pass
def _prune_attempts(self, attempts: Deque[datetime]) -> None:
cutoff = datetime.now(timezone.utc) - self.auth_lockout_window
@@ -210,17 +249,23 @@ class IamService:
return token
def validate_session_token(self, access_key: str, session_token: str) -> bool:
"""Validate a session token for an access key."""
session = self._sessions.get(session_token)
if not session:
hmac.compare_digest(access_key, secrets.token_urlsafe(16))
return False
if not hmac.compare_digest(session["access_key"], access_key):
return False
if time.time() > session["expires_at"]:
del self._sessions[session_token]
return False
return True
"""Validate a session token for an access key (thread-safe, constant-time)."""
dummy_key = secrets.token_urlsafe(16)
dummy_token = secrets.token_urlsafe(32)
with self._session_lock:
session = self._sessions.get(session_token)
if not session:
hmac.compare_digest(access_key, dummy_key)
hmac.compare_digest(session_token, dummy_token)
return False
key_match = hmac.compare_digest(session["access_key"], access_key)
if not key_match:
hmac.compare_digest(session_token, dummy_token)
return False
if time.time() > session["expires_at"]:
self._sessions.pop(session_token, None)
return False
return True
def _cleanup_expired_sessions(self) -> None:
"""Remove expired session tokens."""
@@ -231,9 +276,9 @@ class IamService:
def principal_for_key(self, access_key: str) -> Principal:
now = time.time()
cached = self._credential_cache.get(access_key)
cached = self._principal_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return principal
@@ -242,23 +287,14 @@ class IamService:
if not record:
raise IamError("Unknown access key")
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
self._principal_cache[access_key] = (principal, now)
return principal
def secret_for_key(self, access_key: str) -> str:
now = time.time()
cached = self._credential_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return secret
self._maybe_reload()
record = self._users.get(access_key)
if not record:
raise IamError("Unknown access key")
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
return record["secret_key"]
def authorize(self, principal: Principal, bucket_name: str | None, action: str) -> None:
@@ -330,6 +366,7 @@ class IamService:
new_secret = self._generate_secret_key()
user["secret_key"] = new_secret
self._save()
self._principal_cache.pop(access_key, None)
self._load()
return new_secret
@@ -509,26 +546,17 @@ class IamService:
raise IamError("User not found")
def get_secret_key(self, access_key: str) -> str | None:
now = time.time()
cached = self._credential_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return secret
self._maybe_reload()
record = self._users.get(access_key)
if record:
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
return record["secret_key"]
return None
def get_principal(self, access_key: str) -> Principal | None:
now = time.time()
cached = self._credential_cache.get(access_key)
cached = self._principal_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return principal
@@ -536,6 +564,6 @@ class IamService:
record = self._users.get(access_key)
if record:
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
self._principal_cache[access_key] = (principal, now)
return principal
return None

View File

@@ -5,6 +5,7 @@ import json
import logging
import os
import secrets
import subprocess
import sys
import uuid
from dataclasses import dataclass, field
@@ -16,9 +17,29 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from .encryption import EncryptionError, EncryptionProvider, EncryptionResult
if sys.platform != "win32":
import fcntl
logger = logging.getLogger(__name__)
def _set_secure_file_permissions(file_path: Path) -> None:
"""Set restrictive file permissions (owner read/write only)."""
if sys.platform == "win32":
try:
username = os.environ.get("USERNAME", "")
if username:
subprocess.run(
["icacls", str(file_path), "/inheritance:r",
"/grant:r", f"{username}:F"],
check=True, capture_output=True
)
except (subprocess.SubprocessError, OSError):
pass
else:
os.chmod(file_path, 0o600)
@dataclass
class KMSKey:
"""Represents a KMS encryption key."""
@@ -132,20 +153,33 @@ class KMSManager:
@property
def master_key(self) -> bytes:
"""Load or create the master key for encrypting KMS keys."""
"""Load or create the master key for encrypting KMS keys (with file locking)."""
if self._master_key is None:
if self.master_key_path.exists():
self._master_key = base64.b64decode(
self.master_key_path.read_text().strip()
)
else:
self._master_key = secrets.token_bytes(32)
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(
base64.b64encode(self._master_key).decode()
)
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
lock_path = self.master_key_path.with_suffix(".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
with open(lock_path, "w") as lock_file:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
if self.master_key_path.exists():
self._master_key = base64.b64decode(
self.master_key_path.read_text().strip()
)
else:
self._master_key = secrets.token_bytes(32)
self.master_key_path.write_text(
base64.b64encode(self._master_key).decode()
)
_set_secure_file_permissions(self.master_key_path)
finally:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
return self._master_key
def _load_keys(self) -> None:
@@ -177,12 +211,13 @@ class KMSManager:
encrypted = self._encrypt_key_material(key.key_material)
data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode()
keys_data.append(data)
self.keys_path.parent.mkdir(parents=True, exist_ok=True)
self.keys_path.write_text(
json.dumps({"keys": keys_data}, indent=2),
encoding="utf-8"
)
_set_secure_file_permissions(self.keys_path)
def _encrypt_key_material(self, key_material: bytes) -> bytes:
"""Encrypt key material with the master key."""

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import ipaddress
import json
import logging
import queue
import socket
import threading
import time
import uuid
@@ -14,6 +16,36 @@ from urllib.parse import urlparse
import requests
def _is_safe_url(url: str) -> bool:
"""Check if a URL is safe to make requests to (not internal/private)."""
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
blocked_hosts = {
"localhost",
"127.0.0.1",
"0.0.0.0",
"::1",
"[::1]",
"metadata.google.internal",
"169.254.169.254",
}
if hostname.lower() in blocked_hosts:
return False
try:
resolved_ip = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(resolved_ip)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return False
except (socket.gaierror, ValueError):
return False
return True
except Exception:
return False
logger = logging.getLogger(__name__)
@@ -299,6 +331,8 @@ class NotificationService:
self._queue.task_done()
def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None:
if not _is_safe_url(destination.url):
raise RuntimeError(f"Blocked request to internal/private URL: {destination.url}")
payload = event.to_s3_event()
headers = {"Content-Type": "application/json", **destination.headers}

View File

@@ -60,10 +60,13 @@ def _build_policy_context() -> Dict[str, Any]:
ctx: Dict[str, Any] = {}
if request.headers.get("Referer"):
ctx["aws:Referer"] = request.headers.get("Referer")
if request.access_route:
ctx["aws:SourceIp"] = request.access_route[0]
num_proxies = current_app.config.get("NUM_TRUSTED_PROXIES", 0)
if num_proxies > 0 and request.access_route and len(request.access_route) > num_proxies:
ctx["aws:SourceIp"] = request.access_route[-num_proxies]
elif request.remote_addr:
ctx["aws:SourceIp"] = request.remote_addr
elif request.access_route:
ctx["aws:SourceIp"] = request.access_route[0]
ctx["aws:SecureTransport"] = str(request.is_secure).lower()
if request.headers.get("User-Agent"):
ctx["aws:UserAgent"] = request.headers.get("User-Agent")
@@ -2242,6 +2245,17 @@ def _post_object(bucket_name: str) -> Response:
expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expected_signature, signature):
return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403)
principal = _iam().get_principal(access_key)
if not principal:
return _error_response("AccessDenied", "Invalid access key", 403)
if "${filename}" in object_key:
temp_key = object_key.replace("${filename}", request.files.get("file").filename if request.files.get("file") else "upload")
else:
temp_key = object_key
try:
_authorize_action(principal, bucket_name, "write", object_key=temp_key)
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
file = request.files.get("file")
if not file:
return _error_response("InvalidArgument", "Missing file field", 400)
@@ -2263,6 +2277,12 @@ def _post_object(bucket_name: str) -> Response:
success_action_status = request.form.get("success_action_status", "204")
success_action_redirect = request.form.get("success_action_redirect")
if success_action_redirect:
allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", [])
parsed = urlparse(success_action_redirect)
if parsed.scheme not in ("http", "https"):
return _error_response("InvalidArgument", "Redirect URL must use http or https", 400)
if allowed_hosts and parsed.netloc not in allowed_hosts:
return _error_response("InvalidArgument", "Redirect URL host not allowed", 400)
redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}"
return Response(status=303, headers={"Location": redirect_url})
if success_action_status == "200":