From 5ab62a00ff0255210ae46990a97eab46878a31b2 Mon Sep 17 00:00:00 2001 From: kqjy Date: Sun, 18 Jan 2026 17:18:12 +0800 Subject: [PATCH] Fix security vulnerabilities: XXE, timing attacks, info leaks --- app/__init__.py | 2 +- app/config.py | 10 +++- app/iam.py | 3 +- app/s3_api.py | 118 +++++++++++++++++++++++++++++++--------------- app/storage.py | 8 ++-- requirements.txt | 3 +- tests/test_api.py | 4 +- 7 files changed, 98 insertions(+), 50 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index c5f38a9..eac24c7 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -280,7 +280,7 @@ def create_app( @app.get("/healthz") def healthcheck() -> Dict[str, str]: - return {"status": "ok", "version": app.config.get("APP_VERSION", "unknown")} + return {"status": "ok"} return app diff --git a/app/config.py b/app/config.py index f585b51..80420d9 100644 --- a/app/config.py +++ b/app/config.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re import secrets import shutil import sys @@ -9,6 +10,13 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional + +def _validate_rate_limit(value: str) -> str: + pattern = r"^\d+\s+per\s+(second|minute|hour|day)$" + if not re.match(pattern, value): + raise ValueError(f"Invalid rate limit format: {value}. Expected format: '200 per minute'") + return value + if getattr(sys, "frozen", False): # Running in a PyInstaller bundle PROJECT_ROOT = Path(sys._MEIPASS) @@ -151,7 +159,7 @@ class AppConfig: log_path = log_dir / str(_get("LOG_FILE", "app.log")) log_max_bytes = int(_get("LOG_MAX_BYTES", 5 * 1024 * 1024)) log_backup_count = int(_get("LOG_BACKUP_COUNT", 3)) - ratelimit_default = str(_get("RATE_LIMIT_DEFAULT", "200 per minute")) + ratelimit_default = _validate_rate_limit(str(_get("RATE_LIMIT_DEFAULT", "200 per minute"))) ratelimit_storage_uri = str(_get("RATE_LIMIT_STORAGE_URI", "memory://")) def _csv(value: str, default: list[str]) -> list[str]: diff --git a/app/iam.py b/app/iam.py index 8ff7a4c..0e5e80f 100644 --- a/app/iam.py +++ b/app/iam.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hmac import json import math import secrets @@ -149,7 +150,7 @@ class IamService: f"Access temporarily locked. Try again in {seconds} seconds." ) record = self._users.get(access_key) - if not record or record["secret_key"] != secret_key: + if not record or not hmac.compare_digest(record["secret_key"], secret_key): self._record_failed_attempt(access_key) raise IamError("Invalid credentials") self._clear_failed_attempts(access_key) diff --git a/app/s3_api.py b/app/s3_api.py index 20fd821..b184b35 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -11,7 +11,8 @@ import uuid from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional from urllib.parse import quote, urlencode, urlparse, unquote -from xml.etree.ElementTree import Element, SubElement, tostring, fromstring, ParseError +from xml.etree.ElementTree import Element, SubElement, tostring, ParseError +from defusedxml.ElementTree import fromstring from flask import Blueprint, Response, current_app, jsonify, request, g from werkzeug.http import http_date @@ -29,6 +30,8 @@ from .storage import ObjectStorage, StorageError, QuotaExceededError, BucketNotF logger = logging.getLogger(__name__) +S3_NS = "http://s3.amazonaws.com/doc/2006-03-01/" + s3_api_bp = Blueprint("s3_api", __name__) def _storage() -> ObjectStorage: @@ -93,6 +96,13 @@ def _error_response(code: str, message: str, status: int) -> Response: return _xml_response(error, status) +def _require_xml_content_type() -> Response | None: + ct = request.headers.get("Content-Type", "") + if ct and not ct.startswith(("application/xml", "text/xml")): + return _error_response("InvalidRequest", "Content-Type must be application/xml or text/xml", 400) + return None + + def _parse_range_header(range_header: str, file_size: int) -> list[tuple[int, int]] | None: if not range_header.startswith("bytes="): return None @@ -232,16 +242,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: if not hmac.compare_digest(calculated_signature, signature): if current_app.config.get("DEBUG_SIGV4"): - logger.warning( - "SigV4 signature mismatch", - extra={ - "path": req.path, - "method": method, - "signed_headers": signed_headers_str, - "content_type": req.headers.get("Content-Type"), - "content_length": req.headers.get("Content-Length"), - } - ) + logger.warning("SigV4 signature mismatch for %s %s", method, req.path) raise IamError("SignatureDoesNotMatch") session_token = req.headers.get("X-Amz-Security-Token") @@ -307,7 +308,7 @@ def _verify_sigv4_query(req: Any) -> Principal | None: if header.lower() == 'expect' and val == "": val = "100-continue" val = " ".join(val.split()) - canonical_headers_parts.append(f"{header}:{val}\n") + canonical_headers_parts.append(f"{header.lower()}:{val}\n") canonical_headers = "".join(canonical_headers_parts) payload_hash = "UNSIGNED-PAYLOAD" @@ -661,11 +662,11 @@ def _strip_ns(tag: str | None) -> str: def _find_element(parent: Element, name: str) -> Optional[Element]: - """Find a child element by name, trying both namespaced and non-namespaced variants. + """Find a child element by name, trying S3 namespace then no namespace. This handles XML documents that may or may not include namespace prefixes. """ - el = parent.find(f"{{*}}{name}") + el = parent.find(f"{{{S3_NS}}}{name}") if el is None: el = parent.find(name) return el @@ -689,7 +690,7 @@ def _parse_tagging_document(payload: bytes) -> list[dict[str, str]]: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "Tagging": raise ValueError("Root element must be Tagging") - tagset = root.find(".//{*}TagSet") + tagset = root.find(".//{http://s3.amazonaws.com/doc/2006-03-01/}TagSet") if tagset is None: tagset = root.find("TagSet") if tagset is None: @@ -857,13 +858,13 @@ def _parse_encryption_document(payload: bytes) -> dict[str, Any]: bucket_key_el = child if default_el is None: continue - algo_el = default_el.find("{*}SSEAlgorithm") + algo_el = default_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}SSEAlgorithm") if algo_el is None: algo_el = default_el.find("SSEAlgorithm") if algo_el is None or not (algo_el.text or "").strip(): raise ValueError("SSEAlgorithm is required") rule: dict[str, Any] = {"SSEAlgorithm": algo_el.text.strip()} - kms_el = default_el.find("{*}KMSMasterKeyID") + kms_el = default_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}KMSMasterKeyID") if kms_el is None: kms_el = default_el.find("KMSMasterKeyID") if kms_el is not None and kms_el.text: @@ -964,8 +965,11 @@ def _bucket_versioning_handler(bucket_name: str) -> Response: except IamError as exc: return _error_response("AccessDenied", str(exc), 403) storage = _storage() - + if request.method == "PUT": + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) @@ -975,7 +979,7 @@ def _bucket_versioning_handler(bucket_name: str) -> Response: return _error_response("MalformedXML", "Unable to parse XML document", 400) if _strip_ns(root.tag) != "VersioningConfiguration": return _error_response("MalformedXML", "Root element must be VersioningConfiguration", 400) - status_el = root.find("{*}Status") + status_el = root.find("{http://s3.amazonaws.com/doc/2006-03-01/}Status") if status_el is None: status_el = root.find("Status") status = (status_el.text or "").strip() if status_el is not None else "" @@ -1024,6 +1028,9 @@ def _bucket_tagging_handler(bucket_name: str) -> Response: current_app.logger.info("Bucket tags deleted", extra={"bucket": bucket_name}) return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" try: tags = _parse_tagging_document(payload) @@ -1079,6 +1086,9 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response: current_app.logger.info("Object tags deleted", extra={"bucket": bucket_name, "key": object_key}) return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" try: tags = _parse_tagging_document(payload) @@ -1148,6 +1158,9 @@ def _bucket_cors_handler(bucket_name: str) -> Response: current_app.logger.info("Bucket CORS deleted", extra={"bucket": bucket_name}) return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): try: @@ -1194,6 +1207,9 @@ def _bucket_encryption_handler(bucket_name: str) -> Response: 404, ) return _xml_response(_render_encryption_document(config)) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): try: @@ -1415,6 +1431,9 @@ def _bucket_lifecycle_handler(bucket_name: str) -> Response: current_app.logger.info("Bucket lifecycle deleted", extra={"bucket": bucket_name}) return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) @@ -1479,49 +1498,49 @@ def _parse_lifecycle_config(payload: bytes) -> list: raise ValueError("Root element must be LifecycleConfiguration") rules = [] - for rule_el in root.findall("{*}Rule") or root.findall("Rule"): + for rule_el in root.findall("{http://s3.amazonaws.com/doc/2006-03-01/}Rule") or root.findall("Rule"): rule: dict = {} - id_el = rule_el.find("{*}ID") or rule_el.find("ID") + id_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}ID") or rule_el.find("ID") if id_el is not None and id_el.text: rule["ID"] = id_el.text.strip() - filter_el = rule_el.find("{*}Filter") or rule_el.find("Filter") + filter_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Filter") or rule_el.find("Filter") if filter_el is not None: - prefix_el = filter_el.find("{*}Prefix") or filter_el.find("Prefix") + prefix_el = filter_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Prefix") or filter_el.find("Prefix") if prefix_el is not None and prefix_el.text: rule["Prefix"] = prefix_el.text if "Prefix" not in rule: - prefix_el = rule_el.find("{*}Prefix") or rule_el.find("Prefix") + prefix_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Prefix") or rule_el.find("Prefix") if prefix_el is not None: rule["Prefix"] = prefix_el.text or "" - status_el = rule_el.find("{*}Status") or rule_el.find("Status") + status_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Status") or rule_el.find("Status") rule["Status"] = (status_el.text or "Enabled").strip() if status_el is not None else "Enabled" - exp_el = rule_el.find("{*}Expiration") or rule_el.find("Expiration") + exp_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Expiration") or rule_el.find("Expiration") if exp_el is not None: expiration: dict = {} - days_el = exp_el.find("{*}Days") or exp_el.find("Days") + days_el = exp_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Days") or exp_el.find("Days") if days_el is not None and days_el.text: days_val = int(days_el.text.strip()) if days_val <= 0: raise ValueError("Expiration Days must be a positive integer") expiration["Days"] = days_val - date_el = exp_el.find("{*}Date") or exp_el.find("Date") + date_el = exp_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Date") or exp_el.find("Date") if date_el is not None and date_el.text: expiration["Date"] = date_el.text.strip() - eodm_el = exp_el.find("{*}ExpiredObjectDeleteMarker") or exp_el.find("ExpiredObjectDeleteMarker") + eodm_el = exp_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}ExpiredObjectDeleteMarker") or exp_el.find("ExpiredObjectDeleteMarker") if eodm_el is not None and (eodm_el.text or "").strip().lower() in {"true", "1"}: expiration["ExpiredObjectDeleteMarker"] = True if expiration: rule["Expiration"] = expiration - nve_el = rule_el.find("{*}NoncurrentVersionExpiration") or rule_el.find("NoncurrentVersionExpiration") + nve_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}NoncurrentVersionExpiration") or rule_el.find("NoncurrentVersionExpiration") if nve_el is not None: nve: dict = {} - days_el = nve_el.find("{*}NoncurrentDays") or nve_el.find("NoncurrentDays") + days_el = nve_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}NoncurrentDays") or nve_el.find("NoncurrentDays") if days_el is not None and days_el.text: noncurrent_days = int(days_el.text.strip()) if noncurrent_days <= 0: @@ -1530,10 +1549,10 @@ def _parse_lifecycle_config(payload: bytes) -> list: if nve: rule["NoncurrentVersionExpiration"] = nve - aimu_el = rule_el.find("{*}AbortIncompleteMultipartUpload") or rule_el.find("AbortIncompleteMultipartUpload") + aimu_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}AbortIncompleteMultipartUpload") or rule_el.find("AbortIncompleteMultipartUpload") if aimu_el is not None: aimu: dict = {} - days_el = aimu_el.find("{*}DaysAfterInitiation") or aimu_el.find("DaysAfterInitiation") + days_el = aimu_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}DaysAfterInitiation") or aimu_el.find("DaysAfterInitiation") if days_el is not None and days_el.text: days_after = int(days_el.text.strip()) if days_after <= 0: @@ -1649,6 +1668,9 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response: SubElement(root, "ObjectLockEnabled").text = "Enabled" if config.enabled else "Disabled" return _xml_response(root) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) @@ -1658,7 +1680,7 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response: except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) - enabled_el = root.find("{*}ObjectLockEnabled") or root.find("ObjectLockEnabled") + enabled_el = root.find("{http://s3.amazonaws.com/doc/2006-03-01/}ObjectLockEnabled") or root.find("ObjectLockEnabled") enabled = (enabled_el.text or "").strip() == "Enabled" if enabled_el is not None else False config = ObjectLockConfig(enabled=enabled) @@ -1714,6 +1736,9 @@ def _bucket_notification_handler(bucket_name: str) -> Response: current_app.logger.info("Bucket notifications deleted", extra={"bucket": bucket_name}) return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): notification_service.delete_bucket_notifications(bucket_name) @@ -1725,9 +1750,9 @@ def _bucket_notification_handler(bucket_name: str) -> Response: return _error_response("MalformedXML", "Unable to parse XML document", 400) configs: list[NotificationConfiguration] = [] - for webhook_el in root.findall("{*}WebhookConfiguration") or root.findall("WebhookConfiguration"): + for webhook_el in root.findall("{http://s3.amazonaws.com/doc/2006-03-01/}WebhookConfiguration") or root.findall("WebhookConfiguration"): config_id = _find_element_text(webhook_el, "Id") or uuid.uuid4().hex - events = [el.text for el in webhook_el.findall("{*}Event") or webhook_el.findall("Event") if el.text] + events = [el.text for el in webhook_el.findall("{http://s3.amazonaws.com/doc/2006-03-01/}Event") or webhook_el.findall("Event") if el.text] dest_el = _find_element(webhook_el, "Destination") url = _find_element_text(dest_el, "Url") if dest_el else "" @@ -1740,7 +1765,7 @@ def _bucket_notification_handler(bucket_name: str) -> Response: if filter_el: key_el = _find_element(filter_el, "S3Key") if key_el: - for rule_el in key_el.findall("{*}FilterRule") or key_el.findall("FilterRule"): + for rule_el in key_el.findall("{http://s3.amazonaws.com/doc/2006-03-01/}FilterRule") or key_el.findall("FilterRule"): name = _find_element_text(rule_el, "Name") value = _find_element_text(rule_el, "Value") if name == "prefix": @@ -1793,6 +1818,9 @@ def _bucket_logging_handler(bucket_name: str) -> Response: current_app.logger.info("Bucket logging deleted", extra={"bucket": bucket_name}) return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): logging_service.delete_bucket_logging(bucket_name) @@ -1930,6 +1958,9 @@ def _object_retention_handler(bucket_name: str, object_key: str) -> Response: SubElement(root, "RetainUntilDate").text = retention.retain_until_date.strftime("%Y-%m-%dT%H:%M:%S.000Z") return _xml_response(root) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) @@ -1999,6 +2030,9 @@ def _object_legal_hold_handler(bucket_name: str, object_key: str) -> Response: SubElement(root, "Status").text = "ON" if enabled else "OFF" return _xml_response(root) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) @@ -2030,6 +2064,9 @@ def _bulk_delete_handler(bucket_name: str) -> Response: except IamError as exc: return _error_response("AccessDenied", str(exc), 403) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" if not payload.strip(): return _error_response("MalformedXML", "Request body must include a Delete specification", 400) @@ -3003,6 +3040,9 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: if not upload_id: return _error_response("InvalidArgument", "uploadId is required", 400) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error payload = request.get_data(cache=False) or b"" try: root = fromstring(payload) @@ -3016,11 +3056,11 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: for part_el in list(root): if _strip_ns(part_el.tag) != "Part": continue - part_number_el = part_el.find("{*}PartNumber") + part_number_el = part_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}PartNumber") if part_number_el is None: part_number_el = part_el.find("PartNumber") - etag_el = part_el.find("{*}ETag") + etag_el = part_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}ETag") if etag_el is None: etag_el = part_el.find("ETag") diff --git a/app/storage.py b/app/storage.py index b5e38be..70488d0 100644 --- a/app/storage.py +++ b/app/storage.py @@ -1773,11 +1773,9 @@ class ObjectStorage: raise StorageError("Object key contains null bytes") if object_key.startswith(("/", "\\")): raise StorageError("Object key cannot start with a slash") - normalized = unicodedata.normalize("NFC", object_key) - if normalized != object_key: - raise StorageError("Object key must use normalized Unicode") - - candidate = Path(normalized) + object_key = unicodedata.normalize("NFC", object_key) + + candidate = Path(object_key) if ".." in candidate.parts: raise StorageError("Object key contains parent directory references") diff --git a/requirements.txt b/requirements.txt index 8fe9bb3..17915fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ requests>=2.32.5 boto3>=1.42.14 waitress>=3.0.2 psutil>=7.1.3 -cryptography>=46.0.3 \ No newline at end of file +cryptography>=46.0.3 +defusedxml>=0.7.1 \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index affe9ec..b2859cb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -104,12 +104,12 @@ def test_request_id_header_present(client, signer): assert response.headers.get("X-Request-ID") -def test_healthcheck_returns_version(client): +def test_healthcheck_returns_status(client): response = client.get("/healthz") data = response.get_json() assert response.status_code == 200 assert data["status"] == "ok" - assert "version" in data + assert "version" not in data def test_missing_credentials_denied(client):