burnはRust製の深層学習(Deep Learning)フレームワークです。
現在活発に開発が進められており、最新のコードやサンプルには、Githubリポジトリからアクセスできます。
この記事では、burnの特徴や、Rustの記法を活用した面白さ、速攻でサンプルコードを動かす方法について解説します。
burnの特徴とメリット
3+1種類のバックエンドで、no_stdにも対応
burnは3種類のバックエンドを用意しており、feature
の設定により切り替えることができます。
- Torch
- CPU・GPUをともにサポートする
- Ndarray
- 最もシンプルであり、
no_std
での実装もサポートする
- 最もシンプルであり、
- WebGPU
- ブラウザ環境も考慮したGPUベースの計算をサポート
個人的には、Ndarrayバックエンドがno_std
開発をサポートしているのが嬉しくて、マイコンへの学習済みモデルの移植が捗りそうです。
さらに、これらのバックエンドとAutodiff
バックエンドを併用することで、(サポートされている場合)自動微分に対応できるようになります。
公式サンプルによると、Autodiff
は以下のようにADBackendDecorator
を使用して設定できます。
type Backend = NdArrayBackend<f32>;
let y = linear::<ADBackendDecorator<Backend>>(...);
データはTensorとして扱う
他の一般的な深層学習フレームワークと同様に、burnでもデータはTensor
として扱います。
burnに特徴的な記法として、Tensor
が格納するデータの型をBackend
を介して指定します。
// NdArrayをBackendに使用した、f32の2次元Tensor
Tensor<NdArrayBackend<f32>, 2>
deriveによりモデル構造やパラメータを指定
機械学習モデルのネットワーク構造や、学習パラメータという、フレームワークとしての軸となる部分を、Rustのderive
機能を活用して実装できるのが面白いです。
Module
deriveは以下のようにモデル構造を指定します。
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: GELU,
}
また、ハイパーパラメータの種類やデフォルト値は、Config
deriveを用いて指定できます。
#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
pub d_model: usize,
pub d_ff: usize,
#[config(default = 0.1)]
pub dropout: f64,
}
データセットと学習済みモデルが付いてくる
burn内のDataset
クレートから種々のデータセットを利用できます。
また、学習済みの機械学習モデルが複数アップデートされており、Import
クレートは、これらをあなたのプロジェクトに組み込む際に有用です。
サンプルコードを実行してみる
burnのリポジトリには、各バックエンドに対応したサンプルコードがexamples
で管理されています。
burnの機能を手元のPCで簡単に体験できるため、実際に動かしてみましょう。
Rustのアップデート
サンプルコードを実行する前に、Rustをアップデートしておきます。
...>rustup update
万が一アップデートが原因で以下のサンプルコードがエラーとなった場合は、以下の記事を参考にRustのバージョンを調整してください。
リポジトリを手元のPCにクローンする
適当なディレクトリに移動し、burnのリポジトリをclone
します。
...>git clone https://github.com/burn-rs/burn
MNISTのサンプルを実行する
今回はもっとも簡単そうな、MNISTデータセットの学習を行うサンプルコードを動かします。
(注!)サンプルコードの実行にはファイル生成等も伴うため、アクセス権確保のために管理者権限でコマンドを実行してください。
MNISTのサンプルはburn/examples/mnist
にあるので、ここに移動します。
...>cd burn/examples/mnist
そして、もっとも制限の少ないBackend
として、今回はNdArray
を使用してサンプルを実行します。
.../burn/examples/mnist>cargo run --example mnist --release --features ndarray
すると、以下のように学習が始まり、予測精度の変化をグラフとして表示します。
実行を中断したい場合はCtrl+C
を入力してください。
Comments