MyFSIO v0.2.4 Release #16
@@ -78,6 +78,16 @@ def _validate_region(region: str) -> Optional[str]:
|
||||
return "Region must match format like us-east-1"
|
||||
return None
|
||||
|
||||
|
||||
def _validate_site_id(site_id: str) -> Optional[str]:
|
||||
"""Validate site_id format. Returns error message or None."""
|
||||
if not site_id or len(site_id) > 63:
|
||||
return "site_id must be 1-63 characters"
|
||||
if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9_-]*$', site_id):
|
||||
return "site_id must start with alphanumeric and contain only alphanumeric, hyphens, underscores"
|
||||
return None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin")
|
||||
@@ -168,6 +178,25 @@ def update_local_site():
|
||||
if not site_id:
|
||||
return _json_error("ValidationError", "site_id is required", 400)
|
||||
|
||||
site_id_error = _validate_site_id(site_id)
|
||||
if site_id_error:
|
||||
return _json_error("ValidationError", site_id_error, 400)
|
||||
|
||||
if endpoint:
|
||||
endpoint_error = _validate_endpoint(endpoint)
|
||||
if endpoint_error:
|
||||
return _json_error("ValidationError", endpoint_error, 400)
|
||||
|
||||
if "priority" in payload:
|
||||
priority_error = _validate_priority(payload["priority"])
|
||||
if priority_error:
|
||||
return _json_error("ValidationError", priority_error, 400)
|
||||
|
||||
if "region" in payload:
|
||||
region_error = _validate_region(payload["region"])
|
||||
if region_error:
|
||||
return _json_error("ValidationError", region_error, 400)
|
||||
|
||||
registry = _site_registry()
|
||||
existing = registry.get_local_site()
|
||||
|
||||
@@ -220,6 +249,11 @@ def register_peer_site():
|
||||
|
||||
if not site_id:
|
||||
return _json_error("ValidationError", "site_id is required", 400)
|
||||
|
||||
site_id_error = _validate_site_id(site_id)
|
||||
if site_id_error:
|
||||
return _json_error("ValidationError", site_id_error, 400)
|
||||
|
||||
if not endpoint:
|
||||
return _json_error("ValidationError", "endpoint is required", 400)
|
||||
|
||||
@@ -293,6 +327,21 @@ def update_peer_site(site_id: str):
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
|
||||
if "endpoint" in payload:
|
||||
endpoint_error = _validate_endpoint(payload["endpoint"])
|
||||
if endpoint_error:
|
||||
return _json_error("ValidationError", endpoint_error, 400)
|
||||
|
||||
if "priority" in payload:
|
||||
priority_error = _validate_priority(payload["priority"])
|
||||
if priority_error:
|
||||
return _json_error("ValidationError", priority_error, 400)
|
||||
|
||||
if "region" in payload:
|
||||
region_error = _validate_region(payload["region"])
|
||||
if region_error:
|
||||
return _json_error("ValidationError", region_error, 400)
|
||||
|
||||
peer = PeerSite(
|
||||
site_id=site_id,
|
||||
endpoint=payload.get("endpoint", existing.endpoint),
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Encryption providers for server-side and client-side encryption."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
@@ -19,6 +19,8 @@ from cryptography.hazmat.primitives import hashes
|
||||
if sys.platform != "win32":
|
||||
import fcntl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _set_secure_file_permissions(file_path: Path) -> None:
|
||||
"""Set restrictive file permissions (owner read/write only)."""
|
||||
@@ -31,8 +33,10 @@ def _set_secure_file_permissions(file_path: Path) -> None:
|
||||
"/grant:r", f"{username}:F"],
|
||||
check=True, capture_output=True
|
||||
)
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
else:
|
||||
logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path)
|
||||
except (subprocess.SubprocessError, OSError) as exc:
|
||||
logger.warning("Failed to set secure permissions on %s: %s", file_path, exc)
|
||||
else:
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
@@ -100,6 +104,18 @@ class EncryptionProvider:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes:
|
||||
"""Decrypt an encrypted data key.
|
||||
|
||||
Args:
|
||||
encrypted_data_key: The encrypted data key bytes
|
||||
key_id: Optional key identifier (used by KMS providers)
|
||||
|
||||
Returns:
|
||||
The decrypted data key
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LocalKeyEncryption(EncryptionProvider):
|
||||
"""SSE-S3 style encryption using a local master key.
|
||||
@@ -157,11 +173,13 @@ class LocalKeyEncryption(EncryptionProvider):
|
||||
except OSError as exc:
|
||||
raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc
|
||||
|
||||
DATA_KEY_AAD = b'{"purpose":"data_key","version":1}'
|
||||
|
||||
def _encrypt_data_key(self, data_key: bytes) -> bytes:
|
||||
"""Encrypt the data key with the master key."""
|
||||
aesgcm = AESGCM(self.master_key)
|
||||
nonce = secrets.token_bytes(12)
|
||||
encrypted = aesgcm.encrypt(nonce, data_key, None)
|
||||
encrypted = aesgcm.encrypt(nonce, data_key, self.DATA_KEY_AAD)
|
||||
return nonce + encrypted
|
||||
|
||||
def _decrypt_data_key(self, encrypted_data_key: bytes) -> bytes:
|
||||
@@ -172,9 +190,16 @@ class LocalKeyEncryption(EncryptionProvider):
|
||||
nonce = encrypted_data_key[:12]
|
||||
ciphertext = encrypted_data_key[12:]
|
||||
try:
|
||||
return aesgcm.decrypt(nonce, ciphertext, None)
|
||||
except Exception as exc:
|
||||
raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc
|
||||
return aesgcm.decrypt(nonce, ciphertext, self.DATA_KEY_AAD)
|
||||
except Exception:
|
||||
try:
|
||||
return aesgcm.decrypt(nonce, ciphertext, None)
|
||||
except Exception as exc:
|
||||
raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc
|
||||
|
||||
def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes:
|
||||
"""Decrypt an encrypted data key (key_id ignored for local encryption)."""
|
||||
return self._decrypt_data_key(encrypted_data_key)
|
||||
|
||||
def generate_data_key(self) -> tuple[bytes, bytes]:
|
||||
"""Generate a data key and its encrypted form."""
|
||||
@@ -281,10 +306,7 @@ class StreamingEncryptor:
|
||||
|
||||
Performance: Writes chunks directly to output buffer instead of accumulating in list.
|
||||
"""
|
||||
if isinstance(self.provider, LocalKeyEncryption):
|
||||
data_key = self.provider._decrypt_data_key(metadata.encrypted_data_key)
|
||||
else:
|
||||
raise EncryptionError("Unsupported provider for streaming decryption")
|
||||
data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id)
|
||||
|
||||
aesgcm = AESGCM(data_key)
|
||||
base_nonce = metadata.nonce
|
||||
|
||||
15
app/kms.py
15
app/kms.py
@@ -34,8 +34,10 @@ def _set_secure_file_permissions(file_path: Path) -> None:
|
||||
"/grant:r", f"{username}:F"],
|
||||
check=True, capture_output=True
|
||||
)
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
else:
|
||||
logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path)
|
||||
except (subprocess.SubprocessError, OSError) as exc:
|
||||
logger.warning("Failed to set secure permissions on %s: %s", file_path, exc)
|
||||
else:
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
@@ -128,6 +130,15 @@ class KMSEncryptionProvider(EncryptionProvider):
|
||||
logger.debug("KMS decryption failed: %s", exc)
|
||||
raise EncryptionError("Failed to decrypt data") from exc
|
||||
|
||||
def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes:
|
||||
"""Decrypt an encrypted data key using KMS."""
|
||||
if key_id is None:
|
||||
key_id = self.key_id
|
||||
data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None)
|
||||
if len(data_key) != 32:
|
||||
raise EncryptionError("Invalid data key size")
|
||||
return data_key
|
||||
|
||||
|
||||
class KMSManager:
|
||||
"""Manages KMS keys and operations.
|
||||
|
||||
137
app/s3_api.py
137
app/s3_api.py
@@ -1,4 +1,3 @@
|
||||
"""Flask blueprint exposing a subset of the S3 REST API."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
@@ -32,6 +31,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
S3_NS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||
|
||||
_HEADER_CONTROL_CHARS = re.compile(r'[\r\n\x00-\x1f\x7f]')
|
||||
|
||||
|
||||
def _sanitize_header_value(value: str) -> str:
|
||||
return _HEADER_CONTROL_CHARS.sub('', value)
|
||||
|
||||
|
||||
MAX_XML_PAYLOAD_SIZE = 1048576 # 1 MB
|
||||
|
||||
|
||||
def _parse_xml_with_limit(payload: bytes) -> Element:
|
||||
"""Parse XML payload with size limit to prevent DoS attacks."""
|
||||
max_size = current_app.config.get("MAX_XML_PAYLOAD_SIZE", MAX_XML_PAYLOAD_SIZE)
|
||||
if len(payload) > max_size:
|
||||
raise ParseError(f"XML payload exceeds maximum size of {max_size} bytes")
|
||||
return fromstring(payload)
|
||||
|
||||
|
||||
s3_api_bp = Blueprint("s3_api", __name__)
|
||||
|
||||
def _storage() -> ObjectStorage:
|
||||
@@ -126,30 +143,36 @@ def _require_xml_content_type() -> Response | None:
|
||||
def _parse_range_header(range_header: str, file_size: int) -> list[tuple[int, int]] | None:
|
||||
if not range_header.startswith("bytes="):
|
||||
return None
|
||||
max_range_value = 2**63 - 1
|
||||
ranges = []
|
||||
range_spec = range_header[6:]
|
||||
for part in range_spec.split(","):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if part.startswith("-"):
|
||||
suffix_length = int(part[1:])
|
||||
if suffix_length <= 0:
|
||||
return None
|
||||
start = max(0, file_size - suffix_length)
|
||||
end = file_size - 1
|
||||
elif part.endswith("-"):
|
||||
start = int(part[:-1])
|
||||
if start >= file_size:
|
||||
return None
|
||||
end = file_size - 1
|
||||
else:
|
||||
start_str, end_str = part.split("-", 1)
|
||||
start = int(start_str)
|
||||
end = int(end_str)
|
||||
if start > end or start >= file_size:
|
||||
return None
|
||||
end = min(end, file_size - 1)
|
||||
try:
|
||||
if part.startswith("-"):
|
||||
suffix_length = int(part[1:])
|
||||
if suffix_length <= 0 or suffix_length > max_range_value:
|
||||
return None
|
||||
start = max(0, file_size - suffix_length)
|
||||
end = file_size - 1
|
||||
elif part.endswith("-"):
|
||||
start = int(part[:-1])
|
||||
if start < 0 or start > max_range_value or start >= file_size:
|
||||
return None
|
||||
end = file_size - 1
|
||||
else:
|
||||
start_str, end_str = part.split("-", 1)
|
||||
start = int(start_str)
|
||||
end = int(end_str)
|
||||
if start < 0 or end < 0 or start > max_range_value or end > max_range_value:
|
||||
return None
|
||||
if start > end or start >= file_size:
|
||||
return None
|
||||
end = min(end, file_size - 1)
|
||||
except (ValueError, OverflowError):
|
||||
return None
|
||||
ranges.append((start, end))
|
||||
return ranges if ranges else None
|
||||
|
||||
@@ -196,7 +219,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None:
|
||||
access_key, date_stamp, region, service, signed_headers_str, signature = match.groups()
|
||||
secret_key = _iam().get_secret_key(access_key)
|
||||
if not secret_key:
|
||||
raise IamError("Invalid access key")
|
||||
raise IamError("SignatureDoesNotMatch")
|
||||
|
||||
method = req.method
|
||||
canonical_uri = _get_canonical_uri(req)
|
||||
@@ -379,16 +402,18 @@ def _verify_sigv4(req: Any) -> Principal | None:
|
||||
|
||||
|
||||
def _require_principal():
|
||||
if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \
|
||||
(request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"):
|
||||
sigv4_attempted = ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \
|
||||
(request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256")
|
||||
if sigv4_attempted:
|
||||
try:
|
||||
principal = _verify_sigv4(request)
|
||||
if principal:
|
||||
return principal, None
|
||||
return None, _error_response("AccessDenied", "Signature verification failed", 403)
|
||||
except IamError as exc:
|
||||
return None, _error_response("AccessDenied", str(exc), 403)
|
||||
except (ValueError, TypeError):
|
||||
return None, _error_response("AccessDenied", "Signature verification failed", 403)
|
||||
return None, _error_response("AccessDenied", "Signature verification failed", 403)
|
||||
|
||||
access_key = request.headers.get("X-Access-Key")
|
||||
secret_key = request.headers.get("X-Secret-Key")
|
||||
@@ -709,7 +734,7 @@ def _find_element_text(parent: Element, name: str, default: str = "") -> str:
|
||||
|
||||
def _parse_tagging_document(payload: bytes) -> list[dict[str, str]]:
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError as exc:
|
||||
raise ValueError("Malformed XML") from exc
|
||||
if _strip_ns(root.tag) != "Tagging":
|
||||
@@ -810,7 +835,7 @@ def _validate_content_type(object_key: str, content_type: str | None) -> str | N
|
||||
|
||||
def _parse_cors_document(payload: bytes) -> list[dict[str, Any]]:
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError as exc:
|
||||
raise ValueError("Malformed XML") from exc
|
||||
if _strip_ns(root.tag) != "CORSConfiguration":
|
||||
@@ -863,7 +888,7 @@ def _render_cors_document(rules: list[dict[str, Any]]) -> Element:
|
||||
|
||||
def _parse_encryption_document(payload: bytes) -> dict[str, Any]:
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError as exc:
|
||||
raise ValueError("Malformed XML") from exc
|
||||
if _strip_ns(root.tag) != "ServerSideEncryptionConfiguration":
|
||||
@@ -1000,7 +1025,7 @@ def _bucket_versioning_handler(bucket_name: str) -> Response:
|
||||
if not payload.strip():
|
||||
return _error_response("MalformedXML", "Request body is required", 400)
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
if _strip_ns(root.tag) != "VersioningConfiguration":
|
||||
@@ -1387,7 +1412,10 @@ def _bucket_list_versions_handler(bucket_name: str) -> Response:
|
||||
prefix = request.args.get("prefix", "")
|
||||
delimiter = request.args.get("delimiter", "")
|
||||
try:
|
||||
max_keys = max(1, min(int(request.args.get("max-keys", 1000)), 1000))
|
||||
max_keys = int(request.args.get("max-keys", 1000))
|
||||
if max_keys < 1:
|
||||
return _error_response("InvalidArgument", "max-keys must be a positive integer", 400)
|
||||
max_keys = min(max_keys, 1000)
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "max-keys must be an integer", 400)
|
||||
key_marker = request.args.get("key-marker", "")
|
||||
@@ -1551,7 +1579,7 @@ def _render_lifecycle_config(config: list) -> Element:
|
||||
def _parse_lifecycle_config(payload: bytes) -> list:
|
||||
"""Parse lifecycle configuration from XML."""
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError as exc:
|
||||
raise ValueError(f"Unable to parse XML document: {exc}") from exc
|
||||
|
||||
@@ -1737,7 +1765,7 @@ def _bucket_object_lock_handler(bucket_name: str) -> Response:
|
||||
return _error_response("MalformedXML", "Request body is required", 400)
|
||||
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
|
||||
@@ -1806,7 +1834,7 @@ def _bucket_notification_handler(bucket_name: str) -> Response:
|
||||
return Response(status=200)
|
||||
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
|
||||
@@ -1888,7 +1916,7 @@ def _bucket_logging_handler(bucket_name: str) -> Response:
|
||||
return Response(status=200)
|
||||
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
|
||||
@@ -1941,7 +1969,10 @@ def _bucket_uploads_handler(bucket_name: str) -> Response:
|
||||
prefix = request.args.get("prefix", "")
|
||||
delimiter = request.args.get("delimiter", "")
|
||||
try:
|
||||
max_uploads = max(1, min(int(request.args.get("max-uploads", 1000)), 1000))
|
||||
max_uploads = int(request.args.get("max-uploads", 1000))
|
||||
if max_uploads < 1:
|
||||
return _error_response("InvalidArgument", "max-uploads must be a positive integer", 400)
|
||||
max_uploads = min(max_uploads, 1000)
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "max-uploads must be an integer", 400)
|
||||
|
||||
@@ -2027,7 +2058,7 @@ def _object_retention_handler(bucket_name: str, object_key: str) -> Response:
|
||||
return _error_response("MalformedXML", "Request body is required", 400)
|
||||
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
|
||||
@@ -2099,7 +2130,7 @@ def _object_legal_hold_handler(bucket_name: str, object_key: str) -> Response:
|
||||
return _error_response("MalformedXML", "Request body is required", 400)
|
||||
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
|
||||
@@ -2132,7 +2163,7 @@ def _bulk_delete_handler(bucket_name: str) -> Response:
|
||||
if not payload.strip():
|
||||
return _error_response("MalformedXML", "Request body must include a Delete specification", 400)
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
if _strip_ns(root.tag) != "Delete":
|
||||
@@ -2339,7 +2370,10 @@ def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditio
|
||||
if actual_value != expected:
|
||||
return f"Field {field} must equal {expected}"
|
||||
elif operator == "content-length-range" and len(condition) == 3:
|
||||
min_size, max_size = condition[1], condition[2]
|
||||
try:
|
||||
min_size, max_size = int(condition[1]), int(condition[2])
|
||||
except (TypeError, ValueError):
|
||||
return "Invalid content-length-range values"
|
||||
if content_length < min_size or content_length > max_size:
|
||||
return f"Content length must be between {min_size} and {max_size}"
|
||||
return None
|
||||
@@ -2437,7 +2471,10 @@ def bucket_handler(bucket_name: str) -> Response:
|
||||
prefix = request.args.get("prefix", "")
|
||||
delimiter = request.args.get("delimiter", "")
|
||||
try:
|
||||
max_keys = max(1, min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000))
|
||||
max_keys = int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"]))
|
||||
if max_keys < 1:
|
||||
return _error_response("InvalidArgument", "max-keys must be a positive integer", 400)
|
||||
max_keys = min(max_keys, 1000)
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "max-keys must be an integer", 400)
|
||||
|
||||
@@ -2766,7 +2803,7 @@ def object_handler(bucket_name: str, object_key: str):
|
||||
for param, header in response_overrides.items():
|
||||
value = request.args.get(param)
|
||||
if value:
|
||||
response.headers[header] = value
|
||||
response.headers[header] = _sanitize_header_value(value)
|
||||
|
||||
action = "Object read" if request.method == "GET" else "Object head"
|
||||
current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes})
|
||||
@@ -2924,7 +2961,7 @@ def _bucket_replication_handler(bucket_name: str) -> Response:
|
||||
|
||||
def _parse_replication_config(bucket_name: str, payload: bytes):
|
||||
from .replication import ReplicationRule, REPLICATION_MODE_ALL
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
if _strip_ns(root.tag) != "ReplicationConfiguration":
|
||||
raise ValueError("Root element must be ReplicationConfiguration")
|
||||
rule_el = None
|
||||
@@ -3295,6 +3332,9 @@ def _upload_part(bucket_name: str, object_key: str) -> Response:
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "partNumber must be an integer", 400)
|
||||
|
||||
if part_number < 1 or part_number > 10000:
|
||||
return _error_response("InvalidArgument", "partNumber must be between 1 and 10000", 400)
|
||||
|
||||
stream = request.stream
|
||||
content_encoding = request.headers.get("Content-Encoding", "").lower()
|
||||
if "aws-chunked" in content_encoding:
|
||||
@@ -3329,6 +3369,9 @@ def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Re
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "partNumber must be an integer", 400)
|
||||
|
||||
if part_number < 1 or part_number > 10000:
|
||||
return _error_response("InvalidArgument", "partNumber must be between 1 and 10000", 400)
|
||||
|
||||
copy_source = unquote(copy_source)
|
||||
if copy_source.startswith("/"):
|
||||
copy_source = copy_source[1:]
|
||||
@@ -3336,6 +3379,8 @@ def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Re
|
||||
if len(parts) != 2:
|
||||
return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400)
|
||||
source_bucket, source_key = parts
|
||||
if not source_bucket or not source_key:
|
||||
return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400)
|
||||
|
||||
_, read_error = _object_principal("read", source_bucket, source_key)
|
||||
if read_error:
|
||||
@@ -3384,7 +3429,7 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
|
||||
return ct_error
|
||||
payload = request.get_data(cache=False) or b""
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
|
||||
@@ -3404,8 +3449,14 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
|
||||
etag_el = part_el.find("ETag")
|
||||
|
||||
if part_number_el is not None and etag_el is not None:
|
||||
try:
|
||||
part_num = int(part_number_el.text or 0)
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "PartNumber must be an integer", 400)
|
||||
if part_num < 1 or part_num > 10000:
|
||||
return _error_response("InvalidArgument", f"PartNumber {part_num} must be between 1 and 10000", 400)
|
||||
parts.append({
|
||||
"PartNumber": int(part_number_el.text or 0),
|
||||
"PartNumber": part_num,
|
||||
"ETag": (etag_el.text or "").strip('"')
|
||||
})
|
||||
|
||||
@@ -3463,7 +3514,7 @@ def _select_object_content(bucket_name: str, object_key: str) -> Response:
|
||||
return ct_error
|
||||
payload = request.get_data(cache=False) or b""
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
root = _parse_xml_with_limit(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
if _strip_ns(root.tag) != "SelectObjectContentRequest":
|
||||
|
||||
@@ -46,6 +46,34 @@ else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _atomic_lock_file(lock_path: Path, max_retries: int = 10, base_delay: float = 0.1) -> Generator[None, None, None]:
|
||||
"""Atomically acquire a lock file with exponential backoff.
|
||||
|
||||
Uses O_EXCL to ensure atomic creation of the lock file.
|
||||
"""
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
fd = os.open(str(lock_path), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
||||
break
|
||||
except FileExistsError:
|
||||
if attempt == max_retries - 1:
|
||||
raise BlockingIOError("Another upload to this key is in progress")
|
||||
delay = base_delay * (2 ** attempt)
|
||||
time.sleep(min(delay, 2.0))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if fd is not None:
|
||||
os.close(fd)
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
WINDOWS_RESERVED_NAMES = {
|
||||
"CON",
|
||||
"PRN",
|
||||
@@ -1159,34 +1187,26 @@ class ObjectStorage:
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
lock_file_path = self._system_bucket_root(bucket_id) / "locks" / f"{safe_key.as_posix().replace('/', '_')}.lock"
|
||||
lock_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with lock_file_path.open("w") as lock_file:
|
||||
with _file_lock(lock_file):
|
||||
if self._is_versioning_enabled(bucket_path) and destination.exists():
|
||||
self._archive_current_version(bucket_id, safe_key, reason="overwrite")
|
||||
checksum = hashlib.md5()
|
||||
with destination.open("wb") as target:
|
||||
for _, record in validated:
|
||||
part_path = upload_root / record["filename"]
|
||||
if not part_path.exists():
|
||||
raise StorageError(f"Missing part file {record['filename']}")
|
||||
with part_path.open("rb") as chunk:
|
||||
while True:
|
||||
data = chunk.read(1024 * 1024)
|
||||
if not data:
|
||||
break
|
||||
checksum.update(data)
|
||||
target.write(data)
|
||||
|
||||
with _atomic_lock_file(lock_file_path):
|
||||
if self._is_versioning_enabled(bucket_path) and destination.exists():
|
||||
self._archive_current_version(bucket_id, safe_key, reason="overwrite")
|
||||
checksum = hashlib.md5()
|
||||
with destination.open("wb") as target:
|
||||
for _, record in validated:
|
||||
part_path = upload_root / record["filename"]
|
||||
if not part_path.exists():
|
||||
raise StorageError(f"Missing part file {record['filename']}")
|
||||
with part_path.open("rb") as chunk:
|
||||
while True:
|
||||
data = chunk.read(1024 * 1024)
|
||||
if not data:
|
||||
break
|
||||
checksum.update(data)
|
||||
target.write(data)
|
||||
except BlockingIOError:
|
||||
raise StorageError("Another upload to this key is in progress")
|
||||
finally:
|
||||
try:
|
||||
lock_file_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
shutil.rmtree(upload_root, ignore_errors=True)
|
||||
|
||||
@@ -1867,13 +1887,13 @@ class ObjectStorage:
|
||||
def _sanitize_object_key(object_key: str, max_length_bytes: int = 1024) -> Path:
|
||||
if not object_key:
|
||||
raise StorageError("Object key required")
|
||||
if len(object_key.encode("utf-8")) > max_length_bytes:
|
||||
raise StorageError(f"Object key exceeds maximum length of {max_length_bytes} bytes")
|
||||
if "\x00" in object_key:
|
||||
raise StorageError("Object key contains null bytes")
|
||||
object_key = unicodedata.normalize("NFC", object_key)
|
||||
if len(object_key.encode("utf-8")) > max_length_bytes:
|
||||
raise StorageError(f"Object key exceeds maximum length of {max_length_bytes} bytes")
|
||||
if object_key.startswith(("/", "\\")):
|
||||
raise StorageError("Object key cannot start with a slash")
|
||||
object_key = unicodedata.normalize("NFC", object_key)
|
||||
|
||||
candidate = Path(object_key)
|
||||
if ".." in candidate.parts:
|
||||
|
||||
Reference in New Issue
Block a user