Add 4 new S3 APIs: UploadPartCopy, Bucket Replication, PostObject, SelectObjectContent

This commit is contained in:
2026-01-29 12:51:00 +08:00
parent 0ea54457e8
commit 9385d1fe1c
5 changed files with 742 additions and 4 deletions

View File

@@ -962,6 +962,7 @@ def _maybe_handle_bucket_subresource(bucket_name: str) -> Response | None:
"logging": _bucket_logging_handler,
"uploads": _bucket_uploads_handler,
"policy": _bucket_policy_handler,
"replication": _bucket_replication_handler,
}
requested = [key for key in handlers if key in request.args]
if not requested:
@@ -2196,6 +2197,134 @@ def _bulk_delete_handler(bucket_name: str) -> Response:
return _xml_response(result, status=200)
def _post_object(bucket_name: str) -> Response:
storage = _storage()
if not storage.bucket_exists(bucket_name):
return _error_response("NoSuchBucket", "Bucket does not exist", 404)
object_key = request.form.get("key")
policy_b64 = request.form.get("policy")
signature = request.form.get("x-amz-signature")
credential = request.form.get("x-amz-credential")
algorithm = request.form.get("x-amz-algorithm")
amz_date = request.form.get("x-amz-date")
if not all([object_key, policy_b64, signature, credential, algorithm, amz_date]):
return _error_response("InvalidArgument", "Missing required form fields", 400)
if algorithm != "AWS4-HMAC-SHA256":
return _error_response("InvalidArgument", "Unsupported signing algorithm", 400)
try:
policy_json = base64.b64decode(policy_b64).decode("utf-8")
policy = __import__("json").loads(policy_json)
except (ValueError, __import__("json").JSONDecodeError) as exc:
return _error_response("InvalidPolicyDocument", f"Invalid policy: {exc}", 400)
expiration = policy.get("expiration")
if expiration:
try:
exp_time = datetime.fromisoformat(expiration.replace("Z", "+00:00"))
if datetime.now(timezone.utc) > exp_time:
return _error_response("AccessDenied", "Policy expired", 403)
except ValueError:
return _error_response("InvalidPolicyDocument", "Invalid expiration format", 400)
conditions = policy.get("conditions", [])
validation_error = _validate_post_policy_conditions(bucket_name, object_key, conditions, request.form, request.content_length or 0)
if validation_error:
return _error_response("AccessDenied", validation_error, 403)
try:
parts = credential.split("/")
if len(parts) != 5:
raise ValueError("Invalid credential format")
access_key, date_stamp, region, service, _ = parts
except ValueError:
return _error_response("InvalidArgument", "Invalid credential format", 400)
secret_key = _iam().get_secret_key(access_key)
if not secret_key:
return _error_response("AccessDenied", "Invalid access key", 403)
signing_key = _derive_signing_key(secret_key, date_stamp, region, service)
expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expected_signature, signature):
return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403)
file = request.files.get("file")
if not file:
return _error_response("InvalidArgument", "Missing file field", 400)
if "${filename}" in object_key:
object_key = object_key.replace("${filename}", file.filename or "upload")
metadata = {}
for field_name, value in request.form.items():
if field_name.lower().startswith("x-amz-meta-"):
key = field_name[11:]
if key:
metadata[key] = value
try:
meta = storage.put_object(bucket_name, object_key, file.stream, metadata=metadata or None)
except QuotaExceededError as exc:
return _error_response("QuotaExceeded", str(exc), 403)
except StorageError as exc:
return _error_response("InvalidArgument", str(exc), 400)
current_app.logger.info("Object uploaded via POST", extra={"bucket": bucket_name, "key": object_key, "size": meta.size})
success_action_status = request.form.get("success_action_status", "204")
success_action_redirect = request.form.get("success_action_redirect")
if success_action_redirect:
redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}"
return Response(status=303, headers={"Location": redirect_url})
if success_action_status == "200":
root = Element("PostResponse")
SubElement(root, "Location").text = f"/{bucket_name}/{object_key}"
SubElement(root, "Bucket").text = bucket_name
SubElement(root, "Key").text = object_key
SubElement(root, "ETag").text = f'"{meta.etag}"'
return _xml_response(root, status=200)
if success_action_status == "201":
root = Element("PostResponse")
SubElement(root, "Location").text = f"/{bucket_name}/{object_key}"
SubElement(root, "Bucket").text = bucket_name
SubElement(root, "Key").text = object_key
SubElement(root, "ETag").text = f'"{meta.etag}"'
return _xml_response(root, status=201)
return Response(status=204)
def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditions: list, form_data, content_length: int) -> Optional[str]:
for condition in conditions:
if isinstance(condition, dict):
for key, expected_value in condition.items():
if key == "bucket":
if bucket_name != expected_value:
return f"Bucket must be {expected_value}"
elif key == "key":
if object_key != expected_value:
return f"Key must be {expected_value}"
else:
actual_value = form_data.get(key, "")
if actual_value != expected_value:
return f"Field {key} must be {expected_value}"
elif isinstance(condition, list) and len(condition) >= 2:
operator = condition[0].lower() if isinstance(condition[0], str) else ""
if operator == "starts-with" and len(condition) == 3:
field = condition[1].lstrip("$")
prefix = condition[2]
if field == "key":
if not object_key.startswith(prefix):
return f"Key must start with {prefix}"
else:
actual_value = form_data.get(field, "")
if not actual_value.startswith(prefix):
return f"Field {field} must start with {prefix}"
elif operator == "eq" and len(condition) == 3:
field = condition[1].lstrip("$")
expected = condition[2]
if field == "key":
if object_key != expected:
return f"Key must equal {expected}"
else:
actual_value = form_data.get(field, "")
if actual_value != expected:
return f"Field {field} must equal {expected}"
elif operator == "content-length-range" and len(condition) == 3:
min_size, max_size = condition[1], condition[2]
if content_length < min_size or content_length > max_size:
return f"Content length must be between {min_size} and {max_size}"
return None
@s3_api_bp.get("/")
@limiter.limit(_get_list_buckets_limit)
def list_buckets() -> Response:
@@ -2233,9 +2362,12 @@ def bucket_handler(bucket_name: str) -> Response:
return subresource_response
if request.method == "POST":
if "delete" not in request.args:
return _method_not_allowed(["GET", "PUT", "DELETE"])
return _bulk_delete_handler(bucket_name)
if "delete" in request.args:
return _bulk_delete_handler(bucket_name)
content_type = request.headers.get("Content-Type", "")
if "multipart/form-data" in content_type:
return _post_object(bucket_name)
return _method_not_allowed(["GET", "PUT", "DELETE"])
if request.method == "PUT":
principal, error = _require_principal()
@@ -2433,6 +2565,8 @@ def object_handler(bucket_name: str, object_key: str):
return _initiate_multipart_upload(bucket_name, object_key)
if "uploadId" in request.args:
return _complete_multipart_upload(bucket_name, object_key)
if "select" in request.args:
return _select_object_content(bucket_name, object_key)
return _method_not_allowed(["GET", "PUT", "DELETE", "HEAD", "POST"])
if request.method == "PUT":
@@ -2732,6 +2866,120 @@ def _bucket_policy_handler(bucket_name: str) -> Response:
return Response(status=204)
def _bucket_replication_handler(bucket_name: str) -> Response:
if request.method not in {"GET", "PUT", "DELETE"}:
return _method_not_allowed(["GET", "PUT", "DELETE"])
principal, error = _require_principal()
if error:
return error
try:
_authorize_action(principal, bucket_name, "policy")
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
storage = _storage()
if not storage.bucket_exists(bucket_name):
return _error_response("NoSuchBucket", "Bucket does not exist", 404)
replication = _replication_manager()
if request.method == "GET":
rule = replication.get_rule(bucket_name)
if not rule:
return _error_response("ReplicationConfigurationNotFoundError", "Replication configuration not found", 404)
return _xml_response(_render_replication_config(rule))
if request.method == "DELETE":
replication.delete_rule(bucket_name)
current_app.logger.info("Bucket replication removed", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
rule = _parse_replication_config(bucket_name, payload)
except ValueError as exc:
return _error_response("MalformedXML", str(exc), 400)
replication.set_rule(rule)
current_app.logger.info("Bucket replication updated", extra={"bucket": bucket_name})
return Response(status=200)
def _parse_replication_config(bucket_name: str, payload: bytes):
from .replication import ReplicationRule, REPLICATION_MODE_ALL
root = fromstring(payload)
if _strip_ns(root.tag) != "ReplicationConfiguration":
raise ValueError("Root element must be ReplicationConfiguration")
rule_el = None
for child in list(root):
if _strip_ns(child.tag) == "Rule":
rule_el = child
break
if rule_el is None:
raise ValueError("At least one Rule is required")
status_el = _find_element(rule_el, "Status")
status = status_el.text if status_el is not None and status_el.text else "Enabled"
enabled = status.lower() == "enabled"
filter_prefix = None
filter_el = _find_element(rule_el, "Filter")
if filter_el is not None:
prefix_el = _find_element(filter_el, "Prefix")
if prefix_el is not None and prefix_el.text:
filter_prefix = prefix_el.text
dest_el = _find_element(rule_el, "Destination")
if dest_el is None:
raise ValueError("Destination element is required")
bucket_el = _find_element(dest_el, "Bucket")
if bucket_el is None or not bucket_el.text:
raise ValueError("Destination Bucket is required")
target_bucket, target_connection_id = _parse_destination_arn(bucket_el.text)
sync_deletions = True
dm_el = _find_element(rule_el, "DeleteMarkerReplication")
if dm_el is not None:
dm_status_el = _find_element(dm_el, "Status")
if dm_status_el is not None and dm_status_el.text:
sync_deletions = dm_status_el.text.lower() == "enabled"
return ReplicationRule(
bucket_name=bucket_name,
target_connection_id=target_connection_id,
target_bucket=target_bucket,
enabled=enabled,
mode=REPLICATION_MODE_ALL,
sync_deletions=sync_deletions,
filter_prefix=filter_prefix,
)
def _parse_destination_arn(arn: str) -> tuple:
if not arn.startswith("arn:aws:s3:::"):
raise ValueError(f"Invalid ARN format: {arn}")
bucket_part = arn[13:]
if "/" in bucket_part:
connection_id, bucket_name = bucket_part.split("/", 1)
else:
connection_id = "local"
bucket_name = bucket_part
return bucket_name, connection_id
def _render_replication_config(rule) -> Element:
root = Element("ReplicationConfiguration")
SubElement(root, "Role").text = "arn:aws:iam::000000000000:role/replication"
rule_el = SubElement(root, "Rule")
SubElement(rule_el, "ID").text = f"{rule.bucket_name}-replication"
SubElement(rule_el, "Status").text = "Enabled" if rule.enabled else "Disabled"
SubElement(rule_el, "Priority").text = "1"
filter_el = SubElement(rule_el, "Filter")
if rule.filter_prefix:
SubElement(filter_el, "Prefix").text = rule.filter_prefix
dest_el = SubElement(rule_el, "Destination")
if rule.target_connection_id == "local":
arn = f"arn:aws:s3:::{rule.target_bucket}"
else:
arn = f"arn:aws:s3:::{rule.target_connection_id}/{rule.target_bucket}"
SubElement(dest_el, "Bucket").text = arn
dm_el = SubElement(rule_el, "DeleteMarkerReplication")
SubElement(dm_el, "Status").text = "Enabled" if rule.sync_deletions else "Disabled"
return root
@s3_api_bp.route("/<bucket_name>", methods=["HEAD"])
@limiter.limit(_get_head_ops_limit)
def head_bucket(bucket_name: str) -> Response:
@@ -3009,6 +3257,10 @@ def _initiate_multipart_upload(bucket_name: str, object_key: str) -> Response:
def _upload_part(bucket_name: str, object_key: str) -> Response:
copy_source = request.headers.get("x-amz-copy-source")
if copy_source:
return _upload_part_copy(bucket_name, object_key, copy_source)
principal, error = _object_principal("write", bucket_name, object_key)
if error:
return error
@@ -3042,6 +3294,62 @@ def _upload_part(bucket_name: str, object_key: str) -> Response:
return response
def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Response:
principal, error = _object_principal("write", bucket_name, object_key)
if error:
return error
upload_id = request.args.get("uploadId")
part_number_str = request.args.get("partNumber")
if not upload_id or not part_number_str:
return _error_response("InvalidArgument", "uploadId and partNumber are required", 400)
try:
part_number = int(part_number_str)
except ValueError:
return _error_response("InvalidArgument", "partNumber must be an integer", 400)
copy_source = unquote(copy_source)
if copy_source.startswith("/"):
copy_source = copy_source[1:]
parts = copy_source.split("/", 1)
if len(parts) != 2:
return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400)
source_bucket, source_key = parts
_, read_error = _object_principal("read", source_bucket, source_key)
if read_error:
return read_error
copy_source_range = request.headers.get("x-amz-copy-source-range")
start_byte, end_byte = None, None
if copy_source_range:
match = re.match(r"bytes=(\d+)-(\d+)", copy_source_range)
if not match:
return _error_response("InvalidArgument", "Invalid x-amz-copy-source-range format", 400)
start_byte, end_byte = int(match.group(1)), int(match.group(2))
try:
result = _storage().upload_part_copy(
bucket_name, upload_id, part_number,
source_bucket, source_key,
start_byte, end_byte
)
except ObjectNotFoundError:
return _error_response("NoSuchKey", "Source object not found", 404)
except StorageError as exc:
if "Multipart upload not found" in str(exc):
return _error_response("NoSuchUpload", str(exc), 404)
if "Invalid byte range" in str(exc):
return _error_response("InvalidRange", str(exc), 416)
return _error_response("InvalidArgument", str(exc), 400)
root = Element("CopyPartResult")
SubElement(root, "LastModified").text = result["last_modified"].strftime("%Y-%m-%dT%H:%M:%S.000Z")
SubElement(root, "ETag").text = f'"{result["etag"]}"'
return _xml_response(root)
def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
principal, error = _object_principal("write", bucket_name, object_key)
if error:
@@ -3126,6 +3434,164 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response:
return Response(status=204)
def _select_object_content(bucket_name: str, object_key: str) -> Response:
_, error = _object_principal("read", bucket_name, object_key)
if error:
return error
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
root = fromstring(payload)
except ParseError:
return _error_response("MalformedXML", "Unable to parse XML document", 400)
if _strip_ns(root.tag) != "SelectObjectContentRequest":
return _error_response("MalformedXML", "Root element must be SelectObjectContentRequest", 400)
expression_el = _find_element(root, "Expression")
if expression_el is None or not expression_el.text:
return _error_response("InvalidRequest", "Expression is required", 400)
expression = expression_el.text
expression_type_el = _find_element(root, "ExpressionType")
expression_type = expression_type_el.text if expression_type_el is not None and expression_type_el.text else "SQL"
if expression_type.upper() != "SQL":
return _error_response("InvalidRequest", "Only SQL expression type is supported", 400)
input_el = _find_element(root, "InputSerialization")
if input_el is None:
return _error_response("InvalidRequest", "InputSerialization is required", 400)
try:
input_format, input_config = _parse_select_input_serialization(input_el)
except ValueError as exc:
return _error_response("InvalidRequest", str(exc), 400)
output_el = _find_element(root, "OutputSerialization")
if output_el is None:
return _error_response("InvalidRequest", "OutputSerialization is required", 400)
try:
output_format, output_config = _parse_select_output_serialization(output_el)
except ValueError as exc:
return _error_response("InvalidRequest", str(exc), 400)
storage = _storage()
try:
path = storage.get_object_path(bucket_name, object_key)
except ObjectNotFoundError:
return _error_response("NoSuchKey", "Object not found", 404)
except StorageError:
return _error_response("NoSuchKey", "Object not found", 404)
from .select_content import execute_select_query, SelectError
try:
result_stream = execute_select_query(
file_path=path,
expression=expression,
input_format=input_format,
input_config=input_config,
output_format=output_format,
output_config=output_config,
)
except SelectError as exc:
return _error_response("InvalidRequest", str(exc), 400)
def generate_events():
bytes_scanned = 0
bytes_returned = 0
for chunk in result_stream:
bytes_returned += len(chunk)
yield _encode_select_event("Records", chunk)
stats_payload = _build_stats_xml(bytes_scanned, bytes_returned)
yield _encode_select_event("Stats", stats_payload)
yield _encode_select_event("End", b"")
return Response(generate_events(), mimetype="application/octet-stream", headers={"x-amz-request-charged": "requester"})
def _parse_select_input_serialization(el: Element) -> tuple:
csv_el = _find_element(el, "CSV")
if csv_el is not None:
file_header_el = _find_element(csv_el, "FileHeaderInfo")
config = {
"file_header_info": file_header_el.text.upper() if file_header_el is not None and file_header_el.text else "NONE",
"comments": _find_element_text(csv_el, "Comments", "#"),
"field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","),
"record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"),
"quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'),
"quote_escape_character": _find_element_text(csv_el, "QuoteEscapeCharacter", '"'),
}
return "CSV", config
json_el = _find_element(el, "JSON")
if json_el is not None:
type_el = _find_element(json_el, "Type")
config = {
"type": type_el.text.upper() if type_el is not None and type_el.text else "DOCUMENT",
}
return "JSON", config
parquet_el = _find_element(el, "Parquet")
if parquet_el is not None:
return "Parquet", {}
raise ValueError("InputSerialization must specify CSV, JSON, or Parquet")
def _parse_select_output_serialization(el: Element) -> tuple:
csv_el = _find_element(el, "CSV")
if csv_el is not None:
config = {
"field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","),
"record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"),
"quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'),
"quote_fields": _find_element_text(csv_el, "QuoteFields", "ASNEEDED").upper(),
}
return "CSV", config
json_el = _find_element(el, "JSON")
if json_el is not None:
config = {
"record_delimiter": _find_element_text(json_el, "RecordDelimiter", "\n"),
}
return "JSON", config
raise ValueError("OutputSerialization must specify CSV or JSON")
def _encode_select_event(event_type: str, payload: bytes) -> bytes:
import struct
import binascii
headers = _build_event_headers(event_type)
headers_length = len(headers)
total_length = 4 + 4 + 4 + headers_length + len(payload) + 4
prelude = struct.pack(">I", total_length) + struct.pack(">I", headers_length)
prelude_crc = binascii.crc32(prelude) & 0xffffffff
prelude += struct.pack(">I", prelude_crc)
message = prelude + headers + payload
message_crc = binascii.crc32(message) & 0xffffffff
message += struct.pack(">I", message_crc)
return message
def _build_event_headers(event_type: str) -> bytes:
headers = b""
headers += _encode_select_header(":event-type", event_type)
if event_type == "Records":
headers += _encode_select_header(":content-type", "application/octet-stream")
elif event_type == "Stats":
headers += _encode_select_header(":content-type", "text/xml")
headers += _encode_select_header(":message-type", "event")
return headers
def _encode_select_header(name: str, value: str) -> bytes:
import struct
name_bytes = name.encode("utf-8")
value_bytes = value.encode("utf-8")
header = struct.pack("B", len(name_bytes)) + name_bytes
header += struct.pack("B", 7)
header += struct.pack(">H", len(value_bytes)) + value_bytes
return header
def _build_stats_xml(bytes_scanned: int, bytes_returned: int) -> bytes:
stats = Element("Stats")
SubElement(stats, "BytesScanned").text = str(bytes_scanned)
SubElement(stats, "BytesProcessed").text = str(bytes_scanned)
SubElement(stats, "BytesReturned").text = str(bytes_returned)
return tostring(stats, encoding="utf-8")
@s3_api_bp.before_request
def resolve_principal():
g.principal = None