Files
simple-rust-tests/__machinelearning/rusty-machine/src/main.rs

37 lines
1.2 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use rusty_machine::linalg::{Matrix, Vector};
use rusty_machine::learning::gp::{GaussianProcess, ConstMean};
use rusty_machine::learning::toolkit::kernel::SquaredExp;
use rusty_machine::learning::SupModel;
fn main() {
// 首先获取一些数据。
// 一些示例训练数据。
let inputs = Matrix::new(3,3,vec![1.,1.,1.,2.,2.,2.,3.,3.,3.]);
let targets = Vector::new(vec![0.,1.,0.]);
// 一些示例测试数据。
let test_inputs = Matrix::new(2,3, vec![1.5,1.5,1.5,2.5,2.5,2.5]);
// 现在设置好我们的模组
// 这几乎是rusty-machine 中最复杂的模组了!
// 设置平方指数核函数,长度参数 2宽度参数 1。
let ker = SquaredExp::new(2., 1.);
// 零函数
let zero_mean = ConstMean::default();
// 用核函数,平均值, 噪声0.5来构建一个高斯过程。
let mut gp = GaussianProcess::new(ker, zero_mean, 0.5);
// 现在我们可以训练并且用这个模组进行预测了。
// 训练模组!
gp.train(&inputs, &targets).unwrap();
// 使用测试数据来测试预测。
let outputs = gp.predict(&test_inputs).unwrap();
println!("{:?}", outputs);
}