diff --git a/app/__init__.py b/app/__init__.py index 2a95281..eb99fd2 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -18,8 +18,10 @@ from werkzeug.middleware.proxy_fix import ProxyFix from .bucket_policies import BucketPolicyStore from .config import AppConfig from .connections import ConnectionStore +from .encryption import EncryptionManager from .extensions import limiter, csrf from .iam import IamService +from .kms import KMSManager from .replication import ReplicationManager from .secret_store import EphemeralSecretStore from .storage import ObjectStorage @@ -77,6 +79,21 @@ def create_app( connections = ConnectionStore(connections_path) replication = ReplicationManager(storage, connections, replication_rules_path) + + # Initialize encryption and KMS + encryption_config = { + "encryption_enabled": app.config.get("ENCRYPTION_ENABLED", False), + "encryption_master_key_path": app.config.get("ENCRYPTION_MASTER_KEY_PATH"), + "default_encryption_algorithm": app.config.get("DEFAULT_ENCRYPTION_ALGORITHM", "AES256"), + } + encryption_manager = EncryptionManager(encryption_config) + + kms_manager = None + if app.config.get("KMS_ENABLED", False): + kms_keys_path = Path(app.config.get("KMS_KEYS_PATH", "")) + kms_master_key_path = Path(app.config.get("ENCRYPTION_MASTER_KEY_PATH", "")) + kms_manager = KMSManager(kms_keys_path, kms_master_key_path) + encryption_manager.set_kms_provider(kms_manager) app.extensions["object_storage"] = storage app.extensions["iam"] = iam @@ -85,6 +102,8 @@ def create_app( app.extensions["limiter"] = limiter app.extensions["connections"] = connections app.extensions["replication"] = replication + app.extensions["encryption"] = encryption_manager + app.extensions["kms"] = kms_manager @app.errorhandler(500) def internal_error(error): @@ -119,9 +138,12 @@ def create_app( if include_api: from .s3_api import s3_api_bp + from .kms_api import kms_api_bp app.register_blueprint(s3_api_bp) + app.register_blueprint(kms_api_bp) csrf.exempt(s3_api_bp) + csrf.exempt(kms_api_bp) if include_ui: from .ui import ui_bp diff --git a/app/config.py b/app/config.py index a5d3783..d8d9555 100644 --- a/app/config.py +++ b/app/config.py @@ -66,6 +66,11 @@ class AppConfig: stream_chunk_size: int multipart_min_part_size: int bucket_stats_cache_ttl: int + encryption_enabled: bool + encryption_master_key_path: Path + kms_enabled: bool + kms_keys_path: Path + default_encryption_algorithm: str @classmethod def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig": @@ -155,6 +160,14 @@ class AppConfig: ]) session_lifetime_days = int(_get("SESSION_LIFETIME_DAYS", 30)) bucket_stats_cache_ttl = int(_get("BUCKET_STATS_CACHE_TTL", 60)) # Default 60 seconds + + # Encryption settings + encryption_enabled = str(_get("ENCRYPTION_ENABLED", "0")).lower() in {"1", "true", "yes", "on"} + encryption_keys_dir = storage_root / ".myfsio.sys" / "keys" + encryption_master_key_path = Path(_get("ENCRYPTION_MASTER_KEY_PATH", encryption_keys_dir / "master.key")).resolve() + kms_enabled = str(_get("KMS_ENABLED", "0")).lower() in {"1", "true", "yes", "on"} + kms_keys_path = Path(_get("KMS_KEYS_PATH", encryption_keys_dir / "kms_keys.json")).resolve() + default_encryption_algorithm = str(_get("DEFAULT_ENCRYPTION_ALGORITHM", "AES256")) return cls(storage_root=storage_root, max_upload_size=max_upload_size, @@ -182,7 +195,12 @@ class AppConfig: secret_ttl_seconds=secret_ttl_seconds, stream_chunk_size=stream_chunk_size, multipart_min_part_size=multipart_min_part_size, - bucket_stats_cache_ttl=bucket_stats_cache_ttl) + bucket_stats_cache_ttl=bucket_stats_cache_ttl, + encryption_enabled=encryption_enabled, + encryption_master_key_path=encryption_master_key_path, + kms_enabled=kms_enabled, + kms_keys_path=kms_keys_path, + default_encryption_algorithm=default_encryption_algorithm) def to_flask_config(self) -> Dict[str, Any]: return { @@ -213,4 +231,9 @@ class AppConfig: "CORS_METHODS": self.cors_methods, "CORS_ALLOW_HEADERS": self.cors_allow_headers, "SESSION_LIFETIME_DAYS": self.session_lifetime_days, + "ENCRYPTION_ENABLED": self.encryption_enabled, + "ENCRYPTION_MASTER_KEY_PATH": str(self.encryption_master_key_path), + "KMS_ENABLED": self.kms_enabled, + "KMS_KEYS_PATH": str(self.kms_keys_path), + "DEFAULT_ENCRYPTION_ALGORITHM": self.default_encryption_algorithm, } diff --git a/app/encrypted_storage.py b/app/encrypted_storage.py new file mode 100644 index 0000000..ca4f138 --- /dev/null +++ b/app/encrypted_storage.py @@ -0,0 +1,268 @@ +"""Encrypted storage layer that wraps ObjectStorage with encryption support.""" +from __future__ import annotations + +import io +from pathlib import Path +from typing import Any, BinaryIO, Dict, Optional + +from .encryption import EncryptionManager, EncryptionMetadata, EncryptionError +from .storage import ObjectStorage, ObjectMeta, StorageError + + +class EncryptedObjectStorage: + """Object storage with transparent server-side encryption. + + This class wraps ObjectStorage and provides transparent encryption/decryption + of objects based on bucket encryption configuration. + + Encryption is applied when: + 1. Bucket has default encryption configured (SSE-S3 or SSE-KMS) + 2. Client explicitly requests encryption via headers + + The encryption metadata is stored alongside object metadata. + """ + + STREAMING_THRESHOLD = 64 * 1024 + + def __init__(self, storage: ObjectStorage, encryption_manager: EncryptionManager): + self.storage = storage + self.encryption = encryption_manager + + @property + def root(self) -> Path: + return self.storage.root + + def _should_encrypt(self, bucket_name: str, + server_side_encryption: str | None = None) -> tuple[bool, str, str | None]: + """Determine if object should be encrypted. + + Returns: + Tuple of (should_encrypt, algorithm, kms_key_id) + """ + if not self.encryption.enabled: + return False, "", None + + if server_side_encryption: + if server_side_encryption == "AES256": + return True, "AES256", None + elif server_side_encryption.startswith("aws:kms"): + parts = server_side_encryption.split(":") + kms_key_id = parts[2] if len(parts) > 2 else None + return True, "aws:kms", kms_key_id + + try: + encryption_config = self.storage.get_bucket_encryption(bucket_name) + if encryption_config and encryption_config.get("Rules"): + rule = encryption_config["Rules"][0] + algorithm = rule.get("SSEAlgorithm", "AES256") + kms_key_id = rule.get("KMSMasterKeyID") + return True, algorithm, kms_key_id + except StorageError: + pass + + return False, "", None + + def _is_encrypted(self, metadata: Dict[str, str]) -> bool: + """Check if object is encrypted based on its metadata.""" + return "x-amz-server-side-encryption" in metadata + + def put_object( + self, + bucket_name: str, + object_key: str, + stream: BinaryIO, + *, + metadata: Optional[Dict[str, str]] = None, + server_side_encryption: Optional[str] = None, + kms_key_id: Optional[str] = None, + ) -> ObjectMeta: + """Store an object, optionally with encryption. + + Args: + bucket_name: Name of the bucket + object_key: Key for the object + stream: Binary stream of object data + metadata: Optional user metadata + server_side_encryption: Encryption algorithm ("AES256" or "aws:kms") + kms_key_id: KMS key ID (for aws:kms encryption) + + Returns: + ObjectMeta with object information + """ + should_encrypt, algorithm, detected_kms_key = self._should_encrypt( + bucket_name, server_side_encryption + ) + + if kms_key_id is None: + kms_key_id = detected_kms_key + + if should_encrypt: + data = stream.read() + + try: + ciphertext, enc_metadata = self.encryption.encrypt_object( + data, + algorithm=algorithm, + kms_key_id=kms_key_id, + context={"bucket": bucket_name, "key": object_key}, + ) + + combined_metadata = metadata.copy() if metadata else {} + combined_metadata.update(enc_metadata.to_dict()) + + encrypted_stream = io.BytesIO(ciphertext) + result = self.storage.put_object( + bucket_name, + object_key, + encrypted_stream, + metadata=combined_metadata, + ) + + result.metadata = combined_metadata + return result + + except EncryptionError as exc: + raise StorageError(f"Encryption failed: {exc}") from exc + else: + return self.storage.put_object( + bucket_name, + object_key, + stream, + metadata=metadata, + ) + + def get_object_data(self, bucket_name: str, object_key: str) -> tuple[bytes, Dict[str, str]]: + """Get object data, decrypting if necessary. + + Returns: + Tuple of (data, metadata) + """ + path = self.storage.get_object_path(bucket_name, object_key) + metadata = self.storage.get_object_metadata(bucket_name, object_key) + + with path.open("rb") as f: + data = f.read() + + enc_metadata = EncryptionMetadata.from_dict(metadata) + if enc_metadata: + try: + data = self.encryption.decrypt_object( + data, + enc_metadata, + context={"bucket": bucket_name, "key": object_key}, + ) + except EncryptionError as exc: + raise StorageError(f"Decryption failed: {exc}") from exc + + clean_metadata = { + k: v for k, v in metadata.items() + if not k.startswith("x-amz-encryption") + and k != "x-amz-encrypted-data-key" + } + + return data, clean_metadata + + def get_object_stream(self, bucket_name: str, object_key: str) -> tuple[BinaryIO, Dict[str, str], int]: + """Get object as a stream, decrypting if necessary. + + Returns: + Tuple of (stream, metadata, original_size) + """ + data, metadata = self.get_object_data(bucket_name, object_key) + return io.BytesIO(data), metadata, len(data) + + def list_buckets(self): + return self.storage.list_buckets() + + def bucket_exists(self, bucket_name: str) -> bool: + return self.storage.bucket_exists(bucket_name) + + def create_bucket(self, bucket_name: str) -> None: + return self.storage.create_bucket(bucket_name) + + def delete_bucket(self, bucket_name: str) -> None: + return self.storage.delete_bucket(bucket_name) + + def bucket_stats(self, bucket_name: str, cache_ttl: int = 60): + return self.storage.bucket_stats(bucket_name, cache_ttl) + + def list_objects(self, bucket_name: str): + return self.storage.list_objects(bucket_name) + + def get_object_path(self, bucket_name: str, object_key: str): + return self.storage.get_object_path(bucket_name, object_key) + + def get_object_metadata(self, bucket_name: str, object_key: str): + return self.storage.get_object_metadata(bucket_name, object_key) + + def delete_object(self, bucket_name: str, object_key: str) -> None: + return self.storage.delete_object(bucket_name, object_key) + + def purge_object(self, bucket_name: str, object_key: str) -> None: + return self.storage.purge_object(bucket_name, object_key) + + def is_versioning_enabled(self, bucket_name: str) -> bool: + return self.storage.is_versioning_enabled(bucket_name) + + def set_bucket_versioning(self, bucket_name: str, enabled: bool) -> None: + return self.storage.set_bucket_versioning(bucket_name, enabled) + + def get_bucket_tags(self, bucket_name: str): + return self.storage.get_bucket_tags(bucket_name) + + def set_bucket_tags(self, bucket_name: str, tags): + return self.storage.set_bucket_tags(bucket_name, tags) + + def get_bucket_cors(self, bucket_name: str): + return self.storage.get_bucket_cors(bucket_name) + + def set_bucket_cors(self, bucket_name: str, rules): + return self.storage.set_bucket_cors(bucket_name, rules) + + def get_bucket_encryption(self, bucket_name: str): + return self.storage.get_bucket_encryption(bucket_name) + + def set_bucket_encryption(self, bucket_name: str, config_payload): + return self.storage.set_bucket_encryption(bucket_name, config_payload) + + def get_bucket_lifecycle(self, bucket_name: str): + return self.storage.get_bucket_lifecycle(bucket_name) + + def set_bucket_lifecycle(self, bucket_name: str, rules): + return self.storage.set_bucket_lifecycle(bucket_name, rules) + + def get_object_tags(self, bucket_name: str, object_key: str): + return self.storage.get_object_tags(bucket_name, object_key) + + def set_object_tags(self, bucket_name: str, object_key: str, tags): + return self.storage.set_object_tags(bucket_name, object_key, tags) + + def delete_object_tags(self, bucket_name: str, object_key: str): + return self.storage.delete_object_tags(bucket_name, object_key) + + def list_object_versions(self, bucket_name: str, object_key: str): + return self.storage.list_object_versions(bucket_name, object_key) + + def restore_object_version(self, bucket_name: str, object_key: str, version_id: str): + return self.storage.restore_object_version(bucket_name, object_key, version_id) + + def list_orphaned_objects(self, bucket_name: str): + return self.storage.list_orphaned_objects(bucket_name) + + def initiate_multipart_upload(self, bucket_name: str, object_key: str, *, metadata=None) -> str: + return self.storage.initiate_multipart_upload(bucket_name, object_key, metadata=metadata) + + def upload_multipart_part(self, bucket_name: str, upload_id: str, part_number: int, stream: BinaryIO) -> str: + return self.storage.upload_multipart_part(bucket_name, upload_id, part_number, stream) + + def complete_multipart_upload(self, bucket_name: str, upload_id: str, ordered_parts): + return self.storage.complete_multipart_upload(bucket_name, upload_id, ordered_parts) + + def abort_multipart_upload(self, bucket_name: str, upload_id: str) -> None: + return self.storage.abort_multipart_upload(bucket_name, upload_id) + + def list_multipart_parts(self, bucket_name: str, upload_id: str): + return self.storage.list_multipart_parts(bucket_name, upload_id) + + def _compute_etag(self, path: Path) -> str: + return self.storage._compute_etag(path) diff --git a/app/encryption.py b/app/encryption.py new file mode 100644 index 0000000..aa98cb4 --- /dev/null +++ b/app/encryption.py @@ -0,0 +1,395 @@ +"""Encryption providers for server-side and client-side encryption.""" +from __future__ import annotations + +import base64 +import io +import json +import secrets +from dataclasses import dataclass +from pathlib import Path +from typing import Any, BinaryIO, Dict, Generator, Optional + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + +class EncryptionError(Exception): + """Raised when encryption/decryption fails.""" + + +@dataclass +class EncryptionResult: + """Result of encrypting data.""" + ciphertext: bytes + nonce: bytes + key_id: str + encrypted_data_key: bytes + + +@dataclass +class EncryptionMetadata: + """Metadata stored with encrypted objects.""" + algorithm: str + key_id: str + nonce: bytes + encrypted_data_key: bytes + + def to_dict(self) -> Dict[str, str]: + return { + "x-amz-server-side-encryption": self.algorithm, + "x-amz-encryption-key-id": self.key_id, + "x-amz-encryption-nonce": base64.b64encode(self.nonce).decode(), + "x-amz-encrypted-data-key": base64.b64encode(self.encrypted_data_key).decode(), + } + + @classmethod + def from_dict(cls, data: Dict[str, str]) -> Optional["EncryptionMetadata"]: + algorithm = data.get("x-amz-server-side-encryption") + if not algorithm: + return None + try: + return cls( + algorithm=algorithm, + key_id=data.get("x-amz-encryption-key-id", "local"), + nonce=base64.b64decode(data.get("x-amz-encryption-nonce", "")), + encrypted_data_key=base64.b64decode(data.get("x-amz-encrypted-data-key", "")), + ) + except Exception: + return None + + +class EncryptionProvider: + """Base class for encryption providers.""" + + def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: + raise NotImplementedError + + def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, + key_id: str, context: Dict[str, str] | None = None) -> bytes: + raise NotImplementedError + + def generate_data_key(self) -> tuple[bytes, bytes]: + """Generate a data key and its encrypted form. + + Returns: + Tuple of (plaintext_key, encrypted_key) + """ + raise NotImplementedError + + +class LocalKeyEncryption(EncryptionProvider): + """SSE-S3 style encryption using a local master key. + + Uses envelope encryption: + 1. Generate a unique data key for each object + 2. Encrypt the data with the data key (AES-256-GCM) + 3. Encrypt the data key with the master key + 4. Store the encrypted data key alongside the ciphertext + """ + + KEY_ID = "local" + + def __init__(self, master_key_path: Path): + self.master_key_path = master_key_path + self._master_key: bytes | None = None + + @property + def master_key(self) -> bytes: + if self._master_key is None: + self._master_key = self._load_or_create_master_key() + return self._master_key + + def _load_or_create_master_key(self) -> bytes: + """Load master key from file or generate a new one.""" + if self.master_key_path.exists(): + try: + return base64.b64decode(self.master_key_path.read_text().strip()) + except Exception as exc: + raise EncryptionError(f"Failed to load master key: {exc}") from exc + + key = secrets.token_bytes(32) + try: + self.master_key_path.parent.mkdir(parents=True, exist_ok=True) + self.master_key_path.write_text(base64.b64encode(key).decode()) + except OSError as exc: + raise EncryptionError(f"Failed to save master key: {exc}") from exc + return key + + def _encrypt_data_key(self, data_key: bytes) -> bytes: + """Encrypt the data key with the master key.""" + aesgcm = AESGCM(self.master_key) + nonce = secrets.token_bytes(12) + encrypted = aesgcm.encrypt(nonce, data_key, None) + return nonce + encrypted + + def _decrypt_data_key(self, encrypted_data_key: bytes) -> bytes: + """Decrypt the data key using the master key.""" + if len(encrypted_data_key) < 12 + 32 + 16: # nonce + key + tag + raise EncryptionError("Invalid encrypted data key") + aesgcm = AESGCM(self.master_key) + nonce = encrypted_data_key[:12] + ciphertext = encrypted_data_key[12:] + try: + return aesgcm.decrypt(nonce, ciphertext, None) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc + + def generate_data_key(self) -> tuple[bytes, bytes]: + """Generate a data key and its encrypted form.""" + plaintext_key = secrets.token_bytes(32) + encrypted_key = self._encrypt_data_key(plaintext_key) + return plaintext_key, encrypted_key + + def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: + """Encrypt data using envelope encryption.""" + data_key, encrypted_data_key = self.generate_data_key() + + aesgcm = AESGCM(data_key) + nonce = secrets.token_bytes(12) + ciphertext = aesgcm.encrypt(nonce, plaintext, None) + + return EncryptionResult( + ciphertext=ciphertext, + nonce=nonce, + key_id=self.KEY_ID, + encrypted_data_key=encrypted_data_key, + ) + + def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, + key_id: str, context: Dict[str, str] | None = None) -> bytes: + """Decrypt data using envelope encryption.""" + # Decrypt the data key + data_key = self._decrypt_data_key(encrypted_data_key) + + # Decrypt the data + aesgcm = AESGCM(data_key) + try: + return aesgcm.decrypt(nonce, ciphertext, None) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt data: {exc}") from exc + + +class StreamingEncryptor: + """Encrypts/decrypts data in streaming fashion for large files. + + For large files, we encrypt in chunks. Each chunk is encrypted with the + same data key but a unique nonce derived from the base nonce + chunk index. + """ + + CHUNK_SIZE = 64 * 1024 + HEADER_SIZE = 4 + + def __init__(self, provider: EncryptionProvider, chunk_size: int = CHUNK_SIZE): + self.provider = provider + self.chunk_size = chunk_size + + def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes: + """Derive a unique nonce for each chunk.""" + # XOR the base nonce with the chunk index + nonce_int = int.from_bytes(base_nonce, "big") + derived = nonce_int ^ chunk_index + return derived.to_bytes(12, "big") + + def encrypt_stream(self, stream: BinaryIO, + context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]: + """Encrypt a stream and return encrypted stream + metadata.""" + + data_key, encrypted_data_key = self.provider.generate_data_key() + base_nonce = secrets.token_bytes(12) + + aesgcm = AESGCM(data_key) + encrypted_chunks = [] + chunk_index = 0 + + while True: + chunk = stream.read(self.chunk_size) + if not chunk: + break + + chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index) + encrypted_chunk = aesgcm.encrypt(chunk_nonce, chunk, None) + + size_prefix = len(encrypted_chunk).to_bytes(self.HEADER_SIZE, "big") + encrypted_chunks.append(size_prefix + encrypted_chunk) + chunk_index += 1 + + header = chunk_index.to_bytes(4, "big") + encrypted_data = header + b"".join(encrypted_chunks) + + metadata = EncryptionMetadata( + algorithm="AES256", + key_id=self.provider.KEY_ID if hasattr(self.provider, "KEY_ID") else "local", + nonce=base_nonce, + encrypted_data_key=encrypted_data_key, + ) + + return io.BytesIO(encrypted_data), metadata + + def decrypt_stream(self, stream: BinaryIO, metadata: EncryptionMetadata) -> BinaryIO: + """Decrypt a stream using the provided metadata.""" + if isinstance(self.provider, LocalKeyEncryption): + data_key = self.provider._decrypt_data_key(metadata.encrypted_data_key) + else: + raise EncryptionError("Unsupported provider for streaming decryption") + + aesgcm = AESGCM(data_key) + base_nonce = metadata.nonce + + chunk_count_bytes = stream.read(4) + if len(chunk_count_bytes) < 4: + raise EncryptionError("Invalid encrypted stream: missing header") + chunk_count = int.from_bytes(chunk_count_bytes, "big") + + decrypted_chunks = [] + for chunk_index in range(chunk_count): + size_bytes = stream.read(self.HEADER_SIZE) + if len(size_bytes) < self.HEADER_SIZE: + raise EncryptionError(f"Invalid encrypted stream: truncated at chunk {chunk_index}") + chunk_size = int.from_bytes(size_bytes, "big") + + encrypted_chunk = stream.read(chunk_size) + if len(encrypted_chunk) < chunk_size: + raise EncryptionError(f"Invalid encrypted stream: incomplete chunk {chunk_index}") + + chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index) + try: + decrypted_chunk = aesgcm.decrypt(chunk_nonce, encrypted_chunk, None) + decrypted_chunks.append(decrypted_chunk) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt chunk {chunk_index}: {exc}") from exc + + return io.BytesIO(b"".join(decrypted_chunks)) + + +class EncryptionManager: + """Manages encryption providers and operations.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self._local_provider: LocalKeyEncryption | None = None + self._kms_provider: Any = None # Set by KMS module + self._streaming_encryptor: StreamingEncryptor | None = None + + @property + def enabled(self) -> bool: + return self.config.get("encryption_enabled", False) + + @property + def default_algorithm(self) -> str: + return self.config.get("default_encryption_algorithm", "AES256") + + def get_local_provider(self) -> LocalKeyEncryption: + if self._local_provider is None: + key_path = Path(self.config.get("encryption_master_key_path", "data/.myfsio.sys/keys/master.key")) + self._local_provider = LocalKeyEncryption(key_path) + return self._local_provider + + def set_kms_provider(self, kms_provider: Any) -> None: + """Set the KMS provider (injected from kms module).""" + self._kms_provider = kms_provider + + def get_provider(self, algorithm: str, kms_key_id: str | None = None) -> EncryptionProvider: + """Get the appropriate encryption provider for the algorithm.""" + if algorithm == "AES256": + return self.get_local_provider() + elif algorithm == "aws:kms": + if self._kms_provider is None: + raise EncryptionError("KMS is not configured") + return self._kms_provider.get_provider(kms_key_id) + else: + raise EncryptionError(f"Unsupported encryption algorithm: {algorithm}") + + def get_streaming_encryptor(self) -> StreamingEncryptor: + if self._streaming_encryptor is None: + self._streaming_encryptor = StreamingEncryptor(self.get_local_provider()) + return self._streaming_encryptor + + def encrypt_object(self, data: bytes, algorithm: str = "AES256", + kms_key_id: str | None = None, + context: Dict[str, str] | None = None) -> tuple[bytes, EncryptionMetadata]: + """Encrypt object data.""" + provider = self.get_provider(algorithm, kms_key_id) + result = provider.encrypt(data, context) + + metadata = EncryptionMetadata( + algorithm=algorithm, + key_id=result.key_id, + nonce=result.nonce, + encrypted_data_key=result.encrypted_data_key, + ) + + return result.ciphertext, metadata + + def decrypt_object(self, ciphertext: bytes, metadata: EncryptionMetadata, + context: Dict[str, str] | None = None) -> bytes: + """Decrypt object data.""" + provider = self.get_provider(metadata.algorithm, metadata.key_id) + return provider.decrypt( + ciphertext, + metadata.nonce, + metadata.encrypted_data_key, + metadata.key_id, + context, + ) + + def encrypt_stream(self, stream: BinaryIO, algorithm: str = "AES256", + context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]: + """Encrypt a stream for large files.""" + encryptor = self.get_streaming_encryptor() + return encryptor.encrypt_stream(stream, context) + + def decrypt_stream(self, stream: BinaryIO, metadata: EncryptionMetadata) -> BinaryIO: + """Decrypt a stream.""" + encryptor = self.get_streaming_encryptor() + return encryptor.decrypt_stream(stream, metadata) + + +class ClientEncryptionHelper: + """Helpers for client-side encryption. + + Client-side encryption is performed by the client, but this helper + provides key generation and materials for clients that need them. + """ + + @staticmethod + def generate_client_key() -> Dict[str, str]: + """Generate a new client encryption key.""" + from datetime import datetime, timezone + key = secrets.token_bytes(32) + return { + "key": base64.b64encode(key).decode(), + "algorithm": "AES-256-GCM", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + @staticmethod + def encrypt_with_key(plaintext: bytes, key_b64: str) -> Dict[str, str]: + """Encrypt data with a client-provided key.""" + key = base64.b64decode(key_b64) + if len(key) != 32: + raise EncryptionError("Key must be 256 bits (32 bytes)") + + aesgcm = AESGCM(key) + nonce = secrets.token_bytes(12) + ciphertext = aesgcm.encrypt(nonce, plaintext, None) + + return { + "ciphertext": base64.b64encode(ciphertext).decode(), + "nonce": base64.b64encode(nonce).decode(), + "algorithm": "AES-256-GCM", + } + + @staticmethod + def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str) -> bytes: + """Decrypt data with a client-provided key.""" + key = base64.b64decode(key_b64) + nonce = base64.b64decode(nonce_b64) + ciphertext = base64.b64decode(ciphertext_b64) + + if len(key) != 32: + raise EncryptionError("Key must be 256 bits (32 bytes)") + + aesgcm = AESGCM(key) + try: + return aesgcm.decrypt(nonce, ciphertext, None) + except Exception as exc: + raise EncryptionError(f"Decryption failed: {exc}") from exc diff --git a/app/kms.py b/app/kms.py new file mode 100644 index 0000000..8323749 --- /dev/null +++ b/app/kms.py @@ -0,0 +1,343 @@ +"""Key Management Service (KMS) for encryption key management.""" +from __future__ import annotations + +import base64 +import json +import secrets +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from .encryption import EncryptionError, EncryptionProvider, EncryptionResult + + +@dataclass +class KMSKey: + """Represents a KMS encryption key.""" + key_id: str + description: str + created_at: str + enabled: bool = True + key_material: bytes = field(default_factory=lambda: b"", repr=False) + + @property + def arn(self) -> str: + return f"arn:aws:kms:local:000000000000:key/{self.key_id}" + + def to_dict(self, include_key: bool = False) -> Dict[str, Any]: + data = { + "KeyId": self.key_id, + "Arn": self.arn, + "Description": self.description, + "CreationDate": self.created_at, + "Enabled": self.enabled, + "KeyState": "Enabled" if self.enabled else "Disabled", + "KeyUsage": "ENCRYPT_DECRYPT", + "KeySpec": "SYMMETRIC_DEFAULT", + } + if include_key: + data["KeyMaterial"] = base64.b64encode(self.key_material).decode() + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "KMSKey": + key_material = b"" + if "KeyMaterial" in data: + key_material = base64.b64decode(data["KeyMaterial"]) + return cls( + key_id=data["KeyId"], + description=data.get("Description", ""), + created_at=data.get("CreationDate", datetime.now(timezone.utc).isoformat()), + enabled=data.get("Enabled", True), + key_material=key_material, + ) + + +class KMSEncryptionProvider(EncryptionProvider): + """Encryption provider using a specific KMS key.""" + + def __init__(self, kms: "KMSManager", key_id: str): + self.kms = kms + self.key_id = key_id + + @property + def KEY_ID(self) -> str: + return self.key_id + + def generate_data_key(self) -> tuple[bytes, bytes]: + """Generate a data key encrypted with the KMS key.""" + return self.kms.generate_data_key(self.key_id) + + def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: + """Encrypt data using envelope encryption with KMS.""" + data_key, encrypted_data_key = self.generate_data_key() + + aesgcm = AESGCM(data_key) + nonce = secrets.token_bytes(12) + ciphertext = aesgcm.encrypt(nonce, plaintext, + json.dumps(context).encode() if context else None) + + return EncryptionResult( + ciphertext=ciphertext, + nonce=nonce, + key_id=self.key_id, + encrypted_data_key=encrypted_data_key, + ) + + def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, + key_id: str, context: Dict[str, str] | None = None) -> bytes: + """Decrypt data using envelope encryption with KMS.""" + data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context) + + aesgcm = AESGCM(data_key) + try: + return aesgcm.decrypt(nonce, ciphertext, + json.dumps(context).encode() if context else None) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt data: {exc}") from exc + + +class KMSManager: + """Manages KMS keys and operations. + + This is a local implementation that mimics AWS KMS functionality. + Keys are stored encrypted on disk. + """ + + def __init__(self, keys_path: Path, master_key_path: Path): + self.keys_path = keys_path + self.master_key_path = master_key_path + self._keys: Dict[str, KMSKey] = {} + self._master_key: bytes | None = None + self._loaded = False + + @property + def master_key(self) -> bytes: + """Load or create the master key for encrypting KMS keys.""" + if self._master_key is None: + if self.master_key_path.exists(): + self._master_key = base64.b64decode( + self.master_key_path.read_text().strip() + ) + else: + self._master_key = secrets.token_bytes(32) + self.master_key_path.parent.mkdir(parents=True, exist_ok=True) + self.master_key_path.write_text( + base64.b64encode(self._master_key).decode() + ) + return self._master_key + + def _load_keys(self) -> None: + """Load keys from disk.""" + if self._loaded: + return + + if self.keys_path.exists(): + try: + data = json.loads(self.keys_path.read_text(encoding="utf-8")) + for key_data in data.get("keys", []): + key = KMSKey.from_dict(key_data) + if key_data.get("EncryptedKeyMaterial"): + encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"]) + key.key_material = self._decrypt_key_material(encrypted) + self._keys[key.key_id] = key + except Exception: + pass + + self._loaded = True + + def _save_keys(self) -> None: + """Save keys to disk (with encrypted key material).""" + keys_data = [] + for key in self._keys.values(): + data = key.to_dict(include_key=False) + encrypted = self._encrypt_key_material(key.key_material) + data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode() + keys_data.append(data) + + self.keys_path.parent.mkdir(parents=True, exist_ok=True) + self.keys_path.write_text( + json.dumps({"keys": keys_data}, indent=2), + encoding="utf-8" + ) + + def _encrypt_key_material(self, key_material: bytes) -> bytes: + """Encrypt key material with the master key.""" + aesgcm = AESGCM(self.master_key) + nonce = secrets.token_bytes(12) + ciphertext = aesgcm.encrypt(nonce, key_material, None) + return nonce + ciphertext + + def _decrypt_key_material(self, encrypted: bytes) -> bytes: + """Decrypt key material with the master key.""" + aesgcm = AESGCM(self.master_key) + nonce = encrypted[:12] + ciphertext = encrypted[12:] + return aesgcm.decrypt(nonce, ciphertext, None) + + def create_key(self, description: str = "", key_id: str | None = None) -> KMSKey: + """Create a new KMS key.""" + self._load_keys() + + if key_id is None: + key_id = str(uuid.uuid4()) + + if key_id in self._keys: + raise EncryptionError(f"Key already exists: {key_id}") + + key = KMSKey( + key_id=key_id, + description=description, + created_at=datetime.now(timezone.utc).isoformat(), + enabled=True, + key_material=secrets.token_bytes(32), + ) + + self._keys[key_id] = key + self._save_keys() + return key + + def get_key(self, key_id: str) -> KMSKey | None: + """Get a key by ID.""" + self._load_keys() + return self._keys.get(key_id) + + def list_keys(self) -> List[KMSKey]: + """List all keys.""" + self._load_keys() + return list(self._keys.values()) + + def enable_key(self, key_id: str) -> None: + """Enable a key.""" + self._load_keys() + key = self._keys.get(key_id) + if not key: + raise EncryptionError(f"Key not found: {key_id}") + key.enabled = True + self._save_keys() + + def disable_key(self, key_id: str) -> None: + """Disable a key.""" + self._load_keys() + key = self._keys.get(key_id) + if not key: + raise EncryptionError(f"Key not found: {key_id}") + key.enabled = False + self._save_keys() + + def delete_key(self, key_id: str) -> None: + """Delete a key (schedule for deletion in real KMS).""" + self._load_keys() + if key_id not in self._keys: + raise EncryptionError(f"Key not found: {key_id}") + del self._keys[key_id] + self._save_keys() + + def encrypt(self, key_id: str, plaintext: bytes, + context: Dict[str, str] | None = None) -> bytes: + """Encrypt data directly with a KMS key.""" + self._load_keys() + key = self._keys.get(key_id) + if not key: + raise EncryptionError(f"Key not found: {key_id}") + if not key.enabled: + raise EncryptionError(f"Key is disabled: {key_id}") + + aesgcm = AESGCM(key.key_material) + nonce = secrets.token_bytes(12) + aad = json.dumps(context).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) + + key_id_bytes = key_id.encode("utf-8") + return len(key_id_bytes).to_bytes(2, "big") + key_id_bytes + nonce + ciphertext + + def decrypt(self, ciphertext: bytes, + context: Dict[str, str] | None = None) -> tuple[bytes, str]: + """Decrypt data directly with a KMS key. + + Returns: + Tuple of (plaintext, key_id) + """ + self._load_keys() + + key_id_len = int.from_bytes(ciphertext[:2], "big") + key_id = ciphertext[2:2 + key_id_len].decode("utf-8") + rest = ciphertext[2 + key_id_len:] + + key = self._keys.get(key_id) + if not key: + raise EncryptionError(f"Key not found: {key_id}") + if not key.enabled: + raise EncryptionError(f"Key is disabled: {key_id}") + + nonce = rest[:12] + encrypted = rest[12:] + + aesgcm = AESGCM(key.key_material) + aad = json.dumps(context).encode() if context else None + try: + plaintext = aesgcm.decrypt(nonce, encrypted, aad) + return plaintext, key_id + except Exception as exc: + raise EncryptionError(f"Decryption failed: {exc}") from exc + + def generate_data_key(self, key_id: str, + context: Dict[str, str] | None = None) -> tuple[bytes, bytes]: + """Generate a data key and return both plaintext and encrypted versions. + + Returns: + Tuple of (plaintext_key, encrypted_key) + """ + self._load_keys() + key = self._keys.get(key_id) + if not key: + raise EncryptionError(f"Key not found: {key_id}") + if not key.enabled: + raise EncryptionError(f"Key is disabled: {key_id}") + + plaintext_key = secrets.token_bytes(32) + + encrypted_key = self.encrypt(key_id, plaintext_key, context) + + return plaintext_key, encrypted_key + + def decrypt_data_key(self, key_id: str, encrypted_key: bytes, + context: Dict[str, str] | None = None) -> bytes: + """Decrypt a data key.""" + plaintext, _ = self.decrypt(encrypted_key, context) + return plaintext + + def get_provider(self, key_id: str | None = None) -> KMSEncryptionProvider: + """Get an encryption provider for a specific key.""" + self._load_keys() + + if key_id is None: + if not self._keys: + key = self.create_key("Default KMS Key") + key_id = key.key_id + else: + key_id = next(iter(self._keys.keys())) + + if key_id not in self._keys: + raise EncryptionError(f"Key not found: {key_id}") + + return KMSEncryptionProvider(self, key_id) + + def re_encrypt(self, ciphertext: bytes, destination_key_id: str, + source_context: Dict[str, str] | None = None, + destination_context: Dict[str, str] | None = None) -> bytes: + """Re-encrypt data with a different key.""" + + plaintext, source_key_id = self.decrypt(ciphertext, source_context) + + return self.encrypt(destination_key_id, plaintext, destination_context) + + def generate_random(self, num_bytes: int = 32) -> bytes: + """Generate cryptographically secure random bytes.""" + if num_bytes < 1 or num_bytes > 1024: + raise EncryptionError("Number of bytes must be between 1 and 1024") + return secrets.token_bytes(num_bytes) diff --git a/app/kms_api.py b/app/kms_api.py new file mode 100644 index 0000000..551d262 --- /dev/null +++ b/app/kms_api.py @@ -0,0 +1,463 @@ +"""KMS and encryption API endpoints.""" +from __future__ import annotations + +import base64 +import uuid +from typing import Any, Dict + +from flask import Blueprint, Response, current_app, jsonify, request + +from .encryption import ClientEncryptionHelper, EncryptionError +from .extensions import limiter +from .iam import IamError + +kms_api_bp = Blueprint("kms_api", __name__, url_prefix="/kms") + + +def _require_principal(): + """Require authentication for KMS operations.""" + from .s3_api import _require_principal as s3_require_principal + return s3_require_principal() + + +def _kms(): + """Get KMS manager from app extensions.""" + return current_app.extensions.get("kms") + + +def _encryption(): + """Get encryption manager from app extensions.""" + return current_app.extensions.get("encryption") + + +def _error_response(code: str, message: str, status: int) -> tuple[Dict[str, Any], int]: + return {"__type": code, "message": message}, status + + +# ---------------------- Key Management ---------------------- + +@kms_api_bp.route("/keys", methods=["GET", "POST"]) +@limiter.limit("30 per minute") +def list_or_create_keys(): + """List all KMS keys or create a new key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + if request.method == "POST": + payload = request.get_json(silent=True) or {} + key_id = payload.get("KeyId") or payload.get("key_id") + description = payload.get("Description") or payload.get("description", "") + + try: + key = kms.create_key(description=description, key_id=key_id) + current_app.logger.info( + "KMS key created", + extra={"key_id": key.key_id, "principal": principal.access_key}, + ) + return jsonify({ + "KeyMetadata": key.to_dict(), + }) + except EncryptionError as exc: + return _error_response("KMSInternalException", str(exc), 400) + + # GET - List keys + keys = kms.list_keys() + return jsonify({ + "Keys": [{"KeyId": k.key_id, "KeyArn": k.arn} for k in keys], + "Truncated": False, + }) + + +@kms_api_bp.route("/keys/", methods=["GET", "DELETE"]) +@limiter.limit("30 per minute") +def get_or_delete_key(key_id: str): + """Get or delete a specific KMS key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + if request.method == "DELETE": + try: + kms.delete_key(key_id) + current_app.logger.info( + "KMS key deleted", + extra={"key_id": key_id, "principal": principal.access_key}, + ) + return Response(status=204) + except EncryptionError as exc: + return _error_response("NotFoundException", str(exc), 404) + + # GET + key = kms.get_key(key_id) + if not key: + return _error_response("NotFoundException", f"Key not found: {key_id}", 404) + + return jsonify({"KeyMetadata": key.to_dict()}) + + +@kms_api_bp.route("/keys//enable", methods=["POST"]) +@limiter.limit("30 per minute") +def enable_key(key_id: str): + """Enable a KMS key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + try: + kms.enable_key(key_id) + current_app.logger.info( + "KMS key enabled", + extra={"key_id": key_id, "principal": principal.access_key}, + ) + return Response(status=200) + except EncryptionError as exc: + return _error_response("NotFoundException", str(exc), 404) + + +@kms_api_bp.route("/keys//disable", methods=["POST"]) +@limiter.limit("30 per minute") +def disable_key(key_id: str): + """Disable a KMS key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + try: + kms.disable_key(key_id) + current_app.logger.info( + "KMS key disabled", + extra={"key_id": key_id, "principal": principal.access_key}, + ) + return Response(status=200) + except EncryptionError as exc: + return _error_response("NotFoundException", str(exc), 404) + + +# ---------------------- Encryption Operations ---------------------- + +@kms_api_bp.route("/encrypt", methods=["POST"]) +@limiter.limit("60 per minute") +def encrypt_data(): + """Encrypt data using a KMS key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + key_id = payload.get("KeyId") + plaintext_b64 = payload.get("Plaintext") + context = payload.get("EncryptionContext") + + if not key_id: + return _error_response("ValidationException", "KeyId is required", 400) + if not plaintext_b64: + return _error_response("ValidationException", "Plaintext is required", 400) + + try: + plaintext = base64.b64decode(plaintext_b64) + except Exception: + return _error_response("ValidationException", "Plaintext must be base64 encoded", 400) + + try: + ciphertext = kms.encrypt(key_id, plaintext, context) + return jsonify({ + "CiphertextBlob": base64.b64encode(ciphertext).decode(), + "KeyId": key_id, + "EncryptionAlgorithm": "SYMMETRIC_DEFAULT", + }) + except EncryptionError as exc: + return _error_response("KMSInternalException", str(exc), 400) + + +@kms_api_bp.route("/decrypt", methods=["POST"]) +@limiter.limit("60 per minute") +def decrypt_data(): + """Decrypt data using a KMS key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + ciphertext_b64 = payload.get("CiphertextBlob") + context = payload.get("EncryptionContext") + + if not ciphertext_b64: + return _error_response("ValidationException", "CiphertextBlob is required", 400) + + try: + ciphertext = base64.b64decode(ciphertext_b64) + except Exception: + return _error_response("ValidationException", "CiphertextBlob must be base64 encoded", 400) + + try: + plaintext, key_id = kms.decrypt(ciphertext, context) + return jsonify({ + "Plaintext": base64.b64encode(plaintext).decode(), + "KeyId": key_id, + "EncryptionAlgorithm": "SYMMETRIC_DEFAULT", + }) + except EncryptionError as exc: + return _error_response("InvalidCiphertextException", str(exc), 400) + + +@kms_api_bp.route("/generate-data-key", methods=["POST"]) +@limiter.limit("60 per minute") +def generate_data_key(): + """Generate a data encryption key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + key_id = payload.get("KeyId") + context = payload.get("EncryptionContext") + key_spec = payload.get("KeySpec", "AES_256") + + if not key_id: + return _error_response("ValidationException", "KeyId is required", 400) + + if key_spec not in {"AES_256", "AES_128"}: + return _error_response("ValidationException", "KeySpec must be AES_256 or AES_128", 400) + + try: + plaintext_key, encrypted_key = kms.generate_data_key(key_id, context) + + # Trim key if AES_128 requested + if key_spec == "AES_128": + plaintext_key = plaintext_key[:16] + + return jsonify({ + "Plaintext": base64.b64encode(plaintext_key).decode(), + "CiphertextBlob": base64.b64encode(encrypted_key).decode(), + "KeyId": key_id, + }) + except EncryptionError as exc: + return _error_response("KMSInternalException", str(exc), 400) + + +@kms_api_bp.route("/generate-data-key-without-plaintext", methods=["POST"]) +@limiter.limit("60 per minute") +def generate_data_key_without_plaintext(): + """Generate a data encryption key without returning the plaintext.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + key_id = payload.get("KeyId") + context = payload.get("EncryptionContext") + + if not key_id: + return _error_response("ValidationException", "KeyId is required", 400) + + try: + _, encrypted_key = kms.generate_data_key(key_id, context) + return jsonify({ + "CiphertextBlob": base64.b64encode(encrypted_key).decode(), + "KeyId": key_id, + }) + except EncryptionError as exc: + return _error_response("KMSInternalException", str(exc), 400) + + +@kms_api_bp.route("/re-encrypt", methods=["POST"]) +@limiter.limit("30 per minute") +def re_encrypt(): + """Re-encrypt data with a different key.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + ciphertext_b64 = payload.get("CiphertextBlob") + destination_key_id = payload.get("DestinationKeyId") + source_context = payload.get("SourceEncryptionContext") + destination_context = payload.get("DestinationEncryptionContext") + + if not ciphertext_b64: + return _error_response("ValidationException", "CiphertextBlob is required", 400) + if not destination_key_id: + return _error_response("ValidationException", "DestinationKeyId is required", 400) + + try: + ciphertext = base64.b64decode(ciphertext_b64) + except Exception: + return _error_response("ValidationException", "CiphertextBlob must be base64 encoded", 400) + + try: + # First decrypt, get source key id + plaintext, source_key_id = kms.decrypt(ciphertext, source_context) + + # Re-encrypt with destination key + new_ciphertext = kms.encrypt(destination_key_id, plaintext, destination_context) + + return jsonify({ + "CiphertextBlob": base64.b64encode(new_ciphertext).decode(), + "SourceKeyId": source_key_id, + "KeyId": destination_key_id, + }) + except EncryptionError as exc: + return _error_response("KMSInternalException", str(exc), 400) + + +@kms_api_bp.route("/generate-random", methods=["POST"]) +@limiter.limit("60 per minute") +def generate_random(): + """Generate random bytes.""" + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + num_bytes = payload.get("NumberOfBytes", 32) + + try: + num_bytes = int(num_bytes) + except (TypeError, ValueError): + return _error_response("ValidationException", "NumberOfBytes must be an integer", 400) + + try: + random_bytes = kms.generate_random(num_bytes) + return jsonify({ + "Plaintext": base64.b64encode(random_bytes).decode(), + }) + except EncryptionError as exc: + return _error_response("ValidationException", str(exc), 400) + + +# ---------------------- Client-Side Encryption Helpers ---------------------- + +@kms_api_bp.route("/client/generate-key", methods=["POST"]) +@limiter.limit("30 per minute") +def generate_client_key(): + """Generate a client-side encryption key.""" + principal, error = _require_principal() + if error: + return error + + key_info = ClientEncryptionHelper.generate_client_key() + return jsonify(key_info) + + +@kms_api_bp.route("/client/encrypt", methods=["POST"]) +@limiter.limit("60 per minute") +def client_encrypt(): + """Encrypt data using client-side encryption.""" + principal, error = _require_principal() + if error: + return error + + payload = request.get_json(silent=True) or {} + plaintext_b64 = payload.get("Plaintext") + key_b64 = payload.get("Key") + + if not plaintext_b64 or not key_b64: + return _error_response("ValidationException", "Plaintext and Key are required", 400) + + try: + plaintext = base64.b64decode(plaintext_b64) + result = ClientEncryptionHelper.encrypt_with_key(plaintext, key_b64) + return jsonify(result) + except Exception as exc: + return _error_response("EncryptionError", str(exc), 400) + + +@kms_api_bp.route("/client/decrypt", methods=["POST"]) +@limiter.limit("60 per minute") +def client_decrypt(): + """Decrypt data using client-side encryption.""" + principal, error = _require_principal() + if error: + return error + + payload = request.get_json(silent=True) or {} + ciphertext_b64 = payload.get("Ciphertext") or payload.get("ciphertext") + nonce_b64 = payload.get("Nonce") or payload.get("nonce") + key_b64 = payload.get("Key") or payload.get("key") + + if not ciphertext_b64 or not nonce_b64 or not key_b64: + return _error_response("ValidationException", "Ciphertext, Nonce, and Key are required", 400) + + try: + plaintext = ClientEncryptionHelper.decrypt_with_key(ciphertext_b64, nonce_b64, key_b64) + return jsonify({ + "Plaintext": base64.b64encode(plaintext).decode(), + }) + except Exception as exc: + return _error_response("DecryptionError", str(exc), 400) + + +# ---------------------- Encryption Materials for S3 Client-Side Encryption ---------------------- + +@kms_api_bp.route("/materials/", methods=["POST"]) +@limiter.limit("60 per minute") +def get_encryption_materials(key_id: str): + """Get encryption materials for client-side S3 encryption. + + This is used by S3 encryption clients that want to use KMS for + key management but perform encryption client-side. + """ + principal, error = _require_principal() + if error: + return error + + kms = _kms() + if not kms: + return _error_response("KMSNotEnabled", "KMS is not configured", 400) + + payload = request.get_json(silent=True) or {} + context = payload.get("EncryptionContext") + + try: + plaintext_key, encrypted_key = kms.generate_data_key(key_id, context) + + return jsonify({ + "PlaintextKey": base64.b64encode(plaintext_key).decode(), + "EncryptedKey": base64.b64encode(encrypted_key).decode(), + "KeyId": key_id, + "Algorithm": "AES-256-GCM", + "KeyWrapAlgorithm": "kms", + }) + except EncryptionError as exc: + return _error_response("KMSInternalException", str(exc), 400) diff --git a/app/ui.py b/app/ui.py index e68c60e..20e1d14 100644 --- a/app/ui.py +++ b/app/ui.py @@ -30,6 +30,7 @@ from .bucket_policies import BucketPolicyStore from .connections import ConnectionStore, RemoteConnection from .extensions import limiter from .iam import IamError +from .kms import KMSManager from .replication import ReplicationManager, ReplicationRule from .secret_store import EphemeralSecretStore from .storage import ObjectStorage, StorageError @@ -50,6 +51,9 @@ def _iam(): return current_app.extensions["iam"] +def _kms() -> KMSManager | None: + return current_app.extensions.get("kms") + def _bucket_policies() -> BucketPolicyStore: store: BucketPolicyStore = current_app.extensions["bucket_policies"] @@ -360,6 +364,14 @@ def bucket_detail(bucket_name: str): # Load connections for admin, or for non-admin if there's an existing rule (to show target name) connections = _connections().list() if (is_replication_admin or replication_rule) else [] + # Encryption settings + encryption_config = storage.get_bucket_encryption(bucket_name) + kms_manager = _kms() + kms_keys = kms_manager.list_keys() if kms_manager else [] + kms_enabled = current_app.config.get("KMS_ENABLED", False) + encryption_enabled = current_app.config.get("ENCRYPTION_ENABLED", False) + can_manage_encryption = can_manage_versioning # Same as other bucket properties + return render_template( "bucket_detail.html", bucket_name=bucket_name, @@ -370,11 +382,16 @@ def bucket_detail(bucket_name: str): can_edit_policy=can_edit_policy, can_manage_versioning=can_manage_versioning, can_manage_replication=can_manage_replication, + can_manage_encryption=can_manage_encryption, is_replication_admin=is_replication_admin, default_policy=default_policy, versioning_enabled=versioning_enabled, replication_rule=replication_rule, connections=connections, + encryption_config=encryption_config, + kms_keys=kms_keys, + kms_enabled=kms_enabled, + encryption_enabled=encryption_enabled, ) @@ -878,6 +895,62 @@ def update_bucket_versioning(bucket_name: str): return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) +@ui_bp.post("/buckets//encryption") +def update_bucket_encryption(bucket_name: str): + """Update bucket default encryption configuration.""" + principal = _current_principal() + try: + _authorize_ui(principal, bucket_name, "write") + except IamError as exc: + flash(_friendly_error_message(exc), "danger") + return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) + + action = request.form.get("action", "enable") + + if action == "disable": + # Disable encryption + try: + _storage().set_bucket_encryption(bucket_name, None) + flash("Default encryption disabled", "info") + except StorageError as exc: + flash(_friendly_error_message(exc), "danger") + return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) + + # Enable or update encryption + algorithm = request.form.get("algorithm", "AES256") + kms_key_id = request.form.get("kms_key_id", "").strip() or None + + # Validate algorithm + if algorithm not in ("AES256", "aws:kms"): + flash("Invalid encryption algorithm", "danger") + return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) + + # Build encryption config following AWS format + encryption_config: dict[str, Any] = { + "Rules": [ + { + "ApplyServerSideEncryptionByDefault": { + "SSEAlgorithm": algorithm, + } + } + ] + } + + if algorithm == "aws:kms" and kms_key_id: + encryption_config["Rules"][0]["ApplyServerSideEncryptionByDefault"]["KMSMasterKeyID"] = kms_key_id + + try: + _storage().set_bucket_encryption(bucket_name, encryption_config) + if algorithm == "aws:kms": + flash("Default KMS encryption enabled", "success") + else: + flash("Default AES-256 encryption enabled", "success") + except StorageError as exc: + flash(_friendly_error_message(exc), "danger") + + return redirect(url_for("ui.bucket_detail", bucket_name=bucket_name, tab="properties")) + + @ui_bp.get("/iam") def iam_dashboard(): principal = _current_principal() diff --git a/requirements.txt b/requirements.txt index 7c2c75d..1f225cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ requests>=2.31 boto3>=1.34 waitress>=2.1.2 psutil>=5.9.0 +cryptography>=41.0.0 diff --git a/templates/bucket_detail.html b/templates/bucket_detail.html index 3e6adcc..5017476 100644 --- a/templates/bucket_detail.html +++ b/templates/bucket_detail.html @@ -607,6 +607,129 @@ {% endif %} + + + {% if encryption_enabled %} +
+
+ + + + + Default Encryption +
+
+ {% set enc_rules = encryption_config.get('Rules', []) %} + {% set enc_default = enc_rules[0].get('ApplyServerSideEncryptionByDefault', {}) if enc_rules else {} %} + {% set enc_algorithm = enc_default.get('SSEAlgorithm', '') %} + {% set enc_kms_key = enc_default.get('KMSMasterKeyID', '') %} + + {% if enc_algorithm %} + + + {% else %} + + + {% endif %} + + {% if can_manage_encryption %} +
+ + + + +
+ +
+
+
+ + +
+ {% if kms_enabled %} +
+ + +
+ {% endif %} +
+
+
+ + + {% if kms_enabled %} +
+ + +
Select a KMS key to encrypt objects. Leave empty to use the default key.
+
+ {% endif %} + +
+ + {% if enc_algorithm %} + + {% endif %} +
+
+ {% else %} +
+ + + +

You do not have permission to modify encryption settings for this bucket.

+
+ {% endif %} +
+
+ {% endif %} @@ -656,6 +779,35 @@ {% endif %} + + {% if encryption_enabled %} +
+
+
+ + + + About Encryption +
+

+ Server-side encryption protects data at rest. Objects are encrypted when stored and decrypted when retrieved. +

+ +
Encryption Types
+
    +
  • SSE-S3 (AES-256) — S3-managed keys, automatic encryption
  • +
  • SSE-KMS — KMS-managed keys with audit trail and key rotation
  • +
+ +
How It Works
+
    +
  • New objects are encrypted using the default setting
  • +
  • Existing objects are not automatically re-encrypted
  • +
  • Decryption is transparent during download
  • +
+
+
+ {% endif %} @@ -3147,5 +3299,33 @@ loadReplicationStats(); }); } + + // Encryption settings + const algoAes256Radio = document.getElementById('algo_aes256'); + const algoKmsRadio = document.getElementById('algo_kms'); + const kmsKeySection = document.getElementById('kmsKeySection'); + const encryptionForm = document.getElementById('encryptionForm'); + const encryptionAction = document.getElementById('encryptionAction'); + const disableEncryptionBtn = document.getElementById('disableEncryptionBtn'); + + // Toggle KMS key section visibility based on selected algorithm + const updateKmsKeyVisibility = () => { + if (!kmsKeySection) return; + const showKms = algoKmsRadio?.checked; + kmsKeySection.style.display = showKms ? '' : 'none'; + }; + + algoAes256Radio?.addEventListener('change', updateKmsKeyVisibility); + algoKmsRadio?.addEventListener('change', updateKmsKeyVisibility); + + // Handle disable encryption button + disableEncryptionBtn?.addEventListener('click', () => { + if (encryptionAction && encryptionForm) { + if (confirm('Are you sure you want to disable default encryption? New objects will not be encrypted automatically.')) { + encryptionAction.value = 'disable'; + encryptionForm.submit(); + } + } + }); {% endblock %} diff --git a/tests/test_encryption.py b/tests/test_encryption.py new file mode 100644 index 0000000..3493622 --- /dev/null +++ b/tests/test_encryption.py @@ -0,0 +1,763 @@ +"""Tests for encryption functionality.""" +from __future__ import annotations + +import base64 +import io +import json +import os +import secrets +import tempfile +from pathlib import Path + +import pytest + + +class TestLocalKeyEncryption: + """Tests for LocalKeyEncryption provider.""" + + def test_create_master_key(self, tmp_path): + """Test that master key is created if it doesn't exist.""" + from app.encryption import LocalKeyEncryption + + key_path = tmp_path / "keys" / "master.key" + provider = LocalKeyEncryption(key_path) + + # Access master key to trigger creation + key = provider.master_key + + assert key_path.exists() + assert len(key) == 32 # 256-bit key + + def test_load_existing_master_key(self, tmp_path): + """Test loading an existing master key.""" + from app.encryption import LocalKeyEncryption + + key_path = tmp_path / "master.key" + original_key = secrets.token_bytes(32) + key_path.write_text(base64.b64encode(original_key).decode()) + + provider = LocalKeyEncryption(key_path) + loaded_key = provider.master_key + + assert loaded_key == original_key + + def test_encrypt_decrypt_roundtrip(self, tmp_path): + """Test that data can be encrypted and decrypted correctly.""" + from app.encryption import LocalKeyEncryption + + key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(key_path) + + plaintext = b"Hello, World! This is a test message." + + # Encrypt + result = provider.encrypt(plaintext) + + assert result.ciphertext != plaintext + assert result.key_id == "local" + assert len(result.nonce) == 12 + assert len(result.encrypted_data_key) > 0 + + # Decrypt + decrypted = provider.decrypt( + result.ciphertext, + result.nonce, + result.encrypted_data_key, + result.key_id, + ) + + assert decrypted == plaintext + + def test_different_data_keys_per_encryption(self, tmp_path): + """Test that each encryption uses a different data key.""" + from app.encryption import LocalKeyEncryption + + key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(key_path) + + plaintext = b"Same message" + + result1 = provider.encrypt(plaintext) + result2 = provider.encrypt(plaintext) + + # Different encrypted data keys + assert result1.encrypted_data_key != result2.encrypted_data_key + # Different nonces + assert result1.nonce != result2.nonce + # Different ciphertexts + assert result1.ciphertext != result2.ciphertext + + def test_generate_data_key(self, tmp_path): + """Test data key generation.""" + from app.encryption import LocalKeyEncryption + + key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(key_path) + + plaintext_key, encrypted_key = provider.generate_data_key() + + assert len(plaintext_key) == 32 + assert len(encrypted_key) > 32 # nonce + ciphertext + tag + + # Verify we can decrypt the key + decrypted_key = provider._decrypt_data_key(encrypted_key) + assert decrypted_key == plaintext_key + + def test_decrypt_with_wrong_key_fails(self, tmp_path): + """Test that decryption fails with wrong master key.""" + from app.encryption import LocalKeyEncryption, EncryptionError + + # Create two providers with different keys + key_path1 = tmp_path / "master1.key" + key_path2 = tmp_path / "master2.key" + + provider1 = LocalKeyEncryption(key_path1) + provider2 = LocalKeyEncryption(key_path2) + + # Encrypt with provider1 + plaintext = b"Secret message" + result = provider1.encrypt(plaintext) + + # Try to decrypt with provider2 + with pytest.raises(EncryptionError): + provider2.decrypt( + result.ciphertext, + result.nonce, + result.encrypted_data_key, + result.key_id, + ) + + +class TestEncryptionMetadata: + """Tests for EncryptionMetadata class.""" + + def test_to_dict(self): + """Test converting metadata to dictionary.""" + from app.encryption import EncryptionMetadata + + nonce = secrets.token_bytes(12) + encrypted_key = secrets.token_bytes(60) + + metadata = EncryptionMetadata( + algorithm="AES256", + key_id="local", + nonce=nonce, + encrypted_data_key=encrypted_key, + ) + + result = metadata.to_dict() + + assert result["x-amz-server-side-encryption"] == "AES256" + assert result["x-amz-encryption-key-id"] == "local" + assert base64.b64decode(result["x-amz-encryption-nonce"]) == nonce + assert base64.b64decode(result["x-amz-encrypted-data-key"]) == encrypted_key + + def test_from_dict(self): + """Test creating metadata from dictionary.""" + from app.encryption import EncryptionMetadata + + nonce = secrets.token_bytes(12) + encrypted_key = secrets.token_bytes(60) + + data = { + "x-amz-server-side-encryption": "AES256", + "x-amz-encryption-key-id": "local", + "x-amz-encryption-nonce": base64.b64encode(nonce).decode(), + "x-amz-encrypted-data-key": base64.b64encode(encrypted_key).decode(), + } + + metadata = EncryptionMetadata.from_dict(data) + + assert metadata is not None + assert metadata.algorithm == "AES256" + assert metadata.key_id == "local" + assert metadata.nonce == nonce + assert metadata.encrypted_data_key == encrypted_key + + def test_from_dict_returns_none_for_unencrypted(self): + """Test that from_dict returns None for unencrypted objects.""" + from app.encryption import EncryptionMetadata + + data = {"some-other-key": "value"} + + metadata = EncryptionMetadata.from_dict(data) + + assert metadata is None + + +class TestStreamingEncryptor: + """Tests for streaming encryption.""" + + def test_encrypt_decrypt_stream(self, tmp_path): + """Test streaming encryption and decryption.""" + from app.encryption import LocalKeyEncryption, StreamingEncryptor + + key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(key_path) + encryptor = StreamingEncryptor(provider, chunk_size=1024) + + # Create test data + original_data = b"A" * 5000 + b"B" * 5000 + b"C" * 5000 # 15KB + stream = io.BytesIO(original_data) + + # Encrypt + encrypted_stream, metadata = encryptor.encrypt_stream(stream) + encrypted_data = encrypted_stream.read() + + assert encrypted_data != original_data + assert metadata.algorithm == "AES256" + + # Decrypt + encrypted_stream = io.BytesIO(encrypted_data) + decrypted_stream = encryptor.decrypt_stream(encrypted_stream, metadata) + decrypted_data = decrypted_stream.read() + + assert decrypted_data == original_data + + def test_encrypt_small_data(self, tmp_path): + """Test encrypting data smaller than chunk size.""" + from app.encryption import LocalKeyEncryption, StreamingEncryptor + + key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(key_path) + encryptor = StreamingEncryptor(provider, chunk_size=1024) + + original_data = b"Small data" + stream = io.BytesIO(original_data) + + encrypted_stream, metadata = encryptor.encrypt_stream(stream) + encrypted_stream.seek(0) + + decrypted_stream = encryptor.decrypt_stream(encrypted_stream, metadata) + decrypted_data = decrypted_stream.read() + + assert decrypted_data == original_data + + def test_encrypt_empty_data(self, tmp_path): + """Test encrypting empty data.""" + from app.encryption import LocalKeyEncryption, StreamingEncryptor + + key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(key_path) + encryptor = StreamingEncryptor(provider) + + stream = io.BytesIO(b"") + + encrypted_stream, metadata = encryptor.encrypt_stream(stream) + encrypted_stream.seek(0) + + decrypted_stream = encryptor.decrypt_stream(encrypted_stream, metadata) + decrypted_data = decrypted_stream.read() + + assert decrypted_data == b"" + + +class TestEncryptionManager: + """Tests for EncryptionManager.""" + + def test_encryption_disabled_by_default(self, tmp_path): + """Test that encryption is disabled by default.""" + from app.encryption import EncryptionManager + + config = { + "encryption_enabled": False, + "encryption_master_key_path": str(tmp_path / "master.key"), + } + + manager = EncryptionManager(config) + + assert not manager.enabled + + def test_encryption_enabled(self, tmp_path): + """Test enabling encryption.""" + from app.encryption import EncryptionManager + + config = { + "encryption_enabled": True, + "encryption_master_key_path": str(tmp_path / "master.key"), + "default_encryption_algorithm": "AES256", + } + + manager = EncryptionManager(config) + + assert manager.enabled + assert manager.default_algorithm == "AES256" + + def test_encrypt_decrypt_object(self, tmp_path): + """Test encrypting and decrypting an object.""" + from app.encryption import EncryptionManager + + config = { + "encryption_enabled": True, + "encryption_master_key_path": str(tmp_path / "master.key"), + } + + manager = EncryptionManager(config) + + plaintext = b"Object data to encrypt" + + ciphertext, metadata = manager.encrypt_object(plaintext) + + assert ciphertext != plaintext + assert metadata.algorithm == "AES256" + + decrypted = manager.decrypt_object(ciphertext, metadata) + + assert decrypted == plaintext + + +class TestClientEncryptionHelper: + """Tests for client-side encryption helpers.""" + + def test_generate_client_key(self): + """Test generating a client encryption key.""" + from app.encryption import ClientEncryptionHelper + + key_info = ClientEncryptionHelper.generate_client_key() + + assert "key" in key_info + assert key_info["algorithm"] == "AES-256-GCM" + assert "created_at" in key_info + + # Verify key is 256 bits + key = base64.b64decode(key_info["key"]) + assert len(key) == 32 + + def test_encrypt_with_key(self): + """Test encrypting data with a client key.""" + from app.encryption import ClientEncryptionHelper + + key = base64.b64encode(secrets.token_bytes(32)).decode() + plaintext = b"Client-side encrypted data" + + result = ClientEncryptionHelper.encrypt_with_key(plaintext, key) + + assert "ciphertext" in result + assert "nonce" in result + assert result["algorithm"] == "AES-256-GCM" + + def test_encrypt_decrypt_with_key(self): + """Test round-trip client-side encryption.""" + from app.encryption import ClientEncryptionHelper + + key = base64.b64encode(secrets.token_bytes(32)).decode() + plaintext = b"Client-side encrypted data" + + encrypted = ClientEncryptionHelper.encrypt_with_key(plaintext, key) + + decrypted = ClientEncryptionHelper.decrypt_with_key( + encrypted["ciphertext"], + encrypted["nonce"], + key, + ) + + assert decrypted == plaintext + + def test_wrong_key_fails(self): + """Test that decryption with wrong key fails.""" + from app.encryption import ClientEncryptionHelper, EncryptionError + + key1 = base64.b64encode(secrets.token_bytes(32)).decode() + key2 = base64.b64encode(secrets.token_bytes(32)).decode() + plaintext = b"Secret data" + + encrypted = ClientEncryptionHelper.encrypt_with_key(plaintext, key1) + + with pytest.raises(EncryptionError): + ClientEncryptionHelper.decrypt_with_key( + encrypted["ciphertext"], + encrypted["nonce"], + key2, + ) + + +class TestKMSManager: + """Tests for KMS key management.""" + + def test_create_key(self, tmp_path): + """Test creating a KMS key.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + key = kms.create_key("Test key", key_id="test-key-1") + + assert key.key_id == "test-key-1" + assert key.description == "Test key" + assert key.enabled + assert keys_path.exists() + + def test_list_keys(self, tmp_path): + """Test listing KMS keys.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Key 1", key_id="key-1") + kms.create_key("Key 2", key_id="key-2") + + keys = kms.list_keys() + + assert len(keys) == 2 + key_ids = {k.key_id for k in keys} + assert "key-1" in key_ids + assert "key-2" in key_ids + + def test_get_key(self, tmp_path): + """Test getting a specific key.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Test key", key_id="test-key") + + key = kms.get_key("test-key") + + assert key is not None + assert key.key_id == "test-key" + + # Non-existent key + assert kms.get_key("non-existent") is None + + def test_enable_disable_key(self, tmp_path): + """Test enabling and disabling keys.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Test key", key_id="test-key") + + # Initially enabled + assert kms.get_key("test-key").enabled + + # Disable + kms.disable_key("test-key") + assert not kms.get_key("test-key").enabled + + # Enable + kms.enable_key("test-key") + assert kms.get_key("test-key").enabled + + def test_delete_key(self, tmp_path): + """Test deleting a key.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Test key", key_id="test-key") + assert kms.get_key("test-key") is not None + + kms.delete_key("test-key") + assert kms.get_key("test-key") is None + + def test_encrypt_decrypt(self, tmp_path): + """Test KMS encrypt and decrypt.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + key = kms.create_key("Test key", key_id="test-key") + + plaintext = b"Secret data to encrypt" + + ciphertext = kms.encrypt("test-key", plaintext) + + assert ciphertext != plaintext + + decrypted, key_id = kms.decrypt(ciphertext) + + assert decrypted == plaintext + assert key_id == "test-key" + + def test_encrypt_with_context(self, tmp_path): + """Test encryption with encryption context.""" + from app.kms import KMSManager, EncryptionError + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Test key", key_id="test-key") + + plaintext = b"Secret data" + context = {"bucket": "test-bucket", "key": "test-key"} + + ciphertext = kms.encrypt("test-key", plaintext, context) + + # Decrypt with same context succeeds + decrypted, _ = kms.decrypt(ciphertext, context) + assert decrypted == plaintext + + # Decrypt with different context fails + with pytest.raises(EncryptionError): + kms.decrypt(ciphertext, {"different": "context"}) + + def test_generate_data_key(self, tmp_path): + """Test generating a data key.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Test key", key_id="test-key") + + plaintext_key, encrypted_key = kms.generate_data_key("test-key") + + assert len(plaintext_key) == 32 + assert len(encrypted_key) > 0 + + # Decrypt the encrypted key + decrypted_key = kms.decrypt_data_key("test-key", encrypted_key) + + assert decrypted_key == plaintext_key + + def test_disabled_key_cannot_encrypt(self, tmp_path): + """Test that disabled keys cannot be used for encryption.""" + from app.kms import KMSManager, EncryptionError + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Test key", key_id="test-key") + kms.disable_key("test-key") + + with pytest.raises(EncryptionError, match="disabled"): + kms.encrypt("test-key", b"data") + + def test_re_encrypt(self, tmp_path): + """Test re-encrypting data with a different key.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + kms.create_key("Key 1", key_id="key-1") + kms.create_key("Key 2", key_id="key-2") + + plaintext = b"Data to re-encrypt" + + # Encrypt with key-1 + ciphertext1 = kms.encrypt("key-1", plaintext) + + # Re-encrypt with key-2 + ciphertext2 = kms.re_encrypt(ciphertext1, "key-2") + + # Decrypt with key-2 + decrypted, key_id = kms.decrypt(ciphertext2) + + assert decrypted == plaintext + assert key_id == "key-2" + + def test_generate_random(self, tmp_path): + """Test generating random bytes.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + + random1 = kms.generate_random(32) + random2 = kms.generate_random(32) + + assert len(random1) == 32 + assert len(random2) == 32 + assert random1 != random2 # Very unlikely to be equal + + def test_keys_persist_across_instances(self, tmp_path): + """Test that keys persist and can be loaded by new instances.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + # Create key with first instance + kms1 = KMSManager(keys_path, master_key_path) + kms1.create_key("Test key", key_id="test-key") + + plaintext = b"Persistent encryption test" + ciphertext = kms1.encrypt("test-key", plaintext) + + # Create new instance and verify key works + kms2 = KMSManager(keys_path, master_key_path) + + decrypted, key_id = kms2.decrypt(ciphertext) + + assert decrypted == plaintext + assert key_id == "test-key" + + +class TestKMSEncryptionProvider: + """Tests for KMS encryption provider.""" + + def test_kms_encryption_provider(self, tmp_path): + """Test using KMS as an encryption provider.""" + from app.kms import KMSManager + + keys_path = tmp_path / "kms_keys.json" + master_key_path = tmp_path / "master.key" + + kms = KMSManager(keys_path, master_key_path) + kms.create_key("Test key", key_id="test-key") + + provider = kms.get_provider("test-key") + + plaintext = b"Data encrypted with KMS provider" + + result = provider.encrypt(plaintext) + + assert result.key_id == "test-key" + assert result.ciphertext != plaintext + + decrypted = provider.decrypt( + result.ciphertext, + result.nonce, + result.encrypted_data_key, + result.key_id, + ) + + assert decrypted == plaintext + + +class TestEncryptedStorage: + """Tests for encrypted storage layer.""" + + def test_put_and_get_encrypted_object(self, tmp_path): + """Test storing and retrieving an encrypted object.""" + from app.storage import ObjectStorage + from app.encryption import EncryptionManager + from app.encrypted_storage import EncryptedObjectStorage + + storage_root = tmp_path / "storage" + storage = ObjectStorage(storage_root) + + config = { + "encryption_enabled": True, + "encryption_master_key_path": str(tmp_path / "master.key"), + "default_encryption_algorithm": "AES256", + } + encryption = EncryptionManager(config) + + encrypted_storage = EncryptedObjectStorage(storage, encryption) + + # Create bucket with encryption config + storage.create_bucket("test-bucket") + storage.set_bucket_encryption("test-bucket", { + "Rules": [{"SSEAlgorithm": "AES256"}] + }) + + # Put object + original_data = b"This is secret data that should be encrypted" + stream = io.BytesIO(original_data) + + meta = encrypted_storage.put_object( + "test-bucket", + "secret.txt", + stream, + ) + + assert meta is not None + + # Verify file on disk is encrypted (not plaintext) + file_path = storage_root / "test-bucket" / "secret.txt" + stored_data = file_path.read_bytes() + assert stored_data != original_data + + # Get object - should be decrypted + data, metadata = encrypted_storage.get_object_data("test-bucket", "secret.txt") + + assert data == original_data + + def test_no_encryption_without_config(self, tmp_path): + """Test that objects are not encrypted without bucket config.""" + from app.storage import ObjectStorage + from app.encryption import EncryptionManager + from app.encrypted_storage import EncryptedObjectStorage + + storage_root = tmp_path / "storage" + storage = ObjectStorage(storage_root) + + config = { + "encryption_enabled": True, + "encryption_master_key_path": str(tmp_path / "master.key"), + } + encryption = EncryptionManager(config) + + encrypted_storage = EncryptedObjectStorage(storage, encryption) + + storage.create_bucket("test-bucket") + # No encryption config + + original_data = b"Unencrypted data" + stream = io.BytesIO(original_data) + + encrypted_storage.put_object("test-bucket", "plain.txt", stream) + + # Verify file on disk is NOT encrypted + file_path = storage_root / "test-bucket" / "plain.txt" + stored_data = file_path.read_bytes() + assert stored_data == original_data + + def test_explicit_encryption_request(self, tmp_path): + """Test explicitly requesting encryption.""" + from app.storage import ObjectStorage + from app.encryption import EncryptionManager + from app.encrypted_storage import EncryptedObjectStorage + + storage_root = tmp_path / "storage" + storage = ObjectStorage(storage_root) + + config = { + "encryption_enabled": True, + "encryption_master_key_path": str(tmp_path / "master.key"), + } + encryption = EncryptionManager(config) + + encrypted_storage = EncryptedObjectStorage(storage, encryption) + + storage.create_bucket("test-bucket") + + original_data = b"Explicitly encrypted data" + stream = io.BytesIO(original_data) + + # Request encryption explicitly + encrypted_storage.put_object( + "test-bucket", + "encrypted.txt", + stream, + server_side_encryption="AES256", + ) + + # Verify file is encrypted + file_path = storage_root / "test-bucket" / "encrypted.txt" + stored_data = file_path.read_bytes() + assert stored_data != original_data + + # Get object - should be decrypted + data, _ = encrypted_storage.get_object_data("test-bucket", "encrypted.txt") + assert data == original_data diff --git a/tests/test_kms_api.py b/tests/test_kms_api.py new file mode 100644 index 0000000..e015f7a --- /dev/null +++ b/tests/test_kms_api.py @@ -0,0 +1,506 @@ +"""Tests for KMS API endpoints.""" +from __future__ import annotations + +import base64 +import json +import secrets + +import pytest + + +@pytest.fixture +def kms_client(tmp_path): + """Create a test client with KMS enabled.""" + from app import create_app + + app = create_app({ + "TESTING": True, + "STORAGE_ROOT": str(tmp_path / "storage"), + "IAM_CONFIG": str(tmp_path / "iam.json"), + "BUCKET_POLICY_PATH": str(tmp_path / "policies.json"), + "ENCRYPTION_ENABLED": True, + "KMS_ENABLED": True, + "ENCRYPTION_MASTER_KEY_PATH": str(tmp_path / "master.key"), + "KMS_KEYS_PATH": str(tmp_path / "kms_keys.json"), + }) + + # Create default IAM config with admin user + iam_config = { + "users": [ + { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "display_name": "Test User", + "permissions": ["*"] + } + ] + } + (tmp_path / "iam.json").write_text(json.dumps(iam_config)) + + return app.test_client() + + +@pytest.fixture +def auth_headers(): + """Get authentication headers.""" + return { + "X-Access-Key": "test-access-key", + "X-Secret-Key": "test-secret-key", + } + + +class TestKMSKeyManagement: + """Tests for KMS key management endpoints.""" + + def test_create_key(self, kms_client, auth_headers): + """Test creating a KMS key.""" + response = kms_client.post( + "/kms/keys", + json={"Description": "Test encryption key"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + assert "KeyMetadata" in data + assert data["KeyMetadata"]["Description"] == "Test encryption key" + assert data["KeyMetadata"]["Enabled"] is True + assert "KeyId" in data["KeyMetadata"] + + def test_create_key_with_custom_id(self, kms_client, auth_headers): + """Test creating a key with a custom ID.""" + response = kms_client.post( + "/kms/keys", + json={"KeyId": "my-custom-key", "Description": "Custom key"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + assert data["KeyMetadata"]["KeyId"] == "my-custom-key" + + def test_list_keys(self, kms_client, auth_headers): + """Test listing KMS keys.""" + # Create some keys + kms_client.post("/kms/keys", json={"Description": "Key 1"}, headers=auth_headers) + kms_client.post("/kms/keys", json={"Description": "Key 2"}, headers=auth_headers) + + response = kms_client.get("/kms/keys", headers=auth_headers) + + assert response.status_code == 200 + data = response.get_json() + + assert "Keys" in data + assert len(data["Keys"]) == 2 + + def test_get_key(self, kms_client, auth_headers): + """Test getting a specific key.""" + # Create a key + create_response = kms_client.post( + "/kms/keys", + json={"KeyId": "test-key", "Description": "Test key"}, + headers=auth_headers, + ) + + response = kms_client.get("/kms/keys/test-key", headers=auth_headers) + + assert response.status_code == 200 + data = response.get_json() + + assert data["KeyMetadata"]["KeyId"] == "test-key" + assert data["KeyMetadata"]["Description"] == "Test key" + + def test_get_nonexistent_key(self, kms_client, auth_headers): + """Test getting a key that doesn't exist.""" + response = kms_client.get("/kms/keys/nonexistent", headers=auth_headers) + + assert response.status_code == 404 + + def test_delete_key(self, kms_client, auth_headers): + """Test deleting a key.""" + # Create a key + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + # Delete it + response = kms_client.delete("/kms/keys/test-key", headers=auth_headers) + + assert response.status_code == 204 + + # Verify it's gone + get_response = kms_client.get("/kms/keys/test-key", headers=auth_headers) + assert get_response.status_code == 404 + + def test_enable_disable_key(self, kms_client, auth_headers): + """Test enabling and disabling a key.""" + # Create a key + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + # Disable + response = kms_client.post("/kms/keys/test-key/disable", headers=auth_headers) + assert response.status_code == 200 + + # Verify disabled + get_response = kms_client.get("/kms/keys/test-key", headers=auth_headers) + assert get_response.get_json()["KeyMetadata"]["Enabled"] is False + + # Enable + response = kms_client.post("/kms/keys/test-key/enable", headers=auth_headers) + assert response.status_code == 200 + + # Verify enabled + get_response = kms_client.get("/kms/keys/test-key", headers=auth_headers) + assert get_response.get_json()["KeyMetadata"]["Enabled"] is True + + +class TestKMSEncryption: + """Tests for KMS encryption operations.""" + + def test_encrypt_decrypt(self, kms_client, auth_headers): + """Test encrypting and decrypting data.""" + # Create a key + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + plaintext = b"Hello, World!" + plaintext_b64 = base64.b64encode(plaintext).decode() + + # Encrypt + encrypt_response = kms_client.post( + "/kms/encrypt", + json={"KeyId": "test-key", "Plaintext": plaintext_b64}, + headers=auth_headers, + ) + + assert encrypt_response.status_code == 200 + encrypt_data = encrypt_response.get_json() + + assert "CiphertextBlob" in encrypt_data + assert encrypt_data["KeyId"] == "test-key" + + # Decrypt + decrypt_response = kms_client.post( + "/kms/decrypt", + json={"CiphertextBlob": encrypt_data["CiphertextBlob"]}, + headers=auth_headers, + ) + + assert decrypt_response.status_code == 200 + decrypt_data = decrypt_response.get_json() + + decrypted = base64.b64decode(decrypt_data["Plaintext"]) + assert decrypted == plaintext + + def test_encrypt_with_context(self, kms_client, auth_headers): + """Test encryption with encryption context.""" + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + plaintext = b"Contextualized data" + plaintext_b64 = base64.b64encode(plaintext).decode() + context = {"purpose": "testing", "bucket": "my-bucket"} + + # Encrypt with context + encrypt_response = kms_client.post( + "/kms/encrypt", + json={ + "KeyId": "test-key", + "Plaintext": plaintext_b64, + "EncryptionContext": context, + }, + headers=auth_headers, + ) + + assert encrypt_response.status_code == 200 + ciphertext = encrypt_response.get_json()["CiphertextBlob"] + + # Decrypt with same context succeeds + decrypt_response = kms_client.post( + "/kms/decrypt", + json={ + "CiphertextBlob": ciphertext, + "EncryptionContext": context, + }, + headers=auth_headers, + ) + + assert decrypt_response.status_code == 200 + + # Decrypt with wrong context fails + wrong_context_response = kms_client.post( + "/kms/decrypt", + json={ + "CiphertextBlob": ciphertext, + "EncryptionContext": {"wrong": "context"}, + }, + headers=auth_headers, + ) + + assert wrong_context_response.status_code == 400 + + def test_encrypt_missing_key_id(self, kms_client, auth_headers): + """Test encryption without KeyId.""" + response = kms_client.post( + "/kms/encrypt", + json={"Plaintext": base64.b64encode(b"data").decode()}, + headers=auth_headers, + ) + + assert response.status_code == 400 + assert "KeyId is required" in response.get_json()["message"] + + def test_encrypt_missing_plaintext(self, kms_client, auth_headers): + """Test encryption without Plaintext.""" + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + response = kms_client.post( + "/kms/encrypt", + json={"KeyId": "test-key"}, + headers=auth_headers, + ) + + assert response.status_code == 400 + assert "Plaintext is required" in response.get_json()["message"] + + +class TestKMSDataKey: + """Tests for KMS data key generation.""" + + def test_generate_data_key(self, kms_client, auth_headers): + """Test generating a data key.""" + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + response = kms_client.post( + "/kms/generate-data-key", + json={"KeyId": "test-key"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + assert "Plaintext" in data + assert "CiphertextBlob" in data + assert data["KeyId"] == "test-key" + + # Verify plaintext key is 256 bits (32 bytes) + plaintext_key = base64.b64decode(data["Plaintext"]) + assert len(plaintext_key) == 32 + + def test_generate_data_key_aes_128(self, kms_client, auth_headers): + """Test generating an AES-128 data key.""" + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + response = kms_client.post( + "/kms/generate-data-key", + json={"KeyId": "test-key", "KeySpec": "AES_128"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + # Verify plaintext key is 128 bits (16 bytes) + plaintext_key = base64.b64decode(data["Plaintext"]) + assert len(plaintext_key) == 16 + + def test_generate_data_key_without_plaintext(self, kms_client, auth_headers): + """Test generating a data key without plaintext.""" + kms_client.post("/kms/keys", json={"KeyId": "test-key"}, headers=auth_headers) + + response = kms_client.post( + "/kms/generate-data-key-without-plaintext", + json={"KeyId": "test-key"}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + assert "CiphertextBlob" in data + assert "Plaintext" not in data + + +class TestKMSReEncrypt: + """Tests for KMS re-encryption.""" + + def test_re_encrypt(self, kms_client, auth_headers): + """Test re-encrypting data with a different key.""" + # Create two keys + kms_client.post("/kms/keys", json={"KeyId": "key-1"}, headers=auth_headers) + kms_client.post("/kms/keys", json={"KeyId": "key-2"}, headers=auth_headers) + + # Encrypt with key-1 + plaintext = b"Data to re-encrypt" + encrypt_response = kms_client.post( + "/kms/encrypt", + json={ + "KeyId": "key-1", + "Plaintext": base64.b64encode(plaintext).decode(), + }, + headers=auth_headers, + ) + + ciphertext = encrypt_response.get_json()["CiphertextBlob"] + + # Re-encrypt with key-2 + re_encrypt_response = kms_client.post( + "/kms/re-encrypt", + json={ + "CiphertextBlob": ciphertext, + "DestinationKeyId": "key-2", + }, + headers=auth_headers, + ) + + assert re_encrypt_response.status_code == 200 + data = re_encrypt_response.get_json() + + assert data["SourceKeyId"] == "key-1" + assert data["KeyId"] == "key-2" + + # Verify new ciphertext can be decrypted + decrypt_response = kms_client.post( + "/kms/decrypt", + json={"CiphertextBlob": data["CiphertextBlob"]}, + headers=auth_headers, + ) + + decrypted = base64.b64decode(decrypt_response.get_json()["Plaintext"]) + assert decrypted == plaintext + + +class TestKMSRandom: + """Tests for random number generation.""" + + def test_generate_random(self, kms_client, auth_headers): + """Test generating random bytes.""" + response = kms_client.post( + "/kms/generate-random", + json={"NumberOfBytes": 64}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + random_bytes = base64.b64decode(data["Plaintext"]) + assert len(random_bytes) == 64 + + def test_generate_random_default_size(self, kms_client, auth_headers): + """Test generating random bytes with default size.""" + response = kms_client.post( + "/kms/generate-random", + json={}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + random_bytes = base64.b64decode(data["Plaintext"]) + assert len(random_bytes) == 32 # Default is 32 bytes + + +class TestClientSideEncryption: + """Tests for client-side encryption helpers.""" + + def test_generate_client_key(self, kms_client, auth_headers): + """Test generating a client encryption key.""" + response = kms_client.post( + "/kms/client/generate-key", + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + assert "key" in data + assert data["algorithm"] == "AES-256-GCM" + + key = base64.b64decode(data["key"]) + assert len(key) == 32 + + def test_client_encrypt_decrypt(self, kms_client, auth_headers): + """Test client-side encryption and decryption.""" + # Generate a key + key_response = kms_client.post("/kms/client/generate-key", headers=auth_headers) + key = key_response.get_json()["key"] + + # Encrypt + plaintext = b"Client-side encrypted data" + encrypt_response = kms_client.post( + "/kms/client/encrypt", + json={ + "Plaintext": base64.b64encode(plaintext).decode(), + "Key": key, + }, + headers=auth_headers, + ) + + assert encrypt_response.status_code == 200 + encrypted = encrypt_response.get_json() + + # Decrypt + decrypt_response = kms_client.post( + "/kms/client/decrypt", + json={ + "Ciphertext": encrypted["ciphertext"], + "Nonce": encrypted["nonce"], + "Key": key, + }, + headers=auth_headers, + ) + + assert decrypt_response.status_code == 200 + decrypted = base64.b64decode(decrypt_response.get_json()["Plaintext"]) + assert decrypted == plaintext + + +class TestEncryptionMaterials: + """Tests for S3 encryption materials endpoint.""" + + def test_get_encryption_materials(self, kms_client, auth_headers): + """Test getting encryption materials for client-side S3 encryption.""" + # Create a key + kms_client.post("/kms/keys", json={"KeyId": "s3-key"}, headers=auth_headers) + + response = kms_client.post( + "/kms/materials/s3-key", + json={}, + headers=auth_headers, + ) + + assert response.status_code == 200 + data = response.get_json() + + assert "PlaintextKey" in data + assert "EncryptedKey" in data + assert data["KeyId"] == "s3-key" + assert data["Algorithm"] == "AES-256-GCM" + + # Verify key is 256 bits + key = base64.b64decode(data["PlaintextKey"]) + assert len(key) == 32 + + +class TestKMSAuthentication: + """Tests for KMS authentication requirements.""" + + def test_unauthenticated_request_fails(self, kms_client): + """Test that unauthenticated requests are rejected.""" + response = kms_client.get("/kms/keys") + + # Should fail with 403 (no credentials) + assert response.status_code == 403 + + def test_invalid_credentials_fail(self, kms_client): + """Test that invalid credentials are rejected.""" + response = kms_client.get( + "/kms/keys", + headers={ + "X-Access-Key": "wrong-key", + "X-Secret-Key": "wrong-secret", + }, + ) + + assert response.status_code == 403 diff --git a/tests/test_ui_encryption.py b/tests/test_ui_encryption.py new file mode 100644 index 0000000..a5d533b --- /dev/null +++ b/tests/test_ui_encryption.py @@ -0,0 +1,268 @@ +"""Tests for UI-based encryption configuration.""" +import json +from pathlib import Path + +import pytest + +from app import create_app + + +def get_csrf_token(response): + """Extract CSRF token from response HTML.""" + html = response.data.decode("utf-8") + import re + match = re.search(r'name="csrf_token"\s+value="([^"]+)"', html) + return match.group(1) if match else None + + +def _make_encryption_app(tmp_path: Path, *, kms_enabled: bool = True): + """Create an app with encryption enabled.""" + storage_root = tmp_path / "data" + iam_config = tmp_path / "iam.json" + bucket_policies = tmp_path / "bucket_policies.json" + iam_payload = { + "users": [ + { + "access_key": "test", + "secret_key": "secret", + "display_name": "Test User", + "policies": [{"bucket": "*", "actions": ["list", "read", "write", "delete", "policy"]}], + }, + { + "access_key": "readonly", + "secret_key": "secret", + "display_name": "Read Only User", + "policies": [{"bucket": "*", "actions": ["list", "read"]}], + }, + ] + } + iam_config.write_text(json.dumps(iam_payload)) + + config = { + "TESTING": True, + "STORAGE_ROOT": storage_root, + "IAM_CONFIG": iam_config, + "BUCKET_POLICY_PATH": bucket_policies, + "API_BASE_URL": "http://testserver", + "SECRET_KEY": "testing", + "ENCRYPTION_ENABLED": True, + } + + if kms_enabled: + config["KMS_ENABLED"] = True + config["KMS_KEYS_PATH"] = str(tmp_path / "kms_keys.json") + config["ENCRYPTION_MASTER_KEY_PATH"] = str(tmp_path / "master.key") + + app = create_app(config) + storage = app.extensions["object_storage"] + storage.create_bucket("test-bucket") + return app + + +class TestUIBucketEncryption: + """Test bucket encryption configuration via UI.""" + + def test_bucket_detail_shows_encryption_card(self, tmp_path): + """Encryption card should be visible on bucket detail page.""" + app = _make_encryption_app(tmp_path) + client = app.test_client() + + # Login first + client.post("/ui/login", data={"access_key": "test", "secret_key": "secret"}, follow_redirects=True) + + response = client.get("/ui/buckets/test-bucket?tab=properties") + assert response.status_code == 200 + + html = response.data.decode("utf-8") + assert "Default Encryption" in html + assert "Encryption Algorithm" in html or "Default encryption disabled" in html + + def test_enable_aes256_encryption(self, tmp_path): + """Should be able to enable AES-256 encryption.""" + app = _make_encryption_app(tmp_path) + client = app.test_client() + + # Login + client.post("/ui/login", data={"access_key": "test", "secret_key": "secret"}, follow_redirects=True) + + # Get CSRF token + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + # Enable AES-256 encryption + response = client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "enable", + "algorithm": "AES256", + }, + follow_redirects=True, + ) + + assert response.status_code == 200 + html = response.data.decode("utf-8") + # Should see success message or enabled state + assert "AES-256" in html or "encryption enabled" in html.lower() + + def test_enable_kms_encryption(self, tmp_path): + """Should be able to enable KMS encryption.""" + app = _make_encryption_app(tmp_path, kms_enabled=True) + client = app.test_client() + + # Create a KMS key first + with app.app_context(): + kms = app.extensions.get("kms") + if kms: + key = kms.create_key("test-key") + key_id = key.key_id + else: + pytest.skip("KMS not available") + + # Login + client.post("/ui/login", data={"access_key": "test", "secret_key": "secret"}, follow_redirects=True) + + # Get CSRF token + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + # Enable KMS encryption + response = client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "enable", + "algorithm": "aws:kms", + "kms_key_id": key_id, + }, + follow_redirects=True, + ) + + assert response.status_code == 200 + html = response.data.decode("utf-8") + assert "KMS" in html or "encryption enabled" in html.lower() + + def test_disable_encryption(self, tmp_path): + """Should be able to disable encryption.""" + app = _make_encryption_app(tmp_path) + client = app.test_client() + + # Login + client.post("/ui/login", data={"access_key": "test", "secret_key": "secret"}, follow_redirects=True) + + # First enable encryption + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "enable", + "algorithm": "AES256", + }, + ) + + # Now disable it + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + response = client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "disable", + }, + follow_redirects=True, + ) + + assert response.status_code == 200 + html = response.data.decode("utf-8") + assert "disabled" in html.lower() or "Default encryption disabled" in html + + def test_invalid_algorithm_rejected(self, tmp_path): + """Invalid encryption algorithm should be rejected.""" + app = _make_encryption_app(tmp_path) + client = app.test_client() + + # Login + client.post("/ui/login", data={"access_key": "test", "secret_key": "secret"}, follow_redirects=True) + + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + response = client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "enable", + "algorithm": "INVALID", + }, + follow_redirects=True, + ) + + assert response.status_code == 200 + html = response.data.decode("utf-8") + assert "Invalid" in html or "danger" in html + + def test_encryption_persists_in_config(self, tmp_path): + """Encryption config should persist in bucket config.""" + app = _make_encryption_app(tmp_path) + client = app.test_client() + + # Login + client.post("/ui/login", data={"access_key": "test", "secret_key": "secret"}, follow_redirects=True) + + # Enable encryption + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "enable", + "algorithm": "AES256", + }, + ) + + # Verify it's stored + with app.app_context(): + storage = app.extensions["object_storage"] + config = storage.get_bucket_encryption("test-bucket") + + assert "Rules" in config + assert len(config["Rules"]) == 1 + assert config["Rules"][0]["ApplyServerSideEncryptionByDefault"]["SSEAlgorithm"] == "AES256" + + +class TestUIEncryptionWithoutPermission: + """Test encryption UI when user lacks permissions.""" + + def test_readonly_user_cannot_change_encryption(self, tmp_path): + """Read-only user should not be able to change encryption settings.""" + app = _make_encryption_app(tmp_path) + client = app.test_client() + + # Login as readonly user + client.post("/ui/login", data={"access_key": "readonly", "secret_key": "secret"}, follow_redirects=True) + + # This should fail or be rejected + response = client.get("/ui/buckets/test-bucket?tab=properties") + csrf_token = get_csrf_token(response) + + response = client.post( + "/ui/buckets/test-bucket/encryption", + data={ + "csrf_token": csrf_token, + "action": "enable", + "algorithm": "AES256", + }, + follow_redirects=True, + ) + + # Should either redirect with error or show permission denied + assert response.status_code == 200 + html = response.data.decode("utf-8") + # Should contain error about permission denied + assert "Access denied" in html or "permission" in html.lower() or "not authorized" in html.lower()