MyFSIO v0.2.4 Release #16

Merged
kqjy merged 15 commits from next into main 2026-02-01 10:27:11 +00:00
8 changed files with 327 additions and 267 deletions
Showing only changes of commit 8c4bf67974 - Show all commits

View File

@@ -1,8 +1,12 @@
from __future__ import annotations
import ipaddress
import logging
import re
import socket
import time
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse
import requests
from flask import Blueprint, Response, current_app, jsonify, request
@@ -13,6 +17,67 @@ from .iam import IamError, Principal
from .replication import ReplicationManager
from .site_registry import PeerSite, SiteInfo, SiteRegistry
def _is_safe_url(url: str) -> bool:
"""Check if a URL is safe to make requests to (not internal/private)."""
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
blocked_hosts = {
"localhost",
"127.0.0.1",
"0.0.0.0",
"::1",
"[::1]",
"metadata.google.internal",
"169.254.169.254",
}
if hostname.lower() in blocked_hosts:
return False
try:
resolved_ip = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(resolved_ip)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return False
except (socket.gaierror, ValueError):
return False
return True
except Exception:
return False
def _validate_endpoint(endpoint: str) -> Optional[str]:
"""Validate endpoint URL format. Returns error message or None."""
try:
parsed = urlparse(endpoint)
if not parsed.scheme or parsed.scheme not in ("http", "https"):
return "Endpoint must be http or https URL"
if not parsed.netloc:
return "Endpoint must have a host"
return None
except Exception:
return "Invalid endpoint URL"
def _validate_priority(priority: Any) -> Optional[str]:
"""Validate priority value. Returns error message or None."""
try:
p = int(priority)
if p < 0 or p > 1000:
return "Priority must be between 0 and 1000"
return None
except (TypeError, ValueError):
return "Priority must be an integer"
def _validate_region(region: str) -> Optional[str]:
"""Validate region format. Returns error message or None."""
if not re.match(r"^[a-z]{2,}-[a-z]+-\d+$", region):
return "Region must match format like us-east-1"
return None
logger = logging.getLogger(__name__)
admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin")
@@ -158,6 +223,20 @@ def register_peer_site():
if not endpoint:
return _json_error("ValidationError", "endpoint is required", 400)
endpoint_error = _validate_endpoint(endpoint)
if endpoint_error:
return _json_error("ValidationError", endpoint_error, 400)
region = payload.get("region", "us-east-1")
region_error = _validate_region(region)
if region_error:
return _json_error("ValidationError", region_error, 400)
priority = payload.get("priority", 100)
priority_error = _validate_priority(priority)
if priority_error:
return _json_error("ValidationError", priority_error, 400)
registry = _site_registry()
if registry.get_peer(site_id):
@@ -171,8 +250,8 @@ def register_peer_site():
peer = PeerSite(
site_id=site_id,
endpoint=endpoint,
region=payload.get("region", "us-east-1"),
priority=payload.get("priority", 100),
region=region,
priority=int(priority),
display_name=payload.get("display_name", site_id),
connection_id=connection_id,
)
@@ -411,6 +490,14 @@ def check_bidirectional_status(site_id: str):
})
return jsonify(result)
if not _is_safe_url(peer.endpoint):
result["issues"].append({
"code": "ENDPOINT_NOT_ALLOWED",
"message": "Peer endpoint points to internal or private address",
"severity": "error",
})
return jsonify(result)
try:
admin_url = peer.endpoint.rstrip("/") + "/admin/sites"
resp = requests.get(
@@ -494,20 +581,21 @@ def check_bidirectional_status(site_id: str):
"severity": "warning",
})
except requests.RequestException as e:
logger.warning("Remote admin API unreachable: %s", e)
result["remote_status"] = {
"reachable": False,
"error": str(e),
"error": "Connection failed",
}
result["issues"].append({
"code": "REMOTE_ADMIN_UNREACHABLE",
"message": f"Could not reach remote admin API: {e}",
"message": "Could not reach remote admin API",
"severity": "warning",
})
except Exception as e:
logger.warning(f"Error checking remote bidirectional status: {e}")
logger.warning("Error checking remote bidirectional status: %s", e, exc_info=True)
result["issues"].append({
"code": "VERIFICATION_ERROR",
"message": f"Error during verification: {e}",
"message": "Internal error during verification",
"severity": "warning",
})

View File

@@ -146,6 +146,8 @@ class AppConfig:
site_region: str
site_priority: int
ratelimit_admin: str
num_trusted_proxies: int
allowed_redirect_hosts: list[str]
@classmethod
def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig":
@@ -310,6 +312,9 @@ class AppConfig:
site_region = str(_get("SITE_REGION", "us-east-1"))
site_priority = int(_get("SITE_PRIORITY", 100))
ratelimit_admin = _validate_rate_limit(str(_get("RATE_LIMIT_ADMIN", "60 per minute")))
num_trusted_proxies = int(_get("NUM_TRUSTED_PROXIES", 0))
allowed_redirect_hosts_raw = _get("ALLOWED_REDIRECT_HOSTS", "")
allowed_redirect_hosts = [h.strip() for h in str(allowed_redirect_hosts_raw).split(",") if h.strip()]
return cls(storage_root=storage_root,
max_upload_size=max_upload_size,
@@ -393,7 +398,9 @@ class AppConfig:
site_endpoint=site_endpoint,
site_region=site_region,
site_priority=site_priority,
ratelimit_admin=ratelimit_admin)
ratelimit_admin=ratelimit_admin,
num_trusted_proxies=num_trusted_proxies,
allowed_redirect_hosts=allowed_redirect_hosts)
def validate_and_report(self) -> list[str]:
"""Validate configuration and return a list of warnings/issues.
@@ -598,4 +605,6 @@ class AppConfig:
"SITE_REGION": self.site_region,
"SITE_PRIORITY": self.site_priority,
"RATE_LIMIT_ADMIN": self.ratelimit_admin,
"NUM_TRUSTED_PROXIES": self.num_trusted_proxies,
"ALLOWED_REDIRECT_HOSTS": self.allowed_redirect_hosts,
}

View File

@@ -6,6 +6,7 @@ import io
import json
import os
import secrets
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
@@ -15,6 +16,26 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
if sys.platform != "win32":
import fcntl
def _set_secure_file_permissions(file_path: Path) -> None:
"""Set restrictive file permissions (owner read/write only)."""
if sys.platform == "win32":
try:
username = os.environ.get("USERNAME", "")
if username:
subprocess.run(
["icacls", str(file_path), "/inheritance:r",
"/grant:r", f"{username}:F"],
check=True, capture_output=True
)
except (subprocess.SubprocessError, OSError):
pass
else:
os.chmod(file_path, 0o600)
class EncryptionError(Exception):
"""Raised when encryption/decryption fails."""
@@ -103,22 +124,38 @@ class LocalKeyEncryption(EncryptionProvider):
return self._master_key
def _load_or_create_master_key(self) -> bytes:
"""Load master key from file or generate a new one."""
if self.master_key_path.exists():
try:
return base64.b64decode(self.master_key_path.read_text().strip())
except Exception as exc:
raise EncryptionError(f"Failed to load master key: {exc}") from exc
key = secrets.token_bytes(32)
"""Load master key from file or generate a new one (with file locking)."""
lock_path = self.master_key_path.with_suffix(".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
try:
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(base64.b64encode(key).decode())
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
with open(lock_path, "w") as lock_file:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
if self.master_key_path.exists():
try:
return base64.b64decode(self.master_key_path.read_text().strip())
except Exception as exc:
raise EncryptionError(f"Failed to load master key: {exc}") from exc
key = secrets.token_bytes(32)
try:
self.master_key_path.write_text(base64.b64encode(key).decode())
_set_secure_file_permissions(self.master_key_path)
except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key
finally:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
except OSError as exc:
raise EncryptionError(f"Failed to save master key: {exc}") from exc
return key
raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc
def _encrypt_data_key(self, data_key: bytes) -> bytes:
"""Encrypt the data key with the master key."""

View File

@@ -1,9 +1,11 @@
from __future__ import annotations
import hashlib
import hmac
import json
import math
import secrets
import threading
import time
from collections import deque
from dataclasses import dataclass
@@ -118,12 +120,14 @@ class IamService:
self._raw_config: Dict[str, Any] = {}
self._failed_attempts: Dict[str, Deque[datetime]] = {}
self._last_load_time = 0.0
self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {}
self._cache_ttl = 10.0
self._principal_cache: Dict[str, Tuple[Principal, float]] = {}
self._cache_ttl = 10.0
self._last_stat_check = 0.0
self._stat_check_interval = 1.0
self._sessions: Dict[str, Dict[str, Any]] = {}
self._session_lock = threading.Lock()
self._load()
self._load_lockout_state()
def _maybe_reload(self) -> None:
"""Reload configuration if the file has changed on disk."""
@@ -134,7 +138,7 @@ class IamService:
try:
if self.config_path.stat().st_mtime > self._last_load_time:
self._load()
self._credential_cache.clear()
self._principal_cache.clear()
except OSError:
pass
@@ -163,11 +167,46 @@ class IamService:
attempts = self._failed_attempts.setdefault(access_key, deque())
self._prune_attempts(attempts)
attempts.append(datetime.now(timezone.utc))
self._save_lockout_state()
def _clear_failed_attempts(self, access_key: str) -> None:
if not access_key:
return
self._failed_attempts.pop(access_key, None)
if self._failed_attempts.pop(access_key, None) is not None:
self._save_lockout_state()
def _lockout_file(self) -> Path:
return self.config_path.parent / "lockout_state.json"
def _load_lockout_state(self) -> None:
"""Load lockout state from disk."""
try:
if self._lockout_file().exists():
data = json.loads(self._lockout_file().read_text(encoding="utf-8"))
cutoff = datetime.now(timezone.utc) - self.auth_lockout_window
for key, timestamps in data.get("failed_attempts", {}).items():
valid = []
for ts in timestamps:
try:
dt = datetime.fromisoformat(ts)
if dt > cutoff:
valid.append(dt)
except (ValueError, TypeError):
continue
if valid:
self._failed_attempts[key] = deque(valid)
except (OSError, json.JSONDecodeError):
pass
def _save_lockout_state(self) -> None:
"""Persist lockout state to disk."""
data: Dict[str, Any] = {"failed_attempts": {}}
for key, attempts in self._failed_attempts.items():
data["failed_attempts"][key] = [ts.isoformat() for ts in attempts]
try:
self._lockout_file().write_text(json.dumps(data), encoding="utf-8")
except OSError:
pass
def _prune_attempts(self, attempts: Deque[datetime]) -> None:
cutoff = datetime.now(timezone.utc) - self.auth_lockout_window
@@ -210,17 +249,23 @@ class IamService:
return token
def validate_session_token(self, access_key: str, session_token: str) -> bool:
"""Validate a session token for an access key."""
session = self._sessions.get(session_token)
if not session:
hmac.compare_digest(access_key, secrets.token_urlsafe(16))
return False
if not hmac.compare_digest(session["access_key"], access_key):
return False
if time.time() > session["expires_at"]:
del self._sessions[session_token]
return False
return True
"""Validate a session token for an access key (thread-safe, constant-time)."""
dummy_key = secrets.token_urlsafe(16)
dummy_token = secrets.token_urlsafe(32)
with self._session_lock:
session = self._sessions.get(session_token)
if not session:
hmac.compare_digest(access_key, dummy_key)
hmac.compare_digest(session_token, dummy_token)
return False
key_match = hmac.compare_digest(session["access_key"], access_key)
if not key_match:
hmac.compare_digest(session_token, dummy_token)
return False
if time.time() > session["expires_at"]:
self._sessions.pop(session_token, None)
return False
return True
def _cleanup_expired_sessions(self) -> None:
"""Remove expired session tokens."""
@@ -231,9 +276,9 @@ class IamService:
def principal_for_key(self, access_key: str) -> Principal:
now = time.time()
cached = self._credential_cache.get(access_key)
cached = self._principal_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return principal
@@ -242,23 +287,14 @@ class IamService:
if not record:
raise IamError("Unknown access key")
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
self._principal_cache[access_key] = (principal, now)
return principal
def secret_for_key(self, access_key: str) -> str:
now = time.time()
cached = self._credential_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return secret
self._maybe_reload()
record = self._users.get(access_key)
if not record:
raise IamError("Unknown access key")
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
return record["secret_key"]
def authorize(self, principal: Principal, bucket_name: str | None, action: str) -> None:
@@ -330,6 +366,7 @@ class IamService:
new_secret = self._generate_secret_key()
user["secret_key"] = new_secret
self._save()
self._principal_cache.pop(access_key, None)
self._load()
return new_secret
@@ -509,26 +546,17 @@ class IamService:
raise IamError("User not found")
def get_secret_key(self, access_key: str) -> str | None:
now = time.time()
cached = self._credential_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return secret
self._maybe_reload()
record = self._users.get(access_key)
if record:
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
return record["secret_key"]
return None
def get_principal(self, access_key: str) -> Principal | None:
now = time.time()
cached = self._credential_cache.get(access_key)
cached = self._principal_cache.get(access_key)
if cached:
secret, principal, cached_time = cached
principal, cached_time = cached
if now - cached_time < self._cache_ttl:
return principal
@@ -536,6 +564,6 @@ class IamService:
record = self._users.get(access_key)
if record:
principal = self._build_principal(access_key, record)
self._credential_cache[access_key] = (record["secret_key"], principal, now)
self._principal_cache[access_key] = (principal, now)
return principal
return None

View File

@@ -5,6 +5,7 @@ import json
import logging
import os
import secrets
import subprocess
import sys
import uuid
from dataclasses import dataclass, field
@@ -16,9 +17,29 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from .encryption import EncryptionError, EncryptionProvider, EncryptionResult
if sys.platform != "win32":
import fcntl
logger = logging.getLogger(__name__)
def _set_secure_file_permissions(file_path: Path) -> None:
"""Set restrictive file permissions (owner read/write only)."""
if sys.platform == "win32":
try:
username = os.environ.get("USERNAME", "")
if username:
subprocess.run(
["icacls", str(file_path), "/inheritance:r",
"/grant:r", f"{username}:F"],
check=True, capture_output=True
)
except (subprocess.SubprocessError, OSError):
pass
else:
os.chmod(file_path, 0o600)
@dataclass
class KMSKey:
"""Represents a KMS encryption key."""
@@ -132,20 +153,33 @@ class KMSManager:
@property
def master_key(self) -> bytes:
"""Load or create the master key for encrypting KMS keys."""
"""Load or create the master key for encrypting KMS keys (with file locking)."""
if self._master_key is None:
if self.master_key_path.exists():
self._master_key = base64.b64decode(
self.master_key_path.read_text().strip()
)
else:
self._master_key = secrets.token_bytes(32)
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
self.master_key_path.write_text(
base64.b64encode(self._master_key).decode()
)
if sys.platform != "win32":
os.chmod(self.master_key_path, 0o600)
lock_path = self.master_key_path.with_suffix(".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
with open(lock_path, "w") as lock_file:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
try:
if self.master_key_path.exists():
self._master_key = base64.b64decode(
self.master_key_path.read_text().strip()
)
else:
self._master_key = secrets.token_bytes(32)
self.master_key_path.write_text(
base64.b64encode(self._master_key).decode()
)
_set_secure_file_permissions(self.master_key_path)
finally:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
return self._master_key
def _load_keys(self) -> None:
@@ -177,12 +211,13 @@ class KMSManager:
encrypted = self._encrypt_key_material(key.key_material)
data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode()
keys_data.append(data)
self.keys_path.parent.mkdir(parents=True, exist_ok=True)
self.keys_path.write_text(
json.dumps({"keys": keys_data}, indent=2),
encoding="utf-8"
)
_set_secure_file_permissions(self.keys_path)
def _encrypt_key_material(self, key_material: bytes) -> bytes:
"""Encrypt key material with the master key."""

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import ipaddress
import json
import logging
import queue
import socket
import threading
import time
import uuid
@@ -14,6 +16,36 @@ from urllib.parse import urlparse
import requests
def _is_safe_url(url: str) -> bool:
"""Check if a URL is safe to make requests to (not internal/private)."""
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
blocked_hosts = {
"localhost",
"127.0.0.1",
"0.0.0.0",
"::1",
"[::1]",
"metadata.google.internal",
"169.254.169.254",
}
if hostname.lower() in blocked_hosts:
return False
try:
resolved_ip = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(resolved_ip)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return False
except (socket.gaierror, ValueError):
return False
return True
except Exception:
return False
logger = logging.getLogger(__name__)
@@ -299,6 +331,8 @@ class NotificationService:
self._queue.task_done()
def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None:
if not _is_safe_url(destination.url):
raise RuntimeError(f"Blocked request to internal/private URL: {destination.url}")
payload = event.to_s3_event()
headers = {"Content-Type": "application/json", **destination.headers}

View File

@@ -60,10 +60,13 @@ def _build_policy_context() -> Dict[str, Any]:
ctx: Dict[str, Any] = {}
if request.headers.get("Referer"):
ctx["aws:Referer"] = request.headers.get("Referer")
if request.access_route:
ctx["aws:SourceIp"] = request.access_route[0]
num_proxies = current_app.config.get("NUM_TRUSTED_PROXIES", 0)
if num_proxies > 0 and request.access_route and len(request.access_route) > num_proxies:
ctx["aws:SourceIp"] = request.access_route[-num_proxies]
elif request.remote_addr:
ctx["aws:SourceIp"] = request.remote_addr
elif request.access_route:
ctx["aws:SourceIp"] = request.access_route[0]
ctx["aws:SecureTransport"] = str(request.is_secure).lower()
if request.headers.get("User-Agent"):
ctx["aws:UserAgent"] = request.headers.get("User-Agent")
@@ -2242,6 +2245,17 @@ def _post_object(bucket_name: str) -> Response:
expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest()
if not hmac.compare_digest(expected_signature, signature):
return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403)
principal = _iam().get_principal(access_key)
if not principal:
return _error_response("AccessDenied", "Invalid access key", 403)
if "${filename}" in object_key:
temp_key = object_key.replace("${filename}", request.files.get("file").filename if request.files.get("file") else "upload")
else:
temp_key = object_key
try:
_authorize_action(principal, bucket_name, "write", object_key=temp_key)
except IamError as exc:
return _error_response("AccessDenied", str(exc), 403)
file = request.files.get("file")
if not file:
return _error_response("InvalidArgument", "Missing file field", 400)
@@ -2263,6 +2277,12 @@ def _post_object(bucket_name: str) -> Response:
success_action_status = request.form.get("success_action_status", "204")
success_action_redirect = request.form.get("success_action_redirect")
if success_action_redirect:
allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", [])
parsed = urlparse(success_action_redirect)
if parsed.scheme not in ("http", "https"):
return _error_response("InvalidArgument", "Redirect URL must use http or https", 400)
if allowed_hosts and parsed.netloc not in allowed_hosts:
return _error_response("InvalidArgument", "Redirect URL host not allowed", 400)
redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}"
return Response(status=303, headers={"Location": redirect_url})
if success_action_status == "200":

View File

@@ -1,191 +0,0 @@
import hashlib
import hmac
import pytest
from datetime import datetime, timedelta, timezone
from urllib.parse import quote
def _sign(key, msg):
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
def _get_signature_key(key, date_stamp, region_name, service_name):
k_date = _sign(("AWS4" + key).encode("utf-8"), date_stamp)
k_region = _sign(k_date, region_name)
k_service = _sign(k_region, service_name)
k_signing = _sign(k_service, "aws4_request")
return k_signing
def create_signed_headers(
method,
path,
headers=None,
body=None,
access_key="test",
secret_key="secret",
region="us-east-1",
service="s3",
timestamp=None
):
if headers is None:
headers = {}
if timestamp is None:
now = datetime.now(timezone.utc)
else:
now = timestamp
amz_date = now.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now.strftime("%Y%m%d")
headers["X-Amz-Date"] = amz_date
headers["Host"] = "testserver"
canonical_uri = quote(path, safe="/-_.~")
canonical_query_string = ""
canonical_headers = ""
signed_headers_list = []
for k, v in sorted(headers.items(), key=lambda x: x[0].lower()):
canonical_headers += f"{k.lower()}:{v.strip()}\n"
signed_headers_list.append(k.lower())
signed_headers = ";".join(signed_headers_list)
payload_hash = hashlib.sha256(body or b"").hexdigest()
headers["X-Amz-Content-Sha256"] = payload_hash
canonical_request = f"{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
credential_scope = f"{date_stamp}/{region}/{service}/aws4_request"
string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}"
signing_key = _get_signature_key(secret_key, date_stamp, region, service)
signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
headers["Authorization"] = (
f"AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, "
f"SignedHeaders={signed_headers}, Signature={signature}"
)
return headers
def test_sigv4_old_date(client):
# Test with a date 20 minutes in the past
old_time = datetime.now(timezone.utc) - timedelta(minutes=20)
headers = create_signed_headers("GET", "/", timestamp=old_time)
response = client.get("/", headers=headers)
assert response.status_code == 403
assert b"Request timestamp too old" in response.data
def test_sigv4_future_date(client):
# Test with a date 20 minutes in the future
future_time = datetime.now(timezone.utc) + timedelta(minutes=20)
headers = create_signed_headers("GET", "/", timestamp=future_time)
response = client.get("/", headers=headers)
assert response.status_code == 403
assert b"Request timestamp too old" in response.data # The error message is the same
def test_path_traversal_in_key(client, signer):
headers = signer("PUT", "/test-bucket")
client.put("/test-bucket", headers=headers)
# Try to upload with .. in key
headers = signer("PUT", "/test-bucket/../secret.txt", body=b"attack")
response = client.put("/test-bucket/../secret.txt", headers=headers, data=b"attack")
# Should be rejected by storage layer or flask routing
# Flask might normalize it before it reaches the app, but if it reaches, it should fail.
# If Flask normalizes /test-bucket/../secret.txt to /secret.txt, then it hits 404 (bucket not found) or 403.
# But we want to test the storage layer check.
# We can try to encode the dots?
# If we use a key that doesn't get normalized by Flask routing easily.
# But wait, the route is /<bucket_name>/<path:object_key>
# If I send /test-bucket/folder/../file.txt, Flask might pass "folder/../file.txt" as object_key?
# Let's try.
headers = signer("PUT", "/test-bucket/folder/../file.txt", body=b"attack")
response = client.put("/test-bucket/folder/../file.txt", headers=headers, data=b"attack")
# If Flask normalizes it, it becomes /test-bucket/file.txt.
# If it doesn't, it hits our check.
# Let's try to call the storage method directly to verify the check works,
# because testing via client depends on Flask's URL handling.
pass
def test_storage_path_traversal(app):
storage = app.extensions["object_storage"]
from app.storage import StorageError, ObjectStorage
from app.encrypted_storage import EncryptedObjectStorage
# Get the underlying ObjectStorage if wrapped
if isinstance(storage, EncryptedObjectStorage):
storage = storage.storage
with pytest.raises(StorageError, match="Object key contains parent directory references"):
storage._sanitize_object_key("folder/../file.txt")
with pytest.raises(StorageError, match="Object key contains parent directory references"):
storage._sanitize_object_key("..")
def test_head_bucket(client, signer):
headers = signer("PUT", "/head-test")
client.put("/head-test", headers=headers)
headers = signer("HEAD", "/head-test")
response = client.head("/head-test", headers=headers)
assert response.status_code == 200
headers = signer("HEAD", "/non-existent")
response = client.head("/non-existent", headers=headers)
assert response.status_code == 404
def test_head_object(client, signer):
headers = signer("PUT", "/head-obj-test")
client.put("/head-obj-test", headers=headers)
headers = signer("PUT", "/head-obj-test/obj", body=b"content")
client.put("/head-obj-test/obj", headers=headers, data=b"content")
headers = signer("HEAD", "/head-obj-test/obj")
response = client.head("/head-obj-test/obj", headers=headers)
assert response.status_code == 200
assert response.headers["ETag"]
assert response.headers["Content-Length"] == "7"
headers = signer("HEAD", "/head-obj-test/missing")
response = client.head("/head-obj-test/missing", headers=headers)
assert response.status_code == 404
def test_list_parts(client, signer):
# Create bucket
headers = signer("PUT", "/multipart-test")
client.put("/multipart-test", headers=headers)
# Initiate multipart upload
headers = signer("POST", "/multipart-test/obj?uploads")
response = client.post("/multipart-test/obj?uploads", headers=headers)
assert response.status_code == 200
from xml.etree.ElementTree import fromstring
upload_id = fromstring(response.data).find("UploadId").text
# Upload part 1
headers = signer("PUT", f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", body=b"part1")
client.put(f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", headers=headers, data=b"part1")
# Upload part 2
headers = signer("PUT", f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", body=b"part2")
client.put(f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", headers=headers, data=b"part2")
# List parts
headers = signer("GET", f"/multipart-test/obj?uploadId={upload_id}")
response = client.get(f"/multipart-test/obj?uploadId={upload_id}", headers=headers)
assert response.status_code == 200
root = fromstring(response.data)
assert root.tag == "ListPartsResult"
parts = root.findall("Part")
assert len(parts) == 2
assert parts[0].find("PartNumber").text == "1"
assert parts[1].find("PartNumber").text == "2"