From f2fa7594dcf9d7dd2c029d907f602854e85c9ffb Mon Sep 17 00:00:00 2001 From: Hatter Jiang Date: Fri, 25 Jul 2025 00:57:02 +0800 Subject: [PATCH] feat: works --- src/lib.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5799c7e..a89a235 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ -use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; -use rust_util::{opt_result, simple_error, SimpleError, XResult}; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use rust_util::{SimpleError, XResult, opt_result, simple_error}; use serde::{Deserialize, Serialize}; -use sm4_gcm::{sm4_gcm_encrypt, Sm4GcmStreamEncryptor, Sm4Key}; +use sm4_gcm::{ + Sm4GcmStreamDecryptor, Sm4GcmStreamEncryptor, Sm4Key, sm4_gcm_decrypt, sm4_gcm_encrypt, +}; use zeroize::Zeroize; const SM4GCM: &str = "SM4GCM"; @@ -70,6 +72,14 @@ fn test() { .key_id("k001"); println!("{}", k.encrypt(JweAlg::Dir, b"hello world").unwrap()); println!("{}", k.encrypt(JweAlg::Sm4Skw, b"hello world").unwrap()); + + let key_finder = |_key_id| -> XResult { Ok(Sm4Key(k.sm4key.0)) }; + let plaintext = Sm4Jwk::decrypt(key_finder, + "eyJraWQiOiJrMDAxIiwiZW5jIjoiU000R0NNIiwiYWxnIjoiZGlyIn0..R0vzs8r5gTzz72ve.K6iaSJpyel_IZUw.iLPlxGV-xpBe9O5vZ8oz9A").unwrap(); + println!("1:{}", String::from_utf8_lossy(&plaintext)); + let plaintext = Sm4Jwk::decrypt(key_finder, + "eyJraWQiOiJrMDAxIiwiZW5jIjoiU000R0NNIiwiYWxnIjoiU000U0tXIn0.Yf2NSJSr-coYZP2IhzjtcgYfRq1Nalt7MyrRMgUm8hqD4vZ8QPMyp2brhGY.tIhwFoCSnzWcycVY.KaMzkCzGSmylUdY.du5GLF6lMz3vGJuSO614aw").unwrap(); + println!("2:{}", String::from_utf8_lossy(&plaintext)); } impl Sm4Jwk { @@ -78,6 +88,65 @@ impl Sm4Jwk { self } + pub fn decrypt( + key_finder: impl Fn(Option) -> XResult, + jwe: &str, + ) -> XResult> { + let jwe_parts: Vec<&str> = jwe.split(".").collect(); + if jwe_parts.len() != 5 { + return simple_error!( + "invalid JWE format, expect 5 parts, actual {} part(s)", + jwe_parts.len() + ); + } + let header_base64 = jwe_parts[0]; + let header_bytes = decode_url_safe_no_pad("header", header_base64)?; + let header: JweHeader = opt_result!( + serde_json::from_slice(&header_bytes), + "invalid JWE header: {}" + ); + let alg = match header.alg.as_str() { + "dir" => JweAlg::Dir, + "SM4SKW" => JweAlg::Sm4Skw, + _ => return simple_error!("invalid JWE alg: {}", header.alg), + }; + + let key_id = header.kid.as_deref(); + let sm4_key = opt_result!( + key_finder(key_id.map(|s| s.to_string())), + "find key: {:?} failed: {}", + key_id + ); + + let mut temp_key = vec![]; + if alg == JweAlg::Sm4Skw { + let encrypted_key_bytes = decode_url_safe_no_pad("encrypted wrap key", jwe_parts[1])?; + temp_key = decrypt_sm4ske(&sm4_key, &encrypted_key_bytes)?; + } + + let nonce_bytes = decode_url_safe_no_pad("nonce", jwe_parts[2])?; + + let mut decryptor; + if alg == JweAlg::Sm4Skw { + let temp_sm4_key = Sm4Key::from_slice(&temp_key)?; + decryptor = Sm4GcmStreamDecryptor::new(&temp_sm4_key, &nonce_bytes); + } else { + decryptor = Sm4GcmStreamDecryptor::new(&sm4_key, &nonce_bytes); + } + + let ciphertext = decode_url_safe_no_pad("ciphertext", jwe_parts[3])?; + let tag = decode_url_safe_no_pad("tag", jwe_parts[4])?; + + decryptor.init_adata(header_base64.as_bytes()); + let mut plaintext = decryptor.update(&ciphertext); + let plaintext2 = decryptor.update(&tag); + let plaintext3 = opt_result!(decryptor.finalize(), "decrypt JWE failed: {}"); + plaintext.extend_from_slice(&plaintext2); + plaintext.extend_from_slice(&plaintext3); + + Ok(plaintext) + } + pub fn encrypt(&self, alg: JweAlg, message: &[u8]) -> XResult { if alg == JweAlg::Sm2Pke { return simple_error!("SM2PKE is not supported"); @@ -131,6 +200,16 @@ impl Sm4Jwk { } } +fn decrypt_sm4ske(sm4jwk: &Sm4Key, wrap_key: &[u8]) -> XResult> { + let nonce = &wrap_key[0..12]; + let ciphertext = &wrap_key[12..]; + let temp_key = opt_result!( + sm4_gcm_decrypt(sm4jwk, nonce, ciphertext), + "unwrap temp key failed: {}" + ); + Ok(temp_key) +} + fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult> { let nonce: [u8; 12] = rand::random(); let encrypted_key = sm4_gcm_encrypt(sm4key, &nonce, key); @@ -139,6 +218,14 @@ fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult> { Ok(ske) } +fn decode_url_safe_no_pad(tag: &str, s: &str) -> XResult> { + Ok(opt_result!( + URL_SAFE_NO_PAD.decode(s), + "decode {} failed: {}", + tag + )) +} + fn encode_url_safe_no_pad(bytes: &[u8]) -> String { URL_SAFE_NO_PAD.encode(bytes) }