280 lines
11 KiB
Python
280 lines
11 KiB
Python
"""Encrypted storage layer that wraps ObjectStorage with encryption support."""
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
from pathlib import Path
|
|
from typing import Any, BinaryIO, Dict, Optional
|
|
|
|
from .encryption import EncryptionManager, EncryptionMetadata, EncryptionError
|
|
from .storage import ObjectStorage, ObjectMeta, StorageError
|
|
|
|
|
|
class EncryptedObjectStorage:
|
|
"""Object storage with transparent server-side encryption.
|
|
|
|
This class wraps ObjectStorage and provides transparent encryption/decryption
|
|
of objects based on bucket encryption configuration.
|
|
|
|
Encryption is applied when:
|
|
1. Bucket has default encryption configured (SSE-S3 or SSE-KMS)
|
|
2. Client explicitly requests encryption via headers
|
|
|
|
The encryption metadata is stored alongside object metadata.
|
|
"""
|
|
|
|
STREAMING_THRESHOLD = 64 * 1024
|
|
|
|
def __init__(self, storage: ObjectStorage, encryption_manager: EncryptionManager):
|
|
self.storage = storage
|
|
self.encryption = encryption_manager
|
|
|
|
@property
|
|
def root(self) -> Path:
|
|
return self.storage.root
|
|
|
|
def _should_encrypt(self, bucket_name: str,
|
|
server_side_encryption: str | None = None) -> tuple[bool, str, str | None]:
|
|
"""Determine if object should be encrypted.
|
|
|
|
Returns:
|
|
Tuple of (should_encrypt, algorithm, kms_key_id)
|
|
"""
|
|
if not self.encryption.enabled:
|
|
return False, "", None
|
|
|
|
if server_side_encryption:
|
|
if server_side_encryption == "AES256":
|
|
return True, "AES256", None
|
|
elif server_side_encryption.startswith("aws:kms"):
|
|
parts = server_side_encryption.split(":")
|
|
kms_key_id = parts[2] if len(parts) > 2 else None
|
|
return True, "aws:kms", kms_key_id
|
|
|
|
try:
|
|
encryption_config = self.storage.get_bucket_encryption(bucket_name)
|
|
if encryption_config and encryption_config.get("Rules"):
|
|
rule = encryption_config["Rules"][0]
|
|
# AWS format: Rules[].ApplyServerSideEncryptionByDefault.SSEAlgorithm
|
|
sse_default = rule.get("ApplyServerSideEncryptionByDefault", {})
|
|
algorithm = sse_default.get("SSEAlgorithm", "AES256")
|
|
kms_key_id = sse_default.get("KMSMasterKeyID")
|
|
return True, algorithm, kms_key_id
|
|
except StorageError:
|
|
pass
|
|
|
|
return False, "", None
|
|
|
|
def _is_encrypted(self, metadata: Dict[str, str]) -> bool:
|
|
"""Check if object is encrypted based on its metadata."""
|
|
return "x-amz-server-side-encryption" in metadata
|
|
|
|
def put_object(
|
|
self,
|
|
bucket_name: str,
|
|
object_key: str,
|
|
stream: BinaryIO,
|
|
*,
|
|
metadata: Optional[Dict[str, str]] = None,
|
|
server_side_encryption: Optional[str] = None,
|
|
kms_key_id: Optional[str] = None,
|
|
) -> ObjectMeta:
|
|
"""Store an object, optionally with encryption.
|
|
|
|
Args:
|
|
bucket_name: Name of the bucket
|
|
object_key: Key for the object
|
|
stream: Binary stream of object data
|
|
metadata: Optional user metadata
|
|
server_side_encryption: Encryption algorithm ("AES256" or "aws:kms")
|
|
kms_key_id: KMS key ID (for aws:kms encryption)
|
|
|
|
Returns:
|
|
ObjectMeta with object information
|
|
"""
|
|
should_encrypt, algorithm, detected_kms_key = self._should_encrypt(
|
|
bucket_name, server_side_encryption
|
|
)
|
|
|
|
if kms_key_id is None:
|
|
kms_key_id = detected_kms_key
|
|
|
|
if should_encrypt:
|
|
data = stream.read()
|
|
|
|
try:
|
|
ciphertext, enc_metadata = self.encryption.encrypt_object(
|
|
data,
|
|
algorithm=algorithm,
|
|
kms_key_id=kms_key_id,
|
|
context={"bucket": bucket_name, "key": object_key},
|
|
)
|
|
|
|
combined_metadata = metadata.copy() if metadata else {}
|
|
combined_metadata.update(enc_metadata.to_dict())
|
|
|
|
encrypted_stream = io.BytesIO(ciphertext)
|
|
result = self.storage.put_object(
|
|
bucket_name,
|
|
object_key,
|
|
encrypted_stream,
|
|
metadata=combined_metadata,
|
|
)
|
|
|
|
result.metadata = combined_metadata
|
|
return result
|
|
|
|
except EncryptionError as exc:
|
|
raise StorageError(f"Encryption failed: {exc}") from exc
|
|
else:
|
|
return self.storage.put_object(
|
|
bucket_name,
|
|
object_key,
|
|
stream,
|
|
metadata=metadata,
|
|
)
|
|
|
|
def get_object_data(self, bucket_name: str, object_key: str) -> tuple[bytes, Dict[str, str]]:
|
|
"""Get object data, decrypting if necessary.
|
|
|
|
Returns:
|
|
Tuple of (data, metadata)
|
|
"""
|
|
path = self.storage.get_object_path(bucket_name, object_key)
|
|
metadata = self.storage.get_object_metadata(bucket_name, object_key)
|
|
|
|
with path.open("rb") as f:
|
|
data = f.read()
|
|
|
|
enc_metadata = EncryptionMetadata.from_dict(metadata)
|
|
if enc_metadata:
|
|
try:
|
|
data = self.encryption.decrypt_object(
|
|
data,
|
|
enc_metadata,
|
|
context={"bucket": bucket_name, "key": object_key},
|
|
)
|
|
except EncryptionError as exc:
|
|
raise StorageError(f"Decryption failed: {exc}") from exc
|
|
|
|
clean_metadata = {
|
|
k: v for k, v in metadata.items()
|
|
if not k.startswith("x-amz-encryption")
|
|
and k != "x-amz-encrypted-data-key"
|
|
}
|
|
|
|
return data, clean_metadata
|
|
|
|
def get_object_stream(self, bucket_name: str, object_key: str) -> tuple[BinaryIO, Dict[str, str], int]:
|
|
"""Get object as a stream, decrypting if necessary.
|
|
|
|
Returns:
|
|
Tuple of (stream, metadata, original_size)
|
|
"""
|
|
data, metadata = self.get_object_data(bucket_name, object_key)
|
|
return io.BytesIO(data), metadata, len(data)
|
|
|
|
def list_buckets(self):
|
|
return self.storage.list_buckets()
|
|
|
|
def bucket_exists(self, bucket_name: str) -> bool:
|
|
return self.storage.bucket_exists(bucket_name)
|
|
|
|
def create_bucket(self, bucket_name: str) -> None:
|
|
return self.storage.create_bucket(bucket_name)
|
|
|
|
def delete_bucket(self, bucket_name: str) -> None:
|
|
return self.storage.delete_bucket(bucket_name)
|
|
|
|
def bucket_stats(self, bucket_name: str, cache_ttl: int = 60):
|
|
return self.storage.bucket_stats(bucket_name, cache_ttl)
|
|
|
|
def list_objects(self, bucket_name: str, **kwargs):
|
|
return self.storage.list_objects(bucket_name, **kwargs)
|
|
|
|
def list_objects_all(self, bucket_name: str):
|
|
return self.storage.list_objects_all(bucket_name)
|
|
|
|
def get_object_path(self, bucket_name: str, object_key: str):
|
|
return self.storage.get_object_path(bucket_name, object_key)
|
|
|
|
def get_object_metadata(self, bucket_name: str, object_key: str):
|
|
return self.storage.get_object_metadata(bucket_name, object_key)
|
|
|
|
def delete_object(self, bucket_name: str, object_key: str) -> None:
|
|
return self.storage.delete_object(bucket_name, object_key)
|
|
|
|
def purge_object(self, bucket_name: str, object_key: str) -> None:
|
|
return self.storage.purge_object(bucket_name, object_key)
|
|
|
|
def is_versioning_enabled(self, bucket_name: str) -> bool:
|
|
return self.storage.is_versioning_enabled(bucket_name)
|
|
|
|
def set_bucket_versioning(self, bucket_name: str, enabled: bool) -> None:
|
|
return self.storage.set_bucket_versioning(bucket_name, enabled)
|
|
|
|
def get_bucket_tags(self, bucket_name: str):
|
|
return self.storage.get_bucket_tags(bucket_name)
|
|
|
|
def set_bucket_tags(self, bucket_name: str, tags):
|
|
return self.storage.set_bucket_tags(bucket_name, tags)
|
|
|
|
def get_bucket_cors(self, bucket_name: str):
|
|
return self.storage.get_bucket_cors(bucket_name)
|
|
|
|
def set_bucket_cors(self, bucket_name: str, rules):
|
|
return self.storage.set_bucket_cors(bucket_name, rules)
|
|
|
|
def get_bucket_encryption(self, bucket_name: str):
|
|
return self.storage.get_bucket_encryption(bucket_name)
|
|
|
|
def set_bucket_encryption(self, bucket_name: str, config_payload):
|
|
return self.storage.set_bucket_encryption(bucket_name, config_payload)
|
|
|
|
def get_bucket_lifecycle(self, bucket_name: str):
|
|
return self.storage.get_bucket_lifecycle(bucket_name)
|
|
|
|
def set_bucket_lifecycle(self, bucket_name: str, rules):
|
|
return self.storage.set_bucket_lifecycle(bucket_name, rules)
|
|
|
|
def get_object_tags(self, bucket_name: str, object_key: str):
|
|
return self.storage.get_object_tags(bucket_name, object_key)
|
|
|
|
def set_object_tags(self, bucket_name: str, object_key: str, tags):
|
|
return self.storage.set_object_tags(bucket_name, object_key, tags)
|
|
|
|
def delete_object_tags(self, bucket_name: str, object_key: str):
|
|
return self.storage.delete_object_tags(bucket_name, object_key)
|
|
|
|
def list_object_versions(self, bucket_name: str, object_key: str):
|
|
return self.storage.list_object_versions(bucket_name, object_key)
|
|
|
|
def restore_object_version(self, bucket_name: str, object_key: str, version_id: str):
|
|
return self.storage.restore_object_version(bucket_name, object_key, version_id)
|
|
|
|
def list_orphaned_objects(self, bucket_name: str):
|
|
return self.storage.list_orphaned_objects(bucket_name)
|
|
|
|
def initiate_multipart_upload(self, bucket_name: str, object_key: str, *, metadata=None) -> str:
|
|
return self.storage.initiate_multipart_upload(bucket_name, object_key, metadata=metadata)
|
|
|
|
def upload_multipart_part(self, bucket_name: str, upload_id: str, part_number: int, stream: BinaryIO) -> str:
|
|
return self.storage.upload_multipart_part(bucket_name, upload_id, part_number, stream)
|
|
|
|
def complete_multipart_upload(self, bucket_name: str, upload_id: str, ordered_parts):
|
|
return self.storage.complete_multipart_upload(bucket_name, upload_id, ordered_parts)
|
|
|
|
def abort_multipart_upload(self, bucket_name: str, upload_id: str) -> None:
|
|
return self.storage.abort_multipart_upload(bucket_name, upload_id)
|
|
|
|
def list_multipart_parts(self, bucket_name: str, upload_id: str):
|
|
return self.storage.list_multipart_parts(bucket_name, upload_id)
|
|
|
|
def get_bucket_quota(self, bucket_name: str):
|
|
return self.storage.get_bucket_quota(bucket_name)
|
|
|
|
def set_bucket_quota(self, bucket_name: str, *, max_bytes=None, max_objects=None):
|
|
return self.storage.set_bucket_quota(bucket_name, max_bytes=max_bytes, max_objects=max_objects)
|
|
|
|
def _compute_etag(self, path: Path) -> str:
|
|
return self.storage._compute_etag(path)
|