use aes_gcm::aead::Aead; use aes_gcm::{Aes256Gcm, KeyInit, Nonce}; use base64::engine::general_purpose::STANDARD as B64; use base64::Engine; use chrono::{DateTime, Utc}; use rand::RngCore; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::sync::RwLock; use crate::aes_gcm::CryptoError; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KmsKey { #[serde(rename = "KeyId")] pub key_id: String, #[serde(rename = "Arn")] pub arn: String, #[serde(rename = "Description")] pub description: String, #[serde(rename = "CreationDate")] pub creation_date: DateTime, #[serde(rename = "Enabled")] pub enabled: bool, #[serde(rename = "KeyState")] pub key_state: String, #[serde(rename = "KeyUsage")] pub key_usage: String, #[serde(rename = "KeySpec")] pub key_spec: String, #[serde(rename = "EncryptedKeyMaterial")] pub encrypted_key_material: String, } #[derive(Debug, Clone, Serialize, Deserialize)] struct KmsStore { keys: Vec, } pub struct KmsService { keys_path: PathBuf, master_key: Arc>, keys: Arc>>, } impl KmsService { pub async fn new(keys_dir: &Path) -> Result { std::fs::create_dir_all(keys_dir).map_err(CryptoError::Io)?; let keys_path = keys_dir.join("kms_keys.json"); let master_key = Self::load_or_create_master_key(&keys_dir.join("kms_master.key"))?; let keys = if keys_path.exists() { let data = std::fs::read_to_string(&keys_path).map_err(CryptoError::Io)?; let store: KmsStore = serde_json::from_str(&data) .map_err(|e| CryptoError::EncryptionFailed(format!("Bad KMS store: {}", e)))?; store.keys } else { Vec::new() }; Ok(Self { keys_path, master_key: Arc::new(RwLock::new(master_key)), keys: Arc::new(RwLock::new(keys)), }) } fn load_or_create_master_key(path: &Path) -> Result<[u8; 32], CryptoError> { if path.exists() { let encoded = std::fs::read_to_string(path).map_err(CryptoError::Io)?; let decoded = B64.decode(encoded.trim()).map_err(|e| { CryptoError::EncryptionFailed(format!("Bad master key encoding: {}", e)) })?; if decoded.len() != 32 { return Err(CryptoError::InvalidKeySize(decoded.len())); } let mut key = [0u8; 32]; key.copy_from_slice(&decoded); Ok(key) } else { let mut key = [0u8; 32]; rand::thread_rng().fill_bytes(&mut key); let encoded = B64.encode(key); std::fs::write(path, &encoded).map_err(CryptoError::Io)?; Ok(key) } } fn encrypt_key_material( master_key: &[u8; 32], plaintext_key: &[u8], ) -> Result { let cipher = Aes256Gcm::new(master_key.into()); let mut nonce_bytes = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes); let ciphertext = cipher .encrypt(nonce, plaintext_key) .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?; let mut combined = Vec::with_capacity(12 + ciphertext.len()); combined.extend_from_slice(&nonce_bytes); combined.extend_from_slice(&ciphertext); Ok(B64.encode(&combined)) } fn decrypt_key_material( master_key: &[u8; 32], encrypted_b64: &str, ) -> Result, CryptoError> { let combined = B64.decode(encrypted_b64).map_err(|e| { CryptoError::EncryptionFailed(format!("Bad key material encoding: {}", e)) })?; if combined.len() < 12 { return Err(CryptoError::EncryptionFailed( "Encrypted key material too short".to_string(), )); } let (nonce_bytes, ciphertext) = combined.split_at(12); let cipher = Aes256Gcm::new(master_key.into()); let nonce = Nonce::from_slice(nonce_bytes); cipher .decrypt(nonce, ciphertext) .map_err(|_| CryptoError::DecryptionFailed(0)) } async fn save(&self) -> Result<(), CryptoError> { let keys = self.keys.read().await; let store = KmsStore { keys: keys.clone() }; let json = serde_json::to_string_pretty(&store) .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?; std::fs::write(&self.keys_path, json).map_err(CryptoError::Io)?; Ok(()) } pub async fn create_key(&self, description: &str) -> Result { let key_id = uuid::Uuid::new_v4().to_string(); let arn = format!("arn:aws:kms:local:000000000000:key/{}", key_id); let mut plaintext_key = [0u8; 32]; rand::thread_rng().fill_bytes(&mut plaintext_key); let master = self.master_key.read().await; let encrypted = Self::encrypt_key_material(&master, &plaintext_key)?; let kms_key = KmsKey { key_id: key_id.clone(), arn, description: description.to_string(), creation_date: Utc::now(), enabled: true, key_state: "Enabled".to_string(), key_usage: "ENCRYPT_DECRYPT".to_string(), key_spec: "SYMMETRIC_DEFAULT".to_string(), encrypted_key_material: encrypted, }; self.keys.write().await.push(kms_key.clone()); self.save().await?; Ok(kms_key) } pub async fn list_keys(&self) -> Vec { self.keys.read().await.clone() } pub async fn get_key(&self, key_id: &str) -> Option { let keys = self.keys.read().await; keys.iter() .find(|k| k.key_id == key_id || k.arn == key_id) .cloned() } pub async fn delete_key(&self, key_id: &str) -> Result { let mut keys = self.keys.write().await; let len_before = keys.len(); keys.retain(|k| k.key_id != key_id && k.arn != key_id); let removed = keys.len() < len_before; drop(keys); if removed { self.save().await?; } Ok(removed) } pub async fn enable_key(&self, key_id: &str) -> Result { let mut keys = self.keys.write().await; if let Some(key) = keys.iter_mut().find(|k| k.key_id == key_id) { key.enabled = true; key.key_state = "Enabled".to_string(); drop(keys); self.save().await?; Ok(true) } else { Ok(false) } } pub async fn disable_key(&self, key_id: &str) -> Result { let mut keys = self.keys.write().await; if let Some(key) = keys.iter_mut().find(|k| k.key_id == key_id) { key.enabled = false; key.key_state = "Disabled".to_string(); drop(keys); self.save().await?; Ok(true) } else { Ok(false) } } pub async fn decrypt_data_key(&self, key_id: &str) -> Result, CryptoError> { let keys = self.keys.read().await; let key = keys .iter() .find(|k| k.key_id == key_id || k.arn == key_id) .ok_or_else(|| CryptoError::EncryptionFailed("KMS key not found".to_string()))?; if !key.enabled { return Err(CryptoError::EncryptionFailed( "KMS key is disabled".to_string(), )); } let master = self.master_key.read().await; Self::decrypt_key_material(&master, &key.encrypted_key_material) } pub async fn encrypt_data( &self, key_id: &str, plaintext: &[u8], ) -> Result, CryptoError> { let data_key = self.decrypt_data_key(key_id).await?; if data_key.len() != 32 { return Err(CryptoError::InvalidKeySize(data_key.len())); } let key_arr: [u8; 32] = data_key.try_into().unwrap(); let cipher = Aes256Gcm::new(&key_arr.into()); let mut nonce_bytes = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes); let ciphertext = cipher .encrypt(nonce, plaintext) .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?; let mut result = Vec::with_capacity(12 + ciphertext.len()); result.extend_from_slice(&nonce_bytes); result.extend_from_slice(&ciphertext); Ok(result) } pub async fn decrypt_data( &self, key_id: &str, ciphertext: &[u8], ) -> Result, CryptoError> { if ciphertext.len() < 12 { return Err(CryptoError::EncryptionFailed( "Ciphertext too short".to_string(), )); } let data_key = self.decrypt_data_key(key_id).await?; if data_key.len() != 32 { return Err(CryptoError::InvalidKeySize(data_key.len())); } let key_arr: [u8; 32] = data_key.try_into().unwrap(); let (nonce_bytes, ct) = ciphertext.split_at(12); let cipher = Aes256Gcm::new(&key_arr.into()); let nonce = Nonce::from_slice(nonce_bytes); cipher .decrypt(nonce, ct) .map_err(|_| CryptoError::DecryptionFailed(0)) } pub async fn generate_data_key( &self, key_id: &str, num_bytes: usize, ) -> Result<(Vec, Vec), CryptoError> { let kms_key = self.decrypt_data_key(key_id).await?; if kms_key.len() != 32 { return Err(CryptoError::InvalidKeySize(kms_key.len())); } let mut plaintext_key = vec![0u8; num_bytes]; rand::thread_rng().fill_bytes(&mut plaintext_key); let key_arr: [u8; 32] = kms_key.try_into().unwrap(); let cipher = Aes256Gcm::new(&key_arr.into()); let mut nonce_bytes = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes); let encrypted = cipher .encrypt(nonce, plaintext_key.as_slice()) .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?; let mut wrapped = Vec::with_capacity(12 + encrypted.len()); wrapped.extend_from_slice(&nonce_bytes); wrapped.extend_from_slice(&encrypted); Ok((plaintext_key, wrapped)) } } pub async fn load_or_create_master_key(keys_dir: &Path) -> Result<[u8; 32], CryptoError> { std::fs::create_dir_all(keys_dir).map_err(CryptoError::Io)?; let path = keys_dir.join("master.key"); if path.exists() { let encoded = std::fs::read_to_string(&path).map_err(CryptoError::Io)?; let decoded = B64.decode(encoded.trim()).map_err(|e| { CryptoError::EncryptionFailed(format!("Bad master key encoding: {}", e)) })?; if decoded.len() != 32 { return Err(CryptoError::InvalidKeySize(decoded.len())); } let mut key = [0u8; 32]; key.copy_from_slice(&decoded); Ok(key) } else { let mut key = [0u8; 32]; rand::thread_rng().fill_bytes(&mut key); let encoded = B64.encode(key); std::fs::write(&path, &encoded).map_err(CryptoError::Io)?; Ok(key) } } #[cfg(test)] mod tests { use super::*; #[tokio::test] async fn test_create_and_list_keys() { let dir = tempfile::tempdir().unwrap(); let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("test key").await.unwrap(); assert!(key.enabled); assert_eq!(key.description, "test key"); assert!(key.key_id.len() > 0); let keys = kms.list_keys().await; assert_eq!(keys.len(), 1); assert_eq!(keys[0].key_id, key.key_id); } #[tokio::test] async fn test_enable_disable_key() { let dir = tempfile::tempdir().unwrap(); let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("toggle").await.unwrap(); assert!(key.enabled); kms.disable_key(&key.key_id).await.unwrap(); let k = kms.get_key(&key.key_id).await.unwrap(); assert!(!k.enabled); kms.enable_key(&key.key_id).await.unwrap(); let k = kms.get_key(&key.key_id).await.unwrap(); assert!(k.enabled); } #[tokio::test] async fn test_delete_key() { let dir = tempfile::tempdir().unwrap(); let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("doomed").await.unwrap(); assert!(kms.delete_key(&key.key_id).await.unwrap()); assert!(kms.get_key(&key.key_id).await.is_none()); assert_eq!(kms.list_keys().await.len(), 0); } #[tokio::test] async fn test_encrypt_decrypt_data() { let dir = tempfile::tempdir().unwrap(); let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("enc-key").await.unwrap(); let plaintext = b"Hello, KMS!"; let ciphertext = kms.encrypt_data(&key.key_id, plaintext).await.unwrap(); assert_ne!(&ciphertext, plaintext); let decrypted = kms.decrypt_data(&key.key_id, &ciphertext).await.unwrap(); assert_eq!(decrypted, plaintext); } #[tokio::test] async fn test_generate_data_key() { let dir = tempfile::tempdir().unwrap(); let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("data-key-gen").await.unwrap(); let (plaintext, wrapped) = kms.generate_data_key(&key.key_id, 32).await.unwrap(); assert_eq!(plaintext.len(), 32); assert!(wrapped.len() > 32); } #[tokio::test] async fn test_disabled_key_cannot_encrypt() { let dir = tempfile::tempdir().unwrap(); let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("disabled").await.unwrap(); kms.disable_key(&key.key_id).await.unwrap(); let result = kms.encrypt_data(&key.key_id, b"test").await; assert!(result.is_err()); } #[tokio::test] async fn test_persistence_across_reload() { let dir = tempfile::tempdir().unwrap(); let key_id = { let kms = KmsService::new(dir.path()).await.unwrap(); let key = kms.create_key("persistent").await.unwrap(); key.key_id }; let kms2 = KmsService::new(dir.path()).await.unwrap(); let key = kms2.get_key(&key_id).await; assert!(key.is_some()); assert_eq!(key.unwrap().description, "persistent"); } #[tokio::test] async fn test_master_key_roundtrip() { let dir = tempfile::tempdir().unwrap(); let key1 = load_or_create_master_key(dir.path()).await.unwrap(); let key2 = load_or_create_master_key(dir.path()).await.unwrap(); assert_eq!(key1, key2); } }