burn – Rust製DeepLearningフレームワークの紹介

ndarray
Sponsored

burnはRust製の深層学習(Deep Learning)フレームワークです。

現在活発に開発が進められており、最新のコードやサンプルには、Githubリポジトリからアクセスできます。

GitHub - tracel-ai/burn: Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals. ...

この記事では、burnの特徴や、Rustの記法を活用した面白さ、速攻でサンプルコードを動かす方法について解説します

burnの特徴とメリット

3+1種類のバックエンドで、no_stdにも対応

burnは3種類のバックエンドを用意しており、featureの設定により切り替えることができます。

  1. Torch
    • CPU・GPUをともにサポートする
  2. Ndarray
    • 最もシンプルであり、no_stdでの実装もサポートする
  3. WebGPU
    • ブラウザ環境も考慮したGPUベースの計算をサポート

個人的には、Ndarrayバックエンドがno_std開発をサポートしているのが嬉しくて、マイコンへの学習済みモデルの移植が捗りそうです。

さらに、これらのバックエンドとAutodiffバックエンドを併用することで、(サポートされている場合)自動微分に対応できるようになります。

公式サンプルによると、Autodiffは以下のようにADBackendDecoratorを使用して設定できます。

type Backend = NdArrayBackend<f32>;
let y = linear::<ADBackendDecorator<Backend>>(...);

データはTensorとして扱う

他の一般的な深層学習フレームワークと同様に、burnでもデータはTensorとして扱います。

burnに特徴的な記法として、Tensorが格納するデータの型をBackendを介して指定します。

// NdArrayをBackendに使用した、f32の2次元Tensor
Tensor<NdArrayBackend<f32>, 2>

deriveによりモデル構造やパラメータを指定

機械学習モデルのネットワーク構造や、学習パラメータという、フレームワークとしての軸となる部分を、Rustのderive機能を活用して実装できるのが面白いです。

Module deriveは以下のようにモデル構造を指定します。

#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: GELU,
}

また、ハイパーパラメータの種類やデフォルト値は、Config deriveを用いて指定できます。

#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
    pub d_model: usize,
    pub d_ff: usize,
    #[config(default = 0.1)]
    pub dropout: f64,
}

データセットと学習済みモデルが付いてくる

burn内のDatasetクレートから種々のデータセットを利用できます。

また、学習済みの機械学習モデルが複数アップデートされており、Importクレートは、これらをあなたのプロジェクトに組み込む際に有用です。

サンプルコードを実行してみる

burnのリポジトリには、各バックエンドに対応したサンプルコードがexamplesで管理されています。

burnの機能を手元のPCで簡単に体験できるため、実際に動かしてみましょう。

Rustのアップデート

サンプルコードを実行する前に、Rustをアップデートしておきます。

...>rustup update

万が一アップデートが原因で以下のサンプルコードがエラーとなった場合は、以下の記事を参考にRustのバージョンを調整してください。

Rustのバージョンを変更する方法
概要 現在利用されているRustのバージョンを調べ、異なるバージョンのRustをインストールして利用する方法について解説する。 この記事を読むことで、チームメイトや技術書の著者とRustのバージョンを合わせたり、ツールが正常に動作するように...

リポジトリを手元のPCにクローンする

適当なディレクトリに移動し、burnのリポジトリをcloneします。

...>git clone https://github.com/burn-rs/burn

MNISTのサンプルを実行する

今回はもっとも簡単そうな、MNISTデータセットの学習を行うサンプルコードを動かします。

注!)サンプルコードの実行にはファイル生成等も伴うため、アクセス権確保のために管理者権限でコマンドを実行してください。

MNISTのサンプルはburn/examples/mnistにあるので、ここに移動します。

...>cd burn/examples/mnist

そして、もっとも制限の少ないBackendとして、今回はNdArrayを使用してサンプルを実行します。

.../burn/examples/mnist>cargo run --example mnist --release --features ndarray

すると、以下のように学習が始まり、予測精度の変化をグラフとして表示します。

実行を中断したい場合はCtrl+Cを入力してください。

Comments