対象の配列から任意の1次元を削除したゼロ配列を作成する ー Rust-Ndarray例文集(第2回)

ndarray
Sponsored

やりたいこと

前回、対象の配列と同じ次元・大きさを持つゼロ配列を作成する方法について述べた。今回は全く同じサイズの配列ではなく、指定した次元については欠落した配列を作成する。つまり、Shape=[3, 4, 5]の配列と、欠落させる次元(軸)番号 1 が与えられた際に、Shape=[3, 5]のゼロ配列を取得したい。

解決策

特定の次元(軸)を落とす場合には、remove_axis()メソッドを用いる。

let a = /* 任意の配列 */;
let axis = /* 任意の次元(軸) */;
let z = Array::zeros(a.dim().clone()).remove_axis(axis);

その他、こうした「もともとの配列より1次元小さいものを返す」系の操作は、ndarray::ArrayBase::mean_axis, ndarray::ArrayBase::std_axis等の実装が参考になる。

活用事例

配列の特定の次元(軸)についての代表値を求め、その結果として1次元小さい配列を返したい場合などに重宝する。

先ほどのndarray::ArrayBase::mean_axis, ndarray::ArrayBase::std_axisあたりが良い例で、代表値の計算対象となった次元(軸)の情報は平均値や標準偏差としてまとめられるので、その次元(軸)が欠落した1次元小さい配列ができる。

平均(1次モーメント)・標準偏差(分散→2次モーメント)についてはすでに実装があるので、3次モーメントの歪度skew_axisについて実装してみた。この関数が行うのは以下の作業である。

  1. 答えを格納する場所として、selfから次元(軸)axisを落としたゼロ配列ansを作成する。
  2. selfのaxisに沿った平均muを求める。
  3. selfのaxisに沿った標準偏差sigmaを求める。
  4. selfのaxisに沿って、歪度の総和部分を計算し、ansに格納する。
  5. ansの各要素に歪度の定数をかける。
  6. ansを返す。

ただし、歪度(不偏歪度)は以下の式で求められる。

ここで、 \(\mu, \sigma\) はそれぞれ平均と標準偏差である。

なお、この関数はOptionを返し、指定されたAxisが存在しない場合や、 \(n\leq 2\) のときにNoneとなる。

use ndarray::*;

trait SkewArr<D: RemoveAxis> {
    fn skew_axis(&self, axis: Axis) -> Option<Array<f64, D::Smaller>>;
}

impl<D: RemoveAxis> SkewArr<D> for Array<f64, D> {
    fn skew_axis(&self, axis: Axis) -> Option<Array<f64, D::Smaller>> {
        let index = axis.index();
        if self.ndim() < index {
            return None;
        }

        let n = self.shape()[index] as f64;
        if n <= 2.0 {
            return None;
        }

        // 1.
        let mut ans = Array::zeros(self.dim().clone()).remove_axis(axis);
        // 2.
        let mu = self.mean_axis(axis).unwrap();
        // 3.
        let sigma = self.std_axis(axis, 1.0);

        // 4.
        for self_sub in self.axis_iter(axis) {
            Zip::from(&mut ans)
                .and(self_sub)
                .and(&mu)
                .and(&sigma)
                .for_each(|a, &b, &c, &d| {
                    *a += ((b - c) / d).powi(3);
                })
        }

        // 5.
        let c = n / ((n - 1.0) * (n - 2.0));
        Zip::from(&mut ans)
            .for_each(|a| *a *= c);

        // 6.
        Some(ans)
    }
}


fn main() {
    let a = arr2(&[
        [0.0, 1.0, 2.0],
        [3.0, 4.0, 5.0],
        [6.0, 7.0, 8.0],
        [9.0, 0.0, 1.0],
    ]);

    println!("a.skew_axis(Axis(0)).unwrap() = ");
    println!("{:?}", a.skew_axis(Axis(0)).unwrap());
    println!("********************************");
    println!("a.skew_axis(Axis(1)).unwrap() = ");
    println!("{:?}", a.skew_axis(Axis(1)).unwrap());
}

コードについて

このシリーズで取り扱ったコードは、

GitHub - doraneko94/ndarray-tutorial: RustのNdarrayクレートの実用例について書いていく。
RustのNdarrayクレートの実用例について書いていく。. Contribute to doraneko94/ndarray-tutorial development by creating an account on GitHub.

にて公開されている。

Comments