feat: ml-kem generate

This commit is contained in:
2025-09-27 11:32:42 +08:00
parent 3d29fe6a6d
commit 9f544e3cb7
3 changed files with 184 additions and 42 deletions

View File

@@ -3,7 +3,7 @@ use crate::keychain::{KeychainKey, KeychainKeyValue};
use crate::keyutil::{KeyAlgorithmId, KeyUri, YubikeyHmacEncSoftKey}; use crate::keyutil::{KeyAlgorithmId, KeyUri, YubikeyHmacEncSoftKey};
use crate::pivutil::FromStr; use crate::pivutil::FromStr;
use crate::util::base64_encode; use crate::util::base64_encode;
use crate::{cmd_hmac_encrypt, cmdutil, ecdsautil, hmacutil, pbeutil, rsautil, util, yubikeyutil}; use crate::{cmd_hmac_encrypt, cmdutil, ecdsautil, hmacutil, mlkemutil, pbeutil, rsautil, util, yubikeyutil};
use clap::{App, Arg, ArgMatches, SubCommand}; use clap::{App, Arg, ArgMatches, SubCommand};
use rust_util::util_clap::{Command, CommandError}; use rust_util::util_clap::{Command, CommandError};
use std::collections::BTreeMap; use std::collections::BTreeMap;
@@ -23,7 +23,7 @@ impl Command for CommandImpl {
.long("type") .long("type")
.required(true) .required(true)
.takes_value(true) .takes_value(true)
.help("Key type (e.g. p256, p384, p521, rsa1024, rsa2048, rsa3072, rsa4096)"), .help("Key type (e.g. p256, p384, p521, rsa1024, rsa2048, rsa3072, rsa4096, mlkem512, mlkem768, mlkem1024)"),
) )
.arg(cmdutil::build_with_hmac_encrypt_arg()) .arg(cmdutil::build_with_hmac_encrypt_arg())
.arg(cmdutil::build_with_pbe_encrypt_arg()) .arg(cmdutil::build_with_pbe_encrypt_arg())
@@ -59,12 +59,20 @@ impl Command for CommandImpl {
"rsa4096" => Some(4096), "rsa4096" => Some(4096),
_ => None, _ => None,
}; };
let mlkem_len: Option<usize> = match key_type.as_str() {
"mlkem512" => Some(512),
"mlkem768" => Some(768),
"mlkem1024" => Some(1024),
_ => None,
};
let (pkcs8_base64, secret_key_pem, public_key_pem, public_key_der, jwk_key) = let (pkcs8_base64, secret_key_pem, public_key_pem, public_key_der, jwk_key) =
if let Some(ecdsa_algorithm) = ecdsa_algorithm { if let Some(ecdsa_algorithm) = ecdsa_algorithm {
ecdsautil::generate_ecdsa_keypair(ecdsa_algorithm)? ecdsautil::generate_ecdsa_keypair(ecdsa_algorithm)?
} else if let Some(rsa_bit_size) = rsa_bit_size { } else if let Some(rsa_bit_size) = rsa_bit_size {
rsautil::generate_rsa_keypair(rsa_bit_size)? rsautil::generate_rsa_keypair(rsa_bit_size)?
} else if let Some(mlkem_len) = mlkem_len {
mlkemutil::generate_mlkem_keypair(mlkem_len)?
} else { } else {
return simple_error!("Unsupported key type: {}", key_type); return simple_error!("Unsupported key type: {}", key_type);
}; };
@@ -130,7 +138,9 @@ impl Command for CommandImpl {
} }
json.insert("public_key_pem", public_key_pem); json.insert("public_key_pem", public_key_pem);
json.insert("public_key_base64", public_key_base64); json.insert("public_key_base64", public_key_base64);
if !jwk_key.is_empty() {
json.insert("public_key_jwk", jwk_key); json.insert("public_key_jwk", jwk_key);
}
util::print_pretty_json(&json); util::print_pretty_json(&json);
} else { } else {
@@ -149,8 +159,10 @@ impl Command for CommandImpl {
} }
information!("Public key PEM:\n{}", public_key_pem); information!("Public key PEM:\n{}", public_key_pem);
information!("Public key Base64:\n{}\n", public_key_base64); information!("Public key Base64:\n{}\n", public_key_base64);
if !jwk_key.is_empty() {
information!("Public key JWK:\n{}", jwk_key); information!("Public key JWK:\n{}", jwk_key);
} }
}
Ok(None) Ok(None)
} }

View File

@@ -1,40 +1,151 @@
use crate::util::base64_encode; use crate::rsautil::rsa_public_key_to_jwk;
use crate::util::{base64_encode, to_pem};
use ecdsa::elliptic_curve::pkcs8::LineEnding;
use ml_kem::kem::{Decapsulate, Encapsulate}; use ml_kem::kem::{Decapsulate, Encapsulate};
use ml_kem::{EncodedSizeUser, KemCore, MlKem768}; use ml_kem::{EncodedSizeUser, KemCore, MlKem1024, MlKem512, MlKem768};
use rust_util::XResult; use rust_util::XResult;
use std::convert::TryInto; use std::convert::TryInto;
// #[test] pub fn generate_mlkem_keypair(len: usize) -> XResult<(String, String, String, Vec<u8>, String)> {
pub fn generate_ml_kem_768() -> XResult<()> { let (dk_private, ek_public) = match len {
let mut rng = rand::thread_rng(); 512 => generate_ml_kem_512(),
let (dk, ek) = <MlKem768 as KemCore>::generate(&mut rng); 768 => generate_ml_kem_768(),
println!("dk: {}", base64_encode(&dk.as_bytes().0.to_vec())); 1024 => generate_ml_kem_1024(),
println!("ek: {}", base64_encode(ek.as_bytes().0.to_vec())); _ => return simple_error!("Invalid ML-KEM={}", len),
};
let ek_bytes = dk.as_bytes().0.to_vec(); let secret_key_der_base64 = base64_encode(&dk_private);
let dk = <MlKem768 as KemCore>::DecapsulationKey::from_bytes(&opt_result!( let secret_key_pem = to_pem(&format!("ML-KEM-{} PRIVATE KEY", len), &dk_private);
ek_bytes.as_slice().try_into(), let public_key_pem = to_pem(&format!("ML-KEM-{} PUBLIC KEY", len), &ek_public);
"Parse decapsulation key failed: {}" let public_key_der = ek_public;
)); let jwk_ec_key = "".to_string();
Ok((
let (encoded_ciphertext, shared_key) = opt_result!( secret_key_der_base64,
ek.encapsulate(&mut rng), secret_key_pem,
"Encapsulation key encapsulate failed: {:?}" public_key_pem,
); public_key_der,
println!( jwk_ec_key,
"encoded_ciphertext: {}", ))
base64_encode(&encoded_ciphertext.0.to_vec()) }
);
println!("shared_key: {}", base64_encode(&shared_key.0.to_vec())); pub fn generate_ml_kem_512() -> (Vec<u8>, Vec<u8>) {
let mut rng = rand::thread_rng();
let k_bytes = encoded_ciphertext.0.to_vec(); let (dk_private, ek_public) = <MlKem512 as KemCore>::generate(&mut rng);
let shared_key_2 = opt_result!( (
dk.decapsulate(opt_result!( dk_private.as_bytes().0.to_vec(),
&k_bytes.as_slice().try_into(), ek_public.as_bytes().0.to_vec(),
"Parse encoded ciphertext failed: {}" )
)), }
"Decapsulation key decapsulate failed: {:?}"
); pub fn generate_ml_kem_768() -> (Vec<u8>, Vec<u8>) {
println!("shared_key2: {}", base64_encode(&shared_key_2.0.to_vec())); let mut rng = rand::thread_rng();
Ok(()) let (dk_private, ek_public) = <MlKem768 as KemCore>::generate(&mut rng);
(
dk_private.as_bytes().0.to_vec(),
ek_public.as_bytes().0.to_vec(),
)
}
pub fn generate_ml_kem_1024() -> (Vec<u8>, Vec<u8>) {
let mut rng = rand::thread_rng();
let (dk_private, ek_public) = <MlKem1024 as KemCore>::generate(&mut rng);
(
dk_private.as_bytes().0.to_vec(),
ek_public.as_bytes().0.to_vec(),
)
}
pub fn parse_encapsulation_key_512_public_then_encapsulate(
bytes: &[u8],
) -> XResult<(Vec<u8>, Vec<u8>)> {
let ek = <MlKem512 as KemCore>::EncapsulationKey::from_bytes(&opt_result!(
bytes.try_into(),
"Parse encapsulation key 512 failed: {}"
));
let (ciphertext, shared_key) = opt_result!(
ek.encapsulate(&mut rand::thread_rng()),
"Encapsulation key 512 encapsulate failed: {:?}"
);
Ok((ciphertext.0.to_vec(), shared_key.0.to_vec()))
}
pub fn parse_encapsulation_key_768_public_then_encapsulate(
bytes: &[u8],
) -> XResult<(Vec<u8>, Vec<u8>)> {
let ek = <MlKem768 as KemCore>::EncapsulationKey::from_bytes(&opt_result!(
bytes.try_into(),
"Parse encapsulation key 768 failed: {}"
));
let (ciphertext, shared_key) = opt_result!(
ek.encapsulate(&mut rand::thread_rng()),
"Encapsulation key 768 encapsulate failed: {:?}"
);
Ok((ciphertext.0.to_vec(), shared_key.0.to_vec()))
}
pub fn parse_encapsulation_key_1024_public_then_encapsulate(
bytes: &[u8],
) -> XResult<(Vec<u8>, Vec<u8>)> {
let ek = <MlKem1024 as KemCore>::EncapsulationKey::from_bytes(&opt_result!(
bytes.try_into(),
"Parse encapsulation key 1024 failed: {}"
));
let (ciphertext, shared_key) = opt_result!(
ek.encapsulate(&mut rand::thread_rng()),
"Encapsulation key 1024 encapsulate failed: {:?}"
);
Ok((ciphertext.0.to_vec(), shared_key.0.to_vec()))
}
pub fn parse_decapsulate_key_512_private_then_decapsulate(
key_bytes: &[u8],
ciphertext_bytes: &[u8],
) -> XResult<Vec<u8>> {
let dk = <MlKem512 as KemCore>::DecapsulationKey::from_bytes(&opt_result!(
key_bytes.try_into(),
"Parse decapsulation key 512 failed: {}"
));
let shared_key = opt_result!(
dk.decapsulate(opt_result!(
ciphertext_bytes.try_into(),
"Parse encoded ciphertext 512 failed: {}"
)),
"Decapsulation key 512 decapsulate failed: {:?}"
);
Ok(shared_key.0.to_vec())
}
pub fn parse_decapsulate_key_768_private_then_decapsulate(
key_bytes: &[u8],
ciphertext_bytes: &[u8],
) -> XResult<Vec<u8>> {
let dk = <MlKem768 as KemCore>::DecapsulationKey::from_bytes(&opt_result!(
key_bytes.try_into(),
"Parse decapsulation key 768 failed: {}"
));
let shared_key = opt_result!(
dk.decapsulate(opt_result!(
ciphertext_bytes.try_into(),
"Parse encoded ciphertext 768 failed: {}"
)),
"Decapsulation key 768 decapsulate failed: {:?}"
);
Ok(shared_key.0.to_vec())
}
pub fn parse_decapsulate_key_1024_private_then_decapsulate(
key_bytes: &[u8],
ciphertext_bytes: &[u8],
) -> XResult<Vec<u8>> {
let dk = <MlKem768 as KemCore>::DecapsulationKey::from_bytes(&opt_result!(
key_bytes.try_into(),
"Parse decapsulation key 1024 failed: {}"
));
let shared_key = opt_result!(
dk.decapsulate(opt_result!(
ciphertext_bytes.try_into(),
"Parse encoded ciphertext 1024 failed: {}"
)),
"Decapsulation key 1024 decapsulate failed: {:?}"
);
Ok(shared_key.0.to_vec())
} }

View File

@@ -1,8 +1,8 @@
use std::fs; use std::fs;
use std::io::Read; use std::io::Read;
use base64::{DecodeError, Engine};
use base64::engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD}; use base64::engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD};
use base64::{DecodeError, Engine};
use rust_util::XResult; use rust_util::XResult;
use serde::Serialize; use serde::Serialize;
@@ -22,6 +22,21 @@ pub fn base64_uri_decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeErro
URL_SAFE_NO_PAD.decode(input) URL_SAFE_NO_PAD.decode(input)
} }
pub fn to_pem(header: &str, bytes: &[u8]) -> String {
let mut buf = String::new();
buf.push_str(&format!("-----BEGIN {}-----\n", header));
let bas64ed = base64_encode(bytes);
let len = bas64ed.len();
for (i, c) in bas64ed.chars().enumerate() {
buf.push(c);
if i > 0 && i < len && i % 64 == 0 {
buf.push('\n');
}
}
buf.push_str(&format!("\n-----END {}-----\n", header));
buf
}
pub fn try_decode(input: &str) -> XResult<Vec<u8>> { pub fn try_decode(input: &str) -> XResult<Vec<u8>> {
match hex::decode(input) { match hex::decode(input) {
Ok(v) => Ok(v), Ok(v) => Ok(v),
@@ -31,7 +46,7 @@ pub fn try_decode(input: &str) -> XResult<Vec<u8>> {
Ok(v) => Ok(v), Ok(v) => Ok(v),
Err(e) => simple_error!("decode hex or base64 error: {}", e), Err(e) => simple_error!("decode hex or base64 error: {}", e),
}, },
} },
} }
} }
@@ -46,7 +61,11 @@ pub fn read_file_or_stdin(file: &str) -> XResult<Vec<u8>> {
if file == "-" { if file == "-" {
read_stdin() read_stdin()
} else { } else {
Ok(opt_result!(fs::read(file), "Read file: {} failed: {}", file)) Ok(opt_result!(
fs::read(file),
"Read file: {} failed: {}",
file
))
} }
} }