Add new tests; Fix typo and validations

This commit is contained in:
2026-01-03 23:29:07 +08:00
parent 2d60e36fbf
commit b9cfc45aa2
14 changed files with 1970 additions and 125 deletions

View File

@@ -25,7 +25,7 @@ from .iam import IamError, Principal
from .notifications import NotificationService, NotificationConfiguration, WebhookDestination
from .object_lock import ObjectLockService, ObjectLockRetention, ObjectLockConfig, ObjectLockError, RetentionMode
from .replication import ReplicationManager
from .storage import ObjectStorage, StorageError, QuotaExceededError
from .storage import ObjectStorage, StorageError, QuotaExceededError, BucketNotFoundError, ObjectNotFoundError
logger = logging.getLogger(__name__)
@@ -217,7 +217,6 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None:
calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
if not hmac.compare_digest(calculated_signature, signature):
# Only log detailed signature debug info if DEBUG_SIGV4 is enabled
if current_app.config.get("DEBUG_SIGV4"):
logger.warning(
"SigV4 signature mismatch",
@@ -260,7 +259,13 @@ def _verify_sigv4_query(req: Any) -> Principal | None:
raise IamError("Invalid Date format")
now = datetime.now(timezone.utc)
if now > req_time + timedelta(seconds=int(expires)):
try:
expires_seconds = int(expires)
if expires_seconds <= 0:
raise IamError("Invalid Expires value: must be positive")
except ValueError:
raise IamError("Invalid Expires value: must be an integer")
if now > req_time + timedelta(seconds=expires_seconds):
raise IamError("Request expired")
secret_key = _iam().get_secret_key(access_key)
@@ -1036,21 +1041,23 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response:
if request.method == "GET":
try:
tags = storage.get_object_tags(bucket_name, object_key)
except BucketNotFoundError as exc:
return _error_response("NoSuchBucket", str(exc), 404)
except ObjectNotFoundError as exc:
return _error_response("NoSuchKey", str(exc), 404)
except StorageError as exc:
message = str(exc)
if "Bucket" in message:
return _error_response("NoSuchBucket", message, 404)
return _error_response("NoSuchKey", message, 404)
return _error_response("InternalError", str(exc), 500)
return _xml_response(_render_tagging_document(tags))
if request.method == "DELETE":
try:
storage.delete_object_tags(bucket_name, object_key)
except BucketNotFoundError as exc:
return _error_response("NoSuchBucket", str(exc), 404)
except ObjectNotFoundError as exc:
return _error_response("NoSuchKey", str(exc), 404)
except StorageError as exc:
message = str(exc)
if "Bucket" in message:
return _error_response("NoSuchBucket", message, 404)
return _error_response("NoSuchKey", message, 404)
return _error_response("InternalError", str(exc), 500)
current_app.logger.info("Object tags deleted", extra={"bucket": bucket_name, "key": object_key})
return Response(status=204)
@@ -1063,11 +1070,12 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response:
return _error_response("InvalidTag", "A maximum of 10 tags is supported for objects", 400)
try:
storage.set_object_tags(bucket_name, object_key, tags)
except BucketNotFoundError as exc:
return _error_response("NoSuchBucket", str(exc), 404)
except ObjectNotFoundError as exc:
return _error_response("NoSuchKey", str(exc), 404)
except StorageError as exc:
message = str(exc)
if "Bucket" in message:
return _error_response("NoSuchBucket", message, 404)
return _error_response("NoSuchKey", message, 404)
return _error_response("InternalError", str(exc), 500)
current_app.logger.info("Object tags updated", extra={"bucket": bucket_name, "key": object_key, "tags": len(tags)})
return Response(status=204)
@@ -1283,7 +1291,10 @@ def _bucket_list_versions_handler(bucket_name: str) -> Response:
prefix = request.args.get("prefix", "")
delimiter = request.args.get("delimiter", "")
max_keys = min(int(request.args.get("max-keys", 1000)), 1000)
try:
max_keys = max(1, min(int(request.args.get("max-keys", 1000)), 1000))
except ValueError:
return _error_response("InvalidArgument", "max-keys must be an integer", 400)
key_marker = request.args.get("key-marker", "")
if prefix:
@@ -1476,7 +1487,10 @@ def _parse_lifecycle_config(payload: bytes) -> list:
expiration: dict = {}
days_el = exp_el.find("{*}Days") or exp_el.find("Days")
if days_el is not None and days_el.text:
expiration["Days"] = int(days_el.text.strip())
days_val = int(days_el.text.strip())
if days_val <= 0:
raise ValueError("Expiration Days must be a positive integer")
expiration["Days"] = days_val
date_el = exp_el.find("{*}Date") or exp_el.find("Date")
if date_el is not None and date_el.text:
expiration["Date"] = date_el.text.strip()
@@ -1491,7 +1505,10 @@ def _parse_lifecycle_config(payload: bytes) -> list:
nve: dict = {}
days_el = nve_el.find("{*}NoncurrentDays") or nve_el.find("NoncurrentDays")
if days_el is not None and days_el.text:
nve["NoncurrentDays"] = int(days_el.text.strip())
noncurrent_days = int(days_el.text.strip())
if noncurrent_days <= 0:
raise ValueError("NoncurrentDays must be a positive integer")
nve["NoncurrentDays"] = noncurrent_days
if nve:
rule["NoncurrentVersionExpiration"] = nve
@@ -1500,7 +1517,10 @@ def _parse_lifecycle_config(payload: bytes) -> list:
aimu: dict = {}
days_el = aimu_el.find("{*}DaysAfterInitiation") or aimu_el.find("DaysAfterInitiation")
if days_el is not None and days_el.text:
aimu["DaysAfterInitiation"] = int(days_el.text.strip())
days_after = int(days_el.text.strip())
if days_after <= 0:
raise ValueError("DaysAfterInitiation must be a positive integer")
aimu["DaysAfterInitiation"] = days_after
if aimu:
rule["AbortIncompleteMultipartUpload"] = aimu
@@ -2086,7 +2106,10 @@ def bucket_handler(bucket_name: str) -> Response:
list_type = request.args.get("list-type")
prefix = request.args.get("prefix", "")
delimiter = request.args.get("delimiter", "")
max_keys = min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000)
try:
max_keys = max(1, min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000))
except ValueError:
return _error_response("InvalidArgument", "max-keys must be an integer", 400)
marker = request.args.get("marker", "") # ListObjects v1
continuation_token = request.args.get("continuation-token", "") # ListObjectsV2
@@ -2099,7 +2122,7 @@ def bucket_handler(bucket_name: str) -> Response:
if continuation_token:
try:
effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8")
except Exception:
except (ValueError, UnicodeDecodeError):
effective_start = continuation_token
elif start_after:
effective_start = start_after
@@ -2742,7 +2765,7 @@ class AwsChunkedDecoder:
def __init__(self, stream):
self.stream = stream
self._read_buffer = bytearray() # Performance: Pre-allocated buffer
self._read_buffer = bytearray()
self.chunk_remaining = 0
self.finished = False
@@ -2753,20 +2776,15 @@ class AwsChunkedDecoder:
"""
line = bytearray()
while True:
# Check if we have data in buffer
if self._read_buffer:
# Look for CRLF in buffer
idx = self._read_buffer.find(b"\r\n")
if idx != -1:
# Found CRLF - extract line and update buffer
line.extend(self._read_buffer[: idx + 2])
del self._read_buffer[: idx + 2]
return bytes(line)
# No CRLF yet - consume entire buffer
line.extend(self._read_buffer)
self._read_buffer.clear()
# Read more data in larger chunks (64 bytes is enough for chunk headers)
chunk = self.stream.read(64)
if not chunk:
return bytes(line) if line else b""
@@ -2775,14 +2793,11 @@ class AwsChunkedDecoder:
def _read_exact(self, n: int) -> bytes:
"""Read exactly n bytes, using buffer first."""
result = bytearray()
# Use buffered data first
if self._read_buffer:
take = min(len(self._read_buffer), n)
result.extend(self._read_buffer[:take])
del self._read_buffer[:take]
n -= take
# Read remaining directly from stream
if n > 0:
data = self.stream.read(n)
if data:
@@ -2794,7 +2809,7 @@ class AwsChunkedDecoder:
if self.finished:
return b""
result = bytearray() # Performance: Use bytearray for building result
result = bytearray()
while size == -1 or len(result) < size:
if self.chunk_remaining > 0:
to_read = self.chunk_remaining
@@ -2828,7 +2843,6 @@ class AwsChunkedDecoder:
if chunk_size == 0:
self.finished = True
# Skip trailing headers
while True:
trailer = self._read_line()
if trailer == b"\r\n" or not trailer:
@@ -2969,10 +2983,11 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response:
try:
_storage().abort_multipart_upload(bucket_name, upload_id)
except BucketNotFoundError as exc:
return _error_response("NoSuchBucket", str(exc), 404)
except StorageError as exc:
if "Bucket does not exist" in str(exc):
return _error_response("NoSuchBucket", str(exc), 404)
current_app.logger.warning(f"Error aborting multipart upload: {exc}")
return Response(status=204)
@@ -2984,13 +2999,15 @@ def resolve_principal():
(request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"):
g.principal = _verify_sigv4(request)
return
except Exception:
pass
except IamError as exc:
logger.debug(f"SigV4 authentication failed: {exc}")
except (ValueError, KeyError) as exc:
logger.debug(f"SigV4 parsing error: {exc}")
access_key = request.headers.get("X-Access-Key")
secret_key = request.headers.get("X-Secret-Key")
if access_key and secret_key:
try:
g.principal = _iam().authenticate(access_key, secret_key)
except Exception:
pass
except IamError as exc:
logger.debug(f"Header authentication failed: {exc}")