feat: add sm2 encryption mode
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
|
.idea/
|
||||||
# ---> Rust
|
# ---> Rust
|
||||||
# Generated by Cargo
|
# Generated by Cargo
|
||||||
# will have compiled files and executables
|
# will have compiled files and executables
|
||||||
|
|||||||
18
examples/encrypt_and_decrypt.rs
Normal file
18
examples/encrypt_and_decrypt.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
use libsm::sm2::encrypt::{DecryptCtx, EncryptCtx, Sm2EncryptionMode};
|
||||||
|
use libsm::sm2::signature::SigCtx;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let msg = "hello world".as_bytes();
|
||||||
|
let klen = msg.len();
|
||||||
|
let ctx = SigCtx::new();
|
||||||
|
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
|
||||||
|
|
||||||
|
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
|
||||||
|
let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C3C2).unwrap();
|
||||||
|
|
||||||
|
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
|
||||||
|
let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C3C2).unwrap();
|
||||||
|
|
||||||
|
println!("{}", hex::encode(&cipher));
|
||||||
|
println!("{}", String::from_utf8_lossy(&plain));
|
||||||
|
}
|
||||||
@@ -1,9 +1,16 @@
|
|||||||
use num_bigint::BigUint;
|
use num_bigint::BigUint;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
|
||||||
use super::ecc::{EccCtx, Point};
|
|
||||||
use crate::sm2::error::{Sm2Error, Sm2Result};
|
|
||||||
use crate::{sm2::util::kdf, sm3::hash::Sm3Hash};
|
use crate::{sm2::util::kdf, sm3::hash::Sm3Hash};
|
||||||
|
use crate::sm2::error::{Sm2Error, Sm2Result};
|
||||||
|
|
||||||
|
use super::ecc::{EccCtx, Point};
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||||
|
pub enum Sm2EncryptionMode {
|
||||||
|
C1C2C3,
|
||||||
|
C1C3C2,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct EncryptCtx {
|
pub struct EncryptCtx {
|
||||||
klen: usize,
|
klen: usize,
|
||||||
@@ -27,7 +34,7 @@ impl EncryptCtx {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// klen bytes, result: C1+C2+C3
|
// klen bytes, result: C1+C2+C3
|
||||||
pub fn encrypt(&self, msg: &[u8]) -> Sm2Result<Vec<u8>> {
|
pub fn encrypt(&self, msg: &[u8], mode: Sm2EncryptionMode) -> Sm2Result<Vec<u8>> {
|
||||||
loop {
|
loop {
|
||||||
let k = self.curve.random_uint();
|
let k = self.curve.random_uint();
|
||||||
let c_1_point = self.curve.g_mul(&k)?;
|
let c_1_point = self.curve.g_mul(&k)?;
|
||||||
@@ -66,7 +73,11 @@ impl EncryptCtx {
|
|||||||
let c_3 = Sm3Hash::new(&prepend).get_hash();
|
let c_3 = Sm3Hash::new(&prepend).get_hash();
|
||||||
let c_1_bytes = self.curve.point_to_bytes(&c_1_point, false)?;
|
let c_1_bytes = self.curve.point_to_bytes(&c_1_point, false)?;
|
||||||
|
|
||||||
let a = [c_1_bytes, t, c_3.to_vec()].concat();
|
let a = if mode == Sm2EncryptionMode::C1C2C3 {
|
||||||
|
[c_1_bytes, t, c_3.to_vec()].concat()
|
||||||
|
} else {
|
||||||
|
[c_1_bytes, c_3.to_vec(), t].concat()
|
||||||
|
};
|
||||||
return Ok(a);
|
return Ok(a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -82,7 +93,7 @@ impl DecryptCtx {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrypt(&self, cipher: &[u8]) -> Sm2Result<Vec<u8>> {
|
pub fn decrypt(&self, cipher: &[u8], mode: Sm2EncryptionMode) -> Sm2Result<Vec<u8>> {
|
||||||
let c_1_bytes = &cipher[0..65];
|
let c_1_bytes = &cipher[0..65];
|
||||||
let c_1_point = self.curve.bytes_to_point(c_1_bytes)?;
|
let c_1_point = self.curve.bytes_to_point(c_1_bytes)?;
|
||||||
// if c_1_point not in curve, return error, todo return error
|
// if c_1_point not in curve, return error, todo return error
|
||||||
@@ -115,7 +126,13 @@ impl DecryptCtx {
|
|||||||
if flag {
|
if flag {
|
||||||
return Err(Sm2Error::ZeroData);
|
return Err(Sm2Error::ZeroData);
|
||||||
}
|
}
|
||||||
let mut c_2 = cipher[65..(65 + self.klen)].to_vec();
|
|
||||||
|
let (mut c_2, c_3) = if mode == Sm2EncryptionMode::C1C2C3 {
|
||||||
|
(cipher[65..(65 + self.klen)].to_vec(), &cipher[(65 + self.klen)..])
|
||||||
|
} else {
|
||||||
|
(cipher[(65 + 32)..].to_vec(), &cipher[65..(65 + 32)])
|
||||||
|
};
|
||||||
|
|
||||||
for i in 0..self.klen {
|
for i in 0..self.klen {
|
||||||
c_2[i] ^= t[i];
|
c_2[i] ^= t[i];
|
||||||
}
|
}
|
||||||
@@ -123,7 +140,6 @@ impl DecryptCtx {
|
|||||||
prepend.extend_from_slice(&x_2_bytes);
|
prepend.extend_from_slice(&x_2_bytes);
|
||||||
prepend.extend_from_slice(&c_2);
|
prepend.extend_from_slice(&c_2);
|
||||||
prepend.extend_from_slice(&y_2_bytes);
|
prepend.extend_from_slice(&y_2_bytes);
|
||||||
let c_3 = &cipher[(65 + self.klen)..];
|
|
||||||
let u = Sm3Hash::new(&prepend).get_hash();
|
let u = Sm3Hash::new(&prepend).get_hash();
|
||||||
if c_3 != u {
|
if c_3 != u {
|
||||||
return Err(Sm2Error::HashNotEqual);
|
return Err(Sm2Error::HashNotEqual);
|
||||||
@@ -146,10 +162,10 @@ mod tests {
|
|||||||
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
|
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
|
||||||
|
|
||||||
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
|
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
|
||||||
let cipher = encrypt_ctx.encrypt(msg).unwrap();
|
let cipher = encrypt_ctx.encrypt(msg, Sm2EncryptionMode::C1C2C3).unwrap();
|
||||||
|
|
||||||
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
|
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
|
||||||
let plain = decrypt_ctx.decrypt(&cipher).unwrap();
|
let plain = decrypt_ctx.decrypt(&cipher, Sm2EncryptionMode::C1C2C3).unwrap();
|
||||||
assert_eq!(msg, plain);
|
assert_eq!(msg, plain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user