Separate Python and Rust into python/ and rust/ with per-stack Dockerfiles
This commit is contained in:
422
python/app/kms.py
Normal file
422
python/app/kms.py
Normal file
@@ -0,0 +1,422 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from .encryption import EncryptionError, EncryptionProvider, EncryptionResult
|
||||
|
||||
if sys.platform != "win32":
|
||||
import fcntl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _set_secure_file_permissions(file_path: Path) -> None:
|
||||
"""Set restrictive file permissions (owner read/write only)."""
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
username = os.environ.get("USERNAME", "")
|
||||
if username:
|
||||
subprocess.run(
|
||||
["icacls", str(file_path), "/inheritance:r",
|
||||
"/grant:r", f"{username}:F"],
|
||||
check=True, capture_output=True
|
||||
)
|
||||
else:
|
||||
logger.warning("Could not set secure permissions on %s: USERNAME not set", file_path)
|
||||
except (subprocess.SubprocessError, OSError) as exc:
|
||||
logger.warning("Failed to set secure permissions on %s: %s", file_path, exc)
|
||||
else:
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KMSKey:
|
||||
"""Represents a KMS encryption key."""
|
||||
key_id: str
|
||||
description: str
|
||||
created_at: str
|
||||
enabled: bool = True
|
||||
key_material: bytes = field(default_factory=lambda: b"", repr=False)
|
||||
|
||||
@property
|
||||
def arn(self) -> str:
|
||||
return f"arn:aws:kms:local:000000000000:key/{self.key_id}"
|
||||
|
||||
def to_dict(self, include_key: bool = False) -> Dict[str, Any]:
|
||||
data = {
|
||||
"KeyId": self.key_id,
|
||||
"Arn": self.arn,
|
||||
"Description": self.description,
|
||||
"CreationDate": self.created_at,
|
||||
"Enabled": self.enabled,
|
||||
"KeyState": "Enabled" if self.enabled else "Disabled",
|
||||
"KeyUsage": "ENCRYPT_DECRYPT",
|
||||
"KeySpec": "SYMMETRIC_DEFAULT",
|
||||
}
|
||||
if include_key:
|
||||
data["KeyMaterial"] = base64.b64encode(self.key_material).decode()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "KMSKey":
|
||||
key_material = b""
|
||||
if "KeyMaterial" in data:
|
||||
key_material = base64.b64decode(data["KeyMaterial"])
|
||||
return cls(
|
||||
key_id=data["KeyId"],
|
||||
description=data.get("Description", ""),
|
||||
created_at=data.get("CreationDate", datetime.now(timezone.utc).isoformat()),
|
||||
enabled=data.get("Enabled", True),
|
||||
key_material=key_material,
|
||||
)
|
||||
|
||||
|
||||
class KMSEncryptionProvider(EncryptionProvider):
|
||||
"""Encryption provider using a specific KMS key."""
|
||||
|
||||
def __init__(self, kms: "KMSManager", key_id: str):
|
||||
self.kms = kms
|
||||
self.key_id = key_id
|
||||
|
||||
@property
|
||||
def KEY_ID(self) -> str:
|
||||
return self.key_id
|
||||
|
||||
def generate_data_key(self) -> tuple[bytes, bytes]:
|
||||
"""Generate a data key encrypted with the KMS key."""
|
||||
return self.kms.generate_data_key(self.key_id)
|
||||
|
||||
def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult:
|
||||
"""Encrypt data using envelope encryption with KMS."""
|
||||
data_key, encrypted_data_key = self.generate_data_key()
|
||||
|
||||
aesgcm = AESGCM(data_key)
|
||||
nonce = secrets.token_bytes(12)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext,
|
||||
json.dumps(context, sort_keys=True).encode() if context else None)
|
||||
|
||||
return EncryptionResult(
|
||||
ciphertext=ciphertext,
|
||||
nonce=nonce,
|
||||
key_id=self.key_id,
|
||||
encrypted_data_key=encrypted_data_key,
|
||||
)
|
||||
|
||||
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
|
||||
key_id: str, context: Dict[str, str] | None = None) -> bytes:
|
||||
"""Decrypt data using envelope encryption with KMS."""
|
||||
data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None)
|
||||
if len(data_key) != 32:
|
||||
raise EncryptionError("Invalid data key size")
|
||||
|
||||
aesgcm = AESGCM(data_key)
|
||||
try:
|
||||
return aesgcm.decrypt(nonce, ciphertext,
|
||||
json.dumps(context, sort_keys=True).encode() if context else None)
|
||||
except Exception as exc:
|
||||
logger.debug("KMS decryption failed: %s", exc)
|
||||
raise EncryptionError("Failed to decrypt data") from exc
|
||||
|
||||
def decrypt_data_key(self, encrypted_data_key: bytes, key_id: str | None = None) -> bytes:
|
||||
"""Decrypt an encrypted data key using KMS."""
|
||||
if key_id is None:
|
||||
key_id = self.key_id
|
||||
data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None)
|
||||
if len(data_key) != 32:
|
||||
raise EncryptionError("Invalid data key size")
|
||||
return data_key
|
||||
|
||||
|
||||
class KMSManager:
|
||||
"""Manages KMS keys and operations.
|
||||
|
||||
This is a local implementation that mimics AWS KMS functionality.
|
||||
Keys are stored encrypted on disk.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keys_path: Path,
|
||||
master_key_path: Path,
|
||||
generate_data_key_min_bytes: int = 1,
|
||||
generate_data_key_max_bytes: int = 1024,
|
||||
):
|
||||
self.keys_path = keys_path
|
||||
self.master_key_path = master_key_path
|
||||
self.generate_data_key_min_bytes = generate_data_key_min_bytes
|
||||
self.generate_data_key_max_bytes = generate_data_key_max_bytes
|
||||
self._keys: Dict[str, KMSKey] = {}
|
||||
self._master_key: bytes | None = None
|
||||
self._master_aesgcm: AESGCM | None = None
|
||||
self._loaded = False
|
||||
|
||||
@property
|
||||
def master_key(self) -> bytes:
|
||||
"""Load or create the master key for encrypting KMS keys (with file locking)."""
|
||||
if self._master_key is None:
|
||||
lock_path = self.master_key_path.with_suffix(".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(lock_path, "w") as lock_file:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
try:
|
||||
if self.master_key_path.exists():
|
||||
self._master_key = base64.b64decode(
|
||||
self.master_key_path.read_text().strip()
|
||||
)
|
||||
else:
|
||||
self._master_key = secrets.token_bytes(32)
|
||||
self.master_key_path.write_text(
|
||||
base64.b64encode(self._master_key).decode()
|
||||
)
|
||||
_set_secure_file_permissions(self.master_key_path)
|
||||
finally:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
self._master_aesgcm = AESGCM(self._master_key)
|
||||
return self._master_key
|
||||
|
||||
def _load_keys(self) -> None:
|
||||
"""Load keys from disk."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
if self.keys_path.exists():
|
||||
try:
|
||||
data = json.loads(self.keys_path.read_text(encoding="utf-8"))
|
||||
for key_data in data.get("keys", []):
|
||||
key = KMSKey.from_dict(key_data)
|
||||
if key_data.get("EncryptedKeyMaterial"):
|
||||
encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"])
|
||||
key.key_material = self._decrypt_key_material(encrypted)
|
||||
self._keys[key.key_id] = key
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error("Failed to parse KMS keys file: %s", exc)
|
||||
except (ValueError, KeyError) as exc:
|
||||
logger.error("Invalid KMS key data: %s", exc)
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def _save_keys(self) -> None:
|
||||
"""Save keys to disk (with encrypted key material)."""
|
||||
keys_data = []
|
||||
for key in self._keys.values():
|
||||
data = key.to_dict(include_key=False)
|
||||
encrypted = self._encrypt_key_material(key.key_material)
|
||||
data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode()
|
||||
keys_data.append(data)
|
||||
|
||||
self.keys_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.keys_path.write_text(
|
||||
json.dumps({"keys": keys_data}, indent=2),
|
||||
encoding="utf-8"
|
||||
)
|
||||
_set_secure_file_permissions(self.keys_path)
|
||||
|
||||
def _encrypt_key_material(self, key_material: bytes) -> bytes:
|
||||
_ = self.master_key
|
||||
nonce = secrets.token_bytes(12)
|
||||
ciphertext = self._master_aesgcm.encrypt(nonce, key_material, None)
|
||||
return nonce + ciphertext
|
||||
|
||||
def _decrypt_key_material(self, encrypted: bytes) -> bytes:
|
||||
_ = self.master_key
|
||||
nonce = encrypted[:12]
|
||||
ciphertext = encrypted[12:]
|
||||
return self._master_aesgcm.decrypt(nonce, ciphertext, None)
|
||||
|
||||
def create_key(self, description: str = "", key_id: str | None = None) -> KMSKey:
|
||||
"""Create a new KMS key."""
|
||||
self._load_keys()
|
||||
|
||||
if key_id is None:
|
||||
key_id = str(uuid.uuid4())
|
||||
|
||||
if key_id in self._keys:
|
||||
raise EncryptionError(f"Key already exists: {key_id}")
|
||||
|
||||
key = KMSKey(
|
||||
key_id=key_id,
|
||||
description=description,
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
enabled=True,
|
||||
key_material=secrets.token_bytes(32),
|
||||
)
|
||||
|
||||
self._keys[key_id] = key
|
||||
self._save_keys()
|
||||
return key
|
||||
|
||||
def get_key(self, key_id: str) -> KMSKey | None:
|
||||
"""Get a key by ID."""
|
||||
self._load_keys()
|
||||
return self._keys.get(key_id)
|
||||
|
||||
def list_keys(self) -> List[KMSKey]:
|
||||
"""List all keys."""
|
||||
self._load_keys()
|
||||
return list(self._keys.values())
|
||||
|
||||
def get_default_key_id(self) -> str:
|
||||
"""Get the default KMS key ID, creating one if none exist."""
|
||||
self._load_keys()
|
||||
for key in self._keys.values():
|
||||
if key.enabled:
|
||||
return key.key_id
|
||||
default_key = self.create_key(description="Default KMS Key")
|
||||
return default_key.key_id
|
||||
|
||||
def get_provider(self, key_id: str | None = None) -> "KMSEncryptionProvider":
|
||||
"""Get a KMS encryption provider for the specified key."""
|
||||
if key_id is None:
|
||||
key_id = self.get_default_key_id()
|
||||
key = self.get_key(key_id)
|
||||
if not key:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
if not key.enabled:
|
||||
raise EncryptionError(f"Key is disabled: {key_id}")
|
||||
return KMSEncryptionProvider(self, key_id)
|
||||
|
||||
def enable_key(self, key_id: str) -> None:
|
||||
"""Enable a key."""
|
||||
self._load_keys()
|
||||
key = self._keys.get(key_id)
|
||||
if not key:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
key.enabled = True
|
||||
self._save_keys()
|
||||
|
||||
def disable_key(self, key_id: str) -> None:
|
||||
"""Disable a key."""
|
||||
self._load_keys()
|
||||
key = self._keys.get(key_id)
|
||||
if not key:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
key.enabled = False
|
||||
self._save_keys()
|
||||
|
||||
def delete_key(self, key_id: str) -> None:
|
||||
"""Delete a key (schedule for deletion in real KMS)."""
|
||||
self._load_keys()
|
||||
if key_id not in self._keys:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
del self._keys[key_id]
|
||||
self._save_keys()
|
||||
|
||||
def encrypt(self, key_id: str, plaintext: bytes,
|
||||
context: Dict[str, str] | None = None) -> bytes:
|
||||
"""Encrypt data directly with a KMS key."""
|
||||
self._load_keys()
|
||||
key = self._keys.get(key_id)
|
||||
if not key:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
if not key.enabled:
|
||||
raise EncryptionError(f"Key is disabled: {key_id}")
|
||||
|
||||
aesgcm = AESGCM(key.key_material)
|
||||
nonce = secrets.token_bytes(12)
|
||||
aad = json.dumps(context, sort_keys=True).encode() if context else None
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
|
||||
|
||||
key_id_bytes = key_id.encode("utf-8")
|
||||
return len(key_id_bytes).to_bytes(2, "big") + key_id_bytes + nonce + ciphertext
|
||||
|
||||
def decrypt(self, ciphertext: bytes,
|
||||
context: Dict[str, str] | None = None) -> tuple[bytes, str]:
|
||||
"""Decrypt data directly with a KMS key.
|
||||
|
||||
Returns:
|
||||
Tuple of (plaintext, key_id)
|
||||
"""
|
||||
self._load_keys()
|
||||
|
||||
key_id_len = int.from_bytes(ciphertext[:2], "big")
|
||||
key_id = ciphertext[2:2 + key_id_len].decode("utf-8")
|
||||
rest = ciphertext[2 + key_id_len:]
|
||||
|
||||
key = self._keys.get(key_id)
|
||||
if not key:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
if not key.enabled:
|
||||
raise EncryptionError(f"Key is disabled: {key_id}")
|
||||
|
||||
nonce = rest[:12]
|
||||
encrypted = rest[12:]
|
||||
|
||||
aesgcm = AESGCM(key.key_material)
|
||||
aad = json.dumps(context, sort_keys=True).encode() if context else None
|
||||
try:
|
||||
plaintext = aesgcm.decrypt(nonce, encrypted, aad)
|
||||
return plaintext, key_id
|
||||
except Exception as exc:
|
||||
logger.debug("KMS decrypt operation failed: %s", exc)
|
||||
raise EncryptionError("Decryption failed") from exc
|
||||
|
||||
def generate_data_key(self, key_id: str,
|
||||
context: Dict[str, str] | None = None,
|
||||
key_spec: str = "AES_256") -> tuple[bytes, bytes]:
|
||||
"""Generate a data key and return both plaintext and encrypted versions.
|
||||
|
||||
Args:
|
||||
key_id: The KMS key ID to use for encryption
|
||||
context: Optional encryption context
|
||||
key_spec: Key specification - AES_128 or AES_256 (default)
|
||||
|
||||
Returns:
|
||||
Tuple of (plaintext_key, encrypted_key)
|
||||
"""
|
||||
self._load_keys()
|
||||
key = self._keys.get(key_id)
|
||||
if not key:
|
||||
raise EncryptionError(f"Key not found: {key_id}")
|
||||
if not key.enabled:
|
||||
raise EncryptionError(f"Key is disabled: {key_id}")
|
||||
|
||||
key_bytes = 32 if key_spec == "AES_256" else 16
|
||||
plaintext_key = secrets.token_bytes(key_bytes)
|
||||
|
||||
encrypted_key = self.encrypt(key_id, plaintext_key, context)
|
||||
|
||||
return plaintext_key, encrypted_key
|
||||
|
||||
def decrypt_data_key(self, key_id: str, encrypted_key: bytes,
|
||||
context: Dict[str, str] | None = None) -> bytes:
|
||||
"""Decrypt a data key."""
|
||||
plaintext, _ = self.decrypt(encrypted_key, context)
|
||||
return plaintext
|
||||
|
||||
def re_encrypt(self, ciphertext: bytes, destination_key_id: str,
|
||||
source_context: Dict[str, str] | None = None,
|
||||
destination_context: Dict[str, str] | None = None) -> bytes:
|
||||
"""Re-encrypt data with a different key."""
|
||||
|
||||
plaintext, source_key_id = self.decrypt(ciphertext, source_context)
|
||||
|
||||
return self.encrypt(destination_key_id, plaintext, destination_context)
|
||||
|
||||
def generate_random(self, num_bytes: int = 32) -> bytes:
|
||||
"""Generate cryptographically secure random bytes."""
|
||||
if num_bytes < self.generate_data_key_min_bytes or num_bytes > self.generate_data_key_max_bytes:
|
||||
raise EncryptionError(
|
||||
f"Number of bytes must be between {self.generate_data_key_min_bytes} and {self.generate_data_key_max_bytes}"
|
||||
)
|
||||
return secrets.token_bytes(num_bytes)
|
||||
Reference in New Issue
Block a user