登録トークン数 (固定) :
NT (Number of Tokens)。
ID が n のトークン:
T(n)。
トークベクトルの次元 (固定):
D
ID が n のトークン ( T( n ) ) のトークンベクトル:,
TV( n )
「TV(n)」は,つぎのように読める:
1. トークンベクトル行列を,TV で表す。
2. これのn行ベクトル TV_n を,特に TV( n ) で表す。
入力テクストS のトークン数:
LEN(S)
整数 i, j に対し,pos_ij をつぎのように定義する:
pos_ij :
j が偶数のとき:pos_ij = sin( i / 10000^(j/D) )
j が奇数のとき:pos_ij = cos( i / 10000^((j-1)/D) )
L × D 行列の ( pos_ij ) を,つぎのように表す:
POS(L)
重み行列
D×D 行列 W_Q, W_K, W_V
重み行列
NT×D 行列 W_O
入力テクストS のトークン化
[ T(ID_1), ‥‥, T(ID_LEN(S)) ]
を,
[ t_1, ‥‥, t_m ] ( m = LEN(S) )
と簡略表記する。。
[ T(ID_1), ‥‥, T(ID_LEN(S)) ] に対応するトークンベクトルレル列
[ TV(ID_1), ‥‥, TV(ID_LEN(S)) ]
を
[ e_1, ‥‥, e_m ] ( m = LEN(S) )
と簡略表記する。。
e_1, ‥‥, e_m を上から下に並べるてできる LEN(S) × D 行列を,
E(S)
で表す。
入力層の出力
LEN(S)×D 行列
X(S) = ( x_ij ) = E(S) + POS(LEN(S))
x_i = e_i + pos_i
糸造り:レイヤー構造
レイヤー数を
NL (Number of Layers) で表す。
最終出力の LEN(S)×D 行列:
O = ( o_ij )
レイヤー:
x_i^(1) = x_i
↓
x_i^(2)
↓
:
↓
x_i^(NL)_
↓
o_i
x_i^(ℓ) → x_i^(ℓ+1) (ℓ< LEN(S) ), x_i^(NL) → o_i
のそれぞれが,レイヤー1つにあたる。
Xの糸 P = [ p_1, ‥‥, p_NT ] 生成の流れる:
テクストS
↓トークン分割
T = [ t_1, ‥‥, t_m ], m = LEN(S)
↓対応するトークンベクトル
E = [ e_1, ‥‥, e_m ]
↓位置エンコーディングを加算
X = [ x_1, ‥‥, x_m ]
│
├───────────┐
│ ( Self-Attention )
│ ┌──────┿──────┐
│ ↓線型変換 ↓ ↓
│ Q_i = x_i W_Q K_i = x_i W_K V_i = x_i W_V
│ └──┬───┘ │
│ ↓ │
│ α_i = sim( Q_i ; K_1, ‥‥, K_m ) │
│ │ │
│ └───┬──────┘
│ z_ij =α_i (V_j)^T
│ │
│← Residual ─────┘
│
│← LayerNorm (正規化)
│
├───────────┐
│ ( FFN )
│ ↓
│ z'_i^(ℓ)
│ │
│← Residual ─────┘
│
│← LayerNorm
↓
X^(2) = [ x^(2)_1, ‥‥, x^(2)_m ]
↓
:
↓
X^(n+1) = O = [ o_1, ‥‥, o_m ]
↓
logis
↓
P = [ p_1, ‥‥, p_NT ] ( Xの糸)
↓(誤差逆伝播)
Self-Attention
x_i の 3つの線形変換:
Q_i = x_i W_Q」
K_i = x_i W_Kけ」
V_i = x_i W_V」
α_i = sim( Q_i ; K_1, ‥‥, K_m )
= softmax( Q_i K_1/√D, ‥‥ , Q_i K_m/√D ) )
m×D 行列 Z = ( z_ij ) を,つぎのように定義する:
z_ij = α_i (V_j)^T
= α_i1 V_j1 + ‥‥ + α_im V_jm
(「確率分布α_i を用いた V_j の加重平均」)
FFN (FeedForward Network)
NT次元確率分布ベクトル p_i の導出
logits_i = W o_i^T
p_i = softmax( logits_i )
p_i の「正解ラベル」:
true_i
誤差(2つの確率分布 p_i, true_(ID_i) の比較)
loss_i = cross_entropy ( p_i, true_i )
誤差逆伝播
∇loss_i/∇x_i^(ℓ)
∇loss_i/∇TV(ID_i)
∇loss_i/∇W
など
順伝播から逆伝播へ ── 最急降下法の場合
TV(ID_i)
│
▼
x_i^(1) = TV(ID_i)
│
▼
x_i^(2), ‥‥ , x_i^(NL), o_i
│
▼
損失 loss_i が計算される
▲
│
∇loss_i/∇x_i^(NL), ‥‥ , ∇loss_i/∇x_i^(1) を計算
▲
│
∇loss_i/∇TV(ID_i) を計算(連鎖律)
▲
│
TV(ID_i) := TV(ID_i) - η・∇loss_i / ∇TV(ID_i)
η : 学習率 (learning rate)
|