Files
MyFSIO/crates/myfsio-crypto/src/aes_gcm.rs

546 lines
17 KiB
Rust

use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use hkdf::Hkdf;
use sha2::Sha256;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
use thiserror::Error;
const DEFAULT_CHUNK_SIZE: usize = 65536;
const HEADER_SIZE: usize = 4;
#[derive(Debug, Error)]
pub enum CryptoError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Invalid key size: expected 32 bytes, got {0}")]
InvalidKeySize(usize),
#[error("Invalid nonce size: expected 12 bytes, got {0}")]
InvalidNonceSize(usize),
#[error("Encryption failed: {0}")]
EncryptionFailed(String),
#[error("Decryption failed at chunk {0}")]
DecryptionFailed(u32),
#[error("HKDF expand failed: {0}")]
HkdfFailed(String),
}
fn read_exact_chunk(reader: &mut impl Read, buf: &mut [u8]) -> std::io::Result<usize> {
let mut filled = 0;
while filled < buf.len() {
match reader.read(&mut buf[filled..]) {
Ok(0) => break,
Ok(n) => filled += n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(filled)
}
fn derive_chunk_nonce(base_nonce: &[u8], chunk_index: u32) -> Result<[u8; 12], CryptoError> {
let hkdf = Hkdf::<Sha256>::new(Some(base_nonce), b"chunk_nonce");
let mut okm = [0u8; 12];
hkdf.expand(&chunk_index.to_be_bytes(), &mut okm)
.map_err(|e| CryptoError::HkdfFailed(e.to_string()))?;
Ok(okm)
}
pub fn encrypt_stream_chunked(
input_path: &Path,
output_path: &Path,
key: &[u8],
base_nonce: &[u8],
chunk_size: Option<usize>,
) -> Result<u32, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKeySize(key.len()));
}
if base_nonce.len() != 12 {
return Err(CryptoError::InvalidNonceSize(base_nonce.len()));
}
let chunk_size = chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE);
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
let cipher = Aes256Gcm::new(&key_arr.into());
let mut infile = File::open(input_path)?;
let mut outfile = File::create(output_path)?;
outfile.write_all(&[0u8; 4])?;
let mut buf = vec![0u8; chunk_size];
let mut chunk_index: u32 = 0;
loop {
let n = read_exact_chunk(&mut infile, &mut buf)?;
if n == 0 {
break;
}
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)?;
let nonce = Nonce::from_slice(&nonce_bytes);
let encrypted = cipher
.encrypt(nonce, &buf[..n])
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
let size = encrypted.len() as u32;
outfile.write_all(&size.to_be_bytes())?;
outfile.write_all(&encrypted)?;
chunk_index += 1;
}
outfile.seek(SeekFrom::Start(0))?;
outfile.write_all(&chunk_index.to_be_bytes())?;
Ok(chunk_index)
}
pub fn decrypt_stream_chunked(
input_path: &Path,
output_path: &Path,
key: &[u8],
base_nonce: &[u8],
) -> Result<u32, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKeySize(key.len()));
}
if base_nonce.len() != 12 {
return Err(CryptoError::InvalidNonceSize(base_nonce.len()));
}
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
let cipher = Aes256Gcm::new(&key_arr.into());
let mut infile = File::open(input_path)?;
let mut outfile = File::create(output_path)?;
let mut header = [0u8; HEADER_SIZE];
infile.read_exact(&mut header)?;
let chunk_count = u32::from_be_bytes(header);
let mut size_buf = [0u8; HEADER_SIZE];
for chunk_index in 0..chunk_count {
infile.read_exact(&mut size_buf)?;
let chunk_size = u32::from_be_bytes(size_buf) as usize;
let mut encrypted = vec![0u8; chunk_size];
infile.read_exact(&mut encrypted)?;
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)?;
let nonce = Nonce::from_slice(&nonce_bytes);
let decrypted = cipher
.decrypt(nonce, encrypted.as_ref())
.map_err(|_| CryptoError::DecryptionFailed(chunk_index))?;
outfile.write_all(&decrypted)?;
}
Ok(chunk_count)
}
const GCM_TAG_LEN: usize = 16;
pub fn decrypt_stream_chunked_range(
input_path: &Path,
output_path: &Path,
key: &[u8],
base_nonce: &[u8],
chunk_plain_size: usize,
plaintext_size: u64,
plain_start: u64,
plain_end_inclusive: u64,
) -> Result<u64, CryptoError> {
if key.len() != 32 {
return Err(CryptoError::InvalidKeySize(key.len()));
}
if base_nonce.len() != 12 {
return Err(CryptoError::InvalidNonceSize(base_nonce.len()));
}
if chunk_plain_size == 0 {
return Err(CryptoError::EncryptionFailed(
"chunk_plain_size must be > 0".into(),
));
}
if plaintext_size == 0 {
let _ = File::create(output_path)?;
return Ok(0);
}
if plain_start > plain_end_inclusive || plain_end_inclusive >= plaintext_size {
return Err(CryptoError::EncryptionFailed(format!(
"range [{}, {}] invalid for plaintext size {}",
plain_start, plain_end_inclusive, plaintext_size
)));
}
let key_arr: [u8; 32] = key.try_into().unwrap();
let nonce_arr: [u8; 12] = base_nonce.try_into().unwrap();
let cipher = Aes256Gcm::new(&key_arr.into());
let n = chunk_plain_size as u64;
let first_chunk = (plain_start / n) as u32;
let last_chunk = (plain_end_inclusive / n) as u32;
let total_chunks = plaintext_size.div_ceil(n) as u32;
let final_chunk_plain = plaintext_size - (total_chunks as u64 - 1) * n;
let mut infile = File::open(input_path)?;
let mut header = [0u8; HEADER_SIZE];
infile.read_exact(&mut header)?;
let stored_chunk_count = u32::from_be_bytes(header);
if stored_chunk_count != total_chunks {
return Err(CryptoError::EncryptionFailed(format!(
"chunk count mismatch: header says {}, plaintext_size implies {}",
stored_chunk_count, total_chunks
)));
}
let mut outfile = File::create(output_path)?;
let stride = n + GCM_TAG_LEN as u64 + HEADER_SIZE as u64;
let first_offset = HEADER_SIZE as u64 + first_chunk as u64 * stride;
infile.seek(SeekFrom::Start(first_offset))?;
let mut size_buf = [0u8; HEADER_SIZE];
let mut bytes_written: u64 = 0;
for chunk_index in first_chunk..=last_chunk {
infile.read_exact(&mut size_buf)?;
let ct_len = u32::from_be_bytes(size_buf) as usize;
let expected_plain = if chunk_index + 1 == total_chunks {
final_chunk_plain as usize
} else {
chunk_plain_size
};
let expected_ct = expected_plain + GCM_TAG_LEN;
if ct_len != expected_ct {
return Err(CryptoError::EncryptionFailed(format!(
"chunk {} stored length {} != expected {} (corrupt file or chunk_size mismatch)",
chunk_index, ct_len, expected_ct
)));
}
let mut encrypted = vec![0u8; ct_len];
infile.read_exact(&mut encrypted)?;
let nonce_bytes = derive_chunk_nonce(&nonce_arr, chunk_index)?;
let nonce = Nonce::from_slice(&nonce_bytes);
let decrypted = cipher
.decrypt(nonce, encrypted.as_ref())
.map_err(|_| CryptoError::DecryptionFailed(chunk_index))?;
let chunk_plain_start = chunk_index as u64 * n;
let chunk_plain_end_exclusive = chunk_plain_start + decrypted.len() as u64;
let slice_start = plain_start.saturating_sub(chunk_plain_start) as usize;
let slice_end = (plain_end_inclusive + 1).min(chunk_plain_end_exclusive);
let slice_end_local = (slice_end - chunk_plain_start) as usize;
if slice_end_local > slice_start {
outfile.write_all(&decrypted[slice_start..slice_end_local])?;
bytes_written += (slice_end_local - slice_start) as u64;
}
}
Ok(bytes_written)
}
pub async fn encrypt_stream_chunked_async(
input_path: &Path,
output_path: &Path,
key: &[u8],
base_nonce: &[u8],
chunk_size: Option<usize>,
) -> Result<u32, CryptoError> {
let input_path = input_path.to_owned();
let output_path = output_path.to_owned();
let key = key.to_vec();
let base_nonce = base_nonce.to_vec();
tokio::task::spawn_blocking(move || {
encrypt_stream_chunked(&input_path, &output_path, &key, &base_nonce, chunk_size)
})
.await
.map_err(|e| CryptoError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?
}
pub async fn decrypt_stream_chunked_async(
input_path: &Path,
output_path: &Path,
key: &[u8],
base_nonce: &[u8],
) -> Result<u32, CryptoError> {
let input_path = input_path.to_owned();
let output_path = output_path.to_owned();
let key = key.to_vec();
let base_nonce = base_nonce.to_vec();
tokio::task::spawn_blocking(move || {
decrypt_stream_chunked(&input_path, &output_path, &key, &base_nonce)
})
.await
.map_err(|e| CryptoError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))?
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write as IoWrite;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let input = dir.path().join("input.bin");
let encrypted = dir.path().join("encrypted.bin");
let decrypted = dir.path().join("decrypted.bin");
let data = b"Hello, this is a test of AES-256-GCM chunked encryption!";
std::fs::File::create(&input)
.unwrap()
.write_all(data)
.unwrap();
let key = [0x42u8; 32];
let nonce = [0x01u8; 12];
let chunks = encrypt_stream_chunked(&input, &encrypted, &key, &nonce, Some(16)).unwrap();
assert!(chunks > 0);
let chunks2 = decrypt_stream_chunked(&encrypted, &decrypted, &key, &nonce).unwrap();
assert_eq!(chunks, chunks2);
let result = std::fs::read(&decrypted).unwrap();
assert_eq!(result, data);
}
#[test]
fn test_invalid_key_size() {
let dir = tempfile::tempdir().unwrap();
let input = dir.path().join("input.bin");
std::fs::File::create(&input)
.unwrap()
.write_all(b"test")
.unwrap();
let result = encrypt_stream_chunked(
&input,
&dir.path().join("out"),
&[0u8; 16],
&[0u8; 12],
None,
);
assert!(matches!(result, Err(CryptoError::InvalidKeySize(16))));
}
fn write_file(path: &Path, data: &[u8]) {
std::fs::File::create(path).unwrap().write_all(data).unwrap();
}
fn make_encrypted_file(
dir: &Path,
data: &[u8],
key: &[u8; 32],
nonce: &[u8; 12],
chunk: usize,
) -> std::path::PathBuf {
let input = dir.join("input.bin");
let encrypted = dir.join("encrypted.bin");
write_file(&input, data);
encrypt_stream_chunked(&input, &encrypted, key, nonce, Some(chunk)).unwrap();
encrypted
}
#[test]
fn test_range_within_single_chunk() {
let dir = tempfile::tempdir().unwrap();
let data: Vec<u8> = (0u8..=255).cycle().take(4096).collect();
let key = [0x33u8; 32];
let nonce = [0x07u8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 1024);
let out = dir.path().join("range.bin");
let n = decrypt_stream_chunked_range(
&encrypted,
&out,
&key,
&nonce,
1024,
data.len() as u64,
200,
399,
)
.unwrap();
assert_eq!(n, 200);
let got = std::fs::read(&out).unwrap();
assert_eq!(got, &data[200..400]);
}
#[test]
fn test_range_spanning_multiple_chunks() {
let dir = tempfile::tempdir().unwrap();
let data: Vec<u8> = (0..5000u32).map(|i| (i % 251) as u8).collect();
let key = [0x44u8; 32];
let nonce = [0x02u8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 512);
let out = dir.path().join("range.bin");
let n = decrypt_stream_chunked_range(
&encrypted,
&out,
&key,
&nonce,
512,
data.len() as u64,
100,
2999,
)
.unwrap();
assert_eq!(n, 2900);
let got = std::fs::read(&out).unwrap();
assert_eq!(got, &data[100..3000]);
}
#[test]
fn test_range_covers_final_partial_chunk() {
let dir = tempfile::tempdir().unwrap();
let data: Vec<u8> = (0..1300u32).map(|i| (i % 71) as u8).collect();
let key = [0x55u8; 32];
let nonce = [0x0au8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 512);
let out = dir.path().join("range.bin");
let n = decrypt_stream_chunked_range(
&encrypted,
&out,
&key,
&nonce,
512,
data.len() as u64,
900,
1299,
)
.unwrap();
assert_eq!(n, 400);
let got = std::fs::read(&out).unwrap();
assert_eq!(got, &data[900..1300]);
}
#[test]
fn test_range_full_object() {
let dir = tempfile::tempdir().unwrap();
let data: Vec<u8> = (0..2048u32).map(|i| (i % 13) as u8).collect();
let key = [0x11u8; 32];
let nonce = [0x33u8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 512);
let out = dir.path().join("range.bin");
let n = decrypt_stream_chunked_range(
&encrypted,
&out,
&key,
&nonce,
512,
data.len() as u64,
0,
data.len() as u64 - 1,
)
.unwrap();
assert_eq!(n, data.len() as u64);
let got = std::fs::read(&out).unwrap();
assert_eq!(got, data);
}
#[test]
fn test_range_wrong_key_fails() {
let dir = tempfile::tempdir().unwrap();
let data = b"range-auth-check".repeat(100);
let key = [0x66u8; 32];
let nonce = [0x09u8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 256);
let out = dir.path().join("range.bin");
let wrong = [0x67u8; 32];
let r = decrypt_stream_chunked_range(
&encrypted,
&out,
&wrong,
&nonce,
256,
data.len() as u64,
0,
data.len() as u64 - 1,
);
assert!(matches!(r, Err(CryptoError::DecryptionFailed(_))));
}
#[test]
fn test_range_out_of_bounds_rejected() {
let dir = tempfile::tempdir().unwrap();
let data = vec![0u8; 100];
let key = [0x22u8; 32];
let nonce = [0x44u8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 64);
let out = dir.path().join("range.bin");
let r = decrypt_stream_chunked_range(
&encrypted,
&out,
&key,
&nonce,
64,
data.len() as u64,
50,
200,
);
assert!(r.is_err());
}
#[test]
fn test_range_mismatched_chunk_size_detected() {
let dir = tempfile::tempdir().unwrap();
let data: Vec<u8> = (0..2048u32).map(|i| i as u8).collect();
let key = [0x77u8; 32];
let nonce = [0x88u8; 12];
let encrypted = make_encrypted_file(dir.path(), &data, &key, &nonce, 512);
let out = dir.path().join("range.bin");
let r = decrypt_stream_chunked_range(
&encrypted,
&out,
&key,
&nonce,
1024,
data.len() as u64,
0,
1023,
);
assert!(r.is_err());
}
#[test]
fn test_wrong_key_fails_decrypt() {
let dir = tempfile::tempdir().unwrap();
let input = dir.path().join("input.bin");
let encrypted = dir.path().join("encrypted.bin");
let decrypted = dir.path().join("decrypted.bin");
std::fs::File::create(&input)
.unwrap()
.write_all(b"secret data")
.unwrap();
let key = [0x42u8; 32];
let nonce = [0x01u8; 12];
encrypt_stream_chunked(&input, &encrypted, &key, &nonce, None).unwrap();
let wrong_key = [0x43u8; 32];
let result = decrypt_stream_chunked(&encrypted, &decrypted, &wrong_key, &nonce);
assert!(matches!(result, Err(CryptoError::DecryptionFailed(_))));
}
}