From 9629507acdd445919d575c1f58d685ba3b3643ec Mon Sep 17 00:00:00 2001 From: kqjy Date: Sun, 1 Feb 2026 18:12:03 +0800 Subject: [PATCH] Fix auth bypass, user enumeration, xml DoS, multipart race, path traversal unicode, silent permissions failures, data key without AAD, KMS streaming --- app/admin_api.py | 49 ++++++++++++++++ app/encryption.py | 56 +++++++++++++------ app/kms.py | 15 ++++- app/s3_api.py | 139 +++++++++++++++++++++++++++++++--------------- app/storage.py | 80 ++++++++++++++++---------- 5 files changed, 246 insertions(+), 93 deletions(-) diff --git a/app/admin_api.py b/app/admin_api.py index b565ba4..e8d8609 100644 --- a/app/admin_api.py +++ b/app/admin_api.py @@ -78,6 +78,16 @@ def _validate_region(region: str) -> Optional[str]: return "Region must match format like us-east-1" return None + +def _validate_site_id(site_id: str) -> Optional[str]: + """Validate site_id format. Returns error message or None.""" + if not site_id or len(site_id) > 63: + return "site_id must be 1-63 characters" + if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9_-]*$', site_id): + return "site_id must start with alphanumeric and contain only alphanumeric, hyphens, underscores" + return None + + logger = logging.getLogger(__name__) admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin") @@ -168,6 +178,25 @@ def update_local_site(): if not site_id: return _json_error("ValidationError", "site_id is required", 400) + site_id_error = _validate_site_id(site_id) + if site_id_error: + return _json_error("ValidationError", site_id_error, 400) + + if endpoint: + endpoint_error = _validate_endpoint(endpoint) + if endpoint_error: + return _json_error("ValidationError", endpoint_error, 400) + + if "priority" in payload: + priority_error = _validate_priority(payload["priority"]) + if priority_error: + return _json_error("ValidationError", priority_error, 400) + + if "region" in payload: + region_error = _validate_region(payload["region"]) + if region_error: + return _json_error("ValidationError", region_error, 400) + registry = _site_registry() existing = registry.get_local_site() @@ -220,6 +249,11 @@ def register_peer_site(): if not site_id: return _json_error("ValidationError", "site_id is required", 400) + + site_id_error = _validate_site_id(site_id) + if site_id_error: + return _json_error("ValidationError", site_id_error, 400) + if not endpoint: return _json_error("ValidationError", "endpoint is required", 400) @@ -293,6 +327,21 @@ def update_peer_site(site_id: str): payload = request.get_json(silent=True) or {} + if "endpoint" in payload: + endpoint_error = _validate_endpoint(payload["endpoint"]) + if endpoint_error: + return _json_error("ValidationError", endpoint_error, 400) + + if "priority" in payload: + priority_error = _validate_priority(payload["priority"]) + if priority_error: + return _json_error("ValidationError", priority_error, 400) + + if "region" in payload: + region_error = _validate_region(payload["region"]) + if region_error: + return _json_error("ValidationError", region_error, 400) + peer = PeerSite( site_id=site_id, endpoint=payload.get("endpoint", existing.endpoint), diff --git a/app/encryption.py b/app/encryption.py index d9c1679..6d8c2b2 100644 --- a/app/encryption.py +++ b/app/encryption.py @@ -1,9 +1,9 @@ -"""Encryption providers for server-side and client-side encryption.""" from __future__ import annotations import base64 import io import json +import logging import os import secrets import subprocess @@ -19,6 +19,8 @@ from cryptography.hazmat.primitives import hashes if sys.platform != "win32": import fcntl +logger = logging.getLogger(__name__) + def _set_secure_file_permissions(file_path: Path) -> None: """Set restrictive file permissions (owner read/write only).""" @@ -31,8 +33,10 @@ def _set_secure_file_permissions(file_path: Path) -> None: "/grant:r", f"{username}:F"], check=True, capture_output=True ) - except (subprocess.SubprocessError, OSError): - pass + else: + logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path) + except (subprocess.SubprocessError, OSError) as exc: + logger.warning("Failed to set secure permissions on %s: %s", file_path, exc) else: os.chmod(file_path, 0o600) @@ -84,22 +88,34 @@ class EncryptionMetadata: class EncryptionProvider: """Base class for encryption providers.""" - + def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: raise NotImplementedError - + def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, key_id: str, context: Dict[str, str] | None = None) -> bytes: raise NotImplementedError - + def generate_data_key(self) -> tuple[bytes, bytes]: """Generate a data key and its encrypted form. - + Returns: Tuple of (plaintext_key, encrypted_key) """ raise NotImplementedError + def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes: + """Decrypt an encrypted data key. + + Args: + encrypted_data_key: The encrypted data key bytes + key_id: Optional key identifier (used by KMS providers) + + Returns: + The decrypted data key + """ + raise NotImplementedError + class LocalKeyEncryption(EncryptionProvider): """SSE-S3 style encryption using a local master key. @@ -157,13 +173,15 @@ class LocalKeyEncryption(EncryptionProvider): except OSError as exc: raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc + DATA_KEY_AAD = b'{"purpose":"data_key","version":1}' + def _encrypt_data_key(self, data_key: bytes) -> bytes: """Encrypt the data key with the master key.""" aesgcm = AESGCM(self.master_key) nonce = secrets.token_bytes(12) - encrypted = aesgcm.encrypt(nonce, data_key, None) + encrypted = aesgcm.encrypt(nonce, data_key, self.DATA_KEY_AAD) return nonce + encrypted - + def _decrypt_data_key(self, encrypted_data_key: bytes) -> bytes: """Decrypt the data key using the master key.""" if len(encrypted_data_key) < 12 + 32 + 16: # nonce + key + tag @@ -172,10 +190,17 @@ class LocalKeyEncryption(EncryptionProvider): nonce = encrypted_data_key[:12] ciphertext = encrypted_data_key[12:] try: - return aesgcm.decrypt(nonce, ciphertext, None) - except Exception as exc: - raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc - + return aesgcm.decrypt(nonce, ciphertext, self.DATA_KEY_AAD) + except Exception: + try: + return aesgcm.decrypt(nonce, ciphertext, None) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc + + def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes: + """Decrypt an encrypted data key (key_id ignored for local encryption).""" + return self._decrypt_data_key(encrypted_data_key) + def generate_data_key(self) -> tuple[bytes, bytes]: """Generate a data key and its encrypted form.""" plaintext_key = secrets.token_bytes(32) @@ -281,10 +306,7 @@ class StreamingEncryptor: Performance: Writes chunks directly to output buffer instead of accumulating in list. """ - if isinstance(self.provider, LocalKeyEncryption): - data_key = self.provider._decrypt_data_key(metadata.encrypted_data_key) - else: - raise EncryptionError("Unsupported provider for streaming decryption") + data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id) aesgcm = AESGCM(data_key) base_nonce = metadata.nonce diff --git a/app/kms.py b/app/kms.py index 884f975..dbd07e0 100644 --- a/app/kms.py +++ b/app/kms.py @@ -34,8 +34,10 @@ def _set_secure_file_permissions(file_path: Path) -> None: "/grant:r", f"{username}:F"], check=True, capture_output=True ) - except (subprocess.SubprocessError, OSError): - pass + else: + logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path) + except (subprocess.SubprocessError, OSError) as exc: + logger.warning("Failed to set secure permissions on %s: %s", file_path, exc) else: os.chmod(file_path, 0o600) @@ -128,6 +130,15 @@ class KMSEncryptionProvider(EncryptionProvider): logger.debug("KMS decryption failed: %s", exc) raise EncryptionError("Failed to decrypt data") from exc + def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes: + """Decrypt an encrypted data key using KMS.""" + if key_id is None: + key_id = self.key_id + data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None) + if len(data_key) != 32: + raise EncryptionError("Invalid data key size") + return data_key + class KMSManager: """Manages KMS keys and operations. diff --git a/app/s3_api.py b/app/s3_api.py index f784e38..822997b 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -1,4 +1,3 @@ -"""Flask blueprint exposing a subset of the S3 REST API.""" from __future__ import annotations import base64 @@ -32,6 +31,24 @@ logger = logging.getLogger(__name__) S3_NS = "http://s3.amazonaws.com/doc/2006-03-01/" +_HEADER_CONTROL_CHARS = re.compile(r'[\r\n\x00-\x1f\x7f]') + + +def _sanitize_header_value(value: str) -> str: + return _HEADER_CONTROL_CHARS.sub('', value) + + +MAX_XML_PAYLOAD_SIZE = 1048576 # 1 MB + + +def _parse_xml_with_limit(payload: bytes) -> Element: + """Parse XML payload with size limit to prevent DoS attacks.""" + max_size = current_app.config.get("MAX_XML_PAYLOAD_SIZE", MAX_XML_PAYLOAD_SIZE) + if len(payload) > max_size: + raise ParseError(f"XML payload exceeds maximum size of {max_size} bytes") + return fromstring(payload) + + s3_api_bp = Blueprint("s3_api", __name__) def _storage() -> ObjectStorage: @@ -126,30 +143,36 @@ def _require_xml_content_type() -> Response | None: def _parse_range_header(range_header: str, file_size: int) -> list[tuple[int, int]] | None: if not range_header.startswith("bytes="): return None + max_range_value = 2**63 - 1 ranges = [] range_spec = range_header[6:] for part in range_spec.split(","): part = part.strip() if not part: continue - if part.startswith("-"): - suffix_length = int(part[1:]) - if suffix_length <= 0: - return None - start = max(0, file_size - suffix_length) - end = file_size - 1 - elif part.endswith("-"): - start = int(part[:-1]) - if start >= file_size: - return None - end = file_size - 1 - else: - start_str, end_str = part.split("-", 1) - start = int(start_str) - end = int(end_str) - if start > end or start >= file_size: - return None - end = min(end, file_size - 1) + try: + if part.startswith("-"): + suffix_length = int(part[1:]) + if suffix_length <= 0 or suffix_length > max_range_value: + return None + start = max(0, file_size - suffix_length) + end = file_size - 1 + elif part.endswith("-"): + start = int(part[:-1]) + if start < 0 or start > max_range_value or start >= file_size: + return None + end = file_size - 1 + else: + start_str, end_str = part.split("-", 1) + start = int(start_str) + end = int(end_str) + if start < 0 or end < 0 or start > max_range_value or end > max_range_value: + return None + if start > end or start >= file_size: + return None + end = min(end, file_size - 1) + except (ValueError, OverflowError): + return None ranges.append((start, end)) return ranges if ranges else None @@ -196,7 +219,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: access_key, date_stamp, region, service, signed_headers_str, signature = match.groups() secret_key = _iam().get_secret_key(access_key) if not secret_key: - raise IamError("Invalid access key") + raise IamError("SignatureDoesNotMatch") method = req.method canonical_uri = _get_canonical_uri(req) @@ -379,16 +402,18 @@ def _verify_sigv4(req: Any) -> Principal | None: def _require_principal(): - if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \ - (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"): + sigv4_attempted = ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \ + (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256") + if sigv4_attempted: try: principal = _verify_sigv4(request) if principal: return principal, None + return None, _error_response("AccessDenied", "Signature verification failed", 403) except IamError as exc: return None, _error_response("AccessDenied", str(exc), 403) except (ValueError, TypeError): - return None, _error_response("AccessDenied", "Signature verification failed", 403) + return None, _error_response("AccessDenied", "Signature verification failed", 403) access_key = request.headers.get("X-Access-Key") secret_key = request.headers.get("X-Secret-Key") @@ -709,7 +734,7 @@ def _find_element_text(parent: Element, name: str, default: str = "") -> str: def _parse_tagging_document(payload: bytes) -> list[dict[str, str]]: try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "Tagging": @@ -810,7 +835,7 @@ def _validate_content_type(object_key: str, content_type: str | None) -> str | N def _parse_cors_document(payload: bytes) -> list[dict[str, Any]]: try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "CORSConfiguration": @@ -863,7 +888,7 @@ def _render_cors_document(rules: list[dict[str, Any]]) -> Element: def _parse_encryption_document(payload: bytes) -> dict[str, Any]: try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError("Malformed XML") from exc if _strip_ns(root.tag) != "ServerSideEncryptionConfiguration": @@ -1000,7 +1025,7 @@ def _bucket_versioning_handler(bucket_name: str) -> Response: if not payload.strip(): return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) if _strip_ns(root.tag) != "VersioningConfiguration": @@ -1387,7 +1412,10 @@ def _bucket_list_versions_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") try: - max_keys = max(1, min(int(request.args.get("max-keys", 1000)), 1000)) + max_keys = int(request.args.get("max-keys", 1000)) + if max_keys < 1: + return _error_response("InvalidArgument", "max-keys must be a positive integer", 400) + max_keys = min(max_keys, 1000) except ValueError: return _error_response("InvalidArgument", "max-keys must be an integer", 400) key_marker = request.args.get("key-marker", "") @@ -1551,7 +1579,7 @@ def _render_lifecycle_config(config: list) -> Element: def _parse_lifecycle_config(payload: bytes) -> list: """Parse lifecycle configuration from XML.""" try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError as exc: raise ValueError(f"Unable to parse XML document: {exc}") from exc @@ -1737,7 +1765,7 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response: return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -1806,7 +1834,7 @@ def _bucket_notification_handler(bucket_name: str) -> Response: return Response(status=200) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -1888,7 +1916,7 @@ def _bucket_logging_handler(bucket_name: str) -> Response: return Response(status=200) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -1941,7 +1969,10 @@ def _bucket_uploads_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") try: - max_uploads = max(1, min(int(request.args.get("max-uploads", 1000)), 1000)) + max_uploads = int(request.args.get("max-uploads", 1000)) + if max_uploads < 1: + return _error_response("InvalidArgument", "max-uploads must be a positive integer", 400) + max_uploads = min(max_uploads, 1000) except ValueError: return _error_response("InvalidArgument", "max-uploads must be an integer", 400) @@ -2027,7 +2058,7 @@ def _object_retention_handler(bucket_name: str, object_key: str) -> Response: return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -2099,7 +2130,7 @@ def _object_legal_hold_handler(bucket_name: str, object_key: str) -> Response: return _error_response("MalformedXML", "Request body is required", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -2132,7 +2163,7 @@ def _bulk_delete_handler(bucket_name: str) -> Response: if not payload.strip(): return _error_response("MalformedXML", "Request body must include a Delete specification", 400) try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) if _strip_ns(root.tag) != "Delete": @@ -2339,7 +2370,10 @@ def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditio if actual_value != expected: return f"Field {field} must equal {expected}" elif operator == "content-length-range" and len(condition) == 3: - min_size, max_size = condition[1], condition[2] + try: + min_size, max_size = int(condition[1]), int(condition[2]) + except (TypeError, ValueError): + return "Invalid content-length-range values" if content_length < min_size or content_length > max_size: return f"Content length must be between {min_size} and {max_size}" return None @@ -2437,10 +2471,13 @@ def bucket_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") try: - max_keys = max(1, min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000)) + max_keys = int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])) + if max_keys < 1: + return _error_response("InvalidArgument", "max-keys must be a positive integer", 400) + max_keys = min(max_keys, 1000) except ValueError: return _error_response("InvalidArgument", "max-keys must be an integer", 400) - + marker = request.args.get("marker", "") # ListObjects v1 continuation_token = request.args.get("continuation-token", "") # ListObjectsV2 start_after = request.args.get("start-after", "") # ListObjectsV2 @@ -2766,7 +2803,7 @@ def object_handler(bucket_name: str, object_key: str): for param, header in response_overrides.items(): value = request.args.get(param) if value: - response.headers[header] = value + response.headers[header] = _sanitize_header_value(value) action = "Object read" if request.method == "GET" else "Object head" current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes}) @@ -2924,7 +2961,7 @@ def _bucket_replication_handler(bucket_name: str) -> Response: def _parse_replication_config(bucket_name: str, payload: bytes): from .replication import ReplicationRule, REPLICATION_MODE_ALL - root = fromstring(payload) + root = _parse_xml_with_limit(payload) if _strip_ns(root.tag) != "ReplicationConfiguration": raise ValueError("Root element must be ReplicationConfiguration") rule_el = None @@ -3295,6 +3332,9 @@ def _upload_part(bucket_name: str, object_key: str) -> Response: except ValueError: return _error_response("InvalidArgument", "partNumber must be an integer", 400) + if part_number < 1 or part_number > 10000: + return _error_response("InvalidArgument", "partNumber must be between 1 and 10000", 400) + stream = request.stream content_encoding = request.headers.get("Content-Encoding", "").lower() if "aws-chunked" in content_encoding: @@ -3329,6 +3369,9 @@ def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Re except ValueError: return _error_response("InvalidArgument", "partNumber must be an integer", 400) + if part_number < 1 or part_number > 10000: + return _error_response("InvalidArgument", "partNumber must be between 1 and 10000", 400) + copy_source = unquote(copy_source) if copy_source.startswith("/"): copy_source = copy_source[1:] @@ -3336,6 +3379,8 @@ def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Re if len(parts) != 2: return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400) source_bucket, source_key = parts + if not source_bucket or not source_key: + return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400) _, read_error = _object_principal("read", source_bucket, source_key) if read_error: @@ -3384,7 +3429,7 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: return ct_error payload = request.get_data(cache=False) or b"" try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) @@ -3404,8 +3449,14 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: etag_el = part_el.find("ETag") if part_number_el is not None and etag_el is not None: + try: + part_num = int(part_number_el.text or 0) + except ValueError: + return _error_response("InvalidArgument", "PartNumber must be an integer", 400) + if part_num < 1 or part_num > 10000: + return _error_response("InvalidArgument", f"PartNumber {part_num} must be between 1 and 10000", 400) parts.append({ - "PartNumber": int(part_number_el.text or 0), + "PartNumber": part_num, "ETag": (etag_el.text or "").strip('"') }) @@ -3463,7 +3514,7 @@ def _select_object_content(bucket_name: str, object_key: str) -> Response: return ct_error payload = request.get_data(cache=False) or b"" try: - root = fromstring(payload) + root = _parse_xml_with_limit(payload) except ParseError: return _error_response("MalformedXML", "Unable to parse XML document", 400) if _strip_ns(root.tag) != "SelectObjectContentRequest": diff --git a/app/storage.py b/app/storage.py index 05a2fda..2a034d0 100644 --- a/app/storage.py +++ b/app/storage.py @@ -46,6 +46,34 @@ else: fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN) +@contextmanager +def _atomic_lock_file(lock_path: Path, max_retries: int = 10, base_delay: float = 0.1) -> Generator[None, None, None]: + """Atomically acquire a lock file with exponential backoff. + + Uses O_EXCL to ensure atomic creation of the lock file. + """ + lock_path.parent.mkdir(parents=True, exist_ok=True) + fd = None + for attempt in range(max_retries): + try: + fd = os.open(str(lock_path), os.O_CREAT | os.O_EXCL | os.O_WRONLY) + break + except FileExistsError: + if attempt == max_retries - 1: + raise BlockingIOError("Another upload to this key is in progress") + delay = base_delay * (2 ** attempt) + time.sleep(min(delay, 2.0)) + try: + yield + finally: + if fd is not None: + os.close(fd) + try: + lock_path.unlink(missing_ok=True) + except OSError: + pass + + WINDOWS_RESERVED_NAMES = { "CON", "PRN", @@ -1157,36 +1185,28 @@ class ObjectStorage: ) destination.parent.mkdir(parents=True, exist_ok=True) - - lock_file_path = self._system_bucket_root(bucket_id) / "locks" / f"{safe_key.as_posix().replace('/', '_')}.lock" - lock_file_path.parent.mkdir(parents=True, exist_ok=True) - - try: - with lock_file_path.open("w") as lock_file: - with _file_lock(lock_file): - if self._is_versioning_enabled(bucket_path) and destination.exists(): - self._archive_current_version(bucket_id, safe_key, reason="overwrite") - checksum = hashlib.md5() - with destination.open("wb") as target: - for _, record in validated: - part_path = upload_root / record["filename"] - if not part_path.exists(): - raise StorageError(f"Missing part file {record['filename']}") - with part_path.open("rb") as chunk: - while True: - data = chunk.read(1024 * 1024) - if not data: - break - checksum.update(data) - target.write(data) + lock_file_path = self._system_bucket_root(bucket_id) / "locks" / f"{safe_key.as_posix().replace('/', '_')}.lock" + + try: + with _atomic_lock_file(lock_file_path): + if self._is_versioning_enabled(bucket_path) and destination.exists(): + self._archive_current_version(bucket_id, safe_key, reason="overwrite") + checksum = hashlib.md5() + with destination.open("wb") as target: + for _, record in validated: + part_path = upload_root / record["filename"] + if not part_path.exists(): + raise StorageError(f"Missing part file {record['filename']}") + with part_path.open("rb") as chunk: + while True: + data = chunk.read(1024 * 1024) + if not data: + break + checksum.update(data) + target.write(data) except BlockingIOError: raise StorageError("Another upload to this key is in progress") - finally: - try: - lock_file_path.unlink(missing_ok=True) - except OSError: - pass shutil.rmtree(upload_root, ignore_errors=True) @@ -1867,13 +1887,13 @@ class ObjectStorage: def _sanitize_object_key(object_key: str, max_length_bytes: int = 1024) -> Path: if not object_key: raise StorageError("Object key required") - if len(object_key.encode("utf-8")) > max_length_bytes: - raise StorageError(f"Object key exceeds maximum length of {max_length_bytes} bytes") if "\x00" in object_key: raise StorageError("Object key contains null bytes") + object_key = unicodedata.normalize("NFC", object_key) + if len(object_key.encode("utf-8")) > max_length_bytes: + raise StorageError(f"Object key exceeds maximum length of {max_length_bytes} bytes") if object_key.startswith(("/", "\\")): raise StorageError("Object key cannot start with a slash") - object_key = unicodedata.normalize("NFC", object_key) candidate = Path(object_key) if ".." in candidate.parts: