Harden security: fail-closed policies, presigned URL time/expiry validation, SSRF DNS pinning, lockout cap, proxy trust config

This commit is contained in:
2026-02-22 17:55:40 +08:00
parent 6ab702a818
commit fb32ca0a7d
7 changed files with 72 additions and 28 deletions

View File

@@ -93,7 +93,9 @@ def create_app(
app.config.setdefault("WTF_CSRF_ENABLED", False)
# Trust X-Forwarded-* headers from proxies
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1)
num_proxies = app.config.get("NUM_TRUSTED_PROXIES", 0)
if num_proxies:
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=num_proxies, x_proto=num_proxies, x_host=num_proxies, x_prefix=num_proxies)
# Enable gzip compression for responses (10-20x smaller JSON payloads)
if app.config.get("ENABLE_GZIP", True):

View File

@@ -75,7 +75,7 @@ def _evaluate_condition_operator(
expected_null = condition_values[0].lower() in ("true", "1", "yes") if condition_values else True
return is_null == expected_null
return True
return False
ACTION_ALIASES = {
"s3:listbucket": "list",

View File

@@ -164,9 +164,14 @@ class IamService:
self._clear_failed_attempts(access_key)
return self._build_principal(access_key, record)
_MAX_LOCKOUT_KEYS = 10000
def _record_failed_attempt(self, access_key: str) -> None:
if not access_key:
return
if access_key not in self._failed_attempts and len(self._failed_attempts) >= self._MAX_LOCKOUT_KEYS:
oldest_key = min(self._failed_attempts, key=lambda k: self._failed_attempts[k][0] if self._failed_attempts[k] else datetime.min.replace(tzinfo=timezone.utc))
del self._failed_attempts[oldest_key]
attempts = self._failed_attempts.setdefault(access_key, deque())
self._prune_attempts(attempts)
attempts.append(datetime.now(timezone.utc))

View File

@@ -15,29 +15,23 @@ from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests
from urllib3.util.connection import create_connection as _urllib3_create_connection
def _is_safe_url(url: str, allow_internal: bool = False) -> bool:
"""Check if a URL is safe to make requests to (not internal/private).
Args:
url: The URL to check.
allow_internal: If True, allows internal/private IP addresses.
Use for self-hosted deployments on internal networks.
"""
def _resolve_and_check_url(url: str, allow_internal: bool = False) -> Optional[str]:
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
return None
cloud_metadata_hosts = {
"metadata.google.internal",
"169.254.169.254",
}
if hostname.lower() in cloud_metadata_hosts:
return False
return None
if allow_internal:
return True
return hostname
blocked_hosts = {
"localhost",
"127.0.0.1",
@@ -46,17 +40,46 @@ def _is_safe_url(url: str, allow_internal: bool = False) -> bool:
"[::1]",
}
if hostname.lower() in blocked_hosts:
return False
return None
try:
resolved_ip = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(resolved_ip)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return False
return None
return resolved_ip
except (socket.gaierror, ValueError):
return False
return True
return None
except Exception:
return False
return None
def _is_safe_url(url: str, allow_internal: bool = False) -> bool:
return _resolve_and_check_url(url, allow_internal) is not None
_dns_pin_lock = threading.Lock()
def _pinned_post(url: str, pinned_ip: str, **kwargs: Any) -> requests.Response:
parsed = urlparse(url)
hostname = parsed.hostname or ""
session = requests.Session()
original_create = _urllib3_create_connection
def _create_pinned(address: Any, *args: Any, **kw: Any) -> Any:
host, req_port = address
if host == hostname:
return original_create((pinned_ip, req_port), *args, **kw)
return original_create(address, *args, **kw)
import urllib3.util.connection as _conn_mod
with _dns_pin_lock:
_conn_mod.create_connection = _create_pinned
try:
return session.post(url, **kwargs)
finally:
_conn_mod.create_connection = original_create
logger = logging.getLogger(__name__)
@@ -344,16 +367,18 @@ class NotificationService:
self._queue.task_done()
def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None:
if not _is_safe_url(destination.url, allow_internal=self._allow_internal_endpoints):
raise RuntimeError(f"Blocked request to cloud metadata service (SSRF protection): {destination.url}")
resolved_ip = _resolve_and_check_url(destination.url, allow_internal=self._allow_internal_endpoints)
if not resolved_ip:
raise RuntimeError(f"Blocked request (SSRF protection): {destination.url}")
payload = event.to_s3_event()
headers = {"Content-Type": "application/json", **destination.headers}
last_error = None
for attempt in range(destination.retry_count):
try:
response = requests.post(
response = _pinned_post(
destination.url,
resolved_ip,
json=payload,
headers=headers,
timeout=destination.timeout_seconds,

View File

@@ -372,12 +372,19 @@ def _verify_sigv4_query(req: Any) -> Principal | None:
raise IamError("Invalid Date format")
now = datetime.now(timezone.utc)
tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900))
if req_time > now + tolerance:
raise IamError("Request date is too far in the future")
try:
expires_seconds = int(expires)
if expires_seconds <= 0:
raise IamError("Invalid Expires value: must be positive")
except ValueError:
raise IamError("Invalid Expires value: must be an integer")
min_expiry = current_app.config.get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1)
max_expiry = current_app.config.get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800)
if expires_seconds < min_expiry or expires_seconds > max_expiry:
raise IamError(f"Expiration must be between {min_expiry} second(s) and {max_expiry} seconds")
if now > req_time + timedelta(seconds=expires_seconds):
raise IamError("Request expired")
@@ -595,7 +602,11 @@ def _validate_presigned_request(action: str, bucket_name: str, object_key: str)
request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc)
except ValueError as exc:
raise IamError("Invalid X-Amz-Date") from exc
if datetime.now(timezone.utc) > request_time + timedelta(seconds=expiry):
now = datetime.now(timezone.utc)
tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900))
if request_time > now + tolerance:
raise IamError("Request date is too far in the future")
if now > request_time + timedelta(seconds=expiry):
raise IamError("Presigned URL expired")
signed_headers_list = [header.strip().lower() for header in signed_headers.split(";") if header]
@@ -2470,7 +2481,7 @@ def _post_object(bucket_name: str) -> Response:
if success_action_redirect:
allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", [])
if not allowed_hosts:
allowed_hosts = [request.host]
return _error_response("InvalidArgument", "Redirect not allowed: ALLOWED_REDIRECT_HOSTS not configured", 400)
parsed = urlparse(success_action_redirect)
if parsed.scheme not in ("http", "https"):
return _error_response("InvalidArgument", "Redirect URL must use http or https", 400)

View File

@@ -743,7 +743,6 @@ def initiate_multipart_upload(bucket_name: str):
@ui_bp.put("/buckets/<bucket_name>/multipart/<upload_id>/parts")
@limiter.exempt
@csrf.exempt
def upload_multipart_part(bucket_name: str, upload_id: str):
principal = _current_principal()

View File

@@ -321,8 +321,9 @@ class TestNotificationService:
assert "events_sent" in stats
assert "events_failed" in stats
@patch("app.notifications.requests.post")
def test_send_notification_success(self, mock_post, notification_service):
@patch("app.notifications._pinned_post")
@patch("app.notifications._resolve_and_check_url", return_value="93.184.216.34")
def test_send_notification_success(self, mock_resolve, mock_post, notification_service):
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
@@ -337,8 +338,9 @@ class TestNotificationService:
notification_service._send_notification(event, destination)
mock_post.assert_called_once()
@patch("app.notifications.requests.post")
def test_send_notification_retry_on_failure(self, mock_post, notification_service):
@patch("app.notifications._pinned_post")
@patch("app.notifications._resolve_and_check_url", return_value="93.184.216.34")
def test_send_notification_retry_on_failure(self, mock_resolve, mock_post, notification_service):
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"