損失 loss_i が計算されたところから,つぎのプロセスが開始される:
《損失を小さくする方向に,
トークンベクトルと各種重み行列を更新する》
特に,トークンベクトルと重み行列の更新は,別々のものではない。
このプロセスは,学習材テクストS の入力から o_i の出力までの間に昇ってきたレイヤーを逆に降るものなので,逆伝播 (backpropagation) と呼ぶ。
──翻って,レイヤー上昇のプロセスを,順伝播 (forward pass) と呼ぶ。
順伝播は,計算アルゴリズムの内容が線型代数なので,本テクストでもこの内容の概略を扱えた。
しかし逆伝播の方は,実際値から勾配を計算することが内容であり,計算処理の式表現も,つぎのようなラベル的なものにしかならない:
∇loss_i/∇x_i^(ℓ)
∇loss_i/∇TV(ID_i)
∇loss_i/∇W
など
これらの式は,関数計算のように見えるが,実際数値に対し差分を適当な区間でとることがベースであり,内容は離散的である。
こうして,「ラベル的」以上の計算式にはならない。
というわけで,本テクストは,「逆伝播」の概念を示すにとどめるとする。
- 「損失」
「こういう文脈ではこのトークンはこういう意味で使われるべきだった」
- 「勾配」
- 「勾配降下」
- パラメータの更新では、損失を最小化する必要がある。
この手法の一つが,勾配降下法。
- 損失の (各パラメータに対する) 勾配を計算する
- 「損失ベクトルが,少しずつが変形されて,Embedding 層 (トークンベクトル,各種重みベクトルの所在) まで行く。」
- 「逆伝播」
- 勾配を効率的に計算するためのアルゴリズム。
- Chain Rule (連鎖律) で勾配を展開
- 誤差逆伝播」による勾配の流れ
ℓ層の W_Q, W_K, W_V を例に:
出力トークン o_i の誤差 → 誤差が伝播
↓
LayerNorm, 残差接続, FeedForward
↓
z_i^(ℓ), y_i^(ℓ), z'_i^(ℓ) など中間表現を経由し
↓
Self-Attention の中へ
↓
W_Q^(ℓ), W_K^(ℓ), W_V^(ℓ) に勾配が到達
- PyTorchにおける逆伝播
- PyTorchでは、自動微分と呼ばれる技術を用いて、逆伝播を自動的に計算することができる。
- 自動微分とは、計算グラフを構築し、各ノードの勾配を自動的に計算する技術。
- PyTorchでは、テンソル演算がグラフとして表現される。
そして、backward() 関数を呼び出すことで、グラフ上の各ノードの勾配を自動的に計算することができる。
- 損失関数の勾配の計算を,backward() 関数で,自動的に行う
- 計算された勾配を用いて、パラメータを更新
- コード
python
import torch
# トークン埋め込み表の定義 (embedding)
embedding = nn.Embedding(vocab_size, d_model)
# トークン表:語彙数 × ベクトル次元
# 入力テクストのトークン分割
input_ids = [ ‥‥ ]
# e.g., input_ids = [3, 7, 128]
# 順伝播 (forward)
output = embedding(input_ids)
# input_ids に対応するトークンベクトル表の作成
# 損失計算 (loss)
loss = loss_fn(output, targets)
# 誤差逆伝播 (backwaord)
loss.backward()
# パラメータの更新
optimizer.step()
- embedding.weight
- 勾配を受け取るパラメータ
- 更新対象のトークンベクトルは,embedding.weight の一部になっている。
- optimizer.step() は,embedding.weight を更新する。
- これにより,勾配を送ったところ (ここでは,input_ids のトークンベクトル) が更新される。
- スパース更新
- 「使われたIDの行だけ」が更新
- PyTorchの nn.Embedding は,内部的に,スパースな更新をする
- optimizer.step() を行うと、つぎのコードで,使われたトークンIDに対応するベクトルだけが更新される:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer.step()
- スパース更新は、巨大語彙(数万〜数百万)でも効率的に学習できる鍵となっている。
|