MyFSIO v0.2.4 Release #16

Merged
kqjy merged 15 commits from next into main 2026-02-01 10:27:11 +00:00
6 changed files with 146 additions and 62 deletions
Showing only changes of commit 0ea54457e8 - Show all commits

View File

@@ -421,18 +421,38 @@ def check_bidirectional_status(site_id: str):
)
if resp.status_code == 200:
try:
remote_data = resp.json()
if not isinstance(remote_data, dict):
raise ValueError("Expected JSON object")
remote_local = remote_data.get("local")
if remote_local is not None and not isinstance(remote_local, dict):
raise ValueError("Expected 'local' to be an object")
remote_peers = remote_data.get("peers", [])
if not isinstance(remote_peers, list):
raise ValueError("Expected 'peers' to be a list")
except (ValueError, json.JSONDecodeError) as e:
logger.warning("Invalid JSON from remote admin API: %s", e)
result["remote_status"] = {"reachable": True, "invalid_response": True}
result["issues"].append({
"code": "REMOTE_INVALID_RESPONSE",
"message": "Remote admin API returned invalid JSON",
"severity": "warning",
})
return jsonify(result)
result["remote_status"] = {
"reachable": True,
"local_site": remote_data.get("local"),
"local_site": remote_local,
"site_sync_enabled": None,
"has_peer_for_us": False,
"peer_connection_configured": False,
"has_bidirectional_rules_for_us": False,
}
remote_peers = remote_data.get("peers", [])
for rp in remote_peers:
if not isinstance(rp, dict):
continue
if local_site and (
rp.get("site_id") == local_site.site_id or
rp.get("endpoint") == local_site.endpoint

View File

@@ -4,12 +4,16 @@ from __future__ import annotations
import base64
import io
import json
import os
import secrets
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, BinaryIO, Dict, Generator, Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
class EncryptionError(Exception):
@@ -110,6 +114,8 @@ class LocalKeyEncryption(EncryptionProvider):
try:
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(base64.b64encode(key).decode())
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key
@@ -145,7 +151,8 @@ class LocalKeyEncryption(EncryptionProvider):
aesgcm = AESGCM(data_key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
return EncryptionResult(
ciphertext=ciphertext,
@@ -159,10 +166,11 @@ class LocalKeyEncryption(EncryptionProvider):
"""Decrypt data using envelope encryption."""
data_key = self._decrypt_data_key(encrypted_data_key)
aesgcm = AESGCM(data_key)
aad = json.dumps(context, sort_keys=True).encode() if context else None
try:
return aesgcm.decrypt(nonce, ciphertext, None)
return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt data: {exc}") from exc
raise EncryptionError("Failed to decrypt data") from exc
class StreamingEncryptor:
@@ -180,12 +188,14 @@ class StreamingEncryptor:
self.chunk_size = chunk_size
def _derive_chunk_nonce(self, base_nonce: bytes, chunk_index: int) -> bytes:
"""Derive a unique nonce for each chunk.
Performance: Use direct byte manipulation instead of full int conversion.
"""
# Performance: Only modify last 4 bytes instead of full 12-byte conversion
return base_nonce[:8] + (chunk_index ^ int.from_bytes(base_nonce[8:], "big")).to_bytes(4, "big")
"""Derive a unique nonce for each chunk using HKDF."""
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=12,
salt=base_nonce,
info=chunk_index.to_bytes(4, "big"),
)
return hkdf.derive(b"chunk_nonce")
def encrypt_stream(self, stream: BinaryIO,
context: Dict[str, str] | None = None) -> tuple[BinaryIO, EncryptionMetadata]:
@@ -404,7 +414,8 @@ class SSECEncryption(EncryptionProvider):
def encrypt(self, plaintext: bytes, context: Dict[str, str] | None = None) -> EncryptionResult:
aesgcm = AESGCM(self.customer_key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
return EncryptionResult(
ciphertext=ciphertext,
@@ -416,10 +427,11 @@ class SSECEncryption(EncryptionProvider):
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes:
aesgcm = AESGCM(self.customer_key)
aad = json.dumps(context, sort_keys=True).encode() if context else None
try:
return aesgcm.decrypt(nonce, ciphertext, None)
return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception as exc:
raise EncryptionError(f"SSE-C decryption failed: {exc}") from exc
raise EncryptionError("SSE-C decryption failed") from exc
def generate_data_key(self) -> tuple[bytes, bytes]:
return self.customer_key, b""
@@ -473,7 +485,7 @@ class ClientEncryptionHelper:
}
@staticmethod
def encrypt_with_key(plaintext: bytes, key_b64: str) -> Dict[str, str]:
def encrypt_with_key(plaintext: bytes, key_b64: str, context: Dict[str, str] | None = None) -> Dict[str, str]:
"""Encrypt data with a client-provided key."""
key = base64.b64decode(key_b64)
if len(key) != 32:
@@ -481,7 +493,8 @@ class ClientEncryptionHelper:
aesgcm = AESGCM(key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext, None)
aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
return {
"ciphertext": base64.b64encode(ciphertext).decode(),
@@ -490,7 +503,7 @@ class ClientEncryptionHelper:
}
@staticmethod
def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str) -> bytes:
def decrypt_with_key(ciphertext_b64: str, nonce_b64: str, key_b64: str, context: Dict[str, str] | None = None) -> bytes:
"""Decrypt data with a client-provided key."""
key = base64.b64decode(key_b64)
nonce = base64.b64decode(nonce_b64)
@@ -500,7 +513,8 @@ class ClientEncryptionHelper:
raise EncryptionError("Key must be 256 bits (32 bytes)")
aesgcm = AESGCM(key)
aad = json.dumps(context, sort_keys=True).encode() if context else None
try:
return aesgcm.decrypt(nonce, ciphertext, None)
return aesgcm.decrypt(nonce, ciphertext, aad)
except Exception as exc:
raise EncryptionError(f"Decryption failed: {exc}") from exc
raise EncryptionError("Decryption failed") from exc

View File

@@ -119,7 +119,7 @@ class IamService:
self._failed_attempts: Dict[str, Deque[datetime]] = {}
self._last_load_time = 0.0
self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {}
self._cache_ttl = 60.0
self._cache_ttl = 10.0
self._last_stat_check = 0.0
self._stat_check_interval = 1.0
self._sessions: Dict[str, Dict[str, Any]] = {}
@@ -150,7 +150,8 @@ class IamService:
f"Access temporarily locked. Try again in {seconds} seconds."
)
record = self._users.get(access_key)
if not record or not hmac.compare_digest(record["secret_key"], secret_key):
stored_secret = record["secret_key"] if record else secrets.token_urlsafe(24)
if not record or not hmac.compare_digest(stored_secret, secret_key):
self._record_failed_attempt(access_key)
raise IamError("Invalid credentials")
self._clear_failed_attempts(access_key)
@@ -212,8 +213,9 @@ class IamService:
"""Validate a session token for an access key."""
session = self._sessions.get(session_token)
if not session:
hmac.compare_digest(access_key, secrets.token_urlsafe(16))
return False
if session["access_key"] != access_key:
if not hmac.compare_digest(session["access_key"], access_key):
return False
if time.time() > session["expires_at"]:
del self._sessions[session_token]

View File

@@ -2,7 +2,10 @@ from __future__ import annotations
import base64
import json
import logging
import os
import secrets
import sys
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
@@ -13,6 +16,8 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from .encryption import EncryptionError, EncryptionProvider, EncryptionResult
logger = logging.getLogger(__name__)
@dataclass
class KMSKey:
@@ -78,7 +83,7 @@ class KMSEncryptionProvider(EncryptionProvider):
aesgcm = AESGCM(data_key)
nonce = secrets.token_bytes(12)
ciphertext = aesgcm.encrypt(nonce, plaintext,
json.dumps(context).encode() if context else None)
json.dumps(context, sort_keys=True).encode() if context else None)
return EncryptionResult(
ciphertext=ciphertext,
@@ -90,15 +95,17 @@ class KMSEncryptionProvider(EncryptionProvider):
def decrypt(self, ciphertext: bytes, nonce: bytes, encrypted_data_key: bytes,
key_id: str, context: Dict[str, str] | None = None) -> bytes:
"""Decrypt data using envelope encryption with KMS."""
# Note: Data key is encrypted without context (AAD), so we decrypt without context
data_key = self.kms.decrypt_data_key(key_id, encrypted_data_key, context=None)
if len(data_key) != 32:
raise EncryptionError("Invalid data key size")
aesgcm = AESGCM(data_key)
try:
return aesgcm.decrypt(nonce, ciphertext,
json.dumps(context).encode() if context else None)
json.dumps(context, sort_keys=True).encode() if context else None)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt data: {exc}") from exc
logger.debug("KMS decryption failed: %s", exc)
raise EncryptionError("Failed to decrypt data") from exc
class KMSManager:
@@ -137,6 +144,8 @@ class KMSManager:
self.master_key_path.write_text(
base64.b64encode(self._master_key).decode()
)
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
return self._master_key
def _load_keys(self) -> None:
@@ -153,8 +162,10 @@ class KMSManager:
encrypted = base64.b64decode(key_data["EncryptedKeyMaterial"])
key.key_material = self._decrypt_key_material(encrypted)
self._keys[key.key_id] = key
except Exception:
pass
except json.JSONDecodeError as exc:
logger.error("Failed to parse KMS keys file: %s", exc)
except (ValueError, KeyError) as exc:
logger.error("Invalid KMS key data: %s", exc)
self._loaded = True
@@ -277,7 +288,7 @@ class KMSManager:
aesgcm = AESGCM(key.key_material)
nonce = secrets.token_bytes(12)
aad = json.dumps(context).encode() if context else None
aad = json.dumps(context, sort_keys=True).encode() if context else None
ciphertext = aesgcm.encrypt(nonce, plaintext, aad)
key_id_bytes = key_id.encode("utf-8")
@@ -306,17 +317,24 @@ class KMSManager:
encrypted = rest[12:]
aesgcm = AESGCM(key.key_material)
aad = json.dumps(context).encode() if context else None
aad = json.dumps(context, sort_keys=True).encode() if context else None
try:
plaintext = aesgcm.decrypt(nonce, encrypted, aad)
return plaintext, key_id
except Exception as exc:
raise EncryptionError(f"Decryption failed: {exc}") from exc
logger.debug("KMS decrypt operation failed: %s", exc)
raise EncryptionError("Decryption failed") from exc
def generate_data_key(self, key_id: str,
context: Dict[str, str] | None = None) -> tuple[bytes, bytes]:
context: Dict[str, str] | None = None,
key_spec: str = "AES_256") -> tuple[bytes, bytes]:
"""Generate a data key and return both plaintext and encrypted versions.
Args:
key_id: The KMS key ID to use for encryption
context: Optional encryption context
key_spec: Key specification - AES_128 or AES_256 (default)
Returns:
Tuple of (plaintext_key, encrypted_key)
"""
@@ -327,7 +345,8 @@ class KMSManager:
if not key.enabled:
raise EncryptionError(f"Key is disabled: {key_id}")
plaintext_key = secrets.token_bytes(32)
key_bytes = 32 if key_spec == "AES_256" else 16
plaintext_key = secrets.token_bytes(key_bytes)
encrypted_key = self.encrypt(key_id, plaintext_key, context)

View File

@@ -1131,6 +1131,33 @@ def _object_tagging_handler(bucket_name: str, object_key: str) -> Response:
return Response(status=204)
def _validate_cors_origin(origin: str) -> bool:
"""Validate a CORS origin pattern."""
import re
origin = origin.strip()
if not origin:
return False
if origin == "*":
return True
if origin.startswith("*."):
domain = origin[2:]
if not domain or ".." in domain:
return False
return bool(re.match(r'^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?)*$', domain))
if origin.startswith(("http://", "https://")):
try:
from urllib.parse import urlparse
parsed = urlparse(origin)
if not parsed.netloc:
return False
if parsed.path and parsed.path != "/":
return False
return True
except Exception:
return False
return False
def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]:
sanitized: list[dict[str, Any]] = []
for rule in rules:
@@ -1140,6 +1167,13 @@ def _sanitize_cors_rules(rules: list[dict[str, Any]]) -> list[dict[str, Any]]:
expose_headers = [header.strip() for header in rule.get("ExposeHeaders", []) if header and header.strip()]
if not allowed_origins or not allowed_methods:
raise ValueError("Each CORSRule must include AllowedOrigin and AllowedMethod entries")
for origin in allowed_origins:
if not _validate_cors_origin(origin):
raise ValueError(f"Invalid CORS origin: {origin}")
valid_methods = {"GET", "PUT", "POST", "DELETE", "HEAD"}
for method in allowed_methods:
if method not in valid_methods:
raise ValueError(f"Invalid CORS method: {method}")
sanitized_rule: dict[str, Any] = {
"AllowedOrigins": allowed_origins,
"AllowedMethods": allowed_methods,
@@ -2259,15 +2293,13 @@ def bucket_handler(bucket_name: str) -> Response:
continuation_token = request.args.get("continuation-token", "") # ListObjectsV2
start_after = request.args.get("start-after", "") # ListObjectsV2
# For ListObjectsV2, continuation-token takes precedence, then start-after
# For ListObjects v1, use marker
effective_start = ""
if list_type == "2":
if continuation_token:
try:
effective_start = base64.urlsafe_b64decode(continuation_token.encode()).decode("utf-8")
except (ValueError, UnicodeDecodeError):
effective_start = continuation_token
return _error_response("InvalidArgument", "Invalid continuation token", 400)
elif start_after:
effective_start = start_after
else:

View File

@@ -320,7 +320,6 @@ class ObjectStorage:
total_count = len(all_keys)
start_index = 0
if continuation_token:
try:
import bisect
start_index = bisect.bisect_right(all_keys, continuation_token)
if start_index >= total_count:
@@ -330,8 +329,6 @@ class ObjectStorage:
next_continuation_token=None,
total_count=total_count,
)
except Exception:
pass
end_index = start_index + max_keys
keys_slice = all_keys[start_index:end_index]