11 Commits

Author SHA1 Message Date
eb0e435a5a MyFSIO v0.3.6 Release
Reviewed-on: #29
2026-03-08 04:46:31 +00:00
72f5d9d70c Restore data integrity guarantees: Content-MD5 validation, fsync durability, atomic metadata writes, concurrent write protection 2026-03-07 17:54:00 +08:00
be63e27c15 Reduce per-request CPU overhead: eliminate double stat(), cache content type and policy context, gate logging, configurable stat intervals 2026-03-07 14:08:23 +08:00
7633007a08 MyFSIO v0.3.5 Release
Reviewed-on: #28
2026-03-07 05:53:02 +00:00
81ef0fe4c7 Fix stale object count in bucket header and metrics dashboard after deletes 2026-03-03 19:42:37 +08:00
5f24bd920d Reduce P99 tail latency: defer etag index writes, eliminate double cache rebuild, skip redundant stat() in bucket config 2026-03-02 22:39:37 +08:00
8552f193de Reduce CPU/lock contention under concurrent uploads: split cache lock, in-memory stats, dict copy, lightweight request IDs, defaultdict metrics 2026-03-02 22:05:54 +08:00
de0d869c9f Merge pull request 'MyFSIO v0.3.4 Release' (#27) from next into main
Reviewed-on: #27
2026-03-02 08:31:32 +00:00
5536330aeb Move performance-critical Python functions to Rust: streaming I/O, multipart assembly, and AES-256-GCM encryption 2026-02-27 22:55:20 +08:00
d4657c389d Fix misleading default credentials in README to match actual random generation behavior 2026-02-27 21:58:10 +08:00
3827235232 Reduce CPU usage on heavy uploads: skip SHA256 body hashing in SigV4, use Rust md5_file post-write instead of per-chunk _HashingReader 2026-02-27 21:57:13 +08:00
17 changed files with 1114 additions and 219 deletions

View File

@@ -80,7 +80,7 @@ python run.py --mode api # API only (port 5000)
python run.py --mode ui # UI only (port 5100) python run.py --mode ui # UI only (port 5100)
``` ```
**Default Credentials:** `localadmin` / `localadmin` **Credentials:** Generated automatically on first run and printed to the console. If missed, check the IAM config file at `<STORAGE_ROOT>/.myfsio.sys/config/iam.json`.
- **Web Console:** http://127.0.0.1:5100/ui - **Web Console:** http://127.0.0.1:5100/ui
- **API Endpoint:** http://127.0.0.1:5000 - **API Endpoint:** http://127.0.0.1:5000

View File

@@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
import html as html_module import html as html_module
import itertools
import logging import logging
import mimetypes import mimetypes
import os import os
import shutil import shutil
import sys import sys
import time import time
import uuid
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from pathlib import Path from pathlib import Path
from datetime import timedelta from datetime import timedelta
@@ -39,6 +39,8 @@ from .storage import ObjectStorage, StorageError
from .version import get_version from .version import get_version
from .website_domains import WebsiteDomainStore from .website_domains import WebsiteDomainStore
_request_counter = itertools.count(1)
def _migrate_config_file(active_path: Path, legacy_paths: List[Path]) -> Path: def _migrate_config_file(active_path: Path, legacy_paths: List[Path]) -> Path:
"""Migrate config file from legacy locations to the active path. """Migrate config file from legacy locations to the active path.
@@ -481,13 +483,9 @@ def _configure_logging(app: Flask) -> None:
@app.before_request @app.before_request
def _log_request_start() -> None: def _log_request_start() -> None:
g.request_id = uuid.uuid4().hex g.request_id = f"{os.getpid():x}{next(_request_counter):012x}"
g.request_started_at = time.perf_counter() g.request_started_at = time.perf_counter()
g.request_bytes_in = request.content_length or 0 g.request_bytes_in = request.content_length or 0
app.logger.info(
"Request started",
extra={"path": request.path, "method": request.method, "remote_addr": request.remote_addr},
)
@app.before_request @app.before_request
def _maybe_serve_website(): def _maybe_serve_website():
@@ -616,16 +614,17 @@ def _configure_logging(app: Flask) -> None:
duration_ms = 0.0 duration_ms = 0.0
if hasattr(g, "request_started_at"): if hasattr(g, "request_started_at"):
duration_ms = (time.perf_counter() - g.request_started_at) * 1000 duration_ms = (time.perf_counter() - g.request_started_at) * 1000
request_id = getattr(g, "request_id", uuid.uuid4().hex) request_id = getattr(g, "request_id", f"{os.getpid():x}{next(_request_counter):012x}")
response.headers.setdefault("X-Request-ID", request_id) response.headers.setdefault("X-Request-ID", request_id)
app.logger.info( if app.logger.isEnabledFor(logging.INFO):
"Request completed", app.logger.info(
extra={ "Request completed",
"path": request.path, extra={
"method": request.method, "path": request.path,
"remote_addr": request.remote_addr, "method": request.method,
}, "remote_addr": request.remote_addr,
) },
)
response.headers["X-Request-Duration-ms"] = f"{duration_ms:.2f}" response.headers["X-Request-Duration-ms"] = f"{duration_ms:.2f}"
operation_metrics = app.extensions.get("operation_metrics") operation_metrics = app.extensions.get("operation_metrics")

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import ipaddress import ipaddress
import json import json
import os
import re import re
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -268,7 +269,7 @@ class BucketPolicyStore:
self._last_mtime = self._current_mtime() self._last_mtime = self._current_mtime()
# Performance: Avoid stat() on every request # Performance: Avoid stat() on every request
self._last_stat_check = 0.0 self._last_stat_check = 0.0
self._stat_check_interval = 1.0 # Only check mtime every 1 second self._stat_check_interval = float(os.environ.get("BUCKET_POLICY_STAT_CHECK_INTERVAL_SECONDS", "2.0"))
def maybe_reload(self) -> None: def maybe_reload(self) -> None:
# Performance: Skip stat check if we checked recently # Performance: Skip stat check if we checked recently

View File

@@ -19,6 +19,13 @@ from cryptography.hazmat.primitives import hashes
if sys.platform != "win32": if sys.platform != "win32":
import fcntl import fcntl
try:
import myfsio_core as _rc
_HAS_RUST = True
except ImportError:
_rc = None
_HAS_RUST = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -338,6 +345,69 @@ class StreamingEncryptor:
output.seek(0) output.seek(0)
return output return output
def encrypt_file(self, input_path: str, output_path: str) -> EncryptionMetadata:
data_key, encrypted_data_key = self.provider.generate_data_key()
base_nonce = secrets.token_bytes(12)
if _HAS_RUST:
_rc.encrypt_stream_chunked(
input_path, output_path, data_key, base_nonce, self.chunk_size
)
else:
with open(input_path, "rb") as stream:
aesgcm = AESGCM(data_key)
with open(output_path, "wb") as out:
out.write(b"\x00\x00\x00\x00")
chunk_index = 0
while True:
chunk = stream.read(self.chunk_size)
if not chunk:
break
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
encrypted_chunk = aesgcm.encrypt(chunk_nonce, chunk, None)
out.write(len(encrypted_chunk).to_bytes(self.HEADER_SIZE, "big"))
out.write(encrypted_chunk)
chunk_index += 1
out.seek(0)
out.write(chunk_index.to_bytes(4, "big"))
return EncryptionMetadata(
algorithm="AES256",
key_id=self.provider.KEY_ID if hasattr(self.provider, "KEY_ID") else "local",
nonce=base_nonce,
encrypted_data_key=encrypted_data_key,
)
def decrypt_file(self, input_path: str, output_path: str,
metadata: EncryptionMetadata) -> None:
data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id)
base_nonce = metadata.nonce
if _HAS_RUST:
_rc.decrypt_stream_chunked(input_path, output_path, data_key, base_nonce)
else:
with open(input_path, "rb") as stream:
chunk_count_bytes = stream.read(4)
if len(chunk_count_bytes) < 4:
raise EncryptionError("Invalid encrypted stream: missing header")
chunk_count = int.from_bytes(chunk_count_bytes, "big")
aesgcm = AESGCM(data_key)
with open(output_path, "wb") as out:
for chunk_index in range(chunk_count):
size_bytes = stream.read(self.HEADER_SIZE)
if len(size_bytes) < self.HEADER_SIZE:
raise EncryptionError(f"Invalid encrypted stream: truncated at chunk {chunk_index}")
chunk_size = int.from_bytes(size_bytes, "big")
encrypted_chunk = stream.read(chunk_size)
if len(encrypted_chunk) < chunk_size:
raise EncryptionError(f"Invalid encrypted stream: incomplete chunk {chunk_index}")
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
try:
decrypted_chunk = aesgcm.decrypt(chunk_nonce, encrypted_chunk, None)
out.write(decrypted_chunk)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt chunk {chunk_index}: {exc}") from exc
class EncryptionManager: class EncryptionManager:
"""Manages encryption providers and operations.""" """Manages encryption providers and operations."""

View File

@@ -125,7 +125,7 @@ class IamService:
self._secret_key_cache: Dict[str, Tuple[str, float]] = {} self._secret_key_cache: Dict[str, Tuple[str, float]] = {}
self._cache_ttl = float(os.environ.get("IAM_CACHE_TTL_SECONDS", "5.0")) self._cache_ttl = float(os.environ.get("IAM_CACHE_TTL_SECONDS", "5.0"))
self._last_stat_check = 0.0 self._last_stat_check = 0.0
self._stat_check_interval = 1.0 self._stat_check_interval = float(os.environ.get("IAM_STAT_CHECK_INTERVAL_SECONDS", "2.0"))
self._sessions: Dict[str, Dict[str, Any]] = {} self._sessions: Dict[str, Dict[str, Any]] = {}
self._session_lock = threading.Lock() self._session_lock = threading.Lock()
self._load() self._load()

View File

@@ -5,6 +5,7 @@ import logging
import random import random
import threading import threading
import time import time
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@@ -138,8 +139,8 @@ class OperationMetricsCollector:
self.interval_seconds = interval_minutes * 60 self.interval_seconds = interval_minutes * 60
self.retention_hours = retention_hours self.retention_hours = retention_hours
self._lock = threading.Lock() self._lock = threading.Lock()
self._by_method: Dict[str, OperationStats] = {} self._by_method: Dict[str, OperationStats] = defaultdict(OperationStats)
self._by_endpoint: Dict[str, OperationStats] = {} self._by_endpoint: Dict[str, OperationStats] = defaultdict(OperationStats)
self._by_status_class: Dict[str, int] = {} self._by_status_class: Dict[str, int] = {}
self._error_codes: Dict[str, int] = {} self._error_codes: Dict[str, int] = {}
self._totals = OperationStats() self._totals = OperationStats()
@@ -211,8 +212,8 @@ class OperationMetricsCollector:
self._prune_old_snapshots() self._prune_old_snapshots()
self._save_history() self._save_history()
self._by_method.clear() self._by_method = defaultdict(OperationStats)
self._by_endpoint.clear() self._by_endpoint = defaultdict(OperationStats)
self._by_status_class.clear() self._by_status_class.clear()
self._error_codes.clear() self._error_codes.clear()
self._totals = OperationStats() self._totals = OperationStats()
@@ -232,12 +233,7 @@ class OperationMetricsCollector:
status_class = f"{status_code // 100}xx" status_class = f"{status_code // 100}xx"
with self._lock: with self._lock:
if method not in self._by_method:
self._by_method[method] = OperationStats()
self._by_method[method].record(latency_ms, success, bytes_in, bytes_out) self._by_method[method].record(latency_ms, success, bytes_in, bytes_out)
if endpoint_type not in self._by_endpoint:
self._by_endpoint[endpoint_type] = OperationStats()
self._by_endpoint[endpoint_type].record(latency_ms, success, bytes_in, bytes_out) self._by_endpoint[endpoint_type].record(latency_ms, success, bytes_in, bytes_out)
self._by_status_class[status_class] = self._by_status_class.get(status_class, 0) + 1 self._by_status_class[status_class] = self._by_status_class.get(status_class, 0) + 1

View File

@@ -85,6 +85,9 @@ def _bucket_policies() -> BucketPolicyStore:
def _build_policy_context() -> Dict[str, Any]: def _build_policy_context() -> Dict[str, Any]:
cached = getattr(g, "_policy_context", None)
if cached is not None:
return cached
ctx: Dict[str, Any] = {} ctx: Dict[str, Any] = {}
if request.headers.get("Referer"): if request.headers.get("Referer"):
ctx["aws:Referer"] = request.headers.get("Referer") ctx["aws:Referer"] = request.headers.get("Referer")
@@ -98,6 +101,7 @@ def _build_policy_context() -> Dict[str, Any]:
ctx["aws:SecureTransport"] = str(request.is_secure).lower() ctx["aws:SecureTransport"] = str(request.is_secure).lower()
if request.headers.get("User-Agent"): if request.headers.get("User-Agent"):
ctx["aws:UserAgent"] = request.headers.get("User-Agent") ctx["aws:UserAgent"] = request.headers.get("User-Agent")
g._policy_context = ctx
return ctx return ctx
@@ -293,9 +297,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None:
raise IamError("Required headers not signed") raise IamError("Required headers not signed")
canonical_uri = _get_canonical_uri(req) canonical_uri = _get_canonical_uri(req)
payload_hash = req.headers.get("X-Amz-Content-Sha256") payload_hash = req.headers.get("X-Amz-Content-Sha256") or "UNSIGNED-PAYLOAD"
if not payload_hash:
payload_hash = hashlib.sha256(req.get_data()).hexdigest()
if _HAS_RUST: if _HAS_RUST:
query_params = list(req.args.items(multi=True)) query_params = list(req.args.items(multi=True))
@@ -1023,11 +1025,15 @@ def _apply_object_headers(
file_stat, file_stat,
metadata: Dict[str, str] | None, metadata: Dict[str, str] | None,
etag: str, etag: str,
size_override: int | None = None,
mtime_override: float | None = None,
) -> None: ) -> None:
if file_stat is not None: effective_size = size_override if size_override is not None else (file_stat.st_size if file_stat is not None else None)
if response.status_code != 206: effective_mtime = mtime_override if mtime_override is not None else (file_stat.st_mtime if file_stat is not None else None)
response.headers["Content-Length"] = str(file_stat.st_size) if effective_size is not None and response.status_code != 206:
response.headers["Last-Modified"] = http_date(file_stat.st_mtime) response.headers["Content-Length"] = str(effective_size)
if effective_mtime is not None:
response.headers["Last-Modified"] = http_date(effective_mtime)
response.headers["ETag"] = f'"{etag}"' response.headers["ETag"] = f'"{etag}"'
response.headers["Accept-Ranges"] = "bytes" response.headers["Accept-Ranges"] = "bytes"
for key, value in (metadata or {}).items(): for key, value in (metadata or {}).items():
@@ -2822,6 +2828,8 @@ def object_handler(bucket_name: str, object_key: str):
if validation_error: if validation_error:
return _error_response("InvalidArgument", validation_error, 400) return _error_response("InvalidArgument", validation_error, 400)
metadata["__content_type__"] = content_type or mimetypes.guess_type(object_key)[0] or "application/octet-stream"
try: try:
meta = storage.put_object( meta = storage.put_object(
bucket_name, bucket_name,
@@ -2836,10 +2844,23 @@ def object_handler(bucket_name: str, object_key: str):
if "Bucket" in message: if "Bucket" in message:
return _error_response("NoSuchBucket", message, 404) return _error_response("NoSuchBucket", message, 404)
return _error_response("InvalidArgument", message, 400) return _error_response("InvalidArgument", message, 400)
current_app.logger.info(
"Object uploaded", content_md5 = request.headers.get("Content-MD5")
extra={"bucket": bucket_name, "key": object_key, "size": meta.size}, if content_md5 and meta.etag:
) try:
expected_md5 = base64.b64decode(content_md5).hex()
except Exception:
storage.delete_object(bucket_name, object_key)
return _error_response("InvalidDigest", "Content-MD5 header is not valid base64", 400)
if expected_md5 != meta.etag:
storage.delete_object(bucket_name, object_key)
return _error_response("BadDigest", "The Content-MD5 you specified did not match what we received", 400)
if current_app.logger.isEnabledFor(logging.INFO):
current_app.logger.info(
"Object uploaded",
extra={"bucket": bucket_name, "key": object_key, "size": meta.size},
)
response = Response(status=200) response = Response(status=200)
if meta.etag: if meta.etag:
response.headers["ETag"] = f'"{meta.etag}"' response.headers["ETag"] = f'"{meta.etag}"'
@@ -2873,7 +2894,7 @@ def object_handler(bucket_name: str, object_key: str):
except StorageError as exc: except StorageError as exc:
return _error_response("NoSuchKey", str(exc), 404) return _error_response("NoSuchKey", str(exc), 404)
metadata = storage.get_object_metadata(bucket_name, object_key) metadata = storage.get_object_metadata(bucket_name, object_key)
mimetype = mimetypes.guess_type(object_key)[0] or "application/octet-stream" mimetype = metadata.get("__content_type__") or mimetypes.guess_type(object_key)[0] or "application/octet-stream"
is_encrypted = "x-amz-server-side-encryption" in metadata is_encrypted = "x-amz-server-side-encryption" in metadata
@@ -2965,10 +2986,7 @@ def object_handler(bucket_name: str, object_key: str):
response.headers["Content-Type"] = mimetype response.headers["Content-Type"] = mimetype
logged_bytes = 0 logged_bytes = 0
try: file_stat = stat if not is_encrypted else None
file_stat = path.stat() if not is_encrypted else None
except (PermissionError, OSError):
file_stat = None
_apply_object_headers(response, file_stat=file_stat, metadata=metadata, etag=etag) _apply_object_headers(response, file_stat=file_stat, metadata=metadata, etag=etag)
if request.method == "GET": if request.method == "GET":
@@ -2985,8 +3003,9 @@ def object_handler(bucket_name: str, object_key: str):
if value: if value:
response.headers[header] = _sanitize_header_value(value) response.headers[header] = _sanitize_header_value(value)
action = "Object read" if request.method == "GET" else "Object head" if current_app.logger.isEnabledFor(logging.INFO):
current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes}) action = "Object read" if request.method == "GET" else "Object head"
current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes})
return response return response
if "uploadId" in request.args: if "uploadId" in request.args:
@@ -3004,7 +3023,8 @@ def object_handler(bucket_name: str, object_key: str):
storage.delete_object(bucket_name, object_key) storage.delete_object(bucket_name, object_key)
lock_service.delete_object_lock_metadata(bucket_name, object_key) lock_service.delete_object_lock_metadata(bucket_name, object_key)
current_app.logger.info("Object deleted", extra={"bucket": bucket_name, "key": object_key}) if current_app.logger.isEnabledFor(logging.INFO):
current_app.logger.info("Object deleted", extra={"bucket": bucket_name, "key": object_key})
principal, _ = _require_principal() principal, _ = _require_principal()
_notifications().emit_object_removed( _notifications().emit_object_removed(
@@ -3345,12 +3365,20 @@ def head_object(bucket_name: str, object_key: str) -> Response:
_authorize_action(principal, bucket_name, "read", object_key=object_key) _authorize_action(principal, bucket_name, "read", object_key=object_key)
path = _storage().get_object_path(bucket_name, object_key) path = _storage().get_object_path(bucket_name, object_key)
metadata = _storage().get_object_metadata(bucket_name, object_key) metadata = _storage().get_object_metadata(bucket_name, object_key)
stat = path.stat()
etag = metadata.get("__etag__") or _storage()._compute_etag(path) etag = metadata.get("__etag__") or _storage()._compute_etag(path)
response = Response(status=200) cached_size = metadata.get("__size__")
_apply_object_headers(response, file_stat=stat, metadata=metadata, etag=etag) cached_mtime = metadata.get("__last_modified__")
response.headers["Content-Type"] = mimetypes.guess_type(object_key)[0] or "application/octet-stream" if cached_size is not None and cached_mtime is not None:
size_val = int(cached_size)
mtime_val = float(cached_mtime)
response = Response(status=200)
_apply_object_headers(response, file_stat=None, metadata=metadata, etag=etag, size_override=size_val, mtime_override=mtime_val)
else:
stat = path.stat()
response = Response(status=200)
_apply_object_headers(response, file_stat=stat, metadata=metadata, etag=etag)
response.headers["Content-Type"] = metadata.get("__content_type__") or mimetypes.guess_type(object_key)[0] or "application/octet-stream"
return response return response
except (StorageError, FileNotFoundError): except (StorageError, FileNotFoundError):
return _error_response("NoSuchKey", "Object not found", 404) return _error_response("NoSuchKey", "Object not found", 404)
@@ -3580,6 +3608,8 @@ def _initiate_multipart_upload(bucket_name: str, object_key: str) -> Response:
return error return error
metadata = _extract_request_metadata() metadata = _extract_request_metadata()
content_type = request.headers.get("Content-Type")
metadata["__content_type__"] = content_type or mimetypes.guess_type(object_key)[0] or "application/octet-stream"
try: try:
upload_id = _storage().initiate_multipart_upload( upload_id = _storage().initiate_multipart_upload(
bucket_name, bucket_name,
@@ -3632,6 +3662,15 @@ def _upload_part(bucket_name: str, object_key: str) -> Response:
return _error_response("NoSuchUpload", str(exc), 404) return _error_response("NoSuchUpload", str(exc), 404)
return _error_response("InvalidArgument", str(exc), 400) return _error_response("InvalidArgument", str(exc), 400)
content_md5 = request.headers.get("Content-MD5")
if content_md5 and etag:
try:
expected_md5 = base64.b64decode(content_md5).hex()
except Exception:
return _error_response("InvalidDigest", "Content-MD5 header is not valid base64", 400)
if expected_md5 != etag:
return _error_response("BadDigest", "The Content-MD5 you specified did not match what we received", 400)
response = Response(status=200) response = Response(status=200)
response.headers["ETag"] = f'"{etag}"' response.headers["ETag"] = f'"{etag}"'
return response return response

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import copy
import hashlib import hashlib
import json import json
import os import os
@@ -196,7 +195,9 @@ class ObjectStorage:
self.root.mkdir(parents=True, exist_ok=True) self.root.mkdir(parents=True, exist_ok=True)
self._ensure_system_roots() self._ensure_system_roots()
self._object_cache: OrderedDict[str, tuple[Dict[str, ObjectMeta], float, float]] = OrderedDict() self._object_cache: OrderedDict[str, tuple[Dict[str, ObjectMeta], float, float]] = OrderedDict()
self._cache_lock = threading.Lock() self._obj_cache_lock = threading.Lock()
self._meta_cache_lock = threading.Lock()
self._registry_lock = threading.Lock()
self._bucket_locks: Dict[str, threading.Lock] = {} self._bucket_locks: Dict[str, threading.Lock] = {}
self._cache_version: Dict[str, int] = {} self._cache_version: Dict[str, int] = {}
self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {} self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {}
@@ -209,10 +210,17 @@ class ObjectStorage:
self._meta_read_cache: OrderedDict[tuple, Optional[Dict[str, Any]]] = OrderedDict() self._meta_read_cache: OrderedDict[tuple, Optional[Dict[str, Any]]] = OrderedDict()
self._meta_read_cache_max = 2048 self._meta_read_cache_max = 2048
self._cleanup_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ParentCleanup") self._cleanup_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ParentCleanup")
self._stats_mem: Dict[str, Dict[str, int]] = {}
self._stats_serial: Dict[str, int] = {}
self._stats_mem_time: Dict[str, float] = {}
self._stats_lock = threading.Lock()
self._stats_dirty: set[str] = set()
self._stats_flush_timer: Optional[threading.Timer] = None
self._etag_index_dirty: set[str] = set()
self._etag_index_flush_timer: Optional[threading.Timer] = None
def _get_bucket_lock(self, bucket_id: str) -> threading.Lock: def _get_bucket_lock(self, bucket_id: str) -> threading.Lock:
"""Get or create a lock for a specific bucket. Reduces global lock contention.""" with self._registry_lock:
with self._cache_lock:
if bucket_id not in self._bucket_locks: if bucket_id not in self._bucket_locks:
self._bucket_locks[bucket_id] = threading.Lock() self._bucket_locks[bucket_id] = threading.Lock()
return self._bucket_locks[bucket_id] return self._bucket_locks[bucket_id]
@@ -260,26 +268,24 @@ class ObjectStorage:
self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True) self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True)
def bucket_stats(self, bucket_name: str, cache_ttl: int = 60) -> dict[str, int]: def bucket_stats(self, bucket_name: str, cache_ttl: int = 60) -> dict[str, int]:
"""Return object count and total size for the bucket (cached).
Args:
bucket_name: Name of the bucket
cache_ttl: Cache time-to-live in seconds (default 60)
"""
bucket_path = self._bucket_path(bucket_name) bucket_path = self._bucket_path(bucket_name)
if not bucket_path.exists(): if not bucket_path.exists():
raise BucketNotFoundError("Bucket does not exist") raise BucketNotFoundError("Bucket does not exist")
with self._stats_lock:
if bucket_name in self._stats_mem:
cached_at = self._stats_mem_time.get(bucket_name, 0.0)
if (time.monotonic() - cached_at) < cache_ttl:
return dict(self._stats_mem[bucket_name])
self._stats_mem.pop(bucket_name, None)
self._stats_mem_time.pop(bucket_name, None)
cache_path = self._system_bucket_root(bucket_name) / "stats.json" cache_path = self._system_bucket_root(bucket_name) / "stats.json"
cached_stats = None cached_stats = None
cache_fresh = False
if cache_path.exists(): if cache_path.exists():
try: try:
cache_fresh = time.time() - cache_path.stat().st_mtime < cache_ttl
cached_stats = json.loads(cache_path.read_text(encoding="utf-8")) cached_stats = json.loads(cache_path.read_text(encoding="utf-8"))
if cache_fresh:
return cached_stats
except (OSError, json.JSONDecodeError): except (OSError, json.JSONDecodeError):
pass pass
@@ -348,16 +354,25 @@ class ObjectStorage:
"_cache_serial": existing_serial, "_cache_serial": existing_serial,
} }
with self._stats_lock:
self._stats_mem[bucket_name] = stats
self._stats_mem_time[bucket_name] = time.monotonic()
self._stats_serial[bucket_name] = existing_serial
try: try:
cache_path.parent.mkdir(parents=True, exist_ok=True) cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_text(json.dumps(stats), encoding="utf-8") self._atomic_write_json(cache_path, stats)
except OSError: except OSError:
pass pass
return stats return stats
def _invalidate_bucket_stats_cache(self, bucket_id: str) -> None: def _invalidate_bucket_stats_cache(self, bucket_id: str) -> None:
"""Invalidate the cached bucket statistics.""" with self._stats_lock:
self._stats_mem.pop(bucket_id, None)
self._stats_mem_time.pop(bucket_id, None)
self._stats_serial[bucket_id] = self._stats_serial.get(bucket_id, 0) + 1
self._stats_dirty.discard(bucket_id)
cache_path = self._system_bucket_root(bucket_id) / "stats.json" cache_path = self._system_bucket_root(bucket_id) / "stats.json"
try: try:
cache_path.unlink(missing_ok=True) cache_path.unlink(missing_ok=True)
@@ -373,29 +388,52 @@ class ObjectStorage:
version_bytes_delta: int = 0, version_bytes_delta: int = 0,
version_count_delta: int = 0, version_count_delta: int = 0,
) -> None: ) -> None:
"""Incrementally update cached bucket statistics instead of invalidating. with self._stats_lock:
if bucket_id not in self._stats_mem:
self._stats_mem[bucket_id] = {
"objects": 0, "bytes": 0, "version_count": 0,
"version_bytes": 0, "total_objects": 0, "total_bytes": 0,
"_cache_serial": 0,
}
data = self._stats_mem[bucket_id]
data["objects"] = max(0, data["objects"] + objects_delta)
data["bytes"] = max(0, data["bytes"] + bytes_delta)
data["version_count"] = max(0, data["version_count"] + version_count_delta)
data["version_bytes"] = max(0, data["version_bytes"] + version_bytes_delta)
data["total_objects"] = max(0, data["total_objects"] + objects_delta + version_count_delta)
data["total_bytes"] = max(0, data["total_bytes"] + bytes_delta + version_bytes_delta)
data["_cache_serial"] = data["_cache_serial"] + 1
self._stats_serial[bucket_id] = self._stats_serial.get(bucket_id, 0) + 1
self._stats_mem_time[bucket_id] = time.monotonic()
self._stats_dirty.add(bucket_id)
self._schedule_stats_flush()
This avoids expensive full directory scans on every PUT/DELETE by def _schedule_stats_flush(self) -> None:
adjusting the cached values directly. Also signals cross-process cache if self._stats_flush_timer is None or not self._stats_flush_timer.is_alive():
invalidation by incrementing _cache_serial. self._stats_flush_timer = threading.Timer(3.0, self._flush_stats)
""" self._stats_flush_timer.daemon = True
cache_path = self._system_bucket_root(bucket_id) / "stats.json" self._stats_flush_timer.start()
try:
cache_path.parent.mkdir(parents=True, exist_ok=True) def _flush_stats(self) -> None:
if cache_path.exists(): with self._stats_lock:
data = json.loads(cache_path.read_text(encoding="utf-8")) dirty = list(self._stats_dirty)
else: self._stats_dirty.clear()
data = {"objects": 0, "bytes": 0, "version_count": 0, "version_bytes": 0, "total_objects": 0, "total_bytes": 0, "_cache_serial": 0} snapshots = {b: dict(self._stats_mem[b]) for b in dirty if b in self._stats_mem}
data["objects"] = max(0, data.get("objects", 0) + objects_delta) for bucket_id, data in snapshots.items():
data["bytes"] = max(0, data.get("bytes", 0) + bytes_delta) cache_path = self._system_bucket_root(bucket_id) / "stats.json"
data["version_count"] = max(0, data.get("version_count", 0) + version_count_delta) try:
data["version_bytes"] = max(0, data.get("version_bytes", 0) + version_bytes_delta) cache_path.parent.mkdir(parents=True, exist_ok=True)
data["total_objects"] = max(0, data.get("total_objects", 0) + objects_delta + version_count_delta) self._atomic_write_json(cache_path, data)
data["total_bytes"] = max(0, data.get("total_bytes", 0) + bytes_delta + version_bytes_delta) except OSError:
data["_cache_serial"] = data.get("_cache_serial", 0) + 1 pass
cache_path.write_text(json.dumps(data), encoding="utf-8")
except (OSError, json.JSONDecodeError): def shutdown_stats(self) -> None:
pass if self._stats_flush_timer is not None:
self._stats_flush_timer.cancel()
self._flush_stats()
if self._etag_index_flush_timer is not None:
self._etag_index_flush_timer.cancel()
self._flush_etag_indexes()
def delete_bucket(self, bucket_name: str) -> None: def delete_bucket(self, bucket_name: str) -> None:
bucket_path = self._bucket_path(bucket_name) bucket_path = self._bucket_path(bucket_name)
@@ -413,13 +451,20 @@ class ObjectStorage:
self._remove_tree(self._system_bucket_root(bucket_id)) self._remove_tree(self._system_bucket_root(bucket_id))
self._remove_tree(self._multipart_bucket_root(bucket_id)) self._remove_tree(self._multipart_bucket_root(bucket_id))
self._bucket_config_cache.pop(bucket_id, None) self._bucket_config_cache.pop(bucket_id, None)
with self._cache_lock: with self._obj_cache_lock:
self._object_cache.pop(bucket_id, None) self._object_cache.pop(bucket_id, None)
self._cache_version.pop(bucket_id, None) self._cache_version.pop(bucket_id, None)
self._sorted_key_cache.pop(bucket_id, None) self._sorted_key_cache.pop(bucket_id, None)
with self._meta_cache_lock:
stale = [k for k in self._meta_read_cache if k[0] == bucket_id] stale = [k for k in self._meta_read_cache if k[0] == bucket_id]
for k in stale: for k in stale:
del self._meta_read_cache[k] del self._meta_read_cache[k]
with self._stats_lock:
self._stats_mem.pop(bucket_id, None)
self._stats_mem_time.pop(bucket_id, None)
self._stats_serial.pop(bucket_id, None)
self._stats_dirty.discard(bucket_id)
self._etag_index_dirty.discard(bucket_id)
def list_objects( def list_objects(
self, self,
@@ -834,51 +879,102 @@ class ObjectStorage:
is_overwrite = destination.exists() is_overwrite = destination.exists()
existing_size = destination.stat().st_size if is_overwrite else 0 existing_size = destination.stat().st_size if is_overwrite else 0
archived_version_size = 0
if self._is_versioning_enabled(bucket_path) and is_overwrite:
archived_version_size = existing_size
self._archive_current_version(bucket_id, safe_key, reason="overwrite")
tmp_dir = self._system_root_path() / self.SYSTEM_TMP_DIR tmp_dir = self._system_root_path() / self.SYSTEM_TMP_DIR
tmp_dir.mkdir(parents=True, exist_ok=True) tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp"
try: if _HAS_RUST:
checksum = hashlib.md5() tmp_path = None
with tmp_path.open("wb") as target:
shutil.copyfileobj(_HashingReader(stream, checksum), target)
new_size = tmp_path.stat().st_size
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
shutil.move(str(tmp_path), str(destination))
finally:
try: try:
tmp_path.unlink(missing_ok=True) tmp_path_str, etag, new_size = _rc.stream_to_file_with_md5(
stream, str(tmp_dir)
)
tmp_path = Path(tmp_path_str)
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
except BaseException:
if tmp_path:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
raise
else:
tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp"
try:
checksum = hashlib.md5()
with tmp_path.open("wb") as target:
shutil.copyfileobj(_HashingReader(stream, checksum), target)
target.flush()
os.fsync(target.fileno())
new_size = tmp_path.stat().st_size
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
etag = checksum.hexdigest()
except BaseException:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
raise
lock_file_path = self._system_bucket_root(bucket_id) / "locks" / f"{safe_key.as_posix().replace('/', '_')}.lock"
try:
with _atomic_lock_file(lock_file_path):
archived_version_size = 0
if self._is_versioning_enabled(bucket_path) and is_overwrite:
archived_version_size = existing_size
self._archive_current_version(bucket_id, safe_key, reason="overwrite")
shutil.move(str(tmp_path), str(destination))
tmp_path = None
stat = destination.stat()
internal_meta = {"__etag__": etag, "__size__": str(stat.st_size), "__last_modified__": str(stat.st_mtime)}
combined_meta = {**internal_meta, **(metadata or {})}
self._write_metadata(bucket_id, safe_key, combined_meta)
except BlockingIOError:
try:
if tmp_path:
tmp_path.unlink(missing_ok=True)
except OSError: except OSError:
pass pass
raise StorageError("Another upload to this key is in progress")
stat = destination.stat() finally:
etag = checksum.hexdigest() if tmp_path:
try:
internal_meta = {"__etag__": etag, "__size__": str(stat.st_size)} tmp_path.unlink(missing_ok=True)
combined_meta = {**internal_meta, **(metadata or {})} except OSError:
self._write_metadata(bucket_id, safe_key, combined_meta) pass
self._update_bucket_stats_cache( self._update_bucket_stats_cache(
bucket_id, bucket_id,
@@ -1465,14 +1561,22 @@ class ObjectStorage:
if not upload_root.exists(): if not upload_root.exists():
raise StorageError("Multipart upload not found") raise StorageError("Multipart upload not found")
checksum = hashlib.md5()
part_filename = f"part-{part_number:05d}.part" part_filename = f"part-{part_number:05d}.part"
part_path = upload_root / part_filename part_path = upload_root / part_filename
temp_path = upload_root / f".{part_filename}.tmp" temp_path = upload_root / f".{part_filename}.tmp"
try: try:
with temp_path.open("wb") as target: if _HAS_RUST:
shutil.copyfileobj(_HashingReader(stream, checksum), target) with temp_path.open("wb") as target:
shutil.copyfileobj(stream, target)
part_etag = _rc.md5_file(str(temp_path))
else:
checksum = hashlib.md5()
with temp_path.open("wb") as target:
shutil.copyfileobj(_HashingReader(stream, checksum), target)
target.flush()
os.fsync(target.fileno())
part_etag = checksum.hexdigest()
temp_path.replace(part_path) temp_path.replace(part_path)
except OSError: except OSError:
try: try:
@@ -1482,7 +1586,7 @@ class ObjectStorage:
raise raise
record = { record = {
"etag": checksum.hexdigest(), "etag": part_etag,
"size": part_path.stat().st_size, "size": part_path.stat().st_size,
"filename": part_filename, "filename": part_filename,
} }
@@ -1505,7 +1609,7 @@ class ObjectStorage:
parts = manifest.setdefault("parts", {}) parts = manifest.setdefault("parts", {})
parts[str(part_number)] = record parts[str(part_number)] = record
manifest_path.write_text(json.dumps(manifest), encoding="utf-8") self._atomic_write_json(manifest_path, manifest)
break break
except OSError as exc: except OSError as exc:
if attempt < max_retries - 1: if attempt < max_retries - 1:
@@ -1598,7 +1702,7 @@ class ObjectStorage:
parts = manifest.setdefault("parts", {}) parts = manifest.setdefault("parts", {})
parts[str(part_number)] = record parts[str(part_number)] = record
manifest_path.write_text(json.dumps(manifest), encoding="utf-8") self._atomic_write_json(manifest_path, manifest)
break break
except OSError as exc: except OSError as exc:
if attempt < max_retries - 1: if attempt < max_retries - 1:
@@ -1682,19 +1786,31 @@ class ObjectStorage:
if versioning_enabled and destination.exists(): if versioning_enabled and destination.exists():
archived_version_size = destination.stat().st_size archived_version_size = destination.stat().st_size
self._archive_current_version(bucket_id, safe_key, reason="overwrite") self._archive_current_version(bucket_id, safe_key, reason="overwrite")
checksum = hashlib.md5() if _HAS_RUST:
with destination.open("wb") as target: part_paths = []
for _, record in validated: for _, record in validated:
part_path = upload_root / record["filename"] pp = upload_root / record["filename"]
if not part_path.exists(): if not pp.exists():
raise StorageError(f"Missing part file {record['filename']}") raise StorageError(f"Missing part file {record['filename']}")
with part_path.open("rb") as chunk: part_paths.append(str(pp))
while True: checksum_hex = _rc.assemble_parts_with_md5(part_paths, str(destination))
data = chunk.read(1024 * 1024) else:
if not data: checksum = hashlib.md5()
break with destination.open("wb") as target:
checksum.update(data) for _, record in validated:
target.write(data) part_path = upload_root / record["filename"]
if not part_path.exists():
raise StorageError(f"Missing part file {record['filename']}")
with part_path.open("rb") as chunk:
while True:
data = chunk.read(1024 * 1024)
if not data:
break
checksum.update(data)
target.write(data)
target.flush()
os.fsync(target.fileno())
checksum_hex = checksum.hexdigest()
except BlockingIOError: except BlockingIOError:
raise StorageError("Another upload to this key is in progress") raise StorageError("Another upload to this key is in progress")
@@ -1709,10 +1825,10 @@ class ObjectStorage:
) )
stat = destination.stat() stat = destination.stat()
etag = checksum.hexdigest() etag = checksum_hex
metadata = manifest.get("metadata") metadata = manifest.get("metadata")
internal_meta = {"__etag__": etag, "__size__": str(stat.st_size)} internal_meta = {"__etag__": etag, "__size__": str(stat.st_size), "__last_modified__": str(stat.st_mtime)}
combined_meta = {**internal_meta, **(metadata or {})} combined_meta = {**internal_meta, **(metadata or {})}
self._write_metadata(bucket_id, safe_key, combined_meta) self._write_metadata(bucket_id, safe_key, combined_meta)
@@ -2073,19 +2189,19 @@ class ObjectStorage:
now = time.time() now = time.time()
current_stats_mtime = self._get_cache_marker_mtime(bucket_id) current_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock: with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id) cached = self._object_cache.get(bucket_id)
if cached: if cached:
objects, timestamp, cached_stats_mtime = cached objects, timestamp, cached_stats_mtime = cached
if now - timestamp < self._cache_ttl and current_stats_mtime == cached_stats_mtime: if now - timestamp < self._cache_ttl and current_stats_mtime == cached_stats_mtime:
self._object_cache.move_to_end(bucket_id) self._object_cache.move_to_end(bucket_id)
return objects return objects
cache_version = self._cache_version.get(bucket_id, 0)
bucket_lock = self._get_bucket_lock(bucket_id) bucket_lock = self._get_bucket_lock(bucket_id)
with bucket_lock: with bucket_lock:
now = time.time()
current_stats_mtime = self._get_cache_marker_mtime(bucket_id) current_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock: with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id) cached = self._object_cache.get(bucket_id)
if cached: if cached:
objects, timestamp, cached_stats_mtime = cached objects, timestamp, cached_stats_mtime = cached
@@ -2096,31 +2212,23 @@ class ObjectStorage:
objects = self._build_object_cache(bucket_path) objects = self._build_object_cache(bucket_path)
new_stats_mtime = self._get_cache_marker_mtime(bucket_id) new_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock: with self._obj_cache_lock:
current_version = self._cache_version.get(bucket_id, 0)
if current_version != cache_version:
objects = self._build_object_cache(bucket_path)
new_stats_mtime = self._get_cache_marker_mtime(bucket_id)
while len(self._object_cache) >= self._object_cache_max_size: while len(self._object_cache) >= self._object_cache_max_size:
self._object_cache.popitem(last=False) self._object_cache.popitem(last=False)
self._object_cache[bucket_id] = (objects, time.time(), new_stats_mtime) self._object_cache[bucket_id] = (objects, time.time(), new_stats_mtime)
self._object_cache.move_to_end(bucket_id) self._object_cache.move_to_end(bucket_id)
self._cache_version[bucket_id] = current_version + 1 self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
self._sorted_key_cache.pop(bucket_id, None) self._sorted_key_cache.pop(bucket_id, None)
return objects return objects
def _invalidate_object_cache(self, bucket_id: str) -> None: def _invalidate_object_cache(self, bucket_id: str) -> None:
"""Invalidate the object cache and etag index for a bucket. with self._obj_cache_lock:
Increments version counter to signal stale reads.
Cross-process invalidation is handled by checking stats.json mtime.
"""
with self._cache_lock:
self._object_cache.pop(bucket_id, None) self._object_cache.pop(bucket_id, None)
self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1 self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
self._etag_index_dirty.discard(bucket_id)
etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json" etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json"
try: try:
etag_index_path.unlink(missing_ok=True) etag_index_path.unlink(missing_ok=True)
@@ -2128,22 +2236,10 @@ class ObjectStorage:
pass pass
def _get_cache_marker_mtime(self, bucket_id: str) -> float: def _get_cache_marker_mtime(self, bucket_id: str) -> float:
"""Get a cache marker combining serial and object count for cross-process invalidation. return float(self._stats_serial.get(bucket_id, 0))
Returns a combined value that changes if either _cache_serial or object count changes.
This handles cases where the serial was reset but object count differs.
"""
stats_path = self._system_bucket_root(bucket_id) / "stats.json"
try:
data = json.loads(stats_path.read_text(encoding="utf-8"))
serial = data.get("_cache_serial", 0)
count = data.get("objects", 0)
return float(serial * 1000000 + count)
except (OSError, json.JSONDecodeError):
return 0
def _update_object_cache_entry(self, bucket_id: str, key: str, meta: Optional[ObjectMeta]) -> None: def _update_object_cache_entry(self, bucket_id: str, key: str, meta: Optional[ObjectMeta]) -> None:
with self._cache_lock: with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id) cached = self._object_cache.get(bucket_id)
if cached: if cached:
objects, timestamp, stats_mtime = cached objects, timestamp, stats_mtime = cached
@@ -2154,23 +2250,32 @@ class ObjectStorage:
self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1 self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
self._sorted_key_cache.pop(bucket_id, None) self._sorted_key_cache.pop(bucket_id, None)
self._update_etag_index(bucket_id, key, meta.etag if meta else None) self._etag_index_dirty.add(bucket_id)
self._schedule_etag_index_flush()
def _update_etag_index(self, bucket_id: str, key: str, etag: Optional[str]) -> None: def _schedule_etag_index_flush(self) -> None:
etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json" if self._etag_index_flush_timer is None or not self._etag_index_flush_timer.is_alive():
if not etag_index_path.exists(): self._etag_index_flush_timer = threading.Timer(5.0, self._flush_etag_indexes)
return self._etag_index_flush_timer.daemon = True
try: self._etag_index_flush_timer.start()
with open(etag_index_path, 'r', encoding='utf-8') as f:
index = json.load(f) def _flush_etag_indexes(self) -> None:
if etag is None: dirty = set(self._etag_index_dirty)
index.pop(key, None) self._etag_index_dirty.clear()
else: for bucket_id in dirty:
index[key] = etag with self._obj_cache_lock:
with open(etag_index_path, 'w', encoding='utf-8') as f: cached = self._object_cache.get(bucket_id)
json.dump(index, f) if not cached:
except (OSError, json.JSONDecodeError): continue
pass objects = cached[0]
index = {k: v.etag for k, v in objects.items() if v.etag}
etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json"
try:
etag_index_path.parent.mkdir(parents=True, exist_ok=True)
with open(etag_index_path, 'w', encoding='utf-8') as f:
json.dump(index, f)
except OSError:
pass
def warm_cache(self, bucket_names: Optional[List[str]] = None) -> None: def warm_cache(self, bucket_names: Optional[List[str]] = None) -> None:
"""Pre-warm the object cache for specified buckets or all buckets. """Pre-warm the object cache for specified buckets or all buckets.
@@ -2211,6 +2316,23 @@ class ObjectStorage:
): ):
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
@staticmethod
def _atomic_write_json(path: Path, data: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(".tmp")
try:
with tmp_path.open("w", encoding="utf-8") as f:
json.dump(data, f)
f.flush()
os.fsync(f.fileno())
tmp_path.replace(path)
except BaseException:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
raise
def _multipart_dir(self, bucket_name: str, upload_id: str) -> Path: def _multipart_dir(self, bucket_name: str, upload_id: str) -> Path:
return self._multipart_bucket_root(bucket_name) / upload_id return self._multipart_bucket_root(bucket_name) / upload_id
@@ -2227,12 +2349,7 @@ class ObjectStorage:
if cached: if cached:
config, cached_time, cached_mtime = cached config, cached_time, cached_mtime = cached
if now - cached_time < self._bucket_config_cache_ttl: if now - cached_time < self._bucket_config_cache_ttl:
try: return config.copy()
current_mtime = config_path.stat().st_mtime if config_path.exists() else 0.0
except OSError:
current_mtime = 0.0
if current_mtime == cached_mtime:
return config.copy()
if not config_path.exists(): if not config_path.exists():
self._bucket_config_cache[bucket_name] = ({}, now, 0.0) self._bucket_config_cache[bucket_name] = ({}, now, 0.0)
@@ -2250,7 +2367,7 @@ class ObjectStorage:
def _write_bucket_config(self, bucket_name: str, payload: dict[str, Any]) -> None: def _write_bucket_config(self, bucket_name: str, payload: dict[str, Any]) -> None:
config_path = self._bucket_config_path(bucket_name) config_path = self._bucket_config_path(bucket_name)
config_path.parent.mkdir(parents=True, exist_ok=True) config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(json.dumps(payload), encoding="utf-8") self._atomic_write_json(config_path, payload)
try: try:
mtime = config_path.stat().st_mtime mtime = config_path.stat().st_mtime
except OSError: except OSError:
@@ -2284,8 +2401,7 @@ class ObjectStorage:
def _write_multipart_manifest(self, upload_root: Path, manifest: dict[str, Any]) -> None: def _write_multipart_manifest(self, upload_root: Path, manifest: dict[str, Any]) -> None:
manifest_path = upload_root / self.MULTIPART_MANIFEST manifest_path = upload_root / self.MULTIPART_MANIFEST
manifest_path.parent.mkdir(parents=True, exist_ok=True) self._atomic_write_json(manifest_path, manifest)
manifest_path.write_text(json.dumps(manifest), encoding="utf-8")
def _metadata_file(self, bucket_name: str, key: Path) -> Path: def _metadata_file(self, bucket_name: str, key: Path) -> Path:
meta_root = self._bucket_meta_root(bucket_name) meta_root = self._bucket_meta_root(bucket_name)
@@ -2301,19 +2417,19 @@ class ObjectStorage:
return meta_root / parent / "_index.json", entry_name return meta_root / parent / "_index.json", entry_name
def _get_meta_index_lock(self, index_path: str) -> threading.Lock: def _get_meta_index_lock(self, index_path: str) -> threading.Lock:
with self._cache_lock: with self._registry_lock:
if index_path not in self._meta_index_locks: if index_path not in self._meta_index_locks:
self._meta_index_locks[index_path] = threading.Lock() self._meta_index_locks[index_path] = threading.Lock()
return self._meta_index_locks[index_path] return self._meta_index_locks[index_path]
def _read_index_entry(self, bucket_name: str, key: Path) -> Optional[Dict[str, Any]]: def _read_index_entry(self, bucket_name: str, key: Path) -> Optional[Dict[str, Any]]:
cache_key = (bucket_name, str(key)) cache_key = (bucket_name, str(key))
with self._cache_lock: with self._meta_cache_lock:
hit = self._meta_read_cache.get(cache_key) hit = self._meta_read_cache.get(cache_key)
if hit is not None: if hit is not None:
self._meta_read_cache.move_to_end(cache_key) self._meta_read_cache.move_to_end(cache_key)
cached = hit[0] cached = hit[0]
return copy.deepcopy(cached) if cached is not None else None return dict(cached) if cached is not None else None
index_path, entry_name = self._index_file_for_key(bucket_name, key) index_path, entry_name = self._index_file_for_key(bucket_name, key)
if _HAS_RUST: if _HAS_RUST:
@@ -2328,16 +2444,16 @@ class ObjectStorage:
except (OSError, json.JSONDecodeError): except (OSError, json.JSONDecodeError):
result = None result = None
with self._cache_lock: with self._meta_cache_lock:
while len(self._meta_read_cache) >= self._meta_read_cache_max: while len(self._meta_read_cache) >= self._meta_read_cache_max:
self._meta_read_cache.popitem(last=False) self._meta_read_cache.popitem(last=False)
self._meta_read_cache[cache_key] = (copy.deepcopy(result) if result is not None else None,) self._meta_read_cache[cache_key] = (dict(result) if result is not None else None,)
return result return result
def _invalidate_meta_read_cache(self, bucket_name: str, key: Path) -> None: def _invalidate_meta_read_cache(self, bucket_name: str, key: Path) -> None:
cache_key = (bucket_name, str(key)) cache_key = (bucket_name, str(key))
with self._cache_lock: with self._meta_cache_lock:
self._meta_read_cache.pop(cache_key, None) self._meta_read_cache.pop(cache_key, None)
def _write_index_entry(self, bucket_name: str, key: Path, entry: Dict[str, Any]) -> None: def _write_index_entry(self, bucket_name: str, key: Path, entry: Dict[str, Any]) -> None:
@@ -2355,7 +2471,7 @@ class ObjectStorage:
except (OSError, json.JSONDecodeError): except (OSError, json.JSONDecodeError):
pass pass
index_data[entry_name] = entry index_data[entry_name] = entry
index_path.write_text(json.dumps(index_data), encoding="utf-8") self._atomic_write_json(index_path, index_data)
self._invalidate_meta_read_cache(bucket_name, key) self._invalidate_meta_read_cache(bucket_name, key)
def _delete_index_entry(self, bucket_name: str, key: Path) -> None: def _delete_index_entry(self, bucket_name: str, key: Path) -> None:
@@ -2376,7 +2492,7 @@ class ObjectStorage:
if entry_name in index_data: if entry_name in index_data:
del index_data[entry_name] del index_data[entry_name]
if index_data: if index_data:
index_path.write_text(json.dumps(index_data), encoding="utf-8") self._atomic_write_json(index_path, index_data)
else: else:
try: try:
index_path.unlink() index_path.unlink()
@@ -2425,7 +2541,7 @@ class ObjectStorage:
"reason": reason, "reason": reason,
} }
manifest_path = version_dir / f"{version_id}.json" manifest_path = version_dir / f"{version_id}.json"
manifest_path.write_text(json.dumps(record), encoding="utf-8") self._atomic_write_json(manifest_path, record)
def _read_metadata(self, bucket_name: str, key: Path) -> Dict[str, str]: def _read_metadata(self, bucket_name: str, key: Path) -> Dict[str, str]:
entry = self._read_index_entry(bucket_name, key) entry = self._read_index_entry(bucket_name, key)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
APP_VERSION = "0.3.3" APP_VERSION = "0.3.6"
def get_version() -> str: def get_version() -> str:

View File

@@ -19,3 +19,6 @@ regex = "1"
lru = "0.14" lru = "0.14"
parking_lot = "0.12" parking_lot = "0.12"
percent-encoding = "2" percent-encoding = "2"
aes-gcm = "0.10"
hkdf = "0.12"
uuid = { version = "1", features = ["v4"] }

192
myfsio_core/src/crypto.rs Normal file
View File

@@ -0,0 +1,192 @@
use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use hkdf::Hkdf;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use sha2::Sha256;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
const DEFAULT_CHUNK_SIZE: usize = 65536;
const HEADER_SIZE: usize = 4;
fn read_exact_chunk(reader: &mut impl Read, buf: &mut [u8]) -> std::io::Result<usize> {
let mut filled = 0;
while filled < buf.len() {
match reader.read(&mut buf[filled..]) {
Ok(0) => break,
Ok(n) => filled += n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(filled)
}
fn derive_chunk_nonce(base_nonce: &[u8], chunk_index: u32) -> Result<[u8; 12], String> {
let hkdf = Hkdf::<Sha256>::new(Some(base_nonce), b"chunk_nonce");
let mut okm = [0u8; 12];
hkdf.expand(&chunk_index.to_be_bytes(), &mut okm)
.map_err(|e| format!("HKDF expand failed: {}", e))?;
Ok(okm)
}
#[pyfunction]
#[pyo3(signature = (input_path, output_path, key, base_nonce, chunk_size=DEFAULT_CHUNK_SIZE))]
pub fn encrypt_stream_chunked(
py: Python<'_>,
input_path: &str,
output_path: &str,
key: &[u8],
base_nonce: &[u8],
chunk_size: usize,
) -> PyResult<u32> {
if key.len() != 32 {
return Err(PyValueError::new_err(format!(
"Key must be 32 bytes, got {}",
key.len()
)));
}
if base_nonce.len() != 12 {
return Err(PyValueError::new_err(format!(
"Base nonce must be 12 bytes, got {}",
base_nonce.len()
)));
}
let chunk_size = if chunk_size == 0 {
DEFAULT_CHUNK_SIZE
} else {
chunk_size
};
let inp = input_path.to_owned();
let out = output_path.to_owned();
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
py.detach(move || {
let cipher = Aes256Gcm::new(&key_arr.into());
let mut infile = File::open(&inp)
.map_err(|e| PyIOError::new_err(format!("Failed to open input: {}", e)))?;
let mut outfile = File::create(&out)
.map_err(|e| PyIOError::new_err(format!("Failed to create output: {}", e)))?;
outfile
.write_all(&[0u8; 4])
.map_err(|e| PyIOError::new_err(format!("Failed to write header: {}", e)))?;
let mut buf = vec![0u8; chunk_size];
let mut chunk_index: u32 = 0;
loop {
let n = read_exact_chunk(&mut infile, &mut buf)
.map_err(|e| PyIOError::new_err(format!("Failed to read: {}", e)))?;
if n == 0 {
break;
}
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)
.map_err(|e| PyValueError::new_err(e))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let encrypted = cipher
.encrypt(nonce, &buf[..n])
.map_err(|e| PyValueError::new_err(format!("Encrypt failed: {}", e)))?;
let size = encrypted.len() as u32;
outfile
.write_all(&size.to_be_bytes())
.map_err(|e| PyIOError::new_err(format!("Failed to write chunk size: {}", e)))?;
outfile
.write_all(&encrypted)
.map_err(|e| PyIOError::new_err(format!("Failed to write chunk: {}", e)))?;
chunk_index += 1;
}
outfile
.seek(SeekFrom::Start(0))
.map_err(|e| PyIOError::new_err(format!("Failed to seek: {}", e)))?;
outfile
.write_all(&chunk_index.to_be_bytes())
.map_err(|e| PyIOError::new_err(format!("Failed to write chunk count: {}", e)))?;
Ok(chunk_index)
})
}
#[pyfunction]
pub fn decrypt_stream_chunked(
py: Python<'_>,
input_path: &str,
output_path: &str,
key: &[u8],
base_nonce: &[u8],
) -> PyResult<u32> {
if key.len() != 32 {
return Err(PyValueError::new_err(format!(
"Key must be 32 bytes, got {}",
key.len()
)));
}
if base_nonce.len() != 12 {
return Err(PyValueError::new_err(format!(
"Base nonce must be 12 bytes, got {}",
base_nonce.len()
)));
}
let inp = input_path.to_owned();
let out = output_path.to_owned();
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
py.detach(move || {
let cipher = Aes256Gcm::new(&key_arr.into());
let mut infile = File::open(&inp)
.map_err(|e| PyIOError::new_err(format!("Failed to open input: {}", e)))?;
let mut outfile = File::create(&out)
.map_err(|e| PyIOError::new_err(format!("Failed to create output: {}", e)))?;
let mut header = [0u8; HEADER_SIZE];
infile
.read_exact(&mut header)
.map_err(|e| PyIOError::new_err(format!("Failed to read header: {}", e)))?;
let chunk_count = u32::from_be_bytes(header);
let mut size_buf = [0u8; HEADER_SIZE];
for chunk_index in 0..chunk_count {
infile
.read_exact(&mut size_buf)
.map_err(|e| {
PyIOError::new_err(format!(
"Failed to read chunk {} size: {}",
chunk_index, e
))
})?;
let chunk_size = u32::from_be_bytes(size_buf) as usize;
let mut encrypted = vec![0u8; chunk_size];
infile.read_exact(&mut encrypted).map_err(|e| {
PyIOError::new_err(format!("Failed to read chunk {}: {}", chunk_index, e))
})?;
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)
.map_err(|e| PyValueError::new_err(e))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let decrypted = cipher.decrypt(nonce, encrypted.as_ref()).map_err(|e| {
PyValueError::new_err(format!("Decrypt chunk {} failed: {}", chunk_index, e))
})?;
outfile.write_all(&decrypted).map_err(|e| {
PyIOError::new_err(format!("Failed to write chunk {}: {}", chunk_index, e))
})?;
}
Ok(chunk_count)
})
}

View File

@@ -1,7 +1,9 @@
mod crypto;
mod hashing; mod hashing;
mod metadata; mod metadata;
mod sigv4; mod sigv4;
mod storage; mod storage;
mod streaming;
mod validation; mod validation;
use pyo3::prelude::*; use pyo3::prelude::*;
@@ -38,6 +40,12 @@ mod myfsio_core {
m.add_function(wrap_pyfunction!(storage::search_objects_scan, m)?)?; m.add_function(wrap_pyfunction!(storage::search_objects_scan, m)?)?;
m.add_function(wrap_pyfunction!(storage::build_object_cache, m)?)?; m.add_function(wrap_pyfunction!(storage::build_object_cache, m)?)?;
m.add_function(wrap_pyfunction!(streaming::stream_to_file_with_md5, m)?)?;
m.add_function(wrap_pyfunction!(streaming::assemble_parts_with_md5, m)?)?;
m.add_function(wrap_pyfunction!(crypto::encrypt_stream_chunked, m)?)?;
m.add_function(wrap_pyfunction!(crypto::decrypt_stream_chunked, m)?)?;
Ok(()) Ok(())
} }
} }

View File

@@ -0,0 +1,112 @@
use md5::{Digest, Md5};
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use std::fs::{self, File};
use std::io::{Read, Write};
use uuid::Uuid;
const DEFAULT_CHUNK_SIZE: usize = 262144;
#[pyfunction]
#[pyo3(signature = (stream, tmp_dir, chunk_size=DEFAULT_CHUNK_SIZE))]
pub fn stream_to_file_with_md5(
py: Python<'_>,
stream: &Bound<'_, PyAny>,
tmp_dir: &str,
chunk_size: usize,
) -> PyResult<(String, String, u64)> {
let chunk_size = if chunk_size == 0 {
DEFAULT_CHUNK_SIZE
} else {
chunk_size
};
fs::create_dir_all(tmp_dir)
.map_err(|e| PyIOError::new_err(format!("Failed to create tmp dir: {}", e)))?;
let tmp_name = format!("{}.tmp", Uuid::new_v4().as_hyphenated());
let tmp_path_buf = std::path::PathBuf::from(tmp_dir).join(&tmp_name);
let tmp_path = tmp_path_buf.to_string_lossy().into_owned();
let mut file = File::create(&tmp_path)
.map_err(|e| PyIOError::new_err(format!("Failed to create temp file: {}", e)))?;
let mut hasher = Md5::new();
let mut total_bytes: u64 = 0;
let result: PyResult<()> = (|| {
loop {
let chunk: Vec<u8> = stream.call_method1("read", (chunk_size,))?.extract()?;
if chunk.is_empty() {
break;
}
hasher.update(&chunk);
file.write_all(&chunk)
.map_err(|e| PyIOError::new_err(format!("Failed to write: {}", e)))?;
total_bytes += chunk.len() as u64;
py.check_signals()?;
}
file.sync_all()
.map_err(|e| PyIOError::new_err(format!("Failed to fsync: {}", e)))?;
Ok(())
})();
if let Err(e) = result {
drop(file);
let _ = fs::remove_file(&tmp_path);
return Err(e);
}
drop(file);
let md5_hex = format!("{:x}", hasher.finalize());
Ok((tmp_path, md5_hex, total_bytes))
}
#[pyfunction]
pub fn assemble_parts_with_md5(
py: Python<'_>,
part_paths: Vec<String>,
dest_path: &str,
) -> PyResult<String> {
if part_paths.is_empty() {
return Err(PyValueError::new_err("No parts to assemble"));
}
let dest = dest_path.to_owned();
let parts = part_paths;
py.detach(move || {
if let Some(parent) = std::path::Path::new(&dest).parent() {
fs::create_dir_all(parent)
.map_err(|e| PyIOError::new_err(format!("Failed to create dest dir: {}", e)))?;
}
let mut target = File::create(&dest)
.map_err(|e| PyIOError::new_err(format!("Failed to create dest file: {}", e)))?;
let mut hasher = Md5::new();
let mut buf = vec![0u8; 1024 * 1024];
for part_path in &parts {
let mut part = File::open(part_path)
.map_err(|e| PyIOError::new_err(format!("Failed to open part {}: {}", part_path, e)))?;
loop {
let n = part
.read(&mut buf)
.map_err(|e| PyIOError::new_err(format!("Failed to read part: {}", e)))?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
target
.write_all(&buf[..n])
.map_err(|e| PyIOError::new_err(format!("Failed to write: {}", e)))?;
}
}
target.sync_all()
.map_err(|e| PyIOError::new_err(format!("Failed to fsync: {}", e)))?;
Ok(format!("{:x}", hasher.finalize()))
})
}

View File

@@ -321,7 +321,7 @@
`; `;
}; };
const bucketTotalObjects = objectsContainer ? parseInt(objectsContainer.dataset.bucketTotalObjects || '0', 10) : 0; let bucketTotalObjects = objectsContainer ? parseInt(objectsContainer.dataset.bucketTotalObjects || '0', 10) : 0;
const updateObjectCountBadge = () => { const updateObjectCountBadge = () => {
if (!objectCountBadge) return; if (!objectCountBadge) return;
@@ -702,6 +702,7 @@
flushPendingStreamObjects(); flushPendingStreamObjects();
hasMoreObjects = false; hasMoreObjects = false;
totalObjectCount = loadedObjectCount; totalObjectCount = loadedObjectCount;
if (!currentPrefix) bucketTotalObjects = totalObjectCount;
updateObjectCountBadge(); updateObjectCountBadge();
if (objectsLoadingRow && objectsLoadingRow.parentNode) { if (objectsLoadingRow && objectsLoadingRow.parentNode) {
@@ -766,6 +767,7 @@
} }
totalObjectCount = data.total_count || 0; totalObjectCount = data.total_count || 0;
if (!append && !currentPrefix) bucketTotalObjects = totalObjectCount;
nextContinuationToken = data.next_continuation_token; nextContinuationToken = data.next_continuation_token;
if (!append && objectsLoadingRow) { if (!append && objectsLoadingRow) {

View File

@@ -43,6 +43,11 @@ def app(tmp_path: Path):
} }
) )
yield flask_app yield flask_app
storage = flask_app.extensions.get("object_storage")
if storage:
base = getattr(storage, "storage", storage)
if hasattr(base, "shutdown_stats"):
base.shutdown_stats()
@pytest.fixture() @pytest.fixture()

View File

@@ -53,7 +53,9 @@ def test_special_characters_in_metadata(tmp_path: Path):
assert meta["special"] == "!@#$%^&*()" assert meta["special"] == "!@#$%^&*()"
def test_disk_full_scenario(tmp_path: Path, monkeypatch): def test_disk_full_scenario(tmp_path: Path, monkeypatch):
# Simulate disk full by mocking write to fail import app.storage as _storage_mod
monkeypatch.setattr(_storage_mod, "_HAS_RUST", False)
storage = ObjectStorage(tmp_path) storage = ObjectStorage(tmp_path)
storage.create_bucket("full") storage.create_bucket("full")

View File

@@ -0,0 +1,350 @@
import hashlib
import io
import os
import secrets
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
try:
import myfsio_core as _rc
HAS_RUST = True
except ImportError:
_rc = None
HAS_RUST = False
pytestmark = pytest.mark.skipif(not HAS_RUST, reason="myfsio_core not available")
class TestStreamToFileWithMd5:
def test_basic_write(self, tmp_path):
data = b"hello world" * 1000
stream = io.BytesIO(data)
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir)
assert size == len(data)
assert md5_hex == hashlib.md5(data).hexdigest()
assert Path(tmp_path_str).exists()
assert Path(tmp_path_str).read_bytes() == data
def test_empty_stream(self, tmp_path):
stream = io.BytesIO(b"")
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir)
assert size == 0
assert md5_hex == hashlib.md5(b"").hexdigest()
assert Path(tmp_path_str).read_bytes() == b""
def test_large_data(self, tmp_path):
data = os.urandom(1024 * 1024 * 2)
stream = io.BytesIO(data)
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir)
assert size == len(data)
assert md5_hex == hashlib.md5(data).hexdigest()
def test_custom_chunk_size(self, tmp_path):
data = b"x" * 10000
stream = io.BytesIO(data)
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(
stream, tmp_dir, chunk_size=128
)
assert size == len(data)
assert md5_hex == hashlib.md5(data).hexdigest()
class TestAssemblePartsWithMd5:
def test_basic_assembly(self, tmp_path):
parts = []
combined = b""
for i in range(3):
data = f"part{i}data".encode() * 100
combined += data
p = tmp_path / f"part{i}"
p.write_bytes(data)
parts.append(str(p))
dest = str(tmp_path / "output")
md5_hex = _rc.assemble_parts_with_md5(parts, dest)
assert md5_hex == hashlib.md5(combined).hexdigest()
assert Path(dest).read_bytes() == combined
def test_single_part(self, tmp_path):
data = b"single part data"
p = tmp_path / "part0"
p.write_bytes(data)
dest = str(tmp_path / "output")
md5_hex = _rc.assemble_parts_with_md5([str(p)], dest)
assert md5_hex == hashlib.md5(data).hexdigest()
assert Path(dest).read_bytes() == data
def test_empty_parts_list(self):
with pytest.raises(ValueError, match="No parts"):
_rc.assemble_parts_with_md5([], "dummy")
def test_missing_part_file(self, tmp_path):
with pytest.raises(OSError):
_rc.assemble_parts_with_md5(
[str(tmp_path / "nonexistent")], str(tmp_path / "out")
)
def test_large_parts(self, tmp_path):
parts = []
combined = b""
for i in range(5):
data = os.urandom(512 * 1024)
combined += data
p = tmp_path / f"part{i}"
p.write_bytes(data)
parts.append(str(p))
dest = str(tmp_path / "output")
md5_hex = _rc.assemble_parts_with_md5(parts, dest)
assert md5_hex == hashlib.md5(combined).hexdigest()
assert Path(dest).read_bytes() == combined
class TestEncryptDecryptStreamChunked:
def _python_derive_chunk_nonce(self, base_nonce, chunk_index):
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=12,
salt=base_nonce,
info=chunk_index.to_bytes(4, "big"),
)
return hkdf.derive(b"chunk_nonce")
def test_encrypt_decrypt_roundtrip(self, tmp_path):
data = b"Hello, encryption!" * 500
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce
)
assert chunk_count > 0
chunk_count_dec = _rc.decrypt_stream_chunked(
encrypted_path, decrypted_path, key, base_nonce
)
assert chunk_count_dec == chunk_count
assert Path(decrypted_path).read_bytes() == data
def test_empty_file(self, tmp_path):
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "empty")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(b"")
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce
)
assert chunk_count == 0
chunk_count_dec = _rc.decrypt_stream_chunked(
encrypted_path, decrypted_path, key, base_nonce
)
assert chunk_count_dec == 0
assert Path(decrypted_path).read_bytes() == b""
def test_custom_chunk_size(self, tmp_path):
data = os.urandom(10000)
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce, chunk_size=1024
)
assert chunk_count == 10
_rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce)
assert Path(decrypted_path).read_bytes() == data
def test_invalid_key_length(self, tmp_path):
input_path = str(tmp_path / "in")
Path(input_path).write_bytes(b"data")
with pytest.raises(ValueError, match="32 bytes"):
_rc.encrypt_stream_chunked(
input_path, str(tmp_path / "out"), b"short", secrets.token_bytes(12)
)
def test_invalid_nonce_length(self, tmp_path):
input_path = str(tmp_path / "in")
Path(input_path).write_bytes(b"data")
with pytest.raises(ValueError, match="12 bytes"):
_rc.encrypt_stream_chunked(
input_path, str(tmp_path / "out"), secrets.token_bytes(32), b"short"
)
def test_wrong_key_fails_decrypt(self, tmp_path):
data = b"sensitive data"
key = secrets.token_bytes(32)
wrong_key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
_rc.encrypt_stream_chunked(input_path, encrypted_path, key, base_nonce)
with pytest.raises((ValueError, OSError)):
_rc.decrypt_stream_chunked(
encrypted_path, decrypted_path, wrong_key, base_nonce
)
def test_cross_compat_python_encrypt_rust_decrypt(self, tmp_path):
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
data = b"cross compat test data" * 100
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
chunk_size = 1024
encrypted_path = str(tmp_path / "py_encrypted")
with open(encrypted_path, "wb") as f:
f.write(b"\x00\x00\x00\x00")
aesgcm = AESGCM(key)
chunk_index = 0
offset = 0
while offset < len(data):
chunk = data[offset:offset + chunk_size]
nonce = self._python_derive_chunk_nonce(base_nonce, chunk_index)
enc = aesgcm.encrypt(nonce, chunk, None)
f.write(len(enc).to_bytes(4, "big"))
f.write(enc)
chunk_index += 1
offset += chunk_size
f.seek(0)
f.write(chunk_index.to_bytes(4, "big"))
decrypted_path = str(tmp_path / "rust_decrypted")
_rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce)
assert Path(decrypted_path).read_bytes() == data
def test_cross_compat_rust_encrypt_python_decrypt(self, tmp_path):
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
data = b"cross compat reverse test" * 100
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
chunk_size = 1024
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "rust_encrypted")
Path(input_path).write_bytes(data)
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce, chunk_size=chunk_size
)
aesgcm = AESGCM(key)
with open(encrypted_path, "rb") as f:
count_bytes = f.read(4)
assert int.from_bytes(count_bytes, "big") == chunk_count
decrypted = b""
for i in range(chunk_count):
size = int.from_bytes(f.read(4), "big")
enc_chunk = f.read(size)
nonce = self._python_derive_chunk_nonce(base_nonce, i)
decrypted += aesgcm.decrypt(nonce, enc_chunk, None)
assert decrypted == data
def test_large_file_roundtrip(self, tmp_path):
data = os.urandom(1024 * 1024)
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "large")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
_rc.encrypt_stream_chunked(input_path, encrypted_path, key, base_nonce)
_rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce)
assert Path(decrypted_path).read_bytes() == data
class TestStreamingEncryptorFileMethods:
def test_encrypt_file_decrypt_file_roundtrip(self, tmp_path):
from app.encryption import LocalKeyEncryption, StreamingEncryptor
master_key_path = tmp_path / "master.key"
provider = LocalKeyEncryption(master_key_path)
encryptor = StreamingEncryptor(provider, chunk_size=512)
data = b"file method test data" * 200
input_path = str(tmp_path / "input")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
metadata = encryptor.encrypt_file(input_path, encrypted_path)
assert metadata.algorithm == "AES256"
encryptor.decrypt_file(encrypted_path, decrypted_path, metadata)
assert Path(decrypted_path).read_bytes() == data
def test_encrypt_file_matches_encrypt_stream(self, tmp_path):
from app.encryption import LocalKeyEncryption, StreamingEncryptor
master_key_path = tmp_path / "master.key"
provider = LocalKeyEncryption(master_key_path)
encryptor = StreamingEncryptor(provider, chunk_size=512)
data = b"stream vs file comparison" * 100
input_path = str(tmp_path / "input")
Path(input_path).write_bytes(data)
file_encrypted_path = str(tmp_path / "file_enc")
metadata_file = encryptor.encrypt_file(input_path, file_encrypted_path)
file_decrypted_path = str(tmp_path / "file_dec")
encryptor.decrypt_file(file_encrypted_path, file_decrypted_path, metadata_file)
assert Path(file_decrypted_path).read_bytes() == data
stream_enc, metadata_stream = encryptor.encrypt_stream(io.BytesIO(data))
stream_dec = encryptor.decrypt_stream(stream_enc, metadata_stream)
assert stream_dec.read() == data