From 645e1572188fd4a21ab9807302d6b72b1605e618 Mon Sep 17 00:00:00 2001 From: Hatter Jiang Date: Sat, 26 Aug 2023 22:30:44 +0800 Subject: [PATCH] feat: aes gcm stream works --- src/lib.rs | 212 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 25 +++++-- 2 files changed, 232 insertions(+), 5 deletions(-) create mode 100644 src/lib.rs diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..3e6de01 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,212 @@ +use aes::Aes128; +use aes::cipher::{Block, BlockEncrypt, KeyInit}; +use aes::cipher::generic_array::GenericArray; + +pub struct Aes128GcmStream { + crypto: Aes128, + message_buffer: Vec, + integrality_buffer: Vec, + ghash_key: u128, + ghash_val: u128, + init_nonce: u128, + encryption_nonce: u128, + adata_len: usize, + message_len: usize, +} + +impl Aes128GcmStream { + pub fn new(key: [u8; 16]) -> Self { + let key = GenericArray::from(key); + let aes = Aes128::new(&key); + + Self { + crypto: aes, + message_buffer: vec![], + integrality_buffer: vec![], + ghash_key: 0, + ghash_val: 0, + init_nonce: 0, + encryption_nonce: 0, + adata_len: 0, + message_len: 0, + } + } + + pub fn init_nonce(&mut self, nonce: &[u8]) { + let (ghash_key, normalized_nonce) = self.normalize_nonce(nonce); + self.ghash_key = ghash_key; + self.init_nonce = normalized_nonce; + self.encryption_nonce = normalized_nonce; + } + + pub fn init_adata(&mut self, adata: &[u8]) { + self.integrality_buffer.extend_from_slice(adata); + self.adata_len += adata.len(); + + let adata_bit_len = self.adata_len * 8; + let v = 128 * ((adata_bit_len + 128 - 1) / 128) - adata_bit_len; + self.integrality_buffer.extend_from_slice(&vec![0x00; v / 8]); + } + + pub fn next(&mut self, bytes: &[u8]) -> Vec { + self.message_buffer.extend_from_slice(bytes); + let message_buffer_slice = self.message_buffer.as_slice(); + let message_buffer_len = message_buffer_slice.len(); + if message_buffer_len < 16 { + return vec![]; + } + let blocks_count = message_buffer_len / 16; + let mut encrypted_message = vec![]; + for i in 0..blocks_count { + self.encryption_nonce = inc_32(self.encryption_nonce); + let mut ctr = self.encryption_nonce.to_be_bytes(); + let block = Block::::from_mut_slice(&mut ctr); + self.crypto.encrypt_block(block); + let chunk = &message_buffer_slice[i * 16..(i + 1) * 16]; + let y = u8to128(chunk) ^ u8to128(&block.as_slice()); + encrypted_message.extend_from_slice(&y.to_be_bytes()); + } + self.message_buffer = message_buffer_slice[blocks_count * 16..].to_vec(); + self.integrality_buffer.extend_from_slice(&encrypted_message); + self.message_len += encrypted_message.len(); + + self.update_integrality_buffer(); + + encrypted_message + } + + pub fn finalize(&mut self) -> (Vec, Vec) { + let mut encrypted_message = vec![]; + if !self.message_buffer.is_empty() { + // last block and this block len is less than 128 bits + self.encryption_nonce = inc_32(self.encryption_nonce); + let mut ctr = self.encryption_nonce.to_be_bytes(); + let block = Block::::from_mut_slice(&mut ctr); + self.crypto.encrypt_block(block); + + let chunk = self.message_buffer.as_slice(); + let msb = msb_s(chunk.len() * 8, block.as_slice()); + let y = u8to128(chunk) ^ u8to128(&msb); + encrypted_message.extend_from_slice(&y.to_be_bytes()[16 - chunk.len()..16]); + self.integrality_buffer.extend_from_slice(&encrypted_message); + self.message_len += encrypted_message.len(); + } + let adata_bit_len = self.adata_len * 8; + let message_bit_len = self.message_len * 8; + let u = 128 * ((message_bit_len + 128 - 1) / 128) - message_bit_len; + self.integrality_buffer.extend_from_slice(&vec![0x00; u / 8]); + self.integrality_buffer.extend_from_slice(&(adata_bit_len as u64).to_be_bytes()); + self.integrality_buffer.extend_from_slice(&(message_bit_len as u64).to_be_bytes()); + + self.update_integrality_buffer(); + assert!(self.integrality_buffer.is_empty()); + + let tag = self.calculate_tag(); + + (encrypted_message, tag) + } + + fn calculate_tag(&mut self) -> Vec { + let mut bs = self.init_nonce.to_be_bytes().clone(); + let block = Block::::from_mut_slice(&mut bs); + self.crypto.encrypt_block(block); + let tag_trunk = self.ghash_val.to_be_bytes(); + let y = u8to128(&tag_trunk) ^ u8to128(&block.as_slice()); + y.to_be_bytes().to_vec() + } + + fn update_integrality_buffer(&mut self) { + let integrality_buffer_slice = self.integrality_buffer.as_slice(); + let integrality_buffer_slice_len = integrality_buffer_slice.len(); + if integrality_buffer_slice_len >= 16 { + let i_blocks_count = integrality_buffer_slice_len / 16; + for i in 0..i_blocks_count { + let buf = &integrality_buffer_slice[i * 16..(i + 1) * 16]; + self.ghash_val = gmul_128(self.ghash_val ^ u8to128(buf), self.ghash_key) + } + self.integrality_buffer = integrality_buffer_slice[i_blocks_count * 16..].to_vec(); + } + } + + fn ghash_key(&mut self) -> u128 { + let mut block = [0u8; 16]; + let block = Block::::from_mut_slice(&mut block); + self.crypto.encrypt_block(block); + u8to128(&block.as_slice()) + } + + fn normalize_nonce(&mut self, nonce_bytes: &[u8]) -> (u128, u128) { + let ghash_key = self.ghash_key(); + let nonce = u8to128(nonce_bytes); + let normalized_nonce = match nonce_bytes.len() == 12 { + true => { + nonce << 32 | 0x00000001 + } + false => { + let mut iv_padding = vec![]; + // s = 128[len(iv) / 128] - len(iv) + let s = 128 * (((nonce_bytes.len() * 8) + 128 - 1) / 128) - (nonce_bytes.len() * 8); + iv_padding.push(nonce << s); + iv_padding.push((nonce_bytes.len() * 8) as u128); + ghash(ghash_key, &iv_padding) + } + }; + (ghash_key, normalized_nonce) + } +} + + +// R = 11100001 || 0(120) +const R: u128 = 0b11100001 << 120; + +fn gmul_128(x: u128, y: u128) -> u128 { + let mut z = 0u128; + let mut v = y; + for i in (0..128).rev() { + let xi = (x >> i) & 1; + if xi != 0 { + z ^= v; + } + v = match v & 1 == 0 { + true => { v >> 1 } + false => { (v >> 1) ^ R } + }; + } + z +} + +fn ghash(key: u128, messages: &[u128]) -> u128 { + let mut y = 0u128; + for i in 0..messages.len() { + let yi = gmul_128(y ^ messages[i], key); + y = yi; + } + y +} + +fn u8to128(bytes: &[u8]) -> u128 { + bytes.iter().rev().enumerate().fold(0, |acc, (i, &byte)| { + acc | (byte as u128) << (i * 8) + }) +} + +fn msb_s(s: usize, bytes: &[u8]) -> Vec { + let mut result = vec![]; + let n = s / 8; + let remain = s % 8; + for i in 0..n { + result.push(bytes[i]); + } + if remain > 0 { + result.push(bytes[n] >> (8 - remain)); + } + result +} + +// incs(X)=MSBlen(X)-s(X) || [int(LSBs(X))+1 mod 2^s]s +fn inc_32(bits: u128) -> u128 { + let msb = bits >> 32; + let mut lsb = (bits & 0xffffffff) as u32; + lsb = lsb.wrapping_add(1); + msb << 32 | lsb as u128 +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 90d22ee..5212229 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,8 @@ use aes::Aes128; use aes::cipher::{Block, BlockEncrypt, KeyInit}; use aes::cipher::generic_array::GenericArray; use aes_gcm::{AeadInPlace, Aes128Gcm, Key}; -use aes_gcm::aead::{Aead, AeadMutInPlace, Nonce}; +use aes_gcm::aead::{Aead, Nonce}; +use aes_gcm_stream::Aes128GcmStream; pub struct GCM { aes: Aes128, @@ -132,6 +133,7 @@ impl GCM { let adata_len = adata.len() * 8; let u = 128 * ((message_len + 128 - 1) / 128) - message_len; let v = 128 * ((adata_len + 128 - 1) / 128) - adata_len; + // println!("message_len, adata_len: {}, {}", message.len(), adata.len()); // println!("u, v: {}, {}", u, v); // println!("j0 = {:02x?}", j0); let enc = self.gctr(inc_32(j0), &message); @@ -145,7 +147,7 @@ impl GCM { bit_string.extend_from_slice(&(message_len as u64).to_be_bytes()); // println!("len = {}, bit_string[u8] = {:02x?}", bit_string.len(), bit_string); let bit_string: Vec = bit_string.chunks(16).map(|it| u8to128(it)).collect(); - // println!("bit_string[u128] = {:02x?}", bit_string); + // println!("bit_string[u128] = {:02x?}", bit_string); let s = ghash(ghash_key, &bit_string).to_be_bytes(); //println!("{:02x?}", s); let tag = self.gctr(j0, &s); @@ -159,12 +161,12 @@ impl GCM { fn main() { let key = [0u8; 16]; let nonce = [0u8; 12]; - let plaintext = [0u8; 64]; + let plaintext = [0u8; 69]; let mut gcm = GCM::new(key); let (tag, enc) = gcm.ae(&nonce, &[], &plaintext); println!("{}", hex::encode(&enc)); - println!("{}", hex::encode(&tag)); + println!("{} : TAG", hex::encode(&tag)); // --------------------------------------------------------------------------------------- @@ -184,9 +186,22 @@ fn main() { let mut ciphertext = vec![0u8; plaintext.len()]; let tag = cipher.encrypt_in_place_detached(&nonce, &[], ciphertext.as_mut_slice()).unwrap(); println!("{}", hex::encode(&ciphertext)); - println!("{}", hex::encode(tag.as_slice())); + println!("{} : TAG", hex::encode(tag.as_slice())); let mut ciphertext = plaintext.to_vec(); cipher.encrypt_in_place(&nonce, &[], &mut ciphertext).unwrap(); println!("{}", hex::encode(ciphertext.as_slice())); + + let mut aes128_gcm_stream = Aes128GcmStream::new([0; 16]); + aes128_gcm_stream.init_nonce(&[0u8; 12]); + aes128_gcm_stream.init_adata(&[]); + let o1 = aes128_gcm_stream.next(&plaintext[0..21]); + let o2 = aes128_gcm_stream.next(&plaintext[21..64]); + let o3 = aes128_gcm_stream.next(&[0; 5]); + let (o4, t) = aes128_gcm_stream.finalize(); + println!("{}: E1", hex::encode(&o1)); + println!("{}: E2", hex::encode(&o2)); + println!("{}: E3", hex::encode(&o3)); + println!("{}: E4", hex::encode(&o4)); + println!("{} : TAG", hex::encode(&t)); } \ No newline at end of file