diff --git a/app/__init__.py b/app/__init__.py index 4d7d68f..6fcc8dc 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -16,6 +16,7 @@ from flask_cors import CORS from flask_wtf.csrf import CSRFError from werkzeug.middleware.proxy_fix import ProxyFix +from .acl import AclService from .bucket_policies import BucketPolicyStore from .config import AppConfig from .connections import ConnectionStore @@ -23,6 +24,7 @@ from .encryption import EncryptionManager from .extensions import limiter, csrf from .iam import IamService from .kms import KMSManager +from .lifecycle import LifecycleManager from .replication import ReplicationManager from .secret_store import EphemeralSecretStore from .storage import ObjectStorage @@ -140,6 +142,17 @@ def create_app( from .encrypted_storage import EncryptedObjectStorage storage = EncryptedObjectStorage(storage, encryption_manager) + acl_service = AclService(storage_root) + + lifecycle_manager = None + if app.config.get("LIFECYCLE_ENABLED", False): + base_storage = storage.storage if hasattr(storage, 'storage') else storage + lifecycle_manager = LifecycleManager( + base_storage, + interval_seconds=app.config.get("LIFECYCLE_INTERVAL_SECONDS", 3600), + ) + lifecycle_manager.start() + app.extensions["object_storage"] = storage app.extensions["iam"] = iam app.extensions["bucket_policies"] = bucket_policies @@ -149,6 +162,8 @@ def create_app( app.extensions["replication"] = replication app.extensions["encryption"] = encryption_manager app.extensions["kms"] = kms_manager + app.extensions["acl"] = acl_service + app.extensions["lifecycle"] = lifecycle_manager @app.errorhandler(500) def internal_error(error): diff --git a/app/acl.py b/app/acl.py new file mode 100644 index 0000000..7f78a11 --- /dev/null +++ b/app/acl.py @@ -0,0 +1,205 @@ +"""S3-compatible Access Control List (ACL) management.""" +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + + +ACL_PERMISSION_FULL_CONTROL = "FULL_CONTROL" +ACL_PERMISSION_WRITE = "WRITE" +ACL_PERMISSION_WRITE_ACP = "WRITE_ACP" +ACL_PERMISSION_READ = "READ" +ACL_PERMISSION_READ_ACP = "READ_ACP" + +ALL_PERMISSIONS = { + ACL_PERMISSION_FULL_CONTROL, + ACL_PERMISSION_WRITE, + ACL_PERMISSION_WRITE_ACP, + ACL_PERMISSION_READ, + ACL_PERMISSION_READ_ACP, +} + +PERMISSION_TO_ACTIONS = { + ACL_PERMISSION_FULL_CONTROL: {"read", "write", "delete", "list", "share"}, + ACL_PERMISSION_WRITE: {"write", "delete"}, + ACL_PERMISSION_WRITE_ACP: {"share"}, + ACL_PERMISSION_READ: {"read", "list"}, + ACL_PERMISSION_READ_ACP: {"share"}, +} + +GRANTEE_ALL_USERS = "*" +GRANTEE_AUTHENTICATED_USERS = "authenticated" + + +@dataclass +class AclGrant: + grantee: str + permission: str + + def to_dict(self) -> Dict[str, str]: + return {"grantee": self.grantee, "permission": self.permission} + + @classmethod + def from_dict(cls, data: Dict[str, str]) -> "AclGrant": + return cls(grantee=data["grantee"], permission=data["permission"]) + + +@dataclass +class Acl: + owner: str + grants: List[AclGrant] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "owner": self.owner, + "grants": [g.to_dict() for g in self.grants], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Acl": + return cls( + owner=data.get("owner", ""), + grants=[AclGrant.from_dict(g) for g in data.get("grants", [])], + ) + + def get_allowed_actions(self, principal_id: Optional[str], is_authenticated: bool = True) -> Set[str]: + actions: Set[str] = set() + if principal_id and principal_id == self.owner: + actions.update(PERMISSION_TO_ACTIONS[ACL_PERMISSION_FULL_CONTROL]) + for grant in self.grants: + if grant.grantee == GRANTEE_ALL_USERS: + actions.update(PERMISSION_TO_ACTIONS.get(grant.permission, set())) + elif grant.grantee == GRANTEE_AUTHENTICATED_USERS and is_authenticated: + actions.update(PERMISSION_TO_ACTIONS.get(grant.permission, set())) + elif principal_id and grant.grantee == principal_id: + actions.update(PERMISSION_TO_ACTIONS.get(grant.permission, set())) + return actions + + +CANNED_ACLS = { + "private": lambda owner: Acl( + owner=owner, + grants=[AclGrant(grantee=owner, permission=ACL_PERMISSION_FULL_CONTROL)], + ), + "public-read": lambda owner: Acl( + owner=owner, + grants=[ + AclGrant(grantee=owner, permission=ACL_PERMISSION_FULL_CONTROL), + AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_READ), + ], + ), + "public-read-write": lambda owner: Acl( + owner=owner, + grants=[ + AclGrant(grantee=owner, permission=ACL_PERMISSION_FULL_CONTROL), + AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_READ), + AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_WRITE), + ], + ), + "authenticated-read": lambda owner: Acl( + owner=owner, + grants=[ + AclGrant(grantee=owner, permission=ACL_PERMISSION_FULL_CONTROL), + AclGrant(grantee=GRANTEE_AUTHENTICATED_USERS, permission=ACL_PERMISSION_READ), + ], + ), + "bucket-owner-read": lambda owner: Acl( + owner=owner, + grants=[ + AclGrant(grantee=owner, permission=ACL_PERMISSION_FULL_CONTROL), + ], + ), + "bucket-owner-full-control": lambda owner: Acl( + owner=owner, + grants=[ + AclGrant(grantee=owner, permission=ACL_PERMISSION_FULL_CONTROL), + ], + ), +} + + +def create_canned_acl(canned_acl: str, owner: str) -> Acl: + factory = CANNED_ACLS.get(canned_acl) + if not factory: + return CANNED_ACLS["private"](owner) + return factory(owner) + + +class AclService: + def __init__(self, storage_root: Path): + self.storage_root = storage_root + self._bucket_acl_cache: Dict[str, Acl] = {} + + def _bucket_acl_path(self, bucket_name: str) -> Path: + return self.storage_root / ".myfsio.sys" / "buckets" / bucket_name / ".acl.json" + + def get_bucket_acl(self, bucket_name: str) -> Optional[Acl]: + if bucket_name in self._bucket_acl_cache: + return self._bucket_acl_cache[bucket_name] + acl_path = self._bucket_acl_path(bucket_name) + if not acl_path.exists(): + return None + try: + data = json.loads(acl_path.read_text(encoding="utf-8")) + acl = Acl.from_dict(data) + self._bucket_acl_cache[bucket_name] = acl + return acl + except (OSError, json.JSONDecodeError): + return None + + def set_bucket_acl(self, bucket_name: str, acl: Acl) -> None: + acl_path = self._bucket_acl_path(bucket_name) + acl_path.parent.mkdir(parents=True, exist_ok=True) + acl_path.write_text(json.dumps(acl.to_dict(), indent=2), encoding="utf-8") + self._bucket_acl_cache[bucket_name] = acl + + def set_bucket_canned_acl(self, bucket_name: str, canned_acl: str, owner: str) -> Acl: + acl = create_canned_acl(canned_acl, owner) + self.set_bucket_acl(bucket_name, acl) + return acl + + def delete_bucket_acl(self, bucket_name: str) -> None: + acl_path = self._bucket_acl_path(bucket_name) + if acl_path.exists(): + acl_path.unlink() + self._bucket_acl_cache.pop(bucket_name, None) + + def evaluate_bucket_acl( + self, + bucket_name: str, + principal_id: Optional[str], + action: str, + is_authenticated: bool = True, + ) -> bool: + acl = self.get_bucket_acl(bucket_name) + if not acl: + return False + allowed_actions = acl.get_allowed_actions(principal_id, is_authenticated) + return action in allowed_actions + + def get_object_acl(self, bucket_name: str, object_key: str, object_metadata: Dict[str, Any]) -> Optional[Acl]: + acl_data = object_metadata.get("__acl__") + if not acl_data: + return None + try: + return Acl.from_dict(acl_data) + except (TypeError, KeyError): + return None + + def create_object_acl_metadata(self, acl: Acl) -> Dict[str, Any]: + return {"__acl__": acl.to_dict()} + + def evaluate_object_acl( + self, + object_metadata: Dict[str, Any], + principal_id: Optional[str], + action: str, + is_authenticated: bool = True, + ) -> bool: + acl = self.get_object_acl("", "", object_metadata) + if not acl: + return False + allowed_actions = acl.get_allowed_actions(principal_id, is_authenticated) + return action in allowed_actions diff --git a/app/config.py b/app/config.py index 206cce6..e6bd127 100644 --- a/app/config.py +++ b/app/config.py @@ -74,6 +74,8 @@ class AppConfig: kms_keys_path: Path default_encryption_algorithm: str display_timezone: str + lifecycle_enabled: bool + lifecycle_interval_seconds: int @classmethod def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig": @@ -91,6 +93,8 @@ class AppConfig: secret_ttl_seconds = int(_get("SECRET_TTL_SECONDS", 300)) stream_chunk_size = int(_get("STREAM_CHUNK_SIZE", 64 * 1024)) multipart_min_part_size = int(_get("MULTIPART_MIN_PART_SIZE", 5 * 1024 * 1024)) + lifecycle_enabled = _get("LIFECYCLE_ENABLED", "false").lower() in ("true", "1", "yes") + lifecycle_interval_seconds = int(_get("LIFECYCLE_INTERVAL_SECONDS", 3600)) default_secret = "dev-secret-key" secret_key = str(_get("SECRET_KEY", default_secret)) @@ -198,7 +202,9 @@ class AppConfig: kms_enabled=kms_enabled, kms_keys_path=kms_keys_path, default_encryption_algorithm=default_encryption_algorithm, - display_timezone=display_timezone) + display_timezone=display_timezone, + lifecycle_enabled=lifecycle_enabled, + lifecycle_interval_seconds=lifecycle_interval_seconds) def validate_and_report(self) -> list[str]: """Validate configuration and return a list of warnings/issues. diff --git a/app/iam.py b/app/iam.py index d93ccc5..5e5c26a 100644 --- a/app/iam.py +++ b/app/iam.py @@ -121,6 +121,7 @@ class IamService: self._cache_ttl = 60.0 # Cache credentials for 60 seconds self._last_stat_check = 0.0 self._stat_check_interval = 1.0 # Only stat() file every 1 second + self._sessions: Dict[str, Dict[str, Any]] = {} self._load() def _maybe_reload(self) -> None: @@ -192,6 +193,40 @@ class IamService: elapsed = (datetime.now(timezone.utc) - oldest).total_seconds() return int(max(0, self.auth_lockout_window.total_seconds() - elapsed)) + def create_session_token(self, access_key: str, duration_seconds: int = 3600) -> str: + """Create a temporary session token for an access key.""" + self._maybe_reload() + record = self._users.get(access_key) + if not record: + raise IamError("Unknown access key") + self._cleanup_expired_sessions() + token = secrets.token_urlsafe(32) + expires_at = time.time() + duration_seconds + self._sessions[token] = { + "access_key": access_key, + "expires_at": expires_at, + } + return token + + def validate_session_token(self, access_key: str, session_token: str) -> bool: + """Validate a session token for an access key.""" + session = self._sessions.get(session_token) + if not session: + return False + if session["access_key"] != access_key: + return False + if time.time() > session["expires_at"]: + del self._sessions[session_token] + return False + return True + + def _cleanup_expired_sessions(self) -> None: + """Remove expired session tokens.""" + now = time.time() + expired = [token for token, data in self._sessions.items() if now > data["expires_at"]] + for token in expired: + del self._sessions[token] + def principal_for_key(self, access_key: str) -> Principal: # Performance: Check cache first now = time.time() diff --git a/app/kms.py b/app/kms.py index 4ed72da..65e98ab 100644 --- a/app/kms.py +++ b/app/kms.py @@ -211,7 +211,27 @@ class KMSManager: """List all keys.""" self._load_keys() return list(self._keys.values()) - + + def get_default_key_id(self) -> str: + """Get the default KMS key ID, creating one if none exist.""" + self._load_keys() + for key in self._keys.values(): + if key.enabled: + return key.key_id + default_key = self.create_key(description="Default KMS Key") + return default_key.key_id + + def get_provider(self, key_id: str | None = None) -> "KMSEncryptionProvider": + """Get a KMS encryption provider for the specified key.""" + if key_id is None: + key_id = self.get_default_key_id() + key = self.get_key(key_id) + if not key: + raise EncryptionError(f"Key not found: {key_id}") + if not key.enabled: + raise EncryptionError(f"Key is disabled: {key_id}") + return KMSEncryptionProvider(self, key_id) + def enable_key(self, key_id: str) -> None: """Enable a key.""" self._load_keys() diff --git a/app/lifecycle.py b/app/lifecycle.py new file mode 100644 index 0000000..bacf32e --- /dev/null +++ b/app/lifecycle.py @@ -0,0 +1,236 @@ +"""Lifecycle rule enforcement for S3-compatible storage.""" +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .storage import ObjectStorage, StorageError + +logger = logging.getLogger(__name__) + + +@dataclass +class LifecycleResult: + bucket_name: str + objects_deleted: int = 0 + versions_deleted: int = 0 + uploads_aborted: int = 0 + errors: List[str] = field(default_factory=list) + execution_time_seconds: float = 0.0 + + +class LifecycleManager: + def __init__(self, storage: ObjectStorage, interval_seconds: int = 3600): + self.storage = storage + self.interval_seconds = interval_seconds + self._timer: Optional[threading.Timer] = None + self._shutdown = False + self._lock = threading.Lock() + + def start(self) -> None: + if self._timer is not None: + return + self._shutdown = False + self._schedule_next() + logger.info(f"Lifecycle manager started with interval {self.interval_seconds}s") + + def stop(self) -> None: + self._shutdown = True + if self._timer: + self._timer.cancel() + self._timer = None + logger.info("Lifecycle manager stopped") + + def _schedule_next(self) -> None: + if self._shutdown: + return + self._timer = threading.Timer(self.interval_seconds, self._run_enforcement) + self._timer.daemon = True + self._timer.start() + + def _run_enforcement(self) -> None: + if self._shutdown: + return + try: + self.enforce_all_buckets() + except Exception as e: + logger.error(f"Lifecycle enforcement failed: {e}") + finally: + self._schedule_next() + + def enforce_all_buckets(self) -> Dict[str, LifecycleResult]: + results = {} + try: + buckets = self.storage.list_buckets() + for bucket in buckets: + result = self.enforce_rules(bucket.name) + if result.objects_deleted > 0 or result.versions_deleted > 0 or result.uploads_aborted > 0: + results[bucket.name] = result + except StorageError as e: + logger.error(f"Failed to list buckets for lifecycle: {e}") + return results + + def enforce_rules(self, bucket_name: str) -> LifecycleResult: + start_time = time.time() + result = LifecycleResult(bucket_name=bucket_name) + + try: + lifecycle = self.storage.get_bucket_lifecycle(bucket_name) + if not lifecycle: + return result + + for rule in lifecycle: + if rule.get("Status") != "Enabled": + continue + rule_id = rule.get("ID", "unknown") + prefix = rule.get("Prefix", rule.get("Filter", {}).get("Prefix", "")) + + self._enforce_expiration(bucket_name, rule, prefix, result) + self._enforce_noncurrent_expiration(bucket_name, rule, prefix, result) + self._enforce_abort_multipart(bucket_name, rule, result) + + except StorageError as e: + result.errors.append(str(e)) + logger.error(f"Lifecycle enforcement error for {bucket_name}: {e}") + + result.execution_time_seconds = time.time() - start_time + if result.objects_deleted > 0 or result.versions_deleted > 0 or result.uploads_aborted > 0: + logger.info( + f"Lifecycle enforcement for {bucket_name}: " + f"deleted={result.objects_deleted}, versions={result.versions_deleted}, " + f"aborted={result.uploads_aborted}, time={result.execution_time_seconds:.2f}s" + ) + return result + + def _enforce_expiration( + self, bucket_name: str, rule: Dict[str, Any], prefix: str, result: LifecycleResult + ) -> None: + expiration = rule.get("Expiration", {}) + if not expiration: + return + + days = expiration.get("Days") + date_str = expiration.get("Date") + + if days: + cutoff = datetime.now(timezone.utc) - timedelta(days=days) + elif date_str: + try: + cutoff = datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except ValueError: + return + else: + return + + try: + objects = self.storage.list_objects_all(bucket_name) + for obj in objects: + if prefix and not obj.key.startswith(prefix): + continue + if obj.last_modified < cutoff: + try: + self.storage.delete_object(bucket_name, obj.key) + result.objects_deleted += 1 + except StorageError as e: + result.errors.append(f"Failed to delete {obj.key}: {e}") + except StorageError as e: + result.errors.append(f"Failed to list objects: {e}") + + def _enforce_noncurrent_expiration( + self, bucket_name: str, rule: Dict[str, Any], prefix: str, result: LifecycleResult + ) -> None: + noncurrent = rule.get("NoncurrentVersionExpiration", {}) + noncurrent_days = noncurrent.get("NoncurrentDays") + if not noncurrent_days: + return + + cutoff = datetime.now(timezone.utc) - timedelta(days=noncurrent_days) + + try: + objects = self.storage.list_objects_all(bucket_name) + for obj in objects: + if prefix and not obj.key.startswith(prefix): + continue + try: + versions = self.storage.list_object_versions(bucket_name, obj.key) + for version in versions: + archived_at_str = version.get("archived_at", "") + if not archived_at_str: + continue + try: + archived_at = datetime.fromisoformat(archived_at_str.replace("Z", "+00:00")) + if archived_at < cutoff: + version_id = version.get("version_id") + if version_id: + self.storage.delete_object_version(bucket_name, obj.key, version_id) + result.versions_deleted += 1 + except (ValueError, StorageError) as e: + result.errors.append(f"Failed to process version: {e}") + except StorageError: + pass + except StorageError as e: + result.errors.append(f"Failed to list objects: {e}") + + try: + orphaned = self.storage.list_orphaned_objects(bucket_name) + for item in orphaned: + obj_key = item.get("key", "") + if prefix and not obj_key.startswith(prefix): + continue + try: + versions = self.storage.list_object_versions(bucket_name, obj_key) + for version in versions: + archived_at_str = version.get("archived_at", "") + if not archived_at_str: + continue + try: + archived_at = datetime.fromisoformat(archived_at_str.replace("Z", "+00:00")) + if archived_at < cutoff: + version_id = version.get("version_id") + if version_id: + self.storage.delete_object_version(bucket_name, obj_key, version_id) + result.versions_deleted += 1 + except (ValueError, StorageError) as e: + result.errors.append(f"Failed to process orphaned version: {e}") + except StorageError: + pass + except StorageError as e: + result.errors.append(f"Failed to list orphaned objects: {e}") + + def _enforce_abort_multipart( + self, bucket_name: str, rule: Dict[str, Any], result: LifecycleResult + ) -> None: + abort_config = rule.get("AbortIncompleteMultipartUpload", {}) + days_after = abort_config.get("DaysAfterInitiation") + if not days_after: + return + + cutoff = datetime.now(timezone.utc) - timedelta(days=days_after) + + try: + uploads = self.storage.list_multipart_uploads(bucket_name) + for upload in uploads: + created_at_str = upload.get("created_at", "") + if not created_at_str: + continue + try: + created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + if created_at < cutoff: + upload_id = upload.get("upload_id") + if upload_id: + self.storage.abort_multipart_upload(bucket_name, upload_id) + result.uploads_aborted += 1 + except (ValueError, StorageError) as e: + result.errors.append(f"Failed to abort upload: {e}") + except StorageError as e: + result.errors.append(f"Failed to list multipart uploads: {e}") + + def run_now(self, bucket_name: Optional[str] = None) -> Dict[str, LifecycleResult]: + if bucket_name: + return {bucket_name: self.enforce_rules(bucket_name)} + return self.enforce_all_buckets() diff --git a/app/replication.py b/app/replication.py index 4703bf9..7301dc2 100644 --- a/app/replication.py +++ b/app/replication.py @@ -182,9 +182,15 @@ class ReplicationManager: return self._rules.get(bucket_name) def set_rule(self, rule: ReplicationRule) -> None: + old_rule = self._rules.get(rule.bucket_name) + was_all_mode = old_rule and old_rule.mode == REPLICATION_MODE_ALL if old_rule else False self._rules[rule.bucket_name] = rule self.save_rules() + if rule.mode == REPLICATION_MODE_ALL and rule.enabled and not was_all_mode: + logger.info(f"Replication mode ALL enabled for {rule.bucket_name}, triggering sync of existing objects") + self._executor.submit(self.replicate_existing_objects, rule.bucket_name) + def delete_rule(self, bucket_name: str) -> None: if bucket_name in self._rules: del self._rules[bucket_name] diff --git a/app/s3_api.py b/app/s3_api.py index ba2d268..5e7fb31 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -16,6 +16,7 @@ from xml.etree.ElementTree import Element, SubElement, tostring, fromstring, Par from flask import Blueprint, Response, current_app, jsonify, request, g from werkzeug.http import http_date +from .acl import AclService from .bucket_policies import BucketPolicyStore from .extensions import limiter from .iam import IamError, Principal @@ -30,6 +31,10 @@ def _storage() -> ObjectStorage: return current_app.extensions["object_storage"] +def _acl() -> AclService: + return current_app.extensions["acl"] + + def _iam(): return current_app.extensions["iam"] @@ -58,6 +63,37 @@ def _error_response(code: str, message: str, status: int) -> Response: return _xml_response(error, status) +def _parse_range_header(range_header: str, file_size: int) -> list[tuple[int, int]] | None: + if not range_header.startswith("bytes="): + return None + ranges = [] + range_spec = range_header[6:] + for part in range_spec.split(","): + part = part.strip() + if not part: + continue + if part.startswith("-"): + suffix_length = int(part[1:]) + if suffix_length <= 0: + return None + start = max(0, file_size - suffix_length) + end = file_size - 1 + elif part.endswith("-"): + start = int(part[:-1]) + if start >= file_size: + return None + end = file_size - 1 + else: + start_str, end_str = part.split("-", 1) + start = int(start_str) + end = int(end_str) + if start > end or start >= file_size: + return None + end = min(end, file_size - 1) + ranges.append((start, end)) + return ranges if ranges else None + + def _sign(key: bytes, msg: str) -> bytes: return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() @@ -179,6 +215,11 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: ) raise IamError("SignatureDoesNotMatch") + session_token = req.headers.get("X-Amz-Security-Token") + if session_token: + if not _iam().validate_session_token(access_key, session_token): + raise IamError("InvalidToken") + return _iam().get_principal(access_key) @@ -257,10 +298,15 @@ def _verify_sigv4_query(req: Any) -> Principal | None: signing_key = _get_signature_key(secret_key, date_stamp, region, service) calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() - + if not hmac.compare_digest(calculated_signature, signature): raise IamError("SignatureDoesNotMatch") - + + session_token = req.args.get("X-Amz-Security-Token") + if session_token: + if not _iam().validate_session_token(access_key, session_token): + raise IamError("InvalidToken") + return _iam().get_principal(access_key) @@ -321,6 +367,19 @@ def _authorize_action(principal: Principal | None, bucket_name: str | None, acti return if policy_decision == "allow": return + + acl_allowed = False + if bucket_name: + acl_service = _acl() + acl_allowed = acl_service.evaluate_bucket_acl( + bucket_name, + access_key, + action, + is_authenticated=principal is not None, + ) + if acl_allowed: + return + raise iam_error or IamError("Access denied") @@ -1131,6 +1190,8 @@ def _bucket_location_handler(bucket_name: str) -> Response: def _bucket_acl_handler(bucket_name: str) -> Response: + from .acl import create_canned_acl, Acl, AclGrant, GRANTEE_ALL_USERS, GRANTEE_AUTHENTICATED_USERS + if request.method not in {"GET", "PUT"}: return _method_not_allowed(["GET", "PUT"]) principal, error = _require_principal() @@ -1143,26 +1204,41 @@ def _bucket_acl_handler(bucket_name: str) -> Response: storage = _storage() if not storage.bucket_exists(bucket_name): return _error_response("NoSuchBucket", "Bucket does not exist", 404) - + + acl_service = _acl() + owner_id = principal.access_key if principal else "anonymous" + if request.method == "PUT": - # Accept canned ACL headers for S3 compatibility (not fully implemented) canned_acl = request.headers.get("x-amz-acl", "private") - current_app.logger.info("Bucket ACL set (canned)", extra={"bucket": bucket_name, "acl": canned_acl}) + acl = acl_service.set_bucket_canned_acl(bucket_name, canned_acl, owner_id) + current_app.logger.info("Bucket ACL set", extra={"bucket": bucket_name, "acl": canned_acl}) return Response(status=200) - + + acl = acl_service.get_bucket_acl(bucket_name) + if not acl: + acl = create_canned_acl("private", owner_id) + root = Element("AccessControlPolicy") - owner = SubElement(root, "Owner") - SubElement(owner, "ID").text = principal.access_key if principal else "anonymous" - SubElement(owner, "DisplayName").text = principal.display_name if principal else "Anonymous" - - acl = SubElement(root, "AccessControlList") - grant = SubElement(acl, "Grant") - grantee = SubElement(grant, "Grantee") - grantee.set("{http://www.w3.org/2001/XMLSchema-instance}type", "CanonicalUser") - SubElement(grantee, "ID").text = principal.access_key if principal else "anonymous" - SubElement(grantee, "DisplayName").text = principal.display_name if principal else "Anonymous" - SubElement(grant, "Permission").text = "FULL_CONTROL" - + owner_el = SubElement(root, "Owner") + SubElement(owner_el, "ID").text = acl.owner + SubElement(owner_el, "DisplayName").text = acl.owner + + acl_el = SubElement(root, "AccessControlList") + for grant in acl.grants: + grant_el = SubElement(acl_el, "Grant") + grantee = SubElement(grant_el, "Grantee") + if grant.grantee == GRANTEE_ALL_USERS: + grantee.set("{http://www.w3.org/2001/XMLSchema-instance}type", "Group") + SubElement(grantee, "URI").text = "http://acs.amazonaws.com/groups/global/AllUsers" + elif grant.grantee == GRANTEE_AUTHENTICATED_USERS: + grantee.set("{http://www.w3.org/2001/XMLSchema-instance}type", "Group") + SubElement(grantee, "URI").text = "http://acs.amazonaws.com/groups/global/AuthenticatedUsers" + else: + grantee.set("{http://www.w3.org/2001/XMLSchema-instance}type", "CanonicalUser") + SubElement(grantee, "ID").text = grant.grantee + SubElement(grantee, "DisplayName").text = grant.grantee + SubElement(grant_el, "Permission").text = grant.permission + return _xml_response(root) @@ -1537,29 +1613,28 @@ def _bulk_delete_handler(bucket_name: str) -> Response: return _error_response("MalformedXML", "A maximum of 1000 objects can be deleted per request", 400) storage = _storage() - deleted: list[str] = [] + deleted: list[dict[str, str | None]] = [] errors: list[dict[str, str]] = [] for entry in objects: key = entry["Key"] or "" version_id = entry.get("VersionId") - if version_id: - errors.append({ - "Key": key, - "Code": "InvalidRequest", - "Message": "VersionId is not supported for bulk deletes", - }) - continue try: - storage.delete_object(bucket_name, key) - deleted.append(key) + if version_id: + storage.delete_object_version(bucket_name, key, version_id) + deleted.append({"Key": key, "VersionId": version_id}) + else: + storage.delete_object(bucket_name, key) + deleted.append({"Key": key, "VersionId": None}) except StorageError as exc: errors.append({"Key": key, "Code": "InvalidRequest", "Message": str(exc)}) result = Element("DeleteResult") if not quiet: - for key in deleted: + for item in deleted: deleted_el = SubElement(result, "Deleted") - SubElement(deleted_el, "Key").text = key + SubElement(deleted_el, "Key").text = item["Key"] + if item.get("VersionId"): + SubElement(deleted_el, "VersionId").text = item["VersionId"] for err in errors: error_el = SubElement(result, "Error") SubElement(error_el, "Key").text = err.get("Key", "") @@ -1870,20 +1945,67 @@ def object_handler(bucket_name: str, object_key: str): is_encrypted = "x-amz-server-side-encryption" in metadata if request.method == "GET": + range_header = request.headers.get("Range") + if is_encrypted and hasattr(storage, 'get_object_data'): try: data, clean_metadata = storage.get_object_data(bucket_name, object_key) - response = Response(data, mimetype=mimetype) - logged_bytes = len(data) - response.headers["Content-Length"] = len(data) + file_size = len(data) etag = hashlib.md5(data).hexdigest() + + if range_header: + try: + ranges = _parse_range_header(range_header, file_size) + except (ValueError, TypeError): + ranges = None + if ranges is None: + return _error_response("InvalidRange", "Range Not Satisfiable", 416) + start, end = ranges[0] + partial_data = data[start:end + 1] + response = Response(partial_data, status=206, mimetype=mimetype) + response.headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" + response.headers["Content-Length"] = len(partial_data) + logged_bytes = len(partial_data) + else: + response = Response(data, mimetype=mimetype) + response.headers["Content-Length"] = file_size + logged_bytes = file_size except StorageError as exc: return _error_response("InternalError", str(exc), 500) else: stat = path.stat() - response = Response(_stream_file(path), mimetype=mimetype, direct_passthrough=True) - logged_bytes = stat.st_size + file_size = stat.st_size etag = storage._compute_etag(path) + + if range_header: + try: + ranges = _parse_range_header(range_header, file_size) + except (ValueError, TypeError): + ranges = None + if ranges is None: + return _error_response("InvalidRange", "Range Not Satisfiable", 416) + start, end = ranges[0] + length = end - start + 1 + + def stream_range(file_path, start_pos, length_to_read): + with open(file_path, "rb") as f: + f.seek(start_pos) + remaining = length_to_read + while remaining > 0: + chunk_size = min(65536, remaining) + chunk = f.read(chunk_size) + if not chunk: + break + remaining -= len(chunk) + yield chunk + + response = Response(stream_range(path, start, length), status=206, mimetype=mimetype, direct_passthrough=True) + response.headers["Content-Range"] = f"bytes {start}-{end}/{file_size}" + response.headers["Content-Length"] = length + logged_bytes = length + else: + response = Response(_stream_file(path), mimetype=mimetype, direct_passthrough=True) + logged_bytes = file_size else: if is_encrypted and hasattr(storage, 'get_object_data'): try: @@ -1901,6 +2023,21 @@ def object_handler(bucket_name: str, object_key: str): logged_bytes = 0 _apply_object_headers(response, file_stat=path.stat() if not is_encrypted else None, metadata=metadata, etag=etag) + + if request.method == "GET": + response_overrides = { + "response-content-type": "Content-Type", + "response-content-language": "Content-Language", + "response-expires": "Expires", + "response-cache-control": "Cache-Control", + "response-content-disposition": "Content-Disposition", + "response-content-encoding": "Content-Encoding", + } + for param, header in response_overrides.items(): + value = request.args.get(param) + if value: + response.headers[header] = value + action = "Object read" if request.method == "GET" else "Object head" current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes}) return response @@ -2119,9 +2256,45 @@ def _copy_object(dest_bucket: str, dest_key: str, copy_source: str) -> Response: source_path = storage.get_object_path(source_bucket, source_key) except StorageError: return _error_response("NoSuchKey", "Source object not found", 404) - + + source_stat = source_path.stat() + source_etag = storage._compute_etag(source_path) + source_mtime = datetime.fromtimestamp(source_stat.st_mtime, timezone.utc) + + copy_source_if_match = request.headers.get("x-amz-copy-source-if-match") + if copy_source_if_match: + expected_etag = copy_source_if_match.strip('"') + if source_etag != expected_etag: + return _error_response("PreconditionFailed", "Source ETag does not match", 412) + + copy_source_if_none_match = request.headers.get("x-amz-copy-source-if-none-match") + if copy_source_if_none_match: + not_expected_etag = copy_source_if_none_match.strip('"') + if source_etag == not_expected_etag: + return _error_response("PreconditionFailed", "Source ETag matches", 412) + + copy_source_if_modified_since = request.headers.get("x-amz-copy-source-if-modified-since") + if copy_source_if_modified_since: + from email.utils import parsedate_to_datetime + try: + if_modified = parsedate_to_datetime(copy_source_if_modified_since) + if source_mtime <= if_modified: + return _error_response("PreconditionFailed", "Source not modified since specified date", 412) + except (TypeError, ValueError): + pass + + copy_source_if_unmodified_since = request.headers.get("x-amz-copy-source-if-unmodified-since") + if copy_source_if_unmodified_since: + from email.utils import parsedate_to_datetime + try: + if_unmodified = parsedate_to_datetime(copy_source_if_unmodified_since) + if source_mtime > if_unmodified: + return _error_response("PreconditionFailed", "Source modified since specified date", 412) + except (TypeError, ValueError): + pass + source_metadata = storage.get_object_metadata(source_bucket, source_key) - + metadata_directive = request.headers.get("x-amz-metadata-directive", "COPY").upper() if metadata_directive == "REPLACE": metadata = _extract_request_metadata() diff --git a/app/storage.py b/app/storage.py index 7d17b7d..27a6e75 100644 --- a/app/storage.py +++ b/app/storage.py @@ -809,6 +809,29 @@ class ObjectStorage: metadata=metadata or None, ) + def delete_object_version(self, bucket_name: str, object_key: str, version_id: str) -> None: + bucket_path = self._bucket_path(bucket_name) + if not bucket_path.exists(): + raise StorageError("Bucket does not exist") + bucket_id = bucket_path.name + safe_key = self._sanitize_object_key(object_key) + version_dir = self._version_dir(bucket_id, safe_key) + data_path = version_dir / f"{version_id}.bin" + meta_path = version_dir / f"{version_id}.json" + if not data_path.exists() and not meta_path.exists(): + legacy_version_dir = self._legacy_version_dir(bucket_id, safe_key) + data_path = legacy_version_dir / f"{version_id}.bin" + meta_path = legacy_version_dir / f"{version_id}.json" + if not data_path.exists() and not meta_path.exists(): + raise StorageError(f"Version {version_id} not found") + if data_path.exists(): + data_path.unlink() + if meta_path.exists(): + meta_path.unlink() + parent = data_path.parent + if parent.exists() and not any(parent.iterdir()): + parent.rmdir() + def list_orphaned_objects(self, bucket_name: str) -> List[Dict[str, Any]]: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): @@ -1124,6 +1147,49 @@ class ObjectStorage: parts.sort(key=lambda x: x["PartNumber"]) return parts + def list_multipart_uploads(self, bucket_name: str) -> List[Dict[str, Any]]: + """List all active multipart uploads for a bucket.""" + bucket_path = self._bucket_path(bucket_name) + if not bucket_path.exists(): + raise StorageError("Bucket does not exist") + bucket_id = bucket_path.name + uploads = [] + multipart_root = self._bucket_multipart_root(bucket_id) + if multipart_root.exists(): + for upload_dir in multipart_root.iterdir(): + if not upload_dir.is_dir(): + continue + manifest_path = upload_dir / "manifest.json" + if not manifest_path.exists(): + continue + try: + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + uploads.append({ + "upload_id": manifest.get("upload_id", upload_dir.name), + "object_key": manifest.get("object_key", ""), + "created_at": manifest.get("created_at", ""), + }) + except (OSError, json.JSONDecodeError): + continue + legacy_root = self._legacy_multipart_root(bucket_id) + if legacy_root.exists(): + for upload_dir in legacy_root.iterdir(): + if not upload_dir.is_dir(): + continue + manifest_path = upload_dir / "manifest.json" + if not manifest_path.exists(): + continue + try: + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + uploads.append({ + "upload_id": manifest.get("upload_id", upload_dir.name), + "object_key": manifest.get("object_key", ""), + "created_at": manifest.get("created_at", ""), + }) + except (OSError, json.JSONDecodeError): + continue + return uploads + def _bucket_path(self, bucket_name: str) -> Path: safe_name = self._sanitize_bucket_name(bucket_name) return self.root / safe_name