diff --git a/app/s3_api.py b/app/s3_api.py index 18777f4..83755e7 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -267,39 +267,6 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: if not secret_key: raise IamError("SignatureDoesNotMatch") - method = req.method - canonical_uri = _get_canonical_uri(req) - - query_args = [] - for key, value in req.args.items(multi=True): - query_args.append((key, value)) - query_args.sort(key=lambda x: (x[0], x[1])) - - canonical_query_parts = [] - for k, v in query_args: - canonical_query_parts.append(f"{quote(k, safe='-_.~')}={quote(v, safe='-_.~')}") - canonical_query_string = "&".join(canonical_query_parts) - - signed_headers_list = signed_headers_str.split(";") - canonical_headers_parts = [] - for header in signed_headers_list: - header_val = req.headers.get(header) - if header_val is None: - header_val = "" - - if header.lower() == 'expect' and header_val == "": - header_val = "100-continue" - - header_val = " ".join(header_val.split()) - canonical_headers_parts.append(f"{header.lower()}:{header_val}\n") - canonical_headers = "".join(canonical_headers_parts) - - payload_hash = req.headers.get("X-Amz-Content-Sha256") - if not payload_hash: - payload_hash = hashlib.sha256(req.get_data()).hexdigest() - - canonical_request = f"{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers_str}\n{payload_hash}" - amz_date = req.headers.get("X-Amz-Date") or req.headers.get("Date") if not amz_date: raise IamError("Missing Date header") @@ -321,23 +288,60 @@ def _verify_sigv4_header(req: Any, auth_header: str) -> Principal | None: if 'date' in signed_headers_set: required_headers.remove('x-amz-date') required_headers.add('date') - + if not required_headers.issubset(signed_headers_set): raise IamError("Required headers not signed") - credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" - signing_key = _get_signature_key(secret_key, date_stamp, region, service) + canonical_uri = _get_canonical_uri(req) + payload_hash = req.headers.get("X-Amz-Content-Sha256") + if not payload_hash: + payload_hash = hashlib.sha256(req.get_data()).hexdigest() + if _HAS_RUST: - string_to_sign = _rc.build_string_to_sign(amz_date, credential_scope, canonical_request) - calculated_signature = _rc.compute_signature(signing_key, string_to_sign) + query_params = list(req.args.items(multi=True)) + header_values = [(h, req.headers.get(h) or "") for h in signed_headers_str.split(";")] + if not _rc.verify_sigv4_signature( + req.method, canonical_uri, query_params, signed_headers_str, + header_values, payload_hash, amz_date, date_stamp, region, + service, secret_key, signature, + ): + if current_app.config.get("DEBUG_SIGV4"): + logger.warning("SigV4 signature mismatch for %s %s", req.method, req.path) + raise IamError("SignatureDoesNotMatch") else: + method = req.method + query_args = [] + for key, value in req.args.items(multi=True): + query_args.append((key, value)) + query_args.sort(key=lambda x: (x[0], x[1])) + + canonical_query_parts = [] + for k, v in query_args: + canonical_query_parts.append(f"{quote(k, safe='-_.~')}={quote(v, safe='-_.~')}") + canonical_query_string = "&".join(canonical_query_parts) + + signed_headers_list = signed_headers_str.split(";") + canonical_headers_parts = [] + for header in signed_headers_list: + header_val = req.headers.get(header) + if header_val is None: + header_val = "" + if header.lower() == 'expect' and header_val == "": + header_val = "100-continue" + header_val = " ".join(header_val.split()) + canonical_headers_parts.append(f"{header.lower()}:{header_val}\n") + canonical_headers = "".join(canonical_headers_parts) + + canonical_request = f"{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n{signed_headers_str}\n{payload_hash}" + + credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" + signing_key = _get_signature_key(secret_key, date_stamp, region, service) string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() - - if not hmac.compare_digest(calculated_signature, signature): - if current_app.config.get("DEBUG_SIGV4"): - logger.warning("SigV4 signature mismatch for %s %s", method, req.path) - raise IamError("SignatureDoesNotMatch") + if not hmac.compare_digest(calculated_signature, signature): + if current_app.config.get("DEBUG_SIGV4"): + logger.warning("SigV4 signature mismatch for %s %s", method, req.path) + raise IamError("SignatureDoesNotMatch") session_token = req.headers.get("X-Amz-Security-Token") if session_token: @@ -366,7 +370,7 @@ def _verify_sigv4_query(req: Any) -> Principal | None: req_time = datetime.strptime(amz_date, "%Y%m%dT%H%M%SZ").replace(tzinfo=timezone.utc) except ValueError: raise IamError("Invalid Date format") - + now = datetime.now(timezone.utc) try: expires_seconds = int(expires) @@ -381,53 +385,58 @@ def _verify_sigv4_query(req: Any) -> Principal | None: if not secret_key: raise IamError("Invalid access key") - method = req.method canonical_uri = _get_canonical_uri(req) - - query_args = [] - for key, value in req.args.items(multi=True): - if key != "X-Amz-Signature": - query_args.append((key, value)) - query_args.sort(key=lambda x: (x[0], x[1])) - - canonical_query_parts = [] - for k, v in query_args: - canonical_query_parts.append(f"{quote(k, safe='-_.~')}={quote(v, safe='-_.~')}") - canonical_query_string = "&".join(canonical_query_parts) - - signed_headers_list = signed_headers_str.split(";") - canonical_headers_parts = [] - for header in signed_headers_list: - val = req.headers.get(header, "").strip() - if header.lower() == 'expect' and val == "": - val = "100-continue" - val = " ".join(val.split()) - canonical_headers_parts.append(f"{header.lower()}:{val}\n") - canonical_headers = "".join(canonical_headers_parts) - - payload_hash = "UNSIGNED-PAYLOAD" - - canonical_request = "\n".join([ - method, - canonical_uri, - canonical_query_string, - canonical_headers, - signed_headers_str, - payload_hash - ]) - - credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" - signing_key = _get_signature_key(secret_key, date_stamp, region, service) + if _HAS_RUST: - string_to_sign = _rc.build_string_to_sign(amz_date, credential_scope, canonical_request) - calculated_signature = _rc.compute_signature(signing_key, string_to_sign) + query_params = [(k, v) for k, v in req.args.items(multi=True) if k != "X-Amz-Signature"] + header_values = [(h, req.headers.get(h) or "") for h in signed_headers_str.split(";")] + if not _rc.verify_sigv4_signature( + req.method, canonical_uri, query_params, signed_headers_str, + header_values, "UNSIGNED-PAYLOAD", amz_date, date_stamp, region, + service, secret_key, signature, + ): + raise IamError("SignatureDoesNotMatch") else: + method = req.method + query_args = [] + for key, value in req.args.items(multi=True): + if key != "X-Amz-Signature": + query_args.append((key, value)) + query_args.sort(key=lambda x: (x[0], x[1])) + + canonical_query_parts = [] + for k, v in query_args: + canonical_query_parts.append(f"{quote(k, safe='-_.~')}={quote(v, safe='-_.~')}") + canonical_query_string = "&".join(canonical_query_parts) + + signed_headers_list = signed_headers_str.split(";") + canonical_headers_parts = [] + for header in signed_headers_list: + val = req.headers.get(header, "").strip() + if header.lower() == 'expect' and val == "": + val = "100-continue" + val = " ".join(val.split()) + canonical_headers_parts.append(f"{header.lower()}:{val}\n") + canonical_headers = "".join(canonical_headers_parts) + + payload_hash = "UNSIGNED-PAYLOAD" + + canonical_request = "\n".join([ + method, + canonical_uri, + canonical_query_string, + canonical_headers, + signed_headers_str, + payload_hash + ]) + + credential_scope = f"{date_stamp}/{region}/{service}/aws4_request" + signing_key = _get_signature_key(secret_key, date_stamp, region, service) hashed_request = hashlib.sha256(canonical_request.encode('utf-8')).hexdigest() string_to_sign = f"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{hashed_request}" calculated_signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() - - if not hmac.compare_digest(calculated_signature, signature): - raise IamError("SignatureDoesNotMatch") + if not hmac.compare_digest(calculated_signature, signature): + raise IamError("SignatureDoesNotMatch") session_token = req.args.get("X-Amz-Security-Token") if session_token: diff --git a/myfsio_core/Cargo.toml b/myfsio_core/Cargo.toml index 5c9fa0e..2bff9cc 100644 --- a/myfsio_core/Cargo.toml +++ b/myfsio_core/Cargo.toml @@ -18,3 +18,4 @@ serde_json = "1" regex = "1" lru = "0.14" parking_lot = "0.12" +percent-encoding = "2" diff --git a/myfsio_core/src/lib.rs b/myfsio_core/src/lib.rs index 1321077..fc1b9f3 100644 --- a/myfsio_core/src/lib.rs +++ b/myfsio_core/src/lib.rs @@ -11,6 +11,7 @@ mod myfsio_core { #[pymodule_init] fn init(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sigv4::verify_sigv4_signature, m)?)?; m.add_function(wrap_pyfunction!(sigv4::derive_signing_key, m)?)?; m.add_function(wrap_pyfunction!(sigv4::compute_signature, m)?)?; m.add_function(wrap_pyfunction!(sigv4::build_string_to_sign, m)?)?; diff --git a/myfsio_core/src/sigv4.rs b/myfsio_core/src/sigv4.rs index d52ca78..904a853 100644 --- a/myfsio_core/src/sigv4.rs +++ b/myfsio_core/src/sigv4.rs @@ -1,6 +1,7 @@ use hmac::{Hmac, Mac}; use lru::LruCache; use parking_lot::Mutex; +use percent_encoding::{percent_encode, AsciiSet, NON_ALPHANUMERIC}; use pyo3::prelude::*; use sha2::{Digest, Sha256}; use std::num::NonZeroUsize; @@ -19,14 +20,29 @@ static SIGNING_KEY_CACHE: LazyLock Vec { let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid"); mac.update(msg); mac.finalize().into_bytes().to_vec() } -#[pyfunction] -pub fn derive_signing_key( +fn sha256_hex(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + hex::encode(hasher.finalize()) +} + +fn aws_uri_encode(input: &str) -> String { + percent_encode(input.as_bytes(), AWS_ENCODE_SET).to_string() +} + +fn derive_signing_key_cached( secret_key: &str, date_stamp: &str, region: &str, @@ -68,18 +84,91 @@ pub fn derive_signing_key( k_signing } +fn constant_time_compare_inner(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut result: u8 = 0; + for (x, y) in a.iter().zip(b.iter()) { + result |= x ^ y; + } + result == 0 +} + +#[pyfunction] +pub fn verify_sigv4_signature( + method: &str, + canonical_uri: &str, + query_params: Vec<(String, String)>, + signed_headers_str: &str, + header_values: Vec<(String, String)>, + payload_hash: &str, + amz_date: &str, + date_stamp: &str, + region: &str, + service: &str, + secret_key: &str, + provided_signature: &str, +) -> bool { + let mut sorted_params = query_params; + sorted_params.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1))); + + let canonical_query_string = sorted_params + .iter() + .map(|(k, v)| format!("{}={}", aws_uri_encode(k), aws_uri_encode(v))) + .collect::>() + .join("&"); + + let mut canonical_headers = String::new(); + for (name, value) in &header_values { + let lower_name = name.to_lowercase(); + let normalized = value.split_whitespace().collect::>().join(" "); + let final_value = if lower_name == "expect" && normalized.is_empty() { + "100-continue" + } else { + &normalized + }; + canonical_headers.push_str(&lower_name); + canonical_headers.push(':'); + canonical_headers.push_str(final_value); + canonical_headers.push('\n'); + } + + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + method, canonical_uri, canonical_query_string, canonical_headers, signed_headers_str, payload_hash + ); + + let credential_scope = format!("{}/{}/{}/aws4_request", date_stamp, region, service); + let cr_hash = sha256_hex(canonical_request.as_bytes()); + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{}\n{}\n{}", + amz_date, credential_scope, cr_hash + ); + + let signing_key = derive_signing_key_cached(secret_key, date_stamp, region, service); + let calculated = hmac_sha256(&signing_key, string_to_sign.as_bytes()); + let calculated_hex = hex::encode(&calculated); + + constant_time_compare_inner(calculated_hex.as_bytes(), provided_signature.as_bytes()) +} + +#[pyfunction] +pub fn derive_signing_key( + secret_key: &str, + date_stamp: &str, + region: &str, + service: &str, +) -> Vec { + derive_signing_key_cached(secret_key, date_stamp, region, service) +} + #[pyfunction] pub fn compute_signature(signing_key: &[u8], string_to_sign: &str) -> String { let sig = hmac_sha256(signing_key, string_to_sign.as_bytes()); hex::encode(sig) } -fn sha256_hex(data: &[u8]) -> String { - let mut hasher = Sha256::new(); - hasher.update(data); - hex::encode(hasher.finalize()) -} - #[pyfunction] pub fn build_string_to_sign( amz_date: &str, @@ -87,19 +176,15 @@ pub fn build_string_to_sign( canonical_request: &str, ) -> String { let cr_hash = sha256_hex(canonical_request.as_bytes()); - format!("AWS4-HMAC-SHA256\n{}\n{}\n{}", amz_date, credential_scope, cr_hash) + format!( + "AWS4-HMAC-SHA256\n{}\n{}\n{}", + amz_date, credential_scope, cr_hash + ) } #[pyfunction] pub fn constant_time_compare(a: &str, b: &str) -> bool { - if a.len() != b.len() { - return false; - } - let mut result: u8 = 0; - for (x, y) in a.bytes().zip(b.bytes()) { - result |= x ^ y; - } - result == 0 + constant_time_compare_inner(a.as_bytes(), b.as_bytes()) } #[pyfunction]