feat: works

This commit is contained in:
2025-07-25 00:57:02 +08:00
parent 75f9c87738
commit f2fa7594dc

View File

@@ -1,8 +1,10 @@
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use rust_util::{opt_result, simple_error, SimpleError, XResult};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use rust_util::{SimpleError, XResult, opt_result, simple_error};
use serde::{Deserialize, Serialize};
use sm4_gcm::{sm4_gcm_encrypt, Sm4GcmStreamEncryptor, Sm4Key};
use sm4_gcm::{
Sm4GcmStreamDecryptor, Sm4GcmStreamEncryptor, Sm4Key, sm4_gcm_decrypt, sm4_gcm_encrypt,
};
use zeroize::Zeroize;
const SM4GCM: &str = "SM4GCM";
@@ -70,6 +72,14 @@ fn test() {
.key_id("k001");
println!("{}", k.encrypt(JweAlg::Dir, b"hello world").unwrap());
println!("{}", k.encrypt(JweAlg::Sm4Skw, b"hello world").unwrap());
let key_finder = |_key_id| -> XResult<Sm4Key> { Ok(Sm4Key(k.sm4key.0)) };
let plaintext = Sm4Jwk::decrypt(key_finder,
"eyJraWQiOiJrMDAxIiwiZW5jIjoiU000R0NNIiwiYWxnIjoiZGlyIn0..R0vzs8r5gTzz72ve.K6iaSJpyel_IZUw.iLPlxGV-xpBe9O5vZ8oz9A").unwrap();
println!("1:{}", String::from_utf8_lossy(&plaintext));
let plaintext = Sm4Jwk::decrypt(key_finder,
"eyJraWQiOiJrMDAxIiwiZW5jIjoiU000R0NNIiwiYWxnIjoiU000U0tXIn0.Yf2NSJSr-coYZP2IhzjtcgYfRq1Nalt7MyrRMgUm8hqD4vZ8QPMyp2brhGY.tIhwFoCSnzWcycVY.KaMzkCzGSmylUdY.du5GLF6lMz3vGJuSO614aw").unwrap();
println!("2:{}", String::from_utf8_lossy(&plaintext));
}
impl Sm4Jwk {
@@ -78,6 +88,65 @@ impl Sm4Jwk {
self
}
pub fn decrypt(
key_finder: impl Fn(Option<String>) -> XResult<Sm4Key>,
jwe: &str,
) -> XResult<Vec<u8>> {
let jwe_parts: Vec<&str> = jwe.split(".").collect();
if jwe_parts.len() != 5 {
return simple_error!(
"invalid JWE format, expect 5 parts, actual {} part(s)",
jwe_parts.len()
);
}
let header_base64 = jwe_parts[0];
let header_bytes = decode_url_safe_no_pad("header", header_base64)?;
let header: JweHeader = opt_result!(
serde_json::from_slice(&header_bytes),
"invalid JWE header: {}"
);
let alg = match header.alg.as_str() {
"dir" => JweAlg::Dir,
"SM4SKW" => JweAlg::Sm4Skw,
_ => return simple_error!("invalid JWE alg: {}", header.alg),
};
let key_id = header.kid.as_deref();
let sm4_key = opt_result!(
key_finder(key_id.map(|s| s.to_string())),
"find key: {:?} failed: {}",
key_id
);
let mut temp_key = vec![];
if alg == JweAlg::Sm4Skw {
let encrypted_key_bytes = decode_url_safe_no_pad("encrypted wrap key", jwe_parts[1])?;
temp_key = decrypt_sm4ske(&sm4_key, &encrypted_key_bytes)?;
}
let nonce_bytes = decode_url_safe_no_pad("nonce", jwe_parts[2])?;
let mut decryptor;
if alg == JweAlg::Sm4Skw {
let temp_sm4_key = Sm4Key::from_slice(&temp_key)?;
decryptor = Sm4GcmStreamDecryptor::new(&temp_sm4_key, &nonce_bytes);
} else {
decryptor = Sm4GcmStreamDecryptor::new(&sm4_key, &nonce_bytes);
}
let ciphertext = decode_url_safe_no_pad("ciphertext", jwe_parts[3])?;
let tag = decode_url_safe_no_pad("tag", jwe_parts[4])?;
decryptor.init_adata(header_base64.as_bytes());
let mut plaintext = decryptor.update(&ciphertext);
let plaintext2 = decryptor.update(&tag);
let plaintext3 = opt_result!(decryptor.finalize(), "decrypt JWE failed: {}");
plaintext.extend_from_slice(&plaintext2);
plaintext.extend_from_slice(&plaintext3);
Ok(plaintext)
}
pub fn encrypt(&self, alg: JweAlg, message: &[u8]) -> XResult<String> {
if alg == JweAlg::Sm2Pke {
return simple_error!("SM2PKE is not supported");
@@ -131,6 +200,16 @@ impl Sm4Jwk {
}
}
fn decrypt_sm4ske(sm4jwk: &Sm4Key, wrap_key: &[u8]) -> XResult<Vec<u8>> {
let nonce = &wrap_key[0..12];
let ciphertext = &wrap_key[12..];
let temp_key = opt_result!(
sm4_gcm_decrypt(sm4jwk, nonce, ciphertext),
"unwrap temp key failed: {}"
);
Ok(temp_key)
}
fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult<Vec<u8>> {
let nonce: [u8; 12] = rand::random();
let encrypted_key = sm4_gcm_encrypt(sm4key, &nonce, key);
@@ -139,6 +218,14 @@ fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult<Vec<u8>> {
Ok(ske)
}
fn decode_url_safe_no_pad(tag: &str, s: &str) -> XResult<Vec<u8>> {
Ok(opt_result!(
URL_SAFE_NO_PAD.decode(s),
"decode {} failed: {}",
tag
))
}
fn encode_url_safe_no_pad(bytes: &[u8]) -> String {
URL_SAFE_NO_PAD.encode(bytes)
}