Add 4 new S3 APIs: UploadPartCopy, Bucket Replication, PostObject, SelectObjectContent

This commit is contained in:
2026-01-29 12:51:00 +08:00
parent 0ea54457e8
commit 9385d1fe1c
5 changed files with 742 additions and 4 deletions

View File

@@ -137,6 +137,7 @@ class ReplicationRule:
stats: ReplicationStats = field(default_factory=ReplicationStats)
sync_deletions: bool = True
last_pull_at: Optional[float] = None
filter_prefix: Optional[str] = None
def to_dict(self) -> dict:
return {
@@ -149,6 +150,7 @@ class ReplicationRule:
"stats": self.stats.to_dict(),
"sync_deletions": self.sync_deletions,
"last_pull_at": self.last_pull_at,
"filter_prefix": self.filter_prefix,
}
@classmethod
@@ -162,6 +164,8 @@ class ReplicationRule:
data["sync_deletions"] = True
if "last_pull_at" not in data:
data["last_pull_at"] = None
if "filter_prefix" not in data:
data["filter_prefix"] = None
rule = cls(**data)
rule.stats = ReplicationStats.from_dict(stats_data) if stats_data else ReplicationStats()
return rule

View File

@@ -962,6 +962,7 @@ def _maybe_handle_bucket_subresource(bucket_name: str) -> Response | None:
"logging": _bucket_logging_handler,
"uploads": _bucket_uploads_handler,
"policy": _bucket_policy_handler,
"replication": _bucket_replication_handler,
}
requested = [key for key in handlers if key in request.args]
if not requested:
@@ -2196,6 +2197,134 @@ def _bulk_delete_handler(bucket_name: str) -> Response:
return _xml_response(result, status=200)
def _post_object(bucket_name: str) -> Response:
storage = _storage()
if not storage.bucket_exists(bucket_name):
return _error_response("NoSuchBucket", "Bucket does not exist", 404)
object_key = request.form.get("key")
policy_b64 = request.form.get("policy")
signature = request.form.get("x-amz-signature")
credential = request.form.get("x-amz-credential")
algorithm = request.form.get("x-amz-algorithm")
amz_date = request.form.get("x-amz-date")
if not all([object_key, policy_b64, signature, credential, algorithm, amz_date]):
return _error_response("InvalidArgument", "Missing required form fields", 400)
if algorithm != "AWS4-HMAC-SHA256":
return _error_response("InvalidArgument", "Unsupported signing algorithm", 400)
try:
policy_json = base64.b64decode(policy_b64).decode("utf-8")
policy = __import__("json").loads(policy_json)
except (ValueError, __import__("json").JSONDecodeError) as exc:
return _error_response("InvalidPolicyDocument", f"Invalid policy: {exc}", 400)
expiration = policy.get("expiration")
if expiration:
try:
exp_time = datetime.fromisoformat(expiration.replace("Z", "+00:00"))
if datetime.now(timezone.utc) > exp_time:
return _error_response("AccessDenied", "Policy expired", 403)
except ValueError:
return _error_response("InvalidPolicyDocument", "Invalid expiration format", 400)
conditions = policy.get("conditions", [])
validation_error = _validate_post_policy_conditions(bucket_name, object_key, conditions, request.form, request.content_length or 0)
if validation_error:
return _error_response("AccessDenied", validation_error, 403)
try:
parts = credential.split("/")
if len(parts) != 5:
raise ValueError("Invalid credential format")
access_key, date_stamp, region, service, _ = parts
except ValueError:
return _error_response("InvalidArgument", "Invalid credential format", 400)
secret_key = _iam().get_secret_key(access_key)
if not secret_key:
return _error_response("AccessDenied", "Invalid access key", 403)
signing_key = _derive_signing_key(secret_key, date_stamp, region, service)
expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expected_signature, signature):
return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403)
file = request.files.get("file")
if not file:
return _error_response("InvalidArgument", "Missing file field", 400)
if "${filename}" in object_key:
object_key = object_key.replace("${filename}", file.filename or "upload")
metadata = {}
for field_name, value in request.form.items():
if field_name.lower().startswith("x-amz-meta-"):
key = field_name[11:]
if key:
metadata[key] = value
try:
meta = storage.put_object(bucket_name, object_key, file.stream, metadata=metadata or None)
except QuotaExceededError as exc:
return _error_response("QuotaExceeded", str(exc), 403)
except StorageError as exc:
return _error_response("InvalidArgument", str(exc), 400)
current_app.logger.info("Object uploaded via POST", extra={"bucket": bucket_name, "key": object_key, "size": meta.size})
success_action_status = request.form.get("success_action_status", "204")
success_action_redirect = request.form.get("success_action_redirect")
if success_action_redirect:
redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}"
return Response(status=303, headers={"Location": redirect_url})
if success_action_status == "200":
root = Element("PostResponse")
SubElement(root, "Location").text = f"/{bucket_name}/{object_key}"
SubElement(root, "Bucket").text = bucket_name
SubElement(root, "Key").text = object_key
SubElement(root, "ETag").text = f'"{meta.etag}"'
return _xml_response(root, status=200)
if success_action_status == "201":
root = Element("PostResponse")
SubElement(root, "Location").text = f"/{bucket_name}/{object_key}"
SubElement(root, "Bucket").text = bucket_name
SubElement(root, "Key").text = object_key
SubElement(root, "ETag").text = f'"{meta.etag}"'
return _xml_response(root, status=201)
return Response(status=204)
def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditions: list, form_data, content_length: int) -> Optional[str]:
for condition in conditions:
if isinstance(condition, dict):
for key, expected_value in condition.items():
if key == "bucket":
if bucket_name != expected_value:
return f"Bucket must be {expected_value}"
elif key == "key":
if object_key != expected_value:
return f"Key must be {expected_value}"
else:
actual_value = form_data.get(key, "")
if actual_value != expected_value:
return f"Field {key} must be {expected_value}"
elif isinstance(condition, list) and len(condition) >= 2:
operator = condition[0].lower() if isinstance(condition[0], str) else ""
if operator == "starts-with" and len(condition) == 3:
field = condition[1].lstrip("$")
prefix = condition[2]
if field == "key":
if not object_key.startswith(prefix):
return f"Key must start with {prefix}"
else:
actual_value = form_data.get(field, "")
if not actual_value.startswith(prefix):
return f"Field {field} must start with {prefix}"
elif operator == "eq" and len(condition) == 3:
field = condition[1].lstrip("$")
expected = condition[2]
if field == "key":
if object_key != expected:
return f"Key must equal {expected}"
else:
actual_value = form_data.get(field, "")
if actual_value != expected:
return f"Field {field} must equal {expected}"
elif operator == "content-length-range" and len(condition) == 3:
min_size, max_size = condition[1], condition[2]
if content_length < min_size or content_length > max_size:
return f"Content length must be between {min_size} and {max_size}"
return None
@s3_api_bp.get("/")
@limiter.limit(_get_list_buckets_limit)
def list_buckets() -> Response:
@@ -2233,9 +2362,12 @@ def bucket_handler(bucket_name: str) -> Response:
return subresource_response
if request.method == "POST":
if "delete" not in request.args:
return _method_not_allowed(["GET", "PUT", "DELETE"])
return _bulk_delete_handler(bucket_name)
if "delete" in request.args:
return _bulk_delete_handler(bucket_name)
content_type = request.headers.get("Content-Type", "")
if "multipart/form-data" in content_type:
return _post_object(bucket_name)
return _method_not_allowed(["GET", "PUT", "DELETE"])
if request.method == "PUT":
principal, error = _require_principal()
@@ -2433,6 +2565,8 @@ def object_handler(bucket_name: str, object_key: str):
return _initiate_multipart_upload(bucket_name, object_key)
if "uploadId" in request.args:
return _complete_multipart_upload(bucket_name, object_key)
if "select" in request.args:
return _select_object_content(bucket_name, object_key)
return _method_not_allowed(["GET", "PUT", "DELETE", "HEAD", "POST"])
if request.method == "PUT":
@@ -2732,6 +2866,120 @@ def _bucket_policy_handler(bucket_name: str) -> Response:
return Response(status=204)
def _bucket_replication_handler(bucket_name: str) -> Response:
if request.method not in {"GET", "PUT", "DELETE"}:
return _method_not_allowed(["GET", "PUT", "DELETE"])
principal, error = _require_principal()
if error:
return error
try:
_authorize_action(principal, bucket_name, "policy")
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
storage = _storage()
if not storage.bucket_exists(bucket_name):
return _error_response("NoSuchBucket", "Bucket does not exist", 404)
replication = _replication_manager()
if request.method == "GET":
rule = replication.get_rule(bucket_name)
if not rule:
return _error_response("ReplicationConfigurationNotFoundError", "Replication configuration not found", 404)
return _xml_response(_render_replication_config(rule))
if request.method == "DELETE":
replication.delete_rule(bucket_name)
current_app.logger.info("Bucket replication removed", extra={"bucket": bucket_name})
return Response(status=204)
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
rule = _parse_replication_config(bucket_name, payload)
except ValueError as exc:
return _error_response("MalformedXML", str(exc), 400)
replication.set_rule(rule)
current_app.logger.info("Bucket replication updated", extra={"bucket": bucket_name})
return Response(status=200)
def _parse_replication_config(bucket_name: str, payload: bytes):
from .replication import ReplicationRule, REPLICATION_MODE_ALL
root = fromstring(payload)
if _strip_ns(root.tag) != "ReplicationConfiguration":
raise ValueError("Root element must be ReplicationConfiguration")
rule_el = None
for child in list(root):
if _strip_ns(child.tag) == "Rule":
rule_el = child
break
if rule_el is None:
raise ValueError("At least one Rule is required")
status_el = _find_element(rule_el, "Status")
status = status_el.text if status_el is not None and status_el.text else "Enabled"
enabled = status.lower() == "enabled"
filter_prefix = None
filter_el = _find_element(rule_el, "Filter")
if filter_el is not None:
prefix_el = _find_element(filter_el, "Prefix")
if prefix_el is not None and prefix_el.text:
filter_prefix = prefix_el.text
dest_el = _find_element(rule_el, "Destination")
if dest_el is None:
raise ValueError("Destination element is required")
bucket_el = _find_element(dest_el, "Bucket")
if bucket_el is None or not bucket_el.text:
raise ValueError("Destination Bucket is required")
target_bucket, target_connection_id = _parse_destination_arn(bucket_el.text)
sync_deletions = True
dm_el = _find_element(rule_el, "DeleteMarkerReplication")
if dm_el is not None:
dm_status_el = _find_element(dm_el, "Status")
if dm_status_el is not None and dm_status_el.text:
sync_deletions = dm_status_el.text.lower() == "enabled"
return ReplicationRule(
bucket_name=bucket_name,
target_connection_id=target_connection_id,
target_bucket=target_bucket,
enabled=enabled,
mode=REPLICATION_MODE_ALL,
sync_deletions=sync_deletions,
filter_prefix=filter_prefix,
)
def _parse_destination_arn(arn: str) -> tuple:
if not arn.startswith("arn:aws:s3:::"):
raise ValueError(f"Invalid ARN format: {arn}")
bucket_part = arn[13:]
if "/" in bucket_part:
connection_id, bucket_name = bucket_part.split("/", 1)
else:
connection_id = "local"
bucket_name = bucket_part
return bucket_name, connection_id
def _render_replication_config(rule) -> Element:
root = Element("ReplicationConfiguration")
SubElement(root, "Role").text = "arn:aws:iam::000000000000:role/replication"
rule_el = SubElement(root, "Rule")
SubElement(rule_el, "ID").text = f"{rule.bucket_name}-replication"
SubElement(rule_el, "Status").text = "Enabled" if rule.enabled else "Disabled"
SubElement(rule_el, "Priority").text = "1"
filter_el = SubElement(rule_el, "Filter")
if rule.filter_prefix:
SubElement(filter_el, "Prefix").text = rule.filter_prefix
dest_el = SubElement(rule_el, "Destination")
if rule.target_connection_id == "local":
arn = f"arn:aws:s3:::{rule.target_bucket}"
else:
arn = f"arn:aws:s3:::{rule.target_connection_id}/{rule.target_bucket}"
SubElement(dest_el, "Bucket").text = arn
dm_el = SubElement(rule_el, "DeleteMarkerReplication")
SubElement(dm_el, "Status").text = "Enabled" if rule.sync_deletions else "Disabled"
return root
@s3_api_bp.route("/<bucket_name>", methods=["HEAD"])
@limiter.limit(_get_head_ops_limit)
def head_bucket(bucket_name: str) -> Response:
@@ -3009,6 +3257,10 @@ def _initiate_multipart_upload(bucket_name: str, object_key: str) -> Response:
def _upload_part(bucket_name: str, object_key: str) -> Response:
copy_source = request.headers.get("x-amz-copy-source")
if copy_source:
return _upload_part_copy(bucket_name, object_key, copy_source)
principal, error = _object_principal("write", bucket_name, object_key)
if error:
return error
@@ -3042,6 +3294,62 @@ def _upload_part(bucket_name: str, object_key: str) -> Response:
return response
def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Response:
principal, error = _object_principal("write", bucket_name, object_key)
if error:
return error
upload_id = request.args.get("uploadId")
part_number_str = request.args.get("partNumber")
if not upload_id or not part_number_str:
return _error_response("InvalidArgument", "uploadId and partNumber are required", 400)
try:
part_number = int(part_number_str)
except ValueError:
return _error_response("InvalidArgument", "partNumber must be an integer", 400)
copy_source = unquote(copy_source)
if copy_source.startswith("/"):
copy_source = copy_source[1:]
parts = copy_source.split("/", 1)
if len(parts) != 2:
return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400)
source_bucket, source_key = parts
_, read_error = _object_principal("read", source_bucket, source_key)
if read_error:
return read_error
copy_source_range = request.headers.get("x-amz-copy-source-range")
start_byte, end_byte = None, None
if copy_source_range:
match = re.match(r"bytes=(\d+)-(\d+)", copy_source_range)
if not match:
return _error_response("InvalidArgument", "Invalid x-amz-copy-source-range format", 400)
start_byte, end_byte = int(match.group(1)), int(match.group(2))
try:
result = _storage().upload_part_copy(
bucket_name, upload_id, part_number,
source_bucket, source_key,
start_byte, end_byte
)
except ObjectNotFoundError:
return _error_response("NoSuchKey", "Source object not found", 404)
except StorageError as exc:
if "Multipart upload not found" in str(exc):
return _error_response("NoSuchUpload", str(exc), 404)
if "Invalid byte range" in str(exc):
return _error_response("InvalidRange", str(exc), 416)
return _error_response("InvalidArgument", str(exc), 400)
root = Element("CopyPartResult")
SubElement(root, "LastModified").text = result["last_modified"].strftime("%Y-%m-%dT%H:%M:%S.000Z")
SubElement(root, "ETag").text = f'"{result["etag"]}"'
return _xml_response(root)
def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response:
principal, error = _object_principal("write", bucket_name, object_key)
if error:
@@ -3126,6 +3434,164 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response:
return Response(status=204)
def _select_object_content(bucket_name: str, object_key: str) -> Response:
_, error = _object_principal("read", bucket_name, object_key)
if error:
return error
ct_error = _require_xml_content_type()
if ct_error:
return ct_error
payload = request.get_data(cache=False) or b""
try:
root = fromstring(payload)
except ParseError:
return _error_response("MalformedXML", "Unable to parse XML document", 400)
if _strip_ns(root.tag) != "SelectObjectContentRequest":
return _error_response("MalformedXML", "Root element must be SelectObjectContentRequest", 400)
expression_el = _find_element(root, "Expression")
if expression_el is None or not expression_el.text:
return _error_response("InvalidRequest", "Expression is required", 400)
expression = expression_el.text
expression_type_el = _find_element(root, "ExpressionType")
expression_type = expression_type_el.text if expression_type_el is not None and expression_type_el.text else "SQL"
if expression_type.upper() != "SQL":
return _error_response("InvalidRequest", "Only SQL expression type is supported", 400)
input_el = _find_element(root, "InputSerialization")
if input_el is None:
return _error_response("InvalidRequest", "InputSerialization is required", 400)
try:
input_format, input_config = _parse_select_input_serialization(input_el)
except ValueError as exc:
return _error_response("InvalidRequest", str(exc), 400)
output_el = _find_element(root, "OutputSerialization")
if output_el is None:
return _error_response("InvalidRequest", "OutputSerialization is required", 400)
try:
output_format, output_config = _parse_select_output_serialization(output_el)
except ValueError as exc:
return _error_response("InvalidRequest", str(exc), 400)
storage = _storage()
try:
path = storage.get_object_path(bucket_name, object_key)
except ObjectNotFoundError:
return _error_response("NoSuchKey", "Object not found", 404)
except StorageError:
return _error_response("NoSuchKey", "Object not found", 404)
from .select_content import execute_select_query, SelectError
try:
result_stream = execute_select_query(
file_path=path,
expression=expression,
input_format=input_format,
input_config=input_config,
output_format=output_format,
output_config=output_config,
)
except SelectError as exc:
return _error_response("InvalidRequest", str(exc), 400)
def generate_events():
bytes_scanned = 0
bytes_returned = 0
for chunk in result_stream:
bytes_returned += len(chunk)
yield _encode_select_event("Records", chunk)
stats_payload = _build_stats_xml(bytes_scanned, bytes_returned)
yield _encode_select_event("Stats", stats_payload)
yield _encode_select_event("End", b"")
return Response(generate_events(), mimetype="application/octet-stream", headers={"x-amz-request-charged": "requester"})
def _parse_select_input_serialization(el: Element) -> tuple:
csv_el = _find_element(el, "CSV")
if csv_el is not None:
file_header_el = _find_element(csv_el, "FileHeaderInfo")
config = {
"file_header_info": file_header_el.text.upper() if file_header_el is not None and file_header_el.text else "NONE",
"comments": _find_element_text(csv_el, "Comments", "#"),
"field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","),
"record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"),
"quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'),
"quote_escape_character": _find_element_text(csv_el, "QuoteEscapeCharacter", '"'),
}
return "CSV", config
json_el = _find_element(el, "JSON")
if json_el is not None:
type_el = _find_element(json_el, "Type")
config = {
"type": type_el.text.upper() if type_el is not None and type_el.text else "DOCUMENT",
}
return "JSON", config
parquet_el = _find_element(el, "Parquet")
if parquet_el is not None:
return "Parquet", {}
raise ValueError("InputSerialization must specify CSV, JSON, or Parquet")
def _parse_select_output_serialization(el: Element) -> tuple:
csv_el = _find_element(el, "CSV")
if csv_el is not None:
config = {
"field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","),
"record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"),
"quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'),
"quote_fields": _find_element_text(csv_el, "QuoteFields", "ASNEEDED").upper(),
}
return "CSV", config
json_el = _find_element(el, "JSON")
if json_el is not None:
config = {
"record_delimiter": _find_element_text(json_el, "RecordDelimiter", "\n"),
}
return "JSON", config
raise ValueError("OutputSerialization must specify CSV or JSON")
def _encode_select_event(event_type: str, payload: bytes) -> bytes:
import struct
import binascii
headers = _build_event_headers(event_type)
headers_length = len(headers)
total_length = 4 + 4 + 4 + headers_length + len(payload) + 4
prelude = struct.pack(">I", total_length) + struct.pack(">I", headers_length)
prelude_crc = binascii.crc32(prelude) & 0xffffffff
prelude += struct.pack(">I", prelude_crc)
message = prelude + headers + payload
message_crc = binascii.crc32(message) & 0xffffffff
message += struct.pack(">I", message_crc)
return message
def _build_event_headers(event_type: str) -> bytes:
headers = b""
headers += _encode_select_header(":event-type", event_type)
if event_type == "Records":
headers += _encode_select_header(":content-type", "application/octet-stream")
elif event_type == "Stats":
headers += _encode_select_header(":content-type", "text/xml")
headers += _encode_select_header(":message-type", "event")
return headers
def _encode_select_header(name: str, value: str) -> bytes:
import struct
name_bytes = name.encode("utf-8")
value_bytes = value.encode("utf-8")
header = struct.pack("B", len(name_bytes)) + name_bytes
header += struct.pack("B", 7)
header += struct.pack(">H", len(value_bytes)) + value_bytes
return header
def _build_stats_xml(bytes_scanned: int, bytes_returned: int) -> bytes:
stats = Element("Stats")
SubElement(stats, "BytesScanned").text = str(bytes_scanned)
SubElement(stats, "BytesProcessed").text = str(bytes_scanned)
SubElement(stats, "BytesReturned").text = str(bytes_returned)
return tostring(stats, encoding="utf-8")
@s3_api_bp.before_request
def resolve_principal():
g.principal = None

171
app/select_content.py Normal file
View File

@@ -0,0 +1,171 @@
"""S3 SelectObjectContent SQL query execution using DuckDB."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, Generator, Optional
try:
import duckdb
DUCKDB_AVAILABLE = True
except ImportError:
DUCKDB_AVAILABLE = False
class SelectError(Exception):
"""Error during SELECT query execution."""
pass
def execute_select_query(
file_path: Path,
expression: str,
input_format: str,
input_config: Dict[str, Any],
output_format: str,
output_config: Dict[str, Any],
chunk_size: int = 65536,
) -> Generator[bytes, None, None]:
"""Execute SQL query on object content."""
if not DUCKDB_AVAILABLE:
raise SelectError("DuckDB is not installed. Install with: pip install duckdb")
conn = duckdb.connect(":memory:")
try:
if input_format == "CSV":
_load_csv(conn, file_path, input_config)
elif input_format == "JSON":
_load_json(conn, file_path, input_config)
elif input_format == "Parquet":
_load_parquet(conn, file_path)
else:
raise SelectError(f"Unsupported input format: {input_format}")
normalized_expression = expression.replace("s3object", "data").replace("S3Object", "data")
try:
result = conn.execute(normalized_expression)
except duckdb.Error as exc:
raise SelectError(f"SQL execution error: {exc}")
if output_format == "CSV":
yield from _output_csv(result, output_config, chunk_size)
elif output_format == "JSON":
yield from _output_json(result, output_config, chunk_size)
else:
raise SelectError(f"Unsupported output format: {output_format}")
finally:
conn.close()
def _load_csv(conn, file_path: Path, config: Dict[str, Any]) -> None:
"""Load CSV file into DuckDB."""
file_header_info = config.get("file_header_info", "NONE")
delimiter = config.get("field_delimiter", ",")
quote = config.get("quote_character", '"')
header = file_header_info in ("USE", "IGNORE")
path_str = str(file_path).replace("\\", "/")
conn.execute(f"""
CREATE TABLE data AS
SELECT * FROM read_csv('{path_str}',
header={header},
delim='{delimiter}',
quote='{quote}'
)
""")
def _load_json(conn, file_path: Path, config: Dict[str, Any]) -> None:
"""Load JSON file into DuckDB."""
json_type = config.get("type", "DOCUMENT")
path_str = str(file_path).replace("\\", "/")
if json_type == "LINES":
conn.execute(f"""
CREATE TABLE data AS
SELECT * FROM read_json_auto('{path_str}', format='newline_delimited')
""")
else:
conn.execute(f"""
CREATE TABLE data AS
SELECT * FROM read_json_auto('{path_str}', format='array')
""")
def _load_parquet(conn, file_path: Path) -> None:
"""Load Parquet file into DuckDB."""
path_str = str(file_path).replace("\\", "/")
conn.execute(f"CREATE TABLE data AS SELECT * FROM read_parquet('{path_str}')")
def _output_csv(
result,
config: Dict[str, Any],
chunk_size: int,
) -> Generator[bytes, None, None]:
"""Output query results as CSV."""
delimiter = config.get("field_delimiter", ",")
record_delimiter = config.get("record_delimiter", "\n")
quote = config.get("quote_character", '"')
buffer = ""
while True:
rows = result.fetchmany(1000)
if not rows:
break
for row in rows:
fields = []
for value in row:
if value is None:
fields.append("")
elif isinstance(value, str):
if delimiter in value or quote in value or record_delimiter in value:
escaped = value.replace(quote, quote + quote)
fields.append(f'{quote}{escaped}{quote}')
else:
fields.append(value)
else:
fields.append(str(value))
buffer += delimiter.join(fields) + record_delimiter
while len(buffer) >= chunk_size:
yield buffer[:chunk_size].encode("utf-8")
buffer = buffer[chunk_size:]
if buffer:
yield buffer.encode("utf-8")
def _output_json(
result,
config: Dict[str, Any],
chunk_size: int,
) -> Generator[bytes, None, None]:
"""Output query results as JSON Lines."""
record_delimiter = config.get("record_delimiter", "\n")
columns = [desc[0] for desc in result.description]
buffer = ""
while True:
rows = result.fetchmany(1000)
if not rows:
break
for row in rows:
record = dict(zip(columns, row))
buffer += json.dumps(record, default=str) + record_delimiter
while len(buffer) >= chunk_size:
yield buffer[:chunk_size].encode("utf-8")
buffer = buffer[chunk_size:]
if buffer:
yield buffer.encode("utf-8")

View File

@@ -999,6 +999,102 @@ class ObjectStorage:
return record["etag"]
def upload_part_copy(
self,
bucket_name: str,
upload_id: str,
part_number: int,
source_bucket: str,
source_key: str,
start_byte: Optional[int] = None,
end_byte: Optional[int] = None,
) -> Dict[str, Any]:
"""Copy a range from an existing object as a multipart part."""
if part_number < 1 or part_number > 10000:
raise StorageError("part_number must be between 1 and 10000")
source_path = self.get_object_path(source_bucket, source_key)
source_size = source_path.stat().st_size
if start_byte is None:
start_byte = 0
if end_byte is None:
end_byte = source_size - 1
if start_byte < 0 or end_byte >= source_size or start_byte > end_byte:
raise StorageError("Invalid byte range")
bucket_path = self._bucket_path(bucket_name)
upload_root = self._multipart_dir(bucket_path.name, upload_id)
if not upload_root.exists():
upload_root = self._legacy_multipart_dir(bucket_path.name, upload_id)
if not upload_root.exists():
raise StorageError("Multipart upload not found")
checksum = hashlib.md5()
part_filename = f"part-{part_number:05d}.part"
part_path = upload_root / part_filename
temp_path = upload_root / f".{part_filename}.tmp"
try:
with source_path.open("rb") as src:
src.seek(start_byte)
bytes_to_copy = end_byte - start_byte + 1
with temp_path.open("wb") as target:
remaining = bytes_to_copy
while remaining > 0:
chunk_size = min(65536, remaining)
chunk = src.read(chunk_size)
if not chunk:
break
checksum.update(chunk)
target.write(chunk)
remaining -= len(chunk)
temp_path.replace(part_path)
except OSError:
try:
temp_path.unlink(missing_ok=True)
except OSError:
pass
raise
record = {
"etag": checksum.hexdigest(),
"size": part_path.stat().st_size,
"filename": part_filename,
}
manifest_path = upload_root / self.MULTIPART_MANIFEST
lock_path = upload_root / ".manifest.lock"
max_retries = 3
for attempt in range(max_retries):
try:
with lock_path.open("w") as lock_file:
with _file_lock(lock_file):
try:
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError) as exc:
if attempt < max_retries - 1:
time.sleep(0.1 * (attempt + 1))
continue
raise StorageError("Multipart manifest unreadable") from exc
parts = manifest.setdefault("parts", {})
parts[str(part_number)] = record
manifest_path.write_text(json.dumps(manifest), encoding="utf-8")
break
except OSError as exc:
if attempt < max_retries - 1:
time.sleep(0.1 * (attempt + 1))
continue
raise StorageError(f"Failed to update multipart manifest: {exc}") from exc
return {
"etag": record["etag"],
"last_modified": datetime.fromtimestamp(part_path.stat().st_mtime, timezone.utc),
}
def complete_multipart_upload(
self,
bucket_name: str,

View File

@@ -9,4 +9,5 @@ boto3>=1.42.14
waitress>=3.0.2
psutil>=7.1.3
cryptography>=46.0.3
defusedxml>=0.7.1
defusedxml>=0.7.1
duckdb>=1.4.4