From 9f544e3cb72924aa8673cf1fa02e18e71b1dc9b8 Mon Sep 17 00:00:00 2001 From: Hatter Jiang Date: Sat, 27 Sep 2025 11:32:42 +0800 Subject: [PATCH] feat: ml-kem generate --- src/cmd_keypair_generate.rs | 20 +++- src/mlkemutil.rs | 181 +++++++++++++++++++++++++++++------- src/util.rs | 25 ++++- 3 files changed, 184 insertions(+), 42 deletions(-) diff --git a/src/cmd_keypair_generate.rs b/src/cmd_keypair_generate.rs index d12d93d..591824e 100644 --- a/src/cmd_keypair_generate.rs +++ b/src/cmd_keypair_generate.rs @@ -3,7 +3,7 @@ use crate::keychain::{KeychainKey, KeychainKeyValue}; use crate::keyutil::{KeyAlgorithmId, KeyUri, YubikeyHmacEncSoftKey}; use crate::pivutil::FromStr; use crate::util::base64_encode; -use crate::{cmd_hmac_encrypt, cmdutil, ecdsautil, hmacutil, pbeutil, rsautil, util, yubikeyutil}; +use crate::{cmd_hmac_encrypt, cmdutil, ecdsautil, hmacutil, mlkemutil, pbeutil, rsautil, util, yubikeyutil}; use clap::{App, Arg, ArgMatches, SubCommand}; use rust_util::util_clap::{Command, CommandError}; use std::collections::BTreeMap; @@ -23,7 +23,7 @@ impl Command for CommandImpl { .long("type") .required(true) .takes_value(true) - .help("Key type (e.g. p256, p384, p521, rsa1024, rsa2048, rsa3072, rsa4096)"), + .help("Key type (e.g. p256, p384, p521, rsa1024, rsa2048, rsa3072, rsa4096, mlkem512, mlkem768, mlkem1024)"), ) .arg(cmdutil::build_with_hmac_encrypt_arg()) .arg(cmdutil::build_with_pbe_encrypt_arg()) @@ -59,12 +59,20 @@ impl Command for CommandImpl { "rsa4096" => Some(4096), _ => None, }; + let mlkem_len: Option = match key_type.as_str() { + "mlkem512" => Some(512), + "mlkem768" => Some(768), + "mlkem1024" => Some(1024), + _ => None, + }; let (pkcs8_base64, secret_key_pem, public_key_pem, public_key_der, jwk_key) = if let Some(ecdsa_algorithm) = ecdsa_algorithm { ecdsautil::generate_ecdsa_keypair(ecdsa_algorithm)? } else if let Some(rsa_bit_size) = rsa_bit_size { rsautil::generate_rsa_keypair(rsa_bit_size)? + } else if let Some(mlkem_len) = mlkem_len { + mlkemutil::generate_mlkem_keypair(mlkem_len)? } else { return simple_error!("Unsupported key type: {}", key_type); }; @@ -130,7 +138,9 @@ impl Command for CommandImpl { } json.insert("public_key_pem", public_key_pem); json.insert("public_key_base64", public_key_base64); - json.insert("public_key_jwk", jwk_key); + if !jwk_key.is_empty() { + json.insert("public_key_jwk", jwk_key); + } util::print_pretty_json(&json); } else { @@ -149,7 +159,9 @@ impl Command for CommandImpl { } information!("Public key PEM:\n{}", public_key_pem); information!("Public key Base64:\n{}\n", public_key_base64); - information!("Public key JWK:\n{}", jwk_key); + if !jwk_key.is_empty() { + information!("Public key JWK:\n{}", jwk_key); + } } Ok(None) diff --git a/src/mlkemutil.rs b/src/mlkemutil.rs index 84110ec..112c8e9 100644 --- a/src/mlkemutil.rs +++ b/src/mlkemutil.rs @@ -1,40 +1,151 @@ -use crate::util::base64_encode; +use crate::rsautil::rsa_public_key_to_jwk; +use crate::util::{base64_encode, to_pem}; +use ecdsa::elliptic_curve::pkcs8::LineEnding; use ml_kem::kem::{Decapsulate, Encapsulate}; -use ml_kem::{EncodedSizeUser, KemCore, MlKem768}; +use ml_kem::{EncodedSizeUser, KemCore, MlKem1024, MlKem512, MlKem768}; use rust_util::XResult; use std::convert::TryInto; -// #[test] -pub fn generate_ml_kem_768() -> XResult<()> { - let mut rng = rand::thread_rng(); - let (dk, ek) = ::generate(&mut rng); - println!("dk: {}", base64_encode(&dk.as_bytes().0.to_vec())); - println!("ek: {}", base64_encode(ek.as_bytes().0.to_vec())); - - let ek_bytes = dk.as_bytes().0.to_vec(); - let dk = ::DecapsulationKey::from_bytes(&opt_result!( - ek_bytes.as_slice().try_into(), - "Parse decapsulation key failed: {}" - )); - - let (encoded_ciphertext, shared_key) = opt_result!( - ek.encapsulate(&mut rng), - "Encapsulation key encapsulate failed: {:?}" - ); - println!( - "encoded_ciphertext: {}", - base64_encode(&encoded_ciphertext.0.to_vec()) - ); - println!("shared_key: {}", base64_encode(&shared_key.0.to_vec())); - - let k_bytes = encoded_ciphertext.0.to_vec(); - let shared_key_2 = opt_result!( - dk.decapsulate(opt_result!( - &k_bytes.as_slice().try_into(), - "Parse encoded ciphertext failed: {}" - )), - "Decapsulation key decapsulate failed: {:?}" - ); - println!("shared_key2: {}", base64_encode(&shared_key_2.0.to_vec())); - Ok(()) +pub fn generate_mlkem_keypair(len: usize) -> XResult<(String, String, String, Vec, String)> { + let (dk_private, ek_public) = match len { + 512 => generate_ml_kem_512(), + 768 => generate_ml_kem_768(), + 1024 => generate_ml_kem_1024(), + _ => return simple_error!("Invalid ML-KEM={}", len), + }; + let secret_key_der_base64 = base64_encode(&dk_private); + let secret_key_pem = to_pem(&format!("ML-KEM-{} PRIVATE KEY", len), &dk_private); + let public_key_pem = to_pem(&format!("ML-KEM-{} PUBLIC KEY", len), &ek_public); + let public_key_der = ek_public; + let jwk_ec_key = "".to_string(); + Ok(( + secret_key_der_base64, + secret_key_pem, + public_key_pem, + public_key_der, + jwk_ec_key, + )) +} + +pub fn generate_ml_kem_512() -> (Vec, Vec) { + let mut rng = rand::thread_rng(); + let (dk_private, ek_public) = ::generate(&mut rng); + ( + dk_private.as_bytes().0.to_vec(), + ek_public.as_bytes().0.to_vec(), + ) +} + +pub fn generate_ml_kem_768() -> (Vec, Vec) { + let mut rng = rand::thread_rng(); + let (dk_private, ek_public) = ::generate(&mut rng); + ( + dk_private.as_bytes().0.to_vec(), + ek_public.as_bytes().0.to_vec(), + ) +} + +pub fn generate_ml_kem_1024() -> (Vec, Vec) { + let mut rng = rand::thread_rng(); + let (dk_private, ek_public) = ::generate(&mut rng); + ( + dk_private.as_bytes().0.to_vec(), + ek_public.as_bytes().0.to_vec(), + ) +} + +pub fn parse_encapsulation_key_512_public_then_encapsulate( + bytes: &[u8], +) -> XResult<(Vec, Vec)> { + let ek = ::EncapsulationKey::from_bytes(&opt_result!( + bytes.try_into(), + "Parse encapsulation key 512 failed: {}" + )); + let (ciphertext, shared_key) = opt_result!( + ek.encapsulate(&mut rand::thread_rng()), + "Encapsulation key 512 encapsulate failed: {:?}" + ); + Ok((ciphertext.0.to_vec(), shared_key.0.to_vec())) +} + +pub fn parse_encapsulation_key_768_public_then_encapsulate( + bytes: &[u8], +) -> XResult<(Vec, Vec)> { + let ek = ::EncapsulationKey::from_bytes(&opt_result!( + bytes.try_into(), + "Parse encapsulation key 768 failed: {}" + )); + let (ciphertext, shared_key) = opt_result!( + ek.encapsulate(&mut rand::thread_rng()), + "Encapsulation key 768 encapsulate failed: {:?}" + ); + Ok((ciphertext.0.to_vec(), shared_key.0.to_vec())) +} + +pub fn parse_encapsulation_key_1024_public_then_encapsulate( + bytes: &[u8], +) -> XResult<(Vec, Vec)> { + let ek = ::EncapsulationKey::from_bytes(&opt_result!( + bytes.try_into(), + "Parse encapsulation key 1024 failed: {}" + )); + let (ciphertext, shared_key) = opt_result!( + ek.encapsulate(&mut rand::thread_rng()), + "Encapsulation key 1024 encapsulate failed: {:?}" + ); + Ok((ciphertext.0.to_vec(), shared_key.0.to_vec())) +} + +pub fn parse_decapsulate_key_512_private_then_decapsulate( + key_bytes: &[u8], + ciphertext_bytes: &[u8], +) -> XResult> { + let dk = ::DecapsulationKey::from_bytes(&opt_result!( + key_bytes.try_into(), + "Parse decapsulation key 512 failed: {}" + )); + let shared_key = opt_result!( + dk.decapsulate(opt_result!( + ciphertext_bytes.try_into(), + "Parse encoded ciphertext 512 failed: {}" + )), + "Decapsulation key 512 decapsulate failed: {:?}" + ); + Ok(shared_key.0.to_vec()) +} + +pub fn parse_decapsulate_key_768_private_then_decapsulate( + key_bytes: &[u8], + ciphertext_bytes: &[u8], +) -> XResult> { + let dk = ::DecapsulationKey::from_bytes(&opt_result!( + key_bytes.try_into(), + "Parse decapsulation key 768 failed: {}" + )); + let shared_key = opt_result!( + dk.decapsulate(opt_result!( + ciphertext_bytes.try_into(), + "Parse encoded ciphertext 768 failed: {}" + )), + "Decapsulation key 768 decapsulate failed: {:?}" + ); + Ok(shared_key.0.to_vec()) +} + +pub fn parse_decapsulate_key_1024_private_then_decapsulate( + key_bytes: &[u8], + ciphertext_bytes: &[u8], +) -> XResult> { + let dk = ::DecapsulationKey::from_bytes(&opt_result!( + key_bytes.try_into(), + "Parse decapsulation key 1024 failed: {}" + )); + let shared_key = opt_result!( + dk.decapsulate(opt_result!( + ciphertext_bytes.try_into(), + "Parse encoded ciphertext 1024 failed: {}" + )), + "Decapsulation key 1024 decapsulate failed: {:?}" + ); + Ok(shared_key.0.to_vec()) } diff --git a/src/util.rs b/src/util.rs index 9567fa3..aaa05b0 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,8 +1,8 @@ use std::fs; use std::io::Read; -use base64::{DecodeError, Engine}; use base64::engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}; +use base64::{DecodeError, Engine}; use rust_util::XResult; use serde::Serialize; @@ -22,6 +22,21 @@ pub fn base64_uri_decode>(input: T) -> Result, DecodeErro URL_SAFE_NO_PAD.decode(input) } +pub fn to_pem(header: &str, bytes: &[u8]) -> String { + let mut buf = String::new(); + buf.push_str(&format!("-----BEGIN {}-----\n", header)); + let bas64ed = base64_encode(bytes); + let len = bas64ed.len(); + for (i, c) in bas64ed.chars().enumerate() { + buf.push(c); + if i > 0 && i < len && i % 64 == 0 { + buf.push('\n'); + } + } + buf.push_str(&format!("\n-----END {}-----\n", header)); + buf +} + pub fn try_decode(input: &str) -> XResult> { match hex::decode(input) { Ok(v) => Ok(v), @@ -31,7 +46,7 @@ pub fn try_decode(input: &str) -> XResult> { Ok(v) => Ok(v), Err(e) => simple_error!("decode hex or base64 error: {}", e), }, - } + }, } } @@ -46,7 +61,11 @@ pub fn read_file_or_stdin(file: &str) -> XResult> { if file == "-" { read_stdin() } else { - Ok(opt_result!(fs::read(file), "Read file: {} failed: {}", file)) + Ok(opt_result!( + fs::read(file), + "Read file: {} failed: {}", + file + )) } }