use aes_gcm_stream::{Aes256GcmStreamDecryptor, Aes256GcmStreamEncryptor}; use aes_kw::Kek; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use josekit::jwe; use josekit::jwe::alg::aeskw::AeskwJweAlgorithm; use josekit::jwe::alg::rsaes::RsaesJweAlgorithm; use josekit::jwe::JweHeader; use josekit::jwk::alg::rsa::RsaKeyPair; use josekit::jwk::Jwk; use rand::random; use rand::rngs::{OsRng, ThreadRng}; use rsa::{Oaep, RsaPrivateKey, RsaPublicKey}; use rust_util::{opt_result, simple_error, XResult}; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha1::Sha1; const LOCAL_KMS_PREFIX: &str = "LKMS:"; // JWE format: // BASE64URL(UTF8(JWE Protected Header)) || '.' || // BASE64URL(JWE Encrypted Key) || '.' || BASE64URL(JWE Initialization Vector) // || '.' || BASE64URL(JWE Ciphertext) || '.' || BASE64URL(JWE Authentication Tag). // // RSA JWE Header: // {"enc":"A256GCM","vendor":"local-mini-kms","alg":"RSA-OAEP"} // eyJlbmMiOiJBMjU2R0NNIiwidmVuZG9yIjoibG9jYWwtbWluaS1rbXMiLCJhbGciOiJSU0EtT0FFUCJ9.VQ_R // yGjXqQlUIbRIMgaYRSaX5FMRBzZ6ApfdZ2yAwiG70hjNfR3ss7x4PYqMm6QtITm1O4_fp7I3bY8iUz5Njyth_ // Min7Xm2-WsQ6gq9yN58btkUBFm60c7SC5XLaqE1pEtBAz7786bJk6M4NeOtDAOFAmIb2j1EwnS5vweBtmNv7N // UFIgvx806T3WkCFDOkMSJ10_6LSa0z-lIac-s68svsU5WW8CXVKxHAbxaHyX_otu2HxXzDZlF5Goamh5ZJtr0 // 0yW_bzDCx3hZ2nMK3Ve7IJ2ZLAMmvhj9LKWkPtoqH0dGHaPHWff5P3rZ4vtKywt_h5b6SYII_mEoJcpByMyGw // TXCtZymDt82Tyv_FesW2721JgyGxnukuOxQRTw4MfGYIO5bldL3uGGI_H4HXlXhM_kp3wuPAZ0vH4Jj2KD6MV // DDTJQaEBQIEF07i7WiNynr57kbahYwextRXYP7LgoUHfFwA5GGGpN-UkuWLlKkYLTmXGrPYnL6Cf9D3euP7nF // ml2oA2hjig-UuYf9A_QSEqNsMxYDuG-rggn3H_iXNl4ooYcxSVOXhTKfoV578MkNwG75BdHN5FeRYIKq0HCTM // lGqqBWmDibPtMd7Uq1JrDd8774lnA8JcZcCMSia4m6WJSbG0kOuJ4NJPOUrYtNEJXgWKU3FQzDB-apLMQdac. // WYJgsdZRLk310KWd.P333-S2VYg.PCfruTdk8vh3a8wcjJCe-g // // RSA-OAEP RSA using Optimal Asymmetric Encryption Padding (OAEP), as defined in RFC 3447 [RFC3447] // A256GCM Advanced Encryption Standard (AES) using 256 bit keys in Galois/Counter Mode, as defined in [FIPS‑197] and [NIST‑800‑38D] // // AES JWE Header: // {"enc":"A256GCM","vendor":"local-mini-kms","version":"5b90f66a1c6a918d","alg":"A256KW"} // eyJlbmMiOiJBMjU2R0NNIiwidmVuZG9yIjoibG9jYWwtbWluaS1rbXMiLCJ2ZXJzaW9uIjoiNWI5MGY2NmExYz // ZhOTE4ZCIsImFsZyI6IkEyNTZLVyJ9.K2_P-b_Gq9wbrssbcS5AmiUwcnNTnnZSe7rBI1SixVrC7TfFK0fruw. // ez3OKjOHAIIYnfM0.wSO3aXo.-vGJwk8JQKhi3voIlAA9gQ // // A256GCM Advanced Encryption Standard (AES) using 256 bit keys in Galois/Counter Mode, as defined in [FIPS‑197] and [NIST‑800‑38D] // A256KW Advanced Encryption Standard (AES) Key Wrap Algorithm using 256 bit keys, as defined in RFC 3394 [RFC3394] // JWE Header: {"alg":"dir","enc":"A256GCM"} // Encrypted key (CEK): (blank) // IV: Vlf_WdLm-spHbfJe // Ciphertext: RxMPrw // Authentication Tag: 5VC8Y_qSPdSubbGNGyfn6A // // JWE Header: {"alg":"A256KW","enc":"A256GCM"} // Encrypted key (CEK): 66xZoxFI18zfvLMO6WU1zzqqX1tT8xu_qZzMQyPcfVuajPNkOJUXQA // IV: X5ZL8yaOektXmfny // Ciphertext: brz-Lg // Authentication Tag: xG-EvM-9hrw0XRiuRW7HrA // // https://security.stackexchange.com/questions/80966/what-is-the-point-of-aes-key-wrap-with-json-web-encryption pub fn generate_rsa_key_2(bits: u32) -> XResult { let mut rng = OsRng::default(); Ok(RsaPrivateKey::new(&mut rng, bits as usize)?) } pub fn generate_rsa_key(bits: u32) -> XResult { Ok(RsaKeyPair::generate(bits)?) } pub fn serialize_jwe_rsa_2(payload: &[u8], rsa_public_key: &RsaPublicKey) -> XResult { let header = JweHeader2 { enc: "A256GCM".to_string(), alg: "RSA-OAEP".to_string(), vendor: "local-mini-kms".to_string(), }; serialize_jwe_fn(&header, payload, |data_key| -> XResult> { let mut r = ThreadRng::default(); Ok(opt_result!(rsa_public_key.encrypt(&mut r, Oaep::new::(), data_key), "Wrap key failed: {}")) }) } pub fn serialize_jwe_rsa(payload: &[u8], jwk: &Jwk) -> XResult { let mut header = JweHeader::new(); header.set_content_encryption("A256GCM"); header.set_claim("vendor", Some(Value::String("local-mini-kms".to_string())))?; let encrypter = RsaesJweAlgorithm::RsaOaep.encrypter_from_jwk(jwk)?; Ok(format!("{}{}", LOCAL_KMS_PREFIX, jwe::serialize_compact(payload, &header, &encrypter)?)) } pub fn deserialize_jwe_rsa_2(jwe: &str, rsa: &RsaPrivateKey) -> XResult<(Vec, JweHeader2)> { deserialize_jwe_fn(jwe, |key_wrap| -> XResult<(Vec)> { Ok(opt_result!(rsa.decrypt(Oaep::new::(), &key_wrap), "Unwrap key failed: {}")) }) } pub fn deserialize_jwe_rsa(jwe: &str, jwk: &Jwk) -> XResult<(Vec, JweHeader)> { let decrypter = RsaesJweAlgorithm::RsaOaep.decrypter_from_jwk(jwk)?; Ok(jwe::deserialize_compact(&get_jwe(jwe), &decrypter)?) } #[derive(Debug, Serialize, Deserialize)] pub struct JweHeader2 { pub enc: String, pub alg: String, pub vendor: String, } pub fn serialize_jwe_aes_2(payload: &[u8], key: [u8; 32]) -> XResult { let header = JweHeader2 { enc: "A256GCM".to_string(), alg: "A256KW".to_string(), vendor: "local-mini-kms".to_string(), }; serialize_jwe_fn(&header, payload, |data_key| -> XResult> { let kek = Kek::from(key); Ok(opt_result!(kek.wrap_vec(&data_key[..]), "Wrap key failed: {}")) }) } pub fn serialize_jwe_aes(payload: &[u8], key: &[u8]) -> XResult { let mut header = JweHeader::new(); header.set_content_encryption("A256GCM"); header.set_claim("vendor", Some(Value::String("local-mini-kms".to_string())))?; // header.set_claim("version", Some(Value::String(get_master_key_checksum(key))))?; let encrypter = AeskwJweAlgorithm::A256kw.encrypter_from_bytes(key)?; Ok(format!("{}{}", LOCAL_KMS_PREFIX, jwe::serialize_compact(payload, &header, &encrypter)?)) } pub fn deserialize_jwe_aes_2(jwe: &str, key: [u8; 32]) -> XResult<(Vec, JweHeader2)> { deserialize_jwe_fn(jwe, |key_wrap| -> XResult<(Vec)> { let kek = Kek::from(key); Ok(opt_result!(kek.unwrap_vec(&key_wrap), "Unwrap key failed: {}")) }) } pub fn deserialize_jwe_aes(jwe: &str, key: &[u8]) -> XResult<(Vec, JweHeader)> { let decrypter = AeskwJweAlgorithm::A256kw.decrypter_from_bytes(key)?; Ok(jwe::deserialize_compact(&get_jwe(jwe), &decrypter)?) } fn serialize_jwe_fn(header: &JweHeader2, payload: &[u8], key_wrap_fn: F) -> XResult where F: Fn(&[u8]) -> XResult>, { let header_str = serde_json::to_string(&header).unwrap(); let header_b64 = URL_SAFE_NO_PAD.encode(header_str.as_bytes()); let data_key: [u8; 32] = random(); let nonce: [u8; 12] = random(); let mut encryptor = Aes256GcmStreamEncryptor::new(data_key, &nonce); encryptor.init_adata(header_b64.as_bytes()); let mut e = encryptor.update(payload); let (f, t) = encryptor.finalize(); e.extend_from_slice(&f); let wrap_key = key_wrap_fn(&data_key)?; let mut jwe = String::new(); jwe.push_str(&header_b64); jwe.push_str("."); jwe.push_str(&URL_SAFE_NO_PAD.encode(&wrap_key)); jwe.push_str("."); jwe.push_str(&URL_SAFE_NO_PAD.encode(&nonce)); jwe.push_str("."); jwe.push_str(&URL_SAFE_NO_PAD.encode(&e)); jwe.push_str("."); jwe.push_str(&URL_SAFE_NO_PAD.encode(&t)); Ok(jwe) } fn deserialize_jwe_fn(jwe: &str, key_unwrap_fn: F) -> XResult<(Vec, JweHeader2)> where F: Fn(&[u8]) -> XResult>, { let jwe_parts = jwe.split(".").collect::>(); if jwe_parts.len() != 5 { return simple_error!("Invalid JWE: {}", jwe); } let header_bytes = opt_result!(decode_url_safe_no_pad(jwe_parts[0]), "Invalid JWE header: {}, JWE: {}", jwe); let header: JweHeader2 = opt_result!(serde_json::from_slice(&header_bytes), "Invalid JWE header: {}, JWE: {}", jwe); let cek = opt_result!(decode_url_safe_no_pad(jwe_parts[1]), "Invalid JWE CEK: {}, JWE: {}", jwe); let iv = opt_result!(decode_url_safe_no_pad(jwe_parts[2]), "Invalid JWE IV: {}, JWE: {}", jwe); let ciphertext = opt_result!(decode_url_safe_no_pad(jwe_parts[3]), "Invalid JWE ciphertext: {}, JWE: {}", jwe); let tag = opt_result!(decode_url_safe_no_pad(jwe_parts[4]), "Invalid JWE tag: {}, JWE: {}", jwe); let data_key = key_unwrap_fn(&cek)?; let data_key_b32 = opt_result!(to_bytes32(&data_key), "Invalid JWE CEK: {}, JWE: {}", jwe); let mut decryptor = Aes256GcmStreamDecryptor::new(data_key_b32, &iv); decryptor.init_adata(jwe_parts[0].as_bytes()); let mut plaintext = decryptor.update(&ciphertext); let plaintext_2 = decryptor.update(&tag); let plaintext_final = opt_result!(decryptor.finalize(), "Invalid JWE: {}, JWE: {}", jwe); plaintext.extend_from_slice(&plaintext_2); plaintext.extend_from_slice(&plaintext_final); Ok((plaintext, header)) } #[inline] fn decode_url_safe_no_pad(s: &str) -> XResult> { Ok(URL_SAFE_NO_PAD.decode(s.as_bytes())?) } #[inline] fn to_bytes32(bytes: &[u8]) -> XResult<[u8; 32]> { if bytes.len() != 32 { return simple_error!("Not valid 32 bytes"); } let mut ret = [0; 32]; for i in 0..32 { ret[i] = bytes[i]; } Ok(ret) } #[inline] fn get_jwe(jwe: &str) -> String { if jwe.starts_with(LOCAL_KMS_PREFIX) { jwe.chars().skip(LOCAL_KMS_PREFIX.len()).collect() } else { jwe.to_string() } }