diff --git a/app/admin_api.py b/app/admin_api.py index f480614..c554579 100644 --- a/app/admin_api.py +++ b/app/admin_api.py @@ -421,18 +421,38 @@ def check_bidirectional_status(site_id: str): ) if resp.status_code == 200: - remote_data = resp.json() + try: + remote_data = resp.json() + if not isinstance(remote_data, dict): + raise ValueError("Expected JSON object") + remote_local = remote_data.get("local") + if remote_local is not None and not isinstance(remote_local, dict): + raise ValueError("Expected 'local' to be an object") + remote_peers = remote_data.get("peers", []) + if not isinstance(remote_peers, list): + raise ValueError("Expected 'peers' to be a list") + except (ValueError, json.JSONDecodeError) as e: + logger.warning("Invalid JSON from remote admin API: %s", e) + result["remote_status"] = {"reachable": True, "invalid_response": True} + result["issues"].append({ + "code": "REMOTE_INVALID_RESPONSE", + "message": "Remote admin API returned invalid JSON", + "severity": "warning", + }) + return jsonify(result) + result["remote_status"] = { "reachable": True, - "local_site": remote_data.get("local"), + "local_site": remote_local, "site_sync_enabled": None, "has_peer_for_us": False, "peer_connection_configured": False, "has_bidirectional_rules_for_us": False, } - remote_peers = remote_data.get("peers", []) for rp in remote_peers: + if not isinstance(rp, dict): + continue if local_site and ( rp.get("site_id") == local_site.site_id or rp.get("endpoint") == local_site.endpoint diff --git a/app/encryption.py b/app/encryption.py index 3cc18cc..e490a0d 100644 --- a/app/encryption.py +++ b/app/encryption.py @@ -4,12 +4,16 @@ from __future__ import annotations import base64 import io import json +import os import secrets +import sys from dataclasses import dataclass from pathlib import Path from typing import Any, BinaryIO, Dict, Generator, Optional from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives import hashes class EncryptionError(Exception): @@ -110,6 +114,8 @@ class LocalKeyEncryption(EncryptionProvider): try: self.master_key_path.parent.mkdir(parents=True, exist_ok=True) self.master_key_path.write_text(base64.b64encode(key).decode()) + if sys.platform != "win32": + os.chmod(self.master_key_path, 0o600) except OSError as exc: raise EncryptionError(f"Failed to save master key: {exc}") from exc return key @@ -142,11 +148,12 @@ class LocalKeyEncryption(EncryptionProvider): def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: """Encrypt data using envelope encryption.""" data_key, encrypted_data_key = self.generate_data_key() - + aesgcm = AESGCM(data_key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, None) - + aad = json.dumps(context, sort_keys=True).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) + return EncryptionResult( ciphertext=ciphertext, nonce=nonce, @@ -159,10 +166,11 @@ class LocalKeyEncryption(EncryptionProvider): """Decrypt data using envelope encryption.""" data_key = self._decrypt_data_key(encrypted_data_key) aesgcm = AESGCM(data_key) + aad = json.dumps(context, sort_keys=True).encode() if context else None try: - return aesgcm.decrypt(nonce, ciphertext, None) + return aesgcm.decrypt(nonce, ciphertext, aad) except Exception as exc: - raise EncryptionError(f"Failed to decrypt data: {exc}") from exc + raise EncryptionError("Failed to decrypt data") from exc class StreamingEncryptor: @@ -180,12 +188,14 @@ class StreamingEncryptor: self.chunk_size = chunk_size def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes: - """Derive a unique nonce for each chunk. - - Performance: Use direct byte manipulation instead of full int conversion. - """ - # Performance: Only modify last 4 bytes instead of full 12-byte conversion - return base_nonce[:8] + (chunk_index ^ int.from_bytes(base_nonce[8:], "big")).to_bytes(4, "big") + """Derive a unique nonce for each chunk using HKDF.""" + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=12, + salt=base_nonce, + info=chunk_index.to_bytes(4, "big"), + ) + return hkdf.derive(b"chunk_nonce") def encrypt_stream(self, stream: BinaryIO, context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]: @@ -404,7 +414,8 @@ class SSECEncryption(EncryptionProvider): def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: aesgcm = AESGCM(self.customer_key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, None) + aad = json.dumps(context, sort_keys=True).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) return EncryptionResult( ciphertext=ciphertext, @@ -416,10 +427,11 @@ class SSECEncryption(EncryptionProvider): def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, key_id: str, context: Dict[str, str] | None = None) -> bytes: aesgcm = AESGCM(self.customer_key) + aad = json.dumps(context, sort_keys=True).encode() if context else None try: - return aesgcm.decrypt(nonce, ciphertext, None) + return aesgcm.decrypt(nonce, ciphertext, aad) except Exception as exc: - raise EncryptionError(f"SSE-C decryption failed: {exc}") from exc + raise EncryptionError("SSE-C decryption failed") from exc def generate_data_key(self) -> tuple[bytes, bytes]: return self.customer_key, b"" @@ -473,34 +485,36 @@ class ClientEncryptionHelper: } @staticmethod - def encrypt_with_key(plaintext: bytes, key_b64: str) -> Dict[str, str]: + def encrypt_with_key(plaintext: bytes, key_b64: str, context: Dict[str, str] | None = None) -> Dict[str, str]: """Encrypt data with a client-provided key.""" key = base64.b64decode(key_b64) if len(key) != 32: raise EncryptionError("Key must be 256 bits (32 bytes)") - + aesgcm = AESGCM(key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, None) - + aad = json.dumps(context, sort_keys=True).encode() if context else None + ciphertext = aesgcm.encrypt(nonce, plaintext, aad) + return { "ciphertext": base64.b64encode(ciphertext).decode(), "nonce": base64.b64encode(nonce).decode(), "algorithm": "AES-256-GCM", } - + @staticmethod - def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str) -> bytes: + def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str, context: Dict[str, str] | None = None) -> bytes: """Decrypt data with a client-provided key.""" key = base64.b64decode(key_b64) nonce = base64.b64decode(nonce_b64) ciphertext = base64.b64decode(ciphertext_b64) - + if len(key) != 32: raise EncryptionError("Key must be 256 bits (32 bytes)") - + aesgcm = AESGCM(key) + aad = json.dumps(context, sort_keys=True).encode() if context else None try: - return aesgcm.decrypt(nonce, ciphertext, None) + return aesgcm.decrypt(nonce, ciphertext, aad) except Exception as exc: - raise EncryptionError(f"Decryption failed: {exc}") from exc + raise EncryptionError("Decryption failed") from exc diff --git a/app/iam.py b/app/iam.py index 0e5e80f..2098f8a 100644 --- a/app/iam.py +++ b/app/iam.py @@ -119,7 +119,7 @@ class IamService: self._failed_attempts: Dict[str, Deque[datetime]] = {} self._last_load_time = 0.0 self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {} - self._cache_ttl = 60.0 + self._cache_ttl = 10.0 self._last_stat_check = 0.0 self._stat_check_interval = 1.0 self._sessions: Dict[str, Dict[str, Any]] = {} @@ -150,7 +150,8 @@ class IamService: f"Access temporarily locked. Try again in {seconds} seconds." ) record = self._users.get(access_key) - if not record or not hmac.compare_digest(record["secret_key"], secret_key): + stored_secret = record["secret_key"] if record else secrets.token_urlsafe(24) + if not record or not hmac.compare_digest(stored_secret, secret_key): self._record_failed_attempt(access_key) raise IamError("Invalid credentials") self._clear_failed_attempts(access_key) @@ -212,8 +213,9 @@ class IamService: """Validate a session token for an access key.""" session = self._sessions.get(session_token) if not session: + hmac.compare_digest(access_key, secrets.token_urlsafe(16)) return False - if session["access_key"] != access_key: + if not hmac.compare_digest(session["access_key"], access_key): return False if time.time() > session["expires_at"]: del self._sessions[session_token] diff --git a/app/kms.py b/app/kms.py index 9326e2d..6928f67 100644 --- a/app/kms.py +++ b/app/kms.py @@ -2,7 +2,10 @@ from __future__ import annotations import base64 import json +import logging +import os import secrets +import sys import uuid from dataclasses import dataclass, field from datetime import datetime, timezone @@ -13,6 +16,8 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from .encryption import EncryptionError, EncryptionProvider, EncryptionResult +logger = logging.getLogger(__name__) + @dataclass class KMSKey: @@ -74,11 +79,11 @@ class KMSEncryptionProvider(EncryptionProvider): def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: """Encrypt data using envelope encryption with KMS.""" data_key, encrypted_data_key = self.generate_data_key() - + aesgcm = AESGCM(data_key) nonce = secrets.token_bytes(12) - ciphertext = aesgcm.encrypt(nonce, plaintext, - json.dumps(context).encode() if context else None) + ciphertext = aesgcm.encrypt(nonce, plaintext, + json.dumps(context, sort_keys=True).encode() if context else None) return EncryptionResult( ciphertext=ciphertext, @@ -90,15 +95,17 @@ class KMSEncryptionProvider(EncryptionProvider): def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, key_id: str, context: Dict[str, str] | None = None) -> bytes: """Decrypt data using envelope encryption with KMS.""" - # Note: Data key is encrypted without context (AAD), so we decrypt without context data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None) - + if len(data_key) != 32: + raise EncryptionError("Invalid data key size") + aesgcm = AESGCM(data_key) try: return aesgcm.decrypt(nonce, ciphertext, - json.dumps(context).encode() if context else None) + json.dumps(context, sort_keys=True).encode() if context else None) except Exception as exc: - raise EncryptionError(f"Failed to decrypt data: {exc}") from exc + logger.debug("KMS decryption failed: %s", exc) + raise EncryptionError("Failed to decrypt data") from exc class KMSManager: @@ -137,6 +144,8 @@ class KMSManager: self.master_key_path.write_text( base64.b64encode(self._master_key).decode() ) + if sys.platform != "win32": + os.chmod(self.master_key_path, 0o600) return self._master_key def _load_keys(self) -> None: @@ -153,8 +162,10 @@ class KMSManager: encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"]) key.key_material = self._decrypt_key_material(encrypted) self._keys[key.key_id] = key - except Exception: - pass + except json.JSONDecodeError as exc: + logger.error("Failed to parse KMS keys file: %s", exc) + except (ValueError, KeyError) as exc: + logger.error("Invalid KMS key data: %s", exc) self._loaded = True @@ -277,7 +288,7 @@ class KMSManager: aesgcm = AESGCM(key.key_material) nonce = secrets.token_bytes(12) - aad = json.dumps(context).encode() if context else None + aad = json.dumps(context, sort_keys=True).encode() if context else None ciphertext = aesgcm.encrypt(nonce, plaintext, aad) key_id_bytes = key_id.encode("utf-8") @@ -306,17 +317,24 @@ class KMSManager: encrypted = rest[12:] aesgcm = AESGCM(key.key_material) - aad = json.dumps(context).encode() if context else None + aad = json.dumps(context, sort_keys=True).encode() if context else None try: plaintext = aesgcm.decrypt(nonce, encrypted, aad) return plaintext, key_id except Exception as exc: - raise EncryptionError(f"Decryption failed: {exc}") from exc + logger.debug("KMS decrypt operation failed: %s", exc) + raise EncryptionError("Decryption failed") from exc def generate_data_key(self, key_id: str, - context: Dict[str, str] | None = None) -> tuple[bytes, bytes]: + context: Dict[str, str] | None = None, + key_spec: str = "AES_256") -> tuple[bytes, bytes]: """Generate a data key and return both plaintext and encrypted versions. - + + Args: + key_id: The KMS key ID to use for encryption + context: Optional encryption context + key_spec: Key specification - AES_128 or AES_256 (default) + Returns: Tuple of (plaintext_key, encrypted_key) """ @@ -326,11 +344,12 @@ class KMSManager: raise EncryptionError(f"Key not found: {key_id}") if not key.enabled: raise EncryptionError(f"Key is disabled: {key_id}") - - plaintext_key = secrets.token_bytes(32) + + key_bytes = 32 if key_spec == "AES_256" else 16 + plaintext_key = secrets.token_bytes(key_bytes) encrypted_key = self.encrypt(key_id, plaintext_key, context) - + return plaintext_key, encrypted_key def decrypt_data_key(self, key_id: str, encrypted_key: bytes, diff --git a/app/s3_api.py b/app/s3_api.py index 47251cd..d6a79c6 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -1131,6 +1131,33 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response: return Response(status=204) +def _validate_cors_origin(origin: str) -> bool: + """Validate a CORS origin pattern.""" + import re + origin = origin.strip() + if not origin: + return False + if origin == "*": + return True + if origin.startswith("*."): + domain = origin[2:] + if not domain or ".." in domain: + return False + return bool(re.match(r'^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?)*$', domain)) + if origin.startswith(("http://", "https://")): + try: + from urllib.parse import urlparse + parsed = urlparse(origin) + if not parsed.netloc: + return False + if parsed.path and parsed.path != "/": + return False + return True + except Exception: + return False + return False + + def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]: sanitized: list[dict[str, Any]] = [] for rule in rules: @@ -1140,6 +1167,13 @@ def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]: expose_headers = [header.strip() for header in rule.get("ExposeHeaders", []) if header and header.strip()] if not allowed_origins or not allowed_methods: raise ValueError("Each CORSRule must include AllowedOrigin and AllowedMethod entries") + for origin in allowed_origins: + if not _validate_cors_origin(origin): + raise ValueError(f"Invalid CORS origin: {origin}") + valid_methods = {"GET", "PUT", "POST", "DELETE", "HEAD"} + for method in allowed_methods: + if method not in valid_methods: + raise ValueError(f"Invalid CORS method: {method}") sanitized_rule: dict[str, Any] = { "AllowedOrigins": allowed_origins, "AllowedMethods": allowed_methods, @@ -2259,15 +2293,13 @@ def bucket_handler(bucket_name: str) -> Response: continuation_token = request.args.get("continuation-token", "") # ListObjectsV2 start_after = request.args.get("start-after", "") # ListObjectsV2 - # For ListObjectsV2, continuation-token takes precedence, then start-after - # For ListObjects v1, use marker effective_start = "" if list_type == "2": if continuation_token: try: effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8") except (ValueError, UnicodeDecodeError): - effective_start = continuation_token + return _error_response("InvalidArgument", "Invalid continuation token", 400) elif start_after: effective_start = start_after else: diff --git a/app/storage.py b/app/storage.py index ea0ec9b..645cbd2 100644 --- a/app/storage.py +++ b/app/storage.py @@ -320,18 +320,15 @@ class ObjectStorage: total_count = len(all_keys) start_index = 0 if continuation_token: - try: - import bisect - start_index = bisect.bisect_right(all_keys, continuation_token) - if start_index >= total_count: - return ListObjectsResult( - objects=[], - is_truncated=False, - next_continuation_token=None, - total_count=total_count, - ) - except Exception: - pass + import bisect + start_index = bisect.bisect_right(all_keys, continuation_token) + if start_index >= total_count: + return ListObjectsResult( + objects=[], + is_truncated=False, + next_continuation_token=None, + total_count=total_count, + ) end_index = start_index + max_keys keys_slice = all_keys[start_index:end_index]