IT・技術研修ならCTC教育サービス

サイト内検索 企業情報 サイトマップ

研修コース検索

コラム

グーグルのクラウドを支えるテクノロジー

CTC 教育サービス

 [IT研修]注目キーワード   Python  UiPath(RPA)  最新技術動向  Microsoft Azure  Docker  Kubernetes 

第140回 オリジナル論文から学ぶ「JAX」の特徴とその役割 (中井悦司) 2022年11月

はじめに

 今回は、2018年に公開された論文「Compiling machine learning programs via high-level tracing」を元にして、機械学習ライブラリーJAXの役割と、その基礎となる考え方を紹介します。

JAXのオリジナル論文

 最近、機械学習に関連するオープソースとして、JAXの名前を耳にすることが増えてきました。たとえば、2022年6月に公開されたGoogle Cloudの公式ブログ「EvoJAX: あなたの課題をNeuroevolutionの力で解く」の冒頭には、次のような一節があります。

『JAXはユーザーコードの簡略化や大規模な並列化・何桁もの高速化を可能にする、最近のGoogleで最も重要な機械学習(ML)フレームワークの一つです。このフレームワークは、言語理解におけるPathways Language Model(PaLM) 、物理学・分子動力学シミュレーションにおけるBraxやJAX MDなどをはじめとして、近年に最先端(State-of-the-Art)の成果を示した研究でも利用されています。』

 実際の所、JAXにはどのような特徴があり、何ができるのでしょうか? 冒頭の論文は、JAXのオリジナルの開発者が執筆したもので、JAXの開発に取り組み始めた初期段階でのアイデアやベンチマーク結果が示されています。開発の初期段階に公開されたものですので、現在のJAXとは細かな違いもありますが、JAXの主要なアイデアを理解する参考となるでしょう。この論文では特に、「PSC関数単位でのコンパイル機能を提供する事」がJAXの主要な役割として強調されています。

JITコンパイラとしてのJAX

 TensorFlowをはじめとするディープラーニング向けの機械学習ライブラリでは、構築したモデルをGPUやTPUなどのアクセラレーターを用いて高速に実行する機能を提供しています。たとえば、TensorFlowでは、XLAと呼ばれる独自のコンパイラを用いて、これらのアクセラレーターに最適化されたバイナリーコードを生成することができます。ただし、TensorFlowからXLAを使用する場合は、あくまでもTensorFlowで定義した機械学習モデルが前提となります。任意の数値計算処理をアクセラレーターで実行するといった汎用的な用途は想定されていません。
 一方、既存の機械学習モデルをそのまま利用するのではなく、新しいタイプのモデルを研究・開発する際は、独自に定義した関数や数値計算処理を組み合わせていく必要があります。このような場合、独自の処理をXLAでコンパイルするには、XLAの機能を直接に呼び出すコードを記述する必要があり、一般の開発者には少し敷居が高いものとなっていました。そこで、JAXの開発者は、一般的なPythonのコードで書かれた数値計算処理の関数をそのままの形でコンパイルするためのライブラリーとして、JAXを開発したということです。
 図1は、論文に掲載されている初期のJAXのサンプルコードですが、表面的にはNumPyを用いた普通のPythonのコードのように見えます。JAXを用いるとこのようなコードをXLAでコンパイルして、GPUやTPUなどのアクセラレーターで高速に実行することができるのです。

fig01

図1 初期のJAXのサンプルコード(論文より抜粋)

 ただし、GPUやTPUで実行できるのは、あくまでも数値計算処理ですので、任意のPythonのコードをコンパイルできるわけではありません。画面出力などの副作用を持たず、アクセラレーター上で実行可能な基本的な計算処理を組み合わせた関数が対象となります。論文内では、このような関数を「pure-and-statically-comosed(PSC)関数」と表現していますが、論文によれば、機械学習の計算処理では、PSC関数を組み合わせた処理が大部分を占めているということです。そのため、PSC関数の単位でコンパイルできれば、機械学習モデルを一般的なPythonのコードで柔軟に記述しながら、同時にアクセラレーターによる高速実行が実現できる可能性があります。この実現を目指して、「PSC関数単位でのコンパイル機能」を提供するライブラリーとして、JAXを開発したという事です。
 なお、JAXでPSC関数をコンパイルする際は、コード内でコンパイル対象の関数を指定しておけば、コードの実行時に自動的にコンパイルの処理が行われます。このように、実行時にコンパイル処理を行うコンパイラは、一般に「JIT(Just In Time)コンパイラ」と呼ばれます。

ベンチマーク結果

 冒頭の論文では、当時のJAXによるベンチマーク結果として、図2の2つの結果が紹介されています。図2の上の表には、基本的な学習処理について、コンパイル前のPythonコードによる実行時間と、JAXによるコンパイルを適用した場合の実行時間が示されています。コンパイル後のコードはGPUで実行されるため、当然ながら、かなりの高速化が実現されています。

fig02

図2 JAXのベンチマーク結果(論文より抜粋)

 一方、図2の下の表は、同一のニューラルネットワークをTenserFlowで実装した場合とJAXで実装した場合の比較結果になります。いずれもXLAによるコンパイル処理が行われますが、TensorFlowの場合はモデル全体をまとめてコンパイルするのに対して、JAXの場合は、先ほど説明したように、PSC関数単位でのコンパイルになります。そのため、JAXの方がオーバーヘッドが大きくなる可能性もありますが、結果としては、ほぼ同等の実行速度が得られています。PSC関数単位でコンパイルするというJAXの考え方は、コード記述の柔軟性とコンパイルによる実行速度の向上を確かに両立している事がわかります。

次回予告

 今回は、2018年に公開された論文「Compiling machine learning programs via high-level tracing」を元にして、機械学習ライブラリーJAXの役割と、その基礎となる考え方を紹介しました。本文の説明からわかるように、JAXは、機械学習モデルを構成する個々の関数を高速に実行するためのライブラリーであり、機械学習モデルそのものを記述するフレームワークではありません。最近は、JAXをベースとして、機械学習モデルを記述するための上位のフレームワークが開発されており、これらを組み合わせて利用することができます。
 次回は、フラッシュディスクによるファイルキャッシュシステムを数理最適化のアルゴリズムで最適化するという話題をお届けします。

Disclaimer:この記事は個人的なものです。ここで述べられていることは私の個人的な意見に基づくものであり、私の雇用者には関係はありません。

 


 

 [IT研修]注目キーワード   Python  UiPath(RPA)  最新技術動向  Microsoft Azure  Docker  Kubernetes