213 lines
7.8 KiB
Rust
213 lines
7.8 KiB
Rust
use aes_gcm_stream::{Aes256GcmStreamDecryptor, Aes256GcmStreamEncryptor};
|
|
use aes_kw::Kek;
|
|
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
|
use base64::Engine;
|
|
use jose_jwk::{Key, Rsa};
|
|
use rand::{random, thread_rng};
|
|
use rsa::pkcs1::LineEnding;
|
|
use rsa::pkcs8::EncodePublicKey;
|
|
use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
|
|
use rust_util::{iff, opt_result, simple_error, XResult};
|
|
use serde_derive::{Deserialize, Serialize};
|
|
use sha1::Sha1;
|
|
use sha2::{Digest, Sha256};
|
|
|
|
const LOCAL_KMS_PREFIX: &str = "LKMS:";
|
|
const JWE_ENC_A256GCM: &str = "A256GCM";
|
|
const JWE_ALG_A256KW: &str = "A256KW";
|
|
const JWE_ALG_RSA_OAEP: &str = "RSA-OAEP";
|
|
const JWE_DOT: &str = ".";
|
|
|
|
#[derive(Default, Debug, Serialize, Deserialize)]
|
|
pub struct JweHeader {
|
|
pub enc: String,
|
|
pub alg: String,
|
|
pub vendor: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub version: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub data_type: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub exportable: Option<bool>,
|
|
}
|
|
|
|
pub fn generate_rsa_key(bits: u32) -> XResult<RsaPrivateKey> {
|
|
let mut rng = thread_rng();
|
|
Ok(RsaPrivateKey::new(&mut rng, bits as usize)?)
|
|
}
|
|
|
|
pub fn rsa_key_to_jwk(rsa_private_key: &RsaPrivateKey) -> XResult<jose_jwk::Jwk> {
|
|
let rsa_public_key = rsa_private_key.as_ref();
|
|
let public_rsa: Rsa = rsa_public_key.into();
|
|
Ok(jose_jwk::Jwk {
|
|
key: Key::Rsa(public_rsa),
|
|
prm: Default::default(),
|
|
})
|
|
}
|
|
|
|
pub fn rsa_key_to_pem(rsa_private_key: &RsaPrivateKey) -> XResult<String> {
|
|
Ok(rsa_private_key.to_public_key().to_public_key_pem(LineEnding::LF)?)
|
|
}
|
|
|
|
pub fn jwk_to_rsa_pubic_key(rsa_jwk: &str) -> XResult<RsaPublicKey> {
|
|
let rsa: Rsa = opt_result!(serde_json::from_str(&rsa_jwk), "Bad RSA JWK: {}, error: {}", rsa_jwk);
|
|
let rsa_public_key = opt_result!(RsaPublicKey::try_from(rsa), "Bad RSA JWK: {}, error: {:?}", rsa_jwk);
|
|
Ok(rsa_public_key)
|
|
}
|
|
|
|
pub fn serialize_jwe_rsa(payload: &[u8], rsa_public_key: &RsaPublicKey) -> XResult<String> {
|
|
let header = JweHeader {
|
|
enc: JWE_ENC_A256GCM.to_string(),
|
|
alg: JWE_ALG_RSA_OAEP.to_string(),
|
|
vendor: "local-mini-kms".to_string(),
|
|
..Default::default()
|
|
};
|
|
serialize_jwe_fn(&header, payload, |data_key| -> XResult<Vec<u8>> {
|
|
let mut r = thread_rng();
|
|
Ok(opt_result!(rsa_public_key.encrypt(&mut r, Oaep::new::<Sha1>(), data_key), "Wrap key failed: {}"))
|
|
})
|
|
}
|
|
|
|
pub fn deserialize_jwe_rsa(jwe: &str, rsa: &RsaPrivateKey) -> XResult<(Vec<u8>, JweHeader)> {
|
|
deserialize_jwe_fn(jwe, |alg, key_wrap| -> XResult<Vec<u8>> {
|
|
if alg != JWE_ALG_RSA_OAEP {
|
|
return simple_error!("Invalid JWE header alg: {}", alg);
|
|
}
|
|
Ok(opt_result!(rsa.decrypt(Oaep::new::<Sha1>(), &key_wrap), "Unwrap key failed: {}"))
|
|
})
|
|
}
|
|
|
|
pub fn serialize_jwe_aes(payload: &[u8], key: &[u8]) -> XResult<String> {
|
|
serialize_jwe_aes_32(None, None, payload, to_bytes32(key)?)
|
|
}
|
|
|
|
pub fn serialize_jwe_aes_with_data_type(data_type: &str, exportable: bool, payload: &[u8], key: &[u8]) -> XResult<String> {
|
|
serialize_jwe_aes_32(Some(data_type.to_string()), iff!(exportable, Some(false), None), payload, to_bytes32(key)?)
|
|
}
|
|
|
|
pub fn serialize_jwe_aes_32(data_type: Option<String>, exportable: Option<bool>, payload: &[u8], key: [u8; 32]) -> XResult<String> {
|
|
let header = JweHeader {
|
|
enc: JWE_ENC_A256GCM.to_string(),
|
|
alg: JWE_ALG_A256KW.to_string(),
|
|
vendor: "local-mini-kms".to_string(),
|
|
version: Some(get_master_key_checksum(&key)),
|
|
data_type,
|
|
exportable,
|
|
..Default::default()
|
|
};
|
|
serialize_jwe_fn(&header, payload, |data_key| -> XResult<Vec<u8>> {
|
|
let kek = Kek::from(key);
|
|
Ok(opt_result!(kek.wrap_vec(&data_key[..]), "Wrap key failed: {}"))
|
|
})
|
|
}
|
|
|
|
pub fn deserialize_jwe_aes(jwe: &str, key: &[u8]) -> XResult<(Vec<u8>, JweHeader)> {
|
|
deserialize_jwe_aes_32(jwe, to_bytes32(key)?)
|
|
}
|
|
|
|
pub fn deserialize_jwe_aes_32(jwe: &str, key: [u8; 32]) -> XResult<(Vec<u8>, JweHeader)> {
|
|
deserialize_jwe_fn(jwe, |alg, key_wrap| -> XResult<Vec<u8>> {
|
|
if alg != JWE_ALG_A256KW {
|
|
return simple_error!("Invalid JWE header alg: {}", alg);
|
|
}
|
|
let kek = Kek::from(key);
|
|
Ok(opt_result!(kek.unwrap_vec(&key_wrap), "Unwrap key failed: {}"))
|
|
})
|
|
}
|
|
|
|
fn serialize_jwe_fn<F>(header: &JweHeader, payload: &[u8], key_wrap_fn: F) -> XResult<String>
|
|
where
|
|
F: Fn(&[u8]) -> XResult<Vec<u8>>,
|
|
{
|
|
let header_str = opt_result!(serde_json::to_string(&header), "Invalid JWE header: {}");
|
|
let header_b64 = URL_SAFE_NO_PAD.encode(header_str.as_bytes());
|
|
|
|
let data_key: [u8; 32] = random();
|
|
let iv: [u8; 12] = random();
|
|
let mut encryptor = Aes256GcmStreamEncryptor::new(data_key, &iv);
|
|
encryptor.init_adata(header_b64.as_bytes());
|
|
let mut ciphertext = encryptor.update(payload);
|
|
let (ciphertext_final, tag) = encryptor.finalize();
|
|
ciphertext.extend_from_slice(&ciphertext_final);
|
|
|
|
let cek = key_wrap_fn(&data_key)?;
|
|
|
|
Ok(format!(
|
|
"{}{}.{}.{}.{}.{}",
|
|
LOCAL_KMS_PREFIX,
|
|
header_b64,
|
|
URL_SAFE_NO_PAD.encode(&cek),
|
|
URL_SAFE_NO_PAD.encode(&iv),
|
|
URL_SAFE_NO_PAD.encode(&ciphertext),
|
|
URL_SAFE_NO_PAD.encode(&tag)
|
|
))
|
|
}
|
|
|
|
fn deserialize_jwe_fn<F>(jwe: &str, key_unwrap_fn: F) -> XResult<(Vec<u8>, JweHeader)>
|
|
where
|
|
F: Fn(&str, &[u8]) -> XResult<Vec<u8>>,
|
|
{
|
|
let jwe = get_jwe(jwe);
|
|
let jwe_parts = jwe.split(JWE_DOT).collect::<Vec<&str>>();
|
|
if jwe_parts.len() != 5 {
|
|
return simple_error!("Invalid JWE: {}", jwe);
|
|
}
|
|
let header_bytes = opt_result!(decode_url_safe_no_pad(jwe_parts[0]), "Invalid JWE header: {}, JWE: {}", jwe);
|
|
let header: JweHeader = opt_result!(serde_json::from_slice(&header_bytes), "Invalid JWE header: {}, JWE: {}", jwe);
|
|
if header.enc != JWE_ENC_A256GCM {
|
|
return simple_error!("Invalid JWE header enc: {}", header.enc);
|
|
}
|
|
|
|
let cek = opt_result!(decode_url_safe_no_pad(jwe_parts[1]), "Invalid JWE CEK: {}, JWE: {}", jwe);
|
|
let iv = opt_result!(decode_url_safe_no_pad(jwe_parts[2]), "Invalid JWE IV: {}, JWE: {}", jwe);
|
|
let ciphertext = opt_result!(decode_url_safe_no_pad(jwe_parts[3]), "Invalid JWE ciphertext: {}, JWE: {}", jwe);
|
|
let tag = opt_result!(decode_url_safe_no_pad(jwe_parts[4]), "Invalid JWE tag: {}, JWE: {}", jwe);
|
|
|
|
let data_key = key_unwrap_fn(&header.alg, &cek)?;
|
|
let data_key_b32 = opt_result!(to_bytes32(&data_key), "Invalid JWE CEK: {}, JWE: {}", jwe);
|
|
|
|
let mut decryptor = Aes256GcmStreamDecryptor::new(data_key_b32, &iv);
|
|
decryptor.init_adata(jwe_parts[0].as_bytes());
|
|
let mut plaintext = decryptor.update(&ciphertext);
|
|
let plaintext_2 = decryptor.update(&tag);
|
|
let plaintext_final = opt_result!(decryptor.finalize(), "Invalid JWE: {}, JWE: {}", jwe);
|
|
plaintext.extend_from_slice(&plaintext_2);
|
|
plaintext.extend_from_slice(&plaintext_final);
|
|
Ok((plaintext, header))
|
|
}
|
|
|
|
#[inline]
|
|
fn decode_url_safe_no_pad(s: &str) -> XResult<Vec<u8>> {
|
|
Ok(URL_SAFE_NO_PAD.decode(s.as_bytes())?)
|
|
}
|
|
|
|
#[inline]
|
|
fn to_bytes32(bytes: &[u8]) -> XResult<[u8; 32]> {
|
|
if bytes.len() != 32 {
|
|
return simple_error!("Not valid 32 bytes");
|
|
}
|
|
let mut ret = [0; 32];
|
|
for i in 0..32 {
|
|
ret[i] = bytes[i];
|
|
}
|
|
Ok(ret)
|
|
}
|
|
|
|
fn get_master_key_checksum(key: &[u8]) -> String {
|
|
let digest = Sha256::digest(key);
|
|
let digest = Sha256::digest(digest.as_slice());
|
|
let digest = Sha256::digest(digest.as_slice());
|
|
let digest = Sha256::digest(digest.as_slice());
|
|
let digest = Sha256::digest(digest.as_slice());
|
|
let digest = Sha256::digest(digest.as_slice());
|
|
hex::encode(&digest[0..8])
|
|
}
|
|
|
|
fn get_jwe(jwe: &str) -> String {
|
|
if jwe.starts_with(LOCAL_KMS_PREFIX) {
|
|
jwe.chars().skip(LOCAL_KMS_PREFIX.len()).collect()
|
|
} else {
|
|
jwe.to_string()
|
|
}
|
|
}
|