diff --git a/app/__init__.py b/app/__init__.py index 7befa1d..a8f673d 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -93,7 +93,9 @@ def create_app( app.config.setdefault("WTF_CSRF_ENABLED", False) # Trust X-Forwarded-* headers from proxies - app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) + num_proxies = app.config.get("NUM_TRUSTED_PROXIES", 0) + if num_proxies: + app.wsgi_app = ProxyFix(app.wsgi_app, x_for=num_proxies, x_proto=num_proxies, x_host=num_proxies, x_prefix=num_proxies) # Enable gzip compression for responses (10-20x smaller JSON payloads) if app.config.get("ENABLE_GZIP", True): diff --git a/app/bucket_policies.py b/app/bucket_policies.py index fcb1e41..1ff9eb6 100644 --- a/app/bucket_policies.py +++ b/app/bucket_policies.py @@ -75,7 +75,7 @@ def _evaluate_condition_operator( expected_null = condition_values[0].lower() in ("true", "1", "yes") if condition_values else True return is_null == expected_null - return True + return False ACTION_ALIASES = { "s3:listbucket": "list", diff --git a/app/iam.py b/app/iam.py index 9e14ee7..65b705a 100644 --- a/app/iam.py +++ b/app/iam.py @@ -164,9 +164,14 @@ class IamService: self._clear_failed_attempts(access_key) return self._build_principal(access_key, record) + _MAX_LOCKOUT_KEYS = 10000 + def _record_failed_attempt(self, access_key: str) -> None: if not access_key: return + if access_key not in self._failed_attempts and len(self._failed_attempts) >= self._MAX_LOCKOUT_KEYS: + oldest_key = min(self._failed_attempts, key=lambda k: self._failed_attempts[k][0] if self._failed_attempts[k] else datetime.min.replace(tzinfo=timezone.utc)) + del self._failed_attempts[oldest_key] attempts = self._failed_attempts.setdefault(access_key, deque()) self._prune_attempts(attempts) attempts.append(datetime.now(timezone.utc)) diff --git a/app/notifications.py b/app/notifications.py index 6951095..ee03ba8 100644 --- a/app/notifications.py +++ b/app/notifications.py @@ -15,29 +15,23 @@ from typing import Any, Dict, List, Optional from urllib.parse import urlparse import requests +from urllib3.util.connection import create_connection as _urllib3_create_connection -def _is_safe_url(url: str, allow_internal: bool = False) -> bool: - """Check if a URL is safe to make requests to (not internal/private). - - Args: - url: The URL to check. - allow_internal: If True, allows internal/private IP addresses. - Use for self-hosted deployments on internal networks. - """ +def _resolve_and_check_url(url: str, allow_internal: bool = False) -> Optional[str]: try: parsed = urlparse(url) hostname = parsed.hostname if not hostname: - return False + return None cloud_metadata_hosts = { "metadata.google.internal", "169.254.169.254", } if hostname.lower() in cloud_metadata_hosts: - return False + return None if allow_internal: - return True + return hostname blocked_hosts = { "localhost", "127.0.0.1", @@ -46,17 +40,46 @@ def _is_safe_url(url: str, allow_internal: bool = False) -> bool: "[::1]", } if hostname.lower() in blocked_hosts: - return False + return None try: resolved_ip = socket.gethostbyname(hostname) ip = ipaddress.ip_address(resolved_ip) if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: - return False + return None + return resolved_ip except (socket.gaierror, ValueError): - return False - return True + return None except Exception: - return False + return None + + +def _is_safe_url(url: str, allow_internal: bool = False) -> bool: + return _resolve_and_check_url(url, allow_internal) is not None + + +_dns_pin_lock = threading.Lock() + + +def _pinned_post(url: str, pinned_ip: str, **kwargs: Any) -> requests.Response: + parsed = urlparse(url) + hostname = parsed.hostname or "" + session = requests.Session() + original_create = _urllib3_create_connection + + def _create_pinned(address: Any, *args: Any, **kw: Any) -> Any: + host, req_port = address + if host == hostname: + return original_create((pinned_ip, req_port), *args, **kw) + return original_create(address, *args, **kw) + + import urllib3.util.connection as _conn_mod + with _dns_pin_lock: + _conn_mod.create_connection = _create_pinned + try: + return session.post(url, **kwargs) + finally: + _conn_mod.create_connection = original_create + logger = logging.getLogger(__name__) @@ -344,16 +367,18 @@ class NotificationService: self._queue.task_done() def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None: - if not _is_safe_url(destination.url, allow_internal=self._allow_internal_endpoints): - raise RuntimeError(f"Blocked request to cloud metadata service (SSRF protection): {destination.url}") + resolved_ip = _resolve_and_check_url(destination.url, allow_internal=self._allow_internal_endpoints) + if not resolved_ip: + raise RuntimeError(f"Blocked request (SSRF protection): {destination.url}") payload = event.to_s3_event() headers = {"Content-Type": "application/json", **destination.headers} last_error = None for attempt in range(destination.retry_count): try: - response = requests.post( + response = _pinned_post( destination.url, + resolved_ip, json=payload, headers=headers, timeout=destination.timeout_seconds, diff --git a/app/s3_api.py b/app/s3_api.py index 74f4449..7a5e5da 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -372,12 +372,19 @@ def _verify_sigv4_query(req: Any) -> Principal | None: raise IamError("Invalid Date format") now = datetime.now(timezone.utc) + tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900)) + if req_time > now + tolerance: + raise IamError("Request date is too far in the future") try: expires_seconds = int(expires) if expires_seconds <= 0: raise IamError("Invalid Expires value: must be positive") except ValueError: raise IamError("Invalid Expires value: must be an integer") + min_expiry = current_app.config.get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1) + max_expiry = current_app.config.get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800) + if expires_seconds < min_expiry or expires_seconds > max_expiry: + raise IamError(f"Expiration must be between {min_expiry} second(s) and {max_expiry} seconds") if now > req_time + timedelta(seconds=expires_seconds): raise IamError("Request expired") @@ -595,7 +602,11 @@ def _validate_presigned_request(action: str, bucket_name: str, object_key: str) request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) except ValueError as exc: raise IamError("Invalid X-Amz-Date") from exc - if datetime.now(timezone.utc) > request_time + timedelta(seconds=expiry): + now = datetime.now(timezone.utc) + tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900)) + if request_time > now + tolerance: + raise IamError("Request date is too far in the future") + if now > request_time + timedelta(seconds=expiry): raise IamError("Presigned URL expired") signed_headers_list = [header.strip().lower() for header in signed_headers.split(";") if header] @@ -2470,7 +2481,7 @@ def _post_object(bucket_name: str) -> Response: if success_action_redirect: allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", []) if not allowed_hosts: - allowed_hosts = [request.host] + return _error_response("InvalidArgument", "Redirect not allowed: ALLOWED_REDIRECT_HOSTS not configured", 400) parsed = urlparse(success_action_redirect) if parsed.scheme not in ("http", "https"): return _error_response("InvalidArgument", "Redirect URL must use http or https", 400) diff --git a/app/ui.py b/app/ui.py index 68455f2..7dfaf90 100644 --- a/app/ui.py +++ b/app/ui.py @@ -743,7 +743,6 @@ def initiate_multipart_upload(bucket_name: str): @ui_bp.put("/buckets//multipart//parts") -@limiter.exempt @csrf.exempt def upload_multipart_part(bucket_name: str, upload_id: str): principal = _current_principal() diff --git a/tests/test_notifications.py b/tests/test_notifications.py index 36d7e03..eb3b38f 100644 --- a/tests/test_notifications.py +++ b/tests/test_notifications.py @@ -321,8 +321,9 @@ class TestNotificationService: assert "events_sent" in stats assert "events_failed" in stats - @patch("app.notifications.requests.post") - def test_send_notification_success(self, mock_post, notification_service): + @patch("app.notifications._pinned_post") + @patch("app.notifications._resolve_and_check_url", return_value="93.184.216.34") + def test_send_notification_success(self, mock_resolve, mock_post, notification_service): mock_response = MagicMock() mock_response.status_code = 200 mock_post.return_value = mock_response @@ -337,8 +338,9 @@ class TestNotificationService: notification_service._send_notification(event, destination) mock_post.assert_called_once() - @patch("app.notifications.requests.post") - def test_send_notification_retry_on_failure(self, mock_post, notification_service): + @patch("app.notifications._pinned_post") + @patch("app.notifications._resolve_and_check_url", return_value="93.184.216.34") + def test_send_notification_retry_on_failure(self, mock_resolve, mock_post, notification_service): mock_response = MagicMock() mock_response.status_code = 500 mock_response.text = "Internal Server Error"