Fix security vulnerabilities: XXE, timing attacks, info leaks

This commit is contained in:
2026-01-18 17:18:12 +08:00
parent 9c3518de63
commit 5ab62a00ff
7 changed files with 98 additions and 50 deletions

View File

@@ -280,7 +280,7 @@ def create_app(
@app.get("/healthz")
def healthcheck() -> Dict[str, str]:
return {"status": "ok", "version": app.config.get("APP_VERSION", "unknown")}
return {"status": "ok"}
return app

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import os
import re
import secrets
import shutil
import sys
@@ -9,6 +10,13 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
def _validate_rate_limit(value: str) -> str:
pattern = r"^\d+\s+per\s+(second|minute|hour|day)$"
if not re.match(pattern, value):
raise ValueError(f"Invalid rate limit format: {value}. Expected format: '200 per minute'")
return value
if getattr(sys, "frozen", False):
# Running in a PyInstaller bundle
PROJECT_ROOT = Path(sys._MEIPASS)
@@ -151,7 +159,7 @@ class AppConfig:
log_path = log_dir / str(_get("LOG_FILE", "app.log"))
log_max_bytes = int(_get("LOG_MAX_BYTES", 5 * 1024 * 1024))
log_backup_count = int(_get("LOG_BACKUP_COUNT", 3))
ratelimit_default = str(_get("RATE_LIMIT_DEFAULT", "200 per minute"))
ratelimit_default = _validate_rate_limit(str(_get("RATE_LIMIT_DEFAULT", "200 per minute")))
ratelimit_storage_uri = str(_get("RATE_LIMIT_STORAGE_URI", "memory://"))
def _csv(value: str, default: list[str]) -> list[str]:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import hmac
import json
import math
import secrets
@@ -149,7 +150,7 @@ class IamService:
f"Access temporarily locked. Try again in {seconds} seconds."
)
record = self._users.get(access_key)
if not record or record["secret_key"] != secret_key:
if not record or not hmac.compare_digest(record["secret_key"], secret_key):
self._record_failed_attempt(access_key)
raise IamError("Invalid credentials")
self._clear_failed_attempts(access_key)

View File

@@ -11,7 +11,8 @@ import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
from urllib.parse import quote, urlencode, urlparse, unquote
from xml.etree.ElementTree import Element, SubElement, tostring, fromstring, ParseError
from xml.etree.ElementTree import Element, SubElement, tostring, ParseError
from defusedxml.ElementTree import fromstring
from flask import Blueprint, Response, current_app, jsonify, request, g
from werkzeug.http import http_date
@@ -29,6 +30,8 @@ from .storage import ObjectStorage, StorageError, QuotaExceededError, BucketNotF
logger = logging.getLogger(__name__)
S3_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
s3_api_bp = Blueprint("s3_api", __name__)
def _storage() -> ObjectStorage:
@@ -93,6 +96,13 @@ def _error_response(code: str, message: str, status: int) -> Response:
return _xml_response(error, status)
def _require_xml_content_type() -> Response | None:
ct = request.headers.get("Content-Type", "")
if ct and not ct.startswith(("application/xml", "text/xml")):
return _error_response("InvalidRequest", "Content-Type must be application/xml or text/xml", 400)
return None
def _parse_range_header(range_header: str, file_size: int) -> list[tuple[int, int]] | None:
if not range_header.startswith("bytes="):
return None
@@ -232,16 +242,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None:
if not hmac.compare_digest(calculated_signature, signature):
if current_app.config.get("DEBUG_SIGV4"):
logger.warning(
"SigV4 signature mismatch",
extra={
"path": req.path,
"method": method,
"signed_headers": signed_headers_str,
"content_type": req.headers.get("Content-Type"),
"content_length": req.headers.get("Content-Length"),
}
)
logger.warning("SigV4 signature mismatch for %s %s", method, req.path)
raise IamError("SignatureDoesNotMatch")
session_token = req.headers.get("X-Amz-Security-Token")
@@ -307,7 +308,7 @@ def _verify_sigv4_query(req: Any) -> Principal | None:
if header.lower() == 'expect' and val == "":
val = "100-continue"
val = " ".join(val.split())
canonical_headers_parts.append(f"{header}:{val}\n")
canonical_headers_parts.append(f"{header.lower()}:{val}\n")
canonical_headers = "".join(canonical_headers_parts)
payload_hash = "UNSIGNED-PAYLOAD"
@@ -661,11 +662,11 @@ def _strip_ns(tag: str | None) -> str:
def _find_element(parent: Element, name: str) -> Optional[Element]:
"""Find a child element by name, trying both namespaced and non-namespaced variants.
"""Find a child element by name, trying S3 namespace then no namespace.
This handles XML documents that may or may not include namespace prefixes.
"""
el = parent.find(f"{{*}}{name}")
el = parent.find(f"{{{S3_NS}}}{name}")
if el is None:
el = parent.find(name)
return el
@@ -689,7 +690,7 @@ def _parse_tagging_document(payload: bytes) -> list[dict[str, str]]:
raise ValueError("Malformed XML") from exc
if _strip_ns(root.tag) != "Tagging":
raise ValueError("Root element must be Tagging")
tagset = root.find(".//{*}TagSet")
tagset = root.find(".//{http://s3.amazonaws.com/doc/2006-03-01/}TagSet")
if tagset is None:
tagset = root.find("TagSet")
if tagset is None:
@@ -857,13 +858,13 @@ def _parse_encryption_document(payload: bytes) -> dict[str, Any]:
bucket_key_el = child
if default_el is None:
continue
algo_el = default_el.find("{*}SSEAlgorithm")
algo_el = default_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}SSEAlgorithm")
if algo_el is None:
algo_el = default_el.find("SSEAlgorithm")
if algo_el is None or not (algo_el.text or "").strip():
raise ValueError("SSEAlgorithm is required")
rule: dict[str, Any] = {"SSEAlgorithm": algo_el.text.strip()}
kms_el = default_el.find("{*}KMSMasterKeyID")
kms_el = default_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}KMSMasterKeyID")
if kms_el is None:
kms_el = default_el.find("KMSMasterKeyID")
if kms_el is not None and kms_el.text:
@@ -964,8 +965,11 @@ def _bucket_versioning_handler(bucket_name: str) -> Response:
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
storage = _storage()
if request.method == "PUT":
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
return _error_response("MalformedXML", "Request body is required", 400)
@@ -975,7 +979,7 @@ def _bucket_versioning_handler(bucket_name: str) -> Response:
return _error_response("MalformedXML", "Unable to parse XML document", 400)
if _strip_ns(root.tag) != "VersioningConfiguration":
return _error_response("MalformedXML", "Root element must be VersioningConfiguration", 400)
status_el = root.find("{*}Status")
status_el = root.find("{http://s3.amazonaws.com/doc/2006-03-01/}Status")
if status_el is None:
status_el = root.find("Status")
status = (status_el.text or "").strip() if status_el is not None else ""
@@ -1024,6 +1028,9 @@ def _bucket_tagging_handler(bucket_name: str) -> Response:
current_app.logger.info("Bucket tags deleted", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
tags = _parse_tagging_document(payload)
@@ -1079,6 +1086,9 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response:
current_app.logger.info("Object tags deleted", extra={"bucket": bucket_name, "key": object_key})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
tags = _parse_tagging_document(payload)
@@ -1148,6 +1158,9 @@ def _bucket_cors_handler(bucket_name: str) -> Response:
current_app.logger.info("Bucket CORS deleted", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
try:
@@ -1194,6 +1207,9 @@ def _bucket_encryption_handler(bucket_name: str) -> Response:
404,
)
return _xml_response(_render_encryption_document(config))
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
try:
@@ -1415,6 +1431,9 @@ def _bucket_lifecycle_handler(bucket_name: str) -> Response:
current_app.logger.info("Bucket lifecycle deleted", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
return _error_response("MalformedXML", "Request body is required", 400)
@@ -1479,49 +1498,49 @@ def _parse_lifecycle_config(payload: bytes) -> list:
raise ValueError("Root element must be LifecycleConfiguration")
rules = []
for rule_el in root.findall("{*}Rule") or root.findall("Rule"):
for rule_el in root.findall("{http://s3.amazonaws.com/doc/2006-03-01/}Rule") or root.findall("Rule"):
rule: dict = {}
id_el = rule_el.find("{*}ID") or rule_el.find("ID")
id_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}ID") or rule_el.find("ID")
if id_el is not None and id_el.text:
rule["ID"] = id_el.text.strip()
filter_el = rule_el.find("{*}Filter") or rule_el.find("Filter")
filter_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Filter") or rule_el.find("Filter")
if filter_el is not None:
prefix_el = filter_el.find("{*}Prefix") or filter_el.find("Prefix")
prefix_el = filter_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Prefix") or filter_el.find("Prefix")
if prefix_el is not None and prefix_el.text:
rule["Prefix"] = prefix_el.text
if "Prefix" not in rule:
prefix_el = rule_el.find("{*}Prefix") or rule_el.find("Prefix")
prefix_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Prefix") or rule_el.find("Prefix")
if prefix_el is not None:
rule["Prefix"] = prefix_el.text or ""
status_el = rule_el.find("{*}Status") or rule_el.find("Status")
status_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Status") or rule_el.find("Status")
rule["Status"] = (status_el.text or "Enabled").strip() if status_el is not None else "Enabled"
exp_el = rule_el.find("{*}Expiration") or rule_el.find("Expiration")
exp_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Expiration") or rule_el.find("Expiration")
if exp_el is not None:
expiration: dict = {}
days_el = exp_el.find("{*}Days") or exp_el.find("Days")
days_el = exp_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Days") or exp_el.find("Days")
if days_el is not None and days_el.text:
days_val = int(days_el.text.strip())
if days_val <= 0:
raise ValueError("Expiration Days must be a positive integer")
expiration["Days"] = days_val
date_el = exp_el.find("{*}Date") or exp_el.find("Date")
date_el = exp_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}Date") or exp_el.find("Date")
if date_el is not None and date_el.text:
expiration["Date"] = date_el.text.strip()
eodm_el = exp_el.find("{*}ExpiredObjectDeleteMarker") or exp_el.find("ExpiredObjectDeleteMarker")
eodm_el = exp_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}ExpiredObjectDeleteMarker") or exp_el.find("ExpiredObjectDeleteMarker")
if eodm_el is not None and (eodm_el.text or "").strip().lower() in {"true", "1"}:
expiration["ExpiredObjectDeleteMarker"] = True
if expiration:
rule["Expiration"] = expiration
nve_el = rule_el.find("{*}NoncurrentVersionExpiration") or rule_el.find("NoncurrentVersionExpiration")
nve_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}NoncurrentVersionExpiration") or rule_el.find("NoncurrentVersionExpiration")
if nve_el is not None:
nve: dict = {}
days_el = nve_el.find("{*}NoncurrentDays") or nve_el.find("NoncurrentDays")
days_el = nve_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}NoncurrentDays") or nve_el.find("NoncurrentDays")
if days_el is not None and days_el.text:
noncurrent_days = int(days_el.text.strip())
if noncurrent_days <= 0:
@@ -1530,10 +1549,10 @@ def _parse_lifecycle_config(payload: bytes) -> list:
if nve:
rule["NoncurrentVersionExpiration"] = nve
aimu_el = rule_el.find("{*}AbortIncompleteMultipartUpload") or rule_el.find("AbortIncompleteMultipartUpload")
aimu_el = rule_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}AbortIncompleteMultipartUpload") or rule_el.find("AbortIncompleteMultipartUpload")
if aimu_el is not None:
aimu: dict = {}
days_el = aimu_el.find("{*}DaysAfterInitiation") or aimu_el.find("DaysAfterInitiation")
days_el = aimu_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}DaysAfterInitiation") or aimu_el.find("DaysAfterInitiation")
if days_el is not None and days_el.text:
days_after = int(days_el.text.strip())
if days_after <= 0:
@@ -1649,6 +1668,9 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response:
SubElement(root, "ObjectLockEnabled").text = "Enabled" if config.enabled else "Disabled"
return _xml_response(root)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
return _error_response("MalformedXML", "Request body is required", 400)
@@ -1658,7 +1680,7 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response:
except ParseError:
return _error_response("MalformedXML", "Unable to parse XML document", 400)
enabled_el = root.find("{*}ObjectLockEnabled") or root.find("ObjectLockEnabled")
enabled_el = root.find("{http://s3.amazonaws.com/doc/2006-03-01/}ObjectLockEnabled") or root.find("ObjectLockEnabled")
enabled = (enabled_el.text or "").strip() == "Enabled" if enabled_el is not None else False
config = ObjectLockConfig(enabled=enabled)
@@ -1714,6 +1736,9 @@ def _bucket_notification_handler(bucket_name: str) -> Response:
current_app.logger.info("Bucket notifications deleted", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
notification_service.delete_bucket_notifications(bucket_name)
@@ -1725,9 +1750,9 @@ def _bucket_notification_handler(bucket_name: str) -> Response:
return _error_response("MalformedXML", "Unable to parse XML document", 400)
configs: list[NotificationConfiguration] = []
for webhook_el in root.findall("{*}WebhookConfiguration") or root.findall("WebhookConfiguration"):
for webhook_el in root.findall("{http://s3.amazonaws.com/doc/2006-03-01/}WebhookConfiguration") or root.findall("WebhookConfiguration"):
config_id = _find_element_text(webhook_el, "Id") or uuid.uuid4().hex
events = [el.text for el in webhook_el.findall("{*}Event") or webhook_el.findall("Event") if el.text]
events = [el.text for el in webhook_el.findall("{http://s3.amazonaws.com/doc/2006-03-01/}Event") or webhook_el.findall("Event") if el.text]
dest_el = _find_element(webhook_el, "Destination")
url = _find_element_text(dest_el, "Url") if dest_el else ""
@@ -1740,7 +1765,7 @@ def _bucket_notification_handler(bucket_name: str) -> Response:
if filter_el:
key_el = _find_element(filter_el, "S3Key")
if key_el:
for rule_el in key_el.findall("{*}FilterRule") or key_el.findall("FilterRule"):
for rule_el in key_el.findall("{http://s3.amazonaws.com/doc/2006-03-01/}FilterRule") or key_el.findall("FilterRule"):
name = _find_element_text(rule_el, "Name")
value = _find_element_text(rule_el, "Value")
if name == "prefix":
@@ -1793,6 +1818,9 @@ def _bucket_logging_handler(bucket_name: str) -> Response:
current_app.logger.info("Bucket logging deleted", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
logging_service.delete_bucket_logging(bucket_name)
@@ -1930,6 +1958,9 @@ def _object_retention_handler(bucket_name: str, object_key: str) -> Response:
SubElement(root, "RetainUntilDate").text = retention.retain_until_date.strftime("%Y-%m-%dT%H:%M:%S.000Z")
return _xml_response(root)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
return _error_response("MalformedXML", "Request body is required", 400)
@@ -1999,6 +2030,9 @@ def _object_legal_hold_handler(bucket_name: str, object_key: str) -> Response:
SubElement(root, "Status").text = "ON" if enabled else "OFF"
return _xml_response(root)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
return _error_response("MalformedXML", "Request body is required", 400)
@@ -2030,6 +2064,9 @@ def _bulk_delete_handler(bucket_name: str) -> Response:
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
if not payload.strip():
return _error_response("MalformedXML", "Request body must include a Delete specification", 400)
@@ -3003,6 +3040,9 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
if not upload_id:
return _error_response("InvalidArgument", "uploadId is required", 400)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
root = fromstring(payload)
@@ -3016,11 +3056,11 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
for part_el in list(root):
if _strip_ns(part_el.tag) != "Part":
continue
part_number_el = part_el.find("{*}PartNumber")
part_number_el = part_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}PartNumber")
if part_number_el is None:
part_number_el = part_el.find("PartNumber")
etag_el = part_el.find("{*}ETag")
etag_el = part_el.find("{http://s3.amazonaws.com/doc/2006-03-01/}ETag")
if etag_el is None:
etag_el = part_el.find("ETag")

View File

@@ -1773,11 +1773,9 @@ class ObjectStorage:
raise StorageError("Object key contains null bytes")
if object_key.startswith(("/", "\\")):
raise StorageError("Object key cannot start with a slash")
normalized = unicodedata.normalize("NFC", object_key)
if normalized != object_key:
raise StorageError("Object key must use normalized Unicode")
candidate = Path(normalized)
object_key = unicodedata.normalize("NFC", object_key)
candidate = Path(object_key)
if ".." in candidate.parts:
raise StorageError("Object key contains parent directory references")

View File

@@ -8,4 +8,5 @@ requests>=2.32.5
boto3>=1.42.14
waitress>=3.0.2
psutil>=7.1.3
cryptography>=46.0.3
cryptography>=46.0.3
defusedxml>=0.7.1

View File

@@ -104,12 +104,12 @@ def test_request_id_header_present(client, signer):
assert response.headers.get("X-Request-ID")
def test_healthcheck_returns_version(client):
def test_healthcheck_returns_status(client):
response = client.get("/healthz")
data = response.get_json()
assert response.status_code == 200
assert data["status"] == "ok"
assert "version" in data
assert "version" not in data
def test_missing_credentials_denied(client):