feat: init commit, fork from libsm

This commit is contained in:
2023-10-22 13:56:53 +08:00
parent 8a0abd0048
commit bcd06779ca
28 changed files with 4167 additions and 0 deletions

5
AUTHORS Normal file
View File

@@ -0,0 +1,5 @@
Cryptape Technology LLC.
BEIHANG KNOC LAB
Tang Zongxun and Niu Junxiang from BEIHANG KNOC LAB are original authors.
This project is maintained by Cryptape Technology LLC.

30
Cargo.toml Normal file
View File

@@ -0,0 +1,30 @@
[package]
name = "libsm"
version = "0.5.1"
authors = [
"Tang Zongxun <tangzongxun@hotmail.com>",
"Niu Junxiang <494521900@qq.com>",
"yieazy <yuitta@163.com>",
"Rivtower Technologies <contact@rivtower.com>"
]
description = "A Rust Library of China's Standards of Encryption Algorithms (SM2/3/4)"
keywords = ["cipher", "cryptography"]
license = "Apache-2.0"
edition = "2021"
[dependencies]
rand = "0.8"
byteorder = "1.0"
num-bigint = "0.4"
num-traits = "0.2"
num-integer = "0.1"
yasna = { version = "0.5", features = [ "num-bigint" ]}
lazy_static = "1.0"
getrandom = { version = "0.2", features = ["js"] }
[dev-dependencies]
hex = "0.4"
base64 = "0.21"
[features]
internal_benches = []

202
LICENSE.txt Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

21
README-original.md Normal file
View File

@@ -0,0 +1,21 @@
# Libsm
Libsm is an open source pure rust library of China Cryptographic Algorithm Standards. It is completed by a collaborative effort between the Cryptape Technology LLC. and BEIHANG KNOC LAB. And now this project is maintained by Cryptape Technology LLC.
## GM/T Algorithms
Libsm implements the following GM/T cryptographic algorithms:
* SM2 (GM/T 0003-2012): elliptic curve cryptographic schemes including digital signature scheme, public key encryption, (authenticated) key exchange protocol and one recommended 256-bit prime field curve sm2p256v1.
* SM3 (GM/T 0004-2012): cryptographic hash function with 256-bit digest length.
* SM4 (GM/T 0002-2012): block cipher with 128-bit key length and 128-bit block size, also named SMS4.
## Documents
* [SM2](/docs/sm2.md)
* [SM3](/docs/sm3.md)
* [SM4](/docs/sm4.md)
## License
Libsm is currently under the [Apache 2.0 license](LICENSE.txt).

BIN
docs/images/sm2.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

139
docs/sm2.md Normal file
View File

@@ -0,0 +1,139 @@
# SM2
SM2 can be used in digital signature.
Algorithms below are related:
- Key generation
- Sign
- Verify
- Encrypt
- Decrypt
- Key exchange
- Serialization and deserialization
## Create a New Contex
By creating a context, libsm will initialize all the parameters used in those algorithms, including ECC parameters.
```rust
use libsm::sm2::signature::{Pubkey, Seckey, Signature, SigCtx};
let ctx = SigCtx::new();
```
## Generate a Key pair
```rust
let (pk, sk) = ctx.new_keypair();
```
`pk` is a public key use for verifying. `sk` is a secret key used for signing.
The public key can be derived from the secret key.
```rust
let pk = ctx.pk_from_sk(&sk).unwrap();
```
## Sign and Verify
```rust
let signature = ctx.sign(msg, &sk, &pk);
let result: bool = ctx.verify(msg, &pk, &signature);
```
## Encrypt and Decrypt
```rust
let encrypt_ctx = EncryptCtx::new(klen, pk);
let cipher_text = encrypt_ctx.encrypt(msg);
let decrypt_ctx = DecryptCtx::new(klen, sk);
let plain_text = decrypt_ctx.decrypt(&cipher);
```
## Key Exchange
```rust
let mut ctx1 = ExchangeCtxA::new(klen, id_a, id_b, pk_a, pk_b, sk_a);
let mut ctx2 = ExchangeCtxB::new(klen, id_a, id_b, pk_a, pk_b, sk_b);
let r_a_point = ctx1.exchange1();
let (r_b_point, s_b) = ctx2.exchange2(&r_a_point);
let s_a = ctx1.exchange3(&r_b_point, s_b);
let succ: bool = ctx2.exchange4(s_a, &r_a_point);
```
## Serialization and Deserialization
Keys and Signatures can be serialized to ``Vec<u8>``.
### Public Key
```rust
let pk_raw = ctx.serialize_pubkey(&pk, true);
let new_pk = ctx.load_pubkey(&pk_raw[..])?;
```
if you want to compress the public key, set the second parameter of `serialize_pubkey()` to `true`. An uncompressed public key will be 65 bytes, and the compressed key is 33 bytes.
The return value of `load_pubkey()` is ``Result<Pubkey, bool>``. If the public key is invalid, an error will be returned.
### Secret Key
```rust
let sk_raw = ctx.serialize_seckey(&sk);
let new_sk = ctx.load_seckey(&sk_raw[..])?;
```
The output size of `serialize_seckey()` is 32 bytes.
The return value of `load_seckey()` is `Result<Seckey, bool>`. An error will be returned if the secret key is invalid.
### Signature
Signatures can be encoded to DER format.
```rust
let der = signature.der_encode();
let parsed_sig = Signature::der_decode(&der[..])?;
```
## Details of How the Signature is Generated
### 1. Calculate Z_A
First, an `ID` is needed. But in certification applications, no ID is given. So according to the standard, we use the default value, which is "1234567812345678". Then we calculate the length of ID in bits, which is 16 * 8 = 128, and name it as `ID_LEN`. `ID_LEN` should be a 16-bit number in big-endian.
Then, the parameters of the elliptic curve should be given, include `a` and `b` in the curve eqution, and the coordinate of EC group generator, which is `x_G` and `y_G`. For details, see [the standard](http://www.oscca.gov.cn/sca/xxgk/2010-12/17/1002386/files/b965ce832cc34bc191cb1cde446b860d.pdf). And the coordinate of the public key should be appended, which is `x_A` and `y_A`. All of them are 32-byte big-endian numbers.
Hash the concatenation using SM3, and we can get Z_A.
```
Z_A = SM3(ID_LEN || ID || a || b || x_G || y_G || x_A || y_A)
```
Z_A is a 32-byte big-endian number.
### 2. Calculate the Final Hash
Prepend Z_A before the message, and hash again using SM3, then we can get `e`.
```
e = SM3(Z_A || M)
```
### 3. Verify the Message
Finally, `e` is verified by SM2 algorithm. See section 6.2 and 7.2 of the second part of [this documentation](http://www.oscca.gov.cn/sca/xxgk/2010-12/17/1002386/files/b791a9f908bb4803875ab6aeeb7b4e03.pdf).
```
isValid = SM2_Verify(e, signature, publicKey)
```
Signatures generated by this library is compatible with [GmSSL](https://github.com/guanzhi/GmSSL).
Be careful, in SM2, we **cannot** recover the public key using the message and the signature, like what Ethereum did. Because before the verification, the public key must be provided to calculate `e`. To solve this, append the public key after the signature, and extract it before the verification.
![sm2 graph](./images/sm2.png)

20
docs/sm3.md Normal file
View File

@@ -0,0 +1,20 @@
# SM3
SM3 is a hash function. To use SM3 in libsm:
1. Make sure that your data is `&[u8]`.
2. Create a `SM3Hash`.
3. Get the digest.
Sample:
```rust
use libsm::sm3::Sm3Hash;
let string = String::from("sample");
let mut hash = Sm3Hash::new(string.as_bytes());
let digest: [u8;32] = hash.get_hash();
```

53
docs/sm4.md Normal file
View File

@@ -0,0 +1,53 @@
# SM4
SM4 is a block cipher. CFB mode, OFB mode, CTR mode and CBC mode are implemented in libsm.
Here are their definitions:
```rust
pub enum CipherMode {
Cfb,
Ofb,
Ctr,
Cbc,
}
```
## Create Cipher
Choose a mode when creating a cipher. Then call the creating function.
Sample:
```rust
use libsm::sm4::{Mode, Cipher};
use rand::RngCore;
fn rand_block() -> [u8; 16] {
let mut rng = rand::thread_rng();;
let mut block: [u8; 16] = [0; 16];
rng.fill_bytes(&mut block[..]);
block
}
let key = rand_block();
let cipher = Cipher::new(&key, Mode::Cfb);
```
## Encryption and Decryption
Initialize a random IV(Initial Vector), which can be generated using the `rand_block()` function above.
Sample:
```rust
let iv = rand_block();
let plain_text = String::from("plain text");
// Encryption
let cipher_text: Vec<u8> = cipher.encrypt(plain_text.to_bytes(), &iv);
// Decryption
let plain_text: Vec<u8> = cipher.decrypt(&cipher_text[..], &iv);
```

29
src/lib.rs Normal file
View File

@@ -0,0 +1,29 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![cfg_attr(feature = "internal_benches", allow(unstable_features), feature(test))]
pub mod sm2;
pub mod sm3;
pub mod sm4;
extern crate byteorder;
extern crate rand;
extern crate num_bigint;
extern crate num_integer;
extern crate num_traits;
extern crate yasna;
#[macro_use]
extern crate lazy_static;

816
src/sm2/ecc.rs Normal file
View File

@@ -0,0 +1,816 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::field::*;
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::*;
use rand::RngCore;
pub struct EccCtx {
fctx: FieldCtx,
a: FieldElem,
b: FieldElem,
n: BigUint,
}
#[derive(Clone, PartialEq, Eq, Copy)]
pub struct Point {
pub x: FieldElem,
pub y: FieldElem,
pub z: FieldElem,
}
fn g_table() -> Vec<Vec<Point>> {
let ctx = EccCtx::new();
let mut init = BigUint::one();
let radix = BigUint::from(256_u32);
let mut table: Vec<Vec<Point>> = Vec::new();
let mut num: Vec<BigUint> = Vec::new();
for i in 0..256 {
num.push(BigUint::from(i as u32));
}
for _i in 0..32 {
let mut table_row: Vec<Point> = Vec::new();
for item in num.iter().take(256) {
let t = item * &init;
let p1 = ctx.mul(&t, &ctx.generator().unwrap()).unwrap();
table_row.push(p1);
}
table.push(table_row);
init *= &radix;
}
table
}
lazy_static! {
static ref TABLE: Vec<Vec<Point>> = g_table();
}
impl EccCtx {
pub fn new() -> EccCtx {
EccCtx {
fctx: FieldCtx::new(),
a: FieldElem::new([
0xffff_fffe,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0x0000_0000,
0xffff_ffff,
0xffff_fffc,
]),
b: FieldElem::new([
0x28e9_fa9e,
0x9d9f_5e34,
0x4d5a_9e4b,
0xcf65_09a7,
0xf397_89f5,
0x15ab_8f92,
0xddbc_bd41,
0x4d94_0e93,
]),
n: BigUint::from_str_radix(
"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
16,
)
.unwrap(),
}
}
#[inline]
pub fn get_a(&self) -> &FieldElem {
&self.a
}
#[inline]
pub fn get_b(&self) -> &FieldElem {
&self.b
}
#[inline]
pub fn get_n(&self) -> &BigUint {
&self.n
}
pub fn inv_n(&self, x: &BigUint) -> Sm2Result<BigUint> {
if *x == BigUint::zero() {
return Err(Sm2Error::ZeroDivisor);
}
let mut ru = x.clone();
let mut rv = self.get_n().clone();
let mut ra = BigUint::one();
let mut rc = BigUint::zero();
let rn = self.get_n().clone();
while ru != BigUint::zero() {
if ru.is_even() {
ru >>= 1;
if ra.is_even() {
ra >>= 1;
} else {
ra = (ra + &rn) >> 1;
}
}
if rv.is_even() {
rv >>= 1;
if rc.is_even() {
rc >>= 1;
} else {
rc = (rc + &rn) >> 1;
}
}
if ru >= rv {
ru -= &rv;
if ra >= rc {
ra -= &rc;
} else {
ra = ra + &rn - &rc;
}
} else {
rv -= &ru;
if rc >= ra {
rc -= &ra;
} else {
rc = rc + &rn - &ra;
}
}
}
Ok(rc)
}
pub fn check_point(&self, p: &Point) -> Sm2Result<bool> {
let ctx = &self.fctx;
let (ref x, ref y) = self.to_affine(p)?;
// Check if (x, y) is a valid point on the curve(affine projection)
// y^2 = x^3 + a * x + b
let lhs = ctx.mul(y, y)?;
let x_cubic = ctx.mul(x, &ctx.mul(x, x)?)?;
let ax = ctx.mul(x, &self.a)?;
let rhs = ctx.add(&self.b, &ctx.add(&x_cubic, &ax)?)?;
Ok(lhs.eq(&rhs))
}
pub fn new_point(&self, x: &FieldElem, y: &FieldElem) -> Sm2Result<Point> {
let ctx = &self.fctx;
// Check if (x, y) is a valid point on the curve(affine projection)
// y^2 = x^3 + a * x + b
let lhs = ctx.mul(y, y)?;
let x_cubic = ctx.mul(x, &ctx.mul(x, x)?)?;
let ax = ctx.mul(x, &self.a)?;
let rhs = ctx.add(&self.b, &ctx.add(&x_cubic, &ax)?)?;
if !lhs.eq(&rhs) {
return Err(Sm2Error::NotOnCurve);
}
let p = Point {
x: *x,
y: *y,
z: FieldElem::from_num(1),
};
Ok(p)
}
// TODO: load point
// pub fn load_point(&self, buf: &[u8]) -> Result<Point, ()>
pub fn new_jacobian(&self, x: &FieldElem, y: &FieldElem, z: &FieldElem) -> Sm2Result<Point> {
let ctx = &self.fctx;
// Check if (x, y, z) is a valid point on the curve(in jacobian projection)
// y^2 = x^3 + a * x * z^4 + b * z^6
let lhs = ctx.square(y)?;
let r1 = ctx.cubic(x)?;
let r2 = ctx.mul(x, &self.a)?;
let r2 = ctx.mul(&r2, z)?;
let r2 = ctx.mul(&r2, &ctx.cubic(z)?)?;
let r3 = ctx.cubic(z)?;
let r3 = ctx.square(&r3)?;
let r3 = ctx.mul(&r3, &self.b)?;
let rhs = ctx.add(&r1, &ctx.add(&r2, &r3)?)?;
// Require lhs =rhs
if !lhs.eq(&rhs) {
return Err(Sm2Error::InvalidPoint);
}
let p = Point {
x: *x,
y: *y,
z: *z,
};
Ok(p)
}
pub fn generator(&self) -> Sm2Result<Point> {
let x = FieldElem::new([
0x32c4_ae2c,
0x1f19_8119,
0x5f99_0446,
0x6a39_c994,
0x8fe3_0bbf,
0xf266_0be1,
0x715a_4589,
0x334c_74c7,
]);
let y = FieldElem::new([
0xbc37_36a2,
0xf4f6_779c,
0x59bd_cee3,
0x6b69_2153,
0xd0a9_877c,
0xc62a_4740,
0x02df_32e5,
0x2139_f0a0,
]);
self.new_point(&x, &y)
}
pub fn zero(&self) -> Point {
let x = FieldElem::from_num(1);
let y = FieldElem::from_num(1);
let z = FieldElem::zero();
self.new_jacobian(&x, &y, &z).unwrap()
}
pub fn to_affine(&self, p: &Point) -> Sm2Result<(FieldElem, FieldElem)> {
let ctx = &self.fctx;
if p.is_zero() {
return Err(Sm2Error::ZeroPoint);
}
let zinv = ctx.inv(&p.z)?;
let x = ctx.mul(&p.x, &ctx.mul(&zinv, &zinv)?)?;
let y = ctx.mul(&p.y, &ctx.mul(&zinv, &ctx.mul(&zinv, &zinv)?)?)?;
Ok((x, y))
}
pub fn neg(&self, p: &Point) -> Sm2Result<Point> {
let neg_y = self.fctx.neg(&p.y)?;
self.new_jacobian(&p.x, &neg_y, &p.z)
}
//add-1998-cmo-2 curve_add 13m+4s
pub fn add(&self, p1: &Point, p2: &Point) -> Sm2Result<Point> {
if p1.is_zero() {
return Ok(*p2);
} else if p2.is_zero() {
return Ok(*p1);
}
if p1 == p2 {
return self.double(p1);
}
let ctx = &self.fctx;
let z1z1 = ctx.square(&p1.z)?;
let z2z2 = ctx.square(&p2.z)?;
let u1 = ctx.mul(&p1.x, &z2z2)?;
let u2 = ctx.mul(&p2.x, &z1z1)?;
let s1 = ctx.mul(&p1.y, &ctx.mul(&p2.z, &z2z2)?)?;
let s2 = ctx.mul(&p2.y, &ctx.mul(&p1.z, &z1z1)?)?;
let h = ctx.sub(&u2, &u1)?;
let hh = ctx.square(&h)?;
let hhh = ctx.mul(&h, &hh)?;
let r = ctx.sub(&s2, &s1)?;
let v = ctx.mul(&u1, &hh)?;
let x3 = ctx.sub(
&ctx.sub(&ctx.square(&r)?, &hhh)?,
&ctx.mul(&FieldElem::from_num(2), &v)?,
)?;
let rvx3 = ctx.mul(&r, &ctx.sub(&v, &x3)?)?;
let s1hhh = ctx.mul(&s1, &hhh)?;
let y3 = ctx.sub(&rvx3, &s1hhh)?;
let z3 = ctx.mul(&p1.z, &ctx.mul(&p2.z, &h)?)?;
Ok(Point {
x: x3,
y: y3,
z: z3,
})
}
//dbl-1998-cmo-2 9m+6s
// XX = X12
// YY = Y12
// ZZ = Z12
// S = 4*X1*YY
// M = 3*XX+a*ZZ2
// T = M2-2*S
// X3 = T
// Y3 = M*(S-T)-8*YY2
// Z3 = 2*Y1*Z1
pub fn double(&self, p: &Point) -> Sm2Result<Point> {
if p.is_zero() {
return Ok(*p);
}
let ctx = &self.fctx;
let xx = ctx.square(&p.x)?;
let yy = ctx.square(&p.y)?;
let zz = ctx.square(&p.z)?;
let yy8 = ctx.mul(&FieldElem::from_num(8), &ctx.square(&yy)?)?;
let s = ctx.mul(&FieldElem::from_num(4), &ctx.mul(&p.x, &yy)?)?;
let m = ctx.add(
&ctx.mul(&FieldElem::from_num(3), &xx)?,
&ctx.mul(&self.a, &ctx.square(&zz)?)?,
)?;
let x3 = ctx.sub(&ctx.square(&m)?, &ctx.mul(&FieldElem::from_num(2), &s)?)?;
let y3 = ctx.sub(&ctx.mul(&m, &ctx.sub(&s, &x3)?)?, &yy8)?;
let z3 = ctx.mul(&FieldElem::from_num(2), &ctx.mul(&p.y, &p.z)?)?;
Ok(Point {
x: x3,
y: y3,
z: z3,
})
}
pub fn mul(&self, m: &BigUint, p: &Point) -> Sm2Result<Point> {
let m = m % self.get_n();
let k = FieldElem::from_biguint(&m)?;
self.mul_raw_naf(&k.value, p)
}
//w-naf algorithm
//See https://crypto.stackexchange.com/questions/82013/simple-explanation-of-sliding-window-and-wnaf-methods-of-elliptic-curve-point-mu
pub fn w_naf(&self, m: &[u32], w: usize, lst: &mut usize) -> [i8; 257] {
let mut carry = 0;
let mut bit = 0;
let mut ret: [i8; 257] = [0; 257];
let mut n: [u32; 9] = [0; 9];
n[1..9].clone_from_slice(&m[..8]);
let window: u32 = (1 << w) - 1;
while bit < 256 {
let u32_idx = 8 - bit / 32;
let bit_idx = 31 - bit % 32;
if ((n[u32_idx] >> (31 - bit_idx)) & 1) == carry {
bit += 1;
continue;
}
let mut word: u32 = if bit_idx >= w - 1 {
(n[u32_idx] >> (31 - bit_idx)) & window
} else {
((n[u32_idx] >> (31 - bit_idx)) | (n[u32_idx - 1] << (bit_idx + 1))) & window
};
word += carry;
carry = (word >> (w - 1)) & 1;
ret[bit] = word as i8 - (carry << w) as i8;
*lst = bit;
bit += w;
}
if carry == 1 {
ret[256] = 1;
*lst = 256;
}
ret
}
pub fn mul_raw_naf(&self, m: &[u32], p: &Point) -> Sm2Result<Point> {
let mut i = 256;
let mut q = self.zero();
let naf = self.w_naf(m, 5, &mut i);
let offset = 16;
let mut table = [self.zero(); 32];
let double_p = self.double(p)?;
table[1 + offset] = *p;
table[offset - 1] = self.neg(&table[1 + offset])?;
for i in 1..8 {
table[2 * i + offset + 1] = self.add(&double_p, &table[2 * i + offset - 1])?;
table[offset - 2 * i - 1] = self.neg(&table[2 * i + offset + 1])?;
}
loop {
q = self.double(&q)?;
if naf[i] != 0 {
let index = (naf[i] + 16) as usize;
q = self.add(&q, &table[index])?;
}
if i == 0 {
break;
}
i -= 1;
}
Ok(q)
}
pub fn g_mul(&self, m: &BigUint) -> Sm2Result<Point> {
let m = m % self.get_n();
let k = FieldElem::from_biguint(&m).unwrap();
let mut q = self.zero();
for i in 0..8 {
for j in 0..4 {
let bits = ((k.value[i] >> (8 * (3 - j))) & 0xff) as usize;
let index = 31 - i * 4 - j;
q = self.add(&q, &TABLE[index][bits])?;
}
}
Ok(q)
}
pub fn eq(&self, p1: &Point, p2: &Point) -> Sm2Result<bool> {
let z1 = &p1.z;
let z2 = &p2.z;
if z1.eq(&FieldElem::zero()) {
return Ok(z2.eq(&FieldElem::zero()));
} else if z2.eq(&FieldElem::zero()) {
return Ok(false);
}
let (p1x, p1y) = self.to_affine(p1)?;
let (p2x, p2y) = self.to_affine(p2)?;
Ok(p1x.eq(&p2x) && p1y.eq(&p2y))
}
pub fn random_uint(&self) -> BigUint {
let mut rng = rand::thread_rng();
let mut buf: [u8; 32] = [0; 32];
let mut ret;
loop {
rng.fill_bytes(&mut buf[..]);
ret = BigUint::from_bytes_be(&buf[..]);
if ret < self.get_n() - BigUint::one() && ret != BigUint::zero() {
break;
}
}
ret
}
pub fn point_to_bytes(&self, p: &Point, compress: bool) -> Sm2Result<Vec<u8>> {
let (x, y) = self.to_affine(p)?;
let mut ret: Vec<u8> = Vec::new();
if compress {
if y.get_value(7) & 0x01 == 0 {
ret.push(0x02);
} else {
ret.push(0x03);
}
let mut x_vec = x.to_bytes();
ret.append(&mut x_vec);
} else {
ret.push(0x04);
let mut x_vec = x.to_bytes();
let mut y_vec = y.to_bytes();
ret.append(&mut x_vec);
ret.append(&mut y_vec);
}
Ok(ret)
}
pub fn bytes_to_point(&self, b: &[u8]) -> Sm2Result<Point> {
let ctx = &self.fctx;
if b.len() == 33 {
let y_q;
if b[0] == 0x02 {
y_q = 0;
} else if b[0] == 0x03 {
y_q = 1
} else {
return Err(Sm2Error::InvalidPublic);
}
let x = FieldElem::from_bytes(&b[1..])?;
let x_cubic = ctx.mul(&x, &ctx.mul(&x, &x)?)?;
let ax = ctx.mul(&x, &self.a)?;
let y_2 = ctx.add(&self.b, &ctx.add(&x_cubic, &ax)?)?;
let mut y = self.fctx.sqrt(&y_2)?;
if y.get_value(7) & 0x01 != y_q {
y = self.fctx.neg(&y)?;
}
self.new_point(&x, &y)
} else if b.len() == 65 {
if b[0] != 0x04 {
return Err(Sm2Error::InvalidPublic);
}
let x = FieldElem::from_bytes(&b[1..33])?;
let y = FieldElem::from_bytes(&b[33..65])?;
self.new_point(&x, &y)
} else {
Err(Sm2Error::InvalidPublic)
}
}
}
impl Default for EccCtx {
fn default() -> Self {
Self::new()
}
}
impl Point {
pub fn is_zero(&self) -> bool {
self.z.eq(&FieldElem::zero())
}
}
use crate::sm2::error::{Sm2Error, Sm2Result};
use std::fmt;
impl fmt::Display for Point {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let curve = EccCtx::new();
if self.is_zero() {
write!(f, "(O)")
} else {
let (x, y) = curve.to_affine(self).unwrap();
write!(
f,
"(x = 0x{:0>64}, y = 0x{:0>64})",
x.to_str(16),
y.to_str(16)
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_double_neg() {
let curve = EccCtx::new();
let g = curve.generator().unwrap();
let neg_g = curve.neg(&g).unwrap();
let double_g = curve.double(&g).unwrap();
let new_g = curve.add(&double_g, &neg_g).unwrap();
let zero = curve.add(&g, &neg_g).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
assert!(zero.is_zero());
let double_g = curve.double(&g).unwrap(); // 2 * g
let add_g = curve.add(&g, &g).unwrap(); // g + g
assert!(curve.eq(&add_g, &double_g).unwrap());
}
#[test]
fn test_point_add() {
let ecctx = EccCtx::new();
let g = ecctx.generator().unwrap();
let g2 = ecctx.double(&g).unwrap();
println!("{}", ecctx.add(&g, &g2).unwrap());
}
#[test]
fn test_point_double() {
let ecctx = EccCtx::new();
let g = ecctx.generator().unwrap();
println!("{}", ecctx.double(&g).unwrap());
}
#[test]
fn test_multiplication() {
let curve = EccCtx::new();
let g = curve.generator().unwrap();
let double_g = curve.double(&g).unwrap();
let twice_g = curve.mul(&BigUint::from_u32(2).unwrap(), &g).unwrap();
assert!(curve.eq(&double_g, &twice_g).unwrap());
let n = curve.get_n() - BigUint::one();
let new_g = curve.mul(&n, &g).unwrap();
let new_g = curve.add(&new_g, &double_g).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
}
#[test]
fn test_g_multiplication() {
let curve = EccCtx::new();
let g = curve.generator().unwrap();
let twice_g = curve
.g_mul(&BigUint::from_u64(4_294_967_296).unwrap())
.unwrap();
let double_g = curve
.mul(&BigUint::from_u64(4_294_967_296).unwrap(), &g)
.unwrap();
assert!(curve.eq(&double_g, &twice_g).unwrap());
let n = curve.get_n() - BigUint::one();
let new_g = curve.g_mul(&n).unwrap();
let nn_g = curve.mul(&n, &g).unwrap();
assert!(curve.eq(&nn_g, &new_g).unwrap());
}
#[test]
fn test_w_naf() {
let curve = EccCtx::new();
let mut lst = 0;
let n = curve.get_n() - BigUint::one();
let _num = BigUint::from(1122334455_u32) - BigUint::one();
let k = FieldElem::from_biguint(&n).unwrap();
let ret = curve.w_naf(&k.value, 5, &mut lst);
let mut sum = BigUint::zero();
let mut init = BigUint::from_str_radix(
"10000000000000000000000000000000000000000000000000000000000000000",
16,
)
.unwrap();
for j in 0..257 {
let i = 256 - j;
if ret[i] != 0 {
if ret[i] > 0 {
sum += &init * BigUint::from(ret[i] as u8);
} else {
let neg = (0 - ret[i]) as u8;
sum -= &init * BigUint::from(neg);
}
}
init >>= 1;
}
assert_eq!(sum, n);
}
#[test]
fn test_inv_n() {
let curve = EccCtx::new();
for _ in 0..20 {
let r = curve.random_uint();
let r_inv = curve.inv_n(&r).unwrap();
let product = r * r_inv;
let product = product % curve.get_n();
assert_eq!(product, BigUint::one());
}
}
#[test]
fn test_point_bytes_conversion() {
let curve = EccCtx::new();
let g = curve.generator().unwrap();
let g_bytes_uncomp = curve.point_to_bytes(&g, false).unwrap();
let new_g = curve.bytes_to_point(&g_bytes_uncomp[..]).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
let g_bytes_comp = curve.point_to_bytes(&g, true).unwrap();
let new_g = curve.bytes_to_point(&g_bytes_comp[..]).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
let g = curve.double(&g).unwrap();
let g_bytes_uncomp = curve.point_to_bytes(&g, false).unwrap();
let new_g = curve.bytes_to_point(&g_bytes_uncomp[..]).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
let g_bytes_comp = curve.point_to_bytes(&g, true).unwrap();
let new_g = curve.bytes_to_point(&g_bytes_comp[..]).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
let g = curve.double(&g).unwrap();
let g_bytes_uncomp = curve.point_to_bytes(&g, false).unwrap();
let new_g = curve.bytes_to_point(&g_bytes_uncomp[..]).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
let g_bytes_comp = curve.point_to_bytes(&g, true).unwrap();
let new_g = curve.bytes_to_point(&g_bytes_comp[..]).unwrap();
assert!(curve.eq(&g, &new_g).unwrap());
}
}
#[cfg(feature = "internal_benches")]
mod internal_benches {
use crate::sm2::ecc::EccCtx;
use crate::sm2::field::FieldElem;
use num_bigint::BigUint;
use num_traits::Num;
extern crate test;
#[bench]
fn sm2_inv_bench(bench: &mut test::Bencher) {
let ecctx = EccCtx::new();
let fe = FieldElem::from_num(2);
bench.iter(|| {
let _ = ecctx.fctx.inv(&fe);
});
}
#[bench]
fn sm2_point_add_bench(bench: &mut test::Bencher) {
let ecctx = EccCtx::new();
let g = ecctx.generator().unwrap();
let g2 = ecctx.double(&g).unwrap();
bench.iter(|| {
let _ = ecctx.add(&g, &g2);
});
}
#[bench]
fn sm2_point_double_bench(bench: &mut test::Bencher) {
let ecctx = EccCtx::new();
let g = ecctx.generator().unwrap();
let g2 = ecctx.double(&g).unwrap();
bench.iter(|| {
let _ = ecctx.double(&g2);
});
}
#[bench]
fn bench_mul_raw_naf(bench: &mut test::Bencher) {
let curve = EccCtx::new();
let g = curve.generator().unwrap();
let m = BigUint::from_str_radix(
"76415405cbb177ebb37a835a2b5a022f66c250abf482e4cb343dcb2091bc1f2e",
16,
)
.unwrap()
% curve.get_n();
let k = FieldElem::from_biguint(&m).unwrap();
bench.iter(|| {
let _ = curve.mul_raw_naf(&k.value, &g);
});
}
#[bench]
fn bench_gmul(bench: &mut test::Bencher) {
let curve = EccCtx::new();
let m = BigUint::from_str_radix(
"76415405cbb177ebb37a835a2b5a022f66c250abf482e4cb343dcb2091bc1f2e",
16,
)
.unwrap()
% curve.get_n();
bench.iter(|| {
let _ = curve.g_mul(&m);
});
}
}

155
src/sm2/encrypt.rs Normal file
View File

@@ -0,0 +1,155 @@
use num_bigint::BigUint;
use num_traits::One;
use super::ecc::{EccCtx, Point};
use crate::sm2::error::{Sm2Error, Sm2Result};
use crate::{sm2::util::kdf, sm3::hash::Sm3Hash};
pub struct EncryptCtx {
klen: usize,
curve: EccCtx,
pk_b: Point,
}
pub struct DecryptCtx {
klen: usize,
curve: EccCtx,
sk_b: BigUint,
}
impl EncryptCtx {
pub fn new(klen: usize, pk_b: Point) -> EncryptCtx {
EncryptCtx {
klen,
curve: EccCtx::new(),
pk_b,
}
}
// klen bytes, result: C1+C2+C3
pub fn encrypt(&self, msg: &[u8]) -> Sm2Result<Vec<u8>> {
loop {
let k = self.curve.random_uint();
let c_1_point = self.curve.g_mul(&k)?;
let h = BigUint::one();
let s_point = self.curve.mul(&h, &self.pk_b)?;
if s_point.is_zero() {
return Err(Sm2Error::ZeroPoint);
}
let c_2_point = self.curve.mul(&k, &self.pk_b)?;
let (x_2, y_2) = self.curve.to_affine(&c_2_point)?;
let x_2_bytes = x_2.to_bytes();
let y_2_bytes = y_2.to_bytes();
let mut prepend: Vec<u8> = vec![];
prepend.extend_from_slice(&x_2_bytes);
prepend.extend_from_slice(&y_2_bytes);
let mut t = kdf(&prepend, self.klen);
let mut flag = true;
for elem in &t {
if elem != &0 {
flag = false;
break;
}
}
if !flag {
for i in 0..t.len() {
t[i] ^= msg[i];
}
let mut prepend: Vec<u8> = vec![];
prepend.extend_from_slice(&x_2_bytes);
prepend.extend_from_slice(msg);
prepend.extend_from_slice(&y_2_bytes);
let c_3 = Sm3Hash::new(&prepend).get_hash();
let c_1_bytes = self.curve.point_to_bytes(&c_1_point, false)?;
let a = [c_1_bytes, t, c_3.to_vec()].concat();
return Ok(a);
}
}
}
}
impl DecryptCtx {
pub fn new(klen: usize, sk_b: BigUint) -> DecryptCtx {
DecryptCtx {
klen,
curve: EccCtx::new(),
sk_b,
}
}
pub fn decrypt(&self, cipher: &[u8]) -> Sm2Result<Vec<u8>> {
let c_1_bytes = &cipher[0..65];
let c_1_point = self.curve.bytes_to_point(c_1_bytes)?;
// if c_1_point not in curve, return error, todo return error
if !self.curve.check_point(&c_1_point)? {
return Err(Sm2Error::CheckPointErr);
}
let h = BigUint::one();
let s_point = self.curve.mul(&h, &c_1_point)?;
// todo return error
if s_point.is_zero() {
return Err(Sm2Error::ZeroPoint);
}
let c_2_point = self.curve.mul(&self.sk_b, &c_1_point)?;
let (x_2, y_2) = self.curve.to_affine(&c_2_point)?;
let x_2_bytes = x_2.to_bytes();
let y_2_bytes = y_2.to_bytes();
let mut prepend: Vec<u8> = vec![];
prepend.extend_from_slice(&x_2_bytes);
prepend.extend_from_slice(&y_2_bytes);
let t = kdf(&prepend, self.klen);
let mut flag = true;
for elem in &t {
if elem != &0 {
flag = false;
break;
}
}
if flag {
return Err(Sm2Error::ZeroData);
}
let mut c_2 = cipher[65..(65 + self.klen)].to_vec();
for i in 0..self.klen {
c_2[i] ^= t[i];
}
let mut prepend: Vec<u8> = vec![];
prepend.extend_from_slice(&x_2_bytes);
prepend.extend_from_slice(&c_2);
prepend.extend_from_slice(&y_2_bytes);
let c_3 = &cipher[(65 + self.klen)..];
let u = Sm3Hash::new(&prepend).get_hash();
if c_3 != u {
return Err(Sm2Error::HashNotEqual);
}
Ok(c_2)
}
}
#[cfg(test)]
mod tests {
use crate::sm2::signature::SigCtx;
use super::*;
#[test]
fn sm2_encrypt_decrypt_test() {
let msg = "hello world".as_bytes();
let klen = msg.len();
let ctx = SigCtx::new();
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
let encrypt_ctx = EncryptCtx::new(klen, pk_b);
let cipher = encrypt_ctx.encrypt(msg).unwrap();
let decrypt_ctx = DecryptCtx::new(klen, sk_b);
let plain = decrypt_ctx.decrypt(&cipher).unwrap();
assert_eq!(msg, plain);
}
}

106
src/sm2/error.rs Normal file
View File

@@ -0,0 +1,106 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use std::fmt::Formatter;
pub type Sm2Result<T> = Result<T, Sm2Error>;
#[derive(PartialEq)]
pub enum Sm2Error {
NotOnCurve,
FieldSqrtError,
InvalidDer,
InvalidPublic,
InvalidPrivate,
ZeroDivisor,
ZeroPoint,
InvalidPoint,
CheckPointErr,
ZeroData,
HashNotEqual,
IdTooLong,
ZeroFiled,
InvalidFieldLen,
ZeroSig,
InvalidDigestLen,
InvalidSecretKey,
}
impl ::std::fmt::Debug for Sm2Error {
fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{self}")
}
}
impl From<Sm2Error> for &str {
fn from(e: Sm2Error) -> Self {
match e {
Sm2Error::NotOnCurve => "the point not on curve",
Sm2Error::FieldSqrtError => "field elem sqrt error",
Sm2Error::InvalidDer => "invalid der",
Sm2Error::InvalidPublic => "invalid public key",
Sm2Error::InvalidPrivate => "invalid private key",
Sm2Error::ZeroDivisor => "zero has no inversion",
Sm2Error::ZeroPoint => "cannot convert the infinite point to affine",
Sm2Error::InvalidPoint => "invalid jacobian point",
Sm2Error::CheckPointErr => "check point error",
Sm2Error::ZeroData => "the vector is zero",
Sm2Error::HashNotEqual => "hash not equal",
Sm2Error::IdTooLong => "ID is too long",
Sm2Error::ZeroFiled => "zero has no inversion in filed",
Sm2Error::InvalidFieldLen => "a SCA-256 field element must be 32-byte long",
Sm2Error::ZeroSig => "the signature is zero, cannot sign",
Sm2Error::InvalidDigestLen => "the length of digest must be 32-bytes",
Sm2Error::InvalidSecretKey => "invalid secret key",
}
}
}
impl Display for Sm2Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let err_msg = match self {
Sm2Error::NotOnCurve => "the point not on curve",
Sm2Error::FieldSqrtError => "field elem sqrt error",
Sm2Error::InvalidDer => "invalid der",
Sm2Error::InvalidPublic => "invalid public key",
Sm2Error::InvalidPrivate => "invalid private key",
Sm2Error::ZeroDivisor => "zero has no inversion",
Sm2Error::ZeroPoint => "cannot convert the infinite point to affine",
Sm2Error::InvalidPoint => "invalid jacobian point",
Sm2Error::CheckPointErr => "check point error",
Sm2Error::ZeroData => "the vector is zero",
Sm2Error::HashNotEqual => "hash and cipher not equal",
Sm2Error::IdTooLong => "ID is too long",
Sm2Error::ZeroFiled => "zero has no inversion in filed",
Sm2Error::InvalidFieldLen => "a SCA-256 field element must be 32-byte long",
Sm2Error::ZeroSig => "the signature is zero, cannot sign",
Sm2Error::InvalidDigestLen => "the length of digest must be 32-bytes",
Sm2Error::InvalidSecretKey => "invalid secret key",
};
write!(f, "{err_msg}")
}
}
#[cfg(test)]
mod tests {
use super::Sm2Error;
#[test]
fn test_error_display() {
let e = Sm2Error::InvalidPublic;
assert_eq!(format!("{e}"), "invalid public key");
assert_eq!(format!("{e:?}"), "invalid public key");
}
}

339
src/sm2/exchange.rs Normal file
View File

@@ -0,0 +1,339 @@
use super::ecc::*;
use super::util::kdf;
use crate::sm2::error::{Sm2Error, Sm2Result};
use crate::sm3::hash::Sm3Hash;
use byteorder::{BigEndian, WriteBytesExt};
use num_bigint::BigUint;
use num_traits::*;
fn compute_z(id: &str, pk: &Point) -> Sm2Result<[u8; 32]> {
let curve = EccCtx::new();
let mut prepend: Vec<u8> = Vec::new();
if id.len() * 8 > 65535 {
return Err(Sm2Error::IdTooLong);
}
// ENTL_A
prepend
.write_u16::<BigEndian>((id.len() * 8) as u16)
.unwrap();
// ID_A
for c in id.bytes() {
prepend.push(c);
}
let mut a = curve.get_a().to_bytes();
let mut b = curve.get_b().to_bytes();
prepend.append(&mut a);
prepend.append(&mut b);
let (x_g, y_g) = curve.to_affine(&curve.generator()?)?;
let (mut x_g, mut y_g) = (x_g.to_bytes(), y_g.to_bytes());
prepend.append(&mut x_g);
prepend.append(&mut y_g);
let (x_a, y_a) = curve.to_affine(pk)?;
let (mut x_a, mut y_a) = (x_a.to_bytes(), y_a.to_bytes());
prepend.append(&mut x_a);
prepend.append(&mut y_a);
let mut hasher = Sm3Hash::new(&prepend[..]);
Ok(hasher.get_hash())
}
pub struct ExchangeCtxA {
klen: usize,
curve: EccCtx,
z_a: [u8; 32],
z_b: [u8; 32],
pk_b: Point,
sk_a: BigUint,
r_a: Option<BigUint>,
r_a_point: Option<Point>,
k_a: Option<Vec<u8>>,
}
pub struct ExchangeCtxB {
klen: usize,
curve: EccCtx,
z_a: [u8; 32],
z_b: [u8; 32],
pk_a: Point,
sk_b: BigUint,
v: Option<Point>,
r_b: Option<BigUint>,
r_b_point: Option<Point>,
k_b: Option<Vec<u8>>,
}
impl ExchangeCtxA {
pub fn new(
klen: usize,
id_a: &str,
id_b: &str,
pk_a: Point,
pk_b: Point,
sk_a: BigUint,
) -> Sm2Result<ExchangeCtxA> {
Ok(ExchangeCtxA {
klen,
curve: EccCtx::new(),
z_a: compute_z(id_a, &pk_a)?,
z_b: compute_z(id_b, &pk_b)?,
pk_b,
sk_a,
r_a: None,
r_a_point: None,
k_a: None,
})
}
pub fn exchange1(&mut self) -> Sm2Result<Point> {
let r_a = self.curve.random_uint();
let r_a_point = self.curve.g_mul(&r_a)?;
self.r_a = Some(r_a);
self.r_a_point = Some(r_a_point);
Ok(r_a_point)
}
pub fn exchange3(&mut self, r_b_point: &Point, s_b: [u8; 32]) -> Sm2Result<[u8; 32]> {
let (x_1, y_1) = self.curve.to_affine(&self.r_a_point.unwrap())?;
let w = ((self.curve.get_n().bits() as f64) / 2.0).ceil() - 1.0;
let pow_w = BigUint::from_u32(2).unwrap().pow(w as u32);
let x_1_bar = &pow_w + (x_1.to_biguint() & (&pow_w - BigUint::one()));
let t_a = (&self.sk_a + x_1_bar * self.r_a.as_ref().unwrap()) % self.curve.get_n();
if !self.curve.check_point(r_b_point)? {
return Err(Sm2Error::CheckPointErr);
}
let (x_2, y_2) = self.curve.to_affine(r_b_point)?;
let x_2_bar = &pow_w + (x_2.to_biguint() & (&pow_w - BigUint::one()));
let h = BigUint::one();
let coefficient = h * t_a;
let point = self
.curve
.add(&self.pk_b, &self.curve.mul(&x_2_bar, r_b_point)?)?;
let u = self.curve.mul(&coefficient, &point)?;
if u.is_zero() {
return Err(Sm2Error::ZeroPoint);
}
let (x_u, y_u) = self.curve.to_affine(&u)?;
let mut prepend = Vec::new();
let x_u_bytes = x_u.to_bytes();
let y_u_bytes = y_u.to_bytes();
prepend.extend_from_slice(&x_u_bytes);
prepend.extend_from_slice(&y_u_bytes);
prepend.extend_from_slice(&self.z_a);
prepend.extend_from_slice(&self.z_b);
let k_a = kdf(&prepend, self.klen);
self.k_a = Some(k_a);
let mut prepend: Vec<u8> = Vec::new();
prepend.write_u16::<BigEndian>(0x02_u16).unwrap();
prepend.extend_from_slice(&y_u_bytes);
let mut temp: Vec<u8> = Vec::new();
temp.extend_from_slice(&x_u_bytes);
temp.extend_from_slice(&self.z_a);
temp.extend_from_slice(&self.z_b);
temp.extend_from_slice(&x_1.to_bytes());
temp.extend_from_slice(&y_1.to_bytes());
temp.extend_from_slice(&x_2.to_bytes());
temp.extend_from_slice(&y_2.to_bytes());
let temp_hash = Sm3Hash::new(&temp).get_hash();
prepend.extend_from_slice(&temp_hash);
let s_1 = Sm3Hash::new(&prepend).get_hash();
if s_1 != s_b {
return Err(Sm2Error::HashNotEqual);
}
let mut prepend: Vec<u8> = Vec::new();
prepend.write_u16::<BigEndian>(0x03_u16).unwrap();
prepend.extend_from_slice(&y_u_bytes);
prepend.extend_from_slice(&temp_hash);
Ok(Sm3Hash::new(&prepend).get_hash())
}
}
impl ExchangeCtxB {
pub fn new(
klen: usize,
id_a: &str,
id_b: &str,
pk_a: Point,
pk_b: Point,
sk_b: BigUint,
) -> Sm2Result<ExchangeCtxB> {
Ok(ExchangeCtxB {
klen,
curve: EccCtx::new(),
z_a: compute_z(id_a, &pk_a)?,
z_b: compute_z(id_b, &pk_b)?,
pk_a,
sk_b,
v: None,
r_b: None,
r_b_point: None,
k_b: None,
})
}
pub fn exchange2(&mut self, r_a_point: &Point) -> Sm2Result<(Point, [u8; 32])> {
let r_b = self.curve.random_uint();
self.r_b = Some(r_b);
let r_b_point = self.curve.g_mul(self.r_b.as_ref().unwrap())?;
self.r_b_point = Some(r_b_point);
let (x_2, y_2) = self.curve.to_affine(&r_b_point)?;
let w = ((self.curve.get_n().bits() as f64) / 2.0).ceil() - 1.0;
let pow_w = BigUint::from_u32(2).unwrap().pow(w as u32);
let x_2_bar = &pow_w + (x_2.to_biguint() & (&pow_w - BigUint::one()));
let t_b = (&self.sk_b + x_2_bar * self.r_b.as_ref().unwrap()) % self.curve.get_n();
if !self.curve.check_point(r_a_point)? {
return Err(Sm2Error::CheckPointErr);
}
let (x_1, y_1) = self.curve.to_affine(r_a_point)?;
let x_1_bar = &pow_w + (x_1.to_biguint() & (&pow_w - BigUint::one()));
let h = BigUint::one();
let coefficient = h * t_b;
let point = self
.curve
.add(&self.pk_a, &self.curve.mul(&x_1_bar, r_a_point)?)?;
let v = self.curve.mul(&coefficient, &point)?;
if v.is_zero() {
return Err(Sm2Error::ZeroPoint);
}
self.v = Some(v);
let (x_v, y_v) = self.curve.to_affine(&v)?;
let mut prepend = Vec::new();
let x_v_bytes = x_v.to_bytes();
let y_v_bytes = y_v.to_bytes();
prepend.extend_from_slice(&x_v_bytes);
prepend.extend_from_slice(&y_v_bytes);
prepend.extend_from_slice(&self.z_a);
prepend.extend_from_slice(&self.z_b);
let k_b = kdf(&prepend, self.klen);
self.k_b = Some(k_b);
let mut prepend: Vec<u8> = Vec::new();
prepend.write_u16::<BigEndian>(0x02_u16).unwrap();
prepend.extend_from_slice(&y_v_bytes);
let mut temp: Vec<u8> = Vec::new();
temp.extend_from_slice(&x_v_bytes);
temp.extend_from_slice(&self.z_a);
temp.extend_from_slice(&self.z_b);
temp.extend_from_slice(&x_1.to_bytes());
temp.extend_from_slice(&y_1.to_bytes());
temp.extend_from_slice(&x_2.to_bytes());
temp.extend_from_slice(&y_2.to_bytes());
let temp_hash = Sm3Hash::new(&temp).get_hash();
prepend.extend_from_slice(&temp_hash);
let s_b = Sm3Hash::new(&prepend).get_hash();
Ok((r_b_point, s_b))
}
pub fn exchange4(&self, s_a: [u8; 32], r_a_point: &Point) -> Sm2Result<bool> {
let (x_1, y_1) = self.curve.to_affine(r_a_point)?;
let (x_2, y_2) = self.curve.to_affine(self.r_b_point.as_ref().unwrap())?;
let (x_v, y_v) = self.curve.to_affine(self.v.as_ref().unwrap())?;
let x_v_bytes = x_v.to_bytes();
let y_v_bytes = y_v.to_bytes();
let mut prepend: Vec<u8> = Vec::new();
prepend.write_u16::<BigEndian>(0x03_u16).unwrap();
prepend.extend_from_slice(&y_v_bytes);
let mut temp: Vec<u8> = Vec::new();
temp.extend_from_slice(&x_v_bytes);
temp.extend_from_slice(&self.z_a);
temp.extend_from_slice(&self.z_b);
temp.extend_from_slice(&x_1.to_bytes());
temp.extend_from_slice(&y_1.to_bytes());
temp.extend_from_slice(&x_2.to_bytes());
temp.extend_from_slice(&y_2.to_bytes());
let temp_hash = Sm3Hash::new(&temp).get_hash();
prepend.extend_from_slice(&temp_hash);
let s_2 = Sm3Hash::new(&prepend).get_hash();
if s_2 != s_a {
return Err(Sm2Error::HashNotEqual);
}
Ok(s_2 == s_a)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sm2::signature::SigCtx;
#[test]
fn sm2_compute_z_test() {
let ctx = SigCtx::new();
let (pk_a, _sk_a) = ctx.new_keypair().unwrap();
let (pk_b, _sk_b) = ctx.new_keypair().unwrap();
let id_a = "AAAAAAAAAAAAA";
let id_b = "BBBBBBBBBBBBB";
let za = compute_z(id_a, &pk_a);
let zb = compute_z(id_b, &pk_b);
println!("{za:x?}");
println!("{zb:x?}");
}
#[test]
fn sm2_key_exchange_user_test() {
let ctx = SigCtx::new();
let (pk_a, sk_a) = ctx.new_keypair().unwrap();
let (pk_b, sk_b) = ctx.new_keypair().unwrap();
let id_a = "AAAAAAAAAAAAA";
let id_b = "BBBBBBBBBBBBB";
let mut ctx1 = ExchangeCtxA::new(8, id_a, id_b, pk_a, pk_b, sk_a).unwrap();
let mut ctx2 = ExchangeCtxB::new(8, id_a, id_b, pk_a, pk_b, sk_b).unwrap();
let r_a_point = ctx1.exchange1().unwrap();
let (r_b_point, s_b) = ctx2.exchange2(&r_a_point).unwrap();
let s_a = ctx1.exchange3(&r_b_point, s_b).unwrap();
let succ = ctx2.exchange4(s_a, &r_a_point).unwrap();
assert!(succ);
assert_eq!(ctx1.k_a, ctx2.k_b);
}
}

616
src/sm2/field.rs Normal file
View File

@@ -0,0 +1,616 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Implementation of the prime field(SCA-256) used by SM2
use crate::sm2::error::{Sm2Error, Sm2Result};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use num_bigint::BigUint;
use num_traits::Num;
use std::io::Cursor;
pub struct FieldCtx {
modulus: FieldElem,
modulus_complete: FieldElem,
}
impl FieldCtx {
pub fn new() -> FieldCtx {
// p = FFFFFFFE FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF 00000000 FFFFFFFF FFFFFFFF
// = 2^256 - 2^224 - 2^96 + 2^64 -1
let modulus = FieldElem::new([
0xffff_fffe,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0x0000_0000,
0xffff_ffff,
0xffff_ffff,
]);
let (modulus_complete, _borrow) = raw_sub(&FieldElem::zero(), &modulus);
FieldCtx {
modulus,
modulus_complete,
}
}
pub fn add(&self, a: &FieldElem, b: &FieldElem) -> Sm2Result<FieldElem> {
let (raw_sum, carry) = raw_add(a, b);
if carry == 1 || raw_sum >= self.modulus {
let (sum, _borrow) = raw_sub(&raw_sum, &self.modulus);
Ok(sum)
} else {
Ok(raw_sum)
}
}
pub fn sub(&self, a: &FieldElem, b: &FieldElem) -> Sm2Result<FieldElem> {
let (raw_diff, borrow) = raw_sub(a, b);
if borrow == 1 {
let (diff, _borrow) = raw_sub(&raw_diff, &self.modulus_complete);
Ok(diff)
} else {
Ok(raw_diff)
}
}
// a quick algorithm to reduce elements on SCA-256 field
// Reference:
// http://ieeexplore.ieee.org/document/7285166/ for details
#[inline]
fn fast_reduction(&self, input: &[u32; 16]) -> Sm2Result<FieldElem> {
let mut rs: [FieldElem; 10] = [FieldElem::zero(); 10];
let mut rx: [u32; 16] = [0; 16];
let mut i = 0;
while i < 16 {
rx[i] = input[15 - i];
i += 1;
}
rs[0] = FieldElem::new([rx[7], rx[6], rx[5], rx[4], rx[3], rx[2], rx[1], rx[0]]);
rs[1] = FieldElem::new([rx[15], 0, 0, 0, 0, 0, rx[15], rx[14]]);
rs[2] = FieldElem::new([rx[14], 0, 0, 0, 0, 0, rx[14], rx[13]]);
rs[3] = FieldElem::new([rx[13], 0, 0, 0, 0, 0, 0, 0]);
rs[4] = FieldElem::new([rx[12], 0, rx[15], rx[14], rx[13], 0, 0, rx[15]]);
rs[5] = FieldElem::new([rx[15], rx[15], rx[14], rx[13], rx[12], 0, rx[11], rx[10]]);
rs[6] = FieldElem::new([rx[11], rx[14], rx[13], rx[12], rx[11], 0, rx[10], rx[9]]);
rs[7] = FieldElem::new([rx[10], rx[11], rx[10], rx[9], rx[8], 0, rx[13], rx[12]]);
rs[8] = FieldElem::new([rx[9], 0, 0, rx[15], rx[14], 0, rx[9], rx[8]]);
rs[9] = FieldElem::new([rx[8], 0, 0, 0, rx[15], 0, rx[12], rx[11]]);
let mut carry: i32 = 0;
let mut sum = FieldElem::zero();
let (rt, rc) = raw_add(&sum, &rs[1]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[2]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[3]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[4]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &sum);
sum = rt;
carry = carry * 2 + rc as i32;
let (rt, rc) = raw_add(&sum, &rs[5]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[6]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[7]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[8]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_add(&sum, &rs[9]);
sum = rt;
carry += rc as i32;
let mut part3 = FieldElem::zero();
let rt: u64 = u64::from(rx[8]) + u64::from(rx[9]) + u64::from(rx[13]) + u64::from(rx[14]);
part3.value[5] = (rt & 0xffff_ffff) as u32;
part3.value[4] = (rt >> 32) as u32;
let (rt, rc) = raw_add(&sum, &rs[0]);
sum = rt;
carry += rc as i32;
let (rt, rc) = raw_sub(&sum, &part3);
sum = rt;
carry -= rc as i32;
while carry > 0 || sum >= self.modulus {
let (rs, rb) = raw_sub(&sum, &self.modulus);
sum = rs;
carry -= rb as i32;
}
Ok(sum)
}
pub fn mul(&self, a: &FieldElem, b: &FieldElem) -> Sm2Result<FieldElem> {
let raw_prod = raw_mul(a, b);
self.fast_reduction(&raw_prod)
}
#[inline(always)]
pub fn square(&self, a: &FieldElem) -> Sm2Result<FieldElem> {
self.mul(a, a)
}
#[inline(always)]
pub fn cubic(&self, a: &FieldElem) -> Sm2Result<FieldElem> {
self.mul(a, &self.mul(a, a)?)
}
// Extended Eulidean Algorithm(EEA) to calculate x^(-1) mod p
// Reference:
// http://delta.cs.cinvestav.mx/~francisco/arith/julio.pdf
pub fn inv(&self, x: &FieldElem) -> Sm2Result<FieldElem> {
if x.is_zero() {
return Err(Sm2Error::ZeroFiled);
}
let mut ru = *x;
let mut rv = self.modulus;
let mut ra = FieldElem::from_num(1);
let mut rc = FieldElem::zero();
while !ru.is_zero() {
if ru.is_even() {
ru = ru.div2(0);
if ra.is_even() {
ra = ra.div2(0);
} else {
let (sum, car) = raw_add(&ra, &self.modulus);
ra = sum.div2(car);
}
}
if rv.is_even() {
rv = rv.div2(0);
if rc.is_even() {
rc = rc.div2(0);
} else {
let (sum, car) = raw_add(&rc, &self.modulus);
rc = sum.div2(car);
}
}
if ru >= rv {
ru = self.sub(&ru, &rv)?;
ra = self.sub(&ra, &rc)?;
} else {
rv = self.sub(&rv, &ru)?;
rc = self.sub(&rc, &ra)?;
}
}
Ok(rc)
}
pub fn neg(&self, x: &FieldElem) -> Sm2Result<FieldElem> {
self.sub(&self.modulus, x)
}
fn exp(&self, x: &FieldElem, n: &BigUint) -> Sm2Result<FieldElem> {
let u = FieldElem::from_biguint(n)?;
let mut q0 = FieldElem::from_num(1);
let mut q1 = *x;
let mut i = 0;
while i < 256 {
let index = i as usize / 32;
let bit = 31 - i as usize % 32;
let sum = self.mul(&q0, &q1)?;
if (u.get_value(index) >> bit) & 0x01 == 0 {
q1 = sum;
q0 = self.square(&q0)?;
} else {
q0 = sum;
q1 = self.square(&q1)?;
}
i += 1;
}
Ok(q0)
}
// Square root of a field element
pub fn sqrt(&self, g: &FieldElem) -> Sm2Result<FieldElem> {
// p = 4 * u + 3
// u = u + 1
let u = BigUint::from_str_radix(
"28948022302589062189105086303505223191562588497981047863605298483322421248000",
10,
)
.unwrap();
let y = self.exp(g, &u)?;
if self.square(&y)? == *g {
Ok(y)
} else {
Err(Sm2Error::FieldSqrtError)
}
}
}
impl Default for FieldCtx {
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct FieldElem {
pub value: [u32; 8],
}
fn raw_add(a: &FieldElem, b: &FieldElem) -> (FieldElem, u32) {
let mut sum = FieldElem::zero();
let mut carry: u32 = 0;
let t_sum: u64 = u64::from(a.value[7]) + u64::from(b.value[7]) + u64::from(carry);
sum.value[7] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[6]) + u64::from(b.value[6]) + u64::from(carry);
sum.value[6] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[5]) + u64::from(b.value[5]) + u64::from(carry);
sum.value[5] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[4]) + u64::from(b.value[4]) + u64::from(carry);
sum.value[4] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[3]) + u64::from(b.value[3]) + u64::from(carry);
sum.value[3] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[2]) + u64::from(b.value[2]) + u64::from(carry);
sum.value[2] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[1]) + u64::from(b.value[1]) + u64::from(carry);
sum.value[1] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
let t_sum: u64 = u64::from(a.value[0]) + u64::from(b.value[0]) + u64::from(carry);
sum.value[0] = (t_sum & 0xffff_ffff) as u32;
carry = (t_sum >> 32) as u32;
(sum, carry)
}
fn raw_sub(a: &FieldElem, b: &FieldElem) -> (FieldElem, u32) {
let mut sum = FieldElem::new([0; 8]);
let mut borrow: u32 = 0;
let mut j = 0;
while j < 8 {
let i = 7 - j;
let t_sum: i64 = i64::from(a.value[i]) - i64::from(b.value[i]) - i64::from(borrow);
if t_sum < 0 {
sum.value[i] = (t_sum + (1 << 32)) as u32;
borrow = 1;
} else {
sum.value[i] = t_sum as u32;
borrow = 0;
}
j += 1;
}
(sum, borrow)
}
#[inline(always)]
fn u32_mul(a: u32, b: u32) -> (u64, u64) {
let uv = u64::from(a) * u64::from(b);
let u = uv >> 32;
let v = uv & 0xffff_ffff;
(u, v)
}
fn raw_mul(a: &FieldElem, b: &FieldElem) -> [u32; 16] {
let mut local: u64 = 0;
let mut carry: u64 = 0;
let mut ret: [u32; 16] = [0; 16];
let mut ret_idx = 0;
while ret_idx < 15 {
let index = 15 - ret_idx;
let mut a_idx = 0;
while a_idx < 8 {
if a_idx > ret_idx {
break;
}
let b_idx = ret_idx - a_idx;
if b_idx < 8 {
let (hi, lo) = u32_mul(a.value[7 - a_idx], b.value[7 - b_idx]);
local += lo;
carry += hi;
}
a_idx += 1;
}
carry += local >> 32;
local &= 0xffff_ffff;
ret[index] = local as u32;
local = carry;
carry = 0;
ret_idx += 1;
}
ret[0] = local as u32;
ret
}
impl FieldElem {
pub fn new(x: [u32; 8]) -> FieldElem {
FieldElem { value: x }
}
pub fn from_slice(x: &[u32]) -> FieldElem {
let mut arr: [u32; 8] = [0; 8];
arr.copy_from_slice(&x[0..8]);
FieldElem::new(arr)
}
pub fn zero() -> FieldElem {
FieldElem::new([0; 8])
}
#[inline]
pub fn is_zero(&self) -> bool {
self.value == [0; 8]
}
pub fn div2(&self, carry: u32) -> FieldElem {
let mut ret = FieldElem::zero();
let mut carry = carry;
let mut i = 0;
while i < 8 {
ret.value[i] = (carry << 31) + (self.value[i] >> 1);
carry = self.value[i] & 0x01;
i += 1;
}
ret
}
pub fn is_even(&self) -> bool {
self.value[7] & 0x01 == 0
}
// Conversions
pub fn to_bytes(&self) -> Vec<u8> {
let mut ret: Vec<u8> = Vec::new();
for i in 0..8 {
ret.write_u32::<BigEndian>(self.value[i]).unwrap();
}
ret
}
pub fn from_bytes(x: &[u8]) -> Sm2Result<FieldElem> {
if x.len() != 32 {
return Err(Sm2Error::InvalidFieldLen);
}
let mut elem = FieldElem::zero();
let mut c = Cursor::new(x);
for i in 0..8 {
let x = c.read_u32::<BigEndian>().unwrap();
elem.value[i] = x;
}
Ok(elem)
}
pub fn to_biguint(&self) -> BigUint {
let v = self.to_bytes();
BigUint::from_bytes_be(&v[..])
}
pub fn from_biguint(bi: &BigUint) -> Sm2Result<FieldElem> {
let v = bi.to_bytes_be();
let mut num_v = [0u8; 32];
num_v[32 - v.len()..32].copy_from_slice(&v[..]);
FieldElem::from_bytes(&num_v[..])
}
pub fn from_num(x: u64) -> FieldElem {
let mut arr: [u32; 8] = [0; 8];
arr[7] = (x & 0xffff_ffff) as u32;
arr[6] = (x >> 32) as u32;
FieldElem::new(arr)
}
pub fn to_str(&self, radix: u32) -> String {
let b = self.to_biguint();
b.to_str_radix(radix)
}
pub fn get_value(&self, i: usize) -> u32 {
self.value[i]
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::RngCore;
#[test]
fn test_add() {
let ctx = FieldCtx::new();
let a = FieldElem::from_num(1);
let b = FieldElem::from_num(0xffff_ffff);
let c = ctx.add(&a, &b).unwrap();
let c1 = FieldElem::from_num(0x1_0000_0000);
assert!(c == c1);
let b1 = ctx.add(&ctx.modulus, &b).unwrap();
assert!(b1 == b);
}
#[test]
fn test_sub() {
let ctx = FieldCtx::new();
let a = FieldElem::from_num(0xffff_ffff);
let a1 = ctx.sub(&a, &ctx.modulus).unwrap();
assert!(a == a1);
}
fn rand_elem() -> FieldElem {
let mut rng = rand::thread_rng();
let mut buf: [u32; 8] = [0; 8];
for v in buf.iter_mut().take(8) {
*v = rng.next_u32();
}
let ret = FieldElem::new(buf);
let ctx = FieldCtx::new();
if ret >= ctx.modulus {
let (ret, _borrow) = raw_sub(&ret, &ctx.modulus);
return ret;
}
ret
}
#[test]
fn add_sub_rand_test() {
let ctx = FieldCtx::new();
for _i in 0..20 {
let a = rand_elem();
let b = rand_elem();
let c = ctx.add(&a, &b).unwrap();
let a1 = ctx.sub(&c, &b).unwrap();
assert!(a1 == a);
}
}
// test multiplilcations
#[test]
fn test_mul() {
let ctx = FieldCtx::new();
let x = raw_mul(&ctx.modulus, &ctx.modulus);
let y = ctx.fast_reduction(&x).unwrap();
assert!(y.is_zero());
}
#[test]
fn test_div2() {
let x = FieldElem::new([
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
]);
let y = FieldElem::new([
0x7fff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
]);
assert!(y == x.div2(0));
assert!(x == x.div2(1));
assert!(!x.is_even());
assert!(FieldElem::from_num(10).is_even());
}
#[test]
fn test_inv() {
let ctx = FieldCtx::new();
let one = FieldElem::from_num(1);
for _x in 1..100 {
let x = rand_elem();
let xinv = ctx.inv(&x).unwrap();
let y = ctx.mul(&x, &xinv).unwrap();
assert!(y == one);
}
}
#[test]
fn test_byte_conversion() {
for _x in 1..100 {
let x = rand_elem();
let y = x.to_bytes();
let newx = FieldElem::from_bytes(&y[..]).unwrap();
assert!(x == newx);
}
}
#[test]
fn test_bigint_conversion() {
for _x in 1..100 {
let x = rand_elem();
let y = x.to_biguint();
let newx = FieldElem::from_biguint(&y).unwrap();
assert!(x == newx);
}
}
#[test]
fn test_neg() {
let ctx = FieldCtx::new();
for _ in 0..100 {
let x = rand_elem();
let neg_x = ctx.neg(&x).unwrap();
let zero = ctx.add(&x, &neg_x).unwrap();
assert!(zero.is_zero());
}
}
#[test]
fn test_sqrt() {
let ctx = FieldCtx::new();
for _ in 0..10 {
let x = rand_elem();
let x_2 = ctx.square(&x).unwrap();
let new_x = ctx.sqrt(&x_2).unwrap();
assert!(x == new_x || ctx.add(&x, &new_x).unwrap().is_zero());
}
}
}

21
src/sm2/mod.rs Normal file
View File

@@ -0,0 +1,21 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod ecc;
pub mod encrypt;
mod error;
pub mod exchange;
pub mod field;
pub mod signature;
mod util;

510
src/sm2/signature.rs Normal file
View File

@@ -0,0 +1,510 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::sm2::error::{Sm2Error, Sm2Result};
use crate::sm3::hash::Sm3Hash;
use super::ecc::*;
use super::field::FieldElem;
use byteorder::{BigEndian, WriteBytesExt};
use num_bigint::BigUint;
use num_traits::*;
use std::fmt;
use yasna;
pub type Pubkey = Point;
pub type Seckey = BigUint;
pub struct Signature {
r: BigUint,
s: BigUint,
}
impl Signature {
pub fn new(r_bytes: &[u8], s_bytes: &[u8]) -> Self {
let r = BigUint::from_bytes_be(r_bytes);
let s = BigUint::from_bytes_be(s_bytes);
Signature { r, s }
}
pub fn der_decode(buf: &[u8]) -> Result<Signature, yasna::ASN1Error> {
let (r, s) = yasna::parse_der(buf, |reader| {
reader.read_sequence(|reader| {
let r = reader.next().read_biguint()?;
let s = reader.next().read_biguint()?;
Ok((r, s))
})
})?;
Ok(Signature { r, s })
}
pub fn der_decode_raw(buf: &[u8]) -> Sm2Result<Signature> {
if buf[0] != 0x02 {
return Err(Sm2Error::InvalidDer);
}
let r_len: usize = buf[1] as usize;
if buf.len() <= r_len + 4 {
return Err(Sm2Error::InvalidDer);
}
let r = BigUint::from_bytes_be(&buf[2..2 + r_len]);
let buf = &buf[2 + r_len..];
if buf[0] != 0x02 {
return Err(Sm2Error::InvalidDer);
}
let s_len: usize = buf[1] as usize;
if buf.len() < s_len + 2 {
return Err(Sm2Error::InvalidDer);
}
let s = BigUint::from_bytes_be(&buf[2..2 + s_len]);
Ok(Signature { r, s })
}
pub fn der_encode(&self) -> Vec<u8> {
yasna::construct_der(|writer| {
writer.write_sequence(|writer| {
writer.next().write_biguint(&self.r);
writer.next().write_biguint(&self.s);
})
})
}
#[inline]
pub fn get_r(&self) -> &BigUint {
&self.r
}
#[inline]
pub fn get_s(&self) -> &BigUint {
&self.s
}
}
impl fmt::Display for Signature {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"r = 0x{:0>64}, s = 0x{:0>64}",
self.r.to_str_radix(16),
self.s.to_str_radix(16)
)
}
}
pub struct SigCtx {
curve: EccCtx,
}
impl SigCtx {
pub fn new() -> SigCtx {
SigCtx {
curve: EccCtx::new(),
}
}
pub fn hash(&self, id: &str, pk: &Point, msg: &[u8]) -> Sm2Result<[u8; 32]> {
let curve = &self.curve;
let mut prepend: Vec<u8> = Vec::new();
if id.len() * 8 > 65535 {
return Err(Sm2Error::IdTooLong);
}
prepend
.write_u16::<BigEndian>((id.len() * 8) as u16)
.unwrap();
for c in id.bytes() {
prepend.push(c);
}
let mut a = curve.get_a().to_bytes();
let mut b = curve.get_b().to_bytes();
prepend.append(&mut a);
prepend.append(&mut b);
let (x_g, y_g) = curve.to_affine(&curve.generator()?)?;
let (mut x_g, mut y_g) = (x_g.to_bytes(), y_g.to_bytes());
prepend.append(&mut x_g);
prepend.append(&mut y_g);
let (x_a, y_a) = curve.to_affine(pk)?;
let (mut x_a, mut y_a) = (x_a.to_bytes(), y_a.to_bytes());
prepend.append(&mut x_a);
prepend.append(&mut y_a);
let mut hasher = Sm3Hash::new(&prepend[..]);
let z_a = hasher.get_hash();
// Z_A = HASH_256(ID_LEN || ID || x_G || y_G || x_A || y_A)
// e = HASH_256(Z_A || M)
let mut prepended_msg: Vec<u8> = Vec::new();
prepended_msg.extend_from_slice(&z_a[..]);
prepended_msg.extend_from_slice(msg);
let mut hasher = Sm3Hash::new(&prepended_msg[..]);
Ok(hasher.get_hash())
}
pub fn recid_combine(&self, id: &str, pk: &Point, msg: &[u8]) -> Sm2Result<Vec<u8>> {
let curve = &self.curve;
let mut prepend: Vec<u8> = Vec::new();
if id.len() * 8 > 65535 {
return Err(Sm2Error::IdTooLong);
}
prepend
.write_u16::<BigEndian>((id.len() * 8) as u16)
.unwrap();
for c in id.bytes() {
prepend.push(c);
}
let mut a = curve.get_a().to_bytes();
let mut b = curve.get_b().to_bytes();
prepend.append(&mut a);
prepend.append(&mut b);
let (x_g, y_g) = curve.to_affine(&curve.generator()?)?;
let (mut x_g, mut y_g) = (x_g.to_bytes(), y_g.to_bytes());
prepend.append(&mut x_g);
prepend.append(&mut y_g);
let (x_a, y_a) = curve.to_affine(pk)?;
let (mut x_a, mut y_a) = (x_a.to_bytes(), y_a.to_bytes());
prepend.append(&mut x_a);
prepend.append(&mut y_a);
let mut hasher = Sm3Hash::new(&prepend[..]);
let z_a = hasher.get_hash();
// Z_A = HASH_256(ID_LEN || ID || x_G || y_G || x_A || y_A)
// e = HASH_256(Z_A || M)
let mut prepended_msg: Vec<u8> = Vec::new();
prepended_msg.extend_from_slice(&z_a[..]);
prepended_msg.extend_from_slice(msg);
Ok(prepended_msg)
}
pub fn sign(&self, msg: &[u8], sk: &BigUint, pk: &Point) -> Sm2Result<Signature> {
// Get the value "e", which is the hash of message and ID, EC parameters and public key
let digest = self.hash("1234567812345678", pk, msg)?;
self.sign_raw(&digest[..], sk)
}
pub fn sign_raw(&self, digest: &[u8], sk: &BigUint) -> Sm2Result<Signature> {
let curve = &self.curve;
// Get the value "e", which is the hash of message and ID, EC parameters and public key
let e = BigUint::from_bytes_be(digest);
// two while loops
loop {
// k = rand()
// (x_1, y_1) = g^kg
let k = self.curve.random_uint();
let p_1 = curve.g_mul(&k)?;
let (x_1, _) = curve.to_affine(&p_1)?;
let x_1 = x_1.to_biguint();
// r = e + x_1
let r = (&e + x_1) % curve.get_n();
if r == BigUint::zero() || &r + &k == *curve.get_n() {
continue;
}
// s = (1 + sk)^-1 * (k - r * sk)
let s1 = curve.inv_n(&(sk + BigUint::one()))?;
let mut s2_1 = &r * sk;
if s2_1 < k {
s2_1 += curve.get_n();
}
let mut s2 = s2_1 - k;
s2 %= curve.get_n();
let s2 = curve.get_n() - s2;
let s = (s1 * s2) % curve.get_n();
if s != BigUint::zero() {
// Output the signature (r, s)
return Ok(Signature { r, s });
}
return Err(Sm2Error::ZeroSig);
}
}
pub fn verify(&self, msg: &[u8], pk: &Point, sig: &Signature) -> Sm2Result<bool> {
//Get hash value
let digest = self.hash("1234567812345678", pk, msg)?;
//println!("digest: {:?}", digest);
self.verify_raw(&digest[..], pk, sig)
}
pub fn verify_raw(&self, digest: &[u8], pk: &Point, sig: &Signature) -> Sm2Result<bool> {
if digest.len() != 32 {
return Err(Sm2Error::InvalidDigestLen);
}
let e = BigUint::from_bytes_be(digest);
let curve = &self.curve;
// check r and s
if *sig.get_r() == BigUint::zero() || *sig.get_s() == BigUint::zero() {
return Ok(false);
}
if *sig.get_r() >= *curve.get_n() || *sig.get_s() >= *curve.get_n() {
return Ok(false);
}
// calculate R
let t = (sig.get_s() + sig.get_r()) % curve.get_n();
if t == BigUint::zero() {
return Ok(false);
}
let p_1 = curve.add(&curve.g_mul(sig.get_s())?, &curve.mul(&t, pk)?)?;
let (x_1, _) = curve.to_affine(&p_1)?;
let x_1 = x_1.to_biguint();
let r_ = (e + x_1) % curve.get_n();
// check R == r?
Ok(r_ == *sig.get_r())
}
pub fn new_keypair(&self) -> Sm2Result<(Point, BigUint)> {
let curve = &self.curve;
let mut sk: BigUint = curve.random_uint();
let mut pk: Point = curve.g_mul(&sk)?;
loop {
if !pk.is_zero() {
break;
}
sk = curve.random_uint();
pk = curve.g_mul(&sk)?;
}
Ok((pk, sk))
}
pub fn pk_from_sk(&self, sk: &BigUint) -> Sm2Result<Point> {
let curve = &self.curve;
if *sk >= *curve.get_n() || *sk == BigUint::zero() {
return Err(Sm2Error::InvalidSecretKey);
}
curve.g_mul(sk)
}
pub fn load_pubkey(&self, buf: &[u8]) -> Sm2Result<Point> {
self.curve.bytes_to_point(buf)
}
pub fn serialize_pubkey(&self, p: &Point, compress: bool) -> Sm2Result<Vec<u8>> {
self.curve.point_to_bytes(p, compress)
}
pub fn load_seckey(&self, buf: &[u8]) -> Sm2Result<BigUint> {
if buf.len() != 32 {
return Err(Sm2Error::InvalidPrivate);
}
let sk = BigUint::from_bytes_be(buf);
if sk > *self.curve.get_n() {
Err(Sm2Error::InvalidPrivate)
} else {
Ok(sk)
}
}
pub fn serialize_seckey(&self, x: &BigUint) -> Sm2Result<Vec<u8>> {
if *x > *self.curve.get_n() {
return Err(Sm2Error::InvalidSecretKey);
}
let x = FieldElem::from_biguint(x)?;
Ok(x.to_bytes())
}
}
impl Default for SigCtx {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sign() {
let string = String::from("abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd");
let msg = string.as_bytes();
let ctx = SigCtx::new();
let (pk, sk) = ctx.new_keypair().unwrap();
let signature = ctx.sign(msg, &sk, &pk).unwrap();
println!("public key is {pk}, signature is {signature}");
}
#[test]
fn test_sign_and_verify() {
let string = String::from("abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd");
let msg = string.as_bytes();
let ctx = SigCtx::new();
let (pk, sk) = ctx.new_keypair().unwrap();
let signature = ctx.sign(msg, &sk, &pk).unwrap();
assert!(ctx.verify(msg, &pk, &signature).unwrap());
}
#[test]
fn test_sig_encode_and_decode() {
let string = String::from("abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd");
let msg = string.as_bytes();
let ctx = SigCtx::new();
let (pk, sk) = ctx.new_keypair().unwrap();
let signature = ctx.sign(msg, &sk, &pk).unwrap();
let der = signature.der_encode();
let sig = Signature::der_decode(&der[..]).unwrap();
assert!(ctx.verify(msg, &pk, &sig).unwrap());
let signature = ctx.sign(msg, &sk, &pk).unwrap();
let der = signature.der_encode();
let sig = Signature::der_decode_raw(&der[2..]).unwrap();
assert!(ctx.verify(msg, &pk, &sig).unwrap());
}
#[test]
fn test_key_serialization() {
let ctx = SigCtx::new();
let (pk, sk) = ctx.new_keypair().unwrap();
let pk_v = ctx.serialize_pubkey(&pk, true).unwrap();
let new_pk = ctx.load_pubkey(&pk_v[..]).unwrap();
assert!(ctx.curve.eq(&new_pk, &pk).unwrap());
let sk_v = ctx.serialize_seckey(&sk).unwrap();
let new_sk = ctx.load_seckey(&sk_v[..]).unwrap();
assert_eq!(new_sk, sk);
}
#[test]
fn test_gmssl() {
let msg: &[u8] = &[
0x66, 0xc7, 0xf0, 0xf4, 0x62, 0xee, 0xed, 0xd9, 0xd1, 0xf2, 0xd4, 0x6b, 0xdc, 0x10,
0xe4, 0xe2, 0x41, 0x67, 0xc4, 0x87, 0x5c, 0xf2, 0xf7, 0xa2, 0x29, 0x7d, 0xa0, 0x2b,
0x8f, 0x4b, 0xa8, 0xe0,
];
let pk: &[u8] = &[
4, 233, 185, 71, 125, 111, 174, 63, 105, 217, 19, 218, 72, 114, 185, 96, 243, 176, 1,
8, 239, 132, 114, 119, 216, 38, 21, 117, 142, 223, 42, 157, 170, 123, 219, 65, 50, 238,
191, 116, 238, 240, 197, 158, 1, 145, 177, 107, 112, 91, 101, 86, 50, 204, 218, 254,
172, 2, 250, 33, 56, 176, 121, 16, 215,
];
let sig: &[u8] = &[
48, 69, 2, 33, 0, 171, 111, 172, 181, 242, 159, 198, 106, 33, 229, 104, 147, 245, 97,
132, 141, 141, 17, 27, 97, 156, 159, 160, 188, 239, 78, 124, 17, 211, 124, 113, 26, 2,
32, 53, 21, 4, 195, 198, 42, 71, 17, 110, 157, 113, 185, 178, 74, 147, 87, 129, 179,
168, 163, 171, 126, 39, 156, 198, 29, 163, 199, 82, 25, 13, 112,
];
let curve = EccCtx::new();
let ctx = SigCtx::new();
let pk = curve.bytes_to_point(pk).unwrap();
let sig = Signature::der_decode(sig).unwrap();
assert!(ctx.verify_raw(msg, &pk, &sig).unwrap());
}
#[test]
fn verify_third_test() {
let ctx = SigCtx::new();
let msg = b"hello world";
let ecc_ctx = EccCtx::new();
let pk_bz = hex::decode("0420e9c9497bf151e33c3af9e7deb63e2133a27d21fa1647cee0afda049af1f664f81dc793ebab487ab51414081075e57a65b016da4087f491c04977a6397327b2").unwrap();
let pk = ecc_ctx.bytes_to_point(&pk_bz).unwrap();
let sig_r_bz =
hex::decode("76415405cbb177ebb37a835a2b5a022f66c250abf482e4cb343dcb2091bc1f2e")
.unwrap();
let sig_s_bz =
hex::decode("61f0665f805e78dd19073922992c671867a1dee839e8179d39b532eb66b9cd90")
.unwrap();
let sig = Signature::new(&sig_r_bz, &sig_s_bz);
assert!(ctx.verify(msg, &pk, &sig).unwrap());
}
#[test]
fn verify_third_der_test() {
let ctx = SigCtx::new();
let msg = "jonllen".to_string().into_bytes();
let ecc_ctx = EccCtx::new();
let pk_bz = hex::decode("044f954d8c4d7c0133e5f402c7e75623438c2dcee5ae5ee6c2f1fca51c60f7017e9cfad13514cd4e7faeca476a98eeb0b8a62c1f6add9794beead4a42291b94278").unwrap();
let pk = ecc_ctx.bytes_to_point(&pk_bz).unwrap();
let sig_bz = hex::decode("304402207e665a4d2781cb488bd374ccf1c8116e95ad0731c99e1dc36c189fd4daf0cb0202206a7ddd6483db176192b25aba9a92bc4de8b76e2c6d1559965ad06224d0725531").unwrap();
let sig = Signature::der_decode(&sig_bz).unwrap();
assert!(ctx.verify(&msg, &pk, &sig).unwrap());
}
}
#[cfg(feature = "internal_benches")]
mod signature_benches {
use crate::sm2::signature::SigCtx;
extern crate test;
#[bench]
fn sign_bench(bench: &mut test::Bencher) {
let test_word = b"hello world";
let ctx = SigCtx::new();
let (pk, sk) = ctx.new_keypair().unwrap();
bench.iter(|| {
let _ = ctx.sign(test_word, &sk, &pk);
});
}
#[bench]
fn verify_bench(bench: &mut test::Bencher) {
let test_word = b"hello world";
let ctx = SigCtx::new();
let (pk, sk) = ctx.new_keypair().unwrap();
let sig = ctx.sign(test_word, &sk, &pk).unwrap();
bench.iter(|| {
let _ = ctx.verify(test_word, &pk, &sig);
});
}
}

31
src/sm2/util.rs Normal file
View File

@@ -0,0 +1,31 @@
use crate::sm3::hash::Sm3Hash;
// DIFFERNCE: klen bytes, not klen bits
pub fn kdf(z: &[u8], klen: usize) -> Vec<u8> {
let mut ct = 0x0000_0001_u32;
let bound = ((klen as f64) / 32.0).ceil() as u32;
let mut h_a = Vec::new();
for _i in 1..bound {
let mut prepend = Vec::new();
prepend.extend_from_slice(z);
prepend.extend_from_slice(&ct.to_be_bytes());
let mut hasher = Sm3Hash::new(&prepend[..]);
let h_a_i = hasher.get_hash();
h_a.extend_from_slice(&h_a_i);
ct += 1;
}
let mut prepend = Vec::new();
prepend.extend_from_slice(z);
prepend.extend_from_slice(&ct.to_be_bytes());
let mut hasher = Sm3Hash::new(&prepend[..]);
let last = hasher.get_hash();
if klen % 32 == 0 {
h_a.extend_from_slice(&last);
} else {
h_a.extend_from_slice(&last[0..(klen % 32)]);
}
h_a
}

58
src/sm3/error.rs Normal file
View File

@@ -0,0 +1,58 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use std::fmt::Formatter;
pub type Sm3Result<T> = Result<T, Sm3Error>;
pub enum Sm3Error {
ErrorMsgLen,
}
impl std::fmt::Debug for Sm3Error {
fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{self}")
}
}
impl From<Sm3Error> for &str {
fn from(e: Sm3Error) -> Self {
match e {
Sm3Error::ErrorMsgLen => "SM3 Pad: error msgLen",
}
}
}
impl Display for Sm3Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let err_msg = match self {
Sm3Error::ErrorMsgLen => "SM3 Pad: error msgLen",
};
write!(f, "{err_msg}")
}
}
#[cfg(test)]
mod tests {
use super::Sm3Error;
#[test]
fn test_error_display() {
assert_eq!(
format!("{}", Sm3Error::ErrorMsgLen),
"SM3 Pad: error msgLen"
)
}
}

294
src/sm3/hash.rs Normal file
View File

@@ -0,0 +1,294 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Sample 1
// Input:"abc"
// Output:66c7f0f4 62eeedd9 d1f2d46b dc10e4e2 4167c487 5cf2f7a2 297da02b 8f4ba8e0
use crate::sm3::error::{Sm3Error, Sm3Result};
// Sample 2
// Input:"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd"
// Outpuf:debe9ff9 2275b8a1 38604889 c18e5a4d 6fdb70e5 387e5765 293dcba3 9c0c5732
#[inline(always)]
fn ff0(x: u32, y: u32, z: u32) -> u32 {
x ^ y ^ z
}
#[inline(always)]
fn ff1(x: u32, y: u32, z: u32) -> u32 {
(x & y) | (x & z) | (y & z)
}
#[inline(always)]
fn gg0(x: u32, y: u32, z: u32) -> u32 {
x ^ y ^ z
}
#[inline(always)]
fn gg1(x: u32, y: u32, z: u32) -> u32 {
(x & y) | (!x & z)
}
#[inline(always)]
fn p0(x: u32) -> u32 {
x ^ x.rotate_left(9) ^ x.rotate_left(17)
}
#[inline(always)]
fn p1(x: u32) -> u32 {
x ^ x.rotate_left(15) ^ x.rotate_left(23)
}
#[inline(always)]
fn get_u32_be(b: &[u8; 64], i: usize) -> u32 {
u32::from(b[i]) << 24
| u32::from(b[i + 1]) << 16
| u32::from(b[i + 2]) << 8
| u32::from(b[i + 3])
}
pub struct Sm3Hash {
digest: [u32; 8],
length: u64,
unhandle_msg: Vec<u8>,
}
impl Sm3Hash {
pub fn new(data: &[u8]) -> Sm3Hash {
let mut hash = Sm3Hash {
digest: [
0x7380_166f,
0x4914_b2b9,
0x1724_42d7,
0xda8a_0600,
0xa96f_30bc,
0x1631_38aa,
0xe38d_ee4d,
0xb0fb_0e4e,
],
length: (data.len() << 3) as u64,
unhandle_msg: Vec::new(),
};
for i in data.iter() {
hash.unhandle_msg.push(*i);
}
hash
}
pub fn get_hash(&mut self) -> [u8; 32] {
let mut output: [u8; 32] = [0; 32];
self.pad().unwrap();
let len = self.unhandle_msg.len();
let mut count: usize = 0;
let mut buffer: [u8; 64] = [0; 64];
while count * 64 != len {
for i in (count * 64)..(count * 64 + 64) {
buffer[i - count * 64] = self.unhandle_msg[i];
}
self.update(&buffer);
count += 1;
}
let mut i = 0;
while i < 8 {
output[i * 4] = (self.digest[i] >> 24) as u8;
output[i * 4 + 1] = (self.digest[i] >> 16) as u8;
output[i * 4 + 2] = (self.digest[i] >> 8) as u8;
output[i * 4 + 3] = self.digest[i] as u8;
i += 1;
}
output
}
fn pad(&mut self) -> Sm3Result<()> {
self.unhandle_msg.push(0x80);
let blocksize = 64;
while self.unhandle_msg.len() % blocksize != 56 {
self.unhandle_msg.push(0x00);
}
self.unhandle_msg.push((self.length >> 56 & 0xff) as u8);
self.unhandle_msg.push((self.length >> 48 & 0xff) as u8);
self.unhandle_msg.push((self.length >> 40 & 0xff) as u8);
self.unhandle_msg.push((self.length >> 32 & 0xff) as u8);
self.unhandle_msg.push((self.length >> 24 & 0xff) as u8);
self.unhandle_msg.push((self.length >> 16 & 0xff) as u8);
self.unhandle_msg.push((self.length >> 8 & 0xff) as u8);
self.unhandle_msg.push((self.length & 0xff) as u8);
if self.unhandle_msg.len() % 64 != 0 {
return Err(Sm3Error::ErrorMsgLen);
}
Ok(())
}
fn update(&mut self, buffer: &[u8; 64]) {
//get expend
let mut w: [u32; 68] = [0; 68];
let mut w1: [u32; 64] = [0; 64];
let mut i = 0;
while i < 16 {
w[i] = get_u32_be(buffer, i * 4);
i += 1;
}
i = 16;
while i < 68 {
w[i] = p1(w[i - 16] ^ w[i - 9] ^ w[i - 3].rotate_left(15))
^ w[i - 13].rotate_left(7)
^ w[i - 6];
i += 1;
}
i = 0;
while i < 64 {
w1[i] = w[i] ^ w[i + 4];
i += 1;
}
let mut ra = self.digest[0];
let mut rb = self.digest[1];
let mut rc = self.digest[2];
let mut rd = self.digest[3];
let mut re = self.digest[4];
let mut rf = self.digest[5];
let mut rg = self.digest[6];
let mut rh = self.digest[7];
let mut ss1: u32;
let mut ss2: u32;
let mut tt1: u32;
let mut tt2: u32;
i = 0;
while i < 16 {
ss1 = ra
.rotate_left(12)
.wrapping_add(re)
.wrapping_add(0x79cc_4519u32.rotate_left(i as u32))
.rotate_left(7);
ss2 = ss1 ^ ra.rotate_left(12);
tt1 = ff0(ra, rb, rc)
.wrapping_add(rd)
.wrapping_add(ss2)
.wrapping_add(w1[i]);
tt2 = gg0(re, rf, rg)
.wrapping_add(rh)
.wrapping_add(ss1)
.wrapping_add(w[i]);
rd = rc;
rc = rb.rotate_left(9);
rb = ra;
ra = tt1;
rh = rg;
rg = rf.rotate_left(19);
rf = re;
re = p0(tt2);
i += 1;
}
i = 16;
while i < 64 {
ss1 = ra
.rotate_left(12)
.wrapping_add(re)
.wrapping_add(0x7a87_9d8au32.rotate_left(i as u32))
.rotate_left(7);
ss2 = ss1 ^ ra.rotate_left(12);
tt1 = ff1(ra, rb, rc)
.wrapping_add(rd)
.wrapping_add(ss2)
.wrapping_add(w1[i]);
tt2 = gg1(re, rf, rg)
.wrapping_add(rh)
.wrapping_add(ss1)
.wrapping_add(w[i]);
rd = rc;
rc = rb.rotate_left(9);
rb = ra;
ra = tt1;
rh = rg;
rg = rf.rotate_left(19);
rf = re;
re = p0(tt2);
i += 1;
}
self.digest[0] ^= ra;
self.digest[1] ^= rb;
self.digest[2] ^= rc;
self.digest[3] ^= rd;
self.digest[4] ^= re;
self.digest[5] ^= rf;
self.digest[6] ^= rg;
self.digest[7] ^= rh;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lets_hash_1() {
let string = String::from("abc");
//let string = String::from("abcd");
let s = string.as_bytes();
let mut sm3 = Sm3Hash::new(s);
let hash = sm3.get_hash();
let standrad_hash: [u8; 32] = [
0x66, 0xc7, 0xf0, 0xf4, 0x62, 0xee, 0xed, 0xd9, 0xd1, 0xf2, 0xd4, 0x6b, 0xdc, 0x10,
0xe4, 0xe2, 0x41, 0x67, 0xc4, 0x87, 0x5c, 0xf2, 0xf7, 0xa2, 0x29, 0x7d, 0xa0, 0x2b,
0x8f, 0x4b, 0xa8, 0xe0,
];
for i in 0..32 {
assert_eq!(standrad_hash[i], hash[i]);
}
}
#[test]
fn lets_hash_2() {
let string =
String::from("abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd");
let s = string.as_bytes();
let mut sm3 = Sm3Hash::new(s);
let hash = sm3.get_hash();
let standrad_hash: [u8; 32] = [
0xde, 0xbe, 0x9f, 0xf9, 0x22, 0x75, 0xb8, 0xa1, 0x38, 0x60, 0x48, 0x89, 0xc1, 0x8e,
0x5a, 0x4d, 0x6f, 0xdb, 0x70, 0xe5, 0x38, 0x7e, 0x57, 0x65, 0x29, 0x3d, 0xcb, 0xa3,
0x9c, 0x0c, 0x57, 0x32,
];
for i in 0..32 {
assert_eq!(standrad_hash[i], hash[i]);
}
}
}

16
src/sm3/mod.rs Normal file
View File

@@ -0,0 +1,16 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod error;
pub mod hash;

244
src/sm4/cipher.rs Normal file
View File

@@ -0,0 +1,244 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::sm4::error::{Sm4Error, Sm4Result};
static SBOX: [u8; 256] = [
0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7, 0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05,
0x2b, 0x67, 0x9a, 0x76, 0x2a, 0xbe, 0x04, 0xc3, 0xaa, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99,
0x9c, 0x42, 0x50, 0xf4, 0x91, 0xef, 0x98, 0x7a, 0x33, 0x54, 0x0b, 0x43, 0xed, 0xcf, 0xac, 0x62,
0xe4, 0xb3, 0x1c, 0xa9, 0xc9, 0x08, 0xe8, 0x95, 0x80, 0xdf, 0x94, 0xfa, 0x75, 0x8f, 0x3f, 0xa6,
0x47, 0x07, 0xa7, 0xfc, 0xf3, 0x73, 0x17, 0xba, 0x83, 0x59, 0x3c, 0x19, 0xe6, 0x85, 0x4f, 0xa8,
0x68, 0x6b, 0x81, 0xb2, 0x71, 0x64, 0xda, 0x8b, 0xf8, 0xeb, 0x0f, 0x4b, 0x70, 0x56, 0x9d, 0x35,
0x1e, 0x24, 0x0e, 0x5e, 0x63, 0x58, 0xd1, 0xa2, 0x25, 0x22, 0x7c, 0x3b, 0x01, 0x21, 0x78, 0x87,
0xd4, 0x00, 0x46, 0x57, 0x9f, 0xd3, 0x27, 0x52, 0x4c, 0x36, 0x02, 0xe7, 0xa0, 0xc4, 0xc8, 0x9e,
0xea, 0xbf, 0x8a, 0xd2, 0x40, 0xc7, 0x38, 0xb5, 0xa3, 0xf7, 0xf2, 0xce, 0xf9, 0x61, 0x15, 0xa1,
0xe0, 0xae, 0x5d, 0xa4, 0x9b, 0x34, 0x1a, 0x55, 0xad, 0x93, 0x32, 0x30, 0xf5, 0x8c, 0xb1, 0xe3,
0x1d, 0xf6, 0xe2, 0x2e, 0x82, 0x66, 0xca, 0x60, 0xc0, 0x29, 0x23, 0xab, 0x0d, 0x53, 0x4e, 0x6f,
0xd5, 0xdb, 0x37, 0x45, 0xde, 0xfd, 0x8e, 0x2f, 0x03, 0xff, 0x6a, 0x72, 0x6d, 0x6c, 0x5b, 0x51,
0x8d, 0x1b, 0xaf, 0x92, 0xbb, 0xdd, 0xbc, 0x7f, 0x11, 0xd9, 0x5c, 0x41, 0x1f, 0x10, 0x5a, 0xd8,
0x0a, 0xc1, 0x31, 0x88, 0xa5, 0xcd, 0x7b, 0xbd, 0x2d, 0x74, 0xd0, 0x12, 0xb8, 0xe5, 0xb4, 0xb0,
0x89, 0x69, 0x97, 0x4a, 0x0c, 0x96, 0x77, 0x7e, 0x65, 0xb9, 0xf1, 0x09, 0xc5, 0x6e, 0xc6, 0x84,
0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48,
];
fn split(input: u32) -> [u8; 4] {
let i4: u8 = input as u8;
let i3: u8 = (input >> 8) as u8;
let i2: u8 = (input >> 16) as u8;
let i1: u8 = (input >> 24) as u8;
[i1, i2, i3, i4]
}
fn combine(input: &[u8]) -> u32 {
let out: u32 = u32::from(input[3]);
let out = out | (u32::from(input[2]) << 8);
let out = out | (u32::from(input[1]) << 16);
out | (u32::from(input[0]) << 24)
}
fn split_block(input: &[u8]) -> Sm4Result<[u32; 4]> {
if input.len() != 16 {
return Err(Sm4Error::ErrorBlockSize);
}
let mut out: [u32; 4] = [0; 4];
for (i, v) in out.iter_mut().enumerate().take(4) {
let start = 4 * i;
let end = 4 * i + 4;
*v = combine(&input[start..end])
}
Ok(out)
}
fn combine_block(input: &[u32]) -> Sm4Result<[u8; 16]> {
let mut out: [u8; 16] = [0; 16];
for i in 0..4 {
let outi = split(input[i]);
for j in 0..4 {
out[i * 4 + j] = outi[j];
}
}
Ok(out)
}
fn tau_trans(input: u32) -> u32 {
let input = split(input);
let mut out: [u8; 4] = [0; 4];
for i in 0..4 {
out[i] = SBOX[input[i] as usize];
}
combine(&out)
}
fn l_rotate(x: u32, i: u32) -> u32 {
(x << (i % 32)) | (x >> (32 - (i % 32)))
}
fn l_trans(input: u32) -> u32 {
let b = input;
b ^ l_rotate(b, 2) ^ l_rotate(b, 10) ^ l_rotate(b, 18) ^ l_rotate(b, 24)
}
fn t_trans(input: u32) -> u32 {
l_trans(tau_trans(input))
}
fn l_prime_trans(input: u32) -> u32 {
let b = input;
b ^ l_rotate(b, 13) ^ l_rotate(b, 23)
}
fn t_prime_trans(input: u32) -> u32 {
l_prime_trans(tau_trans(input))
}
pub struct Sm4Cipher {
// round key
rk: Vec<u32>,
}
static FK: [u32; 4] = [0xa3b1_bac6, 0x56aa_3350, 0x677d_9197, 0xb270_22dc];
static CK: [u32; 32] = [
0x0007_0e15,
0x1c23_2a31,
0x383f_464d,
0x545b_6269,
0x7077_7e85,
0x8c93_9aa1,
0xa8af_b6bd,
0xc4cb_d2d9,
0xe0e7_eef5,
0xfc03_0a11,
0x181f_262d,
0x343b_4249,
0x5057_5e65,
0x6c73_7a81,
0x888f_969d,
0xa4ab_b2b9,
0xc0c7_ced5,
0xdce3_eaf1,
0xf8ff_060d,
0x141b_2229,
0x3037_3e45,
0x4c53_5a61,
0x686f_767d,
0x848b_9299,
0xa0a7_aeb5,
0xbcc3_cad1,
0xd8df_e6ed,
0xf4fb_0209,
0x1017_1e25,
0x2c33_3a41,
0x484f_565d,
0x646b_7279,
];
impl Sm4Cipher {
pub fn new(key: &[u8]) -> Result<Sm4Cipher, Sm4Error> {
let mut k: [u32; 4] = split_block(key)?;
let mut cipher = Sm4Cipher { rk: Vec::new() };
for i in 0..4 {
k[i] ^= FK[i];
}
for i in 0..8 {
k[0] ^= t_prime_trans(k[1] ^ k[2] ^ k[3] ^ CK[i * 4]);
k[1] ^= t_prime_trans(k[2] ^ k[3] ^ k[0] ^ CK[i * 4 + 1]);
k[2] ^= t_prime_trans(k[3] ^ k[0] ^ k[1] ^ CK[i * 4 + 2]);
k[3] ^= t_prime_trans(k[0] ^ k[1] ^ k[2] ^ CK[i * 4 + 3]);
cipher.rk.push(k[0]);
cipher.rk.push(k[1]);
cipher.rk.push(k[2]);
cipher.rk.push(k[3]);
}
Ok(cipher)
}
pub fn encrypt(&self, block_in: &[u8]) -> Result<[u8; 16], Sm4Error> {
let mut x: [u32; 4] = split_block(block_in)?;
let rk = &self.rk;
for i in 0..8 {
x[0] ^= t_trans(x[1] ^ x[2] ^ x[3] ^ rk[i * 4]);
x[1] ^= t_trans(x[2] ^ x[3] ^ x[0] ^ rk[i * 4 + 1]);
x[2] ^= t_trans(x[3] ^ x[0] ^ x[1] ^ rk[i * 4 + 2]);
x[3] ^= t_trans(x[0] ^ x[1] ^ x[2] ^ rk[i * 4 + 3]);
}
let y = [x[3], x[2], x[1], x[0]];
combine_block(&y)
}
pub fn decrypt(&self, block_in: &[u8]) -> Result<[u8; 16], Sm4Error> {
let mut x: [u32; 4] = split_block(block_in)?;
let rk = &self.rk;
for i in 0..8 {
x[0] ^= t_trans(x[1] ^ x[2] ^ x[3] ^ rk[31 - i * 4]);
x[1] ^= t_trans(x[2] ^ x[3] ^ x[0] ^ rk[31 - (i * 4 + 1)]);
x[2] ^= t_trans(x[3] ^ x[0] ^ x[1] ^ rk[31 - (i * 4 + 2)]);
x[3] ^= t_trans(x[0] ^ x[1] ^ x[2] ^ rk[31 - (i * 4 + 3)]);
}
let y = [x[3], x[2], x[1], x[0]];
combine_block(&y)
}
}
// Tests below
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn setup_cipher() {
let key: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let cipher = Sm4Cipher::new(&key).unwrap();
let rk = &cipher.rk;
assert_eq!(rk[0], 0xf121_86f9);
assert_eq!(rk[31], 0x9124_a012);
}
#[test]
fn enc_and_dec() {
let key: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let cipher = Sm4Cipher::new(&key).unwrap();
let data: [u8; 16] = [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
];
let ct = cipher.encrypt(&data).unwrap();
let standard_ct: [u8; 16] = [
0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, 0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e,
0x42, 0x46,
];
// Check the example cipher text
for i in 0..16 {
assert_eq!(standard_ct[i], ct[i]);
}
// Check the result of decryption
let pt = cipher.decrypt(&ct).unwrap();
for i in 0..16 {
assert_eq!(pt[i], data[i]);
}
}
}

366
src/sm4/cipher_mode.rs Normal file
View File

@@ -0,0 +1,366 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::cipher::Sm4Cipher;
use crate::sm4::error::{Sm4Error, Sm4Result};
pub enum CipherMode {
Cfb,
Ofb,
Ctr,
Cbc,
}
pub struct Sm4CipherMode {
cipher: Sm4Cipher,
mode: CipherMode,
}
fn block_xor(a: &[u8], b: &[u8]) -> [u8; 16] {
let mut out: [u8; 16] = [0; 16];
for i in 0..16 {
out[i] = a[i] ^ b[i];
}
out
}
fn block_add_one(a: &mut [u8]) {
let mut carry = 1;
for i in 0..16 {
let (t, c) = a[15 - i].overflowing_add(carry);
a[15 - i] = t;
if !c {
return;
}
carry = c as u8;
}
}
impl Sm4CipherMode {
pub fn new(key: &[u8], mode: CipherMode) -> Sm4Result<Sm4CipherMode> {
let cipher = Sm4Cipher::new(key)?;
Ok(Sm4CipherMode { cipher, mode })
}
pub fn encrypt(&self, data: &[u8], iv: &[u8]) -> Sm4Result<Vec<u8>> {
if iv.len() != 16 {
return Err(Sm4Error::ErrorBlockSize);
}
match self.mode {
CipherMode::Cfb => self.cfb_encrypt(data, iv),
CipherMode::Ofb => self.ofb_encrypt(data, iv),
CipherMode::Ctr => self.ctr_encrypt(data, iv),
CipherMode::Cbc => self.cbc_encrypt(data, iv),
}
}
pub fn decrypt(&self, data: &[u8], iv: &[u8]) -> Sm4Result<Vec<u8>> {
if iv.len() != 16 {
return Err(Sm4Error::ErrorBlockSize);
}
match self.mode {
CipherMode::Cfb => self.cfb_decrypt(data, iv),
CipherMode::Ofb => self.ofb_encrypt(data, iv),
CipherMode::Ctr => self.ctr_encrypt(data, iv),
CipherMode::Cbc => self.cbc_decrypt(data, iv),
}
}
fn cfb_encrypt(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, Sm4Error> {
let block_num = data.len() / 16;
let tail_len = data.len() - block_num * 16;
let mut out: Vec<u8> = Vec::new();
let mut vec_buf: Vec<u8> = vec![0; 16];
vec_buf.clone_from_slice(iv);
// Normal
for i in 0..block_num {
let enc = self.cipher.encrypt(&vec_buf[..])?;
let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
for i in ct.iter() {
out.push(*i);
}
vec_buf.clone_from_slice(&ct);
}
// Last block
let enc = self.cipher.encrypt(&vec_buf[..])?;
for i in 0..tail_len {
let b = data[block_num * 16 + i] ^ enc[i];
out.push(b);
}
Ok(out)
}
fn cfb_decrypt(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, Sm4Error> {
let block_num = data.len() / 16;
let tail_len = data.len() - block_num * 16;
let mut out: Vec<u8> = Vec::new();
let mut vec_buf: Vec<u8> = vec![0; 16];
vec_buf.clone_from_slice(iv);
// Normal
for i in 0..block_num {
let enc = self.cipher.encrypt(&vec_buf[..])?;
let ct = &data[i * 16..i * 16 + 16];
let pt = block_xor(&enc, ct);
for i in pt.iter() {
out.push(*i);
}
vec_buf.clone_from_slice(ct);
}
// Last block
let enc = self.cipher.encrypt(&vec_buf[..])?;
for i in 0..tail_len {
let b = data[block_num * 16 + i] ^ enc[i];
out.push(b);
}
Ok(out)
}
fn ofb_encrypt(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, Sm4Error> {
let block_num = data.len() / 16;
let tail_len = data.len() - block_num * 16;
let mut out: Vec<u8> = Vec::new();
let mut vec_buf: Vec<u8> = vec![0; 16];
vec_buf.clone_from_slice(iv);
// Normal
for i in 0..block_num {
let enc = self.cipher.encrypt(&vec_buf[..])?;
let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
for i in ct.iter() {
out.push(*i);
}
vec_buf.clone_from_slice(&enc);
}
// Last block
let enc = self.cipher.encrypt(&vec_buf[..])?;
for i in 0..tail_len {
let b = data[block_num * 16 + i] ^ enc[i];
out.push(b);
}
Ok(out)
}
fn ctr_encrypt(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, Sm4Error> {
let block_num = data.len() / 16;
let tail_len = data.len() - block_num * 16;
let mut out: Vec<u8> = Vec::new();
let mut vec_buf: Vec<u8> = vec![0; 16];
vec_buf.clone_from_slice(iv);
// Normal
for i in 0..block_num {
let enc = self.cipher.encrypt(&vec_buf[..])?;
let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
for i in ct.iter() {
out.push(*i);
}
block_add_one(&mut vec_buf[..]);
}
// Last block
let enc = self.cipher.encrypt(&vec_buf[..])?;
for i in 0..tail_len {
let b = data[block_num * 16 + i] ^ enc[i];
out.push(b);
}
Ok(out)
}
fn cbc_encrypt(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, Sm4Error> {
let block_num = data.len() / 16;
let remind = data.len() % 16;
let mut out: Vec<u8> = Vec::new();
let mut vec_buf = [0; 16];
vec_buf.copy_from_slice(iv);
// Normal
for i in 0..block_num {
let ct = block_xor(&vec_buf, &data[i * 16..i * 16 + 16]);
let enc = self.cipher.encrypt(&ct)?;
out.extend_from_slice(&enc);
vec_buf = enc;
}
if remind != 0 {
let mut last_block = [16 - remind as u8; 16];
last_block[..remind].copy_from_slice(&data[block_num * 16..]);
let ct = block_xor(&vec_buf, &last_block);
let enc = self.cipher.encrypt(&ct)?;
out.extend_from_slice(&enc);
} else {
let ff_padding = block_xor(&vec_buf, &[0x10; 16]);
let enc = self.cipher.encrypt(&ff_padding)?;
out.extend_from_slice(&enc);
}
Ok(out)
}
fn cbc_decrypt(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, Sm4Error> {
let data_len = data.len();
let block_num = data_len / 16;
if data_len % 16 != 0 {
return Err(Sm4Error::ErrorDataLen);
}
let mut out: Vec<u8> = Vec::new();
let mut vec_buf = [0; 16];
vec_buf.copy_from_slice(iv);
// Normal
for i in 0..block_num {
let enc = self.cipher.decrypt(&data[i * 16..i * 16 + 16])?;
let ct = block_xor(&vec_buf, &enc);
for j in ct.iter() {
out.push(*j);
}
vec_buf.copy_from_slice(&data[i * 16..i * 16 + 16]);
}
let last_u8 = out[data_len - 1];
if last_u8 > 0x10 || last_u8 == 0 {
return Err(Sm4Error::InvalidLastU8);
}
out.resize(data_len - last_u8 as usize, 0);
Ok(out)
}
}
// TODO: AEAD in SM4
// pub struct SM4Gcm;
// Tests below
#[cfg(test)]
mod tests {
use super::*;
use rand::RngCore;
fn rand_block() -> [u8; 16] {
let mut rng = rand::thread_rng();
let mut block: [u8; 16] = [0; 16];
rng.fill_bytes(&mut block[..]);
block
}
fn rand_data(len: usize) -> Vec<u8> {
let mut rng = rand::thread_rng();
let mut dat: Vec<u8> = Vec::new();
dat.resize(len, 0);
rng.fill_bytes(&mut dat[..]);
dat
}
#[test]
fn test_driver() {
test_ciphermode(CipherMode::Ctr);
test_ciphermode(CipherMode::Cfb);
test_ciphermode(CipherMode::Ofb);
test_ciphermode(CipherMode::Cbc);
}
fn test_ciphermode(mode: CipherMode) {
let key = rand_block();
let iv = rand_block();
let cmode = Sm4CipherMode::new(&key, mode).unwrap();
let pt = rand_data(10);
let ct = cmode.encrypt(&pt[..], &iv).unwrap();
let new_pt = cmode.decrypt(&ct[..], &iv).unwrap();
assert_eq!(pt, new_pt);
let pt = rand_data(100);
let ct = cmode.encrypt(&pt[..], &iv).unwrap();
let new_pt = cmode.decrypt(&ct[..], &iv).unwrap();
assert_eq!(pt, new_pt);
let pt = rand_data(1000);
let ct = cmode.encrypt(&pt[..], &iv).unwrap();
let new_pt = cmode.decrypt(&ct[..], &iv).unwrap();
assert_eq!(pt, new_pt);
}
#[test]
fn ctr_enc_test() {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ctr).unwrap();
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv).unwrap();
let lhs: &[u8] = lhs.as_ref();
let rhs: &[u8] = include_bytes!("example/text.sms4-ctr");
assert_eq!(lhs, rhs);
}
#[test]
fn cfb_enc_test() {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Cfb).unwrap();
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref().unwrap();
let rhs: &[u8] = include_bytes!("example/text.sms4-cfb");
assert_eq!(lhs, rhs);
}
#[test]
fn ofb_enc_test() {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ofb).unwrap();
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref().unwrap();
let rhs: &[u8] = include_bytes!("example/text.sms4-ofb");
assert_eq!(lhs, rhs);
}
#[test]
fn cbc_enc_test() {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Cbc).unwrap();
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref().unwrap();
let rhs: &[u8] = include_bytes!("example/text.sms4-cbc");
assert_eq!(lhs, rhs);
}
}

71
src/sm4/error.rs Normal file
View File

@@ -0,0 +1,71 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::error;
use std::fmt::Display;
use std::fmt::Formatter;
pub type Sm4Result<T> = Result<T, Sm4Error>;
pub enum Sm4Error {
ErrorBlockSize,
ErrorDataLen,
InvalidLastU8,
}
impl ::std::fmt::Debug for Sm4Error {
fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{self}")
}
}
impl From<Sm4Error> for &str {
fn from(e: Sm4Error) -> Self {
match e {
Sm4Error::ErrorBlockSize => "the block size of SM4 must be 16",
Sm4Error::ErrorDataLen => "the data len of SM4 must be 16",
Sm4Error::InvalidLastU8 => {
"the last u8 of cbc_decrypt out in SM4 must be positive which isn't greater than 16"
}
}
}
}
impl error::Error for Sm4Error {}
impl Display for Sm4Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let err_msg = match self {
Sm4Error::ErrorBlockSize => "the block size of SM4 must be 16",
Sm4Error::ErrorDataLen => "the data len of SM4 must be 16",
Sm4Error::InvalidLastU8 => {
"the last u8 of cbc_decrypt out in SM4 must be positive which isn't greater than 16"
}
};
write!(f, "{err_msg}")
}
}
#[cfg(test)]
mod tests {
use super::Sm4Error;
#[test]
fn test_error_display() {
assert_eq!(
format!("{}", Sm4Error::ErrorBlockSize),
"the block size of SM4 must be 16"
)
}
}

View File

@@ -0,0 +1 @@
vЬ3<цgYУ├╡>у≤Ч! ╡@┤и├÷S/э0║[╛∙.╢╨╙╤(,─╔

View File

@@ -0,0 +1 @@
|2Ø@

View File

@@ -0,0 +1,2 @@
|2@
ノアナ否ゥ=Oンnノマィ 5ユ4標禺|蒸O

View File

@@ -0,0 +1 @@
|2ь@

20
src/sm4/mod.rs Normal file
View File

@@ -0,0 +1,20 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod cipher;
pub mod cipher_mode;
pub mod error;
pub type Mode = self::cipher_mode::CipherMode;
pub type Cipher = self::cipher_mode::Sm4CipherMode;