Rust で主成分分析
Rust で重回帰に続き、今日は 主成分分析をやりたい。
- Rust(Nightly) rustc 1.6.0-nightly (d5fde83ae 2015-11-12)
- rust-csv 0.14.3 : CSV の読み込みに利用
- nalgebra 0.3.2 : 行列/ベクトルの処理に利用
rust-csv
での CSV ファイルの読み込み
今回は ローカルのCSV ファイル (iris.csv
) を読みとる処理を加える。CSV が単一の型しか含まない場合、 前の記事 のように読み取った値をそのまま Vec
に追加していけばよい。
iris
のように複数の型を含むデータから特定の型のみを抜き出すには以下のような処理を書く必要がある。
まず Reader.byte_records()
で各行のエントリをバイト列として読み取る。
let mut reader = csv::Reader::from_file(path).unwrap().has_headers(false); for record in reader.byte_records().map(|r| r.unwrap()) { println!("{:?}", record); } // [[53, 46, 49], [51, 46, 53], [49, 46, 52], [48, 46, 50], [73, 114, 105, 115, 45, 115, 101, 116, 111, 115, 97]] // [[52, 46, 57], [51, 46, 48], [49, 46, 52], [48, 46, 50], [73, 114, 105, 115, 45, 115, 101, 116, 111, 115, 97]] // ...
読み取ったバイト列を文字列に変換。
for record in reader.byte_records().map(|r| r.unwrap()) { for item in record.iter().map(|i| str::from_utf8(i).unwrap()) { println!("{:?}", item); } } // "5.1" // "3.5" // "1.4" // "0.2" // "Iris-setosa" // "4.9" // ...
文字列を f64
に変換し、成功したもののみ残す。
for record in reader.byte_records().map(|r| r.unwrap()) { // f64 に変換できる列のみ読み込み for item in record.iter().map(|i| str::from_utf8(i).unwrap()) { match f64::from_str(item) { Ok(v) => println!("{}", v), Err(e) => {} }; } } // 5.1 // 3.5 // 1.4 // 0.2 // 4.9 // ...
主成分分析
nalgebra
では 行列の固有ベクトル/固有値を nalgebra::eigen_qr
で計算できるが、この関数は 動的サイズの行列 DMat
では利用できない。そのため、データの入力次元にあわせた行列 Mat4
を使った。
現状の nalgebra::DMat
で任意の入力次元に対応した処理を書くためには固有値計算を自力でやる必要がある...と思う。
extern crate csv; extern crate nalgebra; extern crate num; use std::f64; use std::str; use std::str::FromStr; use std::vec::Vec; use nalgebra::{DVec, DMat, Mat4, Mean, ColSlice, Iterable, Transpose}; fn main() { // "iris" データを使用 // http://aima.cs.berkeley.edu/data/iris.csv let path = "./data/iris.csv"; // http://burntsushi.net/rustdoc/csv/ let mut reader = csv::Reader::from_file(path).unwrap().has_headers(false); let mut x:Vec<f64> = vec![]; let mut nrows: usize = 0; for record in reader.byte_records().map(|r| r.unwrap()) { // f64 に変換できる列のみ読み込み for item in record.iter().map(|i| str::from_utf8(i).unwrap()) { match f64::from_str(item) { Ok(v) => x.push(v), Err(e) => {} }; } nrows += 1; } let ncols = x.len() / nrows; // http://nalgebra.org/doc/nalgebra/struct.DMat.html let dx = DMat::from_row_vec(nrows, ncols, &x); // 主成分分析 let mut pca = PCA::new(ncols, true); pca.fit(&dx); // 結果は小数点以下 5 桁に丸めて表示 println!("主成分 (center=true)\n{:?}", &round(&mut pca.rotation, 5)); println!("主成分得点\n{:?}", &round(&mut pca.transform(&dx), 5)); } struct PCA { center: bool, // センタリングするかどうか nfeatures: usize, // 入力の次元 rotation: DMat<f64> // 主成分 } impl PCA { fn new(nfeatures: usize, center: bool) -> PCA { PCA { center: center, // 固有値計算に Mat4 を使うため、次元は 4 に固定 nfeatures: match nfeatures { 4 => nfeatures, _ => panic!("Number of features must be 4") }, rotation: DMat::new_ones(nfeatures, nfeatures) } } fn get_centers(&self, data: &DMat<f64>) -> DVec<f64> { // センタリングに用いるベクトルを計算 return match self.center { true => data.mean(), false => DVec::from_elem(self.nfeatures, 0.) }; } fn fit(&mut self, data: &DMat<f64>) { let nrows = data.nrows(); let centers = self.get_centers(&data); // 偏差平方和積和行列 // Dmat::from_fn で、行番号 i, 列番号 j を引数とする関数から各要素の値を生成できる let smx = DMat::from_fn(self.nfeatures, self.nfeatures, |i, j| sum_square(&data.col_slice(i, 0, nrows), &data.col_slice(j, 0, nrows), centers[i], centers[j])); // DMat では eigen_qr が使えないため Mat4 に変換 let smxv = smx.to_vec(); let smx4 = Mat4::new(smxv[0], smxv[1], smxv[2], smxv[3], smxv[4], smxv[5], smxv[6], smxv[7], smxv[8], smxv[9], smxv[10], smxv[11], smxv[12], smxv[13], smxv[14], smxv[15]); // 固有ベクトル、固有値を計算 let (evec, eval) = nalgebra::eigen_qr(&smx4, &1e-8, 10000); let mut vals: Vec<f64> = vec![]; for row in evec.transpose().as_array() { vals.extend(row); } self.rotation = DMat::from_row_vec(self.nfeatures, self.nfeatures, &vals); } fn transform(&mut self, data: &DMat<f64>) -> DMat<f64> { // 主成分得点を計算 let centers = self.get_centers(&data); let cdata = DMat::from_fn(data.nrows(), data.ncols(), |i, j| data[(i, j)] - centers[j]); return cdata * &self.rotation; } } fn sum_square(vec1: &DVec<f64>, vec2: &DVec<f64>, m1: f64, m2: f64) -> f64 { // 平方和 let mut val: f64 = 0.; for (v1, v2) in vec1.iter().zip(vec2.iter()) { val += (v1 - m1) * (v2 - m2); } return val; } fn round(data: &mut DMat<f64>, decimals: usize) -> DMat<f64> { // DMat の各要素を指定された有効数字で丸め let nrows = data.nrows(); let ncols = data.ncols(); let d: f64 = num::pow(10., decimals); // round() は常に整数で丸めるため、有効数字の処理を別途行う let vals: Vec<f64> = data.as_mut_vec().iter().map(|x| (x * d).round() / d).collect(); return DMat::from_col_vec(nrows, ncols, &vals); }
実行結果。
# 主成分 (center=true) # 0.36159 -0.65654 0.581 0.31725 # -0.08227 -0.72971 -0.59642 -0.32409 # 0.85657 0.17577 -0.07252 -0.47972 # 0.35884 0.07471 -0.54906 0.75112 # 主成分得点 # -2.68421 -0.32661 0.02151 0.00101 # -2.71539 0.16956 0.20352 0.0996 # -2.88982 0.13735 -0.02471 0.0193 # -2.74644 0.31112 -0.03767 -0.07596 # -2.72859 -0.33392 -0.09623 -0.06313 # ...
R の結果。
biris = read.csv("iris.csv", header=FALSE) # 主成分 prcomp(biris[-5]) # Standard deviations: # [1] 2.0554417 0.4921825 0.2802212 0.1538929 # # Rotation: # PC1 PC2 PC3 PC4 # V1 0.36158968 -0.65653988 0.58099728 0.3172545 # V2 -0.08226889 -0.72971237 -0.59641809 -0.3240944 # V3 0.85657211 0.17576740 -0.07252408 -0.4797190 # V4 0.35884393 0.07470647 -0.54906091 0.7511206 # 主成分得点 head(prcomp(biris[-5])$x, n = 5) # PC1 PC2 PC3 PC4 # [1,] -2.684207 -0.3266073 0.02151184 0.001006157 # [2,] -2.715391 0.1695568 0.20352143 0.099602424 # [3,] -2.889820 0.1373456 -0.02470924 0.019304543 # [4,] -2.746437 0.3111243 -0.03767198 -0.075955274 # [5,] -2.728593 -0.3339246 -0.09622970 -0.063128733
補足 R 組み込みの iris データと berkeley.edu からダウンロードできる iris データは一部のレコードが異なる。そのため、R の iris データを利用した場合は上と同じ結果にはならない。
iris[apply(iris[-5] != biris[-5], 1, any), ] # Sepal.Length Sepal.Width Petal.Length Petal.Width Species # 35 4.9 3.1 1.5 0.2 setosa # 38 4.9 3.6 1.4 0.1 setosa biris[apply(iris[-5] != biris[-5], 1, any), ] # V1 V2 V3 V4 V5 # 35 4.9 3.1 1.5 0.1 Iris-setosa # 38 4.9 3.1 1.5 0.1 Iris-setosa
2016/12/20追記 rulinalg
の SVD を使って書き直した。また、上のプログラムでは .transform
時のセンタリングを予測データから行なってしまっているが、これもおかしいので修正。