diff --git a/app/__init__.py b/app/__init__.py index 68ff222..19a890c 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -289,17 +289,17 @@ def _configure_logging(app: Flask) -> None: formatter = logging.Formatter( "%(asctime)s | %(levelname)s | %(request_id)s | %(method)s %(path)s | %(message)s" ) - - # Stream Handler (stdout) - Primary for Docker + stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setFormatter(formatter) stream_handler.addFilter(_RequestContextFilter()) logger = app.logger + for handler in logger.handlers[:]: + handler.close() logger.handlers.clear() logger.addHandler(stream_handler) - # File Handler (optional, if configured) if app.config.get("LOG_TO_FILE"): log_file = Path(app.config["LOG_FILE"]) log_file.parent.mkdir(parents=True, exist_ok=True) diff --git a/app/access_logging.py b/app/access_logging.py index 03132a8..f07ac99 100644 --- a/app/access_logging.py +++ b/app/access_logging.py @@ -196,18 +196,21 @@ class AccessLoggingService: ) target_key = f"{config.target_bucket}:{config.target_prefix}" + should_flush = False with self._buffer_lock: if target_key not in self._buffer: self._buffer[target_key] = [] self._buffer[target_key].append(entry) + should_flush = len(self._buffer[target_key]) >= self.max_buffer_size - if len(self._buffer[target_key]) >= self.max_buffer_size: - self._flush_buffer(target_key) + if should_flush: + self._flush_buffer(target_key) def _flush_loop(self) -> None: while not self._shutdown.is_set(): - time.sleep(self.flush_interval) - self._flush_all() + self._shutdown.wait(timeout=self.flush_interval) + if not self._shutdown.is_set(): + self._flush_all() def _flush_all(self) -> None: with self._buffer_lock: diff --git a/app/config.py b/app/config.py index 241bcdf..02f72db 100644 --- a/app/config.py +++ b/app/config.py @@ -84,7 +84,7 @@ class AppConfig: return overrides.get(name, os.getenv(name, default)) storage_root = Path(_get("STORAGE_ROOT", PROJECT_ROOT / "data")).resolve() - max_upload_size = int(_get("MAX_UPLOAD_SIZE", 1024 * 1024 * 1024)) # 1 GiB default + max_upload_size = int(_get("MAX_UPLOAD_SIZE", 1024 * 1024 * 1024)) ui_page_size = int(_get("UI_PAGE_SIZE", 100)) auth_max_attempts = int(_get("AUTH_MAX_ATTEMPTS", 5)) auth_lockout_minutes = int(_get("AUTH_LOCKOUT_MINUTES", 15)) @@ -108,6 +108,10 @@ class AppConfig: try: secret_file.parent.mkdir(parents=True, exist_ok=True) secret_file.write_text(generated) + try: + os.chmod(secret_file, 0o600) + except OSError: + pass secret_key = generated except OSError: secret_key = generated diff --git a/app/replication.py b/app/replication.py index 1a91af9..6620fff 100644 --- a/app/replication.py +++ b/app/replication.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) REPLICATION_USER_AGENT = "S3ReplicationAgent/1.0" REPLICATION_CONNECT_TIMEOUT = 5 REPLICATION_READ_TIMEOUT = 30 -STREAMING_THRESHOLD_BYTES = 10 * 1024 * 1024 # 10 MiB - use streaming for larger files +STREAMING_THRESHOLD_BYTES = 10 * 1024 * 1024 REPLICATION_MODE_NEW_ONLY = "new_only" REPLICATION_MODE_ALL = "all" @@ -307,7 +307,6 @@ class ReplicationManager: if self._shutdown: return - # Re-check if rule is still enabled (may have been paused after task was submitted) current_rule = self.get_rule(bucket_name) if not current_rule or not current_rule.enabled: logger.debug(f"Replication skipped for {bucket_name}/{object_key}: rule disabled or removed") @@ -358,7 +357,6 @@ class ReplicationManager: extra_args["ContentType"] = content_type if file_size >= STREAMING_THRESHOLD_BYTES: - # Use multipart upload for large files s3.upload_file( str(path), rule.target_bucket, @@ -366,7 +364,6 @@ class ReplicationManager: ExtraArgs=extra_args if extra_args else None, ) else: - # Read small files into memory file_content = path.read_bytes() put_kwargs = { "Bucket": rule.target_bucket, diff --git a/app/s3_api.py b/app/s3_api.py index e87cf15..f12d5f6 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -25,7 +25,7 @@ from .iam import IamError, Principal from .notifications import NotificationService, NotificationConfiguration, WebhookDestination from .object_lock import ObjectLockService, ObjectLockRetention, ObjectLockConfig, ObjectLockError, RetentionMode from .replication import ReplicationManager -from .storage import ObjectStorage, StorageError, QuotaExceededError +from .storage import ObjectStorage, StorageError, QuotaExceededError, BucketNotFoundError, ObjectNotFoundError logger = logging.getLogger(__name__) @@ -217,7 +217,6 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() if not hmac.compare_digest(calculated_signature, signature): - # Only log detailed signature debug info if DEBUG_SIGV4 is enabled if current_app.config.get("DEBUG_SIGV4"): logger.warning( "SigV4 signature mismatch", @@ -260,7 +259,13 @@ def _verify_sigv4_query(req: Any) -> Principal | None: raise IamError("Invalid Date format") now = datetime.now(timezone.utc) - if now > req_time + timedelta(seconds=int(expires)): + try: + expires_seconds = int(expires) + if expires_seconds <= 0: + raise IamError("Invalid Expires value: must be positive") + except ValueError: + raise IamError("Invalid Expires value: must be an integer") + if now > req_time + timedelta(seconds=expires_seconds): raise IamError("Request expired") secret_key = _iam().get_secret_key(access_key) @@ -1036,21 +1041,23 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response: if request.method == "GET": try: tags = storage.get_object_tags(bucket_name, object_key) + except BucketNotFoundError as exc: + return _error_response("NoSuchBucket", str(exc), 404) + except ObjectNotFoundError as exc: + return _error_response("NoSuchKey", str(exc), 404) except StorageError as exc: - message = str(exc) - if "Bucket" in message: - return _error_response("NoSuchBucket", message, 404) - return _error_response("NoSuchKey", message, 404) + return _error_response("InternalError", str(exc), 500) return _xml_response(_render_tagging_document(tags)) if request.method == "DELETE": try: storage.delete_object_tags(bucket_name, object_key) + except BucketNotFoundError as exc: + return _error_response("NoSuchBucket", str(exc), 404) + except ObjectNotFoundError as exc: + return _error_response("NoSuchKey", str(exc), 404) except StorageError as exc: - message = str(exc) - if "Bucket" in message: - return _error_response("NoSuchBucket", message, 404) - return _error_response("NoSuchKey", message, 404) + return _error_response("InternalError", str(exc), 500) current_app.logger.info("Object tags deleted", extra={"bucket": bucket_name, "key": object_key}) return Response(status=204) @@ -1063,11 +1070,12 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response: return _error_response("InvalidTag", "A maximum of 10 tags is supported for objects", 400) try: storage.set_object_tags(bucket_name, object_key, tags) + except BucketNotFoundError as exc: + return _error_response("NoSuchBucket", str(exc), 404) + except ObjectNotFoundError as exc: + return _error_response("NoSuchKey", str(exc), 404) except StorageError as exc: - message = str(exc) - if "Bucket" in message: - return _error_response("NoSuchBucket", message, 404) - return _error_response("NoSuchKey", message, 404) + return _error_response("InternalError", str(exc), 500) current_app.logger.info("Object tags updated", extra={"bucket": bucket_name, "key": object_key, "tags": len(tags)}) return Response(status=204) @@ -1283,7 +1291,10 @@ def _bucket_list_versions_handler(bucket_name: str) -> Response: prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") - max_keys = min(int(request.args.get("max-keys", 1000)), 1000) + try: + max_keys = max(1, min(int(request.args.get("max-keys", 1000)), 1000)) + except ValueError: + return _error_response("InvalidArgument", "max-keys must be an integer", 400) key_marker = request.args.get("key-marker", "") if prefix: @@ -1476,7 +1487,10 @@ def _parse_lifecycle_config(payload: bytes) -> list: expiration: dict = {} days_el = exp_el.find("{*}Days") or exp_el.find("Days") if days_el is not None and days_el.text: - expiration["Days"] = int(days_el.text.strip()) + days_val = int(days_el.text.strip()) + if days_val <= 0: + raise ValueError("Expiration Days must be a positive integer") + expiration["Days"] = days_val date_el = exp_el.find("{*}Date") or exp_el.find("Date") if date_el is not None and date_el.text: expiration["Date"] = date_el.text.strip() @@ -1491,7 +1505,10 @@ def _parse_lifecycle_config(payload: bytes) -> list: nve: dict = {} days_el = nve_el.find("{*}NoncurrentDays") or nve_el.find("NoncurrentDays") if days_el is not None and days_el.text: - nve["NoncurrentDays"] = int(days_el.text.strip()) + noncurrent_days = int(days_el.text.strip()) + if noncurrent_days <= 0: + raise ValueError("NoncurrentDays must be a positive integer") + nve["NoncurrentDays"] = noncurrent_days if nve: rule["NoncurrentVersionExpiration"] = nve @@ -1500,7 +1517,10 @@ def _parse_lifecycle_config(payload: bytes) -> list: aimu: dict = {} days_el = aimu_el.find("{*}DaysAfterInitiation") or aimu_el.find("DaysAfterInitiation") if days_el is not None and days_el.text: - aimu["DaysAfterInitiation"] = int(days_el.text.strip()) + days_after = int(days_el.text.strip()) + if days_after <= 0: + raise ValueError("DaysAfterInitiation must be a positive integer") + aimu["DaysAfterInitiation"] = days_after if aimu: rule["AbortIncompleteMultipartUpload"] = aimu @@ -2086,7 +2106,10 @@ def bucket_handler(bucket_name: str) -> Response: list_type = request.args.get("list-type") prefix = request.args.get("prefix", "") delimiter = request.args.get("delimiter", "") - max_keys = min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000) + try: + max_keys = max(1, min(int(request.args.get("max-keys", current_app.config["UI_PAGE_SIZE"])), 1000)) + except ValueError: + return _error_response("InvalidArgument", "max-keys must be an integer", 400) marker = request.args.get("marker", "") # ListObjects v1 continuation_token = request.args.get("continuation-token", "") # ListObjectsV2 @@ -2099,7 +2122,7 @@ def bucket_handler(bucket_name: str) -> Response: if continuation_token: try: effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8") - except Exception: + except (ValueError, UnicodeDecodeError): effective_start = continuation_token elif start_after: effective_start = start_after @@ -2742,7 +2765,7 @@ class AwsChunkedDecoder: def __init__(self, stream): self.stream = stream - self._read_buffer = bytearray() # Performance: Pre-allocated buffer + self._read_buffer = bytearray() self.chunk_remaining = 0 self.finished = False @@ -2753,20 +2776,15 @@ class AwsChunkedDecoder: """ line = bytearray() while True: - # Check if we have data in buffer if self._read_buffer: - # Look for CRLF in buffer idx = self._read_buffer.find(b"\r\n") if idx != -1: - # Found CRLF - extract line and update buffer line.extend(self._read_buffer[: idx + 2]) del self._read_buffer[: idx + 2] return bytes(line) - # No CRLF yet - consume entire buffer line.extend(self._read_buffer) self._read_buffer.clear() - # Read more data in larger chunks (64 bytes is enough for chunk headers) chunk = self.stream.read(64) if not chunk: return bytes(line) if line else b"" @@ -2775,14 +2793,11 @@ class AwsChunkedDecoder: def _read_exact(self, n: int) -> bytes: """Read exactly n bytes, using buffer first.""" result = bytearray() - # Use buffered data first if self._read_buffer: take = min(len(self._read_buffer), n) result.extend(self._read_buffer[:take]) del self._read_buffer[:take] n -= take - - # Read remaining directly from stream if n > 0: data = self.stream.read(n) if data: @@ -2794,7 +2809,7 @@ class AwsChunkedDecoder: if self.finished: return b"" - result = bytearray() # Performance: Use bytearray for building result + result = bytearray() while size == -1 or len(result) < size: if self.chunk_remaining > 0: to_read = self.chunk_remaining @@ -2828,7 +2843,6 @@ class AwsChunkedDecoder: if chunk_size == 0: self.finished = True - # Skip trailing headers while True: trailer = self._read_line() if trailer == b"\r\n" or not trailer: @@ -2969,10 +2983,11 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response: try: _storage().abort_multipart_upload(bucket_name, upload_id) + except BucketNotFoundError as exc: + return _error_response("NoSuchBucket", str(exc), 404) except StorageError as exc: - if "Bucket does not exist" in str(exc): - return _error_response("NoSuchBucket", str(exc), 404) - + current_app.logger.warning(f"Error aborting multipart upload: {exc}") + return Response(status=204) @@ -2984,13 +2999,15 @@ def resolve_principal(): (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"): g.principal = _verify_sigv4(request) return - except Exception: - pass - + except IamError as exc: + logger.debug(f"SigV4 authentication failed: {exc}") + except (ValueError, KeyError) as exc: + logger.debug(f"SigV4 parsing error: {exc}") + access_key = request.headers.get("X-Access-Key") secret_key = request.headers.get("X-Secret-Key") if access_key and secret_key: try: g.principal = _iam().authenticate(access_key, secret_key) - except Exception: - pass + except IamError as exc: + logger.debug(f"Header authentication failed: {exc}") diff --git a/app/storage.py b/app/storage.py index 14d522b..32403ec 100644 --- a/app/storage.py +++ b/app/storage.py @@ -76,6 +76,14 @@ class StorageError(RuntimeError): """Raised when the storage layer encounters an unrecoverable problem.""" +class BucketNotFoundError(StorageError): + """Raised when the bucket does not exist.""" + + +class ObjectNotFoundError(StorageError): + """Raised when the object does not exist.""" + + class QuotaExceededError(StorageError): """Raised when an operation would exceed bucket quota limits.""" @@ -106,7 +114,7 @@ class ListObjectsResult: objects: List[ObjectMeta] is_truncated: bool next_continuation_token: Optional[str] - total_count: Optional[int] = None # Total objects in bucket (from stats cache) + total_count: Optional[int] = None def _utcnow() -> datetime: @@ -130,22 +138,18 @@ class ObjectStorage: MULTIPART_MANIFEST = "manifest.json" BUCKET_CONFIG_FILE = ".bucket.json" KEY_INDEX_CACHE_TTL = 30 - OBJECT_CACHE_MAX_SIZE = 100 # Maximum number of buckets to cache + OBJECT_CACHE_MAX_SIZE = 100 def __init__(self, root: Path) -> None: self.root = Path(root) self.root.mkdir(parents=True, exist_ok=True) self._ensure_system_roots() - # LRU cache for object metadata with thread-safe access self._object_cache: OrderedDict[str, tuple[Dict[str, ObjectMeta], float]] = OrderedDict() - self._cache_lock = threading.Lock() # Global lock for cache structure - # Performance: Per-bucket locks to reduce contention + self._cache_lock = threading.Lock() self._bucket_locks: Dict[str, threading.Lock] = {} - # Cache version counter for detecting stale reads self._cache_version: Dict[str, int] = {} - # Performance: Bucket config cache with TTL self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {} - self._bucket_config_cache_ttl = 30.0 # 30 second TTL + self._bucket_config_cache_ttl = 30.0 def _get_bucket_lock(self, bucket_id: str) -> threading.Lock: """Get or create a lock for a specific bucket. Reduces global lock contention.""" @@ -170,6 +174,11 @@ class ObjectStorage: def bucket_exists(self, bucket_name: str) -> bool: return self._bucket_path(bucket_name).exists() + def _require_bucket_exists(self, bucket_path: Path) -> None: + """Raise BucketNotFoundError if bucket does not exist.""" + if not bucket_path.exists(): + raise BucketNotFoundError("Bucket does not exist") + def _validate_bucket_name(self, bucket_name: str) -> None: if len(bucket_name) < 3 or len(bucket_name) > 63: raise StorageError("Bucket name must be between 3 and 63 characters") @@ -188,14 +197,14 @@ class ObjectStorage: def bucket_stats(self, bucket_name: str, cache_ttl: int = 60) -> dict[str, int]: """Return object count and total size for the bucket (cached). - + Args: bucket_name: Name of the bucket cache_ttl: Cache time-to-live in seconds (default 60) """ bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") cache_path = self._system_bucket_root(bucket_name) / "stats.json" if cache_path.exists(): @@ -257,8 +266,7 @@ class ObjectStorage: def delete_bucket(self, bucket_name: str) -> None: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") - # Performance: Single check instead of three separate traversals + raise BucketNotFoundError("Bucket does not exist") has_objects, has_versions, has_multipart = self._check_bucket_contents(bucket_path) if has_objects: raise StorageError("Bucket not empty") @@ -291,7 +299,7 @@ class ObjectStorage: """ bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name object_cache = self._get_object_cache(bucket_id, bucket_path) @@ -352,7 +360,7 @@ class ObjectStorage: ) -> ObjectMeta: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name safe_key = self._sanitize_object_key(object_key) @@ -409,7 +417,6 @@ class ObjectStorage: self._invalidate_bucket_stats_cache(bucket_id) - # Performance: Lazy update - only update the affected key instead of invalidating whole cache obj_meta = ObjectMeta( key=safe_key.as_posix(), size=stat.st_size, @@ -424,7 +431,7 @@ class ObjectStorage: def get_object_path(self, bucket_name: str, object_key: str) -> Path: path = self._object_path(bucket_name, object_key) if not path.exists(): - raise StorageError("Object not found") + raise ObjectNotFoundError("Object not found") return path def get_object_metadata(self, bucket_name: str, object_key: str) -> Dict[str, str]: @@ -467,7 +474,6 @@ class ObjectStorage: self._delete_metadata(bucket_id, rel) self._invalidate_bucket_stats_cache(bucket_id) - # Performance: Lazy update - only remove the affected key instead of invalidating whole cache self._update_object_cache_entry(bucket_id, safe_key.as_posix(), None) self._cleanup_empty_parents(path, bucket_path) @@ -490,14 +496,13 @@ class ObjectStorage: shutil.rmtree(legacy_version_dir, ignore_errors=True) self._invalidate_bucket_stats_cache(bucket_id) - # Performance: Lazy update - only remove the affected key instead of invalidating whole cache self._update_object_cache_entry(bucket_id, rel.as_posix(), None) self._cleanup_empty_parents(target, bucket_path) def is_versioning_enabled(self, bucket_name: str) -> bool: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") return self._is_versioning_enabled(bucket_path) def set_bucket_versioning(self, bucket_name: str, enabled: bool) -> None: @@ -689,11 +694,11 @@ class ObjectStorage: """Get tags for an object.""" bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") safe_key = self._sanitize_object_key(object_key) object_path = bucket_path / safe_key if not object_path.exists(): - raise StorageError("Object does not exist") + raise ObjectNotFoundError("Object does not exist") for meta_file in (self._metadata_file(bucket_path.name, safe_key), self._legacy_metadata_file(bucket_path.name, safe_key)): if not meta_file.exists(): @@ -712,11 +717,11 @@ class ObjectStorage: """Set tags for an object.""" bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") safe_key = self._sanitize_object_key(object_key) object_path = bucket_path / safe_key if not object_path.exists(): - raise StorageError("Object does not exist") + raise ObjectNotFoundError("Object does not exist") meta_file = self._metadata_file(bucket_path.name, safe_key) @@ -750,7 +755,7 @@ class ObjectStorage: def list_object_versions(self, bucket_name: str, object_key: str) -> List[Dict[str, Any]]: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name safe_key = self._sanitize_object_key(object_key) version_dir = self._version_dir(bucket_id, safe_key) @@ -774,7 +779,7 @@ class ObjectStorage: def restore_object_version(self, bucket_name: str, object_key: str, version_id: str) -> ObjectMeta: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name safe_key = self._sanitize_object_key(object_key) version_dir = self._version_dir(bucket_id, safe_key) @@ -811,7 +816,7 @@ class ObjectStorage: def delete_object_version(self, bucket_name: str, object_key: str, version_id: str) -> None: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name safe_key = self._sanitize_object_key(object_key) version_dir = self._version_dir(bucket_id, safe_key) @@ -834,7 +839,7 @@ class ObjectStorage: def list_orphaned_objects(self, bucket_name: str) -> List[Dict[str, Any]]: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name version_roots = [self._bucket_versions_root(bucket_id), self._legacy_versions_root(bucket_id)] if not any(root.exists() for root in version_roots): @@ -902,7 +907,7 @@ class ObjectStorage: ) -> str: bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name safe_key = self._sanitize_object_key(object_key) upload_id = uuid.uuid4().hex @@ -929,8 +934,8 @@ class ObjectStorage: Uses file locking to safely update the manifest and handle concurrent uploads. """ - if part_number < 1: - raise StorageError("part_number must be >= 1") + if part_number < 1 or part_number > 10000: + raise StorageError("part_number must be between 1 and 10000") bucket_path = self._bucket_path(bucket_name) upload_root = self._multipart_dir(bucket_path.name, upload_id) @@ -939,7 +944,6 @@ class ObjectStorage: if not upload_root.exists(): raise StorageError("Multipart upload not found") - # Write part to temporary file first, then rename atomically checksum = hashlib.md5() part_filename = f"part-{part_number:05d}.part" part_path = upload_root / part_filename @@ -948,11 +952,8 @@ class ObjectStorage: try: with temp_path.open("wb") as target: shutil.copyfileobj(_HashingReader(stream, checksum), target) - - # Atomic rename (or replace on Windows) temp_path.replace(part_path) except OSError: - # Clean up temp file on failure try: temp_path.unlink(missing_ok=True) except OSError: @@ -968,7 +969,6 @@ class ObjectStorage: manifest_path = upload_root / self.MULTIPART_MANIFEST lock_path = upload_root / ".manifest.lock" - # Retry loop for handling transient lock/read failures max_retries = 3 for attempt in range(max_retries): try: @@ -1151,10 +1151,10 @@ class ObjectStorage: """List all active multipart uploads for a bucket.""" bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): - raise StorageError("Bucket does not exist") + raise BucketNotFoundError("Bucket does not exist") bucket_id = bucket_path.name uploads = [] - multipart_root = self._bucket_multipart_root(bucket_id) + multipart_root = self._multipart_bucket_root(bucket_id) if multipart_root.exists(): for upload_dir in multipart_root.iterdir(): if not upload_dir.is_dir(): @@ -1171,7 +1171,7 @@ class ObjectStorage: }) except (OSError, json.JSONDecodeError): continue - legacy_root = self._legacy_multipart_root(bucket_id) + legacy_root = self._legacy_multipart_bucket_root(bucket_id) if legacy_root.exists(): for upload_dir in legacy_root.iterdir(): if not upload_dir.is_dir(): @@ -1394,7 +1394,6 @@ class ObjectStorage: """ now = time.time() - # Quick check with global lock (brief) with self._cache_lock: cached = self._object_cache.get(bucket_id) if cached: @@ -1404,10 +1403,8 @@ class ObjectStorage: return objects cache_version = self._cache_version.get(bucket_id, 0) - # Use per-bucket lock for cache building (allows parallel builds for different buckets) bucket_lock = self._get_bucket_lock(bucket_id) with bucket_lock: - # Double-check cache after acquiring per-bucket lock with self._cache_lock: cached = self._object_cache.get(bucket_id) if cached: @@ -1415,17 +1412,12 @@ class ObjectStorage: if now - timestamp < self.KEY_INDEX_CACHE_TTL: self._object_cache.move_to_end(bucket_id) return objects - - # Build cache with per-bucket lock held (prevents duplicate work) objects = self._build_object_cache(bucket_path) with self._cache_lock: - # Check if cache was invalidated while we were building current_version = self._cache_version.get(bucket_id, 0) if current_version != cache_version: objects = self._build_object_cache(bucket_path) - - # Evict oldest entries if cache is full while len(self._object_cache) >= self.OBJECT_CACHE_MAX_SIZE: self._object_cache.popitem(last=False) @@ -1459,12 +1451,9 @@ class ObjectStorage: if cached: objects, timestamp = cached if meta is None: - # Delete operation - remove key from cache objects.pop(key, None) else: - # Put operation - update/add key in cache objects[key] = meta - # Keep same timestamp - don't reset TTL for single key updates def _ensure_system_roots(self) -> None: for path in ( @@ -1485,13 +1474,12 @@ class ObjectStorage: return self._system_bucket_root(bucket_name) / self.BUCKET_CONFIG_FILE def _read_bucket_config(self, bucket_name: str) -> dict[str, Any]: - # Performance: Check cache first now = time.time() cached = self._bucket_config_cache.get(bucket_name) if cached: config, cached_time = cached if now - cached_time < self._bucket_config_cache_ttl: - return config.copy() # Return copy to prevent mutation + return config.copy() config_path = self._bucket_config_path(bucket_name) if not config_path.exists(): @@ -1510,7 +1498,6 @@ class ObjectStorage: config_path = self._bucket_config_path(bucket_name) config_path.parent.mkdir(parents=True, exist_ok=True) config_path.write_text(json.dumps(payload), encoding="utf-8") - # Performance: Update cache immediately after write self._bucket_config_cache[bucket_name] = (payload.copy(), time.time()) def _set_bucket_config_entry(self, bucket_name: str, key: str, value: Any | None) -> None: @@ -1636,7 +1623,6 @@ class ObjectStorage: def _check_bucket_contents(self, bucket_path: Path) -> tuple[bool, bool, bool]: """Check bucket for objects, versions, and multipart uploads in a single pass. - Performance optimization: Combines three separate rglob traversals into one. Returns (has_visible_objects, has_archived_versions, has_active_multipart_uploads). Uses early exit when all three are found. """ @@ -1645,7 +1631,6 @@ class ObjectStorage: has_multipart = False bucket_name = bucket_path.name - # Check visible objects in bucket for path in bucket_path.rglob("*"): if has_objects: break @@ -1656,7 +1641,6 @@ class ObjectStorage: continue has_objects = True - # Check archived versions (only if needed) for version_root in ( self._bucket_versions_root(bucket_name), self._legacy_versions_root(bucket_name), @@ -1669,7 +1653,6 @@ class ObjectStorage: has_versions = True break - # Check multipart uploads (only if needed) for uploads_root in ( self._multipart_bucket_root(bucket_name), self._legacy_multipart_bucket_root(bucket_name), @@ -1703,7 +1686,7 @@ class ObjectStorage: try: os.chmod(target_path, stat.S_IRWXU) func(target_path) - except Exception as exc: # pragma: no cover - fallback failure + except Exception as exc: raise StorageError(f"Unable to delete bucket contents: {exc}") from exc try: diff --git a/app/ui.py b/app/ui.py index 233ef62..7ddd6ea 100644 --- a/app/ui.py +++ b/app/ui.py @@ -371,7 +371,7 @@ def bucket_detail(bucket_name: str): kms_keys = kms_manager.list_keys() if kms_manager else [] kms_enabled = current_app.config.get("KMS_ENABLED", False) encryption_enabled = current_app.config.get("ENCRYPTION_ENABLED", False) - can_manage_encryption = can_manage_versioning # Same as other bucket properties + can_manage_encryption = can_manage_versioning bucket_quota = storage.get_bucket_quota(bucket_name) bucket_stats = storage.bucket_stats(bucket_name) @@ -450,8 +450,6 @@ def list_bucket_objects(bucket_name: str): except StorageError: versioning_enabled = False - # Pre-compute URL templates once (not per-object) for performance - # Frontend will construct actual URLs by replacing KEY_PLACEHOLDER preview_template = url_for("ui.object_preview", bucket_name=bucket_name, object_key="KEY_PLACEHOLDER") delete_template = url_for("ui.delete_object", bucket_name=bucket_name, object_key="KEY_PLACEHOLDER") presign_template = url_for("ui.object_presign", bucket_name=bucket_name, object_key="KEY_PLACEHOLDER") @@ -527,8 +525,6 @@ def upload_object(bucket_name: str): try: _authorize_ui(principal, bucket_name, "write") _storage().put_object(bucket_name, object_key, file.stream, metadata=metadata) - - # Trigger replication _replication().trigger_replication(bucket_name, object_key) message = f"Uploaded '{object_key}'" @@ -765,20 +761,18 @@ def bulk_download_objects(bucket_name: str): if not cleaned: return jsonify({"error": "Select at least one object to download"}), 400 - MAX_KEYS = current_app.config.get("BULK_DELETE_MAX_KEYS", 500) # Reuse same limit for now + MAX_KEYS = current_app.config.get("BULK_DELETE_MAX_KEYS", 500) if len(cleaned) > MAX_KEYS: return jsonify({"error": f"A maximum of {MAX_KEYS} objects can be downloaded per request"}), 400 unique_keys = list(dict.fromkeys(cleaned)) storage = _storage() - # Verify permission to read bucket contents try: _authorize_ui(principal, bucket_name, "read") except IamError as exc: return jsonify({"error": str(exc)}), 403 - # Create ZIP archive of selected objects buffer = io.BytesIO() with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf: for key in unique_keys: @@ -795,7 +789,6 @@ def bulk_download_objects(bucket_name: str): path = storage.get_object_path(bucket_name, key) zf.write(path, arcname=key) except (StorageError, IamError): - # Skip objects that can't be accessed continue buffer.seek(0) @@ -846,7 +839,6 @@ def object_preview(bucket_name: str, object_key: str) -> Response: download = request.args.get("download") == "1" - # Check if object is encrypted and needs decryption is_encrypted = "x-amz-server-side-encryption" in metadata if is_encrypted and hasattr(storage, 'get_object_data'): try: @@ -882,7 +874,6 @@ def object_presign(bucket_name: str, object_key: str): encoded_key = quote(object_key, safe="/") url = f"{api_base}/presign/{bucket_name}/{encoded_key}" - # Use API base URL for forwarded headers so presigned URLs point to API, not UI parsed_api = urlparse(api_base) headers = _api_headers() headers["X-Forwarded-Host"] = parsed_api.netloc or "127.0.0.1:5000" @@ -1027,7 +1018,6 @@ def update_bucket_quota(bucket_name: str): """Update bucket quota configuration (admin only).""" principal = _current_principal() - # Quota management is admin-only is_admin = False try: _iam().authorize(principal, None, "iam:list_users") @@ -1049,7 +1039,6 @@ def update_bucket_quota(bucket_name: str): flash(_friendly_error_message(exc), "danger") return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) - # Parse quota values max_mb_str = request.form.get("max_mb", "").strip() max_objects_str = request.form.get("max_objects", "").strip() @@ -1061,7 +1050,7 @@ def update_bucket_quota(bucket_name: str): max_mb = int(max_mb_str) if max_mb < 1: raise ValueError("Size must be at least 1 MB") - max_bytes = max_mb * 1024 * 1024 # Convert MB to bytes + max_bytes = max_mb * 1024 * 1024 except ValueError as exc: flash(f"Invalid size value: {exc}", "danger") return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) @@ -1114,7 +1103,6 @@ def update_bucket_encryption(bucket_name: str): flash("Invalid encryption algorithm", "danger") return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) - # Build encryption configuration in AWS S3 format encryption_config: dict[str, Any] = { "Rules": [ { @@ -1505,7 +1493,6 @@ def update_bucket_replication(bucket_name: str): if rule: rule.enabled = True _replication().set_rule(rule) - # When resuming, sync any pending objects that accumulated while paused if rule.mode == REPLICATION_MODE_ALL: _replication().replicate_existing_objects(bucket_name) flash("Replication resumed. Syncing pending objects in background.", "success") diff --git a/pytest.ini b/pytest.ini index 3271334..d1a0271 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] testpaths = tests norecursedirs = data .git __pycache__ .venv +markers = + integration: marks tests as integration tests (may require external services) diff --git a/tests/test_access_logging.py b/tests/test_access_logging.py new file mode 100644 index 0000000..b2fb12a --- /dev/null +++ b/tests/test_access_logging.py @@ -0,0 +1,339 @@ +import io +import json +import time +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from app.access_logging import ( + AccessLogEntry, + AccessLoggingService, + LoggingConfiguration, +) +from app.storage import ObjectStorage + + +class TestAccessLogEntry: + def test_default_values(self): + entry = AccessLogEntry() + assert entry.bucket_owner == "-" + assert entry.bucket == "-" + assert entry.remote_ip == "-" + assert entry.requester == "-" + assert entry.operation == "-" + assert entry.http_status == 200 + assert len(entry.request_id) == 16 + + def test_to_log_line(self): + entry = AccessLogEntry( + bucket_owner="owner123", + bucket="my-bucket", + remote_ip="192.168.1.1", + requester="user456", + request_id="REQ123456789012", + operation="REST.PUT.OBJECT", + key="test/key.txt", + request_uri="PUT /my-bucket/test/key.txt HTTP/1.1", + http_status=200, + bytes_sent=1024, + object_size=2048, + total_time_ms=150, + referrer="http://example.com", + user_agent="aws-cli/2.0", + version_id="v1", + ) + log_line = entry.to_log_line() + + assert "owner123" in log_line + assert "my-bucket" in log_line + assert "192.168.1.1" in log_line + assert "user456" in log_line + assert "REST.PUT.OBJECT" in log_line + assert "test/key.txt" in log_line + assert "200" in log_line + + def test_to_dict(self): + entry = AccessLogEntry( + bucket_owner="owner", + bucket="bucket", + remote_ip="10.0.0.1", + requester="admin", + request_id="ABC123", + operation="REST.GET.OBJECT", + key="file.txt", + request_uri="GET /bucket/file.txt HTTP/1.1", + http_status=200, + bytes_sent=512, + object_size=512, + total_time_ms=50, + ) + result = entry.to_dict() + + assert result["bucket_owner"] == "owner" + assert result["bucket"] == "bucket" + assert result["remote_ip"] == "10.0.0.1" + assert result["requester"] == "admin" + assert result["operation"] == "REST.GET.OBJECT" + assert result["key"] == "file.txt" + assert result["http_status"] == 200 + assert result["bytes_sent"] == 512 + + +class TestLoggingConfiguration: + def test_default_values(self): + config = LoggingConfiguration(target_bucket="log-bucket") + assert config.target_bucket == "log-bucket" + assert config.target_prefix == "" + assert config.enabled is True + + def test_to_dict(self): + config = LoggingConfiguration( + target_bucket="logs", + target_prefix="access-logs/", + enabled=True, + ) + result = config.to_dict() + + assert "LoggingEnabled" in result + assert result["LoggingEnabled"]["TargetBucket"] == "logs" + assert result["LoggingEnabled"]["TargetPrefix"] == "access-logs/" + + def test_from_dict(self): + data = { + "LoggingEnabled": { + "TargetBucket": "my-logs", + "TargetPrefix": "bucket-logs/", + } + } + config = LoggingConfiguration.from_dict(data) + + assert config is not None + assert config.target_bucket == "my-logs" + assert config.target_prefix == "bucket-logs/" + assert config.enabled is True + + def test_from_dict_no_logging(self): + data = {} + config = LoggingConfiguration.from_dict(data) + assert config is None + + +@pytest.fixture +def storage(tmp_path: Path): + storage_root = tmp_path / "data" + storage_root.mkdir(parents=True) + return ObjectStorage(storage_root) + + +@pytest.fixture +def logging_service(tmp_path: Path, storage): + service = AccessLoggingService( + tmp_path, + flush_interval=3600, + max_buffer_size=10, + ) + service.set_storage(storage) + yield service + service.shutdown() + + +class TestAccessLoggingService: + def test_get_bucket_logging_not_configured(self, logging_service): + result = logging_service.get_bucket_logging("unconfigured-bucket") + assert result is None + + def test_set_and_get_bucket_logging(self, logging_service): + config = LoggingConfiguration( + target_bucket="log-bucket", + target_prefix="logs/", + ) + logging_service.set_bucket_logging("source-bucket", config) + + retrieved = logging_service.get_bucket_logging("source-bucket") + assert retrieved is not None + assert retrieved.target_bucket == "log-bucket" + assert retrieved.target_prefix == "logs/" + + def test_delete_bucket_logging(self, logging_service): + config = LoggingConfiguration(target_bucket="logs") + logging_service.set_bucket_logging("to-delete", config) + assert logging_service.get_bucket_logging("to-delete") is not None + + logging_service.delete_bucket_logging("to-delete") + logging_service._configs.clear() + assert logging_service.get_bucket_logging("to-delete") is None + + def test_log_request_no_config(self, logging_service): + logging_service.log_request( + "no-config-bucket", + operation="REST.GET.OBJECT", + key="test.txt", + ) + stats = logging_service.get_stats() + assert stats["buffered_entries"] == 0 + + def test_log_request_with_config(self, logging_service, storage): + storage.create_bucket("log-target") + + config = LoggingConfiguration( + target_bucket="log-target", + target_prefix="access/", + ) + logging_service.set_bucket_logging("source-bucket", config) + + logging_service.log_request( + "source-bucket", + operation="REST.PUT.OBJECT", + key="uploaded.txt", + remote_ip="192.168.1.100", + requester="test-user", + http_status=200, + bytes_sent=1024, + ) + + stats = logging_service.get_stats() + assert stats["buffered_entries"] == 1 + + def test_log_request_disabled_config(self, logging_service): + config = LoggingConfiguration( + target_bucket="logs", + enabled=False, + ) + logging_service.set_bucket_logging("disabled-bucket", config) + + logging_service.log_request( + "disabled-bucket", + operation="REST.GET.OBJECT", + key="test.txt", + ) + + stats = logging_service.get_stats() + assert stats["buffered_entries"] == 0 + + def test_flush_buffer(self, logging_service, storage): + storage.create_bucket("flush-target") + + config = LoggingConfiguration( + target_bucket="flush-target", + target_prefix="logs/", + ) + logging_service.set_bucket_logging("flush-source", config) + + for i in range(3): + logging_service.log_request( + "flush-source", + operation="REST.GET.OBJECT", + key=f"file{i}.txt", + ) + + logging_service.flush() + + objects = storage.list_objects_all("flush-target") + assert len(objects) >= 1 + + def test_auto_flush_on_buffer_size(self, logging_service, storage): + storage.create_bucket("auto-flush-target") + + config = LoggingConfiguration( + target_bucket="auto-flush-target", + target_prefix="", + ) + logging_service.set_bucket_logging("auto-source", config) + + for i in range(15): + logging_service.log_request( + "auto-source", + operation="REST.GET.OBJECT", + key=f"file{i}.txt", + ) + + objects = storage.list_objects_all("auto-flush-target") + assert len(objects) >= 1 + + def test_get_stats(self, logging_service, storage): + storage.create_bucket("stats-target") + config = LoggingConfiguration(target_bucket="stats-target") + logging_service.set_bucket_logging("stats-bucket", config) + + logging_service.log_request( + "stats-bucket", + operation="REST.GET.OBJECT", + key="test.txt", + ) + + stats = logging_service.get_stats() + assert "buffered_entries" in stats + assert "target_buckets" in stats + assert stats["buffered_entries"] >= 1 + + def test_shutdown_flushes_buffer(self, tmp_path, storage): + storage.create_bucket("shutdown-target") + + service = AccessLoggingService(tmp_path, flush_interval=3600, max_buffer_size=100) + service.set_storage(storage) + + config = LoggingConfiguration(target_bucket="shutdown-target") + service.set_bucket_logging("shutdown-source", config) + + service.log_request( + "shutdown-source", + operation="REST.PUT.OBJECT", + key="final.txt", + ) + + service.shutdown() + + objects = storage.list_objects_all("shutdown-target") + assert len(objects) >= 1 + + def test_logging_caching(self, logging_service): + config = LoggingConfiguration(target_bucket="cached-logs") + logging_service.set_bucket_logging("cached-bucket", config) + + logging_service.get_bucket_logging("cached-bucket") + assert "cached-bucket" in logging_service._configs + + def test_log_request_all_fields(self, logging_service, storage): + storage.create_bucket("detailed-target") + + config = LoggingConfiguration(target_bucket="detailed-target", target_prefix="detailed/") + logging_service.set_bucket_logging("detailed-source", config) + + logging_service.log_request( + "detailed-source", + operation="REST.PUT.OBJECT", + key="detailed/file.txt", + remote_ip="10.0.0.1", + requester="admin-user", + request_uri="PUT /detailed-source/detailed/file.txt HTTP/1.1", + http_status=201, + error_code="", + bytes_sent=2048, + object_size=2048, + total_time_ms=100, + referrer="http://admin.example.com", + user_agent="curl/7.68.0", + version_id="v1.0", + request_id="CUSTOM_REQ_ID", + ) + + stats = logging_service.get_stats() + assert stats["buffered_entries"] == 1 + + def test_failed_flush_returns_to_buffer(self, logging_service): + config = LoggingConfiguration(target_bucket="nonexistent-target") + logging_service.set_bucket_logging("fail-source", config) + + logging_service.log_request( + "fail-source", + operation="REST.GET.OBJECT", + key="test.txt", + ) + + initial_count = logging_service.get_stats()["buffered_entries"] + logging_service.flush() + + final_count = logging_service.get_stats()["buffered_entries"] + assert final_count >= initial_count diff --git a/tests/test_acl.py b/tests/test_acl.py new file mode 100644 index 0000000..a5b15ab --- /dev/null +++ b/tests/test_acl.py @@ -0,0 +1,284 @@ +import json +from pathlib import Path + +import pytest + +from app.acl import ( + Acl, + AclGrant, + AclService, + ACL_PERMISSION_FULL_CONTROL, + ACL_PERMISSION_READ, + ACL_PERMISSION_WRITE, + ACL_PERMISSION_READ_ACP, + ACL_PERMISSION_WRITE_ACP, + GRANTEE_ALL_USERS, + GRANTEE_AUTHENTICATED_USERS, + PERMISSION_TO_ACTIONS, + create_canned_acl, + CANNED_ACLS, +) + + +class TestAclGrant: + def test_to_dict(self): + grant = AclGrant(grantee="user123", permission=ACL_PERMISSION_READ) + result = grant.to_dict() + assert result == {"grantee": "user123", "permission": "READ"} + + def test_from_dict(self): + data = {"grantee": "admin", "permission": "FULL_CONTROL"} + grant = AclGrant.from_dict(data) + assert grant.grantee == "admin" + assert grant.permission == ACL_PERMISSION_FULL_CONTROL + + +class TestAcl: + def test_to_dict(self): + acl = Acl( + owner="owner-user", + grants=[ + AclGrant(grantee="owner-user", permission=ACL_PERMISSION_FULL_CONTROL), + AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_READ), + ], + ) + result = acl.to_dict() + assert result["owner"] == "owner-user" + assert len(result["grants"]) == 2 + assert result["grants"][0]["grantee"] == "owner-user" + assert result["grants"][1]["grantee"] == "*" + + def test_from_dict(self): + data = { + "owner": "the-owner", + "grants": [ + {"grantee": "the-owner", "permission": "FULL_CONTROL"}, + {"grantee": "authenticated", "permission": "READ"}, + ], + } + acl = Acl.from_dict(data) + assert acl.owner == "the-owner" + assert len(acl.grants) == 2 + assert acl.grants[0].grantee == "the-owner" + assert acl.grants[1].grantee == GRANTEE_AUTHENTICATED_USERS + + def test_from_dict_empty_grants(self): + data = {"owner": "solo-owner"} + acl = Acl.from_dict(data) + assert acl.owner == "solo-owner" + assert len(acl.grants) == 0 + + def test_get_allowed_actions_owner(self): + acl = Acl(owner="owner123", grants=[]) + actions = acl.get_allowed_actions("owner123", is_authenticated=True) + assert actions == PERMISSION_TO_ACTIONS[ACL_PERMISSION_FULL_CONTROL] + + def test_get_allowed_actions_all_users(self): + acl = Acl( + owner="owner", + grants=[AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_READ)], + ) + actions = acl.get_allowed_actions(None, is_authenticated=False) + assert "read" in actions + assert "list" in actions + assert "write" not in actions + + def test_get_allowed_actions_authenticated_users(self): + acl = Acl( + owner="owner", + grants=[AclGrant(grantee=GRANTEE_AUTHENTICATED_USERS, permission=ACL_PERMISSION_WRITE)], + ) + actions_authenticated = acl.get_allowed_actions("some-user", is_authenticated=True) + assert "write" in actions_authenticated + assert "delete" in actions_authenticated + + actions_anonymous = acl.get_allowed_actions(None, is_authenticated=False) + assert "write" not in actions_anonymous + + def test_get_allowed_actions_specific_grantee(self): + acl = Acl( + owner="owner", + grants=[ + AclGrant(grantee="user-abc", permission=ACL_PERMISSION_READ), + AclGrant(grantee="user-xyz", permission=ACL_PERMISSION_WRITE), + ], + ) + abc_actions = acl.get_allowed_actions("user-abc", is_authenticated=True) + assert "read" in abc_actions + assert "list" in abc_actions + assert "write" not in abc_actions + + xyz_actions = acl.get_allowed_actions("user-xyz", is_authenticated=True) + assert "write" in xyz_actions + assert "read" not in xyz_actions + + def test_get_allowed_actions_combined(self): + acl = Acl( + owner="owner", + grants=[ + AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_READ), + AclGrant(grantee="special-user", permission=ACL_PERMISSION_WRITE), + ], + ) + actions = acl.get_allowed_actions("special-user", is_authenticated=True) + assert "read" in actions + assert "list" in actions + assert "write" in actions + assert "delete" in actions + + +class TestCannedAcls: + def test_private_acl(self): + acl = create_canned_acl("private", "the-owner") + assert acl.owner == "the-owner" + assert len(acl.grants) == 1 + assert acl.grants[0].grantee == "the-owner" + assert acl.grants[0].permission == ACL_PERMISSION_FULL_CONTROL + + def test_public_read_acl(self): + acl = create_canned_acl("public-read", "owner") + assert acl.owner == "owner" + has_owner_full_control = any( + g.grantee == "owner" and g.permission == ACL_PERMISSION_FULL_CONTROL for g in acl.grants + ) + has_public_read = any( + g.grantee == GRANTEE_ALL_USERS and g.permission == ACL_PERMISSION_READ for g in acl.grants + ) + assert has_owner_full_control + assert has_public_read + + def test_public_read_write_acl(self): + acl = create_canned_acl("public-read-write", "owner") + assert acl.owner == "owner" + has_public_read = any( + g.grantee == GRANTEE_ALL_USERS and g.permission == ACL_PERMISSION_READ for g in acl.grants + ) + has_public_write = any( + g.grantee == GRANTEE_ALL_USERS and g.permission == ACL_PERMISSION_WRITE for g in acl.grants + ) + assert has_public_read + assert has_public_write + + def test_authenticated_read_acl(self): + acl = create_canned_acl("authenticated-read", "owner") + has_authenticated_read = any( + g.grantee == GRANTEE_AUTHENTICATED_USERS and g.permission == ACL_PERMISSION_READ for g in acl.grants + ) + assert has_authenticated_read + + def test_unknown_canned_acl_defaults_to_private(self): + acl = create_canned_acl("unknown-acl", "owner") + private_acl = create_canned_acl("private", "owner") + assert acl.to_dict() == private_acl.to_dict() + + +@pytest.fixture +def acl_service(tmp_path: Path): + return AclService(tmp_path) + + +class TestAclService: + def test_get_bucket_acl_not_exists(self, acl_service): + result = acl_service.get_bucket_acl("nonexistent-bucket") + assert result is None + + def test_set_and_get_bucket_acl(self, acl_service): + acl = Acl( + owner="bucket-owner", + grants=[AclGrant(grantee="bucket-owner", permission=ACL_PERMISSION_FULL_CONTROL)], + ) + acl_service.set_bucket_acl("my-bucket", acl) + + retrieved = acl_service.get_bucket_acl("my-bucket") + assert retrieved is not None + assert retrieved.owner == "bucket-owner" + assert len(retrieved.grants) == 1 + + def test_bucket_acl_caching(self, acl_service): + acl = Acl(owner="cached-owner", grants=[]) + acl_service.set_bucket_acl("cached-bucket", acl) + + acl_service.get_bucket_acl("cached-bucket") + assert "cached-bucket" in acl_service._bucket_acl_cache + + retrieved = acl_service.get_bucket_acl("cached-bucket") + assert retrieved.owner == "cached-owner" + + def test_set_bucket_canned_acl(self, acl_service): + result = acl_service.set_bucket_canned_acl("new-bucket", "public-read", "the-owner") + assert result.owner == "the-owner" + + retrieved = acl_service.get_bucket_acl("new-bucket") + assert retrieved is not None + has_public_read = any( + g.grantee == GRANTEE_ALL_USERS and g.permission == ACL_PERMISSION_READ for g in retrieved.grants + ) + assert has_public_read + + def test_delete_bucket_acl(self, acl_service): + acl = Acl(owner="to-delete-owner", grants=[]) + acl_service.set_bucket_acl("delete-me", acl) + assert acl_service.get_bucket_acl("delete-me") is not None + + acl_service.delete_bucket_acl("delete-me") + acl_service._bucket_acl_cache.clear() + assert acl_service.get_bucket_acl("delete-me") is None + + def test_evaluate_bucket_acl_allowed(self, acl_service): + acl = Acl( + owner="owner", + grants=[AclGrant(grantee=GRANTEE_ALL_USERS, permission=ACL_PERMISSION_READ)], + ) + acl_service.set_bucket_acl("public-bucket", acl) + + result = acl_service.evaluate_bucket_acl("public-bucket", None, "read", is_authenticated=False) + assert result is True + + def test_evaluate_bucket_acl_denied(self, acl_service): + acl = Acl( + owner="owner", + grants=[AclGrant(grantee="owner", permission=ACL_PERMISSION_FULL_CONTROL)], + ) + acl_service.set_bucket_acl("private-bucket", acl) + + result = acl_service.evaluate_bucket_acl("private-bucket", "other-user", "write", is_authenticated=True) + assert result is False + + def test_evaluate_bucket_acl_no_acl(self, acl_service): + result = acl_service.evaluate_bucket_acl("no-acl-bucket", "anyone", "read") + assert result is False + + def test_get_object_acl_from_metadata(self, acl_service): + metadata = { + "__acl__": { + "owner": "object-owner", + "grants": [{"grantee": "object-owner", "permission": "FULL_CONTROL"}], + } + } + result = acl_service.get_object_acl("bucket", "key", metadata) + assert result is not None + assert result.owner == "object-owner" + + def test_get_object_acl_no_acl_in_metadata(self, acl_service): + metadata = {"Content-Type": "text/plain"} + result = acl_service.get_object_acl("bucket", "key", metadata) + assert result is None + + def test_create_object_acl_metadata(self, acl_service): + acl = Acl(owner="obj-owner", grants=[]) + result = acl_service.create_object_acl_metadata(acl) + assert "__acl__" in result + assert result["__acl__"]["owner"] == "obj-owner" + + def test_evaluate_object_acl(self, acl_service): + metadata = { + "__acl__": { + "owner": "obj-owner", + "grants": [{"grantee": "*", "permission": "READ"}], + } + } + result = acl_service.evaluate_object_acl(metadata, None, "read", is_authenticated=False) + assert result is True + + result = acl_service.evaluate_object_acl(metadata, None, "write", is_authenticated=False) + assert result is False diff --git a/tests/test_lifecycle.py b/tests/test_lifecycle.py new file mode 100644 index 0000000..a92cf4c --- /dev/null +++ b/tests/test_lifecycle.py @@ -0,0 +1,238 @@ +import io +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from app.lifecycle import LifecycleManager, LifecycleResult +from app.storage import ObjectStorage + + +@pytest.fixture +def storage(tmp_path: Path): + storage_root = tmp_path / "data" + storage_root.mkdir(parents=True) + return ObjectStorage(storage_root) + + +@pytest.fixture +def lifecycle_manager(storage): + manager = LifecycleManager(storage, interval_seconds=3600) + yield manager + manager.stop() + + +class TestLifecycleResult: + def test_default_values(self): + result = LifecycleResult(bucket_name="test-bucket") + assert result.bucket_name == "test-bucket" + assert result.objects_deleted == 0 + assert result.versions_deleted == 0 + assert result.uploads_aborted == 0 + assert result.errors == [] + assert result.execution_time_seconds == 0.0 + + +class TestLifecycleManager: + def test_start_and_stop(self, lifecycle_manager): + lifecycle_manager.start() + assert lifecycle_manager._timer is not None + assert lifecycle_manager._shutdown is False + + lifecycle_manager.stop() + assert lifecycle_manager._shutdown is True + assert lifecycle_manager._timer is None + + def test_start_only_once(self, lifecycle_manager): + lifecycle_manager.start() + first_timer = lifecycle_manager._timer + + lifecycle_manager.start() + assert lifecycle_manager._timer is first_timer + + def test_enforce_rules_no_lifecycle(self, lifecycle_manager, storage): + storage.create_bucket("no-lifecycle-bucket") + + result = lifecycle_manager.enforce_rules("no-lifecycle-bucket") + assert result.bucket_name == "no-lifecycle-bucket" + assert result.objects_deleted == 0 + + def test_enforce_rules_disabled_rule(self, lifecycle_manager, storage): + storage.create_bucket("disabled-bucket") + storage.set_bucket_lifecycle("disabled-bucket", [ + { + "ID": "disabled-rule", + "Status": "Disabled", + "Prefix": "", + "Expiration": {"Days": 1}, + } + ]) + + old_object = storage.put_object( + "disabled-bucket", + "old-file.txt", + io.BytesIO(b"old content"), + ) + + result = lifecycle_manager.enforce_rules("disabled-bucket") + assert result.objects_deleted == 0 + + def test_enforce_expiration_by_days(self, lifecycle_manager, storage): + storage.create_bucket("expire-bucket") + storage.set_bucket_lifecycle("expire-bucket", [ + { + "ID": "expire-30-days", + "Status": "Enabled", + "Prefix": "", + "Expiration": {"Days": 30}, + } + ]) + + storage.put_object( + "expire-bucket", + "recent-file.txt", + io.BytesIO(b"recent content"), + ) + + result = lifecycle_manager.enforce_rules("expire-bucket") + assert result.objects_deleted == 0 + + def test_enforce_expiration_with_prefix(self, lifecycle_manager, storage): + storage.create_bucket("prefix-bucket") + storage.set_bucket_lifecycle("prefix-bucket", [ + { + "ID": "expire-logs", + "Status": "Enabled", + "Prefix": "logs/", + "Expiration": {"Days": 1}, + } + ]) + + storage.put_object("prefix-bucket", "logs/old.log", io.BytesIO(b"log data")) + storage.put_object("prefix-bucket", "data/keep.txt", io.BytesIO(b"keep this")) + + result = lifecycle_manager.enforce_rules("prefix-bucket") + + def test_enforce_all_buckets(self, lifecycle_manager, storage): + storage.create_bucket("bucket1") + storage.create_bucket("bucket2") + + results = lifecycle_manager.enforce_all_buckets() + assert isinstance(results, dict) + + def test_run_now_single_bucket(self, lifecycle_manager, storage): + storage.create_bucket("run-now-bucket") + + results = lifecycle_manager.run_now("run-now-bucket") + assert "run-now-bucket" in results + + def test_run_now_all_buckets(self, lifecycle_manager, storage): + storage.create_bucket("all-bucket-1") + storage.create_bucket("all-bucket-2") + + results = lifecycle_manager.run_now() + assert isinstance(results, dict) + + def test_enforce_abort_multipart(self, lifecycle_manager, storage): + storage.create_bucket("multipart-bucket") + storage.set_bucket_lifecycle("multipart-bucket", [ + { + "ID": "abort-old-uploads", + "Status": "Enabled", + "Prefix": "", + "AbortIncompleteMultipartUpload": {"DaysAfterInitiation": 7}, + } + ]) + + upload_id = storage.initiate_multipart_upload("multipart-bucket", "large-file.bin") + + result = lifecycle_manager.enforce_rules("multipart-bucket") + assert result.uploads_aborted == 0 + + def test_enforce_noncurrent_version_expiration(self, lifecycle_manager, storage): + storage.create_bucket("versioned-bucket") + storage.set_bucket_versioning("versioned-bucket", True) + storage.set_bucket_lifecycle("versioned-bucket", [ + { + "ID": "expire-old-versions", + "Status": "Enabled", + "Prefix": "", + "NoncurrentVersionExpiration": {"NoncurrentDays": 30}, + } + ]) + + storage.put_object("versioned-bucket", "file.txt", io.BytesIO(b"v1")) + storage.put_object("versioned-bucket", "file.txt", io.BytesIO(b"v2")) + + result = lifecycle_manager.enforce_rules("versioned-bucket") + assert result.bucket_name == "versioned-bucket" + + def test_execution_time_tracking(self, lifecycle_manager, storage): + storage.create_bucket("timed-bucket") + storage.set_bucket_lifecycle("timed-bucket", [ + { + "ID": "timer-test", + "Status": "Enabled", + "Expiration": {"Days": 1}, + } + ]) + + result = lifecycle_manager.enforce_rules("timed-bucket") + assert result.execution_time_seconds >= 0 + + def test_enforce_rules_with_error(self, lifecycle_manager, storage): + result = lifecycle_manager.enforce_rules("nonexistent-bucket") + assert len(result.errors) > 0 or result.objects_deleted == 0 + + def test_lifecycle_with_date_expiration(self, lifecycle_manager, storage): + storage.create_bucket("date-bucket") + past_date = (datetime.now(timezone.utc) - timedelta(days=1)).strftime("%Y-%m-%dT00:00:00Z") + storage.set_bucket_lifecycle("date-bucket", [ + { + "ID": "expire-by-date", + "Status": "Enabled", + "Prefix": "", + "Expiration": {"Date": past_date}, + } + ]) + + storage.put_object("date-bucket", "should-expire.txt", io.BytesIO(b"content")) + + result = lifecycle_manager.enforce_rules("date-bucket") + + def test_enforce_with_filter_prefix(self, lifecycle_manager, storage): + storage.create_bucket("filter-bucket") + storage.set_bucket_lifecycle("filter-bucket", [ + { + "ID": "filter-prefix-rule", + "Status": "Enabled", + "Filter": {"Prefix": "archive/"}, + "Expiration": {"Days": 1}, + } + ]) + + result = lifecycle_manager.enforce_rules("filter-bucket") + assert result.bucket_name == "filter-bucket" + + +class TestLifecycleManagerScheduling: + def test_schedule_next_respects_shutdown(self, storage): + manager = LifecycleManager(storage, interval_seconds=1) + manager._shutdown = True + manager._schedule_next() + assert manager._timer is None + + @patch.object(LifecycleManager, "enforce_all_buckets") + def test_run_enforcement_catches_exceptions(self, mock_enforce, storage): + mock_enforce.side_effect = Exception("Test error") + manager = LifecycleManager(storage, interval_seconds=3600) + manager._shutdown = True + manager._run_enforcement() + + def test_shutdown_flag_prevents_scheduling(self, storage): + manager = LifecycleManager(storage, interval_seconds=1) + manager.start() + manager.stop() + assert manager._shutdown is True diff --git a/tests/test_notifications.py b/tests/test_notifications.py new file mode 100644 index 0000000..36d7e03 --- /dev/null +++ b/tests/test_notifications.py @@ -0,0 +1,374 @@ +import json +import time +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from app.notifications import ( + NotificationConfiguration, + NotificationEvent, + NotificationService, + WebhookDestination, +) + + +class TestNotificationEvent: + def test_default_values(self): + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="test-bucket", + object_key="test/key.txt", + ) + assert event.event_name == "s3:ObjectCreated:Put" + assert event.bucket_name == "test-bucket" + assert event.object_key == "test/key.txt" + assert event.object_size == 0 + assert event.etag == "" + assert event.version_id is None + assert event.request_id != "" + + def test_to_s3_event(self): + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="my-bucket", + object_key="my/object.txt", + object_size=1024, + etag="abc123", + version_id="v1", + source_ip="192.168.1.1", + user_identity="user123", + ) + result = event.to_s3_event() + + assert "Records" in result + assert len(result["Records"]) == 1 + + record = result["Records"][0] + assert record["eventVersion"] == "2.1" + assert record["eventSource"] == "myfsio:s3" + assert record["eventName"] == "s3:ObjectCreated:Put" + assert record["s3"]["bucket"]["name"] == "my-bucket" + assert record["s3"]["object"]["key"] == "my/object.txt" + assert record["s3"]["object"]["size"] == 1024 + assert record["s3"]["object"]["eTag"] == "abc123" + assert record["s3"]["object"]["versionId"] == "v1" + assert record["userIdentity"]["principalId"] == "user123" + assert record["requestParameters"]["sourceIPAddress"] == "192.168.1.1" + + +class TestWebhookDestination: + def test_default_values(self): + dest = WebhookDestination(url="http://example.com/webhook") + assert dest.url == "http://example.com/webhook" + assert dest.headers == {} + assert dest.timeout_seconds == 30 + assert dest.retry_count == 3 + assert dest.retry_delay_seconds == 1 + + def test_to_dict(self): + dest = WebhookDestination( + url="http://example.com/webhook", + headers={"X-Custom": "value"}, + timeout_seconds=60, + retry_count=5, + retry_delay_seconds=2, + ) + result = dest.to_dict() + assert result["url"] == "http://example.com/webhook" + assert result["headers"] == {"X-Custom": "value"} + assert result["timeout_seconds"] == 60 + assert result["retry_count"] == 5 + assert result["retry_delay_seconds"] == 2 + + def test_from_dict(self): + data = { + "url": "http://hook.example.com", + "headers": {"Authorization": "Bearer token"}, + "timeout_seconds": 45, + "retry_count": 2, + "retry_delay_seconds": 5, + } + dest = WebhookDestination.from_dict(data) + assert dest.url == "http://hook.example.com" + assert dest.headers == {"Authorization": "Bearer token"} + assert dest.timeout_seconds == 45 + assert dest.retry_count == 2 + assert dest.retry_delay_seconds == 5 + + +class TestNotificationConfiguration: + def test_matches_event_exact_match(self): + config = NotificationConfiguration( + id="config1", + events=["s3:ObjectCreated:Put"], + destination=WebhookDestination(url="http://example.com"), + ) + assert config.matches_event("s3:ObjectCreated:Put", "any/key.txt") is True + assert config.matches_event("s3:ObjectCreated:Post", "any/key.txt") is False + + def test_matches_event_wildcard(self): + config = NotificationConfiguration( + id="config1", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com"), + ) + assert config.matches_event("s3:ObjectCreated:Put", "key.txt") is True + assert config.matches_event("s3:ObjectCreated:Copy", "key.txt") is True + assert config.matches_event("s3:ObjectRemoved:Delete", "key.txt") is False + + def test_matches_event_with_prefix_filter(self): + config = NotificationConfiguration( + id="config1", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com"), + prefix_filter="logs/", + ) + assert config.matches_event("s3:ObjectCreated:Put", "logs/app.log") is True + assert config.matches_event("s3:ObjectCreated:Put", "data/file.txt") is False + + def test_matches_event_with_suffix_filter(self): + config = NotificationConfiguration( + id="config1", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com"), + suffix_filter=".jpg", + ) + assert config.matches_event("s3:ObjectCreated:Put", "photos/image.jpg") is True + assert config.matches_event("s3:ObjectCreated:Put", "photos/image.png") is False + + def test_matches_event_with_both_filters(self): + config = NotificationConfiguration( + id="config1", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com"), + prefix_filter="images/", + suffix_filter=".png", + ) + assert config.matches_event("s3:ObjectCreated:Put", "images/photo.png") is True + assert config.matches_event("s3:ObjectCreated:Put", "images/photo.jpg") is False + assert config.matches_event("s3:ObjectCreated:Put", "documents/file.png") is False + + def test_to_dict(self): + config = NotificationConfiguration( + id="my-config", + events=["s3:ObjectCreated:Put", "s3:ObjectRemoved:Delete"], + destination=WebhookDestination(url="http://example.com"), + prefix_filter="logs/", + suffix_filter=".log", + ) + result = config.to_dict() + assert result["Id"] == "my-config" + assert result["Events"] == ["s3:ObjectCreated:Put", "s3:ObjectRemoved:Delete"] + assert "Destination" in result + assert result["Filter"]["Key"]["FilterRules"][0]["Value"] == "logs/" + assert result["Filter"]["Key"]["FilterRules"][1]["Value"] == ".log" + + def test_from_dict(self): + data = { + "Id": "parsed-config", + "Events": ["s3:ObjectCreated:*"], + "Destination": {"url": "http://hook.example.com"}, + "Filter": { + "Key": { + "FilterRules": [ + {"Name": "prefix", "Value": "data/"}, + {"Name": "suffix", "Value": ".csv"}, + ] + } + }, + } + config = NotificationConfiguration.from_dict(data) + assert config.id == "parsed-config" + assert config.events == ["s3:ObjectCreated:*"] + assert config.destination.url == "http://hook.example.com" + assert config.prefix_filter == "data/" + assert config.suffix_filter == ".csv" + + +@pytest.fixture +def notification_service(tmp_path: Path): + service = NotificationService(tmp_path, worker_count=1) + yield service + service.shutdown() + + +class TestNotificationService: + def test_get_bucket_notifications_empty(self, notification_service): + result = notification_service.get_bucket_notifications("nonexistent-bucket") + assert result == [] + + def test_set_and_get_bucket_notifications(self, notification_service): + configs = [ + NotificationConfiguration( + id="config1", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com/webhook1"), + ), + NotificationConfiguration( + id="config2", + events=["s3:ObjectRemoved:*"], + destination=WebhookDestination(url="http://example.com/webhook2"), + ), + ] + notification_service.set_bucket_notifications("my-bucket", configs) + + retrieved = notification_service.get_bucket_notifications("my-bucket") + assert len(retrieved) == 2 + assert retrieved[0].id == "config1" + assert retrieved[1].id == "config2" + + def test_delete_bucket_notifications(self, notification_service): + configs = [ + NotificationConfiguration( + id="to-delete", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com"), + ), + ] + notification_service.set_bucket_notifications("delete-bucket", configs) + assert len(notification_service.get_bucket_notifications("delete-bucket")) == 1 + + notification_service.delete_bucket_notifications("delete-bucket") + notification_service._configs.clear() + assert len(notification_service.get_bucket_notifications("delete-bucket")) == 0 + + def test_emit_event_no_config(self, notification_service): + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="no-config-bucket", + object_key="test.txt", + ) + notification_service.emit_event(event) + assert notification_service._stats["events_queued"] == 0 + + def test_emit_event_matching_config(self, notification_service): + configs = [ + NotificationConfiguration( + id="match-config", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com/webhook"), + ), + ] + notification_service.set_bucket_notifications("event-bucket", configs) + + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="event-bucket", + object_key="test.txt", + ) + notification_service.emit_event(event) + assert notification_service._stats["events_queued"] == 1 + + def test_emit_event_non_matching_config(self, notification_service): + configs = [ + NotificationConfiguration( + id="delete-only", + events=["s3:ObjectRemoved:*"], + destination=WebhookDestination(url="http://example.com/webhook"), + ), + ] + notification_service.set_bucket_notifications("delete-bucket", configs) + + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="delete-bucket", + object_key="test.txt", + ) + notification_service.emit_event(event) + assert notification_service._stats["events_queued"] == 0 + + def test_emit_object_created(self, notification_service): + configs = [ + NotificationConfiguration( + id="create-config", + events=["s3:ObjectCreated:Put"], + destination=WebhookDestination(url="http://example.com/webhook"), + ), + ] + notification_service.set_bucket_notifications("create-bucket", configs) + + notification_service.emit_object_created( + "create-bucket", + "new-file.txt", + size=1024, + etag="abc123", + operation="Put", + ) + assert notification_service._stats["events_queued"] == 1 + + def test_emit_object_removed(self, notification_service): + configs = [ + NotificationConfiguration( + id="remove-config", + events=["s3:ObjectRemoved:Delete"], + destination=WebhookDestination(url="http://example.com/webhook"), + ), + ] + notification_service.set_bucket_notifications("remove-bucket", configs) + + notification_service.emit_object_removed( + "remove-bucket", + "deleted-file.txt", + operation="Delete", + ) + assert notification_service._stats["events_queued"] == 1 + + def test_get_stats(self, notification_service): + stats = notification_service.get_stats() + assert "events_queued" in stats + assert "events_sent" in stats + assert "events_failed" in stats + + @patch("app.notifications.requests.post") + def test_send_notification_success(self, mock_post, notification_service): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="test-bucket", + object_key="test.txt", + ) + destination = WebhookDestination(url="http://example.com/webhook") + + notification_service._send_notification(event, destination) + mock_post.assert_called_once() + + @patch("app.notifications.requests.post") + def test_send_notification_retry_on_failure(self, mock_post, notification_service): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_post.return_value = mock_response + + event = NotificationEvent( + event_name="s3:ObjectCreated:Put", + bucket_name="test-bucket", + object_key="test.txt", + ) + destination = WebhookDestination( + url="http://example.com/webhook", + retry_count=2, + retry_delay_seconds=0, + ) + + with pytest.raises(RuntimeError) as exc_info: + notification_service._send_notification(event, destination) + assert "Failed after 2 attempts" in str(exc_info.value) + assert mock_post.call_count == 2 + + def test_notification_caching(self, notification_service): + configs = [ + NotificationConfiguration( + id="cached-config", + events=["s3:ObjectCreated:*"], + destination=WebhookDestination(url="http://example.com"), + ), + ] + notification_service.set_bucket_notifications("cached-bucket", configs) + + notification_service.get_bucket_notifications("cached-bucket") + assert "cached-bucket" in notification_service._configs diff --git a/tests/test_object_lock.py b/tests/test_object_lock.py new file mode 100644 index 0000000..fa8da8b --- /dev/null +++ b/tests/test_object_lock.py @@ -0,0 +1,332 @@ +import json +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +from app.object_lock import ( + ObjectLockConfig, + ObjectLockError, + ObjectLockRetention, + ObjectLockService, + RetentionMode, +) + + +class TestRetentionMode: + def test_governance_mode(self): + assert RetentionMode.GOVERNANCE.value == "GOVERNANCE" + + def test_compliance_mode(self): + assert RetentionMode.COMPLIANCE.value == "COMPLIANCE" + + +class TestObjectLockRetention: + def test_to_dict(self): + retain_until = datetime(2025, 12, 31, 23, 59, 59, tzinfo=timezone.utc) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=retain_until, + ) + result = retention.to_dict() + assert result["Mode"] == "GOVERNANCE" + assert "2025-12-31" in result["RetainUntilDate"] + + def test_from_dict(self): + data = { + "Mode": "COMPLIANCE", + "RetainUntilDate": "2030-06-15T12:00:00+00:00", + } + retention = ObjectLockRetention.from_dict(data) + assert retention is not None + assert retention.mode == RetentionMode.COMPLIANCE + assert retention.retain_until_date.year == 2030 + + def test_from_dict_empty(self): + result = ObjectLockRetention.from_dict({}) + assert result is None + + def test_from_dict_missing_mode(self): + data = {"RetainUntilDate": "2030-06-15T12:00:00+00:00"} + result = ObjectLockRetention.from_dict(data) + assert result is None + + def test_from_dict_missing_date(self): + data = {"Mode": "GOVERNANCE"} + result = ObjectLockRetention.from_dict(data) + assert result is None + + def test_is_expired_future_date(self): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + assert retention.is_expired() is False + + def test_is_expired_past_date(self): + past = datetime.now(timezone.utc) - timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=past, + ) + assert retention.is_expired() is True + + +class TestObjectLockConfig: + def test_to_dict_enabled(self): + config = ObjectLockConfig(enabled=True) + result = config.to_dict() + assert result["ObjectLockEnabled"] == "Enabled" + + def test_to_dict_disabled(self): + config = ObjectLockConfig(enabled=False) + result = config.to_dict() + assert result["ObjectLockEnabled"] == "Disabled" + + def test_from_dict_enabled(self): + data = {"ObjectLockEnabled": "Enabled"} + config = ObjectLockConfig.from_dict(data) + assert config.enabled is True + + def test_from_dict_disabled(self): + data = {"ObjectLockEnabled": "Disabled"} + config = ObjectLockConfig.from_dict(data) + assert config.enabled is False + + def test_from_dict_with_default_retention_days(self): + data = { + "ObjectLockEnabled": "Enabled", + "Rule": { + "DefaultRetention": { + "Mode": "GOVERNANCE", + "Days": 30, + } + }, + } + config = ObjectLockConfig.from_dict(data) + assert config.enabled is True + assert config.default_retention is not None + assert config.default_retention.mode == RetentionMode.GOVERNANCE + + def test_from_dict_with_default_retention_years(self): + data = { + "ObjectLockEnabled": "Enabled", + "Rule": { + "DefaultRetention": { + "Mode": "COMPLIANCE", + "Years": 1, + } + }, + } + config = ObjectLockConfig.from_dict(data) + assert config.enabled is True + assert config.default_retention is not None + assert config.default_retention.mode == RetentionMode.COMPLIANCE + + +@pytest.fixture +def lock_service(tmp_path: Path): + return ObjectLockService(tmp_path) + + +class TestObjectLockService: + def test_get_bucket_lock_config_default(self, lock_service): + config = lock_service.get_bucket_lock_config("nonexistent-bucket") + assert config.enabled is False + assert config.default_retention is None + + def test_set_and_get_bucket_lock_config(self, lock_service): + config = ObjectLockConfig(enabled=True) + lock_service.set_bucket_lock_config("my-bucket", config) + + retrieved = lock_service.get_bucket_lock_config("my-bucket") + assert retrieved.enabled is True + + def test_enable_bucket_lock(self, lock_service): + lock_service.enable_bucket_lock("lock-bucket") + + config = lock_service.get_bucket_lock_config("lock-bucket") + assert config.enabled is True + + def test_is_bucket_lock_enabled(self, lock_service): + assert lock_service.is_bucket_lock_enabled("new-bucket") is False + + lock_service.enable_bucket_lock("new-bucket") + assert lock_service.is_bucket_lock_enabled("new-bucket") is True + + def test_get_object_retention_not_set(self, lock_service): + result = lock_service.get_object_retention("bucket", "key.txt") + assert result is None + + def test_set_and_get_object_retention(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "key.txt", retention) + + retrieved = lock_service.get_object_retention("bucket", "key.txt") + assert retrieved is not None + assert retrieved.mode == RetentionMode.GOVERNANCE + + def test_cannot_modify_compliance_retention(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.COMPLIANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "locked.txt", retention) + + new_retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future + timedelta(days=10), + ) + with pytest.raises(ObjectLockError) as exc_info: + lock_service.set_object_retention("bucket", "locked.txt", new_retention) + assert "COMPLIANCE" in str(exc_info.value) + + def test_cannot_modify_governance_without_bypass(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "gov.txt", retention) + + new_retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future + timedelta(days=10), + ) + with pytest.raises(ObjectLockError) as exc_info: + lock_service.set_object_retention("bucket", "gov.txt", new_retention) + assert "GOVERNANCE" in str(exc_info.value) + + def test_can_modify_governance_with_bypass(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "bypassable.txt", retention) + + new_retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future + timedelta(days=10), + ) + lock_service.set_object_retention("bucket", "bypassable.txt", new_retention, bypass_governance=True) + retrieved = lock_service.get_object_retention("bucket", "bypassable.txt") + assert retrieved.retain_until_date > future + + def test_can_modify_expired_retention(self, lock_service): + past = datetime.now(timezone.utc) - timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.COMPLIANCE, + retain_until_date=past, + ) + lock_service.set_object_retention("bucket", "expired.txt", retention) + + future = datetime.now(timezone.utc) + timedelta(days=30) + new_retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "expired.txt", new_retention) + retrieved = lock_service.get_object_retention("bucket", "expired.txt") + assert retrieved.mode == RetentionMode.GOVERNANCE + + def test_get_legal_hold_not_set(self, lock_service): + result = lock_service.get_legal_hold("bucket", "key.txt") + assert result is False + + def test_set_and_get_legal_hold(self, lock_service): + lock_service.set_legal_hold("bucket", "held.txt", True) + assert lock_service.get_legal_hold("bucket", "held.txt") is True + + lock_service.set_legal_hold("bucket", "held.txt", False) + assert lock_service.get_legal_hold("bucket", "held.txt") is False + + def test_can_delete_object_no_lock(self, lock_service): + can_delete, reason = lock_service.can_delete_object("bucket", "unlocked.txt") + assert can_delete is True + assert reason == "" + + def test_cannot_delete_object_with_legal_hold(self, lock_service): + lock_service.set_legal_hold("bucket", "held.txt", True) + + can_delete, reason = lock_service.can_delete_object("bucket", "held.txt") + assert can_delete is False + assert "legal hold" in reason.lower() + + def test_cannot_delete_object_with_compliance_retention(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.COMPLIANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "compliant.txt", retention) + + can_delete, reason = lock_service.can_delete_object("bucket", "compliant.txt") + assert can_delete is False + assert "COMPLIANCE" in reason + + def test_cannot_delete_governance_without_bypass(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "governed.txt", retention) + + can_delete, reason = lock_service.can_delete_object("bucket", "governed.txt") + assert can_delete is False + assert "GOVERNANCE" in reason + + def test_can_delete_governance_with_bypass(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "governed.txt", retention) + + can_delete, reason = lock_service.can_delete_object("bucket", "governed.txt", bypass_governance=True) + assert can_delete is True + assert reason == "" + + def test_can_delete_expired_retention(self, lock_service): + past = datetime.now(timezone.utc) - timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.COMPLIANCE, + retain_until_date=past, + ) + lock_service.set_object_retention("bucket", "expired.txt", retention) + + can_delete, reason = lock_service.can_delete_object("bucket", "expired.txt") + assert can_delete is True + + def test_can_overwrite_is_same_as_delete(self, lock_service): + future = datetime.now(timezone.utc) + timedelta(days=30) + retention = ObjectLockRetention( + mode=RetentionMode.GOVERNANCE, + retain_until_date=future, + ) + lock_service.set_object_retention("bucket", "overwrite.txt", retention) + + can_overwrite, _ = lock_service.can_overwrite_object("bucket", "overwrite.txt") + can_delete, _ = lock_service.can_delete_object("bucket", "overwrite.txt") + assert can_overwrite == can_delete + + def test_delete_object_lock_metadata(self, lock_service): + lock_service.set_legal_hold("bucket", "cleanup.txt", True) + lock_service.delete_object_lock_metadata("bucket", "cleanup.txt") + + assert lock_service.get_legal_hold("bucket", "cleanup.txt") is False + + def test_config_caching(self, lock_service): + config = ObjectLockConfig(enabled=True) + lock_service.set_bucket_lock_config("cached-bucket", config) + + lock_service.get_bucket_lock_config("cached-bucket") + assert "cached-bucket" in lock_service._config_cache diff --git a/tests/test_replication.py b/tests/test_replication.py new file mode 100644 index 0000000..3cb0c06 --- /dev/null +++ b/tests/test_replication.py @@ -0,0 +1,285 @@ +import json +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from app.connections import ConnectionStore, RemoteConnection +from app.replication import ( + ReplicationManager, + ReplicationRule, + ReplicationStats, + REPLICATION_MODE_ALL, + REPLICATION_MODE_NEW_ONLY, + _create_s3_client, +) +from app.storage import ObjectStorage + + +@pytest.fixture +def storage(tmp_path: Path): + storage_root = tmp_path / "data" + storage_root.mkdir(parents=True) + return ObjectStorage(storage_root) + + +@pytest.fixture +def connections(tmp_path: Path): + connections_path = tmp_path / "connections.json" + store = ConnectionStore(connections_path) + conn = RemoteConnection( + id="test-conn", + name="Test Remote", + endpoint_url="http://localhost:9000", + access_key="remote-access", + secret_key="remote-secret", + region="us-east-1", + ) + store.add(conn) + return store + + +@pytest.fixture +def replication_manager(storage, connections, tmp_path): + rules_path = tmp_path / "replication_rules.json" + manager = ReplicationManager(storage, connections, rules_path) + yield manager + manager.shutdown(wait=False) + + +class TestReplicationStats: + def test_to_dict(self): + stats = ReplicationStats( + objects_synced=10, + objects_pending=5, + objects_orphaned=2, + bytes_synced=1024, + last_sync_at=1234567890.0, + last_sync_key="test/key.txt", + ) + result = stats.to_dict() + assert result["objects_synced"] == 10 + assert result["objects_pending"] == 5 + assert result["objects_orphaned"] == 2 + assert result["bytes_synced"] == 1024 + assert result["last_sync_at"] == 1234567890.0 + assert result["last_sync_key"] == "test/key.txt" + + def test_from_dict(self): + data = { + "objects_synced": 15, + "objects_pending": 3, + "objects_orphaned": 1, + "bytes_synced": 2048, + "last_sync_at": 9876543210.0, + "last_sync_key": "another/key.txt", + } + stats = ReplicationStats.from_dict(data) + assert stats.objects_synced == 15 + assert stats.objects_pending == 3 + assert stats.objects_orphaned == 1 + assert stats.bytes_synced == 2048 + assert stats.last_sync_at == 9876543210.0 + assert stats.last_sync_key == "another/key.txt" + + def test_from_dict_with_defaults(self): + stats = ReplicationStats.from_dict({}) + assert stats.objects_synced == 0 + assert stats.objects_pending == 0 + assert stats.objects_orphaned == 0 + assert stats.bytes_synced == 0 + assert stats.last_sync_at is None + assert stats.last_sync_key is None + + +class TestReplicationRule: + def test_to_dict(self): + rule = ReplicationRule( + bucket_name="source-bucket", + target_connection_id="test-conn", + target_bucket="dest-bucket", + enabled=True, + mode=REPLICATION_MODE_ALL, + created_at=1234567890.0, + ) + result = rule.to_dict() + assert result["bucket_name"] == "source-bucket" + assert result["target_connection_id"] == "test-conn" + assert result["target_bucket"] == "dest-bucket" + assert result["enabled"] is True + assert result["mode"] == REPLICATION_MODE_ALL + assert result["created_at"] == 1234567890.0 + assert "stats" in result + + def test_from_dict(self): + data = { + "bucket_name": "my-bucket", + "target_connection_id": "conn-123", + "target_bucket": "remote-bucket", + "enabled": False, + "mode": REPLICATION_MODE_NEW_ONLY, + "created_at": 1111111111.0, + "stats": {"objects_synced": 5}, + } + rule = ReplicationRule.from_dict(data) + assert rule.bucket_name == "my-bucket" + assert rule.target_connection_id == "conn-123" + assert rule.target_bucket == "remote-bucket" + assert rule.enabled is False + assert rule.mode == REPLICATION_MODE_NEW_ONLY + assert rule.created_at == 1111111111.0 + assert rule.stats.objects_synced == 5 + + def test_from_dict_defaults_mode(self): + data = { + "bucket_name": "my-bucket", + "target_connection_id": "conn-123", + "target_bucket": "remote-bucket", + } + rule = ReplicationRule.from_dict(data) + assert rule.mode == REPLICATION_MODE_NEW_ONLY + assert rule.created_at is None + + +class TestReplicationManager: + def test_get_rule_not_exists(self, replication_manager): + rule = replication_manager.get_rule("nonexistent-bucket") + assert rule is None + + def test_set_and_get_rule(self, replication_manager): + rule = ReplicationRule( + bucket_name="my-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + enabled=True, + mode=REPLICATION_MODE_NEW_ONLY, + created_at=time.time(), + ) + replication_manager.set_rule(rule) + + retrieved = replication_manager.get_rule("my-bucket") + assert retrieved is not None + assert retrieved.bucket_name == "my-bucket" + assert retrieved.target_connection_id == "test-conn" + assert retrieved.target_bucket == "remote-bucket" + + def test_delete_rule(self, replication_manager): + rule = ReplicationRule( + bucket_name="to-delete", + target_connection_id="test-conn", + target_bucket="remote-bucket", + ) + replication_manager.set_rule(rule) + assert replication_manager.get_rule("to-delete") is not None + + replication_manager.delete_rule("to-delete") + assert replication_manager.get_rule("to-delete") is None + + def test_save_and_reload_rules(self, replication_manager, tmp_path): + rule = ReplicationRule( + bucket_name="persistent-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + enabled=True, + ) + replication_manager.set_rule(rule) + + rules_path = tmp_path / "replication_rules.json" + assert rules_path.exists() + data = json.loads(rules_path.read_text()) + assert "persistent-bucket" in data + + @patch("app.replication._create_s3_client") + def test_check_endpoint_health_success(self, mock_create_client, replication_manager, connections): + mock_client = MagicMock() + mock_client.list_buckets.return_value = {"Buckets": []} + mock_create_client.return_value = mock_client + + conn = connections.get("test-conn") + result = replication_manager.check_endpoint_health(conn) + assert result is True + mock_client.list_buckets.assert_called_once() + + @patch("app.replication._create_s3_client") + def test_check_endpoint_health_failure(self, mock_create_client, replication_manager, connections): + mock_client = MagicMock() + mock_client.list_buckets.side_effect = Exception("Connection refused") + mock_create_client.return_value = mock_client + + conn = connections.get("test-conn") + result = replication_manager.check_endpoint_health(conn) + assert result is False + + def test_trigger_replication_no_rule(self, replication_manager): + replication_manager.trigger_replication("no-such-bucket", "test.txt", "write") + + def test_trigger_replication_disabled_rule(self, replication_manager): + rule = ReplicationRule( + bucket_name="disabled-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + enabled=False, + ) + replication_manager.set_rule(rule) + replication_manager.trigger_replication("disabled-bucket", "test.txt", "write") + + def test_trigger_replication_missing_connection(self, replication_manager): + rule = ReplicationRule( + bucket_name="orphan-bucket", + target_connection_id="missing-conn", + target_bucket="remote-bucket", + enabled=True, + ) + replication_manager.set_rule(rule) + replication_manager.trigger_replication("orphan-bucket", "test.txt", "write") + + def test_replicate_task_path_traversal_blocked(self, replication_manager, connections): + rule = ReplicationRule( + bucket_name="secure-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + enabled=True, + ) + replication_manager.set_rule(rule) + conn = connections.get("test-conn") + + replication_manager._replicate_task("secure-bucket", "../../../etc/passwd", rule, conn, "write") + replication_manager._replicate_task("secure-bucket", "/root/secret", rule, conn, "write") + replication_manager._replicate_task("secure-bucket", "..\\..\\windows\\system32", rule, conn, "write") + + +class TestCreateS3Client: + @patch("app.replication.boto3.client") + def test_creates_client_with_correct_config(self, mock_boto_client): + conn = RemoteConnection( + id="test", + name="Test", + endpoint_url="http://localhost:9000", + access_key="access", + secret_key="secret", + region="eu-west-1", + ) + _create_s3_client(conn) + + mock_boto_client.assert_called_once() + call_kwargs = mock_boto_client.call_args[1] + assert call_kwargs["endpoint_url"] == "http://localhost:9000" + assert call_kwargs["aws_access_key_id"] == "access" + assert call_kwargs["aws_secret_access_key"] == "secret" + assert call_kwargs["region_name"] == "eu-west-1" + + @patch("app.replication.boto3.client") + def test_health_check_mode_minimal_retries(self, mock_boto_client): + conn = RemoteConnection( + id="test", + name="Test", + endpoint_url="http://localhost:9000", + access_key="access", + secret_key="secret", + ) + _create_s3_client(conn, health_check=True) + + call_kwargs = mock_boto_client.call_args[1] + config = call_kwargs["config"] + assert config.retries["max_attempts"] == 1