やりたいこと
前回、対象の配列と同じ次元・大きさを持つゼロ配列を作成する方法について述べた。今回は全く同じサイズの配列ではなく、指定した次元については欠落した配列を作成する。つまり、Shape=[3, 4, 5]の配列と、欠落させる次元(軸)番号 1 が与えられた際に、Shape=[3, 5]のゼロ配列を取得したい。
解決策
特定の次元(軸)を落とす場合には、remove_axis()メソッドを用いる。
let a = /* 任意の配列 */;
let axis = /* 任意の次元(軸) */;
let z = Array::zeros(a.dim().clone()).remove_axis(axis);
その他、こうした「もともとの配列より1次元小さいものを返す」系の操作は、ndarray::ArrayBase::mean_axis, ndarray::ArrayBase::std_axis等の実装が参考になる。
活用事例
配列の特定の次元(軸)についての代表値を求め、その結果として1次元小さい配列を返したい場合などに重宝する。
先ほどのndarray::ArrayBase::mean_axis, ndarray::ArrayBase::std_axisあたりが良い例で、代表値の計算対象となった次元(軸)の情報は平均値や標準偏差としてまとめられるので、その次元(軸)が欠落した1次元小さい配列ができる。
平均(1次モーメント)・標準偏差(分散→2次モーメント)についてはすでに実装があるので、3次モーメントの歪度skew_axisについて実装してみた。この関数が行うのは以下の作業である。
- 答えを格納する場所として、selfから次元(軸)axisを落としたゼロ配列ansを作成する。
- selfのaxisに沿った平均muを求める。
- selfのaxisに沿った標準偏差sigmaを求める。
- selfのaxisに沿って、歪度の総和部分を計算し、ansに格納する。
- ansの各要素に歪度の定数をかける。
- ansを返す。
ただし、歪度(不偏歪度)は以下の式で求められる。
$$\mathrm{skew}=\frac{n}{(n-1)(n-2)}\sum_{i=1}^{n}\left(\frac{x_{i}-\mu}{\sigma}\right)^{3}$$
ここで、 \(\mu, \sigma\) はそれぞれ平均と標準偏差である。
なお、この関数はOptionを返し、指定されたAxisが存在しない場合や、 \(n\leq 2\) のときにNoneとなる。
use ndarray::*;
trait SkewArr<D: RemoveAxis> {
fn skew_axis(&self, axis: Axis) -> Option<Array<f64, D::Smaller>>;
}
impl<D: RemoveAxis> SkewArr<D> for Array<f64, D> {
fn skew_axis(&self, axis: Axis) -> Option<Array<f64, D::Smaller>> {
let index = axis.index();
if self.ndim() < index {
return None;
}
let n = self.shape()[index] as f64;
if n <= 2.0 {
return None;
}
// 1.
let mut ans = Array::zeros(self.dim().clone()).remove_axis(axis);
// 2.
let mu = self.mean_axis(axis).unwrap();
// 3.
let sigma = self.std_axis(axis, 1.0);
// 4.
for self_sub in self.axis_iter(axis) {
Zip::from(&mut ans)
.and(self_sub)
.and(&mu)
.and(&sigma)
.for_each(|a, &b, &c, &d| {
*a += ((b - c) / d).powi(3);
})
}
// 5.
let c = n / ((n - 1.0) * (n - 2.0));
Zip::from(&mut ans)
.for_each(|a| *a *= c);
// 6.
Some(ans)
}
}
fn main() {
let a = arr2(&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 0.0, 1.0],
]);
println!("a.skew_axis(Axis(0)).unwrap() = ");
println!("{:?}", a.skew_axis(Axis(0)).unwrap());
println!("********************************");
println!("a.skew_axis(Axis(1)).unwrap() = ");
println!("{:?}", a.skew_axis(Axis(1)).unwrap());
}
コードについて
このシリーズで取り扱ったコードは、
にて公開されている。
コメント