機器之心分析師網絡
作者:周宇
編輯:H4O
本文重點探討分布式學習框架中針對隨機梯度下降(SGD)算法的拜占庭問題。
分布式學習(Distributed Learning)是一種廣泛應用的大規模模型訓練框架。在分布式學習框架中,服務器通過聚合在分布式設備中訓練的本地模型(local model)來利用各個設備的計算能力。分布式機器學習的典型架構——參數服務器架構中,包括一個服務器(稱為參數服務器 - Parameter Server,PS)和多個計算節點(workers,也稱為節點 nodes)[1]。其中,隨機梯度下降(Stochastic Gradient Descent,SGD)是一種廣泛使用的、效果較好的分布式優化算法。在每一輪中,每個計算節點根據不同的本地數據集在它的設備上訓練一個本地模型,并與服務器共享最終的參數。然后,服務器聚合不同計算節點的參數,并通過與計算節點共享得到的組合參數來啟動下一輪訓練。關于基于 SGD 優化的分布式框架的網絡結構(包括:層數、類型、大小等)在訓練開始之前由所有計算節點共同商定確認。
近年來,分布式學習的安全性越來越受到人們的關注,其中,最重要的就是拜占庭威脅模型。在拜占庭威脅模型中,計算節點可以任意和惡意地行事。機器之心在前期的文章中也探討過分布式學習中的拜占庭問題,主要針對聯邦學習中的拜占庭問題。在這篇文章中,我們重點探討的是分布式學習框架中針對隨機梯度下降(SGD)算法的拜占庭問題。如圖 1 所示,在 SGD 學習框架中,一些惡意節點(Malicious worker)向服務器發送拜占庭梯度(Byzantine Gradient),而不是計算得到的真實梯度,而拜占庭梯度可以是任意值。惡意節點可以控制計算節點設備本身,也可以控制節點和服務器之間的通信。以 Algorithm 1 中提出的同步 SGD(sync-SGD)協議為例 [4]。攻擊者(惡意節點)在使其效果最大化的時間內(即在 Algorithm 1 的第 6 行和第 7 行之間)干擾進程。在此期間,攻擊者可以將節點 i 中的參數(p_i)^(t+1) 替換為任意值,然后將此任意值發送到服務器中。攻擊方法在設置參數值的方式上有所不同,而防御方法則試圖識別損壞的參數并丟棄它們。Algorithm 1 使用平均值(第 8 行中的 AggregationRule( ))聚合計算節點參數。
圖 1. SGD 學習框架工作流程 [3]
本文所討論的分布式學習的核心是這樣一個假設:經過訓練的網絡參數是獨立同分布的(Independent and identically distributed,i.i.d.)