Reduce CPU/lock contention under concurrent uploads: split cache lock, in-memory stats, dict copy, lightweight request IDs, defaultdict metrics

This commit is contained in:
2026-03-02 22:05:54 +08:00
parent 5536330aeb
commit 8552f193de
5 changed files with 97 additions and 80 deletions

View File

@@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
import html as html_module import html as html_module
import itertools
import logging import logging
import mimetypes import mimetypes
import os import os
import shutil import shutil
import sys import sys
import time import time
import uuid
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from pathlib import Path from pathlib import Path
from datetime import timedelta from datetime import timedelta
@@ -39,6 +39,8 @@ from .storage import ObjectStorage, StorageError
from .version import get_version from .version import get_version
from .website_domains import WebsiteDomainStore from .website_domains import WebsiteDomainStore
_request_counter = itertools.count(1)
def _migrate_config_file(active_path: Path, legacy_paths: List[Path]) -> Path: def _migrate_config_file(active_path: Path, legacy_paths: List[Path]) -> Path:
"""Migrate config file from legacy locations to the active path. """Migrate config file from legacy locations to the active path.
@@ -481,7 +483,7 @@ def _configure_logging(app: Flask) -> None:
@app.before_request @app.before_request
def _log_request_start() -> None: def _log_request_start() -> None:
g.request_id = uuid.uuid4().hex g.request_id = f"{os.getpid():x}{next(_request_counter):012x}"
g.request_started_at = time.perf_counter() g.request_started_at = time.perf_counter()
g.request_bytes_in = request.content_length or 0 g.request_bytes_in = request.content_length or 0
app.logger.info( app.logger.info(
@@ -616,7 +618,7 @@ def _configure_logging(app: Flask) -> None:
duration_ms = 0.0 duration_ms = 0.0
if hasattr(g, "request_started_at"): if hasattr(g, "request_started_at"):
duration_ms = (time.perf_counter() - g.request_started_at) * 1000 duration_ms = (time.perf_counter() - g.request_started_at) * 1000
request_id = getattr(g, "request_id", uuid.uuid4().hex) request_id = getattr(g, "request_id", f"{os.getpid():x}{next(_request_counter):012x}")
response.headers.setdefault("X-Request-ID", request_id) response.headers.setdefault("X-Request-ID", request_id)
app.logger.info( app.logger.info(
"Request completed", "Request completed",

View File

@@ -5,6 +5,7 @@ import logging
import random import random
import threading import threading
import time import time
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@@ -138,8 +139,8 @@ class OperationMetricsCollector:
self.interval_seconds = interval_minutes * 60 self.interval_seconds = interval_minutes * 60
self.retention_hours = retention_hours self.retention_hours = retention_hours
self._lock = threading.Lock() self._lock = threading.Lock()
self._by_method: Dict[str, OperationStats] = {} self._by_method: Dict[str, OperationStats] = defaultdict(OperationStats)
self._by_endpoint: Dict[str, OperationStats] = {} self._by_endpoint: Dict[str, OperationStats] = defaultdict(OperationStats)
self._by_status_class: Dict[str, int] = {} self._by_status_class: Dict[str, int] = {}
self._error_codes: Dict[str, int] = {} self._error_codes: Dict[str, int] = {}
self._totals = OperationStats() self._totals = OperationStats()
@@ -211,8 +212,8 @@ class OperationMetricsCollector:
self._prune_old_snapshots() self._prune_old_snapshots()
self._save_history() self._save_history()
self._by_method.clear() self._by_method = defaultdict(OperationStats)
self._by_endpoint.clear() self._by_endpoint = defaultdict(OperationStats)
self._by_status_class.clear() self._by_status_class.clear()
self._error_codes.clear() self._error_codes.clear()
self._totals = OperationStats() self._totals = OperationStats()
@@ -232,12 +233,7 @@ class OperationMetricsCollector:
status_class = f"{status_code // 100}xx" status_class = f"{status_code // 100}xx"
with self._lock: with self._lock:
if method not in self._by_method:
self._by_method[method] = OperationStats()
self._by_method[method].record(latency_ms, success, bytes_in, bytes_out) self._by_method[method].record(latency_ms, success, bytes_in, bytes_out)
if endpoint_type not in self._by_endpoint:
self._by_endpoint[endpoint_type] = OperationStats()
self._by_endpoint[endpoint_type].record(latency_ms, success, bytes_in, bytes_out) self._by_endpoint[endpoint_type].record(latency_ms, success, bytes_in, bytes_out)
self._by_status_class[status_class] = self._by_status_class.get(status_class, 0) + 1 self._by_status_class[status_class] = self._by_status_class.get(status_class, 0) + 1

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import copy
import hashlib import hashlib
import json import json
import os import os
@@ -196,7 +195,9 @@ class ObjectStorage:
self.root.mkdir(parents=True, exist_ok=True) self.root.mkdir(parents=True, exist_ok=True)
self._ensure_system_roots() self._ensure_system_roots()
self._object_cache: OrderedDict[str, tuple[Dict[str, ObjectMeta], float, float]] = OrderedDict() self._object_cache: OrderedDict[str, tuple[Dict[str, ObjectMeta], float, float]] = OrderedDict()
self._cache_lock = threading.Lock() self._obj_cache_lock = threading.Lock()
self._meta_cache_lock = threading.Lock()
self._registry_lock = threading.Lock()
self._bucket_locks: Dict[str, threading.Lock] = {} self._bucket_locks: Dict[str, threading.Lock] = {}
self._cache_version: Dict[str, int] = {} self._cache_version: Dict[str, int] = {}
self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {} self._bucket_config_cache: Dict[str, tuple[dict[str, Any], float]] = {}
@@ -209,10 +210,14 @@ class ObjectStorage:
self._meta_read_cache: OrderedDict[tuple, Optional[Dict[str, Any]]] = OrderedDict() self._meta_read_cache: OrderedDict[tuple, Optional[Dict[str, Any]]] = OrderedDict()
self._meta_read_cache_max = 2048 self._meta_read_cache_max = 2048
self._cleanup_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ParentCleanup") self._cleanup_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ParentCleanup")
self._stats_mem: Dict[str, Dict[str, int]] = {}
self._stats_serial: Dict[str, int] = {}
self._stats_lock = threading.Lock()
self._stats_dirty: set[str] = set()
self._stats_flush_timer: Optional[threading.Timer] = None
def _get_bucket_lock(self, bucket_id: str) -> threading.Lock: def _get_bucket_lock(self, bucket_id: str) -> threading.Lock:
"""Get or create a lock for a specific bucket. Reduces global lock contention.""" with self._registry_lock:
with self._cache_lock:
if bucket_id not in self._bucket_locks: if bucket_id not in self._bucket_locks:
self._bucket_locks[bucket_id] = threading.Lock() self._bucket_locks[bucket_id] = threading.Lock()
return self._bucket_locks[bucket_id] return self._bucket_locks[bucket_id]
@@ -260,26 +265,20 @@ class ObjectStorage:
self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True) self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True)
def bucket_stats(self, bucket_name: str, cache_ttl: int = 60) -> dict[str, int]: def bucket_stats(self, bucket_name: str, cache_ttl: int = 60) -> dict[str, int]:
"""Return object count and total size for the bucket (cached).
Args:
bucket_name: Name of the bucket
cache_ttl: Cache time-to-live in seconds (default 60)
"""
bucket_path = self._bucket_path(bucket_name) bucket_path = self._bucket_path(bucket_name)
if not bucket_path.exists(): if not bucket_path.exists():
raise BucketNotFoundError("Bucket does not exist") raise BucketNotFoundError("Bucket does not exist")
with self._stats_lock:
if bucket_name in self._stats_mem:
return dict(self._stats_mem[bucket_name])
cache_path = self._system_bucket_root(bucket_name) / "stats.json" cache_path = self._system_bucket_root(bucket_name) / "stats.json"
cached_stats = None cached_stats = None
cache_fresh = False
if cache_path.exists(): if cache_path.exists():
try: try:
cache_fresh = time.time() - cache_path.stat().st_mtime < cache_ttl
cached_stats = json.loads(cache_path.read_text(encoding="utf-8")) cached_stats = json.loads(cache_path.read_text(encoding="utf-8"))
if cache_fresh:
return cached_stats
except (OSError, json.JSONDecodeError): except (OSError, json.JSONDecodeError):
pass pass
@@ -348,6 +347,11 @@ class ObjectStorage:
"_cache_serial": existing_serial, "_cache_serial": existing_serial,
} }
with self._stats_lock:
if bucket_name not in self._stats_mem:
self._stats_mem[bucket_name] = stats
self._stats_serial[bucket_name] = existing_serial
try: try:
cache_path.parent.mkdir(parents=True, exist_ok=True) cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_text(json.dumps(stats), encoding="utf-8") cache_path.write_text(json.dumps(stats), encoding="utf-8")
@@ -357,7 +361,10 @@ class ObjectStorage:
return stats return stats
def _invalidate_bucket_stats_cache(self, bucket_id: str) -> None: def _invalidate_bucket_stats_cache(self, bucket_id: str) -> None:
"""Invalidate the cached bucket statistics.""" with self._stats_lock:
self._stats_mem.pop(bucket_id, None)
self._stats_serial[bucket_id] = self._stats_serial.get(bucket_id, 0) + 1
self._stats_dirty.discard(bucket_id)
cache_path = self._system_bucket_root(bucket_id) / "stats.json" cache_path = self._system_bucket_root(bucket_id) / "stats.json"
try: try:
cache_path.unlink(missing_ok=True) cache_path.unlink(missing_ok=True)
@@ -373,29 +380,48 @@ class ObjectStorage:
version_bytes_delta: int = 0, version_bytes_delta: int = 0,
version_count_delta: int = 0, version_count_delta: int = 0,
) -> None: ) -> None:
"""Incrementally update cached bucket statistics instead of invalidating. with self._stats_lock:
if bucket_id not in self._stats_mem:
self._stats_mem[bucket_id] = {
"objects": 0, "bytes": 0, "version_count": 0,
"version_bytes": 0, "total_objects": 0, "total_bytes": 0,
"_cache_serial": 0,
}
data = self._stats_mem[bucket_id]
data["objects"] = max(0, data["objects"] + objects_delta)
data["bytes"] = max(0, data["bytes"] + bytes_delta)
data["version_count"] = max(0, data["version_count"] + version_count_delta)
data["version_bytes"] = max(0, data["version_bytes"] + version_bytes_delta)
data["total_objects"] = max(0, data["total_objects"] + objects_delta + version_count_delta)
data["total_bytes"] = max(0, data["total_bytes"] + bytes_delta + version_bytes_delta)
data["_cache_serial"] = data["_cache_serial"] + 1
self._stats_serial[bucket_id] = self._stats_serial.get(bucket_id, 0) + 1
self._stats_dirty.add(bucket_id)
self._schedule_stats_flush()
This avoids expensive full directory scans on every PUT/DELETE by def _schedule_stats_flush(self) -> None:
adjusting the cached values directly. Also signals cross-process cache if self._stats_flush_timer is None or not self._stats_flush_timer.is_alive():
invalidation by incrementing _cache_serial. self._stats_flush_timer = threading.Timer(3.0, self._flush_stats)
""" self._stats_flush_timer.daemon = True
cache_path = self._system_bucket_root(bucket_id) / "stats.json" self._stats_flush_timer.start()
try:
cache_path.parent.mkdir(parents=True, exist_ok=True) def _flush_stats(self) -> None:
if cache_path.exists(): with self._stats_lock:
data = json.loads(cache_path.read_text(encoding="utf-8")) dirty = list(self._stats_dirty)
else: self._stats_dirty.clear()
data = {"objects": 0, "bytes": 0, "version_count": 0, "version_bytes": 0, "total_objects": 0, "total_bytes": 0, "_cache_serial": 0} snapshots = {b: dict(self._stats_mem[b]) for b in dirty if b in self._stats_mem}
data["objects"] = max(0, data.get("objects", 0) + objects_delta) for bucket_id, data in snapshots.items():
data["bytes"] = max(0, data.get("bytes", 0) + bytes_delta) cache_path = self._system_bucket_root(bucket_id) / "stats.json"
data["version_count"] = max(0, data.get("version_count", 0) + version_count_delta) try:
data["version_bytes"] = max(0, data.get("version_bytes", 0) + version_bytes_delta) cache_path.parent.mkdir(parents=True, exist_ok=True)
data["total_objects"] = max(0, data.get("total_objects", 0) + objects_delta + version_count_delta) cache_path.write_text(json.dumps(data), encoding="utf-8")
data["total_bytes"] = max(0, data.get("total_bytes", 0) + bytes_delta + version_bytes_delta) except OSError:
data["_cache_serial"] = data.get("_cache_serial", 0) + 1 pass
cache_path.write_text(json.dumps(data), encoding="utf-8")
except (OSError, json.JSONDecodeError): def shutdown_stats(self) -> None:
pass if self._stats_flush_timer is not None:
self._stats_flush_timer.cancel()
self._flush_stats()
def delete_bucket(self, bucket_name: str) -> None: def delete_bucket(self, bucket_name: str) -> None:
bucket_path = self._bucket_path(bucket_name) bucket_path = self._bucket_path(bucket_name)
@@ -413,13 +439,18 @@ class ObjectStorage:
self._remove_tree(self._system_bucket_root(bucket_id)) self._remove_tree(self._system_bucket_root(bucket_id))
self._remove_tree(self._multipart_bucket_root(bucket_id)) self._remove_tree(self._multipart_bucket_root(bucket_id))
self._bucket_config_cache.pop(bucket_id, None) self._bucket_config_cache.pop(bucket_id, None)
with self._cache_lock: with self._obj_cache_lock:
self._object_cache.pop(bucket_id, None) self._object_cache.pop(bucket_id, None)
self._cache_version.pop(bucket_id, None) self._cache_version.pop(bucket_id, None)
self._sorted_key_cache.pop(bucket_id, None) self._sorted_key_cache.pop(bucket_id, None)
with self._meta_cache_lock:
stale = [k for k in self._meta_read_cache if k[0] == bucket_id] stale = [k for k in self._meta_read_cache if k[0] == bucket_id]
for k in stale: for k in stale:
del self._meta_read_cache[k] del self._meta_read_cache[k]
with self._stats_lock:
self._stats_mem.pop(bucket_id, None)
self._stats_serial.pop(bucket_id, None)
self._stats_dirty.discard(bucket_id)
def list_objects( def list_objects(
self, self,
@@ -2131,7 +2162,7 @@ class ObjectStorage:
now = time.time() now = time.time()
current_stats_mtime = self._get_cache_marker_mtime(bucket_id) current_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock: with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id) cached = self._object_cache.get(bucket_id)
if cached: if cached:
objects, timestamp, cached_stats_mtime = cached objects, timestamp, cached_stats_mtime = cached
@@ -2143,7 +2174,7 @@ class ObjectStorage:
bucket_lock = self._get_bucket_lock(bucket_id) bucket_lock = self._get_bucket_lock(bucket_id)
with bucket_lock: with bucket_lock:
current_stats_mtime = self._get_cache_marker_mtime(bucket_id) current_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock: with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id) cached = self._object_cache.get(bucket_id)
if cached: if cached:
objects, timestamp, cached_stats_mtime = cached objects, timestamp, cached_stats_mtime = cached
@@ -2154,7 +2185,7 @@ class ObjectStorage:
objects = self._build_object_cache(bucket_path) objects = self._build_object_cache(bucket_path)
new_stats_mtime = self._get_cache_marker_mtime(bucket_id) new_stats_mtime = self._get_cache_marker_mtime(bucket_id)
with self._cache_lock: with self._obj_cache_lock:
current_version = self._cache_version.get(bucket_id, 0) current_version = self._cache_version.get(bucket_id, 0)
if current_version != cache_version: if current_version != cache_version:
objects = self._build_object_cache(bucket_path) objects = self._build_object_cache(bucket_path)
@@ -2170,12 +2201,7 @@ class ObjectStorage:
return objects return objects
def _invalidate_object_cache(self, bucket_id: str) -> None: def _invalidate_object_cache(self, bucket_id: str) -> None:
"""Invalidate the object cache and etag index for a bucket. with self._obj_cache_lock:
Increments version counter to signal stale reads.
Cross-process invalidation is handled by checking stats.json mtime.
"""
with self._cache_lock:
self._object_cache.pop(bucket_id, None) self._object_cache.pop(bucket_id, None)
self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1 self._cache_version[bucket_id] = self._cache_version.get(bucket_id, 0) + 1
@@ -2186,22 +2212,10 @@ class ObjectStorage:
pass pass
def _get_cache_marker_mtime(self, bucket_id: str) -> float: def _get_cache_marker_mtime(self, bucket_id: str) -> float:
"""Get a cache marker combining serial and object count for cross-process invalidation. return float(self._stats_serial.get(bucket_id, 0))
Returns a combined value that changes if either _cache_serial or object count changes.
This handles cases where the serial was reset but object count differs.
"""
stats_path = self._system_bucket_root(bucket_id) / "stats.json"
try:
data = json.loads(stats_path.read_text(encoding="utf-8"))
serial = data.get("_cache_serial", 0)
count = data.get("objects", 0)
return float(serial * 1000000 + count)
except (OSError, json.JSONDecodeError):
return 0
def _update_object_cache_entry(self, bucket_id: str, key: str, meta: Optional[ObjectMeta]) -> None: def _update_object_cache_entry(self, bucket_id: str, key: str, meta: Optional[ObjectMeta]) -> None:
with self._cache_lock: with self._obj_cache_lock:
cached = self._object_cache.get(bucket_id) cached = self._object_cache.get(bucket_id)
if cached: if cached:
objects, timestamp, stats_mtime = cached objects, timestamp, stats_mtime = cached
@@ -2359,19 +2373,19 @@ class ObjectStorage:
return meta_root / parent / "_index.json", entry_name return meta_root / parent / "_index.json", entry_name
def _get_meta_index_lock(self, index_path: str) -> threading.Lock: def _get_meta_index_lock(self, index_path: str) -> threading.Lock:
with self._cache_lock: with self._registry_lock:
if index_path not in self._meta_index_locks: if index_path not in self._meta_index_locks:
self._meta_index_locks[index_path] = threading.Lock() self._meta_index_locks[index_path] = threading.Lock()
return self._meta_index_locks[index_path] return self._meta_index_locks[index_path]
def _read_index_entry(self, bucket_name: str, key: Path) -> Optional[Dict[str, Any]]: def _read_index_entry(self, bucket_name: str, key: Path) -> Optional[Dict[str, Any]]:
cache_key = (bucket_name, str(key)) cache_key = (bucket_name, str(key))
with self._cache_lock: with self._meta_cache_lock:
hit = self._meta_read_cache.get(cache_key) hit = self._meta_read_cache.get(cache_key)
if hit is not None: if hit is not None:
self._meta_read_cache.move_to_end(cache_key) self._meta_read_cache.move_to_end(cache_key)
cached = hit[0] cached = hit[0]
return copy.deepcopy(cached) if cached is not None else None return dict(cached) if cached is not None else None
index_path, entry_name = self._index_file_for_key(bucket_name, key) index_path, entry_name = self._index_file_for_key(bucket_name, key)
if _HAS_RUST: if _HAS_RUST:
@@ -2386,16 +2400,16 @@ class ObjectStorage:
except (OSError, json.JSONDecodeError): except (OSError, json.JSONDecodeError):
result = None result = None
with self._cache_lock: with self._meta_cache_lock:
while len(self._meta_read_cache) >= self._meta_read_cache_max: while len(self._meta_read_cache) >= self._meta_read_cache_max:
self._meta_read_cache.popitem(last=False) self._meta_read_cache.popitem(last=False)
self._meta_read_cache[cache_key] = (copy.deepcopy(result) if result is not None else None,) self._meta_read_cache[cache_key] = (dict(result) if result is not None else None,)
return result return result
def _invalidate_meta_read_cache(self, bucket_name: str, key: Path) -> None: def _invalidate_meta_read_cache(self, bucket_name: str, key: Path) -> None:
cache_key = (bucket_name, str(key)) cache_key = (bucket_name, str(key))
with self._cache_lock: with self._meta_cache_lock:
self._meta_read_cache.pop(cache_key, None) self._meta_read_cache.pop(cache_key, None)
def _write_index_entry(self, bucket_name: str, key: Path, entry: Dict[str, Any]) -> None: def _write_index_entry(self, bucket_name: str, key: Path, entry: Dict[str, Any]) -> None:

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
APP_VERSION = "0.3.4" APP_VERSION = "0.3.5"
def get_version() -> str: def get_version() -> str:

View File

@@ -43,6 +43,11 @@ def app(tmp_path: Path):
} }
) )
yield flask_app yield flask_app
storage = flask_app.extensions.get("object_storage")
if storage:
base = getattr(storage, "storage", storage)
if hasattr(base, "shutdown_stats"):
base.shutdown_stats()
@pytest.fixture() @pytest.fixture()