Add new metrics function

This commit is contained in:
2025-11-23 23:33:53 +08:00
parent e6ee341b93
commit 85ee5b9388
8 changed files with 569 additions and 14 deletions

View File

@@ -1,10 +1,17 @@
"""Application-wide extension instances."""
from flask import g
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_wtf import CSRFProtect
def get_rate_limit_key():
"""Generate rate limit key based on authenticated user."""
if hasattr(g, 'principal') and g.principal:
return g.principal.access_key
return get_remote_address()
# Shared rate limiter instance; configured in app factory.
limiter = Limiter(key_func=get_remote_address)
limiter = Limiter(key_func=get_rate_limit_key)
# Global CSRF protection for UI routes.
csrf = CSRFProtect()

View File

@@ -11,7 +11,7 @@ from typing import Any, Dict
from urllib.parse import quote, urlencode, urlparse
from xml.etree.ElementTree import Element, SubElement, tostring, fromstring, ParseError
from flask import Blueprint, Response, current_app, jsonify, request
from flask import Blueprint, Response, current_app, jsonify, request, g
from werkzeug.http import http_date
from .bucket_policies import BucketPolicyStore
@@ -127,14 +127,33 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None:
if not amz_date:
raise IamError("Missing Date header")
try:
request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc)
except ValueError:
raise IamError("Invalid X-Amz-Date format")
now = datetime.now(timezone.utc)
time_diff = abs((now - request_time).total_seconds())
if time_diff > 900: # 15 minutes
raise IamError("Request timestamp too old or too far in the future")
required_headers = {'host', 'x-amz-date'}
signed_headers_set = set(signed_headers_str.split(';'))
if not required_headers.issubset(signed_headers_set):
# Some clients might sign 'date' instead of 'x-amz-date'
if 'date' in signed_headers_set:
required_headers.remove('x-amz-date')
required_headers.add('date')
if not required_headers.issubset(signed_headers_set):
raise IamError("Required headers not signed")
credential_scope = f"{date_stamp}/{region}/{service}/aws4_request"
string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}"
# Calculate Signature
signing_key = _get_signature_key(secret_key, date_stamp, region, service)
calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
if calculated_signature != signature:
if not hmac.compare_digest(calculated_signature, signature):
raise IamError("SignatureDoesNotMatch")
return _iam().get_principal(access_key)
@@ -155,7 +174,6 @@ def _verify_sigv4_query(req: Any) -> Principal | None:
except ValueError:
raise IamError("Invalid Credential format")
# Check expiration
try:
req_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc)
except ValueError:
@@ -190,7 +208,6 @@ def _verify_sigv4_query(req: Any) -> Principal | None:
canonical_headers_parts = []
for header in signed_headers_list:
val = req.headers.get(header, "").strip()
# Collapse multiple spaces
val = " ".join(val.split())
canonical_headers_parts.append(f"{header}:{val}\n")
canonical_headers = "".join(canonical_headers_parts)
@@ -240,7 +257,6 @@ def _verify_sigv4(req: Any) -> Principal | None:
def _require_principal():
# Try SigV4 first
if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \
(request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"):
try:
@@ -1132,6 +1148,9 @@ def object_handler(bucket_name: str, object_key: str):
return response
if request.method in {"GET", "HEAD"}:
if request.method == "GET" and "uploadId" in request.args:
return _list_parts(bucket_name, object_key)
_, error = _object_principal("read", bucket_name, object_key)
if error:
return error
@@ -1157,7 +1176,6 @@ def object_handler(bucket_name: str, object_key: str):
current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes})
return response
# DELETE
if "uploadId" in request.args:
return _abort_multipart_upload(bucket_name, object_key)
@@ -1175,6 +1193,51 @@ def object_handler(bucket_name: str, object_key: str):
return Response(status=204)
def _list_parts(bucket_name: str, object_key: str) -> Response:
principal, error = _require_principal()
if error:
return error
try:
_authorize_action(principal, bucket_name, "read", object_key=object_key)
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
upload_id = request.args.get("uploadId")
if not upload_id:
return _error_response("InvalidArgument", "uploadId is required", 400)
try:
parts = _storage().list_multipart_parts(bucket_name, upload_id)
except StorageError as exc:
return _error_response("NoSuchUpload", str(exc), 404)
root = Element("ListPartsResult")
SubElement(root, "Bucket").text = bucket_name
SubElement(root, "Key").text = object_key
SubElement(root, "UploadId").text = upload_id
initiator = SubElement(root, "Initiator")
SubElement(initiator, "ID").text = principal.access_key
SubElement(initiator, "DisplayName").text = principal.display_name
owner = SubElement(root, "Owner")
SubElement(owner, "ID").text = principal.access_key
SubElement(owner, "DisplayName").text = principal.display_name
SubElement(root, "StorageClass").text = "STANDARD"
SubElement(root, "PartNumberMarker").text = "0"
SubElement(root, "NextPartNumberMarker").text = str(parts[-1]["PartNumber"]) if parts else "0"
SubElement(root, "MaxParts").text = "1000"
SubElement(root, "IsTruncated").text = "false"
for part in parts:
p = SubElement(root, "Part")
SubElement(p, "PartNumber").text = str(part["PartNumber"])
SubElement(p, "LastModified").text = part["LastModified"].isoformat()
SubElement(p, "ETag").text = f'"{part["ETag"]}"'
SubElement(p, "Size").text = str(part["Size"])
return _xml_response(root)
@s3_api_bp.route("/bucket-policy/<bucket_name>", methods=["GET", "PUT", "DELETE"])
@@ -1504,3 +1567,25 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response:
return _error_response("NoSuchBucket", str(exc), 404)
return Response(status=204)
@s3_api_bp.before_request
def resolve_principal():
g.principal = None
# Try SigV4
try:
if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \
(request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"):
g.principal = _verify_sigv4(request)
return
except Exception:
pass
# Try simple auth headers (internal/testing)
access_key = request.headers.get("X-Access-Key")
secret_key = request.headers.get("X-Secret-Key")
if access_key and secret_key:
try:
g.principal = _iam().authenticate(access_key, secret_key)
except Exception:
pass

View File

@@ -120,10 +120,22 @@ class ObjectStorage:
self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True)
def bucket_stats(self, bucket_name: str) -> dict[str, int]:
"""Return object count and total size for the bucket without hashing files."""
"""Return object count and total size for the bucket (cached)."""
bucket_path = self._bucket_path(bucket_name)
if not bucket_path.exists():
raise StorageError("Bucket does not exist")
# Try to read from cache
cache_path = self._system_bucket_root(bucket_name) / "stats.json"
if cache_path.exists():
try:
# Check if cache is fresh (e.g., < 60 seconds old)
if time.time() - cache_path.stat().st_mtime < 60:
return json.loads(cache_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
pass
# Calculate fresh stats
object_count = 0
total_bytes = 0
for path in bucket_path.rglob("*"):
@@ -134,7 +146,17 @@ class ObjectStorage:
stat = path.stat()
object_count += 1
total_bytes += stat.st_size
return {"objects": object_count, "bytes": total_bytes}
stats = {"objects": object_count, "bytes": total_bytes}
# Write to cache
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_text(json.dumps(stats), encoding="utf-8")
except OSError:
pass
return stats
def delete_bucket(self, bucket_name: str) -> None:
bucket_path = self._bucket_path(bucket_name)
@@ -239,7 +261,6 @@ class ObjectStorage:
rel = path.relative_to(bucket_path)
self._safe_unlink(path)
self._delete_metadata(bucket_id, rel)
# Clean up now empty parents inside the bucket.
for parent in path.parents:
if parent == bucket_path:
break
@@ -592,6 +613,33 @@ class ObjectStorage:
if legacy_root.exists():
shutil.rmtree(legacy_root, ignore_errors=True)
def list_multipart_parts(self, bucket_name: str, upload_id: str) -> List[Dict[str, Any]]:
"""List uploaded parts for a multipart upload."""
bucket_path = self._bucket_path(bucket_name)
manifest, upload_root = self._load_multipart_manifest(bucket_path.name, upload_id)
parts = []
parts_map = manifest.get("parts", {})
for part_num_str, record in parts_map.items():
part_num = int(part_num_str)
part_filename = record.get("filename")
if not part_filename:
continue
part_path = upload_root / part_filename
if not part_path.exists():
continue
stat = part_path.stat()
parts.append({
"PartNumber": part_num,
"Size": stat.st_size,
"ETag": record.get("etag"),
"LastModified": datetime.fromtimestamp(stat.st_mtime, timezone.utc)
})
parts.sort(key=lambda x: x["PartNumber"])
return parts
# ---------------------- internal helpers ----------------------
def _bucket_path(self, bucket_name: str) -> Path:
safe_name = self._sanitize_bucket_name(bucket_name)
@@ -886,7 +934,11 @@ class ObjectStorage:
normalized = unicodedata.normalize("NFC", object_key)
if normalized != object_key:
raise StorageError("Object key must use normalized Unicode")
candidate = Path(normalized)
if ".." in candidate.parts:
raise StorageError("Object key contains parent directory references")
if candidate.is_absolute():
raise StorageError("Absolute object keys are not allowed")
if getattr(candidate, "drive", ""):

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
import json
import uuid
import psutil
import shutil
from typing import Any
from urllib.parse import urlparse
@@ -469,8 +471,6 @@ def complete_multipart_upload(bucket_name: str, upload_id: str):
normalized.append({"part_number": number, "etag": etag})
try:
result = _storage().complete_multipart_upload(bucket_name, upload_id, normalized)
# Trigger replication
_replication().trigger_replication(bucket_name, result["key"])
return jsonify(result)
@@ -1209,6 +1209,54 @@ def connections_dashboard():
return render_template("connections.html", connections=connections, principal=principal)
@ui_bp.get("/metrics")
def metrics_dashboard():
principal = _current_principal()
cpu_percent = psutil.cpu_percent(interval=None)
memory = psutil.virtual_memory()
storage_root = current_app.config["STORAGE_ROOT"]
disk = psutil.disk_usage(storage_root)
storage = _storage()
buckets = storage.list_buckets()
total_buckets = len(buckets)
total_objects = 0
total_bytes_used = 0
# Note: Uses cached stats from storage layer to improve performance
for bucket in buckets:
stats = storage.bucket_stats(bucket.name)
total_objects += stats["objects"]
total_bytes_used += stats["bytes"]
return render_template(
"metrics.html",
principal=principal,
cpu_percent=cpu_percent,
memory={
"total": _format_bytes(memory.total),
"available": _format_bytes(memory.available),
"used": _format_bytes(memory.used),
"percent": memory.percent,
},
disk={
"total": _format_bytes(disk.total),
"free": _format_bytes(disk.free),
"used": _format_bytes(disk.used),
"percent": disk.percent,
},
app={
"buckets": total_buckets,
"objects": total_objects,
"storage_used": _format_bytes(total_bytes_used),
"storage_raw": total_bytes_used,
}
)
@ui_bp.app_errorhandler(404)
def ui_not_found(error): # type: ignore[override]
prefix = ui_bp.url_prefix or ""