8 Commits

15 changed files with 984 additions and 162 deletions

View File

@@ -80,7 +80,7 @@ python run.py --mode api # API only (port 5000)
python run.py --mode ui # UI only (port 5100)
```
**Default Credentials:** `localadmin` / `localadmin`
**Credentials:** Generated automatically on first run and printed to the console. If missed, check the IAM config file at `<STORAGE_ROOT>/.myfsio.sys/config/iam.json`.
- **Web Console:** http://127.0.0.1:5100/ui
- **API Endpoint:** http://127.0.0.1:5000

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import html as html_module
import itertools
import logging
import mimetypes
import os
import shutil
import sys
import time
import uuid
from logging.handlers import RotatingFileHandler
from pathlib import Path
from datetime import timedelta
@@ -39,6 +39,8 @@ from .storage import ObjectStorage, StorageError
from .version import get_version
from .website_domains import WebsiteDomainStore
_request_counter = itertools.count(1)
def _migrate_config_file(active_path: Path, legacy_paths: List[Path]) -> Path:
"""Migrate config file from legacy locations to the active path.
@@ -481,7 +483,7 @@ def _configure_logging(app: Flask) -> None:
@app.before_request
def _log_request_start() -> None:
g.request_id = uuid.uuid4().hex
g.request_id = f"{os.getpid():x}{next(_request_counter):012x}"
g.request_started_at = time.perf_counter()
g.request_bytes_in = request.content_length or 0
app.logger.info(
@@ -616,7 +618,7 @@ def _configure_logging(app: Flask) -> None:
duration_ms = 0.0
if hasattr(g, "request_started_at"):
duration_ms = (time.perf_counter() - g.request_started_at) * 1000
request_id = getattr(g, "request_id", uuid.uuid4().hex)
request_id = getattr(g, "request_id", f"{os.getpid():x}{next(_request_counter):012x}")
response.headers.setdefault("X-Request-ID", request_id)
app.logger.info(
"Request completed",

View File

@@ -19,6 +19,13 @@ from cryptography.hazmat.primitives import hashes
if sys.platform != "win32":
import fcntl
try:
import myfsio_core as _rc
_HAS_RUST = True
except ImportError:
_rc = None
_HAS_RUST = False
logger = logging.getLogger(__name__)
@@ -338,6 +345,69 @@ class StreamingEncryptor:
output.seek(0)
return output
def encrypt_file(self, input_path: str, output_path: str) -> EncryptionMetadata:
data_key, encrypted_data_key = self.provider.generate_data_key()
base_nonce = secrets.token_bytes(12)
if _HAS_RUST:
_rc.encrypt_stream_chunked(
input_path, output_path, data_key, base_nonce, self.chunk_size
)
else:
with open(input_path, "rb") as stream:
aesgcm = AESGCM(data_key)
with open(output_path, "wb") as out:
out.write(b"\x00\x00\x00\x00")
chunk_index = 0
while True:
chunk = stream.read(self.chunk_size)
if not chunk:
break
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
encrypted_chunk = aesgcm.encrypt(chunk_nonce, chunk, None)
out.write(len(encrypted_chunk).to_bytes(self.HEADER_SIZE, "big"))
out.write(encrypted_chunk)
chunk_index += 1
out.seek(0)
out.write(chunk_index.to_bytes(4, "big"))
return EncryptionMetadata(
algorithm="AES256",
key_id=self.provider.KEY_ID if hasattr(self.provider, "KEY_ID") else "local",
nonce=base_nonce,
encrypted_data_key=encrypted_data_key,
)
def decrypt_file(self, input_path: str, output_path: str,
metadata: EncryptionMetadata) -> None:
data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id)
base_nonce = metadata.nonce
if _HAS_RUST:
_rc.decrypt_stream_chunked(input_path, output_path, data_key, base_nonce)
else:
with open(input_path, "rb") as stream:
chunk_count_bytes = stream.read(4)
if len(chunk_count_bytes) < 4:
raise EncryptionError("Invalid encrypted stream: missing header")
chunk_count = int.from_bytes(chunk_count_bytes, "big")
aesgcm = AESGCM(data_key)
with open(output_path, "wb") as out:
for chunk_index in range(chunk_count):
size_bytes = stream.read(self.HEADER_SIZE)
if len(size_bytes) < self.HEADER_SIZE:
raise EncryptionError(f"Invalid encrypted stream: truncated at chunk {chunk_index}")
chunk_size = int.from_bytes(size_bytes, "big")
encrypted_chunk = stream.read(chunk_size)
if len(encrypted_chunk) < chunk_size:
raise EncryptionError(f"Invalid encrypted stream: incomplete chunk {chunk_index}")
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
try:
decrypted_chunk = aesgcm.decrypt(chunk_nonce, encrypted_chunk, None)
out.write(decrypted_chunk)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt chunk {chunk_index}: {exc}") from exc
class EncryptionManager:
"""Manages encryption providers and operations."""

View File

@@ -5,6 +5,7 @@ import logging
import random
import threading
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
@@ -138,8 +139,8 @@ class OperationMetricsCollector:
self.interval_seconds = interval_minutes * 60
self.retention_hours = retention_hours
self._lock = threading.Lock()
self._by_method: Dict[str, OperationStats] = {}
self._by_endpoint: Dict[str, OperationStats] = {}
self._by_method: Dict[str, OperationStats] = defaultdict(OperationStats)
self._by_endpoint: Dict[str, OperationStats] = defaultdict(OperationStats)
self._by_status_class: Dict[str, int] = {}
self._error_codes: Dict[str, int] = {}
self._totals = OperationStats()
@@ -211,8 +212,8 @@ class OperationMetricsCollector:
self._prune_old_snapshots()
self._save_history()
self._by_method.clear()
self._by_endpoint.clear()
self._by_method = defaultdict(OperationStats)
self._by_endpoint = defaultdict(OperationStats)
self._by_status_class.clear()
self._error_codes.clear()
self._totals = OperationStats()
@@ -232,12 +233,7 @@ class OperationMetricsCollector:
status_class = f"{status_code // 100}xx"
with self._lock:
if method not in self._by_method:
self._by_method[method] = OperationStats()
self._by_method[method].record(latency_ms, success, bytes_in, bytes_out)
if endpoint_type not in self._by_endpoint:
self._by_endpoint[endpoint_type] = OperationStats()
self._by_endpoint[endpoint_type].record(latency_ms, success, bytes_in, bytes_out)
self._by_status_class[status_class] = self._by_status_class.get(status_class, 0) + 1

View File

@@ -293,9 +293,7 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None:
raise IamError("Required headers not signed")
canonical_uri = _get_canonical_uri(req)
payload_hash = req.headers.get("X-Amz-Content-Sha256")
if not payload_hash:
payload_hash = hashlib.sha256(req.get_data()).hexdigest()
payload_hash = req.headers.get("X-Amz-Content-Sha256") or "UNSIGNED-PAYLOAD"
if _HAS_RUST:
query_params = list(req.args.items(multi=True))

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import copy
import hashlib
import json
import os
@@ -196,7 +195,9 @@ class ObjectStorage:
self.root.mkdir(parents=True, exist_ok=True)
self._ensure_system_roots()
self._object_cache: OrderedDict[str, tuple[Dict[str, ObjectMeta], float, float]] = OrderedDict()
self._cache_lock = threading.Lock()
self._obj_cache_lock = threading.Lock()
self._meta_cache_lock = threading.Lock()
self._registry_lock = threading.Lock()
self._bucket_locks: Dict[str, threading.Lock] = {}
self._cache_version: Dict[str, int] = {}
self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {}
@@ -209,10 +210,17 @@ class ObjectStorage:
self._meta_read_cache: OrderedDict[tuple, Optional[Dict[str, Any]]] = OrderedDict()
self._meta_read_cache_max = 2048
self._cleanup_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ParentCleanup")
self._stats_mem: Dict[str, Dict[str, int]] = {}
self._stats_serial: Dict[str, int] = {}
self._stats_mem_time: Dict[str, float] = {}
self._stats_lock = threading.Lock()
self._stats_dirty: set[str] = set()
self._stats_flush_timer: Optional[threading.Timer] = None
self._etag_index_dirty: set[str] = set()
self._etag_index_flush_timer: Optional[threading.Timer] = None
def _get_bucket_lock(self, bucket_id: str) -> threading.Lock:
"""Get or create a lock for a specific bucket. Reduces global lock contention."""
with self._cache_lock:
with self._registry_lock:
if bucket_id not in self._bucket_locks:
self._bucket_locks[bucket_id] = threading.Lock()
return self._bucket_locks[bucket_id]
@@ -260,26 +268,24 @@ class ObjectStorage:
self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True)
def bucket_stats(self, bucket_name: str, cache_ttl: int = 60) -> dict[str, int]:
"""Return object count and total size for the bucket (cached).
Args:
bucket_name: Name of the bucket
cache_ttl: Cache time-to-live in seconds (default 60)
"""
bucket_path = self._bucket_path(bucket_name)
if not bucket_path.exists():
raise BucketNotFoundError("Bucket does not exist")
with self._stats_lock:
if bucket_name in self._stats_mem:
cached_at = self._stats_mem_time.get(bucket_name, 0.0)
if (time.monotonic() - cached_at) < cache_ttl:
return dict(self._stats_mem[bucket_name])
self._stats_mem.pop(bucket_name, None)
self._stats_mem_time.pop(bucket_name, None)
cache_path = self._system_bucket_root(bucket_name) / "stats.json"
cached_stats = None
cache_fresh = False
if cache_path.exists():
try:
cache_fresh = time.time() - cache_path.stat().st_mtime < cache_ttl
cached_stats = json.loads(cache_path.read_text(encoding="utf-8"))
if cache_fresh:
return cached_stats
except (OSError, json.JSONDecodeError):
pass
@@ -348,6 +354,11 @@ class ObjectStorage:
"_cache_serial": existing_serial,
}
with self._stats_lock:
self._stats_mem[bucket_name] = stats
self._stats_mem_time[bucket_name] = time.monotonic()
self._stats_serial[bucket_name] = existing_serial
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_text(json.dumps(stats), encoding="utf-8")
@@ -357,7 +368,11 @@ class ObjectStorage:
return stats
def _invalidate_bucket_stats_cache(self, bucket_id: str) -> None:
"""Invalidate the cached bucket statistics."""
with self._stats_lock:
self._stats_mem.pop(bucket_id, None)
self._stats_mem_time.pop(bucket_id, None)
self._stats_serial[bucket_id] = self._stats_serial.get(bucket_id, 0) + 1
self._stats_dirty.discard(bucket_id)
cache_path = self._system_bucket_root(bucket_id) / "stats.json"
try:
cache_path.unlink(missing_ok=True)
@@ -373,29 +388,52 @@ class ObjectStorage:
version_bytes_delta: int = 0,
version_count_delta: int = 0,
) -> None:
"""Incrementally update cached bucket statistics instead of invalidating.
with self._stats_lock:
if bucket_id not in self._stats_mem:
self._stats_mem[bucket_id] = {
"objects": 0, "bytes": 0, "version_count": 0,
"version_bytes": 0, "total_objects": 0, "total_bytes": 0,
"_cache_serial": 0,
}
data = self._stats_mem[bucket_id]
data["objects"] = max(0, data["objects"] + objects_delta)
data["bytes"] = max(0, data["bytes"] + bytes_delta)
data["version_count"] = max(0, data["version_count"] + version_count_delta)
data["version_bytes"] = max(0, data["version_bytes"] + version_bytes_delta)
data["total_objects"] = max(0, data["total_objects"] + objects_delta + version_count_delta)
data["total_bytes"] = max(0, data["total_bytes"] + bytes_delta + version_bytes_delta)
data["_cache_serial"] = data["_cache_serial"] + 1
self._stats_serial[bucket_id] = self._stats_serial.get(bucket_id, 0) + 1
self._stats_mem_time[bucket_id] = time.monotonic()
self._stats_dirty.add(bucket_id)
self._schedule_stats_flush()
This avoids expensive full directory scans on every PUT/DELETE by
adjusting the cached values directly. Also signals cross-process cache
invalidation by incrementing _cache_serial.
"""
cache_path = self._system_bucket_root(bucket_id) / "stats.json"
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
if cache_path.exists():
data = json.loads(cache_path.read_text(encoding="utf-8"))
else:
data = {"objects": 0, "bytes": 0, "version_count": 0, "version_bytes": 0, "total_objects": 0, "total_bytes": 0, "_cache_serial": 0}
data["objects"] = max(0, data.get("objects", 0) + objects_delta)
data["bytes"] = max(0, data.get("bytes", 0) + bytes_delta)
data["version_count"] = max(0, data.get("version_count", 0) + version_count_delta)
data["version_bytes"] = max(0, data.get("version_bytes", 0) + version_bytes_delta)
data["total_objects"] = max(0, data.get("total_objects", 0) + objects_delta + version_count_delta)
data["total_bytes"] = max(0, data.get("total_bytes", 0) + bytes_delta + version_bytes_delta)
data["_cache_serial"] = data.get("_cache_serial", 0) + 1
cache_path.write_text(json.dumps(data), encoding="utf-8")
except (OSError, json.JSONDecodeError):
pass
def _schedule_stats_flush(self) -> None:
if self._stats_flush_timer is None or not self._stats_flush_timer.is_alive():
self._stats_flush_timer = threading.Timer(3.0, self._flush_stats)
self._stats_flush_timer.daemon = True
self._stats_flush_timer.start()
def _flush_stats(self) -> None:
with self._stats_lock:
dirty = list(self._stats_dirty)
self._stats_dirty.clear()
snapshots = {b: dict(self._stats_mem[b]) for b in dirty if b in self._stats_mem}
for bucket_id, data in snapshots.items():
cache_path = self._system_bucket_root(bucket_id) / "stats.json"
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_text(json.dumps(data), encoding="utf-8")
except OSError:
pass
def shutdown_stats(self) -> None:
if self._stats_flush_timer is not None:
self._stats_flush_timer.cancel()
self._flush_stats()
if self._etag_index_flush_timer is not None:
self._etag_index_flush_timer.cancel()
self._flush_etag_indexes()
def delete_bucket(self, bucket_name: str) -> None:
bucket_path = self._bucket_path(bucket_name)
@@ -413,13 +451,20 @@ class ObjectStorage:
self._remove_tree(self._system_bucket_root(bucket_id))
self._remove_tree(self._multipart_bucket_root(bucket_id))
self._bucket_config_cache.pop(bucket_id, None)
with self._cache_lock:
with self._obj_cache_lock:
self._object_cache.pop(bucket_id, None)
self._cache_version.pop(bucket_id, None)
self._sorted_key_cache.pop(bucket_id, None)
with self._meta_cache_lock:
stale = [k for k in self._meta_read_cache if k[0] == bucket_id]
for k in stale:
del self._meta_read_cache[k]
with self._stats_lock:
self._stats_mem.pop(bucket_id, None)
self._stats_mem_time.pop(bucket_id, None)
self._stats_serial.pop(bucket_id, None)
self._stats_dirty.discard(bucket_id)
self._etag_index_dirty.discard(bucket_id)
def list_objects(
self,
@@ -841,40 +886,78 @@ class ObjectStorage:
tmp_dir = self._system_root_path() / self.SYSTEM_TMP_DIR
tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp"
try:
checksum = hashlib.md5()
with tmp_path.open("wb") as target:
shutil.copyfileobj(_HashingReader(stream, checksum), target)
new_size = tmp_path.stat().st_size
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
shutil.move(str(tmp_path), str(destination))
finally:
if _HAS_RUST:
tmp_path = None
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
tmp_path_str, etag, new_size = _rc.stream_to_file_with_md5(
stream, str(tmp_dir)
)
tmp_path = Path(tmp_path_str)
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
shutil.move(str(tmp_path), str(destination))
finally:
if tmp_path:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
else:
tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp"
try:
with tmp_path.open("wb") as target:
shutil.copyfileobj(stream, target)
new_size = tmp_path.stat().st_size
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
checksum = hashlib.md5()
with tmp_path.open("rb") as f:
while True:
chunk = f.read(1048576)
if not chunk:
break
checksum.update(chunk)
etag = checksum.hexdigest()
shutil.move(str(tmp_path), str(destination))
finally:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
stat = destination.stat()
etag = checksum.hexdigest()
internal_meta = {"__etag__": etag, "__size__": str(stat.st_size)}
combined_meta = {**internal_meta, **(metadata or {})}
@@ -1465,14 +1548,24 @@ class ObjectStorage:
if not upload_root.exists():
raise StorageError("Multipart upload not found")
checksum = hashlib.md5()
part_filename = f"part-{part_number:05d}.part"
part_path = upload_root / part_filename
temp_path = upload_root / f".{part_filename}.tmp"
try:
with temp_path.open("wb") as target:
shutil.copyfileobj(_HashingReader(stream, checksum), target)
shutil.copyfileobj(stream, target)
if _HAS_RUST:
part_etag = _rc.md5_file(str(temp_path))
else:
checksum = hashlib.md5()
with temp_path.open("rb") as f:
while True:
chunk = f.read(1048576)
if not chunk:
break
checksum.update(chunk)
part_etag = checksum.hexdigest()
temp_path.replace(part_path)
except OSError:
try:
@@ -1482,7 +1575,7 @@ class ObjectStorage:
raise
record = {
"etag": checksum.hexdigest(),
"etag": part_etag,
"size": part_path.stat().st_size,
"filename": part_filename,
}
@@ -1682,19 +1775,29 @@ class ObjectStorage:
if versioning_enabled and destination.exists():
archived_version_size = destination.stat().st_size
self._archive_current_version(bucket_id, safe_key, reason="overwrite")
checksum = hashlib.md5()
with destination.open("wb") as target:
if _HAS_RUST:
part_paths = []
for _, record in validated:
part_path = upload_root / record["filename"]
if not part_path.exists():
pp = upload_root / record["filename"]
if not pp.exists():
raise StorageError(f"Missing part file {record['filename']}")
with part_path.open("rb") as chunk:
while True:
data = chunk.read(1024 * 1024)
if not data:
break
checksum.update(data)
target.write(data)
part_paths.append(str(pp))
checksum_hex = _rc.assemble_parts_with_md5(part_paths, str(destination))
else:
checksum = hashlib.md5()
with destination.open("wb") as target:
for _, record in validated:
part_path = upload_root / record["filename"]
if not part_path.exists():
raise StorageError(f"Missing part file {record['filename']}")
with part_path.open("rb") as chunk:
while True:
data = chunk.read(1024 * 1024)
if not data:
break
checksum.update(data)
target.write(data)
checksum_hex = checksum.hexdigest()
except BlockingIOError:
raise StorageError("Another upload to this key is in progress")
@@ -1709,7 +1812,7 @@ class ObjectStorage:
)
stat = destination.stat()
etag = checksum.hexdigest()
etag = checksum_hex
metadata = manifest.get("metadata")
internal_meta = {"__etag__": etag, "__size__": str(stat.st_size)}
@@ -2073,19 +2176,19 @@ class ObjectStorage:
now = time.time()
current_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock:
with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id)
if cached:
objects, timestamp, cached_stats_mtime = cached
if now - timestamp < self._cache_ttl and current_stats_mtime == cached_stats_mtime:
self._object_cache.move_to_end(bucket_id)
return objects
cache_version = self._cache_version.get(bucket_id, 0)
bucket_lock = self._get_bucket_lock(bucket_id)
with bucket_lock:
now = time.time()
current_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock:
with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id)
if cached:
objects, timestamp, cached_stats_mtime = cached
@@ -2096,31 +2199,23 @@ class ObjectStorage:
objects = self._build_object_cache(bucket_path)
new_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock:
current_version = self._cache_version.get(bucket_id, 0)
if current_version != cache_version:
objects = self._build_object_cache(bucket_path)
new_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._obj_cache_lock:
while len(self._object_cache) >= self._object_cache_max_size:
self._object_cache.popitem(last=False)
self._object_cache[bucket_id] = (objects, time.time(), new_stats_mtime)
self._object_cache.move_to_end(bucket_id)
self._cache_version[bucket_id] = current_version + 1
self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
self._sorted_key_cache.pop(bucket_id, None)
return objects
def _invalidate_object_cache(self, bucket_id: str) -> None:
"""Invalidate the object cache and etag index for a bucket.
Increments version counter to signal stale reads.
Cross-process invalidation is handled by checking stats.json mtime.
"""
with self._cache_lock:
with self._obj_cache_lock:
self._object_cache.pop(bucket_id, None)
self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
self._etag_index_dirty.discard(bucket_id)
etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json"
try:
etag_index_path.unlink(missing_ok=True)
@@ -2128,22 +2223,10 @@ class ObjectStorage:
pass
def _get_cache_marker_mtime(self, bucket_id: str) -> float:
"""Get a cache marker combining serial and object count for cross-process invalidation.
Returns a combined value that changes if either _cache_serial or object count changes.
This handles cases where the serial was reset but object count differs.
"""
stats_path = self._system_bucket_root(bucket_id) / "stats.json"
try:
data = json.loads(stats_path.read_text(encoding="utf-8"))
serial = data.get("_cache_serial", 0)
count = data.get("objects", 0)
return float(serial * 1000000 + count)
except (OSError, json.JSONDecodeError):
return 0
return float(self._stats_serial.get(bucket_id, 0))
def _update_object_cache_entry(self, bucket_id: str, key: str, meta: Optional[ObjectMeta]) -> None:
with self._cache_lock:
with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id)
if cached:
objects, timestamp, stats_mtime = cached
@@ -2154,23 +2237,32 @@ class ObjectStorage:
self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
self._sorted_key_cache.pop(bucket_id, None)
self._update_etag_index(bucket_id, key, meta.etag if meta else None)
self._etag_index_dirty.add(bucket_id)
self._schedule_etag_index_flush()
def _update_etag_index(self, bucket_id: str, key: str, etag: Optional[str]) -> None:
etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json"
if not etag_index_path.exists():
return
try:
with open(etag_index_path, 'r', encoding='utf-8') as f:
index = json.load(f)
if etag is None:
index.pop(key, None)
else:
index[key] = etag
with open(etag_index_path, 'w', encoding='utf-8') as f:
json.dump(index, f)
except (OSError, json.JSONDecodeError):
pass
def _schedule_etag_index_flush(self) -> None:
if self._etag_index_flush_timer is None or not self._etag_index_flush_timer.is_alive():
self._etag_index_flush_timer = threading.Timer(5.0, self._flush_etag_indexes)
self._etag_index_flush_timer.daemon = True
self._etag_index_flush_timer.start()
def _flush_etag_indexes(self) -> None:
dirty = set(self._etag_index_dirty)
self._etag_index_dirty.clear()
for bucket_id in dirty:
with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id)
if not cached:
continue
objects = cached[0]
index = {k: v.etag for k, v in objects.items() if v.etag}
etag_index_path = self._system_bucket_root(bucket_id) / "etag_index.json"
try:
etag_index_path.parent.mkdir(parents=True, exist_ok=True)
with open(etag_index_path, 'w', encoding='utf-8') as f:
json.dump(index, f)
except OSError:
pass
def warm_cache(self, bucket_names: Optional[List[str]] = None) -> None:
"""Pre-warm the object cache for specified buckets or all buckets.
@@ -2227,12 +2319,7 @@ class ObjectStorage:
if cached:
config, cached_time, cached_mtime = cached
if now - cached_time < self._bucket_config_cache_ttl:
try:
current_mtime = config_path.stat().st_mtime if config_path.exists() else 0.0
except OSError:
current_mtime = 0.0
if current_mtime == cached_mtime:
return config.copy()
return config.copy()
if not config_path.exists():
self._bucket_config_cache[bucket_name] = ({}, now, 0.0)
@@ -2301,19 +2388,19 @@ class ObjectStorage:
return meta_root / parent / "_index.json", entry_name
def _get_meta_index_lock(self, index_path: str) -> threading.Lock:
with self._cache_lock:
with self._registry_lock:
if index_path not in self._meta_index_locks:
self._meta_index_locks[index_path] = threading.Lock()
return self._meta_index_locks[index_path]
def _read_index_entry(self, bucket_name: str, key: Path) -> Optional[Dict[str, Any]]:
cache_key = (bucket_name, str(key))
with self._cache_lock:
with self._meta_cache_lock:
hit = self._meta_read_cache.get(cache_key)
if hit is not None:
self._meta_read_cache.move_to_end(cache_key)
cached = hit[0]
return copy.deepcopy(cached) if cached is not None else None
return dict(cached) if cached is not None else None
index_path, entry_name = self._index_file_for_key(bucket_name, key)
if _HAS_RUST:
@@ -2328,16 +2415,16 @@ class ObjectStorage:
except (OSError, json.JSONDecodeError):
result = None
with self._cache_lock:
with self._meta_cache_lock:
while len(self._meta_read_cache) >= self._meta_read_cache_max:
self._meta_read_cache.popitem(last=False)
self._meta_read_cache[cache_key] = (copy.deepcopy(result) if result is not None else None,)
self._meta_read_cache[cache_key] = (dict(result) if result is not None else None,)
return result
def _invalidate_meta_read_cache(self, bucket_name: str, key: Path) -> None:
cache_key = (bucket_name, str(key))
with self._cache_lock:
with self._meta_cache_lock:
self._meta_read_cache.pop(cache_key, None)
def _write_index_entry(self, bucket_name: str, key: Path, entry: Dict[str, Any]) -> None:

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
APP_VERSION = "0.3.3"
APP_VERSION = "0.3.5"
def get_version() -> str:

View File

@@ -19,3 +19,6 @@ regex = "1"
lru = "0.14"
parking_lot = "0.12"
percent-encoding = "2"
aes-gcm = "0.10"
hkdf = "0.12"
uuid = { version = "1", features = ["v4"] }

192
myfsio_core/src/crypto.rs Normal file
View File

@@ -0,0 +1,192 @@
use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use hkdf::Hkdf;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use sha2::Sha256;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
const DEFAULT_CHUNK_SIZE: usize = 65536;
const HEADER_SIZE: usize = 4;
fn read_exact_chunk(reader: &mut impl Read, buf: &mut [u8]) -> std::io::Result<usize> {
let mut filled = 0;
while filled < buf.len() {
match reader.read(&mut buf[filled..]) {
Ok(0) => break,
Ok(n) => filled += n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(filled)
}
fn derive_chunk_nonce(base_nonce: &[u8], chunk_index: u32) -> Result<[u8; 12], String> {
let hkdf = Hkdf::<Sha256>::new(Some(base_nonce), b"chunk_nonce");
let mut okm = [0u8; 12];
hkdf.expand(&chunk_index.to_be_bytes(), &mut okm)
.map_err(|e| format!("HKDF expand failed: {}", e))?;
Ok(okm)
}
#[pyfunction]
#[pyo3(signature = (input_path, output_path, key, base_nonce, chunk_size=DEFAULT_CHUNK_SIZE))]
pub fn encrypt_stream_chunked(
py: Python<'_>,
input_path: &str,
output_path: &str,
key: &[u8],
base_nonce: &[u8],
chunk_size: usize,
) -> PyResult<u32> {
if key.len() != 32 {
return Err(PyValueError::new_err(format!(
"Key must be 32 bytes, got {}",
key.len()
)));
}
if base_nonce.len() != 12 {
return Err(PyValueError::new_err(format!(
"Base nonce must be 12 bytes, got {}",
base_nonce.len()
)));
}
let chunk_size = if chunk_size == 0 {
DEFAULT_CHUNK_SIZE
} else {
chunk_size
};
let inp = input_path.to_owned();
let out = output_path.to_owned();
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
py.detach(move || {
let cipher = Aes256Gcm::new(&key_arr.into());
let mut infile = File::open(&inp)
.map_err(|e| PyIOError::new_err(format!("Failed to open input: {}", e)))?;
let mut outfile = File::create(&out)
.map_err(|e| PyIOError::new_err(format!("Failed to create output: {}", e)))?;
outfile
.write_all(&[0u8; 4])
.map_err(|e| PyIOError::new_err(format!("Failed to write header: {}", e)))?;
let mut buf = vec![0u8; chunk_size];
let mut chunk_index: u32 = 0;
loop {
let n = read_exact_chunk(&mut infile, &mut buf)
.map_err(|e| PyIOError::new_err(format!("Failed to read: {}", e)))?;
if n == 0 {
break;
}
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)
.map_err(|e| PyValueError::new_err(e))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let encrypted = cipher
.encrypt(nonce, &buf[..n])
.map_err(|e| PyValueError::new_err(format!("Encrypt failed: {}", e)))?;
let size = encrypted.len() as u32;
outfile
.write_all(&size.to_be_bytes())
.map_err(|e| PyIOError::new_err(format!("Failed to write chunk size: {}", e)))?;
outfile
.write_all(&encrypted)
.map_err(|e| PyIOError::new_err(format!("Failed to write chunk: {}", e)))?;
chunk_index += 1;
}
outfile
.seek(SeekFrom::Start(0))
.map_err(|e| PyIOError::new_err(format!("Failed to seek: {}", e)))?;
outfile
.write_all(&chunk_index.to_be_bytes())
.map_err(|e| PyIOError::new_err(format!("Failed to write chunk count: {}", e)))?;
Ok(chunk_index)
})
}
#[pyfunction]
pub fn decrypt_stream_chunked(
py: Python<'_>,
input_path: &str,
output_path: &str,
key: &[u8],
base_nonce: &[u8],
) -> PyResult<u32> {
if key.len() != 32 {
return Err(PyValueError::new_err(format!(
"Key must be 32 bytes, got {}",
key.len()
)));
}
if base_nonce.len() != 12 {
return Err(PyValueError::new_err(format!(
"Base nonce must be 12 bytes, got {}",
base_nonce.len()
)));
}
let inp = input_path.to_owned();
let out = output_path.to_owned();
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
py.detach(move || {
let cipher = Aes256Gcm::new(&key_arr.into());
let mut infile = File::open(&inp)
.map_err(|e| PyIOError::new_err(format!("Failed to open input: {}", e)))?;
let mut outfile = File::create(&out)
.map_err(|e| PyIOError::new_err(format!("Failed to create output: {}", e)))?;
let mut header = [0u8; HEADER_SIZE];
infile
.read_exact(&mut header)
.map_err(|e| PyIOError::new_err(format!("Failed to read header: {}", e)))?;
let chunk_count = u32::from_be_bytes(header);
let mut size_buf = [0u8; HEADER_SIZE];
for chunk_index in 0..chunk_count {
infile
.read_exact(&mut size_buf)
.map_err(|e| {
PyIOError::new_err(format!(
"Failed to read chunk {} size: {}",
chunk_index, e
))
})?;
let chunk_size = u32::from_be_bytes(size_buf) as usize;
let mut encrypted = vec![0u8; chunk_size];
infile.read_exact(&mut encrypted).map_err(|e| {
PyIOError::new_err(format!("Failed to read chunk {}: {}", chunk_index, e))
})?;
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)
.map_err(|e| PyValueError::new_err(e))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let decrypted = cipher.decrypt(nonce, encrypted.as_ref()).map_err(|e| {
PyValueError::new_err(format!("Decrypt chunk {} failed: {}", chunk_index, e))
})?;
outfile.write_all(&decrypted).map_err(|e| {
PyIOError::new_err(format!("Failed to write chunk {}: {}", chunk_index, e))
})?;
}
Ok(chunk_count)
})
}

View File

@@ -1,7 +1,9 @@
mod crypto;
mod hashing;
mod metadata;
mod sigv4;
mod storage;
mod streaming;
mod validation;
use pyo3::prelude::*;
@@ -38,6 +40,12 @@ mod myfsio_core {
m.add_function(wrap_pyfunction!(storage::search_objects_scan, m)?)?;
m.add_function(wrap_pyfunction!(storage::build_object_cache, m)?)?;
m.add_function(wrap_pyfunction!(streaming::stream_to_file_with_md5, m)?)?;
m.add_function(wrap_pyfunction!(streaming::assemble_parts_with_md5, m)?)?;
m.add_function(wrap_pyfunction!(crypto::encrypt_stream_chunked, m)?)?;
m.add_function(wrap_pyfunction!(crypto::decrypt_stream_chunked, m)?)?;
Ok(())
}
}

View File

@@ -0,0 +1,107 @@
use md5::{Digest, Md5};
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use std::fs::{self, File};
use std::io::{Read, Write};
use uuid::Uuid;
const DEFAULT_CHUNK_SIZE: usize = 262144;
#[pyfunction]
#[pyo3(signature = (stream, tmp_dir, chunk_size=DEFAULT_CHUNK_SIZE))]
pub fn stream_to_file_with_md5(
py: Python<'_>,
stream: &Bound<'_, PyAny>,
tmp_dir: &str,
chunk_size: usize,
) -> PyResult<(String, String, u64)> {
let chunk_size = if chunk_size == 0 {
DEFAULT_CHUNK_SIZE
} else {
chunk_size
};
fs::create_dir_all(tmp_dir)
.map_err(|e| PyIOError::new_err(format!("Failed to create tmp dir: {}", e)))?;
let tmp_name = format!("{}.tmp", Uuid::new_v4().as_hyphenated());
let tmp_path_buf = std::path::PathBuf::from(tmp_dir).join(&tmp_name);
let tmp_path = tmp_path_buf.to_string_lossy().into_owned();
let mut file = File::create(&tmp_path)
.map_err(|e| PyIOError::new_err(format!("Failed to create temp file: {}", e)))?;
let mut hasher = Md5::new();
let mut total_bytes: u64 = 0;
let result: PyResult<()> = (|| {
loop {
let chunk: Vec<u8> = stream.call_method1("read", (chunk_size,))?.extract()?;
if chunk.is_empty() {
break;
}
hasher.update(&chunk);
file.write_all(&chunk)
.map_err(|e| PyIOError::new_err(format!("Failed to write: {}", e)))?;
total_bytes += chunk.len() as u64;
py.check_signals()?;
}
Ok(())
})();
if let Err(e) = result {
drop(file);
let _ = fs::remove_file(&tmp_path);
return Err(e);
}
drop(file);
let md5_hex = format!("{:x}", hasher.finalize());
Ok((tmp_path, md5_hex, total_bytes))
}
#[pyfunction]
pub fn assemble_parts_with_md5(
py: Python<'_>,
part_paths: Vec<String>,
dest_path: &str,
) -> PyResult<String> {
if part_paths.is_empty() {
return Err(PyValueError::new_err("No parts to assemble"));
}
let dest = dest_path.to_owned();
let parts = part_paths;
py.detach(move || {
if let Some(parent) = std::path::Path::new(&dest).parent() {
fs::create_dir_all(parent)
.map_err(|e| PyIOError::new_err(format!("Failed to create dest dir: {}", e)))?;
}
let mut target = File::create(&dest)
.map_err(|e| PyIOError::new_err(format!("Failed to create dest file: {}", e)))?;
let mut hasher = Md5::new();
let mut buf = vec![0u8; 1024 * 1024];
for part_path in &parts {
let mut part = File::open(part_path)
.map_err(|e| PyIOError::new_err(format!("Failed to open part {}: {}", part_path, e)))?;
loop {
let n = part
.read(&mut buf)
.map_err(|e| PyIOError::new_err(format!("Failed to read part: {}", e)))?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
target
.write_all(&buf[..n])
.map_err(|e| PyIOError::new_err(format!("Failed to write: {}", e)))?;
}
}
Ok(format!("{:x}", hasher.finalize()))
})
}

View File

@@ -321,7 +321,7 @@
`;
};
const bucketTotalObjects = objectsContainer ? parseInt(objectsContainer.dataset.bucketTotalObjects || '0', 10) : 0;
let bucketTotalObjects = objectsContainer ? parseInt(objectsContainer.dataset.bucketTotalObjects || '0', 10) : 0;
const updateObjectCountBadge = () => {
if (!objectCountBadge) return;
@@ -702,6 +702,7 @@
flushPendingStreamObjects();
hasMoreObjects = false;
totalObjectCount = loadedObjectCount;
if (!currentPrefix) bucketTotalObjects = totalObjectCount;
updateObjectCountBadge();
if (objectsLoadingRow && objectsLoadingRow.parentNode) {
@@ -766,6 +767,7 @@
}
totalObjectCount = data.total_count || 0;
if (!append && !currentPrefix) bucketTotalObjects = totalObjectCount;
nextContinuationToken = data.next_continuation_token;
if (!append && objectsLoadingRow) {

View File

@@ -43,6 +43,11 @@ def app(tmp_path: Path):
}
)
yield flask_app
storage = flask_app.extensions.get("object_storage")
if storage:
base = getattr(storage, "storage", storage)
if hasattr(base, "shutdown_stats"):
base.shutdown_stats()
@pytest.fixture()

View File

@@ -53,7 +53,9 @@ def test_special_characters_in_metadata(tmp_path: Path):
assert meta["special"] == "!@#$%^&*()"
def test_disk_full_scenario(tmp_path: Path, monkeypatch):
# Simulate disk full by mocking write to fail
import app.storage as _storage_mod
monkeypatch.setattr(_storage_mod, "_HAS_RUST", False)
storage = ObjectStorage(tmp_path)
storage.create_bucket("full")

View File

@@ -0,0 +1,350 @@
import hashlib
import io
import os
import secrets
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
try:
import myfsio_core as _rc
HAS_RUST = True
except ImportError:
_rc = None
HAS_RUST = False
pytestmark = pytest.mark.skipif(not HAS_RUST, reason="myfsio_core not available")
class TestStreamToFileWithMd5:
def test_basic_write(self, tmp_path):
data = b"hello world" * 1000
stream = io.BytesIO(data)
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir)
assert size == len(data)
assert md5_hex == hashlib.md5(data).hexdigest()
assert Path(tmp_path_str).exists()
assert Path(tmp_path_str).read_bytes() == data
def test_empty_stream(self, tmp_path):
stream = io.BytesIO(b"")
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir)
assert size == 0
assert md5_hex == hashlib.md5(b"").hexdigest()
assert Path(tmp_path_str).read_bytes() == b""
def test_large_data(self, tmp_path):
data = os.urandom(1024 * 1024 * 2)
stream = io.BytesIO(data)
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir)
assert size == len(data)
assert md5_hex == hashlib.md5(data).hexdigest()
def test_custom_chunk_size(self, tmp_path):
data = b"x" * 10000
stream = io.BytesIO(data)
tmp_dir = str(tmp_path / "tmp")
tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(
stream, tmp_dir, chunk_size=128
)
assert size == len(data)
assert md5_hex == hashlib.md5(data).hexdigest()
class TestAssemblePartsWithMd5:
def test_basic_assembly(self, tmp_path):
parts = []
combined = b""
for i in range(3):
data = f"part{i}data".encode() * 100
combined += data
p = tmp_path / f"part{i}"
p.write_bytes(data)
parts.append(str(p))
dest = str(tmp_path / "output")
md5_hex = _rc.assemble_parts_with_md5(parts, dest)
assert md5_hex == hashlib.md5(combined).hexdigest()
assert Path(dest).read_bytes() == combined
def test_single_part(self, tmp_path):
data = b"single part data"
p = tmp_path / "part0"
p.write_bytes(data)
dest = str(tmp_path / "output")
md5_hex = _rc.assemble_parts_with_md5([str(p)], dest)
assert md5_hex == hashlib.md5(data).hexdigest()
assert Path(dest).read_bytes() == data
def test_empty_parts_list(self):
with pytest.raises(ValueError, match="No parts"):
_rc.assemble_parts_with_md5([], "dummy")
def test_missing_part_file(self, tmp_path):
with pytest.raises(OSError):
_rc.assemble_parts_with_md5(
[str(tmp_path / "nonexistent")], str(tmp_path / "out")
)
def test_large_parts(self, tmp_path):
parts = []
combined = b""
for i in range(5):
data = os.urandom(512 * 1024)
combined += data
p = tmp_path / f"part{i}"
p.write_bytes(data)
parts.append(str(p))
dest = str(tmp_path / "output")
md5_hex = _rc.assemble_parts_with_md5(parts, dest)
assert md5_hex == hashlib.md5(combined).hexdigest()
assert Path(dest).read_bytes() == combined
class TestEncryptDecryptStreamChunked:
def _python_derive_chunk_nonce(self, base_nonce, chunk_index):
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=12,
salt=base_nonce,
info=chunk_index.to_bytes(4, "big"),
)
return hkdf.derive(b"chunk_nonce")
def test_encrypt_decrypt_roundtrip(self, tmp_path):
data = b"Hello, encryption!" * 500
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce
)
assert chunk_count > 0
chunk_count_dec = _rc.decrypt_stream_chunked(
encrypted_path, decrypted_path, key, base_nonce
)
assert chunk_count_dec == chunk_count
assert Path(decrypted_path).read_bytes() == data
def test_empty_file(self, tmp_path):
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "empty")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(b"")
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce
)
assert chunk_count == 0
chunk_count_dec = _rc.decrypt_stream_chunked(
encrypted_path, decrypted_path, key, base_nonce
)
assert chunk_count_dec == 0
assert Path(decrypted_path).read_bytes() == b""
def test_custom_chunk_size(self, tmp_path):
data = os.urandom(10000)
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce, chunk_size=1024
)
assert chunk_count == 10
_rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce)
assert Path(decrypted_path).read_bytes() == data
def test_invalid_key_length(self, tmp_path):
input_path = str(tmp_path / "in")
Path(input_path).write_bytes(b"data")
with pytest.raises(ValueError, match="32 bytes"):
_rc.encrypt_stream_chunked(
input_path, str(tmp_path / "out"), b"short", secrets.token_bytes(12)
)
def test_invalid_nonce_length(self, tmp_path):
input_path = str(tmp_path / "in")
Path(input_path).write_bytes(b"data")
with pytest.raises(ValueError, match="12 bytes"):
_rc.encrypt_stream_chunked(
input_path, str(tmp_path / "out"), secrets.token_bytes(32), b"short"
)
def test_wrong_key_fails_decrypt(self, tmp_path):
data = b"sensitive data"
key = secrets.token_bytes(32)
wrong_key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
_rc.encrypt_stream_chunked(input_path, encrypted_path, key, base_nonce)
with pytest.raises((ValueError, OSError)):
_rc.decrypt_stream_chunked(
encrypted_path, decrypted_path, wrong_key, base_nonce
)
def test_cross_compat_python_encrypt_rust_decrypt(self, tmp_path):
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
data = b"cross compat test data" * 100
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
chunk_size = 1024
encrypted_path = str(tmp_path / "py_encrypted")
with open(encrypted_path, "wb") as f:
f.write(b"\x00\x00\x00\x00")
aesgcm = AESGCM(key)
chunk_index = 0
offset = 0
while offset < len(data):
chunk = data[offset:offset + chunk_size]
nonce = self._python_derive_chunk_nonce(base_nonce, chunk_index)
enc = aesgcm.encrypt(nonce, chunk, None)
f.write(len(enc).to_bytes(4, "big"))
f.write(enc)
chunk_index += 1
offset += chunk_size
f.seek(0)
f.write(chunk_index.to_bytes(4, "big"))
decrypted_path = str(tmp_path / "rust_decrypted")
_rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce)
assert Path(decrypted_path).read_bytes() == data
def test_cross_compat_rust_encrypt_python_decrypt(self, tmp_path):
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
data = b"cross compat reverse test" * 100
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
chunk_size = 1024
input_path = str(tmp_path / "plaintext")
encrypted_path = str(tmp_path / "rust_encrypted")
Path(input_path).write_bytes(data)
chunk_count = _rc.encrypt_stream_chunked(
input_path, encrypted_path, key, base_nonce, chunk_size=chunk_size
)
aesgcm = AESGCM(key)
with open(encrypted_path, "rb") as f:
count_bytes = f.read(4)
assert int.from_bytes(count_bytes, "big") == chunk_count
decrypted = b""
for i in range(chunk_count):
size = int.from_bytes(f.read(4), "big")
enc_chunk = f.read(size)
nonce = self._python_derive_chunk_nonce(base_nonce, i)
decrypted += aesgcm.decrypt(nonce, enc_chunk, None)
assert decrypted == data
def test_large_file_roundtrip(self, tmp_path):
data = os.urandom(1024 * 1024)
key = secrets.token_bytes(32)
base_nonce = secrets.token_bytes(12)
input_path = str(tmp_path / "large")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
_rc.encrypt_stream_chunked(input_path, encrypted_path, key, base_nonce)
_rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce)
assert Path(decrypted_path).read_bytes() == data
class TestStreamingEncryptorFileMethods:
def test_encrypt_file_decrypt_file_roundtrip(self, tmp_path):
from app.encryption import LocalKeyEncryption, StreamingEncryptor
master_key_path = tmp_path / "master.key"
provider = LocalKeyEncryption(master_key_path)
encryptor = StreamingEncryptor(provider, chunk_size=512)
data = b"file method test data" * 200
input_path = str(tmp_path / "input")
encrypted_path = str(tmp_path / "encrypted")
decrypted_path = str(tmp_path / "decrypted")
Path(input_path).write_bytes(data)
metadata = encryptor.encrypt_file(input_path, encrypted_path)
assert metadata.algorithm == "AES256"
encryptor.decrypt_file(encrypted_path, decrypted_path, metadata)
assert Path(decrypted_path).read_bytes() == data
def test_encrypt_file_matches_encrypt_stream(self, tmp_path):
from app.encryption import LocalKeyEncryption, StreamingEncryptor
master_key_path = tmp_path / "master.key"
provider = LocalKeyEncryption(master_key_path)
encryptor = StreamingEncryptor(provider, chunk_size=512)
data = b"stream vs file comparison" * 100
input_path = str(tmp_path / "input")
Path(input_path).write_bytes(data)
file_encrypted_path = str(tmp_path / "file_enc")
metadata_file = encryptor.encrypt_file(input_path, file_encrypted_path)
file_decrypted_path = str(tmp_path / "file_dec")
encryptor.decrypt_file(file_encrypted_path, file_decrypted_path, metadata_file)
assert Path(file_decrypted_path).read_bytes() == data
stream_enc, metadata_stream = encryptor.encrypt_stream(io.BytesIO(data))
stream_dec = encryptor.decrypt_stream(stream_enc, metadata_stream)
assert stream_dec.read() == data