diff --git a/app/encryption.py b/app/encryption.py index 6d8c2b2..cb40199 100644 --- a/app/encryption.py +++ b/app/encryption.py @@ -19,6 +19,13 @@ from cryptography.hazmat.primitives import hashes if sys.platform != "win32": import fcntl +try: + import myfsio_core as _rc + _HAS_RUST = True +except ImportError: + _rc = None + _HAS_RUST = False + logger = logging.getLogger(__name__) @@ -338,6 +345,69 @@ class StreamingEncryptor: output.seek(0) return output + def encrypt_file(self, input_path: str, output_path: str) -> EncryptionMetadata: + data_key, encrypted_data_key = self.provider.generate_data_key() + base_nonce = secrets.token_bytes(12) + + if _HAS_RUST: + _rc.encrypt_stream_chunked( + input_path, output_path, data_key, base_nonce, self.chunk_size + ) + else: + with open(input_path, "rb") as stream: + aesgcm = AESGCM(data_key) + with open(output_path, "wb") as out: + out.write(b"\x00\x00\x00\x00") + chunk_index = 0 + while True: + chunk = stream.read(self.chunk_size) + if not chunk: + break + chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index) + encrypted_chunk = aesgcm.encrypt(chunk_nonce, chunk, None) + out.write(len(encrypted_chunk).to_bytes(self.HEADER_SIZE, "big")) + out.write(encrypted_chunk) + chunk_index += 1 + out.seek(0) + out.write(chunk_index.to_bytes(4, "big")) + + return EncryptionMetadata( + algorithm="AES256", + key_id=self.provider.KEY_ID if hasattr(self.provider, "KEY_ID") else "local", + nonce=base_nonce, + encrypted_data_key=encrypted_data_key, + ) + + def decrypt_file(self, input_path: str, output_path: str, + metadata: EncryptionMetadata) -> None: + data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id) + base_nonce = metadata.nonce + + if _HAS_RUST: + _rc.decrypt_stream_chunked(input_path, output_path, data_key, base_nonce) + else: + with open(input_path, "rb") as stream: + chunk_count_bytes = stream.read(4) + if len(chunk_count_bytes) < 4: + raise EncryptionError("Invalid encrypted stream: missing header") + chunk_count = int.from_bytes(chunk_count_bytes, "big") + aesgcm = AESGCM(data_key) + with open(output_path, "wb") as out: + for chunk_index in range(chunk_count): + size_bytes = stream.read(self.HEADER_SIZE) + if len(size_bytes) < self.HEADER_SIZE: + raise EncryptionError(f"Invalid encrypted stream: truncated at chunk {chunk_index}") + chunk_size = int.from_bytes(size_bytes, "big") + encrypted_chunk = stream.read(chunk_size) + if len(encrypted_chunk) < chunk_size: + raise EncryptionError(f"Invalid encrypted stream: incomplete chunk {chunk_index}") + chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index) + try: + decrypted_chunk = aesgcm.decrypt(chunk_nonce, encrypted_chunk, None) + out.write(decrypted_chunk) + except Exception as exc: + raise EncryptionError(f"Failed to decrypt chunk {chunk_index}: {exc}") from exc + class EncryptionManager: """Manages encryption providers and operations.""" diff --git a/app/storage.py b/app/storage.py index a582ad9..7296abb 100644 --- a/app/storage.py +++ b/app/storage.py @@ -841,32 +841,61 @@ class ObjectStorage: tmp_dir = self._system_root_path() / self.SYSTEM_TMP_DIR tmp_dir.mkdir(parents=True, exist_ok=True) - tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp" - - try: - with tmp_path.open("wb") as target: - shutil.copyfileobj(stream, target) - new_size = tmp_path.stat().st_size - size_delta = new_size - existing_size - object_delta = 0 if is_overwrite else 1 - - if enforce_quota: - quota_check = self.check_quota( - bucket_name, - additional_bytes=max(0, size_delta), - additional_objects=object_delta, + if _HAS_RUST: + tmp_path = None + try: + tmp_path_str, etag, new_size = _rc.stream_to_file_with_md5( + stream, str(tmp_dir) ) - if not quota_check["allowed"]: - raise QuotaExceededError( - quota_check["message"] or "Quota exceeded", - quota_check["quota"], - quota_check["usage"], - ) + tmp_path = Path(tmp_path_str) + + size_delta = new_size - existing_size + object_delta = 0 if is_overwrite else 1 + + if enforce_quota: + quota_check = self.check_quota( + bucket_name, + additional_bytes=max(0, size_delta), + additional_objects=object_delta, + ) + if not quota_check["allowed"]: + raise QuotaExceededError( + quota_check["message"] or "Quota exceeded", + quota_check["quota"], + quota_check["usage"], + ) + + shutil.move(str(tmp_path), str(destination)) + finally: + if tmp_path: + try: + tmp_path.unlink(missing_ok=True) + except OSError: + pass + else: + tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp" + try: + with tmp_path.open("wb") as target: + shutil.copyfileobj(stream, target) + + new_size = tmp_path.stat().st_size + size_delta = new_size - existing_size + object_delta = 0 if is_overwrite else 1 + + if enforce_quota: + quota_check = self.check_quota( + bucket_name, + additional_bytes=max(0, size_delta), + additional_objects=object_delta, + ) + if not quota_check["allowed"]: + raise QuotaExceededError( + quota_check["message"] or "Quota exceeded", + quota_check["quota"], + quota_check["usage"], + ) - if _HAS_RUST: - etag = _rc.md5_file(str(tmp_path)) - else: checksum = hashlib.md5() with tmp_path.open("rb") as f: while True: @@ -876,13 +905,12 @@ class ObjectStorage: checksum.update(chunk) etag = checksum.hexdigest() - shutil.move(str(tmp_path), str(destination)) - - finally: - try: - tmp_path.unlink(missing_ok=True) - except OSError: - pass + shutil.move(str(tmp_path), str(destination)) + finally: + try: + tmp_path.unlink(missing_ok=True) + except OSError: + pass stat = destination.stat() @@ -1702,19 +1730,29 @@ class ObjectStorage: if versioning_enabled and destination.exists(): archived_version_size = destination.stat().st_size self._archive_current_version(bucket_id, safe_key, reason="overwrite") - checksum = hashlib.md5() - with destination.open("wb") as target: + if _HAS_RUST: + part_paths = [] for _, record in validated: - part_path = upload_root / record["filename"] - if not part_path.exists(): + pp = upload_root / record["filename"] + if not pp.exists(): raise StorageError(f"Missing part file {record['filename']}") - with part_path.open("rb") as chunk: - while True: - data = chunk.read(1024 * 1024) - if not data: - break - checksum.update(data) - target.write(data) + part_paths.append(str(pp)) + checksum_hex = _rc.assemble_parts_with_md5(part_paths, str(destination)) + else: + checksum = hashlib.md5() + with destination.open("wb") as target: + for _, record in validated: + part_path = upload_root / record["filename"] + if not part_path.exists(): + raise StorageError(f"Missing part file {record['filename']}") + with part_path.open("rb") as chunk: + while True: + data = chunk.read(1024 * 1024) + if not data: + break + checksum.update(data) + target.write(data) + checksum_hex = checksum.hexdigest() except BlockingIOError: raise StorageError("Another upload to this key is in progress") @@ -1729,7 +1767,7 @@ class ObjectStorage: ) stat = destination.stat() - etag = checksum.hexdigest() + etag = checksum_hex metadata = manifest.get("metadata") internal_meta = {"__etag__": etag, "__size__": str(stat.st_size)} diff --git a/myfsio_core/Cargo.toml b/myfsio_core/Cargo.toml index 2bff9cc..6f900e0 100644 --- a/myfsio_core/Cargo.toml +++ b/myfsio_core/Cargo.toml @@ -19,3 +19,6 @@ regex = "1" lru = "0.14" parking_lot = "0.12" percent-encoding = "2" +aes-gcm = "0.10" +hkdf = "0.12" +uuid = { version = "1", features = ["v4"] } diff --git a/myfsio_core/src/crypto.rs b/myfsio_core/src/crypto.rs new file mode 100644 index 0000000..082814d --- /dev/null +++ b/myfsio_core/src/crypto.rs @@ -0,0 +1,192 @@ +use aes_gcm::aead::Aead; +use aes_gcm::{Aes256Gcm, KeyInit, Nonce}; +use hkdf::Hkdf; +use pyo3::exceptions::{PyIOError, PyValueError}; +use pyo3::prelude::*; +use sha2::Sha256; +use std::fs::File; +use std::io::{Read, Seek, SeekFrom, Write}; + +const DEFAULT_CHUNK_SIZE: usize = 65536; +const HEADER_SIZE: usize = 4; + +fn read_exact_chunk(reader: &mut impl Read, buf: &mut [u8]) -> std::io::Result { + let mut filled = 0; + while filled < buf.len() { + match reader.read(&mut buf[filled..]) { + Ok(0) => break, + Ok(n) => filled += n, + Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue, + Err(e) => return Err(e), + } + } + Ok(filled) +} + +fn derive_chunk_nonce(base_nonce: &[u8], chunk_index: u32) -> Result<[u8; 12], String> { + let hkdf = Hkdf::::new(Some(base_nonce), b"chunk_nonce"); + let mut okm = [0u8; 12]; + hkdf.expand(&chunk_index.to_be_bytes(), &mut okm) + .map_err(|e| format!("HKDF expand failed: {}", e))?; + Ok(okm) +} + +#[pyfunction] +#[pyo3(signature = (input_path, output_path, key, base_nonce, chunk_size=DEFAULT_CHUNK_SIZE))] +pub fn encrypt_stream_chunked( + py: Python<'_>, + input_path: &str, + output_path: &str, + key: &[u8], + base_nonce: &[u8], + chunk_size: usize, +) -> PyResult { + if key.len() != 32 { + return Err(PyValueError::new_err(format!( + "Key must be 32 bytes, got {}", + key.len() + ))); + } + if base_nonce.len() != 12 { + return Err(PyValueError::new_err(format!( + "Base nonce must be 12 bytes, got {}", + base_nonce.len() + ))); + } + + let chunk_size = if chunk_size == 0 { + DEFAULT_CHUNK_SIZE + } else { + chunk_size + }; + + let inp = input_path.to_owned(); + let out = output_path.to_owned(); + let key_arr: [u8; 32] = key.try_into().unwrap(); + let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap(); + + py.detach(move || { + let cipher = Aes256Gcm::new(&key_arr.into()); + + let mut infile = File::open(&inp) + .map_err(|e| PyIOError::new_err(format!("Failed to open input: {}", e)))?; + let mut outfile = File::create(&out) + .map_err(|e| PyIOError::new_err(format!("Failed to create output: {}", e)))?; + + outfile + .write_all(&[0u8; 4]) + .map_err(|e| PyIOError::new_err(format!("Failed to write header: {}", e)))?; + + let mut buf = vec![0u8; chunk_size]; + let mut chunk_index: u32 = 0; + + loop { + let n = read_exact_chunk(&mut infile, &mut buf) + .map_err(|e| PyIOError::new_err(format!("Failed to read: {}", e)))?; + if n == 0 { + break; + } + + let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index) + .map_err(|e| PyValueError::new_err(e))?; + let nonce = Nonce::from_slice(&nonce_bytes); + + let encrypted = cipher + .encrypt(nonce, &buf[..n]) + .map_err(|e| PyValueError::new_err(format!("Encrypt failed: {}", e)))?; + + let size = encrypted.len() as u32; + outfile + .write_all(&size.to_be_bytes()) + .map_err(|e| PyIOError::new_err(format!("Failed to write chunk size: {}", e)))?; + outfile + .write_all(&encrypted) + .map_err(|e| PyIOError::new_err(format!("Failed to write chunk: {}", e)))?; + + chunk_index += 1; + } + + outfile + .seek(SeekFrom::Start(0)) + .map_err(|e| PyIOError::new_err(format!("Failed to seek: {}", e)))?; + outfile + .write_all(&chunk_index.to_be_bytes()) + .map_err(|e| PyIOError::new_err(format!("Failed to write chunk count: {}", e)))?; + + Ok(chunk_index) + }) +} + +#[pyfunction] +pub fn decrypt_stream_chunked( + py: Python<'_>, + input_path: &str, + output_path: &str, + key: &[u8], + base_nonce: &[u8], +) -> PyResult { + if key.len() != 32 { + return Err(PyValueError::new_err(format!( + "Key must be 32 bytes, got {}", + key.len() + ))); + } + if base_nonce.len() != 12 { + return Err(PyValueError::new_err(format!( + "Base nonce must be 12 bytes, got {}", + base_nonce.len() + ))); + } + + let inp = input_path.to_owned(); + let out = output_path.to_owned(); + let key_arr: [u8; 32] = key.try_into().unwrap(); + let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap(); + + py.detach(move || { + let cipher = Aes256Gcm::new(&key_arr.into()); + + let mut infile = File::open(&inp) + .map_err(|e| PyIOError::new_err(format!("Failed to open input: {}", e)))?; + let mut outfile = File::create(&out) + .map_err(|e| PyIOError::new_err(format!("Failed to create output: {}", e)))?; + + let mut header = [0u8; HEADER_SIZE]; + infile + .read_exact(&mut header) + .map_err(|e| PyIOError::new_err(format!("Failed to read header: {}", e)))?; + let chunk_count = u32::from_be_bytes(header); + + let mut size_buf = [0u8; HEADER_SIZE]; + for chunk_index in 0..chunk_count { + infile + .read_exact(&mut size_buf) + .map_err(|e| { + PyIOError::new_err(format!( + "Failed to read chunk {} size: {}", + chunk_index, e + )) + })?; + let chunk_size = u32::from_be_bytes(size_buf) as usize; + + let mut encrypted = vec![0u8; chunk_size]; + infile.read_exact(&mut encrypted).map_err(|e| { + PyIOError::new_err(format!("Failed to read chunk {}: {}", chunk_index, e)) + })?; + + let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index) + .map_err(|e| PyValueError::new_err(e))?; + let nonce = Nonce::from_slice(&nonce_bytes); + + let decrypted = cipher.decrypt(nonce, encrypted.as_ref()).map_err(|e| { + PyValueError::new_err(format!("Decrypt chunk {} failed: {}", chunk_index, e)) + })?; + + outfile.write_all(&decrypted).map_err(|e| { + PyIOError::new_err(format!("Failed to write chunk {}: {}", chunk_index, e)) + })?; + } + + Ok(chunk_count) + }) +} diff --git a/myfsio_core/src/lib.rs b/myfsio_core/src/lib.rs index 3ff04f9..f10dde3 100644 --- a/myfsio_core/src/lib.rs +++ b/myfsio_core/src/lib.rs @@ -1,7 +1,9 @@ +mod crypto; mod hashing; mod metadata; mod sigv4; mod storage; +mod streaming; mod validation; use pyo3::prelude::*; @@ -38,6 +40,12 @@ mod myfsio_core { m.add_function(wrap_pyfunction!(storage::search_objects_scan, m)?)?; m.add_function(wrap_pyfunction!(storage::build_object_cache, m)?)?; + m.add_function(wrap_pyfunction!(streaming::stream_to_file_with_md5, m)?)?; + m.add_function(wrap_pyfunction!(streaming::assemble_parts_with_md5, m)?)?; + + m.add_function(wrap_pyfunction!(crypto::encrypt_stream_chunked, m)?)?; + m.add_function(wrap_pyfunction!(crypto::decrypt_stream_chunked, m)?)?; + Ok(()) } } diff --git a/myfsio_core/src/streaming.rs b/myfsio_core/src/streaming.rs new file mode 100644 index 0000000..1ff13f6 --- /dev/null +++ b/myfsio_core/src/streaming.rs @@ -0,0 +1,107 @@ +use md5::{Digest, Md5}; +use pyo3::exceptions::{PyIOError, PyValueError}; +use pyo3::prelude::*; +use std::fs::{self, File}; +use std::io::{Read, Write}; +use uuid::Uuid; + +const DEFAULT_CHUNK_SIZE: usize = 262144; + +#[pyfunction] +#[pyo3(signature = (stream, tmp_dir, chunk_size=DEFAULT_CHUNK_SIZE))] +pub fn stream_to_file_with_md5( + py: Python<'_>, + stream: &Bound<'_, PyAny>, + tmp_dir: &str, + chunk_size: usize, +) -> PyResult<(String, String, u64)> { + let chunk_size = if chunk_size == 0 { + DEFAULT_CHUNK_SIZE + } else { + chunk_size + }; + + fs::create_dir_all(tmp_dir) + .map_err(|e| PyIOError::new_err(format!("Failed to create tmp dir: {}", e)))?; + + let tmp_name = format!("{}.tmp", Uuid::new_v4().as_hyphenated()); + let tmp_path_buf = std::path::PathBuf::from(tmp_dir).join(&tmp_name); + let tmp_path = tmp_path_buf.to_string_lossy().into_owned(); + + let mut file = File::create(&tmp_path) + .map_err(|e| PyIOError::new_err(format!("Failed to create temp file: {}", e)))?; + let mut hasher = Md5::new(); + let mut total_bytes: u64 = 0; + + let result: PyResult<()> = (|| { + loop { + let chunk: Vec = stream.call_method1("read", (chunk_size,))?.extract()?; + if chunk.is_empty() { + break; + } + hasher.update(&chunk); + file.write_all(&chunk) + .map_err(|e| PyIOError::new_err(format!("Failed to write: {}", e)))?; + total_bytes += chunk.len() as u64; + + py.check_signals()?; + } + Ok(()) + })(); + + if let Err(e) = result { + drop(file); + let _ = fs::remove_file(&tmp_path); + return Err(e); + } + + drop(file); + + let md5_hex = format!("{:x}", hasher.finalize()); + Ok((tmp_path, md5_hex, total_bytes)) +} + +#[pyfunction] +pub fn assemble_parts_with_md5( + py: Python<'_>, + part_paths: Vec, + dest_path: &str, +) -> PyResult { + if part_paths.is_empty() { + return Err(PyValueError::new_err("No parts to assemble")); + } + + let dest = dest_path.to_owned(); + let parts = part_paths; + + py.detach(move || { + if let Some(parent) = std::path::Path::new(&dest).parent() { + fs::create_dir_all(parent) + .map_err(|e| PyIOError::new_err(format!("Failed to create dest dir: {}", e)))?; + } + + let mut target = File::create(&dest) + .map_err(|e| PyIOError::new_err(format!("Failed to create dest file: {}", e)))?; + let mut hasher = Md5::new(); + let mut buf = vec![0u8; 1024 * 1024]; + + for part_path in &parts { + let mut part = File::open(part_path) + .map_err(|e| PyIOError::new_err(format!("Failed to open part {}: {}", part_path, e)))?; + loop { + let n = part + .read(&mut buf) + .map_err(|e| PyIOError::new_err(format!("Failed to read part: {}", e)))?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + target + .write_all(&buf[..n]) + .map_err(|e| PyIOError::new_err(format!("Failed to write: {}", e)))?; + } + } + + Ok(format!("{:x}", hasher.finalize())) + }) +} diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 47181a2..c4ee8c2 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -53,15 +53,17 @@ def test_special_characters_in_metadata(tmp_path: Path): assert meta["special"] == "!@#$%^&*()" def test_disk_full_scenario(tmp_path: Path, monkeypatch): - # Simulate disk full by mocking write to fail + import app.storage as _storage_mod + monkeypatch.setattr(_storage_mod, "_HAS_RUST", False) + storage = ObjectStorage(tmp_path) storage.create_bucket("full") - + def mock_copyfileobj(*args, **kwargs): raise OSError(28, "No space left on device") - + import shutil monkeypatch.setattr(shutil, "copyfileobj", mock_copyfileobj) - + with pytest.raises(OSError, match="No space left on device"): storage.put_object("full", "file", io.BytesIO(b"data")) diff --git a/tests/test_rust_extensions.py b/tests/test_rust_extensions.py new file mode 100644 index 0000000..31280cb --- /dev/null +++ b/tests/test_rust_extensions.py @@ -0,0 +1,350 @@ +import hashlib +import io +import os +import secrets +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +try: + import myfsio_core as _rc + HAS_RUST = True +except ImportError: + _rc = None + HAS_RUST = False + +pytestmark = pytest.mark.skipif(not HAS_RUST, reason="myfsio_core not available") + + +class TestStreamToFileWithMd5: + def test_basic_write(self, tmp_path): + data = b"hello world" * 1000 + stream = io.BytesIO(data) + tmp_dir = str(tmp_path / "tmp") + + tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir) + + assert size == len(data) + assert md5_hex == hashlib.md5(data).hexdigest() + assert Path(tmp_path_str).exists() + assert Path(tmp_path_str).read_bytes() == data + + def test_empty_stream(self, tmp_path): + stream = io.BytesIO(b"") + tmp_dir = str(tmp_path / "tmp") + + tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir) + + assert size == 0 + assert md5_hex == hashlib.md5(b"").hexdigest() + assert Path(tmp_path_str).read_bytes() == b"" + + def test_large_data(self, tmp_path): + data = os.urandom(1024 * 1024 * 2) + stream = io.BytesIO(data) + tmp_dir = str(tmp_path / "tmp") + + tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5(stream, tmp_dir) + + assert size == len(data) + assert md5_hex == hashlib.md5(data).hexdigest() + + def test_custom_chunk_size(self, tmp_path): + data = b"x" * 10000 + stream = io.BytesIO(data) + tmp_dir = str(tmp_path / "tmp") + + tmp_path_str, md5_hex, size = _rc.stream_to_file_with_md5( + stream, tmp_dir, chunk_size=128 + ) + + assert size == len(data) + assert md5_hex == hashlib.md5(data).hexdigest() + + +class TestAssemblePartsWithMd5: + def test_basic_assembly(self, tmp_path): + parts = [] + combined = b"" + for i in range(3): + data = f"part{i}data".encode() * 100 + combined += data + p = tmp_path / f"part{i}" + p.write_bytes(data) + parts.append(str(p)) + + dest = str(tmp_path / "output") + md5_hex = _rc.assemble_parts_with_md5(parts, dest) + + assert md5_hex == hashlib.md5(combined).hexdigest() + assert Path(dest).read_bytes() == combined + + def test_single_part(self, tmp_path): + data = b"single part data" + p = tmp_path / "part0" + p.write_bytes(data) + + dest = str(tmp_path / "output") + md5_hex = _rc.assemble_parts_with_md5([str(p)], dest) + + assert md5_hex == hashlib.md5(data).hexdigest() + assert Path(dest).read_bytes() == data + + def test_empty_parts_list(self): + with pytest.raises(ValueError, match="No parts"): + _rc.assemble_parts_with_md5([], "dummy") + + def test_missing_part_file(self, tmp_path): + with pytest.raises(OSError): + _rc.assemble_parts_with_md5( + [str(tmp_path / "nonexistent")], str(tmp_path / "out") + ) + + def test_large_parts(self, tmp_path): + parts = [] + combined = b"" + for i in range(5): + data = os.urandom(512 * 1024) + combined += data + p = tmp_path / f"part{i}" + p.write_bytes(data) + parts.append(str(p)) + + dest = str(tmp_path / "output") + md5_hex = _rc.assemble_parts_with_md5(parts, dest) + + assert md5_hex == hashlib.md5(combined).hexdigest() + assert Path(dest).read_bytes() == combined + + +class TestEncryptDecryptStreamChunked: + def _python_derive_chunk_nonce(self, base_nonce, chunk_index): + from cryptography.hazmat.primitives.kdf.hkdf import HKDF + from cryptography.hazmat.primitives import hashes + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=12, + salt=base_nonce, + info=chunk_index.to_bytes(4, "big"), + ) + return hkdf.derive(b"chunk_nonce") + + def test_encrypt_decrypt_roundtrip(self, tmp_path): + data = b"Hello, encryption!" * 500 + key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + + input_path = str(tmp_path / "plaintext") + encrypted_path = str(tmp_path / "encrypted") + decrypted_path = str(tmp_path / "decrypted") + + Path(input_path).write_bytes(data) + + chunk_count = _rc.encrypt_stream_chunked( + input_path, encrypted_path, key, base_nonce + ) + assert chunk_count > 0 + + chunk_count_dec = _rc.decrypt_stream_chunked( + encrypted_path, decrypted_path, key, base_nonce + ) + assert chunk_count_dec == chunk_count + assert Path(decrypted_path).read_bytes() == data + + def test_empty_file(self, tmp_path): + key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + + input_path = str(tmp_path / "empty") + encrypted_path = str(tmp_path / "encrypted") + decrypted_path = str(tmp_path / "decrypted") + + Path(input_path).write_bytes(b"") + + chunk_count = _rc.encrypt_stream_chunked( + input_path, encrypted_path, key, base_nonce + ) + assert chunk_count == 0 + + chunk_count_dec = _rc.decrypt_stream_chunked( + encrypted_path, decrypted_path, key, base_nonce + ) + assert chunk_count_dec == 0 + assert Path(decrypted_path).read_bytes() == b"" + + def test_custom_chunk_size(self, tmp_path): + data = os.urandom(10000) + key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + + input_path = str(tmp_path / "plaintext") + encrypted_path = str(tmp_path / "encrypted") + decrypted_path = str(tmp_path / "decrypted") + + Path(input_path).write_bytes(data) + + chunk_count = _rc.encrypt_stream_chunked( + input_path, encrypted_path, key, base_nonce, chunk_size=1024 + ) + assert chunk_count == 10 + + _rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce) + assert Path(decrypted_path).read_bytes() == data + + def test_invalid_key_length(self, tmp_path): + input_path = str(tmp_path / "in") + Path(input_path).write_bytes(b"data") + + with pytest.raises(ValueError, match="32 bytes"): + _rc.encrypt_stream_chunked( + input_path, str(tmp_path / "out"), b"short", secrets.token_bytes(12) + ) + + def test_invalid_nonce_length(self, tmp_path): + input_path = str(tmp_path / "in") + Path(input_path).write_bytes(b"data") + + with pytest.raises(ValueError, match="12 bytes"): + _rc.encrypt_stream_chunked( + input_path, str(tmp_path / "out"), secrets.token_bytes(32), b"short" + ) + + def test_wrong_key_fails_decrypt(self, tmp_path): + data = b"sensitive data" + key = secrets.token_bytes(32) + wrong_key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + + input_path = str(tmp_path / "plaintext") + encrypted_path = str(tmp_path / "encrypted") + decrypted_path = str(tmp_path / "decrypted") + + Path(input_path).write_bytes(data) + _rc.encrypt_stream_chunked(input_path, encrypted_path, key, base_nonce) + + with pytest.raises((ValueError, OSError)): + _rc.decrypt_stream_chunked( + encrypted_path, decrypted_path, wrong_key, base_nonce + ) + + def test_cross_compat_python_encrypt_rust_decrypt(self, tmp_path): + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + data = b"cross compat test data" * 100 + key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + chunk_size = 1024 + + encrypted_path = str(tmp_path / "py_encrypted") + with open(encrypted_path, "wb") as f: + f.write(b"\x00\x00\x00\x00") + aesgcm = AESGCM(key) + chunk_index = 0 + offset = 0 + while offset < len(data): + chunk = data[offset:offset + chunk_size] + nonce = self._python_derive_chunk_nonce(base_nonce, chunk_index) + enc = aesgcm.encrypt(nonce, chunk, None) + f.write(len(enc).to_bytes(4, "big")) + f.write(enc) + chunk_index += 1 + offset += chunk_size + f.seek(0) + f.write(chunk_index.to_bytes(4, "big")) + + decrypted_path = str(tmp_path / "rust_decrypted") + _rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce) + assert Path(decrypted_path).read_bytes() == data + + def test_cross_compat_rust_encrypt_python_decrypt(self, tmp_path): + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + data = b"cross compat reverse test" * 100 + key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + chunk_size = 1024 + + input_path = str(tmp_path / "plaintext") + encrypted_path = str(tmp_path / "rust_encrypted") + Path(input_path).write_bytes(data) + + chunk_count = _rc.encrypt_stream_chunked( + input_path, encrypted_path, key, base_nonce, chunk_size=chunk_size + ) + + aesgcm = AESGCM(key) + with open(encrypted_path, "rb") as f: + count_bytes = f.read(4) + assert int.from_bytes(count_bytes, "big") == chunk_count + + decrypted = b"" + for i in range(chunk_count): + size = int.from_bytes(f.read(4), "big") + enc_chunk = f.read(size) + nonce = self._python_derive_chunk_nonce(base_nonce, i) + decrypted += aesgcm.decrypt(nonce, enc_chunk, None) + + assert decrypted == data + + def test_large_file_roundtrip(self, tmp_path): + data = os.urandom(1024 * 1024) + key = secrets.token_bytes(32) + base_nonce = secrets.token_bytes(12) + + input_path = str(tmp_path / "large") + encrypted_path = str(tmp_path / "encrypted") + decrypted_path = str(tmp_path / "decrypted") + + Path(input_path).write_bytes(data) + + _rc.encrypt_stream_chunked(input_path, encrypted_path, key, base_nonce) + _rc.decrypt_stream_chunked(encrypted_path, decrypted_path, key, base_nonce) + + assert Path(decrypted_path).read_bytes() == data + + +class TestStreamingEncryptorFileMethods: + def test_encrypt_file_decrypt_file_roundtrip(self, tmp_path): + from app.encryption import LocalKeyEncryption, StreamingEncryptor + + master_key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(master_key_path) + encryptor = StreamingEncryptor(provider, chunk_size=512) + + data = b"file method test data" * 200 + input_path = str(tmp_path / "input") + encrypted_path = str(tmp_path / "encrypted") + decrypted_path = str(tmp_path / "decrypted") + + Path(input_path).write_bytes(data) + + metadata = encryptor.encrypt_file(input_path, encrypted_path) + assert metadata.algorithm == "AES256" + + encryptor.decrypt_file(encrypted_path, decrypted_path, metadata) + assert Path(decrypted_path).read_bytes() == data + + def test_encrypt_file_matches_encrypt_stream(self, tmp_path): + from app.encryption import LocalKeyEncryption, StreamingEncryptor + + master_key_path = tmp_path / "master.key" + provider = LocalKeyEncryption(master_key_path) + encryptor = StreamingEncryptor(provider, chunk_size=512) + + data = b"stream vs file comparison" * 100 + input_path = str(tmp_path / "input") + Path(input_path).write_bytes(data) + + file_encrypted_path = str(tmp_path / "file_enc") + metadata_file = encryptor.encrypt_file(input_path, file_encrypted_path) + + file_decrypted_path = str(tmp_path / "file_dec") + encryptor.decrypt_file(file_encrypted_path, file_decrypted_path, metadata_file) + assert Path(file_decrypted_path).read_bytes() == data + + stream_enc, metadata_stream = encryptor.encrypt_stream(io.BytesIO(data)) + stream_dec = encryptor.decrypt_stream(stream_enc, metadata_stream) + assert stream_dec.read() == data