Files
sm4-jwe/src/lib.rs
2025-07-24 23:51:36 +08:00

145 lines
3.9 KiB
Rust

use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use rust_util::{opt_result, simple_error, SimpleError, XResult};
use serde::{Deserialize, Serialize};
use sm4_gcm::{sm4_gcm_encrypt, Sm4GcmStreamEncryptor, Sm4Key};
use zeroize::Zeroize;
const SM4GCM: &str = "SM4GCM";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JweHeader {
#[serde(skip_serializing_if = "Option::is_none")]
kid: Option<String>,
enc: String,
alg: String,
}
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum JweAlg {
Dir,
Sm2Pke,
Sm4Skw,
}
impl JweAlg {
pub fn get_name(&self) -> &'static str {
match self {
JweAlg::Dir => "dir",
JweAlg::Sm2Pke => "SM2PKE",
JweAlg::Sm4Skw => "SM4SKW",
}
}
}
pub struct Sm4Jwk {
key_id: Option<String>,
sm4key: Sm4Key,
}
impl TryFrom<Vec<u8>> for Sm4Jwk {
type Error = SimpleError;
fn try_from(mut value: Vec<u8>) -> Result<Self, Self::Error> {
let sm4key = opt_result!(Sm4Key::from_slice(&value), "Invalid SM4 key: {}");
value.zeroize();
Ok(Sm4Jwk {
key_id: None,
sm4key,
})
}
}
impl TryFrom<[u8; 16]> for Sm4Jwk {
type Error = SimpleError;
fn try_from(mut value: [u8; 16]) -> Result<Self, Self::Error> {
let sm4key = opt_result!(Sm4Key::from_slice(&value), "Invalid SM4 key: {}");
value.zeroize();
Ok(Sm4Jwk {
key_id: None,
sm4key,
})
}
}
#[test]
fn test() {
let k = Sm4Jwk::try_from(b"0123456789012345".to_vec())
.unwrap()
.key_id("k001");
println!("{}", k.encrypt(JweAlg::Dir, b"hello world").unwrap());
println!("{}", k.encrypt(JweAlg::Sm4Skw, b"hello world").unwrap());
}
impl Sm4Jwk {
pub fn key_id(mut self, key_id: &str) -> Self {
self.key_id = Some(key_id.to_string());
self
}
pub fn encrypt(&self, alg: JweAlg, message: &[u8]) -> XResult<String> {
if alg == JweAlg::Sm2Pke {
return simple_error!("SM2PKE is not supported");
}
let encrypted_temp_key;
let temp_key: [u8; 16];
if alg == JweAlg::Sm4Skw {
temp_key = rand::random();
encrypted_temp_key = encode_url_safe_no_pad(&encrypt_sm4ske(&self.sm4key, &temp_key)?);
} else {
temp_key = [0_u8; 16];
encrypted_temp_key = "".to_string();
}
let jwe_header = JweHeader {
kid: self.key_id.clone(),
enc: SM4GCM.to_string(),
alg: alg.get_name().to_string(),
};
let header_str = opt_result!(
serde_json::to_string(&jwe_header),
"serialize header failed: {}"
);
let header_base64 = encode_url_safe_no_pad(header_str.as_bytes());
let nonce: [u8; 12] = rand::random();
let mut encryptor;
if alg == JweAlg::Sm4Skw {
let temp_sm4_key = Sm4Key::from_slice(&temp_key)?;
encryptor = Sm4GcmStreamEncryptor::new(&temp_sm4_key, &nonce);
} else {
encryptor = Sm4GcmStreamEncryptor::new(&self.sm4key, &nonce);
}
encryptor.init_adata(header_base64.as_bytes());
let mut ciphertext = encryptor.update(message);
let (enc2, tag) = encryptor.finalize();
ciphertext.extend_from_slice(&enc2);
Ok(format!(
"{}.{}.{}.{}.{}",
header_base64,
encrypted_temp_key,
encode_url_safe_no_pad(&nonce),
encode_url_safe_no_pad(&ciphertext),
encode_url_safe_no_pad(&tag)
))
}
}
fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult<Vec<u8>> {
let nonce: [u8; 12] = rand::random();
let encrypted_key = sm4_gcm_encrypt(sm4key, &nonce, key);
let mut ske = nonce.to_vec();
ske.extend_from_slice(&encrypted_key);
Ok(ske)
}
fn encode_url_safe_no_pad(bytes: &[u8]) -> String {
URL_SAFE_NO_PAD.encode(bytes)
}