From 8657e5ffce4a4ef9a8232464e5cbc064e8928abd Mon Sep 17 00:00:00 2001 From: Hatter Jiang Date: Sun, 22 Oct 2023 14:24:25 +0800 Subject: [PATCH] feat: add sm2 encryption mode --- .gitignore | 1 + examples/encrypt_and_decrypt.rs | 18 +++++++++++++++++ src/sm2/encrypt.rs | 34 ++++++++++++++++++++++++--------- 3 files changed, 44 insertions(+), 9 deletions(-) create mode 100644 examples/encrypt_and_decrypt.rs diff --git a/.gitignore b/.gitignore index 3bf25c0..409abaa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea/ # ---> Rust # Generated by Cargo # will have compiled files and executables diff --git a/examples/encrypt_and_decrypt.rs b/examples/encrypt_and_decrypt.rs new file mode 100644 index 0000000..bfb313c --- /dev/null +++ b/examples/encrypt_and_decrypt.rs @@ -0,0 +1,18 @@ +use libsm::sm2::encrypt::{DecryptCtx, EncryptCtx, Sm2EncryptionMode}; +use libsm::sm2::signature::SigCtx; + +fn main() { + let msg = "hello world".as_bytes(); + let klen = msg.len(); + let ctx = SigCtx::new(); + let (pk_b, sk_b) = ctx.new_keypair().unwrap(); + + let encrypt_ctx = EncryptCtx::new(klen, pk_b); + let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C3C2).unwrap(); + + let decrypt_ctx = DecryptCtx::new(klen, sk_b); + let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C3C2).unwrap(); + + println!("{}", hex::encode(&cipher)); + println!("{}", String::from_utf8_lossy(&plain)); +} \ No newline at end of file diff --git a/src/sm2/encrypt.rs b/src/sm2/encrypt.rs index 33e404c..cd322f4 100644 --- a/src/sm2/encrypt.rs +++ b/src/sm2/encrypt.rs @@ -1,9 +1,16 @@ use num_bigint::BigUint; use num_traits::One; -use super::ecc::{EccCtx, Point}; -use crate::sm2::error::{Sm2Error, Sm2Result}; use crate::{sm2::util::kdf, sm3::hash::Sm3Hash}; +use crate::sm2::error::{Sm2Error, Sm2Result}; + +use super::ecc::{EccCtx, Point}; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum Sm2EncryptionMode { + C1C2C3, + C1C3C2, +} pub struct EncryptCtx { klen: usize, @@ -27,7 +34,7 @@ impl EncryptCtx { } // klen bytes, result: C1+C2+C3 - pub fn encrypt(&self, msg: &[u8]) -> Sm2Result> { + pub fn encrypt(&self, msg: &[u8], mode: Sm2EncryptionMode) -> Sm2Result> { loop { let k = self.curve.random_uint(); let c_1_point = self.curve.g_mul(&k)?; @@ -66,7 +73,11 @@ impl EncryptCtx { let c_3 = Sm3Hash::new(&prepend).get_hash(); let c_1_bytes = self.curve.point_to_bytes(&c_1_point, false)?; - let a = [c_1_bytes, t, c_3.to_vec()].concat(); + let a = if mode == Sm2EncryptionMode::C1C2C3 { + [c_1_bytes, t, c_3.to_vec()].concat() + } else { + [c_1_bytes, c_3.to_vec(), t].concat() + }; return Ok(a); } } @@ -82,7 +93,7 @@ impl DecryptCtx { } } - pub fn decrypt(&self, cipher: &[u8]) -> Sm2Result> { + pub fn decrypt(&self, cipher: &[u8], mode: Sm2EncryptionMode) -> Sm2Result> { let c_1_bytes = &cipher[0..65]; let c_1_point = self.curve.bytes_to_point(c_1_bytes)?; // if c_1_point not in curve, return error, todo return error @@ -115,7 +126,13 @@ impl DecryptCtx { if flag { return Err(Sm2Error::ZeroData); } - let mut c_2 = cipher[65..(65 + self.klen)].to_vec(); + + let (mut c_2, c_3) = if mode == Sm2EncryptionMode::C1C2C3 { + (cipher[65..(65 + self.klen)].to_vec(), &cipher[(65 + self.klen)..]) + } else { + (cipher[(65 + 32)..].to_vec(), &cipher[65..(65 + 32)]) + }; + for i in 0..self.klen { c_2[i] ^= t[i]; } @@ -123,7 +140,6 @@ impl DecryptCtx { prepend.extend_from_slice(&x_2_bytes); prepend.extend_from_slice(&c_2); prepend.extend_from_slice(&y_2_bytes); - let c_3 = &cipher[(65 + self.klen)..]; let u = Sm3Hash::new(&prepend).get_hash(); if c_3 != u { return Err(Sm2Error::HashNotEqual); @@ -146,10 +162,10 @@ mod tests { let (pk_b, sk_b) = ctx.new_keypair().unwrap(); let encrypt_ctx = EncryptCtx::new(klen, pk_b); - let cipher = encrypt_ctx.encrypt(msg).unwrap(); + let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C2C3).unwrap(); let decrypt_ctx = DecryptCtx::new(klen, sk_b); - let plain = decrypt_ctx.decrypt(&cipher).unwrap(); + let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C2C3).unwrap(); assert_eq!(msg, plain); } }