MyFSIO v0.3.0 Release #22

Merged
kqjy merged 15 commits from next into main 2026-02-22 10:22:36 +00:00
7 changed files with 72 additions and 28 deletions
Showing only changes of commit fb32ca0a7d - Show all commits

View File

@@ -93,7 +93,9 @@ def create_app(
app.config.setdefault("WTF_CSRF_ENABLED", False) app.config.setdefault("WTF_CSRF_ENABLED", False)
# Trust X-Forwarded-* headers from proxies # Trust X-Forwarded-* headers from proxies
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1) num_proxies = app.config.get("NUM_TRUSTED_PROXIES", 0)
if num_proxies:
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=num_proxies, x_proto=num_proxies, x_host=num_proxies, x_prefix=num_proxies)
# Enable gzip compression for responses (10-20x smaller JSON payloads) # Enable gzip compression for responses (10-20x smaller JSON payloads)
if app.config.get("ENABLE_GZIP", True): if app.config.get("ENABLE_GZIP", True):

View File

@@ -75,7 +75,7 @@ def _evaluate_condition_operator(
expected_null = condition_values[0].lower() in ("true", "1", "yes") if condition_values else True expected_null = condition_values[0].lower() in ("true", "1", "yes") if condition_values else True
return is_null == expected_null return is_null == expected_null
return True return False
ACTION_ALIASES = { ACTION_ALIASES = {
"s3:listbucket": "list", "s3:listbucket": "list",

View File

@@ -164,9 +164,14 @@ class IamService:
self._clear_failed_attempts(access_key) self._clear_failed_attempts(access_key)
return self._build_principal(access_key, record) return self._build_principal(access_key, record)
_MAX_LOCKOUT_KEYS = 10000
def _record_failed_attempt(self, access_key: str) -> None: def _record_failed_attempt(self, access_key: str) -> None:
if not access_key: if not access_key:
return return
if access_key not in self._failed_attempts and len(self._failed_attempts) >= self._MAX_LOCKOUT_KEYS:
oldest_key = min(self._failed_attempts, key=lambda k: self._failed_attempts[k][0] if self._failed_attempts[k] else datetime.min.replace(tzinfo=timezone.utc))
del self._failed_attempts[oldest_key]
attempts = self._failed_attempts.setdefault(access_key, deque()) attempts = self._failed_attempts.setdefault(access_key, deque())
self._prune_attempts(attempts) self._prune_attempts(attempts)
attempts.append(datetime.now(timezone.utc)) attempts.append(datetime.now(timezone.utc))

View File

@@ -15,29 +15,23 @@ from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
from urllib3.util.connection import create_connection as _urllib3_create_connection
def _is_safe_url(url: str, allow_internal: bool = False) -> bool: def _resolve_and_check_url(url: str, allow_internal: bool = False) -> Optional[str]:
"""Check if a URL is safe to make requests to (not internal/private).
Args:
url: The URL to check.
allow_internal: If True, allows internal/private IP addresses.
Use for self-hosted deployments on internal networks.
"""
try: try:
parsed = urlparse(url) parsed = urlparse(url)
hostname = parsed.hostname hostname = parsed.hostname
if not hostname: if not hostname:
return False return None
cloud_metadata_hosts = { cloud_metadata_hosts = {
"metadata.google.internal", "metadata.google.internal",
"169.254.169.254", "169.254.169.254",
} }
if hostname.lower() in cloud_metadata_hosts: if hostname.lower() in cloud_metadata_hosts:
return False return None
if allow_internal: if allow_internal:
return True return hostname
blocked_hosts = { blocked_hosts = {
"localhost", "localhost",
"127.0.0.1", "127.0.0.1",
@@ -46,17 +40,46 @@ def _is_safe_url(url: str, allow_internal: bool = False) -> bool:
"[::1]", "[::1]",
} }
if hostname.lower() in blocked_hosts: if hostname.lower() in blocked_hosts:
return False return None
try: try:
resolved_ip = socket.gethostbyname(hostname) resolved_ip = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(resolved_ip) ip = ipaddress.ip_address(resolved_ip)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return False return None
return resolved_ip
except (socket.gaierror, ValueError): except (socket.gaierror, ValueError):
return False return None
return True
except Exception: except Exception:
return False return None
def _is_safe_url(url: str, allow_internal: bool = False) -> bool:
return _resolve_and_check_url(url, allow_internal) is not None
_dns_pin_lock = threading.Lock()
def _pinned_post(url: str, pinned_ip: str, **kwargs: Any) -> requests.Response:
parsed = urlparse(url)
hostname = parsed.hostname or ""
session = requests.Session()
original_create = _urllib3_create_connection
def _create_pinned(address: Any, *args: Any, **kw: Any) -> Any:
host, req_port = address
if host == hostname:
return original_create((pinned_ip, req_port), *args, **kw)
return original_create(address, *args, **kw)
import urllib3.util.connection as _conn_mod
with _dns_pin_lock:
_conn_mod.create_connection = _create_pinned
try:
return session.post(url, **kwargs)
finally:
_conn_mod.create_connection = original_create
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -344,16 +367,18 @@ class NotificationService:
self._queue.task_done() self._queue.task_done()
def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None: def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None:
if not _is_safe_url(destination.url, allow_internal=self._allow_internal_endpoints): resolved_ip = _resolve_and_check_url(destination.url, allow_internal=self._allow_internal_endpoints)
raise RuntimeError(f"Blocked request to cloud metadata service (SSRF protection): {destination.url}") if not resolved_ip:
raise RuntimeError(f"Blocked request (SSRF protection): {destination.url}")
payload = event.to_s3_event() payload = event.to_s3_event()
headers = {"Content-Type": "application/json", **destination.headers} headers = {"Content-Type": "application/json", **destination.headers}
last_error = None last_error = None
for attempt in range(destination.retry_count): for attempt in range(destination.retry_count):
try: try:
response = requests.post( response = _pinned_post(
destination.url, destination.url,
resolved_ip,
json=payload, json=payload,
headers=headers, headers=headers,
timeout=destination.timeout_seconds, timeout=destination.timeout_seconds,

View File

@@ -372,12 +372,19 @@ def _verify_sigv4_query(req: Any) -> Principal | None:
raise IamError("Invalid Date format") raise IamError("Invalid Date format")
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900))
if req_time > now + tolerance:
raise IamError("Request date is too far in the future")
try: try:
expires_seconds = int(expires) expires_seconds = int(expires)
if expires_seconds <= 0: if expires_seconds <= 0:
raise IamError("Invalid Expires value: must be positive") raise IamError("Invalid Expires value: must be positive")
except ValueError: except ValueError:
raise IamError("Invalid Expires value: must be an integer") raise IamError("Invalid Expires value: must be an integer")
min_expiry = current_app.config.get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1)
max_expiry = current_app.config.get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800)
if expires_seconds < min_expiry or expires_seconds > max_expiry:
raise IamError(f"Expiration must be between {min_expiry} second(s) and {max_expiry} seconds")
if now > req_time + timedelta(seconds=expires_seconds): if now > req_time + timedelta(seconds=expires_seconds):
raise IamError("Request expired") raise IamError("Request expired")
@@ -595,7 +602,11 @@ def _validate_presigned_request(action: str, bucket_name: str, object_key: str)
request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc)
except ValueError as exc: except ValueError as exc:
raise IamError("Invalid X-Amz-Date") from exc raise IamError("Invalid X-Amz-Date") from exc
if datetime.now(timezone.utc) > request_time + timedelta(seconds=expiry): now = datetime.now(timezone.utc)
tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900))
if request_time > now + tolerance:
raise IamError("Request date is too far in the future")
if now > request_time + timedelta(seconds=expiry):
raise IamError("Presigned URL expired") raise IamError("Presigned URL expired")
signed_headers_list = [header.strip().lower() for header in signed_headers.split(";") if header] signed_headers_list = [header.strip().lower() for header in signed_headers.split(";") if header]
@@ -2470,7 +2481,7 @@ def _post_object(bucket_name: str) -> Response:
if success_action_redirect: if success_action_redirect:
allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", []) allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", [])
if not allowed_hosts: if not allowed_hosts:
allowed_hosts = [request.host] return _error_response("InvalidArgument", "Redirect not allowed: ALLOWED_REDIRECT_HOSTS not configured", 400)
parsed = urlparse(success_action_redirect) parsed = urlparse(success_action_redirect)
if parsed.scheme not in ("http", "https"): if parsed.scheme not in ("http", "https"):
return _error_response("InvalidArgument", "Redirect URL must use http or https", 400) return _error_response("InvalidArgument", "Redirect URL must use http or https", 400)

View File

@@ -743,7 +743,6 @@ def initiate_multipart_upload(bucket_name: str):
@ui_bp.put("/buckets/<bucket_name>/multipart/<upload_id>/parts") @ui_bp.put("/buckets/<bucket_name>/multipart/<upload_id>/parts")
@limiter.exempt
@csrf.exempt @csrf.exempt
def upload_multipart_part(bucket_name: str, upload_id: str): def upload_multipart_part(bucket_name: str, upload_id: str):
principal = _current_principal() principal = _current_principal()

View File

@@ -321,8 +321,9 @@ class TestNotificationService:
assert "events_sent" in stats assert "events_sent" in stats
assert "events_failed" in stats assert "events_failed" in stats
@patch("app.notifications.requests.post") @patch("app.notifications._pinned_post")
def test_send_notification_success(self, mock_post, notification_service): @patch("app.notifications._resolve_and_check_url", return_value="93.184.216.34")
def test_send_notification_success(self, mock_resolve, mock_post, notification_service):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_post.return_value = mock_response mock_post.return_value = mock_response
@@ -337,8 +338,9 @@ class TestNotificationService:
notification_service._send_notification(event, destination) notification_service._send_notification(event, destination)
mock_post.assert_called_once() mock_post.assert_called_once()
@patch("app.notifications.requests.post") @patch("app.notifications._pinned_post")
def test_send_notification_retry_on_failure(self, mock_post, notification_service): @patch("app.notifications._resolve_and_check_url", return_value="93.184.216.34")
def test_send_notification_retry_on_failure(self, mock_resolve, mock_post, notification_service):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 500 mock_response.status_code = 500
mock_response.text = "Internal Server Error" mock_response.text = "Internal Server Error"