feat: add sm2 encryption mode
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.idea/
|
||||
# ---> Rust
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
|
||||
18
examples/encrypt_and_decrypt.rs
Normal file
18
examples/encrypt_and_decrypt.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use libsm::sm2::encrypt::{DecryptCtx, EncryptCtx, Sm2EncryptionMode};
|
||||
use libsm::sm2::signature::SigCtx;
|
||||
|
||||
fn main() {
|
||||
let msg = "hello world".as_bytes();
|
||||
let klen = msg.len();
|
||||
let ctx = SigCtx::new();
|
||||
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
|
||||
|
||||
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
|
||||
let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C3C2).unwrap();
|
||||
|
||||
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
|
||||
let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C3C2).unwrap();
|
||||
|
||||
println!("{}", hex::encode(&cipher));
|
||||
println!("{}", String::from_utf8_lossy(&plain));
|
||||
}
|
||||
@@ -1,9 +1,16 @@
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::One;
|
||||
|
||||
use super::ecc::{EccCtx, Point};
|
||||
use crate::sm2::error::{Sm2Error, Sm2Result};
|
||||
use crate::{sm2::util::kdf, sm3::hash::Sm3Hash};
|
||||
use crate::sm2::error::{Sm2Error, Sm2Result};
|
||||
|
||||
use super::ecc::{EccCtx, Point};
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub enum Sm2EncryptionMode {
|
||||
C1C2C3,
|
||||
C1C3C2,
|
||||
}
|
||||
|
||||
pub struct EncryptCtx {
|
||||
klen: usize,
|
||||
@@ -27,7 +34,7 @@ impl EncryptCtx {
|
||||
}
|
||||
|
||||
// klen bytes, result: C1+C2+C3
|
||||
pub fn encrypt(&self, msg: &[u8]) -> Sm2Result<Vec<u8>> {
|
||||
pub fn encrypt(&self, msg: &[u8], mode: Sm2EncryptionMode) -> Sm2Result<Vec<u8>> {
|
||||
loop {
|
||||
let k = self.curve.random_uint();
|
||||
let c_1_point = self.curve.g_mul(&k)?;
|
||||
@@ -66,7 +73,11 @@ impl EncryptCtx {
|
||||
let c_3 = Sm3Hash::new(&prepend).get_hash();
|
||||
let c_1_bytes = self.curve.point_to_bytes(&c_1_point, false)?;
|
||||
|
||||
let a = [c_1_bytes, t, c_3.to_vec()].concat();
|
||||
let a = if mode == Sm2EncryptionMode::C1C2C3 {
|
||||
[c_1_bytes, t, c_3.to_vec()].concat()
|
||||
} else {
|
||||
[c_1_bytes, c_3.to_vec(), t].concat()
|
||||
};
|
||||
return Ok(a);
|
||||
}
|
||||
}
|
||||
@@ -82,7 +93,7 @@ impl DecryptCtx {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt(&self, cipher: &[u8]) -> Sm2Result<Vec<u8>> {
|
||||
pub fn decrypt(&self, cipher: &[u8], mode: Sm2EncryptionMode) -> Sm2Result<Vec<u8>> {
|
||||
let c_1_bytes = &cipher[0..65];
|
||||
let c_1_point = self.curve.bytes_to_point(c_1_bytes)?;
|
||||
// if c_1_point not in curve, return error, todo return error
|
||||
@@ -115,7 +126,13 @@ impl DecryptCtx {
|
||||
if flag {
|
||||
return Err(Sm2Error::ZeroData);
|
||||
}
|
||||
let mut c_2 = cipher[65..(65 + self.klen)].to_vec();
|
||||
|
||||
let (mut c_2, c_3) = if mode == Sm2EncryptionMode::C1C2C3 {
|
||||
(cipher[65..(65 + self.klen)].to_vec(), &cipher[(65 + self.klen)..])
|
||||
} else {
|
||||
(cipher[(65 + 32)..].to_vec(), &cipher[65..(65 + 32)])
|
||||
};
|
||||
|
||||
for i in 0..self.klen {
|
||||
c_2[i] ^= t[i];
|
||||
}
|
||||
@@ -123,7 +140,6 @@ impl DecryptCtx {
|
||||
prepend.extend_from_slice(&x_2_bytes);
|
||||
prepend.extend_from_slice(&c_2);
|
||||
prepend.extend_from_slice(&y_2_bytes);
|
||||
let c_3 = &cipher[(65 + self.klen)..];
|
||||
let u = Sm3Hash::new(&prepend).get_hash();
|
||||
if c_3 != u {
|
||||
return Err(Sm2Error::HashNotEqual);
|
||||
@@ -146,10 +162,10 @@ mod tests {
|
||||
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
|
||||
|
||||
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
|
||||
let cipher = encrypt_ctx.encrypt(msg).unwrap();
|
||||
let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C2C3).unwrap();
|
||||
|
||||
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
|
||||
let plain = decrypt_ctx.decrypt(&cipher).unwrap();
|
||||
let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C2C3).unwrap();
|
||||
assert_eq!(msg, plain);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user