Files
local-mini-kms/src/jose.rs

213 lines
7.8 KiB
Rust

use aes_gcm_stream::{Aes256GcmStreamDecryptor, Aes256GcmStreamEncryptor};
use aes_kw::Kek;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use jose_jwk::{Key, Rsa};
use rand::{random, thread_rng};
use rsa::pkcs1::LineEnding;
use rsa::pkcs8::EncodePublicKey;
use rsa::{Oaep, RsaPrivateKey, RsaPublicKey};
use rust_util::{iff, opt_result, simple_error, XResult};
use serde_derive::{Deserialize, Serialize};
use sha1::Sha1;
use sha2::{Digest, Sha256};
const LOCAL_KMS_PREFIX: &str = "LKMS:";
const JWE_ENC_A256GCM: &str = "A256GCM";
const JWE_ALG_A256KW: &str = "A256KW";
const JWE_ALG_RSA_OAEP: &str = "RSA-OAEP";
const JWE_DOT: &str = ".";
#[derive(Default, Debug, Serialize, Deserialize)]
pub struct JweHeader {
pub enc: String,
pub alg: String,
pub vendor: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exportable: Option<bool>,
}
pub fn generate_rsa_key(bits: u32) -> XResult<RsaPrivateKey> {
let mut rng = thread_rng();
Ok(RsaPrivateKey::new(&mut rng, bits as usize)?)
}
pub fn rsa_key_to_jwk(rsa_private_key: &RsaPrivateKey) -> XResult<jose_jwk::Jwk> {
let rsa_public_key = rsa_private_key.as_ref();
let public_rsa: Rsa = rsa_public_key.into();
Ok(jose_jwk::Jwk {
key: Key::Rsa(public_rsa),
prm: Default::default(),
})
}
pub fn rsa_key_to_pem(rsa_private_key: &RsaPrivateKey) -> XResult<String> {
Ok(rsa_private_key.to_public_key().to_public_key_pem(LineEnding::LF)?)
}
pub fn jwk_to_rsa_pubic_key(rsa_jwk: &str) -> XResult<RsaPublicKey> {
let rsa: Rsa = opt_result!(serde_json::from_str(&rsa_jwk), "Bad RSA JWK: {}, error: {}", rsa_jwk);
let rsa_public_key = opt_result!(RsaPublicKey::try_from(rsa), "Bad RSA JWK: {}, error: {:?}", rsa_jwk);
Ok(rsa_public_key)
}
pub fn serialize_jwe_rsa(payload: &[u8], rsa_public_key: &RsaPublicKey) -> XResult<String> {
let header = JweHeader {
enc: JWE_ENC_A256GCM.to_string(),
alg: JWE_ALG_RSA_OAEP.to_string(),
vendor: "local-mini-kms".to_string(),
..Default::default()
};
serialize_jwe_fn(&header, payload, |data_key| -> XResult<Vec<u8>> {
let mut r = thread_rng();
Ok(opt_result!(rsa_public_key.encrypt(&mut r, Oaep::new::<Sha1>(), data_key), "Wrap key failed: {}"))
})
}
pub fn deserialize_jwe_rsa(jwe: &str, rsa: &RsaPrivateKey) -> XResult<(Vec<u8>, JweHeader)> {
deserialize_jwe_fn(jwe, |alg, key_wrap| -> XResult<Vec<u8>> {
if alg != JWE_ALG_RSA_OAEP {
return simple_error!("Invalid JWE header alg: {}", alg);
}
Ok(opt_result!(rsa.decrypt(Oaep::new::<Sha1>(), &key_wrap), "Unwrap key failed: {}"))
})
}
pub fn serialize_jwe_aes(payload: &[u8], key: &[u8]) -> XResult<String> {
serialize_jwe_aes_32(None, None, payload, to_bytes32(key)?)
}
pub fn serialize_jwe_aes_with_data_type(data_type: &str, exportable: bool, payload: &[u8], key: &[u8]) -> XResult<String> {
serialize_jwe_aes_32(Some(data_type.to_string()), iff!(exportable, Some(false), None), payload, to_bytes32(key)?)
}
pub fn serialize_jwe_aes_32(data_type: Option<String>, exportable: Option<bool>, payload: &[u8], key: [u8; 32]) -> XResult<String> {
let header = JweHeader {
enc: JWE_ENC_A256GCM.to_string(),
alg: JWE_ALG_A256KW.to_string(),
vendor: "local-mini-kms".to_string(),
version: Some(get_master_key_checksum(&key)),
data_type,
exportable,
..Default::default()
};
serialize_jwe_fn(&header, payload, |data_key| -> XResult<Vec<u8>> {
let kek = Kek::from(key);
Ok(opt_result!(kek.wrap_vec(&data_key[..]), "Wrap key failed: {}"))
})
}
pub fn deserialize_jwe_aes(jwe: &str, key: &[u8]) -> XResult<(Vec<u8>, JweHeader)> {
deserialize_jwe_aes_32(jwe, to_bytes32(key)?)
}
pub fn deserialize_jwe_aes_32(jwe: &str, key: [u8; 32]) -> XResult<(Vec<u8>, JweHeader)> {
deserialize_jwe_fn(jwe, |alg, key_wrap| -> XResult<Vec<u8>> {
if alg != JWE_ALG_A256KW {
return simple_error!("Invalid JWE header alg: {}", alg);
}
let kek = Kek::from(key);
Ok(opt_result!(kek.unwrap_vec(&key_wrap), "Unwrap key failed: {}"))
})
}
fn serialize_jwe_fn<F>(header: &JweHeader, payload: &[u8], key_wrap_fn: F) -> XResult<String>
where
F: Fn(&[u8]) -> XResult<Vec<u8>>,
{
let header_str = opt_result!(serde_json::to_string(&header), "Invalid JWE header: {}");
let header_b64 = URL_SAFE_NO_PAD.encode(header_str.as_bytes());
let data_key: [u8; 32] = random();
let iv: [u8; 12] = random();
let mut encryptor = Aes256GcmStreamEncryptor::new(data_key, &iv);
encryptor.init_adata(header_b64.as_bytes());
let mut ciphertext = encryptor.update(payload);
let (ciphertext_final, tag) = encryptor.finalize();
ciphertext.extend_from_slice(&ciphertext_final);
let cek = key_wrap_fn(&data_key)?;
Ok(format!(
"{}{}.{}.{}.{}.{}",
LOCAL_KMS_PREFIX,
header_b64,
URL_SAFE_NO_PAD.encode(&cek),
URL_SAFE_NO_PAD.encode(&iv),
URL_SAFE_NO_PAD.encode(&ciphertext),
URL_SAFE_NO_PAD.encode(&tag)
))
}
fn deserialize_jwe_fn<F>(jwe: &str, key_unwrap_fn: F) -> XResult<(Vec<u8>, JweHeader)>
where
F: Fn(&str, &[u8]) -> XResult<Vec<u8>>,
{
let jwe = get_jwe(jwe);
let jwe_parts = jwe.split(JWE_DOT).collect::<Vec<&str>>();
if jwe_parts.len() != 5 {
return simple_error!("Invalid JWE: {}", jwe);
}
let header_bytes = opt_result!(decode_url_safe_no_pad(jwe_parts[0]), "Invalid JWE header: {}, JWE: {}", jwe);
let header: JweHeader = opt_result!(serde_json::from_slice(&header_bytes), "Invalid JWE header: {}, JWE: {}", jwe);
if header.enc != JWE_ENC_A256GCM {
return simple_error!("Invalid JWE header enc: {}", header.enc);
}
let cek = opt_result!(decode_url_safe_no_pad(jwe_parts[1]), "Invalid JWE CEK: {}, JWE: {}", jwe);
let iv = opt_result!(decode_url_safe_no_pad(jwe_parts[2]), "Invalid JWE IV: {}, JWE: {}", jwe);
let ciphertext = opt_result!(decode_url_safe_no_pad(jwe_parts[3]), "Invalid JWE ciphertext: {}, JWE: {}", jwe);
let tag = opt_result!(decode_url_safe_no_pad(jwe_parts[4]), "Invalid JWE tag: {}, JWE: {}", jwe);
let data_key = key_unwrap_fn(&header.alg, &cek)?;
let data_key_b32 = opt_result!(to_bytes32(&data_key), "Invalid JWE CEK: {}, JWE: {}", jwe);
let mut decryptor = Aes256GcmStreamDecryptor::new(data_key_b32, &iv);
decryptor.init_adata(jwe_parts[0].as_bytes());
let mut plaintext = decryptor.update(&ciphertext);
let plaintext_2 = decryptor.update(&tag);
let plaintext_final = opt_result!(decryptor.finalize(), "Invalid JWE: {}, JWE: {}", jwe);
plaintext.extend_from_slice(&plaintext_2);
plaintext.extend_from_slice(&plaintext_final);
Ok((plaintext, header))
}
#[inline]
fn decode_url_safe_no_pad(s: &str) -> XResult<Vec<u8>> {
Ok(URL_SAFE_NO_PAD.decode(s.as_bytes())?)
}
#[inline]
fn to_bytes32(bytes: &[u8]) -> XResult<[u8; 32]> {
if bytes.len() != 32 {
return simple_error!("Not valid 32 bytes");
}
let mut ret = [0; 32];
for i in 0..32 {
ret[i] = bytes[i];
}
Ok(ret)
}
fn get_master_key_checksum(key: &[u8]) -> String {
let digest = Sha256::digest(key);
let digest = Sha256::digest(digest.as_slice());
let digest = Sha256::digest(digest.as_slice());
let digest = Sha256::digest(digest.as_slice());
let digest = Sha256::digest(digest.as_slice());
let digest = Sha256::digest(digest.as_slice());
hex::encode(&digest[0..8])
}
fn get_jwe(jwe: &str) -> String {
if jwe.starts_with(LOCAL_KMS_PREFIX) {
jwe.chars().skip(LOCAL_KMS_PREFIX.len()).collect()
} else {
jwe.to_string()
}
}