diff --git a/src/decryptor.rs b/src/decryptor.rs index d5ef603..0271221 100644 --- a/src/decryptor.rs +++ b/src/decryptor.rs @@ -17,11 +17,11 @@ pub struct Aes128GcmStreamDecryptor { } impl Aes128GcmStreamDecryptor { - pub fn new(key: [u8; 16]) -> Self { + pub fn new(key: [u8; 16], nonce: &[u8]) -> Self { let key = GenericArray::from(key); let aes = Aes128::new(&key); - Self { + let mut s = Self { crypto: aes, message_buffer: vec![], integrality_buffer: vec![], @@ -31,14 +31,12 @@ impl Aes128GcmStreamDecryptor { encryption_nonce: 0, adata_len: 0, message_len: 0, - } - } - - pub fn init_nonce(&mut self, nonce: &[u8]) { - let (ghash_key, normalized_nonce) = self.normalize_nonce(nonce); - self.ghash_key = ghash_key; - self.init_nonce = normalized_nonce; - self.encryption_nonce = normalized_nonce; + }; + let (ghash_key, normalized_nonce) = s.normalize_nonce(nonce); + s.ghash_key = ghash_key; + s.init_nonce = normalized_nonce; + s.encryption_nonce = normalized_nonce; + s } pub fn init_adata(&mut self, adata: &[u8]) { @@ -78,7 +76,7 @@ impl Aes128GcmStreamDecryptor { } pub fn finalize(&mut self) -> Result, String> { - let mut plaintext_message = vec![]; + let mut plaintext_message = Vec::with_capacity(16); let message_buffer_len = self.message_buffer.len(); if message_buffer_len > 16 { // last block and this block len is less than 128 bits diff --git a/src/encryptor.rs b/src/encryptor.rs index 5d1a1cd..355ed43 100644 --- a/src/encryptor.rs +++ b/src/encryptor.rs @@ -17,11 +17,11 @@ pub struct Aes128GcmStreamEncryptor { } impl Aes128GcmStreamEncryptor { - pub fn new(key: [u8; 16]) -> Self { + pub fn new(key: [u8; 16], nonce: &[u8]) -> Self { let key = GenericArray::from(key); let aes = Aes128::new(&key); - Self { + let mut s = Self { crypto: aes, message_buffer: vec![], integrality_buffer: vec![], @@ -31,14 +31,12 @@ impl Aes128GcmStreamEncryptor { encryption_nonce: 0, adata_len: 0, message_len: 0, - } - } - - pub fn init_nonce(&mut self, nonce: &[u8]) { - let (ghash_key, normalized_nonce) = self.normalize_nonce(nonce); - self.ghash_key = ghash_key; - self.init_nonce = normalized_nonce; - self.encryption_nonce = normalized_nonce; + }; + let (ghash_key, normalized_nonce) = s.normalize_nonce(nonce); + s.ghash_key = ghash_key; + s.init_nonce = normalized_nonce; + s.encryption_nonce = normalized_nonce; + s } pub fn init_adata(&mut self, adata: &[u8]) { @@ -78,7 +76,7 @@ impl Aes128GcmStreamEncryptor { } pub fn finalize(&mut self) -> (Vec, Vec) { - let mut encrypted_message = vec![]; + let mut encrypted_message = Vec::with_capacity(16); if !self.message_buffer.is_empty() { // last block and this block len is less than 128 bits self.encryption_nonce = inc_32(self.encryption_nonce); diff --git a/src/main.rs b/src/main.rs index 051fa14..81413f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,8 +42,7 @@ fn main() { cipher.encrypt_in_place(&nonce, &[], &mut ciphertext).unwrap(); println!("{}", hex::encode(ciphertext.as_slice())); - let mut aes128_gcm_stream_encryptor = Aes128GcmStreamEncryptor::new([0; 16]); - aes128_gcm_stream_encryptor.init_nonce(&[0u8; 12]); + let mut aes128_gcm_stream_encryptor = Aes128GcmStreamEncryptor::new([0; 16], &[0u8; 12]); aes128_gcm_stream_encryptor.init_adata(&[]); let o1 = aes128_gcm_stream_encryptor.next(&plaintext[0..21]); let o2 = aes128_gcm_stream_encryptor.next(&plaintext[21..64]); @@ -55,8 +54,7 @@ fn main() { println!("{}: E4", hex::encode(&o4)); println!("{} : TAG", hex::encode(&t)); - let mut aes128_gcm_stream_decryptor = Aes128GcmStreamDecryptor::new([0; 16]); - aes128_gcm_stream_decryptor.init_nonce(&[0u8; 12]); + let mut aes128_gcm_stream_decryptor = Aes128GcmStreamDecryptor::new([0; 16], &[0u8; 12]); let o1 = aes128_gcm_stream_decryptor.next(&hex::decode("0388dace60b6a392f328c2b971b2fe78f795aaab494b5923f7fd89ff948bc1e0200211214e7394da2089b6acd093abe0c94da219118e297d7b7ebcbcc9c388f28ade7d85a8c992f32a52151e1c2adceb7c6138e042").unwrap()); let o2_result = aes128_gcm_stream_decryptor.finalize(); println!("{}", hex::encode(&o1));