Add 4 new S3 APIs: UploadPartCopy, Bucket Replication, PostObject, SelectObjectContent
This commit is contained in:
@@ -137,6 +137,7 @@ class ReplicationRule:
|
||||
stats: ReplicationStats = field(default_factory=ReplicationStats)
|
||||
sync_deletions: bool = True
|
||||
last_pull_at: Optional[float] = None
|
||||
filter_prefix: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
@@ -149,6 +150,7 @@ class ReplicationRule:
|
||||
"stats": self.stats.to_dict(),
|
||||
"sync_deletions": self.sync_deletions,
|
||||
"last_pull_at": self.last_pull_at,
|
||||
"filter_prefix": self.filter_prefix,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -162,6 +164,8 @@ class ReplicationRule:
|
||||
data["sync_deletions"] = True
|
||||
if "last_pull_at" not in data:
|
||||
data["last_pull_at"] = None
|
||||
if "filter_prefix" not in data:
|
||||
data["filter_prefix"] = None
|
||||
rule = cls(**data)
|
||||
rule.stats = ReplicationStats.from_dict(stats_data) if stats_data else ReplicationStats()
|
||||
return rule
|
||||
|
||||
472
app/s3_api.py
472
app/s3_api.py
@@ -962,6 +962,7 @@ def _maybe_handle_bucket_subresource(bucket_name: str) -> Response | None:
|
||||
"logging": _bucket_logging_handler,
|
||||
"uploads": _bucket_uploads_handler,
|
||||
"policy": _bucket_policy_handler,
|
||||
"replication": _bucket_replication_handler,
|
||||
}
|
||||
requested = [key for key in handlers if key in request.args]
|
||||
if not requested:
|
||||
@@ -2196,6 +2197,134 @@ def _bulk_delete_handler(bucket_name: str) -> Response:
|
||||
return _xml_response(result, status=200)
|
||||
|
||||
|
||||
def _post_object(bucket_name: str) -> Response:
|
||||
storage = _storage()
|
||||
if not storage.bucket_exists(bucket_name):
|
||||
return _error_response("NoSuchBucket", "Bucket does not exist", 404)
|
||||
object_key = request.form.get("key")
|
||||
policy_b64 = request.form.get("policy")
|
||||
signature = request.form.get("x-amz-signature")
|
||||
credential = request.form.get("x-amz-credential")
|
||||
algorithm = request.form.get("x-amz-algorithm")
|
||||
amz_date = request.form.get("x-amz-date")
|
||||
if not all([object_key, policy_b64, signature, credential, algorithm, amz_date]):
|
||||
return _error_response("InvalidArgument", "Missing required form fields", 400)
|
||||
if algorithm != "AWS4-HMAC-SHA256":
|
||||
return _error_response("InvalidArgument", "Unsupported signing algorithm", 400)
|
||||
try:
|
||||
policy_json = base64.b64decode(policy_b64).decode("utf-8")
|
||||
policy = __import__("json").loads(policy_json)
|
||||
except (ValueError, __import__("json").JSONDecodeError) as exc:
|
||||
return _error_response("InvalidPolicyDocument", f"Invalid policy: {exc}", 400)
|
||||
expiration = policy.get("expiration")
|
||||
if expiration:
|
||||
try:
|
||||
exp_time = datetime.fromisoformat(expiration.replace("Z", "+00:00"))
|
||||
if datetime.now(timezone.utc) > exp_time:
|
||||
return _error_response("AccessDenied", "Policy expired", 403)
|
||||
except ValueError:
|
||||
return _error_response("InvalidPolicyDocument", "Invalid expiration format", 400)
|
||||
conditions = policy.get("conditions", [])
|
||||
validation_error = _validate_post_policy_conditions(bucket_name, object_key, conditions, request.form, request.content_length or 0)
|
||||
if validation_error:
|
||||
return _error_response("AccessDenied", validation_error, 403)
|
||||
try:
|
||||
parts = credential.split("/")
|
||||
if len(parts) != 5:
|
||||
raise ValueError("Invalid credential format")
|
||||
access_key, date_stamp, region, service, _ = parts
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "Invalid credential format", 400)
|
||||
secret_key = _iam().get_secret_key(access_key)
|
||||
if not secret_key:
|
||||
return _error_response("AccessDenied", "Invalid access key", 403)
|
||||
signing_key = _derive_signing_key(secret_key, date_stamp, region, service)
|
||||
expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
if not hmac.compare_digest(expected_signature, signature):
|
||||
return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403)
|
||||
file = request.files.get("file")
|
||||
if not file:
|
||||
return _error_response("InvalidArgument", "Missing file field", 400)
|
||||
if "${filename}" in object_key:
|
||||
object_key = object_key.replace("${filename}", file.filename or "upload")
|
||||
metadata = {}
|
||||
for field_name, value in request.form.items():
|
||||
if field_name.lower().startswith("x-amz-meta-"):
|
||||
key = field_name[11:]
|
||||
if key:
|
||||
metadata[key] = value
|
||||
try:
|
||||
meta = storage.put_object(bucket_name, object_key, file.stream, metadata=metadata or None)
|
||||
except QuotaExceededError as exc:
|
||||
return _error_response("QuotaExceeded", str(exc), 403)
|
||||
except StorageError as exc:
|
||||
return _error_response("InvalidArgument", str(exc), 400)
|
||||
current_app.logger.info("Object uploaded via POST", extra={"bucket": bucket_name, "key": object_key, "size": meta.size})
|
||||
success_action_status = request.form.get("success_action_status", "204")
|
||||
success_action_redirect = request.form.get("success_action_redirect")
|
||||
if success_action_redirect:
|
||||
redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}"
|
||||
return Response(status=303, headers={"Location": redirect_url})
|
||||
if success_action_status == "200":
|
||||
root = Element("PostResponse")
|
||||
SubElement(root, "Location").text = f"/{bucket_name}/{object_key}"
|
||||
SubElement(root, "Bucket").text = bucket_name
|
||||
SubElement(root, "Key").text = object_key
|
||||
SubElement(root, "ETag").text = f'"{meta.etag}"'
|
||||
return _xml_response(root, status=200)
|
||||
if success_action_status == "201":
|
||||
root = Element("PostResponse")
|
||||
SubElement(root, "Location").text = f"/{bucket_name}/{object_key}"
|
||||
SubElement(root, "Bucket").text = bucket_name
|
||||
SubElement(root, "Key").text = object_key
|
||||
SubElement(root, "ETag").text = f'"{meta.etag}"'
|
||||
return _xml_response(root, status=201)
|
||||
return Response(status=204)
|
||||
|
||||
|
||||
def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditions: list, form_data, content_length: int) -> Optional[str]:
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
for key, expected_value in condition.items():
|
||||
if key == "bucket":
|
||||
if bucket_name != expected_value:
|
||||
return f"Bucket must be {expected_value}"
|
||||
elif key == "key":
|
||||
if object_key != expected_value:
|
||||
return f"Key must be {expected_value}"
|
||||
else:
|
||||
actual_value = form_data.get(key, "")
|
||||
if actual_value != expected_value:
|
||||
return f"Field {key} must be {expected_value}"
|
||||
elif isinstance(condition, list) and len(condition) >= 2:
|
||||
operator = condition[0].lower() if isinstance(condition[0], str) else ""
|
||||
if operator == "starts-with" and len(condition) == 3:
|
||||
field = condition[1].lstrip("$")
|
||||
prefix = condition[2]
|
||||
if field == "key":
|
||||
if not object_key.startswith(prefix):
|
||||
return f"Key must start with {prefix}"
|
||||
else:
|
||||
actual_value = form_data.get(field, "")
|
||||
if not actual_value.startswith(prefix):
|
||||
return f"Field {field} must start with {prefix}"
|
||||
elif operator == "eq" and len(condition) == 3:
|
||||
field = condition[1].lstrip("$")
|
||||
expected = condition[2]
|
||||
if field == "key":
|
||||
if object_key != expected:
|
||||
return f"Key must equal {expected}"
|
||||
else:
|
||||
actual_value = form_data.get(field, "")
|
||||
if actual_value != expected:
|
||||
return f"Field {field} must equal {expected}"
|
||||
elif operator == "content-length-range" and len(condition) == 3:
|
||||
min_size, max_size = condition[1], condition[2]
|
||||
if content_length < min_size or content_length > max_size:
|
||||
return f"Content length must be between {min_size} and {max_size}"
|
||||
return None
|
||||
|
||||
|
||||
@s3_api_bp.get("/")
|
||||
@limiter.limit(_get_list_buckets_limit)
|
||||
def list_buckets() -> Response:
|
||||
@@ -2233,9 +2362,12 @@ def bucket_handler(bucket_name: str) -> Response:
|
||||
return subresource_response
|
||||
|
||||
if request.method == "POST":
|
||||
if "delete" not in request.args:
|
||||
return _method_not_allowed(["GET", "PUT", "DELETE"])
|
||||
return _bulk_delete_handler(bucket_name)
|
||||
if "delete" in request.args:
|
||||
return _bulk_delete_handler(bucket_name)
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "multipart/form-data" in content_type:
|
||||
return _post_object(bucket_name)
|
||||
return _method_not_allowed(["GET", "PUT", "DELETE"])
|
||||
|
||||
if request.method == "PUT":
|
||||
principal, error = _require_principal()
|
||||
@@ -2433,6 +2565,8 @@ def object_handler(bucket_name: str, object_key: str):
|
||||
return _initiate_multipart_upload(bucket_name, object_key)
|
||||
if "uploadId" in request.args:
|
||||
return _complete_multipart_upload(bucket_name, object_key)
|
||||
if "select" in request.args:
|
||||
return _select_object_content(bucket_name, object_key)
|
||||
return _method_not_allowed(["GET", "PUT", "DELETE", "HEAD", "POST"])
|
||||
|
||||
if request.method == "PUT":
|
||||
@@ -2732,6 +2866,120 @@ def _bucket_policy_handler(bucket_name: str) -> Response:
|
||||
return Response(status=204)
|
||||
|
||||
|
||||
def _bucket_replication_handler(bucket_name: str) -> Response:
|
||||
if request.method not in {"GET", "PUT", "DELETE"}:
|
||||
return _method_not_allowed(["GET", "PUT", "DELETE"])
|
||||
principal, error = _require_principal()
|
||||
if error:
|
||||
return error
|
||||
try:
|
||||
_authorize_action(principal, bucket_name, "policy")
|
||||
except IamError as exc:
|
||||
return _error_response("AccessDenied", str(exc), 403)
|
||||
storage = _storage()
|
||||
if not storage.bucket_exists(bucket_name):
|
||||
return _error_response("NoSuchBucket", "Bucket does not exist", 404)
|
||||
replication = _replication_manager()
|
||||
if request.method == "GET":
|
||||
rule = replication.get_rule(bucket_name)
|
||||
if not rule:
|
||||
return _error_response("ReplicationConfigurationNotFoundError", "Replication configuration not found", 404)
|
||||
return _xml_response(_render_replication_config(rule))
|
||||
if request.method == "DELETE":
|
||||
replication.delete_rule(bucket_name)
|
||||
current_app.logger.info("Bucket replication removed", extra={"bucket": bucket_name})
|
||||
return Response(status=204)
|
||||
ct_error = _require_xml_content_type()
|
||||
if ct_error:
|
||||
return ct_error
|
||||
payload = request.get_data(cache=False) or b""
|
||||
try:
|
||||
rule = _parse_replication_config(bucket_name, payload)
|
||||
except ValueError as exc:
|
||||
return _error_response("MalformedXML", str(exc), 400)
|
||||
replication.set_rule(rule)
|
||||
current_app.logger.info("Bucket replication updated", extra={"bucket": bucket_name})
|
||||
return Response(status=200)
|
||||
|
||||
|
||||
def _parse_replication_config(bucket_name: str, payload: bytes):
|
||||
from .replication import ReplicationRule, REPLICATION_MODE_ALL
|
||||
root = fromstring(payload)
|
||||
if _strip_ns(root.tag) != "ReplicationConfiguration":
|
||||
raise ValueError("Root element must be ReplicationConfiguration")
|
||||
rule_el = None
|
||||
for child in list(root):
|
||||
if _strip_ns(child.tag) == "Rule":
|
||||
rule_el = child
|
||||
break
|
||||
if rule_el is None:
|
||||
raise ValueError("At least one Rule is required")
|
||||
status_el = _find_element(rule_el, "Status")
|
||||
status = status_el.text if status_el is not None and status_el.text else "Enabled"
|
||||
enabled = status.lower() == "enabled"
|
||||
filter_prefix = None
|
||||
filter_el = _find_element(rule_el, "Filter")
|
||||
if filter_el is not None:
|
||||
prefix_el = _find_element(filter_el, "Prefix")
|
||||
if prefix_el is not None and prefix_el.text:
|
||||
filter_prefix = prefix_el.text
|
||||
dest_el = _find_element(rule_el, "Destination")
|
||||
if dest_el is None:
|
||||
raise ValueError("Destination element is required")
|
||||
bucket_el = _find_element(dest_el, "Bucket")
|
||||
if bucket_el is None or not bucket_el.text:
|
||||
raise ValueError("Destination Bucket is required")
|
||||
target_bucket, target_connection_id = _parse_destination_arn(bucket_el.text)
|
||||
sync_deletions = True
|
||||
dm_el = _find_element(rule_el, "DeleteMarkerReplication")
|
||||
if dm_el is not None:
|
||||
dm_status_el = _find_element(dm_el, "Status")
|
||||
if dm_status_el is not None and dm_status_el.text:
|
||||
sync_deletions = dm_status_el.text.lower() == "enabled"
|
||||
return ReplicationRule(
|
||||
bucket_name=bucket_name,
|
||||
target_connection_id=target_connection_id,
|
||||
target_bucket=target_bucket,
|
||||
enabled=enabled,
|
||||
mode=REPLICATION_MODE_ALL,
|
||||
sync_deletions=sync_deletions,
|
||||
filter_prefix=filter_prefix,
|
||||
)
|
||||
|
||||
|
||||
def _parse_destination_arn(arn: str) -> tuple:
|
||||
if not arn.startswith("arn:aws:s3:::"):
|
||||
raise ValueError(f"Invalid ARN format: {arn}")
|
||||
bucket_part = arn[13:]
|
||||
if "/" in bucket_part:
|
||||
connection_id, bucket_name = bucket_part.split("/", 1)
|
||||
else:
|
||||
connection_id = "local"
|
||||
bucket_name = bucket_part
|
||||
return bucket_name, connection_id
|
||||
|
||||
|
||||
def _render_replication_config(rule) -> Element:
|
||||
root = Element("ReplicationConfiguration")
|
||||
SubElement(root, "Role").text = "arn:aws:iam::000000000000:role/replication"
|
||||
rule_el = SubElement(root, "Rule")
|
||||
SubElement(rule_el, "ID").text = f"{rule.bucket_name}-replication"
|
||||
SubElement(rule_el, "Status").text = "Enabled" if rule.enabled else "Disabled"
|
||||
SubElement(rule_el, "Priority").text = "1"
|
||||
filter_el = SubElement(rule_el, "Filter")
|
||||
if rule.filter_prefix:
|
||||
SubElement(filter_el, "Prefix").text = rule.filter_prefix
|
||||
dest_el = SubElement(rule_el, "Destination")
|
||||
if rule.target_connection_id == "local":
|
||||
arn = f"arn:aws:s3:::{rule.target_bucket}"
|
||||
else:
|
||||
arn = f"arn:aws:s3:::{rule.target_connection_id}/{rule.target_bucket}"
|
||||
SubElement(dest_el, "Bucket").text = arn
|
||||
dm_el = SubElement(rule_el, "DeleteMarkerReplication")
|
||||
SubElement(dm_el, "Status").text = "Enabled" if rule.sync_deletions else "Disabled"
|
||||
return root
|
||||
|
||||
|
||||
@s3_api_bp.route("/<bucket_name>", methods=["HEAD"])
|
||||
@limiter.limit(_get_head_ops_limit)
|
||||
def head_bucket(bucket_name: str) -> Response:
|
||||
@@ -3009,6 +3257,10 @@ def _initiate_multipart_upload(bucket_name: str, object_key: str) -> Response:
|
||||
|
||||
|
||||
def _upload_part(bucket_name: str, object_key: str) -> Response:
|
||||
copy_source = request.headers.get("x-amz-copy-source")
|
||||
if copy_source:
|
||||
return _upload_part_copy(bucket_name, object_key, copy_source)
|
||||
|
||||
principal, error = _object_principal("write", bucket_name, object_key)
|
||||
if error:
|
||||
return error
|
||||
@@ -3042,6 +3294,62 @@ def _upload_part(bucket_name: str, object_key: str) -> Response:
|
||||
return response
|
||||
|
||||
|
||||
def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Response:
|
||||
principal, error = _object_principal("write", bucket_name, object_key)
|
||||
if error:
|
||||
return error
|
||||
|
||||
upload_id = request.args.get("uploadId")
|
||||
part_number_str = request.args.get("partNumber")
|
||||
if not upload_id or not part_number_str:
|
||||
return _error_response("InvalidArgument", "uploadId and partNumber are required", 400)
|
||||
|
||||
try:
|
||||
part_number = int(part_number_str)
|
||||
except ValueError:
|
||||
return _error_response("InvalidArgument", "partNumber must be an integer", 400)
|
||||
|
||||
copy_source = unquote(copy_source)
|
||||
if copy_source.startswith("/"):
|
||||
copy_source = copy_source[1:]
|
||||
parts = copy_source.split("/", 1)
|
||||
if len(parts) != 2:
|
||||
return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400)
|
||||
source_bucket, source_key = parts
|
||||
|
||||
_, read_error = _object_principal("read", source_bucket, source_key)
|
||||
if read_error:
|
||||
return read_error
|
||||
|
||||
copy_source_range = request.headers.get("x-amz-copy-source-range")
|
||||
start_byte, end_byte = None, None
|
||||
if copy_source_range:
|
||||
match = re.match(r"bytes=(\d+)-(\d+)", copy_source_range)
|
||||
if not match:
|
||||
return _error_response("InvalidArgument", "Invalid x-amz-copy-source-range format", 400)
|
||||
start_byte, end_byte = int(match.group(1)), int(match.group(2))
|
||||
|
||||
try:
|
||||
result = _storage().upload_part_copy(
|
||||
bucket_name, upload_id, part_number,
|
||||
source_bucket, source_key,
|
||||
start_byte, end_byte
|
||||
)
|
||||
except ObjectNotFoundError:
|
||||
return _error_response("NoSuchKey", "Source object not found", 404)
|
||||
except StorageError as exc:
|
||||
if "Multipart upload not found" in str(exc):
|
||||
return _error_response("NoSuchUpload", str(exc), 404)
|
||||
if "Invalid byte range" in str(exc):
|
||||
return _error_response("InvalidRange", str(exc), 416)
|
||||
return _error_response("InvalidArgument", str(exc), 400)
|
||||
|
||||
root = Element("CopyPartResult")
|
||||
SubElement(root, "LastModified").text = result["last_modified"].strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
SubElement(root, "ETag").text = f'"{result["etag"]}"'
|
||||
return _xml_response(root)
|
||||
|
||||
|
||||
def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
|
||||
principal, error = _object_principal("write", bucket_name, object_key)
|
||||
if error:
|
||||
@@ -3126,6 +3434,164 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response:
|
||||
return Response(status=204)
|
||||
|
||||
|
||||
def _select_object_content(bucket_name: str, object_key: str) -> Response:
|
||||
_, error = _object_principal("read", bucket_name, object_key)
|
||||
if error:
|
||||
return error
|
||||
ct_error = _require_xml_content_type()
|
||||
if ct_error:
|
||||
return ct_error
|
||||
payload = request.get_data(cache=False) or b""
|
||||
try:
|
||||
root = fromstring(payload)
|
||||
except ParseError:
|
||||
return _error_response("MalformedXML", "Unable to parse XML document", 400)
|
||||
if _strip_ns(root.tag) != "SelectObjectContentRequest":
|
||||
return _error_response("MalformedXML", "Root element must be SelectObjectContentRequest", 400)
|
||||
expression_el = _find_element(root, "Expression")
|
||||
if expression_el is None or not expression_el.text:
|
||||
return _error_response("InvalidRequest", "Expression is required", 400)
|
||||
expression = expression_el.text
|
||||
expression_type_el = _find_element(root, "ExpressionType")
|
||||
expression_type = expression_type_el.text if expression_type_el is not None and expression_type_el.text else "SQL"
|
||||
if expression_type.upper() != "SQL":
|
||||
return _error_response("InvalidRequest", "Only SQL expression type is supported", 400)
|
||||
input_el = _find_element(root, "InputSerialization")
|
||||
if input_el is None:
|
||||
return _error_response("InvalidRequest", "InputSerialization is required", 400)
|
||||
try:
|
||||
input_format, input_config = _parse_select_input_serialization(input_el)
|
||||
except ValueError as exc:
|
||||
return _error_response("InvalidRequest", str(exc), 400)
|
||||
output_el = _find_element(root, "OutputSerialization")
|
||||
if output_el is None:
|
||||
return _error_response("InvalidRequest", "OutputSerialization is required", 400)
|
||||
try:
|
||||
output_format, output_config = _parse_select_output_serialization(output_el)
|
||||
except ValueError as exc:
|
||||
return _error_response("InvalidRequest", str(exc), 400)
|
||||
storage = _storage()
|
||||
try:
|
||||
path = storage.get_object_path(bucket_name, object_key)
|
||||
except ObjectNotFoundError:
|
||||
return _error_response("NoSuchKey", "Object not found", 404)
|
||||
except StorageError:
|
||||
return _error_response("NoSuchKey", "Object not found", 404)
|
||||
from .select_content import execute_select_query, SelectError
|
||||
try:
|
||||
result_stream = execute_select_query(
|
||||
file_path=path,
|
||||
expression=expression,
|
||||
input_format=input_format,
|
||||
input_config=input_config,
|
||||
output_format=output_format,
|
||||
output_config=output_config,
|
||||
)
|
||||
except SelectError as exc:
|
||||
return _error_response("InvalidRequest", str(exc), 400)
|
||||
|
||||
def generate_events():
|
||||
bytes_scanned = 0
|
||||
bytes_returned = 0
|
||||
for chunk in result_stream:
|
||||
bytes_returned += len(chunk)
|
||||
yield _encode_select_event("Records", chunk)
|
||||
stats_payload = _build_stats_xml(bytes_scanned, bytes_returned)
|
||||
yield _encode_select_event("Stats", stats_payload)
|
||||
yield _encode_select_event("End", b"")
|
||||
|
||||
return Response(generate_events(), mimetype="application/octet-stream", headers={"x-amz-request-charged": "requester"})
|
||||
|
||||
|
||||
def _parse_select_input_serialization(el: Element) -> tuple:
|
||||
csv_el = _find_element(el, "CSV")
|
||||
if csv_el is not None:
|
||||
file_header_el = _find_element(csv_el, "FileHeaderInfo")
|
||||
config = {
|
||||
"file_header_info": file_header_el.text.upper() if file_header_el is not None and file_header_el.text else "NONE",
|
||||
"comments": _find_element_text(csv_el, "Comments", "#"),
|
||||
"field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","),
|
||||
"record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"),
|
||||
"quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'),
|
||||
"quote_escape_character": _find_element_text(csv_el, "QuoteEscapeCharacter", '"'),
|
||||
}
|
||||
return "CSV", config
|
||||
json_el = _find_element(el, "JSON")
|
||||
if json_el is not None:
|
||||
type_el = _find_element(json_el, "Type")
|
||||
config = {
|
||||
"type": type_el.text.upper() if type_el is not None and type_el.text else "DOCUMENT",
|
||||
}
|
||||
return "JSON", config
|
||||
parquet_el = _find_element(el, "Parquet")
|
||||
if parquet_el is not None:
|
||||
return "Parquet", {}
|
||||
raise ValueError("InputSerialization must specify CSV, JSON, or Parquet")
|
||||
|
||||
|
||||
def _parse_select_output_serialization(el: Element) -> tuple:
|
||||
csv_el = _find_element(el, "CSV")
|
||||
if csv_el is not None:
|
||||
config = {
|
||||
"field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","),
|
||||
"record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"),
|
||||
"quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'),
|
||||
"quote_fields": _find_element_text(csv_el, "QuoteFields", "ASNEEDED").upper(),
|
||||
}
|
||||
return "CSV", config
|
||||
json_el = _find_element(el, "JSON")
|
||||
if json_el is not None:
|
||||
config = {
|
||||
"record_delimiter": _find_element_text(json_el, "RecordDelimiter", "\n"),
|
||||
}
|
||||
return "JSON", config
|
||||
raise ValueError("OutputSerialization must specify CSV or JSON")
|
||||
|
||||
|
||||
def _encode_select_event(event_type: str, payload: bytes) -> bytes:
|
||||
import struct
|
||||
import binascii
|
||||
headers = _build_event_headers(event_type)
|
||||
headers_length = len(headers)
|
||||
total_length = 4 + 4 + 4 + headers_length + len(payload) + 4
|
||||
prelude = struct.pack(">I", total_length) + struct.pack(">I", headers_length)
|
||||
prelude_crc = binascii.crc32(prelude) & 0xffffffff
|
||||
prelude += struct.pack(">I", prelude_crc)
|
||||
message = prelude + headers + payload
|
||||
message_crc = binascii.crc32(message) & 0xffffffff
|
||||
message += struct.pack(">I", message_crc)
|
||||
return message
|
||||
|
||||
|
||||
def _build_event_headers(event_type: str) -> bytes:
|
||||
headers = b""
|
||||
headers += _encode_select_header(":event-type", event_type)
|
||||
if event_type == "Records":
|
||||
headers += _encode_select_header(":content-type", "application/octet-stream")
|
||||
elif event_type == "Stats":
|
||||
headers += _encode_select_header(":content-type", "text/xml")
|
||||
headers += _encode_select_header(":message-type", "event")
|
||||
return headers
|
||||
|
||||
|
||||
def _encode_select_header(name: str, value: str) -> bytes:
|
||||
import struct
|
||||
name_bytes = name.encode("utf-8")
|
||||
value_bytes = value.encode("utf-8")
|
||||
header = struct.pack("B", len(name_bytes)) + name_bytes
|
||||
header += struct.pack("B", 7)
|
||||
header += struct.pack(">H", len(value_bytes)) + value_bytes
|
||||
return header
|
||||
|
||||
|
||||
def _build_stats_xml(bytes_scanned: int, bytes_returned: int) -> bytes:
|
||||
stats = Element("Stats")
|
||||
SubElement(stats, "BytesScanned").text = str(bytes_scanned)
|
||||
SubElement(stats, "BytesProcessed").text = str(bytes_scanned)
|
||||
SubElement(stats, "BytesReturned").text = str(bytes_returned)
|
||||
return tostring(stats, encoding="utf-8")
|
||||
|
||||
|
||||
@s3_api_bp.before_request
|
||||
def resolve_principal():
|
||||
g.principal = None
|
||||
|
||||
171
app/select_content.py
Normal file
171
app/select_content.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""S3 SelectObjectContent SQL query execution using DuckDB."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
try:
|
||||
import duckdb
|
||||
DUCKDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
DUCKDB_AVAILABLE = False
|
||||
|
||||
|
||||
class SelectError(Exception):
|
||||
"""Error during SELECT query execution."""
|
||||
pass
|
||||
|
||||
|
||||
def execute_select_query(
|
||||
file_path: Path,
|
||||
expression: str,
|
||||
input_format: str,
|
||||
input_config: Dict[str, Any],
|
||||
output_format: str,
|
||||
output_config: Dict[str, Any],
|
||||
chunk_size: int = 65536,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""Execute SQL query on object content."""
|
||||
if not DUCKDB_AVAILABLE:
|
||||
raise SelectError("DuckDB is not installed. Install with: pip install duckdb")
|
||||
|
||||
conn = duckdb.connect(":memory:")
|
||||
|
||||
try:
|
||||
if input_format == "CSV":
|
||||
_load_csv(conn, file_path, input_config)
|
||||
elif input_format == "JSON":
|
||||
_load_json(conn, file_path, input_config)
|
||||
elif input_format == "Parquet":
|
||||
_load_parquet(conn, file_path)
|
||||
else:
|
||||
raise SelectError(f"Unsupported input format: {input_format}")
|
||||
|
||||
normalized_expression = expression.replace("s3object", "data").replace("S3Object", "data")
|
||||
|
||||
try:
|
||||
result = conn.execute(normalized_expression)
|
||||
except duckdb.Error as exc:
|
||||
raise SelectError(f"SQL execution error: {exc}")
|
||||
|
||||
if output_format == "CSV":
|
||||
yield from _output_csv(result, output_config, chunk_size)
|
||||
elif output_format == "JSON":
|
||||
yield from _output_json(result, output_config, chunk_size)
|
||||
else:
|
||||
raise SelectError(f"Unsupported output format: {output_format}")
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _load_csv(conn, file_path: Path, config: Dict[str, Any]) -> None:
|
||||
"""Load CSV file into DuckDB."""
|
||||
file_header_info = config.get("file_header_info", "NONE")
|
||||
delimiter = config.get("field_delimiter", ",")
|
||||
quote = config.get("quote_character", '"')
|
||||
|
||||
header = file_header_info in ("USE", "IGNORE")
|
||||
path_str = str(file_path).replace("\\", "/")
|
||||
|
||||
conn.execute(f"""
|
||||
CREATE TABLE data AS
|
||||
SELECT * FROM read_csv('{path_str}',
|
||||
header={header},
|
||||
delim='{delimiter}',
|
||||
quote='{quote}'
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def _load_json(conn, file_path: Path, config: Dict[str, Any]) -> None:
|
||||
"""Load JSON file into DuckDB."""
|
||||
json_type = config.get("type", "DOCUMENT")
|
||||
path_str = str(file_path).replace("\\", "/")
|
||||
|
||||
if json_type == "LINES":
|
||||
conn.execute(f"""
|
||||
CREATE TABLE data AS
|
||||
SELECT * FROM read_json_auto('{path_str}', format='newline_delimited')
|
||||
""")
|
||||
else:
|
||||
conn.execute(f"""
|
||||
CREATE TABLE data AS
|
||||
SELECT * FROM read_json_auto('{path_str}', format='array')
|
||||
""")
|
||||
|
||||
|
||||
def _load_parquet(conn, file_path: Path) -> None:
|
||||
"""Load Parquet file into DuckDB."""
|
||||
path_str = str(file_path).replace("\\", "/")
|
||||
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_parquet('{path_str}')")
|
||||
|
||||
|
||||
def _output_csv(
|
||||
result,
|
||||
config: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""Output query results as CSV."""
|
||||
delimiter = config.get("field_delimiter", ",")
|
||||
record_delimiter = config.get("record_delimiter", "\n")
|
||||
quote = config.get("quote_character", '"')
|
||||
|
||||
buffer = ""
|
||||
|
||||
while True:
|
||||
rows = result.fetchmany(1000)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
fields = []
|
||||
for value in row:
|
||||
if value is None:
|
||||
fields.append("")
|
||||
elif isinstance(value, str):
|
||||
if delimiter in value or quote in value or record_delimiter in value:
|
||||
escaped = value.replace(quote, quote + quote)
|
||||
fields.append(f'{quote}{escaped}{quote}')
|
||||
else:
|
||||
fields.append(value)
|
||||
else:
|
||||
fields.append(str(value))
|
||||
|
||||
buffer += delimiter.join(fields) + record_delimiter
|
||||
|
||||
while len(buffer) >= chunk_size:
|
||||
yield buffer[:chunk_size].encode("utf-8")
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
if buffer:
|
||||
yield buffer.encode("utf-8")
|
||||
|
||||
|
||||
def _output_json(
|
||||
result,
|
||||
config: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""Output query results as JSON Lines."""
|
||||
record_delimiter = config.get("record_delimiter", "\n")
|
||||
columns = [desc[0] for desc in result.description]
|
||||
|
||||
buffer = ""
|
||||
|
||||
while True:
|
||||
rows = result.fetchmany(1000)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
record = dict(zip(columns, row))
|
||||
buffer += json.dumps(record, default=str) + record_delimiter
|
||||
|
||||
while len(buffer) >= chunk_size:
|
||||
yield buffer[:chunk_size].encode("utf-8")
|
||||
buffer = buffer[chunk_size:]
|
||||
|
||||
if buffer:
|
||||
yield buffer.encode("utf-8")
|
||||
@@ -999,6 +999,102 @@ class ObjectStorage:
|
||||
|
||||
return record["etag"]
|
||||
|
||||
def upload_part_copy(
|
||||
self,
|
||||
bucket_name: str,
|
||||
upload_id: str,
|
||||
part_number: int,
|
||||
source_bucket: str,
|
||||
source_key: str,
|
||||
start_byte: Optional[int] = None,
|
||||
end_byte: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Copy a range from an existing object as a multipart part."""
|
||||
if part_number < 1 or part_number > 10000:
|
||||
raise StorageError("part_number must be between 1 and 10000")
|
||||
|
||||
source_path = self.get_object_path(source_bucket, source_key)
|
||||
source_size = source_path.stat().st_size
|
||||
|
||||
if start_byte is None:
|
||||
start_byte = 0
|
||||
if end_byte is None:
|
||||
end_byte = source_size - 1
|
||||
|
||||
if start_byte < 0 or end_byte >= source_size or start_byte > end_byte:
|
||||
raise StorageError("Invalid byte range")
|
||||
|
||||
bucket_path = self._bucket_path(bucket_name)
|
||||
upload_root = self._multipart_dir(bucket_path.name, upload_id)
|
||||
if not upload_root.exists():
|
||||
upload_root = self._legacy_multipart_dir(bucket_path.name, upload_id)
|
||||
if not upload_root.exists():
|
||||
raise StorageError("Multipart upload not found")
|
||||
|
||||
checksum = hashlib.md5()
|
||||
part_filename = f"part-{part_number:05d}.part"
|
||||
part_path = upload_root / part_filename
|
||||
temp_path = upload_root / f".{part_filename}.tmp"
|
||||
|
||||
try:
|
||||
with source_path.open("rb") as src:
|
||||
src.seek(start_byte)
|
||||
bytes_to_copy = end_byte - start_byte + 1
|
||||
with temp_path.open("wb") as target:
|
||||
remaining = bytes_to_copy
|
||||
while remaining > 0:
|
||||
chunk_size = min(65536, remaining)
|
||||
chunk = src.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
checksum.update(chunk)
|
||||
target.write(chunk)
|
||||
remaining -= len(chunk)
|
||||
temp_path.replace(part_path)
|
||||
except OSError:
|
||||
try:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
record = {
|
||||
"etag": checksum.hexdigest(),
|
||||
"size": part_path.stat().st_size,
|
||||
"filename": part_filename,
|
||||
}
|
||||
|
||||
manifest_path = upload_root / self.MULTIPART_MANIFEST
|
||||
lock_path = upload_root / ".manifest.lock"
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with lock_path.open("w") as lock_file:
|
||||
with _file_lock(lock_file):
|
||||
try:
|
||||
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(0.1 * (attempt + 1))
|
||||
continue
|
||||
raise StorageError("Multipart manifest unreadable") from exc
|
||||
|
||||
parts = manifest.setdefault("parts", {})
|
||||
parts[str(part_number)] = record
|
||||
manifest_path.write_text(json.dumps(manifest), encoding="utf-8")
|
||||
break
|
||||
except OSError as exc:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(0.1 * (attempt + 1))
|
||||
continue
|
||||
raise StorageError(f"Failed to update multipart manifest: {exc}") from exc
|
||||
|
||||
return {
|
||||
"etag": record["etag"],
|
||||
"last_modified": datetime.fromtimestamp(part_path.stat().st_mtime, timezone.utc),
|
||||
}
|
||||
|
||||
def complete_multipart_upload(
|
||||
self,
|
||||
bucket_name: str,
|
||||
|
||||
@@ -9,4 +9,5 @@ boto3>=1.42.14
|
||||
waitress>=3.0.2
|
||||
psutil>=7.1.3
|
||||
cryptography>=46.0.3
|
||||
defusedxml>=0.7.1
|
||||
defusedxml>=0.7.1
|
||||
duckdb>=1.4.4
|
||||
Reference in New Issue
Block a user