diff --git a/app/extensions.py b/app/extensions.py index 1f8b71a..0fc97a6 100644 --- a/app/extensions.py +++ b/app/extensions.py @@ -1,10 +1,17 @@ """Application-wide extension instances.""" +from flask import g from flask_limiter import Limiter from flask_limiter.util import get_remote_address from flask_wtf import CSRFProtect +def get_rate_limit_key(): + """Generate rate limit key based on authenticated user.""" + if hasattr(g, 'principal') and g.principal: + return g.principal.access_key + return get_remote_address() + # Shared rate limiter instance; configured in app factory. -limiter = Limiter(key_func=get_remote_address) +limiter = Limiter(key_func=get_rate_limit_key) # Global CSRF protection for UI routes. csrf = CSRFProtect() diff --git a/app/s3_api.py b/app/s3_api.py index 31c9142..7d0c64f 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -11,7 +11,7 @@ from typing import Any, Dict from urllib.parse import quote, urlencode, urlparse from xml.etree.ElementTree import Element, SubElement, tostring, fromstring, ParseError -from flask import Blueprint, Response, current_app, jsonify, request +from flask import Blueprint, Response, current_app, jsonify, request, g from werkzeug.http import http_date from .bucket_policies import BucketPolicyStore @@ -127,14 +127,33 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: if not amz_date: raise IamError("Missing Date header") + try: + request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) + except ValueError: + raise IamError("Invalid X-Amz-Date format") + + now = datetime.now(timezone.utc) + time_diff = abs((now - request_time).total_seconds()) + if time_diff > 900: # 15 minutes + raise IamError("Request timestamp too old or too far in the future") + + required_headers = {'host', 'x-amz-date'} + signed_headers_set = set(signed_headers_str.split(';')) + if not required_headers.issubset(signed_headers_set): + # Some clients might sign 'date' instead of 'x-amz-date' + if 'date' in signed_headers_set: + required_headers.remove('x-amz-date') + required_headers.add('date') + + if not required_headers.issubset(signed_headers_set): + raise IamError("Required headers not signed") + credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" - - # Calculate Signature signing_key = _get_signature_key(secret_key, date_stamp, region, service) calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() - if calculated_signature != signature: + if not hmac.compare_digest(calculated_signature, signature): raise IamError("SignatureDoesNotMatch") return _iam().get_principal(access_key) @@ -155,7 +174,6 @@ def _verify_sigv4_query(req: Any) -> Principal | None: except ValueError: raise IamError("Invalid Credential format") - # Check expiration try: req_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) except ValueError: @@ -190,7 +208,6 @@ def _verify_sigv4_query(req: Any) -> Principal | None: canonical_headers_parts = [] for header in signed_headers_list: val = req.headers.get(header, "").strip() - # Collapse multiple spaces val = " ".join(val.split()) canonical_headers_parts.append(f"{header}:{val}\n") canonical_headers = "".join(canonical_headers_parts) @@ -240,7 +257,6 @@ def _verify_sigv4(req: Any) -> Principal | None: def _require_principal(): - # Try SigV4 first if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \ (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"): try: @@ -1132,6 +1148,9 @@ def object_handler(bucket_name: str, object_key: str): return response if request.method in {"GET", "HEAD"}: + if request.method == "GET" and "uploadId" in request.args: + return _list_parts(bucket_name, object_key) + _, error = _object_principal("read", bucket_name, object_key) if error: return error @@ -1157,7 +1176,6 @@ def object_handler(bucket_name: str, object_key: str): current_app.logger.info(action, extra={"bucket": bucket_name, "key": object_key, "bytes": logged_bytes}) return response - # DELETE if "uploadId" in request.args: return _abort_multipart_upload(bucket_name, object_key) @@ -1175,6 +1193,51 @@ def object_handler(bucket_name: str, object_key: str): return Response(status=204) +def _list_parts(bucket_name: str, object_key: str) -> Response: + principal, error = _require_principal() + if error: + return error + try: + _authorize_action(principal, bucket_name, "read", object_key=object_key) + except IamError as exc: + return _error_response("AccessDenied", str(exc), 403) + + upload_id = request.args.get("uploadId") + if not upload_id: + return _error_response("InvalidArgument", "uploadId is required", 400) + + try: + parts = _storage().list_multipart_parts(bucket_name, upload_id) + except StorageError as exc: + return _error_response("NoSuchUpload", str(exc), 404) + + root = Element("ListPartsResult") + SubElement(root, "Bucket").text = bucket_name + SubElement(root, "Key").text = object_key + SubElement(root, "UploadId").text = upload_id + + initiator = SubElement(root, "Initiator") + SubElement(initiator, "ID").text = principal.access_key + SubElement(initiator, "DisplayName").text = principal.display_name + + owner = SubElement(root, "Owner") + SubElement(owner, "ID").text = principal.access_key + SubElement(owner, "DisplayName").text = principal.display_name + + SubElement(root, "StorageClass").text = "STANDARD" + SubElement(root, "PartNumberMarker").text = "0" + SubElement(root, "NextPartNumberMarker").text = str(parts[-1]["PartNumber"]) if parts else "0" + SubElement(root, "MaxParts").text = "1000" + SubElement(root, "IsTruncated").text = "false" + + for part in parts: + p = SubElement(root, "Part") + SubElement(p, "PartNumber").text = str(part["PartNumber"]) + SubElement(p, "LastModified").text = part["LastModified"].isoformat() + SubElement(p, "ETag").text = f'"{part["ETag"]}"' + SubElement(p, "Size").text = str(part["Size"]) + + return _xml_response(root) @s3_api_bp.route("/bucket-policy/", methods=["GET", "PUT", "DELETE"]) @@ -1504,3 +1567,25 @@ def _abort_multipart_upload(bucket_name: str, object_key: str) -> Response: return _error_response("NoSuchBucket", str(exc), 404) return Response(status=204) + + +@s3_api_bp.before_request +def resolve_principal(): + g.principal = None + # Try SigV4 + try: + if ("Authorization" in request.headers and request.headers["Authorization"].startswith("AWS4-HMAC-SHA256")) or \ + (request.args.get("X-Amz-Algorithm") == "AWS4-HMAC-SHA256"): + g.principal = _verify_sigv4(request) + return + except Exception: + pass + + # Try simple auth headers (internal/testing) + access_key = request.headers.get("X-Access-Key") + secret_key = request.headers.get("X-Secret-Key") + if access_key and secret_key: + try: + g.principal = _iam().authenticate(access_key, secret_key) + except Exception: + pass diff --git a/app/storage.py b/app/storage.py index 37b31db..6c617e3 100644 --- a/app/storage.py +++ b/app/storage.py @@ -120,10 +120,22 @@ class ObjectStorage: self._system_bucket_root(bucket_path.name).mkdir(parents=True, exist_ok=True) def bucket_stats(self, bucket_name: str) -> dict[str, int]: - """Return object count and total size for the bucket without hashing files.""" + """Return object count and total size for the bucket (cached).""" bucket_path = self._bucket_path(bucket_name) if not bucket_path.exists(): raise StorageError("Bucket does not exist") + + # Try to read from cache + cache_path = self._system_bucket_root(bucket_name) / "stats.json" + if cache_path.exists(): + try: + # Check if cache is fresh (e.g., < 60 seconds old) + if time.time() - cache_path.stat().st_mtime < 60: + return json.loads(cache_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + pass + + # Calculate fresh stats object_count = 0 total_bytes = 0 for path in bucket_path.rglob("*"): @@ -134,7 +146,17 @@ class ObjectStorage: stat = path.stat() object_count += 1 total_bytes += stat.st_size - return {"objects": object_count, "bytes": total_bytes} + + stats = {"objects": object_count, "bytes": total_bytes} + + # Write to cache + try: + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_text(json.dumps(stats), encoding="utf-8") + except OSError: + pass + + return stats def delete_bucket(self, bucket_name: str) -> None: bucket_path = self._bucket_path(bucket_name) @@ -239,7 +261,6 @@ class ObjectStorage: rel = path.relative_to(bucket_path) self._safe_unlink(path) self._delete_metadata(bucket_id, rel) - # Clean up now empty parents inside the bucket. for parent in path.parents: if parent == bucket_path: break @@ -592,6 +613,33 @@ class ObjectStorage: if legacy_root.exists(): shutil.rmtree(legacy_root, ignore_errors=True) + def list_multipart_parts(self, bucket_name: str, upload_id: str) -> List[Dict[str, Any]]: + """List uploaded parts for a multipart upload.""" + bucket_path = self._bucket_path(bucket_name) + manifest, upload_root = self._load_multipart_manifest(bucket_path.name, upload_id) + + parts = [] + parts_map = manifest.get("parts", {}) + for part_num_str, record in parts_map.items(): + part_num = int(part_num_str) + part_filename = record.get("filename") + if not part_filename: + continue + part_path = upload_root / part_filename + if not part_path.exists(): + continue + + stat = part_path.stat() + parts.append({ + "PartNumber": part_num, + "Size": stat.st_size, + "ETag": record.get("etag"), + "LastModified": datetime.fromtimestamp(stat.st_mtime, timezone.utc) + }) + + parts.sort(key=lambda x: x["PartNumber"]) + return parts + # ---------------------- internal helpers ---------------------- def _bucket_path(self, bucket_name: str) -> Path: safe_name = self._sanitize_bucket_name(bucket_name) @@ -886,7 +934,11 @@ class ObjectStorage: normalized = unicodedata.normalize("NFC", object_key) if normalized != object_key: raise StorageError("Object key must use normalized Unicode") + candidate = Path(normalized) + if ".." in candidate.parts: + raise StorageError("Object key contains parent directory references") + if candidate.is_absolute(): raise StorageError("Absolute object keys are not allowed") if getattr(candidate, "drive", ""): diff --git a/app/ui.py b/app/ui.py index 9693240..f5a9cfd 100644 --- a/app/ui.py +++ b/app/ui.py @@ -3,6 +3,8 @@ from __future__ import annotations import json import uuid +import psutil +import shutil from typing import Any from urllib.parse import urlparse @@ -469,8 +471,6 @@ def complete_multipart_upload(bucket_name: str, upload_id: str): normalized.append({"part_number": number, "etag": etag}) try: result = _storage().complete_multipart_upload(bucket_name, upload_id, normalized) - - # Trigger replication _replication().trigger_replication(bucket_name, result["key"]) return jsonify(result) @@ -1209,6 +1209,54 @@ def connections_dashboard(): return render_template("connections.html", connections=connections, principal=principal) +@ui_bp.get("/metrics") +def metrics_dashboard(): + principal = _current_principal() + + cpu_percent = psutil.cpu_percent(interval=None) + memory = psutil.virtual_memory() + + storage_root = current_app.config["STORAGE_ROOT"] + disk = psutil.disk_usage(storage_root) + + storage = _storage() + buckets = storage.list_buckets() + total_buckets = len(buckets) + + total_objects = 0 + total_bytes_used = 0 + + # Note: Uses cached stats from storage layer to improve performance + for bucket in buckets: + stats = storage.bucket_stats(bucket.name) + total_objects += stats["objects"] + total_bytes_used += stats["bytes"] + + return render_template( + "metrics.html", + principal=principal, + cpu_percent=cpu_percent, + memory={ + "total": _format_bytes(memory.total), + "available": _format_bytes(memory.available), + "used": _format_bytes(memory.used), + "percent": memory.percent, + }, + disk={ + "total": _format_bytes(disk.total), + "free": _format_bytes(disk.free), + "used": _format_bytes(disk.used), + "percent": disk.percent, + }, + app={ + "buckets": total_buckets, + "objects": total_objects, + "storage_used": _format_bytes(total_bytes_used), + "storage_raw": total_bytes_used, + } + ) + + @ui_bp.app_errorhandler(404) def ui_not_found(error): # type: ignore[override] prefix = ui_bp.url_prefix or "" diff --git a/requirements.txt b/requirements.txt index 43f1ae7..7c2c75d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ pytest>=7.4 requests>=2.31 boto3>=1.34 waitress>=2.1.2 +psutil>=5.9.0 diff --git a/templates/base.html b/templates/base.html index 1bf96ce..d53dbeb 100644 --- a/templates/base.html +++ b/templates/base.html @@ -63,6 +63,9 @@ {% if not can_manage_iam %}Restricted{% endif %} + {% endif %} {% if principal %}