やりたいこと
前回は、対象の配列から任意の次元(軸)をまるごと落とす方法について述べた。今回は次元(軸)の数はそのままに、任意の次元の長さ(大きさ)を短くする。つまり、Shape=[3, 4, 5]の配列と、対象とする次元(軸)1、短くする長さ(大きさ)3が与えられた際に、Shape=[3, 1, 5]の配列を取得したい。
なお、今回のテーマは「対象の配列を、任意の1次元(軸)に沿ってスライスとして切り出す」方法として一般化されるため、主にその観点から述べる。
解決策
特定の次元(軸)に沿って、配列の一部をスライスとして切り出す場合には、slice_axis()メソッドを用いる。
今回は、
- 対象の配列 a と同じ次元・大きさのゼロ配列を作成
- 任意の次元(軸)axis に沿って、先頭から数えて l 個目の要素からスライスを取得
- 取得したスライスに to_owned() メソッドを使って、新たな配列として保存
という手順を踏んだ。
let a = /* 任意の配列 */;
let axis = /* 任意の次元(軸) */;
let l = /* 任意の長さ */
let z = Array::zeros(a.dim().clone()).slice_axis(axis, Slice::from(l..)).to_owned();
これは以下のケースの一般化である。配列 a が3次元配列と決まっている場合、
let a = arr3(&[
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[8, 9], [0, 1], [2, 3], [4, 5]],
[[6, 7], [8, 9], [0, 1], [2, 3]],
]);
let axis = Axis(1);
let l = 1;
println!("{:?}", a.slice(s![.., l.., ..]));
// -> [[[2, 3], [4, 5], [6, 7]],
[[0, 1], [2, 3], [4, 5]],
[[8, 9], [0, 1], [2, 3]]]
というように、slice()メソッドを使って配列の一部を取得できる。しかし、このメソッドの引数内ではすべての次元(軸)について切り出す範囲を指定してやる必要があるので、「切り出したい軸は決まっているが、配列全体の次元がいくつになるかはわからない/決まっていない」場合には使うことができない。そこで、
println!("{:?}", a.slice_axis(axis, Slice::from(l..)));
を用いると、切り出したい次元のみの情報から、同じことが実現できる。
この手法ではスライスの技術を用いているので、slice_axis() メソッドでは、
let m = /* 任意の usize */;
let n = /* 任意の usize */;
println!("{:?}", a.slice_axis(axis, Slice::from(m..n)));
とすれば、次元(軸)axisに沿ってインデックス m ~ n-1 の範囲を切り出すことができる。
活用事例
画像認識系の機械学習でよく使われるような、畳込みを実装したい場合に用いることがあるかもしれない。
畳み込みは、隣接した複数個の要素の和を求める演算である。例えば、 [1, 2, 3, 4, 5] という長さ 5 の配列に対して幅 3 の畳み込みを行うと、 [6, 9, 12] という 5 - 3 + 1 = 3 個の要素からなる配列が得られる。
ここでは任意の次元の配列に対して、特定の次元(軸)に沿って畳み込みを計算する conv_skew 関数を実装してみた。この関数が行うのは以下の作業である。
- 答えを格納する場所として、selfの次元(軸)axisの要素数が 畳み込みの幅 size -1 だけ少ないゼロ配列ansを作成する。
- selfのaxisに沿って、先頭から size-1 個の要素の和を計算する。これを配列 s に格納する。
- 以下を、selfの size 番目の要素から、最後の要素まで繰り返す。
- (i 回目の繰り返し ( i = 0, 1, 2, ... ))
- self の axis に沿って i+size 番目の各要素を s に足す。
- s の各要素を、 ans の axis に沿って i 番目の各要素として格納する。
- self の axis に沿って i 番目の各要素を s から引く。
- ansを返す。
なお、この関数はOptionを返し、size が 0 の場合、指定されたAxisが存在しない場合、 axis方向の大きさが size よりも小さい場合にNoneとなる。
use ndarray::*;
trait ConvArr: Sized {
fn conv_axis(&self, axis: Axis, size: usize) -> Option<Self>;
}
impl<D: RemoveAxis> ConvArr for Array<f64, D> {
fn conv_axis(&self, axis: Axis, size: usize) -> Option<Self> {
if size == 0 {
return None;
}
let index = axis.index();
if self.ndim() < index {
return None;
}
let n = self.shape()[index];
if n < size {
return None;
}
// 1.
let mut ans = Array::zeros(self.dim().clone())
.slice_axis(axis, Slice::from(size-1..))
.to_owned();
// 2.
let mut s = self.slice_axis(axis, Slice::from(..size-1))
.sum_axis(axis);
// 3.
for (mut ans_sub, (tail_sub, head_sub)) in ans.axis_iter_mut(axis)
.zip(
self.axis_iter(axis).skip(size-1)
.zip(self.axis_iter(axis))
)
{
// 3.1
// 3.2
Zip::from(&mut s)
.and(tail_sub)
.for_each(|a, &b| *a += b);
// 3.3
Zip::from(&mut ans_sub)
.and(&s)
.for_each(|a, &b| *a = b);
// 3.4
Zip::from(&mut s)
.and(head_sub)
.for_each(|a, &b| *a -= b);
}
// 3.5
Some(ans)
}
}
fn main() {
let a = arr2(&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 0.0, 1.0],
]);
println!(" a.conv_axis(Axis(0), 3).unwrap() = ");
println!("{:?}", a.conv_axis(Axis(0), 3).unwrap());
println!("************************************");
println!(" a.conv_axis(Axis(1), 2).unwrap() = ");
println!("{:?}", a.conv_axis(Axis(1), 2).unwrap());
}
コードについて
このシリーズで取り扱ったコードは、
にて公開されている。
Comments