CTC 教育サービス
[IT研修]注目キーワード Python UiPath(RPA) 最新技術動向 Microsoft Azure Docker Kubernetes
今回からは、2019年に公開された論文「Applied Federated Learning: Improving Google Keyboard Query Suggestions」を元にして、モバイルデバイスを用いた分散学習技術である「Federated Learning」の適用事例を紹介します。モバイルデバイス上で生成されるデータをサーバーに送信せず、ユーザー情報を適切に保護した状態で、デバイス上で学習処理を行う技術になります。
冒頭の論文では、Federated Learningの具体的な適用事例として、Google Keyboard(Gboard)の検索キーワード予測機能が取り上げられています。Gboardは、モバイルデバイス用の文字入力システムで、2018年の時点で、10億台以上のデバイスにインストールされています。テキストの入力中に、関連する複数のWeb検索キーワードを表示する機能があり、ここで表示するキーワードの選択に機械学習モデルが用いられています。
このモデルを学習するには、それぞれのユーザーが実際に候補から選択したキーワードを学習データとして使用する必要がありますが、このような情報をサーバーに送信するのは、ユーザー情報保護の観点で問題となり得ます。そこで、デバイス上の一時的な記憶領域に学習データを保存しておき、デバイス上でモデルのアップデートを行った上で、モデルの更新内容をサーバーに送信します。ここで送信される情報は、モデルに含まれるパラメーターがどのように変化したかという最小限の差分データとなります。
図1は、この事例で用いられる学習処理の全体像を表します。図の左下にある「Training Cache」は、先に説明した、学習データを保存する一時的な記憶領域で、このデータを用いて、Android版のTensorFlowを用いた学習処理が行われます。
図1 GboardにおけるFederated Learningのアーキテクチャー(論文より抜粋)
また、この仕組みでは、「Baseline Model」と「Triggering Model」の2種類の機械学習モデルを組み合わせた処理が行われます。Baseline Modelは、一般的なコーパス(既存の文書を大量に集めたデータセット)とナレッジグラフ(単語のカテゴリー情報などを保存したデータベース)を用いて、ある単語列が与えられた時に、関連する検索キーワードを生成する事前学習済みのモデルです。まずは、これを用いて、表示するキーワードの候補を生成します。ただし、これは、あくまで既存の文書を学習データとしたものであり、モバイルデバイスで入力される文書の特性や、あるいは、利用するユーザーの地域性といった特性は反映されていません。そこで、もう一つのTriggering Modelを用いて、得られた候補の中から、実際に選択肢として表示するキーワードを絞り込むという処理を行います。こちらのモデルは、Training Cacheに保存されたリアルなユーザーデータを用いて学習されるので、これにより、より精度の高い予測ができることになります。
このようにBaseline ModelとTriggering Modelを分離するのには、機械学習における技術的な理由があります。一般に、単語予測といった自然言語に関わる機械学習モデルには、複雑なディープ・ニューラルネットワークが使用されます。しかしながら、このような複雑なモデルをモバイルデバイス上で学習するのは、学習処理の負荷が高くなりすぎるため現実的ではありません。そこで、事前学習済みのBaseline Modelで候補を生成した上で、そこから絞り込みを行うという、比較的簡単な処理をTriggering Modelとして用意して、こちらをFederated Learningの対象とします。実際、論文の中では、Baseline Modelには、LSTM(時系列データのためのニューラルネットワーク)を使用しており、一方、Triggering Modelには、ロジスティック回帰(シンプルな線形モデル)を使用していることが説明されています(図2)。この図には、学習に用いる具体的なデータ項目が示されていますが、これらについては、次回に改めて解説します。
図2 Base ModelとTriggering Modelの組み合わせ(論文より抜粋)
ここで、先ほどの図1の流れをもう少し詳しく説明しておきます。まず、学習対象のモデルそのものは、クラウド上のサーバーで管理されており、学習処理のタスクがクライアント(モバイルデバイス)に割り当てられます。具体的には、モバイルデバイスが一定の条件(電源に接続されている、WiFiネットワークに接続されている、スリープ状態であるなど)を満たした際に、学習ジョブを受け付けられる状態であるという通知をサーバーに対して行います。サーバー側では、このようなクライアントが一定数存在する状態になると、それらに対して学習処理を依頼します。この際、学習対象のモデルのバイナリーと学習内容を指示したメタデータが送信されます。
続いて、それぞれのクライアントは、ローカルに保存された学習データを用いて学習処理を行い、モデルに含まれるパラメーターを更新します。その後、更新されたパラメーターの差分情報をサーバーに送信します。サーバーは、それぞれのクライアントから得られた更新情報を総合して、サーバー上に保存されたマスターとなるモデルを更新します。最後に、この更新されたモデルは、すべてのGboardのクライアントにアプリケーションのアップデートして配信されます。このように、学習処理はそれぞれのデバイスで行われますが、その結果はサーバーで集約されて、すべてのクライアントに共通のアップデートとして配信されるという仕組みになります。
今回は、2019年に公開された論文2019年に公開された論文「Applied Federated Learning: Improving Google Keyboard Query Suggestions」を元にして、Federated Learningの事例を紹介しました。次回は、学習対象となるモデルの詳細、そして、学習後のモデルの評価結果などを紹介したいと思います。
Disclaimer:この記事は個人的なものです。ここで述べられていることは私の個人的な意見に基づくものであり、私の雇用者には関係はありません。
[IT研修]注目キーワード Python UiPath(RPA) 最新技術動向 Microsoft Azure Docker Kubernetes