diff --git a/app/replication.py b/app/replication.py index c0023dd..8cde3e8 100644 --- a/app/replication.py +++ b/app/replication.py @@ -137,6 +137,7 @@ class ReplicationRule: stats: ReplicationStats = field(default_factory=ReplicationStats) sync_deletions: bool = True last_pull_at: Optional[float] = None + filter_prefix: Optional[str] = None def to_dict(self) -> dict: return { @@ -149,6 +150,7 @@ class ReplicationRule: "stats": self.stats.to_dict(), "sync_deletions": self.sync_deletions, "last_pull_at": self.last_pull_at, + "filter_prefix": self.filter_prefix, } @classmethod @@ -162,6 +164,8 @@ class ReplicationRule: data["sync_deletions"] = True if "last_pull_at" not in data: data["last_pull_at"] = None + if "filter_prefix" not in data: + data["filter_prefix"] = None rule = cls(**data) rule.stats = ReplicationStats.from_dict(stats_data) if stats_data else ReplicationStats() return rule diff --git a/app/s3_api.py b/app/s3_api.py index d6a79c6..ab410ea 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -962,6 +962,7 @@ def _maybe_handle_bucket_subresource(bucket_name: str) -> Response | None: "logging": _bucket_logging_handler, "uploads": _bucket_uploads_handler, "policy": _bucket_policy_handler, + "replication": _bucket_replication_handler, } requested = [key for key in handlers if key in request.args] if not requested: @@ -2196,6 +2197,134 @@ def _bulk_delete_handler(bucket_name: str) -> Response: return _xml_response(result, status=200) +def _post_object(bucket_name: str) -> Response: + storage = _storage() + if not storage.bucket_exists(bucket_name): + return _error_response("NoSuchBucket", "Bucket does not exist", 404) + object_key = request.form.get("key") + policy_b64 = request.form.get("policy") + signature = request.form.get("x-amz-signature") + credential = request.form.get("x-amz-credential") + algorithm = request.form.get("x-amz-algorithm") + amz_date = request.form.get("x-amz-date") + if not all([object_key, policy_b64, signature, credential, algorithm, amz_date]): + return _error_response("InvalidArgument", "Missing required form fields", 400) + if algorithm != "AWS4-HMAC-SHA256": + return _error_response("InvalidArgument", "Unsupported signing algorithm", 400) + try: + policy_json = base64.b64decode(policy_b64).decode("utf-8") + policy = __import__("json").loads(policy_json) + except (ValueError, __import__("json").JSONDecodeError) as exc: + return _error_response("InvalidPolicyDocument", f"Invalid policy: {exc}", 400) + expiration = policy.get("expiration") + if expiration: + try: + exp_time = datetime.fromisoformat(expiration.replace("Z", "+00:00")) + if datetime.now(timezone.utc) > exp_time: + return _error_response("AccessDenied", "Policy expired", 403) + except ValueError: + return _error_response("InvalidPolicyDocument", "Invalid expiration format", 400) + conditions = policy.get("conditions", []) + validation_error = _validate_post_policy_conditions(bucket_name, object_key, conditions, request.form, request.content_length or 0) + if validation_error: + return _error_response("AccessDenied", validation_error, 403) + try: + parts = credential.split("/") + if len(parts) != 5: + raise ValueError("Invalid credential format") + access_key, date_stamp, region, service, _ = parts + except ValueError: + return _error_response("InvalidArgument", "Invalid credential format", 400) + secret_key = _iam().get_secret_key(access_key) + if not secret_key: + return _error_response("AccessDenied", "Invalid access key", 403) + signing_key = _derive_signing_key(secret_key, date_stamp, region, service) + expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest() + if not hmac.compare_digest(expected_signature, signature): + return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403) + file = request.files.get("file") + if not file: + return _error_response("InvalidArgument", "Missing file field", 400) + if "${filename}" in object_key: + object_key = object_key.replace("${filename}", file.filename or "upload") + metadata = {} + for field_name, value in request.form.items(): + if field_name.lower().startswith("x-amz-meta-"): + key = field_name[11:] + if key: + metadata[key] = value + try: + meta = storage.put_object(bucket_name, object_key, file.stream, metadata=metadata or None) + except QuotaExceededError as exc: + return _error_response("QuotaExceeded", str(exc), 403) + except StorageError as exc: + return _error_response("InvalidArgument", str(exc), 400) + current_app.logger.info("Object uploaded via POST", extra={"bucket": bucket_name, "key": object_key, "size": meta.size}) + success_action_status = request.form.get("success_action_status", "204") + success_action_redirect = request.form.get("success_action_redirect") + if success_action_redirect: + redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}" + return Response(status=303, headers={"Location": redirect_url}) + if success_action_status == "200": + root = Element("PostResponse") + SubElement(root, "Location").text = f"/{bucket_name}/{object_key}" + SubElement(root, "Bucket").text = bucket_name + SubElement(root, "Key").text = object_key + SubElement(root, "ETag").text = f'"{meta.etag}"' + return _xml_response(root, status=200) + if success_action_status == "201": + root = Element("PostResponse") + SubElement(root, "Location").text = f"/{bucket_name}/{object_key}" + SubElement(root, "Bucket").text = bucket_name + SubElement(root, "Key").text = object_key + SubElement(root, "ETag").text = f'"{meta.etag}"' + return _xml_response(root, status=201) + return Response(status=204) + + +def _validate_post_policy_conditions(bucket_name: str, object_key: str, conditions: list, form_data, content_length: int) -> Optional[str]: + for condition in conditions: + if isinstance(condition, dict): + for key, expected_value in condition.items(): + if key == "bucket": + if bucket_name != expected_value: + return f"Bucket must be {expected_value}" + elif key == "key": + if object_key != expected_value: + return f"Key must be {expected_value}" + else: + actual_value = form_data.get(key, "") + if actual_value != expected_value: + return f"Field {key} must be {expected_value}" + elif isinstance(condition, list) and len(condition) >= 2: + operator = condition[0].lower() if isinstance(condition[0], str) else "" + if operator == "starts-with" and len(condition) == 3: + field = condition[1].lstrip("$") + prefix = condition[2] + if field == "key": + if not object_key.startswith(prefix): + return f"Key must start with {prefix}" + else: + actual_value = form_data.get(field, "") + if not actual_value.startswith(prefix): + return f"Field {field} must start with {prefix}" + elif operator == "eq" and len(condition) == 3: + field = condition[1].lstrip("$") + expected = condition[2] + if field == "key": + if object_key != expected: + return f"Key must equal {expected}" + else: + actual_value = form_data.get(field, "") + if actual_value != expected: + return f"Field {field} must equal {expected}" + elif operator == "content-length-range" and len(condition) == 3: + min_size, max_size = condition[1], condition[2] + if content_length < min_size or content_length > max_size: + return f"Content length must be between {min_size} and {max_size}" + return None + + @s3_api_bp.get("/") @limiter.limit(_get_list_buckets_limit) def list_buckets() -> Response: @@ -2233,9 +2362,12 @@ def bucket_handler(bucket_name: str) -> Response: return subresource_response if request.method == "POST": - if "delete" not in request.args: - return _method_not_allowed(["GET", "PUT", "DELETE"]) - return _bulk_delete_handler(bucket_name) + if "delete" in request.args: + return _bulk_delete_handler(bucket_name) + content_type = request.headers.get("Content-Type", "") + if "multipart/form-data" in content_type: + return _post_object(bucket_name) + return _method_not_allowed(["GET", "PUT", "DELETE"]) if request.method == "PUT": principal, error = _require_principal() @@ -2433,6 +2565,8 @@ def object_handler(bucket_name: str, object_key: str): return _initiate_multipart_upload(bucket_name, object_key) if "uploadId" in request.args: return _complete_multipart_upload(bucket_name, object_key) + if "select" in request.args: + return _select_object_content(bucket_name, object_key) return _method_not_allowed(["GET", "PUT", "DELETE", "HEAD", "POST"]) if request.method == "PUT": @@ -2732,6 +2866,120 @@ def _bucket_policy_handler(bucket_name: str) -> Response: return Response(status=204) +def _bucket_replication_handler(bucket_name: str) -> Response: + if request.method not in {"GET", "PUT", "DELETE"}: + return _method_not_allowed(["GET", "PUT", "DELETE"]) + principal, error = _require_principal() + if error: + return error + try: + _authorize_action(principal, bucket_name, "policy") + except IamError as exc: + return _error_response("AccessDenied", str(exc), 403) + storage = _storage() + if not storage.bucket_exists(bucket_name): + return _error_response("NoSuchBucket", "Bucket does not exist", 404) + replication = _replication_manager() + if request.method == "GET": + rule = replication.get_rule(bucket_name) + if not rule: + return _error_response("ReplicationConfigurationNotFoundError", "Replication configuration not found", 404) + return _xml_response(_render_replication_config(rule)) + if request.method == "DELETE": + replication.delete_rule(bucket_name) + current_app.logger.info("Bucket replication removed", extra={"bucket": bucket_name}) + return Response(status=204) + ct_error = _require_xml_content_type() + if ct_error: + return ct_error + payload = request.get_data(cache=False) or b"" + try: + rule = _parse_replication_config(bucket_name, payload) + except ValueError as exc: + return _error_response("MalformedXML", str(exc), 400) + replication.set_rule(rule) + current_app.logger.info("Bucket replication updated", extra={"bucket": bucket_name}) + return Response(status=200) + + +def _parse_replication_config(bucket_name: str, payload: bytes): + from .replication import ReplicationRule, REPLICATION_MODE_ALL + root = fromstring(payload) + if _strip_ns(root.tag) != "ReplicationConfiguration": + raise ValueError("Root element must be ReplicationConfiguration") + rule_el = None + for child in list(root): + if _strip_ns(child.tag) == "Rule": + rule_el = child + break + if rule_el is None: + raise ValueError("At least one Rule is required") + status_el = _find_element(rule_el, "Status") + status = status_el.text if status_el is not None and status_el.text else "Enabled" + enabled = status.lower() == "enabled" + filter_prefix = None + filter_el = _find_element(rule_el, "Filter") + if filter_el is not None: + prefix_el = _find_element(filter_el, "Prefix") + if prefix_el is not None and prefix_el.text: + filter_prefix = prefix_el.text + dest_el = _find_element(rule_el, "Destination") + if dest_el is None: + raise ValueError("Destination element is required") + bucket_el = _find_element(dest_el, "Bucket") + if bucket_el is None or not bucket_el.text: + raise ValueError("Destination Bucket is required") + target_bucket, target_connection_id = _parse_destination_arn(bucket_el.text) + sync_deletions = True + dm_el = _find_element(rule_el, "DeleteMarkerReplication") + if dm_el is not None: + dm_status_el = _find_element(dm_el, "Status") + if dm_status_el is not None and dm_status_el.text: + sync_deletions = dm_status_el.text.lower() == "enabled" + return ReplicationRule( + bucket_name=bucket_name, + target_connection_id=target_connection_id, + target_bucket=target_bucket, + enabled=enabled, + mode=REPLICATION_MODE_ALL, + sync_deletions=sync_deletions, + filter_prefix=filter_prefix, + ) + + +def _parse_destination_arn(arn: str) -> tuple: + if not arn.startswith("arn:aws:s3:::"): + raise ValueError(f"Invalid ARN format: {arn}") + bucket_part = arn[13:] + if "/" in bucket_part: + connection_id, bucket_name = bucket_part.split("/", 1) + else: + connection_id = "local" + bucket_name = bucket_part + return bucket_name, connection_id + + +def _render_replication_config(rule) -> Element: + root = Element("ReplicationConfiguration") + SubElement(root, "Role").text = "arn:aws:iam::000000000000:role/replication" + rule_el = SubElement(root, "Rule") + SubElement(rule_el, "ID").text = f"{rule.bucket_name}-replication" + SubElement(rule_el, "Status").text = "Enabled" if rule.enabled else "Disabled" + SubElement(rule_el, "Priority").text = "1" + filter_el = SubElement(rule_el, "Filter") + if rule.filter_prefix: + SubElement(filter_el, "Prefix").text = rule.filter_prefix + dest_el = SubElement(rule_el, "Destination") + if rule.target_connection_id == "local": + arn = f"arn:aws:s3:::{rule.target_bucket}" + else: + arn = f"arn:aws:s3:::{rule.target_connection_id}/{rule.target_bucket}" + SubElement(dest_el, "Bucket").text = arn + dm_el = SubElement(rule_el, "DeleteMarkerReplication") + SubElement(dm_el, "Status").text = "Enabled" if rule.sync_deletions else "Disabled" + return root + + @s3_api_bp.route("/", methods=["HEAD"]) @limiter.limit(_get_head_ops_limit) def head_bucket(bucket_name: str) -> Response: @@ -3009,6 +3257,10 @@ def _initiate_multipart_upload(bucket_name: str, object_key: str) -> Response: def _upload_part(bucket_name: str, object_key: str) -> Response: + copy_source = request.headers.get("x-amz-copy-source") + if copy_source: + return _upload_part_copy(bucket_name, object_key, copy_source) + principal, error = _object_principal("write", bucket_name, object_key) if error: return error @@ -3042,6 +3294,62 @@ def _upload_part(bucket_name: str, object_key: str) -> Response: return response +def _upload_part_copy(bucket_name: str, object_key: str, copy_source: str) -> Response: + principal, error = _object_principal("write", bucket_name, object_key) + if error: + return error + + upload_id = request.args.get("uploadId") + part_number_str = request.args.get("partNumber") + if not upload_id or not part_number_str: + return _error_response("InvalidArgument", "uploadId and partNumber are required", 400) + + try: + part_number = int(part_number_str) + except ValueError: + return _error_response("InvalidArgument", "partNumber must be an integer", 400) + + copy_source = unquote(copy_source) + if copy_source.startswith("/"): + copy_source = copy_source[1:] + parts = copy_source.split("/", 1) + if len(parts) != 2: + return _error_response("InvalidArgument", "Invalid x-amz-copy-source format", 400) + source_bucket, source_key = parts + + _, read_error = _object_principal("read", source_bucket, source_key) + if read_error: + return read_error + + copy_source_range = request.headers.get("x-amz-copy-source-range") + start_byte, end_byte = None, None + if copy_source_range: + match = re.match(r"bytes=(\d+)-(\d+)", copy_source_range) + if not match: + return _error_response("InvalidArgument", "Invalid x-amz-copy-source-range format", 400) + start_byte, end_byte = int(match.group(1)), int(match.group(2)) + + try: + result = _storage().upload_part_copy( + bucket_name, upload_id, part_number, + source_bucket, source_key, + start_byte, end_byte + ) + except ObjectNotFoundError: + return _error_response("NoSuchKey", "Source object not found", 404) + except StorageError as exc: + if "Multipart upload not found" in str(exc): + return _error_response("NoSuchUpload", str(exc), 404) + if "Invalid byte range" in str(exc): + return _error_response("InvalidRange", str(exc), 416) + return _error_response("InvalidArgument", str(exc), 400) + + root = Element("CopyPartResult") + SubElement(root, "LastModified").text = result["last_modified"].strftime("%Y-%m-%dT%H:%M:%S.000Z") + SubElement(root, "ETag").text = f'"{result["etag"]}"' + return _xml_response(root) + + def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: principal, error = _object_principal("write", bucket_name, object_key) if error: @@ -3126,6 +3434,164 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response: return Response(status=204) +def _select_object_content(bucket_name: str, object_key: str) -> Response: + _, error = _object_principal("read", bucket_name, object_key) + if error: + return error + ct_error = _require_xml_content_type() + if ct_error: + return ct_error + payload = request.get_data(cache=False) or b"" + try: + root = fromstring(payload) + except ParseError: + return _error_response("MalformedXML", "Unable to parse XML document", 400) + if _strip_ns(root.tag) != "SelectObjectContentRequest": + return _error_response("MalformedXML", "Root element must be SelectObjectContentRequest", 400) + expression_el = _find_element(root, "Expression") + if expression_el is None or not expression_el.text: + return _error_response("InvalidRequest", "Expression is required", 400) + expression = expression_el.text + expression_type_el = _find_element(root, "ExpressionType") + expression_type = expression_type_el.text if expression_type_el is not None and expression_type_el.text else "SQL" + if expression_type.upper() != "SQL": + return _error_response("InvalidRequest", "Only SQL expression type is supported", 400) + input_el = _find_element(root, "InputSerialization") + if input_el is None: + return _error_response("InvalidRequest", "InputSerialization is required", 400) + try: + input_format, input_config = _parse_select_input_serialization(input_el) + except ValueError as exc: + return _error_response("InvalidRequest", str(exc), 400) + output_el = _find_element(root, "OutputSerialization") + if output_el is None: + return _error_response("InvalidRequest", "OutputSerialization is required", 400) + try: + output_format, output_config = _parse_select_output_serialization(output_el) + except ValueError as exc: + return _error_response("InvalidRequest", str(exc), 400) + storage = _storage() + try: + path = storage.get_object_path(bucket_name, object_key) + except ObjectNotFoundError: + return _error_response("NoSuchKey", "Object not found", 404) + except StorageError: + return _error_response("NoSuchKey", "Object not found", 404) + from .select_content import execute_select_query, SelectError + try: + result_stream = execute_select_query( + file_path=path, + expression=expression, + input_format=input_format, + input_config=input_config, + output_format=output_format, + output_config=output_config, + ) + except SelectError as exc: + return _error_response("InvalidRequest", str(exc), 400) + + def generate_events(): + bytes_scanned = 0 + bytes_returned = 0 + for chunk in result_stream: + bytes_returned += len(chunk) + yield _encode_select_event("Records", chunk) + stats_payload = _build_stats_xml(bytes_scanned, bytes_returned) + yield _encode_select_event("Stats", stats_payload) + yield _encode_select_event("End", b"") + + return Response(generate_events(), mimetype="application/octet-stream", headers={"x-amz-request-charged": "requester"}) + + +def _parse_select_input_serialization(el: Element) -> tuple: + csv_el = _find_element(el, "CSV") + if csv_el is not None: + file_header_el = _find_element(csv_el, "FileHeaderInfo") + config = { + "file_header_info": file_header_el.text.upper() if file_header_el is not None and file_header_el.text else "NONE", + "comments": _find_element_text(csv_el, "Comments", "#"), + "field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","), + "record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"), + "quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'), + "quote_escape_character": _find_element_text(csv_el, "QuoteEscapeCharacter", '"'), + } + return "CSV", config + json_el = _find_element(el, "JSON") + if json_el is not None: + type_el = _find_element(json_el, "Type") + config = { + "type": type_el.text.upper() if type_el is not None and type_el.text else "DOCUMENT", + } + return "JSON", config + parquet_el = _find_element(el, "Parquet") + if parquet_el is not None: + return "Parquet", {} + raise ValueError("InputSerialization must specify CSV, JSON, or Parquet") + + +def _parse_select_output_serialization(el: Element) -> tuple: + csv_el = _find_element(el, "CSV") + if csv_el is not None: + config = { + "field_delimiter": _find_element_text(csv_el, "FieldDelimiter", ","), + "record_delimiter": _find_element_text(csv_el, "RecordDelimiter", "\n"), + "quote_character": _find_element_text(csv_el, "QuoteCharacter", '"'), + "quote_fields": _find_element_text(csv_el, "QuoteFields", "ASNEEDED").upper(), + } + return "CSV", config + json_el = _find_element(el, "JSON") + if json_el is not None: + config = { + "record_delimiter": _find_element_text(json_el, "RecordDelimiter", "\n"), + } + return "JSON", config + raise ValueError("OutputSerialization must specify CSV or JSON") + + +def _encode_select_event(event_type: str, payload: bytes) -> bytes: + import struct + import binascii + headers = _build_event_headers(event_type) + headers_length = len(headers) + total_length = 4 + 4 + 4 + headers_length + len(payload) + 4 + prelude = struct.pack(">I", total_length) + struct.pack(">I", headers_length) + prelude_crc = binascii.crc32(prelude) & 0xffffffff + prelude += struct.pack(">I", prelude_crc) + message = prelude + headers + payload + message_crc = binascii.crc32(message) & 0xffffffff + message += struct.pack(">I", message_crc) + return message + + +def _build_event_headers(event_type: str) -> bytes: + headers = b"" + headers += _encode_select_header(":event-type", event_type) + if event_type == "Records": + headers += _encode_select_header(":content-type", "application/octet-stream") + elif event_type == "Stats": + headers += _encode_select_header(":content-type", "text/xml") + headers += _encode_select_header(":message-type", "event") + return headers + + +def _encode_select_header(name: str, value: str) -> bytes: + import struct + name_bytes = name.encode("utf-8") + value_bytes = value.encode("utf-8") + header = struct.pack("B", len(name_bytes)) + name_bytes + header += struct.pack("B", 7) + header += struct.pack(">H", len(value_bytes)) + value_bytes + return header + + +def _build_stats_xml(bytes_scanned: int, bytes_returned: int) -> bytes: + stats = Element("Stats") + SubElement(stats, "BytesScanned").text = str(bytes_scanned) + SubElement(stats, "BytesProcessed").text = str(bytes_scanned) + SubElement(stats, "BytesReturned").text = str(bytes_returned) + return tostring(stats, encoding="utf-8") + + @s3_api_bp.before_request def resolve_principal(): g.principal = None diff --git a/app/select_content.py b/app/select_content.py new file mode 100644 index 0000000..57a3362 --- /dev/null +++ b/app/select_content.py @@ -0,0 +1,171 @@ +"""S3 SelectObjectContent SQL query execution using DuckDB.""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, Generator, Optional + +try: + import duckdb + DUCKDB_AVAILABLE = True +except ImportError: + DUCKDB_AVAILABLE = False + + +class SelectError(Exception): + """Error during SELECT query execution.""" + pass + + +def execute_select_query( + file_path: Path, + expression: str, + input_format: str, + input_config: Dict[str, Any], + output_format: str, + output_config: Dict[str, Any], + chunk_size: int = 65536, +) -> Generator[bytes, None, None]: + """Execute SQL query on object content.""" + if not DUCKDB_AVAILABLE: + raise SelectError("DuckDB is not installed. Install with: pip install duckdb") + + conn = duckdb.connect(":memory:") + + try: + if input_format == "CSV": + _load_csv(conn, file_path, input_config) + elif input_format == "JSON": + _load_json(conn, file_path, input_config) + elif input_format == "Parquet": + _load_parquet(conn, file_path) + else: + raise SelectError(f"Unsupported input format: {input_format}") + + normalized_expression = expression.replace("s3object", "data").replace("S3Object", "data") + + try: + result = conn.execute(normalized_expression) + except duckdb.Error as exc: + raise SelectError(f"SQL execution error: {exc}") + + if output_format == "CSV": + yield from _output_csv(result, output_config, chunk_size) + elif output_format == "JSON": + yield from _output_json(result, output_config, chunk_size) + else: + raise SelectError(f"Unsupported output format: {output_format}") + + finally: + conn.close() + + +def _load_csv(conn, file_path: Path, config: Dict[str, Any]) -> None: + """Load CSV file into DuckDB.""" + file_header_info = config.get("file_header_info", "NONE") + delimiter = config.get("field_delimiter", ",") + quote = config.get("quote_character", '"') + + header = file_header_info in ("USE", "IGNORE") + path_str = str(file_path).replace("\\", "/") + + conn.execute(f""" + CREATE TABLE data AS + SELECT * FROM read_csv('{path_str}', + header={header}, + delim='{delimiter}', + quote='{quote}' + ) + """) + + +def _load_json(conn, file_path: Path, config: Dict[str, Any]) -> None: + """Load JSON file into DuckDB.""" + json_type = config.get("type", "DOCUMENT") + path_str = str(file_path).replace("\\", "/") + + if json_type == "LINES": + conn.execute(f""" + CREATE TABLE data AS + SELECT * FROM read_json_auto('{path_str}', format='newline_delimited') + """) + else: + conn.execute(f""" + CREATE TABLE data AS + SELECT * FROM read_json_auto('{path_str}', format='array') + """) + + +def _load_parquet(conn, file_path: Path) -> None: + """Load Parquet file into DuckDB.""" + path_str = str(file_path).replace("\\", "/") + conn.execute(f"CREATE TABLE data AS SELECT * FROM read_parquet('{path_str}')") + + +def _output_csv( + result, + config: Dict[str, Any], + chunk_size: int, +) -> Generator[bytes, None, None]: + """Output query results as CSV.""" + delimiter = config.get("field_delimiter", ",") + record_delimiter = config.get("record_delimiter", "\n") + quote = config.get("quote_character", '"') + + buffer = "" + + while True: + rows = result.fetchmany(1000) + if not rows: + break + + for row in rows: + fields = [] + for value in row: + if value is None: + fields.append("") + elif isinstance(value, str): + if delimiter in value or quote in value or record_delimiter in value: + escaped = value.replace(quote, quote + quote) + fields.append(f'{quote}{escaped}{quote}') + else: + fields.append(value) + else: + fields.append(str(value)) + + buffer += delimiter.join(fields) + record_delimiter + + while len(buffer) >= chunk_size: + yield buffer[:chunk_size].encode("utf-8") + buffer = buffer[chunk_size:] + + if buffer: + yield buffer.encode("utf-8") + + +def _output_json( + result, + config: Dict[str, Any], + chunk_size: int, +) -> Generator[bytes, None, None]: + """Output query results as JSON Lines.""" + record_delimiter = config.get("record_delimiter", "\n") + columns = [desc[0] for desc in result.description] + + buffer = "" + + while True: + rows = result.fetchmany(1000) + if not rows: + break + + for row in rows: + record = dict(zip(columns, row)) + buffer += json.dumps(record, default=str) + record_delimiter + + while len(buffer) >= chunk_size: + yield buffer[:chunk_size].encode("utf-8") + buffer = buffer[chunk_size:] + + if buffer: + yield buffer.encode("utf-8") diff --git a/app/storage.py b/app/storage.py index 645cbd2..05a2fda 100644 --- a/app/storage.py +++ b/app/storage.py @@ -999,6 +999,102 @@ class ObjectStorage: return record["etag"] + def upload_part_copy( + self, + bucket_name: str, + upload_id: str, + part_number: int, + source_bucket: str, + source_key: str, + start_byte: Optional[int] = None, + end_byte: Optional[int] = None, + ) -> Dict[str, Any]: + """Copy a range from an existing object as a multipart part.""" + if part_number < 1 or part_number > 10000: + raise StorageError("part_number must be between 1 and 10000") + + source_path = self.get_object_path(source_bucket, source_key) + source_size = source_path.stat().st_size + + if start_byte is None: + start_byte = 0 + if end_byte is None: + end_byte = source_size - 1 + + if start_byte < 0 or end_byte >= source_size or start_byte > end_byte: + raise StorageError("Invalid byte range") + + bucket_path = self._bucket_path(bucket_name) + upload_root = self._multipart_dir(bucket_path.name, upload_id) + if not upload_root.exists(): + upload_root = self._legacy_multipart_dir(bucket_path.name, upload_id) + if not upload_root.exists(): + raise StorageError("Multipart upload not found") + + checksum = hashlib.md5() + part_filename = f"part-{part_number:05d}.part" + part_path = upload_root / part_filename + temp_path = upload_root / f".{part_filename}.tmp" + + try: + with source_path.open("rb") as src: + src.seek(start_byte) + bytes_to_copy = end_byte - start_byte + 1 + with temp_path.open("wb") as target: + remaining = bytes_to_copy + while remaining > 0: + chunk_size = min(65536, remaining) + chunk = src.read(chunk_size) + if not chunk: + break + checksum.update(chunk) + target.write(chunk) + remaining -= len(chunk) + temp_path.replace(part_path) + except OSError: + try: + temp_path.unlink(missing_ok=True) + except OSError: + pass + raise + + record = { + "etag": checksum.hexdigest(), + "size": part_path.stat().st_size, + "filename": part_filename, + } + + manifest_path = upload_root / self.MULTIPART_MANIFEST + lock_path = upload_root / ".manifest.lock" + + max_retries = 3 + for attempt in range(max_retries): + try: + with lock_path.open("w") as lock_file: + with _file_lock(lock_file): + try: + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as exc: + if attempt < max_retries - 1: + time.sleep(0.1 * (attempt + 1)) + continue + raise StorageError("Multipart manifest unreadable") from exc + + parts = manifest.setdefault("parts", {}) + parts[str(part_number)] = record + manifest_path.write_text(json.dumps(manifest), encoding="utf-8") + break + except OSError as exc: + if attempt < max_retries - 1: + time.sleep(0.1 * (attempt + 1)) + continue + raise StorageError(f"Failed to update multipart manifest: {exc}") from exc + + return { + "etag": record["etag"], + "last_modified": datetime.fromtimestamp(part_path.stat().st_mtime, timezone.utc), + } + def complete_multipart_upload( self, bucket_name: str, diff --git a/requirements.txt b/requirements.txt index 17915fa..1813b33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ boto3>=1.42.14 waitress>=3.0.2 psutil>=7.1.3 cryptography>=46.0.3 -defusedxml>=0.7.1 \ No newline at end of file +defusedxml>=0.7.1 +duckdb>=1.4.4 \ No newline at end of file