Fix 17 security vulnerabilities across encryption, auth, and API modules

This commit is contained in:
2026-01-29 12:05:35 +08:00
parent ae26d22388
commit 0ea54457e8
6 changed files with 146 additions and 62 deletions

View File

@@ -421,18 +421,38 @@ def check_bidirectional_status(site_id: str):
) )
if resp.status_code == 200: if resp.status_code == 200:
try:
remote_data = resp.json() remote_data = resp.json()
if not isinstance(remote_data, dict):
raise ValueError("Expected JSON object")
remote_local = remote_data.get("local")
if remote_local is not None and not isinstance(remote_local, dict):
raise ValueError("Expected 'local' to be an object")
remote_peers = remote_data.get("peers", [])
if not isinstance(remote_peers, list):
raise ValueError("Expected 'peers' to be a list")
except (ValueError, json.JSONDecodeError) as e:
logger.warning("Invalid JSON from remote admin API: %s", e)
result["remote_status"] = {"reachable": True, "invalid_response": True}
result["issues"].append({
"code": "REMOTE_INVALID_RESPONSE",
"message": "Remote admin API returned invalid JSON",
"severity": "warning",
})
return jsonify(result)
result["remote_status"] = { result["remote_status"] = {
"reachable": True, "reachable": True,
"local_site": remote_data.get("local"), "local_site": remote_local,
"site_sync_enabled": None, "site_sync_enabled": None,
"has_peer_for_us": False, "has_peer_for_us": False,
"peer_connection_configured": False, "peer_connection_configured": False,
"has_bidirectional_rules_for_us": False, "has_bidirectional_rules_for_us": False,
} }
remote_peers = remote_data.get("peers", [])
for rp in remote_peers: for rp in remote_peers:
if not isinstance(rp, dict):
continue
if local_site and ( if local_site and (
rp.get("site_id") == local_site.site_id or rp.get("site_id") == local_site.site_id or
rp.get("endpoint") == local_site.endpoint rp.get("endpoint") == local_site.endpoint

View File

@@ -4,12 +4,16 @@ from __future__ import annotations
import base64 import base64
import io import io
import json import json
import os
import secrets import secrets
import sys
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, BinaryIO, Dict, Generator, Optional from typing import Any, BinaryIO, Dict, Generator, Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
class EncryptionError(Exception): class EncryptionError(Exception):
@@ -110,6 +114,8 @@ class LocalKeyEncryption(EncryptionProvider):
try: try:
self.master_key_path.parent.mkdir(parents=True, exist_ok=True) self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(base64.b64encode(key).decode()) self.master_key_path.write_text(base64.b64encode(key).decode())
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
except OSError as exc: except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key return key
@@ -145,7 +151,8 @@ class LocalKeyEncryption(EncryptionProvider):
aesgcm = AESGCM(data_key) aesgcm = AESGCM(data_key)
nonce = secrets.token_bytes(12) nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None) aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
return EncryptionResult( return EncryptionResult(
ciphertext=ciphertext, ciphertext=ciphertext,
@@ -159,10 +166,11 @@ class LocalKeyEncryption(EncryptionProvider):
"""Decrypt data using envelope encryption.""" """Decrypt data using envelope encryption."""
data_key = self._decrypt_data_key(encrypted_data_key) data_key = self._decrypt_data_key(encrypted_data_key)
aesgcm = AESGCM(data_key) aesgcm = AESGCM(data_key)
aad = json.dumps(context, sort_keys=True).encode() if context else None
try: try:
return aesgcm.decrypt(nonce, ciphertext, None) return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception as exc: except Exception as exc:
raise EncryptionError(f"Failed to decrypt data: {exc}") from exc raise EncryptionError("Failed to decrypt data") from exc
class StreamingEncryptor: class StreamingEncryptor:
@@ -180,12 +188,14 @@ class StreamingEncryptor:
self.chunk_size = chunk_size self.chunk_size = chunk_size
def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes: def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes:
"""Derive a unique nonce for each chunk. """Derive a unique nonce for each chunk using HKDF."""
hkdf = HKDF(
Performance: Use direct byte manipulation instead of full int conversion. algorithm=hashes.SHA256(),
""" length=12,
# Performance: Only modify last 4 bytes instead of full 12-byte conversion salt=base_nonce,
return base_nonce[:8] + (chunk_index ^ int.from_bytes(base_nonce[8:], "big")).to_bytes(4, "big") info=chunk_index.to_bytes(4, "big"),
)
return hkdf.derive(b"chunk_nonce")
def encrypt_stream(self, stream: BinaryIO, def encrypt_stream(self, stream: BinaryIO,
context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]: context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]:
@@ -404,7 +414,8 @@ class SSECEncryption(EncryptionProvider):
def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult: def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult:
aesgcm = AESGCM(self.customer_key) aesgcm = AESGCM(self.customer_key)
nonce = secrets.token_bytes(12) nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None) aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
return EncryptionResult( return EncryptionResult(
ciphertext=ciphertext, ciphertext=ciphertext,
@@ -416,10 +427,11 @@ class SSECEncryption(EncryptionProvider):
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes: key_id: str, context: Dict[str, str] | None = None) -> bytes:
aesgcm = AESGCM(self.customer_key) aesgcm = AESGCM(self.customer_key)
aad = json.dumps(context, sort_keys=True).encode() if context else None
try: try:
return aesgcm.decrypt(nonce, ciphertext, None) return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception as exc: except Exception as exc:
raise EncryptionError(f"SSE-C decryption failed: {exc}") from exc raise EncryptionError("SSE-C decryption failed") from exc
def generate_data_key(self) -> tuple[bytes, bytes]: def generate_data_key(self) -> tuple[bytes, bytes]:
return self.customer_key, b"" return self.customer_key, b""
@@ -473,7 +485,7 @@ class ClientEncryptionHelper:
} }
@staticmethod @staticmethod
def encrypt_with_key(plaintext: bytes, key_b64: str) -> Dict[str, str]: def encrypt_with_key(plaintext: bytes, key_b64: str, context: Dict[str, str] | None = None) -> Dict[str, str]:
"""Encrypt data with a client-provided key.""" """Encrypt data with a client-provided key."""
key = base64.b64decode(key_b64) key = base64.b64decode(key_b64)
if len(key) != 32: if len(key) != 32:
@@ -481,7 +493,8 @@ class ClientEncryptionHelper:
aesgcm = AESGCM(key) aesgcm = AESGCM(key)
nonce = secrets.token_bytes(12) nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None) aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
return { return {
"ciphertext": base64.b64encode(ciphertext).decode(), "ciphertext": base64.b64encode(ciphertext).decode(),
@@ -490,7 +503,7 @@ class ClientEncryptionHelper:
} }
@staticmethod @staticmethod
def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str) -> bytes: def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str, context: Dict[str, str] | None = None) -> bytes:
"""Decrypt data with a client-provided key.""" """Decrypt data with a client-provided key."""
key = base64.b64decode(key_b64) key = base64.b64decode(key_b64)
nonce = base64.b64decode(nonce_b64) nonce = base64.b64decode(nonce_b64)
@@ -500,7 +513,8 @@ class ClientEncryptionHelper:
raise EncryptionError("Key must be 256 bits (32 bytes)") raise EncryptionError("Key must be 256 bits (32 bytes)")
aesgcm = AESGCM(key) aesgcm = AESGCM(key)
aad = json.dumps(context, sort_keys=True).encode() if context else None
try: try:
return aesgcm.decrypt(nonce, ciphertext, None) return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception as exc: except Exception as exc:
raise EncryptionError(f"Decryption failed: {exc}") from exc raise EncryptionError("Decryption failed") from exc

View File

@@ -119,7 +119,7 @@ class IamService:
self._failed_attempts: Dict[str, Deque[datetime]] = {} self._failed_attempts: Dict[str, Deque[datetime]] = {}
self._last_load_time = 0.0 self._last_load_time = 0.0
self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {} self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {}
self._cache_ttl = 60.0 self._cache_ttl = 10.0
self._last_stat_check = 0.0 self._last_stat_check = 0.0
self._stat_check_interval = 1.0 self._stat_check_interval = 1.0
self._sessions: Dict[str, Dict[str, Any]] = {} self._sessions: Dict[str, Dict[str, Any]] = {}
@@ -150,7 +150,8 @@ class IamService:
f"Access temporarily locked. Try again in {seconds} seconds." f"Access temporarily locked. Try again in {seconds} seconds."
) )
record = self._users.get(access_key) record = self._users.get(access_key)
if not record or not hmac.compare_digest(record["secret_key"], secret_key): stored_secret = record["secret_key"] if record else secrets.token_urlsafe(24)
if not record or not hmac.compare_digest(stored_secret, secret_key):
self._record_failed_attempt(access_key) self._record_failed_attempt(access_key)
raise IamError("Invalid credentials") raise IamError("Invalid credentials")
self._clear_failed_attempts(access_key) self._clear_failed_attempts(access_key)
@@ -212,8 +213,9 @@ class IamService:
"""Validate a session token for an access key.""" """Validate a session token for an access key."""
session = self._sessions.get(session_token) session = self._sessions.get(session_token)
if not session: if not session:
hmac.compare_digest(access_key, secrets.token_urlsafe(16))
return False return False
if session["access_key"] != access_key: if not hmac.compare_digest(session["access_key"], access_key):
return False return False
if time.time() > session["expires_at"]: if time.time() > session["expires_at"]:
del self._sessions[session_token] del self._sessions[session_token]

View File

@@ -2,7 +2,10 @@ from __future__ import annotations
import base64 import base64
import json import json
import logging
import os
import secrets import secrets
import sys
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -13,6 +16,8 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from .encryption import EncryptionError, EncryptionProvider, EncryptionResult from .encryption import EncryptionError, EncryptionProvider, EncryptionResult
logger = logging.getLogger(__name__)
@dataclass @dataclass
class KMSKey: class KMSKey:
@@ -78,7 +83,7 @@ class KMSEncryptionProvider(EncryptionProvider):
aesgcm = AESGCM(data_key) aesgcm = AESGCM(data_key)
nonce = secrets.token_bytes(12) nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, ciphertext = aesgcm.encrypt(nonce, plaintext,
json.dumps(context).encode() if context else None) json.dumps(context, sort_keys=True).encode() if context else None)
return EncryptionResult( return EncryptionResult(
ciphertext=ciphertext, ciphertext=ciphertext,
@@ -90,15 +95,17 @@ class KMSEncryptionProvider(EncryptionProvider):
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes, def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes: key_id: str, context: Dict[str, str] | None = None) -> bytes:
"""Decrypt data using envelope encryption with KMS.""" """Decrypt data using envelope encryption with KMS."""
# Note: Data key is encrypted without context (AAD), so we decrypt without context
data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None) data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None)
if len(data_key) != 32:
raise EncryptionError("Invalid data key size")
aesgcm = AESGCM(data_key) aesgcm = AESGCM(data_key)
try: try:
return aesgcm.decrypt(nonce, ciphertext, return aesgcm.decrypt(nonce, ciphertext,
json.dumps(context).encode() if context else None) json.dumps(context, sort_keys=True).encode() if context else None)
except Exception as exc: except Exception as exc:
raise EncryptionError(f"Failed to decrypt data: {exc}") from exc logger.debug("KMS decryption failed: %s", exc)
raise EncryptionError("Failed to decrypt data") from exc
class KMSManager: class KMSManager:
@@ -137,6 +144,8 @@ class KMSManager:
self.master_key_path.write_text( self.master_key_path.write_text(
base64.b64encode(self._master_key).decode() base64.b64encode(self._master_key).decode()
) )
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
return self._master_key return self._master_key
def _load_keys(self) -> None: def _load_keys(self) -> None:
@@ -153,8 +162,10 @@ class KMSManager:
encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"]) encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"])
key.key_material = self._decrypt_key_material(encrypted) key.key_material = self._decrypt_key_material(encrypted)
self._keys[key.key_id] = key self._keys[key.key_id] = key
except Exception: except json.JSONDecodeError as exc:
pass logger.error("Failed to parse KMS keys file: %s", exc)
except (ValueError, KeyError) as exc:
logger.error("Invalid KMS key data: %s", exc)
self._loaded = True self._loaded = True
@@ -277,7 +288,7 @@ class KMSManager:
aesgcm = AESGCM(key.key_material) aesgcm = AESGCM(key.key_material)
nonce = secrets.token_bytes(12) nonce = secrets.token_bytes(12)
aad = json.dumps(context).encode() if context else None aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad) ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
key_id_bytes = key_id.encode("utf-8") key_id_bytes = key_id.encode("utf-8")
@@ -306,17 +317,24 @@ class KMSManager:
encrypted = rest[12:] encrypted = rest[12:]
aesgcm = AESGCM(key.key_material) aesgcm = AESGCM(key.key_material)
aad = json.dumps(context).encode() if context else None aad = json.dumps(context, sort_keys=True).encode() if context else None
try: try:
plaintext = aesgcm.decrypt(nonce, encrypted, aad) plaintext = aesgcm.decrypt(nonce, encrypted, aad)
return plaintext, key_id return plaintext, key_id
except Exception as exc: except Exception as exc:
raise EncryptionError(f"Decryption failed: {exc}") from exc logger.debug("KMS decrypt operation failed: %s", exc)
raise EncryptionError("Decryption failed") from exc
def generate_data_key(self, key_id: str, def generate_data_key(self, key_id: str,
context: Dict[str, str] | None = None) -> tuple[bytes, bytes]: context: Dict[str, str] | None = None,
key_spec: str = "AES_256") -> tuple[bytes, bytes]:
"""Generate a data key and return both plaintext and encrypted versions. """Generate a data key and return both plaintext and encrypted versions.
Args:
key_id: The KMS key ID to use for encryption
context: Optional encryption context
key_spec: Key specification - AES_128 or AES_256 (default)
Returns: Returns:
Tuple of (plaintext_key, encrypted_key) Tuple of (plaintext_key, encrypted_key)
""" """
@@ -327,7 +345,8 @@ class KMSManager:
if not key.enabled: if not key.enabled:
raise EncryptionError(f"Key is disabled: {key_id}") raise EncryptionError(f"Key is disabled: {key_id}")
plaintext_key = secrets.token_bytes(32) key_bytes = 32 if key_spec == "AES_256" else 16
plaintext_key = secrets.token_bytes(key_bytes)
encrypted_key = self.encrypt(key_id, plaintext_key, context) encrypted_key = self.encrypt(key_id, plaintext_key, context)

View File

@@ -1131,6 +1131,33 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response:
return Response(status=204) return Response(status=204)
def _validate_cors_origin(origin: str) -> bool:
"""Validate a CORS origin pattern."""
import re
origin = origin.strip()
if not origin:
return False
if origin == "*":
return True
if origin.startswith("*."):
domain = origin[2:]
if not domain or ".." in domain:
return False
return bool(re.match(r'^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?)*$', domain))
if origin.startswith(("http://", "https://")):
try:
from urllib.parse import urlparse
parsed = urlparse(origin)
if not parsed.netloc:
return False
if parsed.path and parsed.path != "/":
return False
return True
except Exception:
return False
return False
def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]: def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]:
sanitized: list[dict[str, Any]] = [] sanitized: list[dict[str, Any]] = []
for rule in rules: for rule in rules:
@@ -1140,6 +1167,13 @@ def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]:
expose_headers = [header.strip() for header in rule.get("ExposeHeaders", []) if header and header.strip()] expose_headers = [header.strip() for header in rule.get("ExposeHeaders", []) if header and header.strip()]
if not allowed_origins or not allowed_methods: if not allowed_origins or not allowed_methods:
raise ValueError("Each CORSRule must include AllowedOrigin and AllowedMethod entries") raise ValueError("Each CORSRule must include AllowedOrigin and AllowedMethod entries")
for origin in allowed_origins:
if not _validate_cors_origin(origin):
raise ValueError(f"Invalid CORS origin: {origin}")
valid_methods = {"GET", "PUT", "POST", "DELETE", "HEAD"}
for method in allowed_methods:
if method not in valid_methods:
raise ValueError(f"Invalid CORS method: {method}")
sanitized_rule: dict[str, Any] = { sanitized_rule: dict[str, Any] = {
"AllowedOrigins": allowed_origins, "AllowedOrigins": allowed_origins,
"AllowedMethods": allowed_methods, "AllowedMethods": allowed_methods,
@@ -2259,15 +2293,13 @@ def bucket_handler(bucket_name: str) -> Response:
continuation_token = request.args.get("continuation-token", "") # ListObjectsV2 continuation_token = request.args.get("continuation-token", "") # ListObjectsV2
start_after = request.args.get("start-after", "") # ListObjectsV2 start_after = request.args.get("start-after", "") # ListObjectsV2
# For ListObjectsV2, continuation-token takes precedence, then start-after
# For ListObjects v1, use marker
effective_start = "" effective_start = ""
if list_type == "2": if list_type == "2":
if continuation_token: if continuation_token:
try: try:
effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8") effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8")
except (ValueError, UnicodeDecodeError): except (ValueError, UnicodeDecodeError):
effective_start = continuation_token return _error_response("InvalidArgument", "Invalid continuation token", 400)
elif start_after: elif start_after:
effective_start = start_after effective_start = start_after
else: else:

View File

@@ -320,7 +320,6 @@ class ObjectStorage:
total_count = len(all_keys) total_count = len(all_keys)
start_index = 0 start_index = 0
if continuation_token: if continuation_token:
try:
import bisect import bisect
start_index = bisect.bisect_right(all_keys, continuation_token) start_index = bisect.bisect_right(all_keys, continuation_token)
if start_index >= total_count: if start_index >= total_count:
@@ -330,8 +329,6 @@ class ObjectStorage:
next_continuation_token=None, next_continuation_token=None,
total_count=total_count, total_count=total_count,
) )
except Exception:
pass
end_index = start_index + max_keys end_index = start_index + max_keys
keys_slice = all_keys[start_index:end_index] keys_slice = all_keys[start_index:end_index]