MyFSIO v0.2.4 Release #16
100
app/admin_api.py
100
app/admin_api.py
@@ -1,8 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from flask import Blueprint, Response, current_app, jsonify, request
|
||||
@@ -13,6 +17,67 @@ from .iam import IamError, Principal
|
||||
from .replication import ReplicationManager
|
||||
from .site_registry import PeerSite, SiteInfo, SiteRegistry
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if a URL is safe to make requests to (not internal/private)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
blocked_hosts = {
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
"::1",
|
||||
"[::1]",
|
||||
"metadata.google.internal",
|
||||
"169.254.169.254",
|
||||
}
|
||||
if hostname.lower() in blocked_hosts:
|
||||
return False
|
||||
try:
|
||||
resolved_ip = socket.gethostbyname(hostname)
|
||||
ip = ipaddress.ip_address(resolved_ip)
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
return False
|
||||
except (socket.gaierror, ValueError):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _validate_endpoint(endpoint: str) -> Optional[str]:
|
||||
"""Validate endpoint URL format. Returns error message or None."""
|
||||
try:
|
||||
parsed = urlparse(endpoint)
|
||||
if not parsed.scheme or parsed.scheme not in ("http", "https"):
|
||||
return "Endpoint must be http or https URL"
|
||||
if not parsed.netloc:
|
||||
return "Endpoint must have a host"
|
||||
return None
|
||||
except Exception:
|
||||
return "Invalid endpoint URL"
|
||||
|
||||
|
||||
def _validate_priority(priority: Any) -> Optional[str]:
|
||||
"""Validate priority value. Returns error message or None."""
|
||||
try:
|
||||
p = int(priority)
|
||||
if p < 0 or p > 1000:
|
||||
return "Priority must be between 0 and 1000"
|
||||
return None
|
||||
except (TypeError, ValueError):
|
||||
return "Priority must be an integer"
|
||||
|
||||
|
||||
def _validate_region(region: str) -> Optional[str]:
|
||||
"""Validate region format. Returns error message or None."""
|
||||
if not re.match(r"^[a-z]{2,}-[a-z]+-\d+$", region):
|
||||
return "Region must match format like us-east-1"
|
||||
return None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/admin")
|
||||
@@ -158,6 +223,20 @@ def register_peer_site():
|
||||
if not endpoint:
|
||||
return _json_error("ValidationError", "endpoint is required", 400)
|
||||
|
||||
endpoint_error = _validate_endpoint(endpoint)
|
||||
if endpoint_error:
|
||||
return _json_error("ValidationError", endpoint_error, 400)
|
||||
|
||||
region = payload.get("region", "us-east-1")
|
||||
region_error = _validate_region(region)
|
||||
if region_error:
|
||||
return _json_error("ValidationError", region_error, 400)
|
||||
|
||||
priority = payload.get("priority", 100)
|
||||
priority_error = _validate_priority(priority)
|
||||
if priority_error:
|
||||
return _json_error("ValidationError", priority_error, 400)
|
||||
|
||||
registry = _site_registry()
|
||||
|
||||
if registry.get_peer(site_id):
|
||||
@@ -171,8 +250,8 @@ def register_peer_site():
|
||||
peer = PeerSite(
|
||||
site_id=site_id,
|
||||
endpoint=endpoint,
|
||||
region=payload.get("region", "us-east-1"),
|
||||
priority=payload.get("priority", 100),
|
||||
region=region,
|
||||
priority=int(priority),
|
||||
display_name=payload.get("display_name", site_id),
|
||||
connection_id=connection_id,
|
||||
)
|
||||
@@ -411,6 +490,14 @@ def check_bidirectional_status(site_id: str):
|
||||
})
|
||||
return jsonify(result)
|
||||
|
||||
if not _is_safe_url(peer.endpoint):
|
||||
result["issues"].append({
|
||||
"code": "ENDPOINT_NOT_ALLOWED",
|
||||
"message": "Peer endpoint points to internal or private address",
|
||||
"severity": "error",
|
||||
})
|
||||
return jsonify(result)
|
||||
|
||||
try:
|
||||
admin_url = peer.endpoint.rstrip("/") + "/admin/sites"
|
||||
resp = requests.get(
|
||||
@@ -494,20 +581,21 @@ def check_bidirectional_status(site_id: str):
|
||||
"severity": "warning",
|
||||
})
|
||||
except requests.RequestException as e:
|
||||
logger.warning("Remote admin API unreachable: %s", e)
|
||||
result["remote_status"] = {
|
||||
"reachable": False,
|
||||
"error": str(e),
|
||||
"error": "Connection failed",
|
||||
}
|
||||
result["issues"].append({
|
||||
"code": "REMOTE_ADMIN_UNREACHABLE",
|
||||
"message": f"Could not reach remote admin API: {e}",
|
||||
"message": "Could not reach remote admin API",
|
||||
"severity": "warning",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking remote bidirectional status: {e}")
|
||||
logger.warning("Error checking remote bidirectional status: %s", e, exc_info=True)
|
||||
result["issues"].append({
|
||||
"code": "VERIFICATION_ERROR",
|
||||
"message": f"Error during verification: {e}",
|
||||
"message": "Internal error during verification",
|
||||
"severity": "warning",
|
||||
})
|
||||
|
||||
|
||||
@@ -146,6 +146,8 @@ class AppConfig:
|
||||
site_region: str
|
||||
site_priority: int
|
||||
ratelimit_admin: str
|
||||
num_trusted_proxies: int
|
||||
allowed_redirect_hosts: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig":
|
||||
@@ -310,6 +312,9 @@ class AppConfig:
|
||||
site_region = str(_get("SITE_REGION", "us-east-1"))
|
||||
site_priority = int(_get("SITE_PRIORITY", 100))
|
||||
ratelimit_admin = _validate_rate_limit(str(_get("RATE_LIMIT_ADMIN", "60 per minute")))
|
||||
num_trusted_proxies = int(_get("NUM_TRUSTED_PROXIES", 0))
|
||||
allowed_redirect_hosts_raw = _get("ALLOWED_REDIRECT_HOSTS", "")
|
||||
allowed_redirect_hosts = [h.strip() for h in str(allowed_redirect_hosts_raw).split(",") if h.strip()]
|
||||
|
||||
return cls(storage_root=storage_root,
|
||||
max_upload_size=max_upload_size,
|
||||
@@ -393,7 +398,9 @@ class AppConfig:
|
||||
site_endpoint=site_endpoint,
|
||||
site_region=site_region,
|
||||
site_priority=site_priority,
|
||||
ratelimit_admin=ratelimit_admin)
|
||||
ratelimit_admin=ratelimit_admin,
|
||||
num_trusted_proxies=num_trusted_proxies,
|
||||
allowed_redirect_hosts=allowed_redirect_hosts)
|
||||
|
||||
def validate_and_report(self) -> list[str]:
|
||||
"""Validate configuration and return a list of warnings/issues.
|
||||
@@ -598,4 +605,6 @@ class AppConfig:
|
||||
"SITE_REGION": self.site_region,
|
||||
"SITE_PRIORITY": self.site_priority,
|
||||
"RATE_LIMIT_ADMIN": self.ratelimit_admin,
|
||||
"NUM_TRUSTED_PROXIES": self.num_trusted_proxies,
|
||||
"ALLOWED_REDIRECT_HOSTS": self.allowed_redirect_hosts,
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import io
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -15,6 +16,26 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
if sys.platform != "win32":
|
||||
import fcntl
|
||||
|
||||
|
||||
def _set_secure_file_permissions(file_path: Path) -> None:
|
||||
"""Set restrictive file permissions (owner read/write only)."""
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
username = os.environ.get("USERNAME", "")
|
||||
if username:
|
||||
subprocess.run(
|
||||
["icacls", str(file_path), "/inheritance:r",
|
||||
"/grant:r", f"{username}:F"],
|
||||
check=True, capture_output=True
|
||||
)
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
else:
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Raised when encryption/decryption fails."""
|
||||
@@ -103,22 +124,38 @@ class LocalKeyEncryption(EncryptionProvider):
|
||||
return self._master_key
|
||||
|
||||
def _load_or_create_master_key(self) -> bytes:
|
||||
"""Load master key from file or generate a new one."""
|
||||
if self.master_key_path.exists():
|
||||
try:
|
||||
return base64.b64decode(self.master_key_path.read_text().strip())
|
||||
except Exception as exc:
|
||||
raise EncryptionError(f"Failed to load master key: {exc}") from exc
|
||||
|
||||
key = secrets.token_bytes(32)
|
||||
"""Load master key from file or generate a new one (with file locking)."""
|
||||
lock_path = self.master_key_path.with_suffix(".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.master_key_path.write_text(base64.b64encode(key).decode())
|
||||
if sys.platform != "win32":
|
||||
os.chmod(self.master_key_path, 0o600)
|
||||
with open(lock_path, "w") as lock_file:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
try:
|
||||
if self.master_key_path.exists():
|
||||
try:
|
||||
return base64.b64decode(self.master_key_path.read_text().strip())
|
||||
except Exception as exc:
|
||||
raise EncryptionError(f"Failed to load master key: {exc}") from exc
|
||||
key = secrets.token_bytes(32)
|
||||
try:
|
||||
self.master_key_path.write_text(base64.b64encode(key).decode())
|
||||
_set_secure_file_permissions(self.master_key_path)
|
||||
except OSError as exc:
|
||||
raise EncryptionError(f"Failed to save master key: {exc}") from exc
|
||||
return key
|
||||
finally:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
except OSError as exc:
|
||||
raise EncryptionError(f"Failed to save master key: {exc}") from exc
|
||||
return key
|
||||
raise EncryptionError(f"Failed to acquire lock for master key: {exc}") from exc
|
||||
|
||||
def _encrypt_data_key(self, data_key: bytes) -> bytes:
|
||||
"""Encrypt the data key with the master key."""
|
||||
|
||||
106
app/iam.py
106
app/iam.py
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import math
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
@@ -118,12 +120,14 @@ class IamService:
|
||||
self._raw_config: Dict[str, Any] = {}
|
||||
self._failed_attempts: Dict[str, Deque[datetime]] = {}
|
||||
self._last_load_time = 0.0
|
||||
self._credential_cache: Dict[str, Tuple[str, Principal, float]] = {}
|
||||
self._cache_ttl = 10.0
|
||||
self._principal_cache: Dict[str, Tuple[Principal, float]] = {}
|
||||
self._cache_ttl = 10.0
|
||||
self._last_stat_check = 0.0
|
||||
self._stat_check_interval = 1.0
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self._session_lock = threading.Lock()
|
||||
self._load()
|
||||
self._load_lockout_state()
|
||||
|
||||
def _maybe_reload(self) -> None:
|
||||
"""Reload configuration if the file has changed on disk."""
|
||||
@@ -134,7 +138,7 @@ class IamService:
|
||||
try:
|
||||
if self.config_path.stat().st_mtime > self._last_load_time:
|
||||
self._load()
|
||||
self._credential_cache.clear()
|
||||
self._principal_cache.clear()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -163,11 +167,46 @@ class IamService:
|
||||
attempts = self._failed_attempts.setdefault(access_key, deque())
|
||||
self._prune_attempts(attempts)
|
||||
attempts.append(datetime.now(timezone.utc))
|
||||
self._save_lockout_state()
|
||||
|
||||
def _clear_failed_attempts(self, access_key: str) -> None:
|
||||
if not access_key:
|
||||
return
|
||||
self._failed_attempts.pop(access_key, None)
|
||||
if self._failed_attempts.pop(access_key, None) is not None:
|
||||
self._save_lockout_state()
|
||||
|
||||
def _lockout_file(self) -> Path:
|
||||
return self.config_path.parent / "lockout_state.json"
|
||||
|
||||
def _load_lockout_state(self) -> None:
|
||||
"""Load lockout state from disk."""
|
||||
try:
|
||||
if self._lockout_file().exists():
|
||||
data = json.loads(self._lockout_file().read_text(encoding="utf-8"))
|
||||
cutoff = datetime.now(timezone.utc) - self.auth_lockout_window
|
||||
for key, timestamps in data.get("failed_attempts", {}).items():
|
||||
valid = []
|
||||
for ts in timestamps:
|
||||
try:
|
||||
dt = datetime.fromisoformat(ts)
|
||||
if dt > cutoff:
|
||||
valid.append(dt)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if valid:
|
||||
self._failed_attempts[key] = deque(valid)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
def _save_lockout_state(self) -> None:
|
||||
"""Persist lockout state to disk."""
|
||||
data: Dict[str, Any] = {"failed_attempts": {}}
|
||||
for key, attempts in self._failed_attempts.items():
|
||||
data["failed_attempts"][key] = [ts.isoformat() for ts in attempts]
|
||||
try:
|
||||
self._lockout_file().write_text(json.dumps(data), encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _prune_attempts(self, attempts: Deque[datetime]) -> None:
|
||||
cutoff = datetime.now(timezone.utc) - self.auth_lockout_window
|
||||
@@ -210,17 +249,23 @@ class IamService:
|
||||
return token
|
||||
|
||||
def validate_session_token(self, access_key: str, session_token: str) -> bool:
|
||||
"""Validate a session token for an access key."""
|
||||
session = self._sessions.get(session_token)
|
||||
if not session:
|
||||
hmac.compare_digest(access_key, secrets.token_urlsafe(16))
|
||||
return False
|
||||
if not hmac.compare_digest(session["access_key"], access_key):
|
||||
return False
|
||||
if time.time() > session["expires_at"]:
|
||||
del self._sessions[session_token]
|
||||
return False
|
||||
return True
|
||||
"""Validate a session token for an access key (thread-safe, constant-time)."""
|
||||
dummy_key = secrets.token_urlsafe(16)
|
||||
dummy_token = secrets.token_urlsafe(32)
|
||||
with self._session_lock:
|
||||
session = self._sessions.get(session_token)
|
||||
if not session:
|
||||
hmac.compare_digest(access_key, dummy_key)
|
||||
hmac.compare_digest(session_token, dummy_token)
|
||||
return False
|
||||
key_match = hmac.compare_digest(session["access_key"], access_key)
|
||||
if not key_match:
|
||||
hmac.compare_digest(session_token, dummy_token)
|
||||
return False
|
||||
if time.time() > session["expires_at"]:
|
||||
self._sessions.pop(session_token, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _cleanup_expired_sessions(self) -> None:
|
||||
"""Remove expired session tokens."""
|
||||
@@ -231,9 +276,9 @@ class IamService:
|
||||
|
||||
def principal_for_key(self, access_key: str) -> Principal:
|
||||
now = time.time()
|
||||
cached = self._credential_cache.get(access_key)
|
||||
cached = self._principal_cache.get(access_key)
|
||||
if cached:
|
||||
secret, principal, cached_time = cached
|
||||
principal, cached_time = cached
|
||||
if now - cached_time < self._cache_ttl:
|
||||
return principal
|
||||
|
||||
@@ -242,23 +287,14 @@ class IamService:
|
||||
if not record:
|
||||
raise IamError("Unknown access key")
|
||||
principal = self._build_principal(access_key, record)
|
||||
self._credential_cache[access_key] = (record["secret_key"], principal, now)
|
||||
self._principal_cache[access_key] = (principal, now)
|
||||
return principal
|
||||
|
||||
def secret_for_key(self, access_key: str) -> str:
|
||||
now = time.time()
|
||||
cached = self._credential_cache.get(access_key)
|
||||
if cached:
|
||||
secret, principal, cached_time = cached
|
||||
if now - cached_time < self._cache_ttl:
|
||||
return secret
|
||||
|
||||
self._maybe_reload()
|
||||
record = self._users.get(access_key)
|
||||
if not record:
|
||||
raise IamError("Unknown access key")
|
||||
principal = self._build_principal(access_key, record)
|
||||
self._credential_cache[access_key] = (record["secret_key"], principal, now)
|
||||
return record["secret_key"]
|
||||
|
||||
def authorize(self, principal: Principal, bucket_name: str | None, action: str) -> None:
|
||||
@@ -330,6 +366,7 @@ class IamService:
|
||||
new_secret = self._generate_secret_key()
|
||||
user["secret_key"] = new_secret
|
||||
self._save()
|
||||
self._principal_cache.pop(access_key, None)
|
||||
self._load()
|
||||
return new_secret
|
||||
|
||||
@@ -509,26 +546,17 @@ class IamService:
|
||||
raise IamError("User not found")
|
||||
|
||||
def get_secret_key(self, access_key: str) -> str | None:
|
||||
now = time.time()
|
||||
cached = self._credential_cache.get(access_key)
|
||||
if cached:
|
||||
secret, principal, cached_time = cached
|
||||
if now - cached_time < self._cache_ttl:
|
||||
return secret
|
||||
|
||||
self._maybe_reload()
|
||||
record = self._users.get(access_key)
|
||||
if record:
|
||||
principal = self._build_principal(access_key, record)
|
||||
self._credential_cache[access_key] = (record["secret_key"], principal, now)
|
||||
return record["secret_key"]
|
||||
return None
|
||||
|
||||
def get_principal(self, access_key: str) -> Principal | None:
|
||||
now = time.time()
|
||||
cached = self._credential_cache.get(access_key)
|
||||
cached = self._principal_cache.get(access_key)
|
||||
if cached:
|
||||
secret, principal, cached_time = cached
|
||||
principal, cached_time = cached
|
||||
if now - cached_time < self._cache_ttl:
|
||||
return principal
|
||||
|
||||
@@ -536,6 +564,6 @@ class IamService:
|
||||
record = self._users.get(access_key)
|
||||
if record:
|
||||
principal = self._build_principal(access_key, record)
|
||||
self._credential_cache[access_key] = (record["secret_key"], principal, now)
|
||||
self._principal_cache[access_key] = (principal, now)
|
||||
return principal
|
||||
return None
|
||||
|
||||
63
app/kms.py
63
app/kms.py
@@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,9 +17,29 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from .encryption import EncryptionError, EncryptionProvider, EncryptionResult
|
||||
|
||||
if sys.platform != "win32":
|
||||
import fcntl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _set_secure_file_permissions(file_path: Path) -> None:
|
||||
"""Set restrictive file permissions (owner read/write only)."""
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
username = os.environ.get("USERNAME", "")
|
||||
if username:
|
||||
subprocess.run(
|
||||
["icacls", str(file_path), "/inheritance:r",
|
||||
"/grant:r", f"{username}:F"],
|
||||
check=True, capture_output=True
|
||||
)
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
pass
|
||||
else:
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KMSKey:
|
||||
"""Represents a KMS encryption key."""
|
||||
@@ -132,20 +153,33 @@ class KMSManager:
|
||||
|
||||
@property
|
||||
def master_key(self) -> bytes:
|
||||
"""Load or create the master key for encrypting KMS keys."""
|
||||
"""Load or create the master key for encrypting KMS keys (with file locking)."""
|
||||
if self._master_key is None:
|
||||
if self.master_key_path.exists():
|
||||
self._master_key = base64.b64decode(
|
||||
self.master_key_path.read_text().strip()
|
||||
)
|
||||
else:
|
||||
self._master_key = secrets.token_bytes(32)
|
||||
self.master_key_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.master_key_path.write_text(
|
||||
base64.b64encode(self._master_key).decode()
|
||||
)
|
||||
if sys.platform != "win32":
|
||||
os.chmod(self.master_key_path, 0o600)
|
||||
lock_path = self.master_key_path.with_suffix(".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(lock_path, "w") as lock_file:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
try:
|
||||
if self.master_key_path.exists():
|
||||
self._master_key = base64.b64decode(
|
||||
self.master_key_path.read_text().strip()
|
||||
)
|
||||
else:
|
||||
self._master_key = secrets.token_bytes(32)
|
||||
self.master_key_path.write_text(
|
||||
base64.b64encode(self._master_key).decode()
|
||||
)
|
||||
_set_secure_file_permissions(self.master_key_path)
|
||||
finally:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
return self._master_key
|
||||
|
||||
def _load_keys(self) -> None:
|
||||
@@ -177,12 +211,13 @@ class KMSManager:
|
||||
encrypted = self._encrypt_key_material(key.key_material)
|
||||
data["EncryptedKeyMaterial"] = base64.b64encode(encrypted).decode()
|
||||
keys_data.append(data)
|
||||
|
||||
|
||||
self.keys_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.keys_path.write_text(
|
||||
json.dumps({"keys": keys_data}, indent=2),
|
||||
encoding="utf-8"
|
||||
)
|
||||
_set_secure_file_permissions(self.keys_path)
|
||||
|
||||
def _encrypt_key_material(self, key_material: bytes) -> bytes:
|
||||
"""Encrypt key material with the master key."""
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -14,6 +16,36 @@ from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def _is_safe_url(url: str) -> bool:
|
||||
"""Check if a URL is safe to make requests to (not internal/private)."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
blocked_hosts = {
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"0.0.0.0",
|
||||
"::1",
|
||||
"[::1]",
|
||||
"metadata.google.internal",
|
||||
"169.254.169.254",
|
||||
}
|
||||
if hostname.lower() in blocked_hosts:
|
||||
return False
|
||||
try:
|
||||
resolved_ip = socket.gethostbyname(hostname)
|
||||
ip = ipaddress.ip_address(resolved_ip)
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
return False
|
||||
except (socket.gaierror, ValueError):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -299,6 +331,8 @@ class NotificationService:
|
||||
self._queue.task_done()
|
||||
|
||||
def _send_notification(self, event: NotificationEvent, destination: WebhookDestination) -> None:
|
||||
if not _is_safe_url(destination.url):
|
||||
raise RuntimeError(f"Blocked request to internal/private URL: {destination.url}")
|
||||
payload = event.to_s3_event()
|
||||
headers = {"Content-Type": "application/json", **destination.headers}
|
||||
|
||||
|
||||
@@ -60,10 +60,13 @@ def _build_policy_context() -> Dict[str, Any]:
|
||||
ctx: Dict[str, Any] = {}
|
||||
if request.headers.get("Referer"):
|
||||
ctx["aws:Referer"] = request.headers.get("Referer")
|
||||
if request.access_route:
|
||||
ctx["aws:SourceIp"] = request.access_route[0]
|
||||
num_proxies = current_app.config.get("NUM_TRUSTED_PROXIES", 0)
|
||||
if num_proxies > 0 and request.access_route and len(request.access_route) > num_proxies:
|
||||
ctx["aws:SourceIp"] = request.access_route[-num_proxies]
|
||||
elif request.remote_addr:
|
||||
ctx["aws:SourceIp"] = request.remote_addr
|
||||
elif request.access_route:
|
||||
ctx["aws:SourceIp"] = request.access_route[0]
|
||||
ctx["aws:SecureTransport"] = str(request.is_secure).lower()
|
||||
if request.headers.get("User-Agent"):
|
||||
ctx["aws:UserAgent"] = request.headers.get("User-Agent")
|
||||
@@ -2242,6 +2245,17 @@ def _post_object(bucket_name: str) -> Response:
|
||||
expected_signature = hmac.new(signing_key, policy_b64.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
if not hmac.compare_digest(expected_signature, signature):
|
||||
return _error_response("SignatureDoesNotMatch", "Signature verification failed", 403)
|
||||
principal = _iam().get_principal(access_key)
|
||||
if not principal:
|
||||
return _error_response("AccessDenied", "Invalid access key", 403)
|
||||
if "${filename}" in object_key:
|
||||
temp_key = object_key.replace("${filename}", request.files.get("file").filename if request.files.get("file") else "upload")
|
||||
else:
|
||||
temp_key = object_key
|
||||
try:
|
||||
_authorize_action(principal, bucket_name, "write", object_key=temp_key)
|
||||
except IamError as exc:
|
||||
return _error_response("AccessDenied", str(exc), 403)
|
||||
file = request.files.get("file")
|
||||
if not file:
|
||||
return _error_response("InvalidArgument", "Missing file field", 400)
|
||||
@@ -2263,6 +2277,12 @@ def _post_object(bucket_name: str) -> Response:
|
||||
success_action_status = request.form.get("success_action_status", "204")
|
||||
success_action_redirect = request.form.get("success_action_redirect")
|
||||
if success_action_redirect:
|
||||
allowed_hosts = current_app.config.get("ALLOWED_REDIRECT_HOSTS", [])
|
||||
parsed = urlparse(success_action_redirect)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return _error_response("InvalidArgument", "Redirect URL must use http or https", 400)
|
||||
if allowed_hosts and parsed.netloc not in allowed_hosts:
|
||||
return _error_response("InvalidArgument", "Redirect URL host not allowed", 400)
|
||||
redirect_url = f"{success_action_redirect}?bucket={bucket_name}&key={quote(object_key)}&etag={meta.etag}"
|
||||
return Response(status=303, headers={"Location": redirect_url})
|
||||
if success_action_status == "200":
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from urllib.parse import quote
|
||||
|
||||
def _sign(key, msg):
|
||||
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
|
||||
|
||||
def _get_signature_key(key, date_stamp, region_name, service_name):
|
||||
k_date = _sign(("AWS4" + key).encode("utf-8"), date_stamp)
|
||||
k_region = _sign(k_date, region_name)
|
||||
k_service = _sign(k_region, service_name)
|
||||
k_signing = _sign(k_service, "aws4_request")
|
||||
return k_signing
|
||||
|
||||
def create_signed_headers(
|
||||
method,
|
||||
path,
|
||||
headers=None,
|
||||
body=None,
|
||||
access_key="test",
|
||||
secret_key="secret",
|
||||
region="us-east-1",
|
||||
service="s3",
|
||||
timestamp=None
|
||||
):
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if timestamp is None:
|
||||
now = datetime.now(timezone.utc)
|
||||
else:
|
||||
now = timestamp
|
||||
|
||||
amz_date = now.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = now.strftime("%Y%m%d")
|
||||
|
||||
headers["X-Amz-Date"] = amz_date
|
||||
headers["Host"] = "testserver"
|
||||
|
||||
canonical_uri = quote(path, safe="/-_.~")
|
||||
canonical_query_string = ""
|
||||
|
||||
canonical_headers = ""
|
||||
signed_headers_list = []
|
||||
for k, v in sorted(headers.items(), key=lambda x: x[0].lower()):
|
||||
canonical_headers += f"{k.lower()}:{v.strip()}\n"
|
||||
signed_headers_list.append(k.lower())
|
||||
|
||||
signed_headers = ";".join(signed_headers_list)
|
||||
|
||||
payload_hash = hashlib.sha256(body or b"").hexdigest()
|
||||
headers["X-Amz-Content-Sha256"] = payload_hash
|
||||
|
||||
canonical_request = f"{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
|
||||
|
||||
credential_scope = f"{date_stamp}/{region}/{service}/aws4_request"
|
||||
string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}"
|
||||
|
||||
signing_key = _get_signature_key(secret_key, date_stamp, region, service)
|
||||
signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
|
||||
headers["Authorization"] = (
|
||||
f"AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, "
|
||||
f"SignedHeaders={signed_headers}, Signature={signature}"
|
||||
)
|
||||
return headers
|
||||
|
||||
def test_sigv4_old_date(client):
|
||||
# Test with a date 20 minutes in the past
|
||||
old_time = datetime.now(timezone.utc) - timedelta(minutes=20)
|
||||
headers = create_signed_headers("GET", "/", timestamp=old_time)
|
||||
|
||||
response = client.get("/", headers=headers)
|
||||
assert response.status_code == 403
|
||||
assert b"Request timestamp too old" in response.data
|
||||
|
||||
def test_sigv4_future_date(client):
|
||||
# Test with a date 20 minutes in the future
|
||||
future_time = datetime.now(timezone.utc) + timedelta(minutes=20)
|
||||
headers = create_signed_headers("GET", "/", timestamp=future_time)
|
||||
|
||||
response = client.get("/", headers=headers)
|
||||
assert response.status_code == 403
|
||||
assert b"Request timestamp too old" in response.data # The error message is the same
|
||||
|
||||
def test_path_traversal_in_key(client, signer):
|
||||
headers = signer("PUT", "/test-bucket")
|
||||
client.put("/test-bucket", headers=headers)
|
||||
|
||||
# Try to upload with .. in key
|
||||
headers = signer("PUT", "/test-bucket/../secret.txt", body=b"attack")
|
||||
response = client.put("/test-bucket/../secret.txt", headers=headers, data=b"attack")
|
||||
|
||||
# Should be rejected by storage layer or flask routing
|
||||
# Flask might normalize it before it reaches the app, but if it reaches, it should fail.
|
||||
# If Flask normalizes /test-bucket/../secret.txt to /secret.txt, then it hits 404 (bucket not found) or 403.
|
||||
# But we want to test the storage layer check.
|
||||
# We can try to encode the dots?
|
||||
|
||||
# If we use a key that doesn't get normalized by Flask routing easily.
|
||||
# But wait, the route is /<bucket_name>/<path:object_key>
|
||||
# If I send /test-bucket/folder/../file.txt, Flask might pass "folder/../file.txt" as object_key?
|
||||
# Let's try.
|
||||
|
||||
headers = signer("PUT", "/test-bucket/folder/../file.txt", body=b"attack")
|
||||
response = client.put("/test-bucket/folder/../file.txt", headers=headers, data=b"attack")
|
||||
|
||||
# If Flask normalizes it, it becomes /test-bucket/file.txt.
|
||||
# If it doesn't, it hits our check.
|
||||
|
||||
# Let's try to call the storage method directly to verify the check works,
|
||||
# because testing via client depends on Flask's URL handling.
|
||||
pass
|
||||
|
||||
def test_storage_path_traversal(app):
|
||||
storage = app.extensions["object_storage"]
|
||||
from app.storage import StorageError, ObjectStorage
|
||||
from app.encrypted_storage import EncryptedObjectStorage
|
||||
|
||||
# Get the underlying ObjectStorage if wrapped
|
||||
if isinstance(storage, EncryptedObjectStorage):
|
||||
storage = storage.storage
|
||||
|
||||
with pytest.raises(StorageError, match="Object key contains parent directory references"):
|
||||
storage._sanitize_object_key("folder/../file.txt")
|
||||
|
||||
with pytest.raises(StorageError, match="Object key contains parent directory references"):
|
||||
storage._sanitize_object_key("..")
|
||||
|
||||
def test_head_bucket(client, signer):
|
||||
headers = signer("PUT", "/head-test")
|
||||
client.put("/head-test", headers=headers)
|
||||
|
||||
headers = signer("HEAD", "/head-test")
|
||||
response = client.head("/head-test", headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
headers = signer("HEAD", "/non-existent")
|
||||
response = client.head("/non-existent", headers=headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_head_object(client, signer):
|
||||
headers = signer("PUT", "/head-obj-test")
|
||||
client.put("/head-obj-test", headers=headers)
|
||||
|
||||
headers = signer("PUT", "/head-obj-test/obj", body=b"content")
|
||||
client.put("/head-obj-test/obj", headers=headers, data=b"content")
|
||||
|
||||
headers = signer("HEAD", "/head-obj-test/obj")
|
||||
response = client.head("/head-obj-test/obj", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["ETag"]
|
||||
assert response.headers["Content-Length"] == "7"
|
||||
|
||||
headers = signer("HEAD", "/head-obj-test/missing")
|
||||
response = client.head("/head-obj-test/missing", headers=headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_list_parts(client, signer):
|
||||
# Create bucket
|
||||
headers = signer("PUT", "/multipart-test")
|
||||
client.put("/multipart-test", headers=headers)
|
||||
|
||||
# Initiate multipart upload
|
||||
headers = signer("POST", "/multipart-test/obj?uploads")
|
||||
response = client.post("/multipart-test/obj?uploads", headers=headers)
|
||||
assert response.status_code == 200
|
||||
from xml.etree.ElementTree import fromstring
|
||||
upload_id = fromstring(response.data).find("UploadId").text
|
||||
|
||||
# Upload part 1
|
||||
headers = signer("PUT", f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", body=b"part1")
|
||||
client.put(f"/multipart-test/obj?partNumber=1&uploadId={upload_id}", headers=headers, data=b"part1")
|
||||
|
||||
# Upload part 2
|
||||
headers = signer("PUT", f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", body=b"part2")
|
||||
client.put(f"/multipart-test/obj?partNumber=2&uploadId={upload_id}", headers=headers, data=b"part2")
|
||||
|
||||
# List parts
|
||||
headers = signer("GET", f"/multipart-test/obj?uploadId={upload_id}")
|
||||
response = client.get(f"/multipart-test/obj?uploadId={upload_id}", headers=headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
root = fromstring(response.data)
|
||||
assert root.tag == "ListPartsResult"
|
||||
parts = root.findall("Part")
|
||||
assert len(parts) == 2
|
||||
assert parts[0].find("PartNumber").text == "1"
|
||||
assert parts[1].find("PartNumber").text == "2"
|
||||
Reference in New Issue
Block a user