145 lines
3.9 KiB
Rust
145 lines
3.9 KiB
Rust
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
|
use base64::Engine;
|
|
use rust_util::{opt_result, simple_error, SimpleError, XResult};
|
|
use serde::{Deserialize, Serialize};
|
|
use sm4_gcm::{sm4_gcm_encrypt, Sm4GcmStreamEncryptor, Sm4Key};
|
|
use zeroize::Zeroize;
|
|
|
|
const SM4GCM: &str = "SM4GCM";
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct JweHeader {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
kid: Option<String>,
|
|
enc: String,
|
|
alg: String,
|
|
}
|
|
|
|
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
|
|
pub enum JweAlg {
|
|
Dir,
|
|
Sm2Pke,
|
|
Sm4Skw,
|
|
}
|
|
|
|
impl JweAlg {
|
|
pub fn get_name(&self) -> &'static str {
|
|
match self {
|
|
JweAlg::Dir => "dir",
|
|
JweAlg::Sm2Pke => "SM2PKE",
|
|
JweAlg::Sm4Skw => "SM4SKW",
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Sm4Jwk {
|
|
key_id: Option<String>,
|
|
sm4key: Sm4Key,
|
|
}
|
|
|
|
impl TryFrom<Vec<u8>> for Sm4Jwk {
|
|
type Error = SimpleError;
|
|
|
|
fn try_from(mut value: Vec<u8>) -> Result<Self, Self::Error> {
|
|
let sm4key = opt_result!(Sm4Key::from_slice(&value), "Invalid SM4 key: {}");
|
|
value.zeroize();
|
|
Ok(Sm4Jwk {
|
|
key_id: None,
|
|
sm4key,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl TryFrom<[u8; 16]> for Sm4Jwk {
|
|
type Error = SimpleError;
|
|
|
|
fn try_from(mut value: [u8; 16]) -> Result<Self, Self::Error> {
|
|
let sm4key = opt_result!(Sm4Key::from_slice(&value), "Invalid SM4 key: {}");
|
|
value.zeroize();
|
|
Ok(Sm4Jwk {
|
|
key_id: None,
|
|
sm4key,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test() {
|
|
let k = Sm4Jwk::try_from(b"0123456789012345".to_vec())
|
|
.unwrap()
|
|
.key_id("k001");
|
|
println!("{}", k.encrypt(JweAlg::Dir, b"hello world").unwrap());
|
|
println!("{}", k.encrypt(JweAlg::Sm4Skw, b"hello world").unwrap());
|
|
}
|
|
|
|
impl Sm4Jwk {
|
|
pub fn key_id(mut self, key_id: &str) -> Self {
|
|
self.key_id = Some(key_id.to_string());
|
|
self
|
|
}
|
|
|
|
pub fn encrypt(&self, alg: JweAlg, message: &[u8]) -> XResult<String> {
|
|
if alg == JweAlg::Sm2Pke {
|
|
return simple_error!("SM2PKE is not supported");
|
|
}
|
|
|
|
let encrypted_temp_key;
|
|
let temp_key: [u8; 16];
|
|
if alg == JweAlg::Sm4Skw {
|
|
temp_key = rand::random();
|
|
encrypted_temp_key = encode_url_safe_no_pad(&encrypt_sm4ske(&self.sm4key, &temp_key)?);
|
|
} else {
|
|
temp_key = [0_u8; 16];
|
|
encrypted_temp_key = "".to_string();
|
|
}
|
|
|
|
let jwe_header = JweHeader {
|
|
kid: self.key_id.clone(),
|
|
enc: SM4GCM.to_string(),
|
|
alg: alg.get_name().to_string(),
|
|
};
|
|
|
|
let header_str = opt_result!(
|
|
serde_json::to_string(&jwe_header),
|
|
"serialize header failed: {}"
|
|
);
|
|
let header_base64 = encode_url_safe_no_pad(header_str.as_bytes());
|
|
|
|
let nonce: [u8; 12] = rand::random();
|
|
|
|
let mut encryptor;
|
|
if alg == JweAlg::Sm4Skw {
|
|
let temp_sm4_key = Sm4Key::from_slice(&temp_key)?;
|
|
encryptor = Sm4GcmStreamEncryptor::new(&temp_sm4_key, &nonce);
|
|
} else {
|
|
encryptor = Sm4GcmStreamEncryptor::new(&self.sm4key, &nonce);
|
|
}
|
|
encryptor.init_adata(header_base64.as_bytes());
|
|
|
|
let mut ciphertext = encryptor.update(message);
|
|
let (enc2, tag) = encryptor.finalize();
|
|
ciphertext.extend_from_slice(&enc2);
|
|
|
|
Ok(format!(
|
|
"{}.{}.{}.{}.{}",
|
|
header_base64,
|
|
encrypted_temp_key,
|
|
encode_url_safe_no_pad(&nonce),
|
|
encode_url_safe_no_pad(&ciphertext),
|
|
encode_url_safe_no_pad(&tag)
|
|
))
|
|
}
|
|
}
|
|
|
|
fn encrypt_sm4ske(sm4key: &Sm4Key, key: &[u8]) -> XResult<Vec<u8>> {
|
|
let nonce: [u8; 12] = rand::random();
|
|
let encrypted_key = sm4_gcm_encrypt(sm4key, &nonce, key);
|
|
let mut ske = nonce.to_vec();
|
|
ske.extend_from_slice(&encrypted_key);
|
|
Ok(ske)
|
|
}
|
|
|
|
fn encode_url_safe_no_pad(bytes: &[u8]) -> String {
|
|
URL_SAFE_NO_PAD.encode(bytes)
|
|
}
|