対象の配列の任意の1次元の長さが、任意の大きさだけ短いゼロ配列を作成する ー Rust-Ndarray例文集(第3回)

ndarray
Sponsored

やりたいこと

前回は、対象の配列から任意の次元(軸)をまるごと落とす方法について述べた。今回は次元(軸)の数はそのままに、任意の次元の長さ(大きさ)を短くする。つまり、Shape=[3, 4, 5]の配列と、対象とする次元(軸)1、短くする長さ(大きさ)3が与えられた際に、Shape=[3, 1, 5]の配列を取得したい。

なお、今回のテーマは「対象の配列を、任意の1次元(軸)に沿ってスライスとして切り出す」方法として一般化されるため、主にその観点から述べる。

解決策

特定の次元(軸)に沿って、配列の一部をスライスとして切り出す場合には、slice_axis()メソッドを用いる。

今回は、

  1. 対象の配列 a と同じ次元・大きさのゼロ配列を作成
  2. 任意の次元(軸)axis に沿って、先頭から数えて l 個目の要素からスライスを取得
  3. 取得したスライスに to_owned() メソッドを使って、新たな配列として保存

という手順を踏んだ。

let a = /* 任意の配列 */;
let axis = /* 任意の次元(軸) */;
let l = /* 任意の長さ */
let z = Array::zeros(a.dim().clone()).slice_axis(axis, Slice::from(l..)).to_owned();

これは以下のケースの一般化である。配列 a が3次元配列と決まっている場合、

let a = arr3(&[
        [[0, 1], [2, 3], [4, 5], [6, 7]],
        [[8, 9], [0, 1], [2, 3], [4, 5]],
        [[6, 7], [8, 9], [0, 1], [2, 3]],
]);
let axis = Axis(1);
let l = 1;
println!("{:?}", a.slice(s![.., l.., ..]));
// -> [[[2, 3], [4, 5], [6, 7]],
       [[0, 1], [2, 3], [4, 5]],
       [[8, 9], [0, 1], [2, 3]]]

というように、slice()メソッドを使って配列の一部を取得できる。しかし、このメソッドの引数内ではすべての次元(軸)について切り出す範囲を指定してやる必要があるので、「切り出したい軸は決まっているが、配列全体の次元がいくつになるかはわからない/決まっていない」場合には使うことができない。そこで、

println!("{:?}", a.slice_axis(axis, Slice::from(l..)));

を用いると、切り出したい次元のみの情報から、同じことが実現できる。

この手法ではスライスの技術を用いているので、slice_axis() メソッドでは、

let m = /* 任意の usize */;
let n = /* 任意の usize */;
println!("{:?}", a.slice_axis(axis, Slice::from(m..n)));

とすれば、次元(軸)axisに沿ってインデックス m ~ n-1 の範囲を切り出すことができる。

活用事例

画像認識系の機械学習でよく使われるような、畳込みを実装したい場合に用いることがあるかもしれない。

畳み込みは、隣接した複数個の要素の和を求める演算である。例えば、 [1, 2, 3, 4, 5] という長さ 5 の配列に対して幅 3 の畳み込みを行うと、 [6, 9, 12] という 5 - 3 + 1 = 3 個の要素からなる配列が得られる。

ここでは任意の次元の配列に対して、特定の次元(軸)に沿って畳み込みを計算する conv_skew 関数を実装してみた。この関数が行うのは以下の作業である。

  1. 答えを格納する場所として、selfの次元(軸)axisの要素数が 畳み込みの幅 size -1 だけ少ないゼロ配列ansを作成する。
  2. selfのaxisに沿って、先頭から size-1 個の要素の和を計算する。これを配列 s に格納する。
  3. 以下を、selfの size 番目の要素から、最後の要素まで繰り返す。
    1. (i 回目の繰り返し ( i = 0, 1, 2, ... ))
    2. self の axis に沿って i+size 番目の各要素を s に足す。
    3. s の各要素を、 ans の axis に沿って i 番目の各要素として格納する。
    4. self の axis に沿って i 番目の各要素を s から引く。
  4. ansを返す。

なお、この関数はOptionを返し、size が 0 の場合、指定されたAxisが存在しない場合、 axis方向の大きさが size よりも小さい場合にNoneとなる。

use ndarray::*;

trait ConvArr: Sized {
    fn conv_axis(&self, axis: Axis, size: usize) -> Option<Self>;
}

impl<D: RemoveAxis> ConvArr for Array<f64, D> {
    fn conv_axis(&self, axis: Axis, size: usize) -> Option<Self> {
        if size == 0 {
            return None;
        }

        let index = axis.index();
        if self.ndim() < index {
            return None;
        }

        let n = self.shape()[index];
        if n < size {
            return None;
        }

        // 1.
        let mut ans = Array::zeros(self.dim().clone())
                            .slice_axis(axis, Slice::from(size-1..))
                            .to_owned();

        // 2.
        let mut s = self.slice_axis(axis, Slice::from(..size-1))
                        .sum_axis(axis);

        // 3.
        for (mut ans_sub, (tail_sub, head_sub)) in ans.axis_iter_mut(axis)
                                                .zip(
                                                    self.axis_iter(axis).skip(size-1)
                                                        .zip(self.axis_iter(axis))
                                                )
        {
            // 3.1

            // 3.2
            Zip::from(&mut s)
                .and(tail_sub)
                .for_each(|a, &b| *a += b);

            // 3.3
            Zip::from(&mut ans_sub)
                .and(&s)
                .for_each(|a, &b| *a = b);

            // 3.4
            Zip::from(&mut s)
                .and(head_sub)
                .for_each(|a, &b| *a -= b);
        }

        // 3.5
        Some(ans)
    }
}

fn main() {
    let a = arr2(&[
        [0.0, 1.0, 2.0],
        [3.0, 4.0, 5.0],
        [6.0, 7.0, 8.0],
        [9.0, 0.0, 1.0],
    ]);

    println!(" a.conv_axis(Axis(0), 3).unwrap() = ");
    println!("{:?}", a.conv_axis(Axis(0), 3).unwrap());
    println!("************************************");
    println!(" a.conv_axis(Axis(1), 2).unwrap() = ");
    println!("{:?}", a.conv_axis(Axis(1), 2).unwrap());
}

コードについて

このシリーズで取り扱ったコードは、

GitHub - doraneko94/ndarray-tutorial: RustのNdarrayクレートの実用例について書いていく。
RustのNdarrayクレートの実用例について書いていく。. Contribute to doraneko94/ndarray-tutorial development by creating an account on GitHub.

にて公開されている。

Comments