37 lines
1.2 KiB
Rust
37 lines
1.2 KiB
Rust
use rusty_machine::linalg::{Matrix, Vector};
|
||
use rusty_machine::learning::gp::{GaussianProcess, ConstMean};
|
||
use rusty_machine::learning::toolkit::kernel::SquaredExp;
|
||
use rusty_machine::learning::SupModel;
|
||
|
||
fn main() {
|
||
// 首先获取一些数据。
|
||
|
||
// 一些示例训练数据。
|
||
let inputs = Matrix::new(3,3,vec![1.,1.,1.,2.,2.,2.,3.,3.,3.]);
|
||
let targets = Vector::new(vec![0.,1.,0.]);
|
||
|
||
// 一些示例测试数据。
|
||
let test_inputs = Matrix::new(2,3, vec![1.5,1.5,1.5,2.5,2.5,2.5]);
|
||
|
||
// 现在设置好我们的模组
|
||
// 这几乎是rusty-machine 中最复杂的模组了!
|
||
|
||
// 设置平方指数核函数,长度参数 2,宽度参数 1。
|
||
let ker = SquaredExp::new(2., 1.);
|
||
|
||
// 零函数
|
||
let zero_mean = ConstMean::default();
|
||
|
||
// 用核函数,平均值, 噪声0.5来构建一个高斯过程。
|
||
let mut gp = GaussianProcess::new(ker, zero_mean, 0.5);
|
||
|
||
// 现在我们可以训练并且用这个模组进行预测了。
|
||
|
||
// 训练模组!
|
||
gp.train(&inputs, &targets).unwrap();
|
||
|
||
// 使用测试数据来测试预测。
|
||
let outputs = gp.predict(&test_inputs).unwrap();
|
||
|
||
println!("{:?}", outputs);
|
||
} |