use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use rust_util::{opt_result, simple_error, SimpleError, XResult}; use serde::{Deserialize, Serialize}; use sm4_gcm::{sm4_gcm_encrypt, Sm4GcmStreamEncryptor, Sm4Key}; use zeroize::Zeroize; const SM4GCM: &str = "SM4GCM"; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JweHeader { #[serde(skip_serializing_if = "Option::is_none")] kid: Option, enc: String, alg: String, } #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum JweAlg { Dir, Sm2Pke, Sm4Skw, } impl JweAlg { pub fn get_name(&self) -> &'static str { match self { JweAlg::Dir => "dir", JweAlg::Sm2Pke => "SM2PKE", JweAlg::Sm4Skw => "SM4SKW", } } } pub struct Sm4Jwk { key_id: Option, sm4key: Sm4Key, } impl TryFrom> for Sm4Jwk { type Error = SimpleError; fn try_from(mut value: Vec) -> Result { let sm4key = opt_result!(Sm4Key::from_slice(&value), "Invalid SM4 key: {}"); value.zeroize(); Ok(Sm4Jwk { key_id: None, sm4key, }) } } impl TryFrom<[u8; 16]> for Sm4Jwk { type Error = SimpleError; fn try_from(mut value: [u8; 16]) -> Result { let sm4key = opt_result!(Sm4Key::from_slice(&value), "Invalid SM4 key: {}"); value.zeroize(); Ok(Sm4Jwk { key_id: None, sm4key, }) } } #[test] fn test() { let k = Sm4Jwk::try_from(b"0123456789012345".to_vec()) .unwrap() .key_id("k001"); println!("{}", k.encrypt(JweAlg::Dir, b"hello world").unwrap()); println!("{}", k.encrypt(JweAlg::Sm4Skw, b"hello world").unwrap()); } impl Sm4Jwk { pub fn key_id(mut self, key_id: &str) -> Self { self.key_id = Some(key_id.to_string()); self } pub fn encrypt(&self, alg: JweAlg, message: &[u8]) -> XResult { if alg == JweAlg::Sm2Pke { return simple_error!("SM2PKE is not supported"); } let encrypted_temp_key; let temp_key: [u8; 16]; if alg == JweAlg::Sm4Skw { temp_key = rand::random(); encrypted_temp_key = encode_url_safe_no_pad(&encrypt_sm4ske(&self.sm4key, &temp_key)?); } else { temp_key = [0_u8; 16]; encrypted_temp_key = "".to_string(); } let jwe_header = JweHeader { kid: self.key_id.clone(), enc: SM4GCM.to_string(), alg: alg.get_name().to_string(), }; let header_str = opt_result!( serde_json::to_string(&jwe_header), "serialize header failed: {}" ); let header_base64 = encode_url_safe_no_pad(header_str.as_bytes()); let nonce: [u8; 12] = rand::random(); let mut encryptor; if alg == JweAlg::Sm4Skw { let temp_sm4_key = Sm4Key::from_slice(&temp_key)?; encryptor = Sm4GcmStreamEncryptor::new(&temp_sm4_key, &nonce); } else { encryptor = Sm4GcmStreamEncryptor::new(&self.sm4key, &nonce); } encryptor.init_adata(header_base64.as_bytes()); let mut ciphertext = encryptor.update(message); let (enc2, tag) = encryptor.finalize(); ciphertext.extend_from_slice(&enc2); Ok(format!( "{}.{}.{}.{}.{}", header_base64, encrypted_temp_key, encode_url_safe_no_pad(&nonce), encode_url_safe_no_pad(&ciphertext), encode_url_safe_no_pad(&tag) )) } } fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult> { let nonce: [u8; 12] = rand::random(); let encrypted_key = sm4_gcm_encrypt(sm4key, &nonce, key); let mut ske = nonce.to_vec(); ske.extend_from_slice(&encrypted_key); Ok(ske) } fn encode_url_safe_no_pad(bytes: &[u8]) -> String { URL_SAFE_NO_PAD.encode(bytes) }