Files
MyFSIO/app/encryption.py

506 lines
19 KiB
Python

"""Encryption providers for server-side and client-side encryption."""
from __future__ import annotations
import base64
import io
import json
import secrets
from dataclasses import dataclass
from pathlib import Path
from typing import Any, BinaryIO, Dict, Generator, Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
class EncryptionError(Exception):
"""Raised when encryption/decryption fails."""
@dataclass
class EncryptionResult:
"""Result of encrypting data."""
ciphertext: bytes
nonce: bytes
key_id: str
encrypted_data_key: bytes
@dataclass
class EncryptionMetadata:
"""Metadata stored with encrypted objects."""
algorithm: str
key_id: str
nonce: bytes
encrypted_data_key: bytes
def to_dict(self) -> Dict[str, str]:
return {
"x-amz-server-side-encryption": self.algorithm,
"x-amz-encryption-key-id": self.key_id,
"x-amz-encryption-nonce": base64.b64encode(self.nonce).decode(),
"x-amz-encrypted-data-key": base64.b64encode(self.encrypted_data_key).decode(),
}
@classmethod
def from_dict(cls, data: Dict[str, str]) -> Optional["EncryptionMetadata"]:
algorithm = data.get("x-amz-server-side-encryption")
if not algorithm:
return None
try:
return cls(
algorithm=algorithm,
key_id=data.get("x-amz-encryption-key-id", "local"),
nonce=base64.b64decode(data.get("x-amz-encryption-nonce", "")),
encrypted_data_key=base64.b64decode(data.get("x-amz-encrypted-data-key", "")),
)
except Exception:
return None
class EncryptionProvider:
"""Base class for encryption providers."""
def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult:
raise NotImplementedError
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes:
raise NotImplementedError
def generate_data_key(self) -> tuple[bytes, bytes]:
"""Generate a data key and its encrypted form.
Returns:
Tuple of (plaintext_key, encrypted_key)
"""
raise NotImplementedError
class LocalKeyEncryption(EncryptionProvider):
"""SSE-S3 style encryption using a local master key.
Uses envelope encryption:
1. Generate a unique data key for each object
2. Encrypt the data with the data key (AES-256-GCM)
3. Encrypt the data key with the master key
4. Store the encrypted data key alongside the ciphertext
"""
KEY_ID = "local"
def __init__(self, master_key_path: Path):
self.master_key_path = master_key_path
self._master_key: bytes | None = None
@property
def master_key(self) -> bytes:
if self._master_key is None:
self._master_key = self._load_or_create_master_key()
return self._master_key
def _load_or_create_master_key(self) -> bytes:
"""Load master key from file or generate a new one."""
if self.master_key_path.exists():
try:
return base64.b64decode(self.master_key_path.read_text().strip())
except Exception as exc:
raise EncryptionError(f"Failed to load master key: {exc}") from exc
key = secrets.token_bytes(32)
try:
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(base64.b64encode(key).decode())
except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key
def _encrypt_data_key(self, data_key: bytes) -> bytes:
"""Encrypt the data key with the master key."""
aesgcm = AESGCM(self.master_key)
nonce = secrets.token_bytes(12)
encrypted = aesgcm.encrypt(nonce, data_key, None)
return nonce + encrypted
def _decrypt_data_key(self, encrypted_data_key: bytes) -> bytes:
"""Decrypt the data key using the master key."""
if len(encrypted_data_key) < 12 + 32 + 16: # nonce + key + tag
raise EncryptionError("Invalid encrypted data key")
aesgcm = AESGCM(self.master_key)
nonce = encrypted_data_key[:12]
ciphertext = encrypted_data_key[12:]
try:
return aesgcm.decrypt(nonce, ciphertext, None)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt data key: {exc}") from exc
def generate_data_key(self) -> tuple[bytes, bytes]:
"""Generate a data key and its encrypted form."""
plaintext_key = secrets.token_bytes(32)
encrypted_key = self._encrypt_data_key(plaintext_key)
return plaintext_key, encrypted_key
def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult:
"""Encrypt data using envelope encryption."""
data_key, encrypted_data_key = self.generate_data_key()
aesgcm = AESGCM(data_key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
return EncryptionResult(
ciphertext=ciphertext,
nonce=nonce,
key_id=self.KEY_ID,
encrypted_data_key=encrypted_data_key,
)
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes:
"""Decrypt data using envelope encryption."""
data_key = self._decrypt_data_key(encrypted_data_key)
aesgcm = AESGCM(data_key)
try:
return aesgcm.decrypt(nonce, ciphertext, None)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt data: {exc}") from exc
class StreamingEncryptor:
"""Encrypts/decrypts data in streaming fashion for large files.
For large files, we encrypt in chunks. Each chunk is encrypted with the
same data key but a unique nonce derived from the base nonce + chunk index.
"""
CHUNK_SIZE = 64 * 1024
HEADER_SIZE = 4
def __init__(self, provider: EncryptionProvider, chunk_size: int = CHUNK_SIZE):
self.provider = provider
self.chunk_size = chunk_size
def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes:
"""Derive a unique nonce for each chunk.
Performance: Use direct byte manipulation instead of full int conversion.
"""
# Performance: Only modify last 4 bytes instead of full 12-byte conversion
return base_nonce[:8] + (chunk_index ^ int.from_bytes(base_nonce[8:], "big")).to_bytes(4, "big")
def encrypt_stream(self, stream: BinaryIO,
context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]:
"""Encrypt a stream and return encrypted stream + metadata.
Performance: Writes chunks directly to output buffer instead of accumulating in list.
"""
data_key, encrypted_data_key = self.provider.generate_data_key()
base_nonce = secrets.token_bytes(12)
aesgcm = AESGCM(data_key)
# Performance: Write directly to BytesIO instead of accumulating chunks
output = io.BytesIO()
output.write(b"\x00\x00\x00\x00") # Placeholder for chunk count
chunk_index = 0
while True:
chunk = stream.read(self.chunk_size)
if not chunk:
break
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
encrypted_chunk = aesgcm.encrypt(chunk_nonce, chunk, None)
# Write size prefix + encrypted chunk directly
output.write(len(encrypted_chunk).to_bytes(self.HEADER_SIZE, "big"))
output.write(encrypted_chunk)
chunk_index += 1
# Write actual chunk count to header
output.seek(0)
output.write(chunk_index.to_bytes(4, "big"))
output.seek(0)
metadata = EncryptionMetadata(
algorithm="AES256",
key_id=self.provider.KEY_ID if hasattr(self.provider, "KEY_ID") else "local",
nonce=base_nonce,
encrypted_data_key=encrypted_data_key,
)
return output, metadata
def decrypt_stream(self, stream: BinaryIO, metadata: EncryptionMetadata) -> BinaryIO:
"""Decrypt a stream using the provided metadata.
Performance: Writes chunks directly to output buffer instead of accumulating in list.
"""
if isinstance(self.provider, LocalKeyEncryption):
data_key = self.provider._decrypt_data_key(metadata.encrypted_data_key)
else:
raise EncryptionError("Unsupported provider for streaming decryption")
aesgcm = AESGCM(data_key)
base_nonce = metadata.nonce
chunk_count_bytes = stream.read(4)
if len(chunk_count_bytes) < 4:
raise EncryptionError("Invalid encrypted stream: missing header")
chunk_count = int.from_bytes(chunk_count_bytes, "big")
# Performance: Write directly to BytesIO instead of accumulating chunks
output = io.BytesIO()
for chunk_index in range(chunk_count):
size_bytes = stream.read(self.HEADER_SIZE)
if len(size_bytes) < self.HEADER_SIZE:
raise EncryptionError(f"Invalid encrypted stream: truncated at chunk {chunk_index}")
chunk_size = int.from_bytes(size_bytes, "big")
encrypted_chunk = stream.read(chunk_size)
if len(encrypted_chunk) < chunk_size:
raise EncryptionError(f"Invalid encrypted stream: incomplete chunk {chunk_index}")
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
try:
decrypted_chunk = aesgcm.decrypt(chunk_nonce, encrypted_chunk, None)
output.write(decrypted_chunk) # Write directly instead of appending to list
except Exception as exc:
raise EncryptionError(f"Failed to decrypt chunk {chunk_index}: {exc}") from exc
output.seek(0)
return output
class EncryptionManager:
"""Manages encryption providers and operations."""
def __init__(self, config: Dict[str, Any]):
self.config = config
self._local_provider: LocalKeyEncryption | None = None
self._kms_provider: Any = None # Set by KMS module
self._streaming_encryptor: StreamingEncryptor | None = None
@property
def enabled(self) -> bool:
return self.config.get("encryption_enabled", False)
@property
def default_algorithm(self) -> str:
return self.config.get("default_encryption_algorithm", "AES256")
def get_local_provider(self) -> LocalKeyEncryption:
if self._local_provider is None:
key_path = Path(self.config.get("encryption_master_key_path", "data/.myfsio.sys/keys/master.key"))
self._local_provider = LocalKeyEncryption(key_path)
return self._local_provider
def set_kms_provider(self, kms_provider: Any) -> None:
"""Set the KMS provider (injected from kms module)."""
self._kms_provider = kms_provider
def get_provider(self, algorithm: str, kms_key_id: str | None = None) -> EncryptionProvider:
"""Get the appropriate encryption provider for the algorithm."""
if algorithm == "AES256":
return self.get_local_provider()
elif algorithm == "aws:kms":
if self._kms_provider is None:
raise EncryptionError("KMS is not configured")
return self._kms_provider.get_provider(kms_key_id)
else:
raise EncryptionError(f"Unsupported encryption algorithm: {algorithm}")
def get_streaming_encryptor(self) -> StreamingEncryptor:
if self._streaming_encryptor is None:
self._streaming_encryptor = StreamingEncryptor(self.get_local_provider())
return self._streaming_encryptor
def encrypt_object(self, data: bytes, algorithm: str = "AES256",
kms_key_id: str | None = None,
context: Dict[str, str] | None = None) -> tuple[bytes, EncryptionMetadata]:
"""Encrypt object data."""
provider = self.get_provider(algorithm, kms_key_id)
result = provider.encrypt(data, context)
metadata = EncryptionMetadata(
algorithm=algorithm,
key_id=result.key_id,
nonce=result.nonce,
encrypted_data_key=result.encrypted_data_key,
)
return result.ciphertext, metadata
def decrypt_object(self, ciphertext: bytes, metadata: EncryptionMetadata,
context: Dict[str, str] | None = None) -> bytes:
"""Decrypt object data."""
provider = self.get_provider(metadata.algorithm, metadata.key_id)
return provider.decrypt(
ciphertext,
metadata.nonce,
metadata.encrypted_data_key,
metadata.key_id,
context,
)
def encrypt_stream(self, stream: BinaryIO, algorithm: str = "AES256",
context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]:
"""Encrypt a stream for large files."""
encryptor = self.get_streaming_encryptor()
return encryptor.encrypt_stream(stream, context)
def decrypt_stream(self, stream: BinaryIO, metadata: EncryptionMetadata) -> BinaryIO:
"""Decrypt a stream."""
encryptor = self.get_streaming_encryptor()
return encryptor.decrypt_stream(stream, metadata)
class SSECEncryption(EncryptionProvider):
"""SSE-C: Server-Side Encryption with Customer-Provided Keys.
The client provides the encryption key with each request.
Server encrypts/decrypts but never stores the key.
Required headers for PUT:
- x-amz-server-side-encryption-customer-algorithm: AES256
- x-amz-server-side-encryption-customer-key: Base64-encoded 256-bit key
- x-amz-server-side-encryption-customer-key-MD5: Base64-encoded MD5 of key
"""
KEY_ID = "customer-provided"
def __init__(self, customer_key: bytes):
if len(customer_key) != 32:
raise EncryptionError("Customer key must be exactly 256 bits (32 bytes)")
self.customer_key = customer_key
@classmethod
def from_headers(cls, headers: Dict[str, str]) -> "SSECEncryption":
algorithm = headers.get("x-amz-server-side-encryption-customer-algorithm", "")
if algorithm.upper() != "AES256":
raise EncryptionError(f"Unsupported SSE-C algorithm: {algorithm}. Only AES256 is supported.")
key_b64 = headers.get("x-amz-server-side-encryption-customer-key", "")
if not key_b64:
raise EncryptionError("Missing x-amz-server-side-encryption-customer-key header")
key_md5_b64 = headers.get("x-amz-server-side-encryption-customer-key-md5", "")
try:
customer_key = base64.b64decode(key_b64)
except Exception as e:
raise EncryptionError(f"Invalid base64 in customer key: {e}") from e
if len(customer_key) != 32:
raise EncryptionError(f"Customer key must be 256 bits, got {len(customer_key) * 8} bits")
if key_md5_b64:
import hashlib
expected_md5 = base64.b64encode(hashlib.md5(customer_key).digest()).decode()
if key_md5_b64 != expected_md5:
raise EncryptionError("Customer key MD5 mismatch")
return cls(customer_key)
def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult:
aesgcm = AESGCM(self.customer_key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
return EncryptionResult(
ciphertext=ciphertext,
nonce=nonce,
key_id=self.KEY_ID,
encrypted_data_key=b"",
)
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes:
aesgcm = AESGCM(self.customer_key)
try:
return aesgcm.decrypt(nonce, ciphertext, None)
except Exception as exc:
raise EncryptionError(f"SSE-C decryption failed: {exc}") from exc
def generate_data_key(self) -> tuple[bytes, bytes]:
return self.customer_key, b""
@dataclass
class SSECMetadata:
algorithm: str = "AES256"
nonce: bytes = b""
key_md5: str = ""
def to_dict(self) -> Dict[str, str]:
return {
"x-amz-server-side-encryption-customer-algorithm": self.algorithm,
"x-amz-encryption-nonce": base64.b64encode(self.nonce).decode(),
"x-amz-server-side-encryption-customer-key-MD5": self.key_md5,
}
@classmethod
def from_dict(cls, data: Dict[str, str]) -> Optional["SSECMetadata"]:
algorithm = data.get("x-amz-server-side-encryption-customer-algorithm")
if not algorithm:
return None
try:
nonce = base64.b64decode(data.get("x-amz-encryption-nonce", ""))
return cls(
algorithm=algorithm,
nonce=nonce,
key_md5=data.get("x-amz-server-side-encryption-customer-key-MD5", ""),
)
except Exception:
return None
class ClientEncryptionHelper:
"""Helpers for client-side encryption.
Client-side encryption is performed by the client, but this helper
provides key generation and materials for clients that need them.
"""
@staticmethod
def generate_client_key() -> Dict[str, str]:
"""Generate a new client encryption key."""
from datetime import datetime, timezone
key = secrets.token_bytes(32)
return {
"key": base64.b64encode(key).decode(),
"algorithm": "AES-256-GCM",
"created_at": datetime.now(timezone.utc).isoformat(),
}
@staticmethod
def encrypt_with_key(plaintext: bytes, key_b64: str) -> Dict[str, str]:
"""Encrypt data with a client-provided key."""
key = base64.b64decode(key_b64)
if len(key) != 32:
raise EncryptionError("Key must be 256 bits (32 bytes)")
aesgcm = AESGCM(key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
return {
"ciphertext": base64.b64encode(ciphertext).decode(),
"nonce": base64.b64encode(nonce).decode(),
"algorithm": "AES-256-GCM",
}
@staticmethod
def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str) -> bytes:
"""Decrypt data with a client-provided key."""
key = base64.b64decode(key_b64)
nonce = base64.b64decode(nonce_b64)
ciphertext = base64.b64decode(ciphertext_b64)
if len(key) != 32:
raise EncryptionError("Key must be 256 bits (32 bytes)")
aesgcm = AESGCM(key)
try:
return aesgcm.decrypt(nonce, ciphertext, None)
except Exception as exc:
raise EncryptionError(f"Decryption failed: {exc}") from exc