diff --git a/app/iam.py b/app/iam.py index f4d895d..7e8f3fa 100644 --- a/app/iam.py +++ b/app/iam.py @@ -398,9 +398,11 @@ class IamService: record = self._user_records.get(user_id) if record: self._check_expiry(access_key, record) + self._enforce_key_and_user_status(access_key) return principal self._maybe_reload() + self._enforce_key_and_user_status(access_key) user_id = self._key_index.get(access_key) if not user_id: raise IamError("Unknown access key") @@ -414,6 +416,7 @@ class IamService: def secret_for_key(self, access_key: str) -> str: self._maybe_reload() + self._enforce_key_and_user_status(access_key) secret = self._key_secrets.get(access_key) if not secret: raise IamError("Unknown access key") @@ -1028,6 +1031,16 @@ class IamService: user, _ = self._resolve_raw_user(access_key) return user + def _enforce_key_and_user_status(self, access_key: str) -> None: + key_status = self._key_status.get(access_key, "active") + if key_status != "active": + raise IamError("Access key is inactive") + user_id = self._key_index.get(access_key) + if user_id: + record = self._user_records.get(user_id) + if record and not record.get("enabled", True): + raise IamError("User account is disabled") + def get_secret_key(self, access_key: str) -> str | None: now = time.time() cached = self._secret_key_cache.get(access_key) @@ -1039,6 +1052,7 @@ class IamService: record = self._user_records.get(user_id) if record: self._check_expiry(access_key, record) + self._enforce_key_and_user_status(access_key) return secret_key self._maybe_reload() @@ -1049,6 +1063,7 @@ class IamService: record = self._user_records.get(user_id) if record: self._check_expiry(access_key, record) + self._enforce_key_and_user_status(access_key) self._secret_key_cache[access_key] = (secret, now) return secret return None @@ -1064,9 +1079,11 @@ class IamService: record = self._user_records.get(user_id) if record: self._check_expiry(access_key, record) + self._enforce_key_and_user_status(access_key) return principal self._maybe_reload() + self._enforce_key_and_user_status(access_key) user_id = self._key_index.get(access_key) if user_id: record = self._user_records.get(user_id) diff --git a/app/s3_api.py b/app/s3_api.py index c4c3864..5cde319 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -534,21 +534,6 @@ def _authorize_action(principal: Principal | None, bucket_name: str | None, acti raise iam_error or IamError("Access denied") -def _enforce_bucket_policy(principal: Principal | None, bucket_name: str | None, object_key: str | None, action: str) -> None: - if not bucket_name: - return - policy_context = _build_policy_context() - decision = _bucket_policies().evaluate( - principal.access_key if principal else None, - bucket_name, - object_key, - action, - policy_context, - ) - if decision == "deny": - raise IamError("Access denied by bucket policy") - - def _object_principal(action: str, bucket_name: str, object_key: str): principal, error = _require_principal() try: @@ -557,121 +542,7 @@ def _object_principal(action: str, bucket_name: str, object_key: str): except IamError as exc: if not error: return None, _error_response("AccessDenied", str(exc), 403) - if not _has_presign_params(): return None, error - try: - principal = _validate_presigned_request(action, bucket_name, object_key) - _enforce_bucket_policy(principal, bucket_name, object_key, action) - return principal, None - except IamError as exc: - return None, _error_response("AccessDenied", str(exc), 403) - - -def _has_presign_params() -> bool: - return bool(request.args.get("X-Amz-Algorithm")) - - -def _validate_presigned_request(action: str, bucket_name: str, object_key: str) -> Principal: - algorithm = request.args.get("X-Amz-Algorithm") - credential = request.args.get("X-Amz-Credential") - amz_date = request.args.get("X-Amz-Date") - signed_headers = request.args.get("X-Amz-SignedHeaders") - expires = request.args.get("X-Amz-Expires") - signature = request.args.get("X-Amz-Signature") - if not all([algorithm, credential, amz_date, signed_headers, expires, signature]): - raise IamError("Malformed presigned URL") - if algorithm != "AWS4-HMAC-SHA256": - raise IamError("Unsupported signing algorithm") - - parts = credential.split("/") - if len(parts) != 5: - raise IamError("Invalid credential scope") - access_key, date_stamp, region, service, terminal = parts - if terminal != "aws4_request": - raise IamError("Invalid credential scope") - config_region = current_app.config["AWS_REGION"] - config_service = current_app.config["AWS_SERVICE"] - if region != config_region or service != config_service: - raise IamError("Credential scope mismatch") - - try: - expiry = int(expires) - except ValueError as exc: - raise IamError("Invalid expiration") from exc - min_expiry = current_app.config.get("PRESIGNED_URL_MIN_EXPIRY_SECONDS", 1) - max_expiry = current_app.config.get("PRESIGNED_URL_MAX_EXPIRY_SECONDS", 604800) - if expiry < min_expiry or expiry > max_expiry: - raise IamError(f"Expiration must be between {min_expiry} second(s) and {max_expiry} seconds") - - try: - request_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) - except ValueError as exc: - raise IamError("Invalid X-Amz-Date") from exc - now = datetime.now(timezone.utc) - tolerance = timedelta(seconds=current_app.config.get("SIGV4_TIMESTAMP_TOLERANCE_SECONDS", 900)) - if request_time > now + tolerance: - raise IamError("Request date is too far in the future") - if now > request_time + timedelta(seconds=expiry): - raise IamError("Presigned URL expired") - - signed_headers_list = [header.strip().lower() for header in signed_headers.split(";") if header] - signed_headers_list.sort() - canonical_headers = _canonical_headers_from_request(signed_headers_list) - canonical_query = _canonical_query_from_request() - payload_hash = request.args.get("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD") - canonical_request = "\n".join( - [ - request.method, - _canonical_uri(bucket_name, object_key), - canonical_query, - canonical_headers, - ";".join(signed_headers_list), - payload_hash, - ] - ) - hashed_request = hashlib.sha256(canonical_request.encode()).hexdigest() - scope = f"{date_stamp}/{region}/{service}/aws4_request" - string_to_sign = "\n".join([ - "AWS4-HMAC-SHA256", - amz_date, - scope, - hashed_request, - ]) - secret = _iam().secret_for_key(access_key) - signing_key = _derive_signing_key(secret, date_stamp, region, service) - expected = hmac.new(signing_key, string_to_sign.encode(), hashlib.sha256).hexdigest() - if not hmac.compare_digest(expected, signature): - raise IamError("Signature mismatch") - return _iam().principal_for_key(access_key) - - -def _canonical_query_from_request() -> str: - parts = [] - for key in sorted(request.args.keys()): - if key == "X-Amz-Signature": - continue - values = request.args.getlist(key) - encoded_key = quote(str(key), safe="-_.~") - for value in sorted(values): - encoded_value = quote(str(value), safe="-_.~") - parts.append(f"{encoded_key}={encoded_value}") - return "&".join(parts) - - -def _canonical_headers_from_request(headers: list[str]) -> str: - lines = [] - for header in headers: - if header == "host": - api_base = current_app.config.get("API_BASE_URL") - if api_base: - value = urlparse(api_base).netloc - else: - value = request.host - else: - value = request.headers.get(header, "") - canonical_value = " ".join(value.strip().split()) if value else "" - lines.append(f"{header}:{canonical_value}") - return "\n".join(lines) + "\n" def _canonical_uri(bucket_name: str, object_key: str | None) -> str: @@ -737,8 +608,8 @@ def _generate_presigned_url( host = parsed.netloc scheme = parsed.scheme else: - host = request.headers.get("X-Forwarded-Host", request.host) - scheme = request.headers.get("X-Forwarded-Proto", request.scheme or "http") + host = request.host + scheme = request.scheme or "http" canonical_headers = f"host:{host}\n" canonical_request = "\n".join( diff --git a/tests/test_api.py b/tests/test_api.py index 3e95b48..3b650ec 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,3 +1,56 @@ +import hashlib +import hmac +from datetime import datetime, timezone +from urllib.parse import quote + + +def _build_presigned_query(path: str, *, access_key: str = "test", secret_key: str = "secret", expires: int = 60) -> str: + now = datetime.now(timezone.utc) + amz_date = now.strftime("%Y%m%dT%H%M%SZ") + date_stamp = now.strftime("%Y%m%d") + region = "us-east-1" + service = "s3" + credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" + + query_items = [ + ("X-Amz-Algorithm", "AWS4-HMAC-SHA256"), + ("X-Amz-Content-Sha256", "UNSIGNED-PAYLOAD"), + ("X-Amz-Credential", f"{access_key}/{credential_scope}"), + ("X-Amz-Date", amz_date), + ("X-Amz-Expires", str(expires)), + ("X-Amz-SignedHeaders", "host"), + ] + canonical_query = "&".join( + f"{quote(k, safe='-_.~')}={quote(v, safe='-_.~')}" for k, v in sorted(query_items) + ) + + canonical_request = "\n".join([ + "GET", + quote(path, safe="/-_.~"), + canonical_query, + "host:localhost\n", + "host", + "UNSIGNED-PAYLOAD", + ]) + hashed_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + string_to_sign = "\n".join([ + "AWS4-HMAC-SHA256", + amz_date, + credential_scope, + hashed_request, + ]) + + def _sign(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + k_date = _sign(("AWS4" + secret_key).encode("utf-8"), date_stamp) + k_region = _sign(k_date, region) + k_service = _sign(k_region, service) + signing_key = _sign(k_service, "aws4_request") + signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() + return canonical_query + f"&X-Amz-Signature={signature}" + + def test_bucket_and_object_lifecycle(client, signer): headers = signer("PUT", "/photos") response = client.put("/photos", headers=headers) @@ -114,6 +167,45 @@ def test_missing_credentials_denied(client): assert response.status_code == 403 +def test_presigned_url_denied_for_disabled_user(client, signer): + headers = signer("PUT", "/secure") + assert client.put("/secure", headers=headers).status_code == 200 + + payload = b"hello" + headers = signer("PUT", "/secure/file.txt", body=payload) + assert client.put("/secure/file.txt", headers=headers, data=payload).status_code == 200 + + iam = client.application.extensions["iam"] + iam.disable_user("test") + + query = _build_presigned_query("/secure/file.txt") + response = client.get(f"/secure/file.txt?{query}", headers={"Host": "localhost"}) + assert response.status_code == 403 + assert b"User account is disabled" in response.data + + +def test_presigned_url_denied_for_inactive_key(client, signer): + headers = signer("PUT", "/secure2") + assert client.put("/secure2", headers=headers).status_code == 200 + + payload = b"hello" + headers = signer("PUT", "/secure2/file.txt", body=payload) + assert client.put("/secure2/file.txt", headers=headers, data=payload).status_code == 200 + + iam = client.application.extensions["iam"] + for user in iam._raw_config.get("users", []): + for key_info in user.get("access_keys", []): + if key_info.get("access_key") == "test": + key_info["status"] = "inactive" + iam._save() + iam._load() + + query = _build_presigned_query("/secure2/file.txt") + response = client.get(f"/secure2/file.txt?{query}", headers={"Host": "localhost"}) + assert response.status_code == 403 + assert b"Access key is inactive" in response.data + + def test_bucket_policies_deny_reads(client, signer): import json