use aes_gcm_stream::{Aes256GcmStreamDecryptor, Aes256GcmStreamEncryptor}; use aes_kw::Kek; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use jose_jwk::{Key, Rsa}; use rand::{random, thread_rng}; use rsa::pkcs1::LineEnding; use rsa::pkcs8::EncodePublicKey; use rsa::{Oaep, RsaPrivateKey, RsaPublicKey}; use rust_util::{iff, opt_result, simple_error, XResult}; use serde_derive::{Deserialize, Serialize}; use sha1::Sha1; use sha2::{Digest, Sha256}; const LOCAL_KMS_PREFIX: &str = "LKMS:"; const JWE_ENC_A256GCM: &str = "A256GCM"; const JWE_ALG_A256KW: &str = "A256KW"; const JWE_ALG_RSA_OAEP: &str = "RSA-OAEP"; const JWE_DOT: &str = "."; #[derive(Default, Debug, Serialize, Deserialize)] pub struct JweHeader { pub enc: String, pub alg: String, pub vendor: String, #[serde(skip_serializing_if = "Option::is_none")] pub version: Option, #[serde(skip_serializing_if = "Option::is_none")] pub data_type: Option, #[serde(skip_serializing_if = "Option::is_none")] pub exportable: Option, } pub fn generate_rsa_key(bits: u32) -> XResult { let mut rng = thread_rng(); Ok(RsaPrivateKey::new(&mut rng, bits as usize)?) } pub fn rsa_key_to_jwk(rsa_private_key: &RsaPrivateKey) -> XResult { let rsa_public_key = rsa_private_key.as_ref(); let public_rsa: Rsa = rsa_public_key.into(); Ok(jose_jwk::Jwk { key: Key::Rsa(public_rsa), prm: Default::default(), }) } pub fn rsa_key_to_pem(rsa_private_key: &RsaPrivateKey) -> XResult { Ok(rsa_private_key.to_public_key().to_public_key_pem(LineEnding::LF)?) } pub fn jwk_to_rsa_pubic_key(rsa_jwk: &str) -> XResult { let rsa: Rsa = opt_result!(serde_json::from_str(&rsa_jwk), "Bad RSA JWK: {}, error: {}", rsa_jwk); let rsa_public_key = opt_result!(RsaPublicKey::try_from(rsa), "Bad RSA JWK: {}, error: {:?}", rsa_jwk); Ok(rsa_public_key) } pub fn serialize_jwe_rsa(payload: &[u8], rsa_public_key: &RsaPublicKey) -> XResult { let header = JweHeader { enc: JWE_ENC_A256GCM.to_string(), alg: JWE_ALG_RSA_OAEP.to_string(), vendor: "local-mini-kms".to_string(), ..Default::default() }; serialize_jwe_fn(&header, payload, |data_key| -> XResult> { let mut r = thread_rng(); Ok(opt_result!(rsa_public_key.encrypt(&mut r, Oaep::new::(), data_key), "Wrap key failed: {}")) }) } pub fn deserialize_jwe_rsa(jwe: &str, rsa: &RsaPrivateKey) -> XResult<(Vec, JweHeader)> { deserialize_jwe_fn(jwe, |alg, key_wrap| -> XResult> { if alg != JWE_ALG_RSA_OAEP { return simple_error!("Invalid JWE header alg: {}", alg); } Ok(opt_result!(rsa.decrypt(Oaep::new::(), &key_wrap), "Unwrap key failed: {}")) }) } pub fn serialize_jwe_aes(payload: &[u8], key: &[u8]) -> XResult { serialize_jwe_aes_32(None, None, payload, to_bytes32(key)?) } pub fn serialize_jwe_aes_with_data_type(data_type: &str, exportable: bool, payload: &[u8], key: &[u8]) -> XResult { serialize_jwe_aes_32(Some(data_type.to_string()), iff!(exportable, Some(false), None), payload, to_bytes32(key)?) } pub fn serialize_jwe_aes_32(data_type: Option, exportable: Option, payload: &[u8], key: [u8; 32]) -> XResult { let header = JweHeader { enc: JWE_ENC_A256GCM.to_string(), alg: JWE_ALG_A256KW.to_string(), vendor: "local-mini-kms".to_string(), version: Some(get_master_key_checksum(&key)), data_type, exportable, ..Default::default() }; serialize_jwe_fn(&header, payload, |data_key| -> XResult> { let kek = Kek::from(key); Ok(opt_result!(kek.wrap_vec(&data_key[..]), "Wrap key failed: {}")) }) } pub fn deserialize_jwe_aes(jwe: &str, key: &[u8]) -> XResult<(Vec, JweHeader)> { deserialize_jwe_aes_32(jwe, to_bytes32(key)?) } pub fn deserialize_jwe_aes_32(jwe: &str, key: [u8; 32]) -> XResult<(Vec, JweHeader)> { deserialize_jwe_fn(jwe, |alg, key_wrap| -> XResult> { if alg != JWE_ALG_A256KW { return simple_error!("Invalid JWE header alg: {}", alg); } let kek = Kek::from(key); Ok(opt_result!(kek.unwrap_vec(&key_wrap), "Unwrap key failed: {}")) }) } fn serialize_jwe_fn(header: &JweHeader, payload: &[u8], key_wrap_fn: F) -> XResult where F: Fn(&[u8]) -> XResult>, { let header_str = opt_result!(serde_json::to_string(&header), "Invalid JWE header: {}"); let header_b64 = URL_SAFE_NO_PAD.encode(header_str.as_bytes()); let data_key: [u8; 32] = random(); let iv: [u8; 12] = random(); let mut encryptor = Aes256GcmStreamEncryptor::new(data_key, &iv); encryptor.init_adata(header_b64.as_bytes()); let mut ciphertext = encryptor.update(payload); let (ciphertext_final, tag) = encryptor.finalize(); ciphertext.extend_from_slice(&ciphertext_final); let cek = key_wrap_fn(&data_key)?; Ok(format!( "{}{}.{}.{}.{}.{}", LOCAL_KMS_PREFIX, header_b64, URL_SAFE_NO_PAD.encode(&cek), URL_SAFE_NO_PAD.encode(&iv), URL_SAFE_NO_PAD.encode(&ciphertext), URL_SAFE_NO_PAD.encode(&tag) )) } fn deserialize_jwe_fn(jwe: &str, key_unwrap_fn: F) -> XResult<(Vec, JweHeader)> where F: Fn(&str, &[u8]) -> XResult>, { let jwe = get_jwe(jwe); let jwe_parts = jwe.split(JWE_DOT).collect::>(); if jwe_parts.len() != 5 { return simple_error!("Invalid JWE: {}", jwe); } let header_bytes = opt_result!(decode_url_safe_no_pad(jwe_parts[0]), "Invalid JWE header: {}, JWE: {}", jwe); let header: JweHeader = opt_result!(serde_json::from_slice(&header_bytes), "Invalid JWE header: {}, JWE: {}", jwe); if header.enc != JWE_ENC_A256GCM { return simple_error!("Invalid JWE header enc: {}", header.enc); } let cek = opt_result!(decode_url_safe_no_pad(jwe_parts[1]), "Invalid JWE CEK: {}, JWE: {}", jwe); let iv = opt_result!(decode_url_safe_no_pad(jwe_parts[2]), "Invalid JWE IV: {}, JWE: {}", jwe); let ciphertext = opt_result!(decode_url_safe_no_pad(jwe_parts[3]), "Invalid JWE ciphertext: {}, JWE: {}", jwe); let tag = opt_result!(decode_url_safe_no_pad(jwe_parts[4]), "Invalid JWE tag: {}, JWE: {}", jwe); let data_key = key_unwrap_fn(&header.alg, &cek)?; let data_key_b32 = opt_result!(to_bytes32(&data_key), "Invalid JWE CEK: {}, JWE: {}", jwe); let mut decryptor = Aes256GcmStreamDecryptor::new(data_key_b32, &iv); decryptor.init_adata(jwe_parts[0].as_bytes()); let mut plaintext = decryptor.update(&ciphertext); let plaintext_2 = decryptor.update(&tag); let plaintext_final = opt_result!(decryptor.finalize(), "Invalid JWE: {}, JWE: {}", jwe); plaintext.extend_from_slice(&plaintext_2); plaintext.extend_from_slice(&plaintext_final); Ok((plaintext, header)) } #[inline] fn decode_url_safe_no_pad(s: &str) -> XResult> { Ok(URL_SAFE_NO_PAD.decode(s.as_bytes())?) } #[inline] fn to_bytes32(bytes: &[u8]) -> XResult<[u8; 32]> { if bytes.len() != 32 { return simple_error!("Not valid 32 bytes"); } let mut ret = [0; 32]; for i in 0..32 { ret[i] = bytes[i]; } Ok(ret) } fn get_master_key_checksum(key: &[u8]) -> String { let digest = Sha256::digest(key); let digest = Sha256::digest(digest.as_slice()); let digest = Sha256::digest(digest.as_slice()); let digest = Sha256::digest(digest.as_slice()); let digest = Sha256::digest(digest.as_slice()); let digest = Sha256::digest(digest.as_slice()); hex::encode(&digest[0..8]) } fn get_jwe(jwe: &str) -> String { if jwe.starts_with(LOCAL_KMS_PREFIX) { jwe.chars().skip(LOCAL_KMS_PREFIX.len()).collect() } else { jwe.to_string() } }