feat: add sm2 encryption mode

This commit is contained in:
2023-10-22 14:24:25 +08:00
parent bcd06779ca
commit 8657e5ffce
3 changed files with 44 additions and 9 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.idea/
# ---> Rust
# Generated by Cargo
# will have compiled files and executables

View File

@@ -0,0 +1,18 @@
use libsm::sm2::encrypt::{DecryptCtx, EncryptCtx, Sm2EncryptionMode};
use libsm::sm2::signature::SigCtx;
fn main() {
let msg = "hello world".as_bytes();
let klen = msg.len();
let ctx = SigCtx::new();
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C3C2).unwrap();
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C3C2).unwrap();
println!("{}", hex::encode(&cipher));
println!("{}", String::from_utf8_lossy(&plain));
}

View File

@@ -1,9 +1,16 @@
use num_bigint::BigUint;
use num_traits::One;
use super::ecc::{EccCtx, Point};
use crate::sm2::error::{Sm2Error, Sm2Result};
use crate::{sm2::util::kdf, sm3::hash::Sm3Hash};
use crate::sm2::error::{Sm2Error, Sm2Result};
use super::ecc::{EccCtx, Point};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Sm2EncryptionMode {
C1C2C3,
C1C3C2,
}
pub struct EncryptCtx {
klen: usize,
@@ -27,7 +34,7 @@ impl EncryptCtx {
}
// klen bytes, result: C1+C2+C3
pub fn encrypt(&self, msg: &[u8]) -> Sm2Result<Vec<u8>> {
pub fn encrypt(&self, msg: &[u8], mode: Sm2EncryptionMode) -> Sm2Result<Vec<u8>> {
loop {
let k = self.curve.random_uint();
let c_1_point = self.curve.g_mul(&k)?;
@@ -66,7 +73,11 @@ impl EncryptCtx {
let c_3 = Sm3Hash::new(&prepend).get_hash();
let c_1_bytes = self.curve.point_to_bytes(&c_1_point, false)?;
let a = [c_1_bytes, t, c_3.to_vec()].concat();
let a = if mode == Sm2EncryptionMode::C1C2C3 {
[c_1_bytes, t, c_3.to_vec()].concat()
} else {
[c_1_bytes, c_3.to_vec(), t].concat()
};
return Ok(a);
}
}
@@ -82,7 +93,7 @@ impl DecryptCtx {
}
}
pub fn decrypt(&self, cipher: &[u8]) -> Sm2Result<Vec<u8>> {
pub fn decrypt(&self, cipher: &[u8], mode: Sm2EncryptionMode) -> Sm2Result<Vec<u8>> {
let c_1_bytes = &cipher[0..65];
let c_1_point = self.curve.bytes_to_point(c_1_bytes)?;
// if c_1_point not in curve, return error, todo return error
@@ -115,7 +126,13 @@ impl DecryptCtx {
if flag {
return Err(Sm2Error::ZeroData);
}
let mut c_2 = cipher[65..(65 + self.klen)].to_vec();
let (mut c_2, c_3) = if mode == Sm2EncryptionMode::C1C2C3 {
(cipher[65..(65 + self.klen)].to_vec(), &cipher[(65 + self.klen)..])
} else {
(cipher[(65 + 32)..].to_vec(), &cipher[65..(65 + 32)])
};
for i in 0..self.klen {
c_2[i] ^= t[i];
}
@@ -123,7 +140,6 @@ impl DecryptCtx {
prepend.extend_from_slice(&x_2_bytes);
prepend.extend_from_slice(&c_2);
prepend.extend_from_slice(&y_2_bytes);
let c_3 = &cipher[(65 + self.klen)..];
let u = Sm3Hash::new(&prepend).get_hash();
if c_3 != u {
return Err(Sm2Error::HashNotEqual);
@@ -146,10 +162,10 @@ mod tests {
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
let cipher = encrypt_ctx.encrypt(msg).unwrap();
let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C2C3).unwrap();
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
let plain = decrypt_ctx.decrypt(&cipher).unwrap();
let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C2C3).unwrap();
assert_eq!(msg, plain);
}
}