GCPにGPUを借りてJAXを動かすまでにやったこと

GCPにGPUを借りてJAXを動かすまでにやったこと

目的

GPUを利用した実験を行いたかったがGPUを持っていなかったためGoogle Cloud Platform上でインスタンスを借りてやってみることにした。 セットアップも初めてだったのでまとめました。

セットアップ手順

今回はus-central1-fn1-standard-81 x NVIDIA V100を借りた。OSとしてはubuntu-1804を指定。 ディスクサイズは10GBがデフォルトだったがインストール時にストレージがなくなることが多々あったので40GBとしている。

基本的なシステム更新

まずは、updateとupgradeを実行:

sudo apt update
sudo apt upgrade -y

Python 3.10のセットアップ

実験にpython3.10が必要なのでpyenvから用意する。

pyenvに必要なpackageをまずインストール:

sudo apt install -y \
libncurses-dev \
build-essential \
libffi-dev \
libssl-dev \
zlib1g-dev \
liblzma-dev \
libbz2-dev \
libreadline-dev \
libsqlite3-dev \
libopencv-dev \
tk-dev \
git

pyenvのインストール:

git clone https://github.com/pyenv/pyenv.git ~/.pyenv
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo 'eval "$(pyenv init --path)"' >> ~/.bashrc
source ~/.bashrc

pyenvを利用してpython3.10のインストール:

pyenv install 3.10.0
pyenv global 3.10.0
pyenv --version

CUDAのインストール

次にnvidiaからcudaを落としてくる。CUDA 12はJAXが現在対応中だったので11.8を使用。ubuntu以外のインストール方法は公式サイトが詳しい。

wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run

注意: GCPではもともとCUDAなどをインストールしてくれているイメージもあるのでそちらを利用した方が簡単かもしれません。

インストールが終了したらパスを通す。nvcc -Vコマンドが通る+バージョンがあっているか確認:

# Linuxアーキテクチャを確認
uname -a

echo 'export PATH="/usr/local/cuda/bin:$PATH"'  >> ~/.bashrc
echo 'export LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"' >> ~/.bashrc
source ~/.bashrc
nvcc -V

cuDNNのインストール

NVIDIA cuDNNも必要なのでダウンロード。こちらはサイトにログインが必要。 GCP上からコマンドで落としてくることができないので、MacにUbuntu用の**Local Installer for Ubuntu18.04 x86_64 (Deb)**を選択してダウンロード。gcloudコマンドを利用してscpで転送:

gcloud compute scp --zone "asia-east1-a" --project "sampleprj" /path/to/cudnn-local-repo-ubuntu1804-8.8.0.121_1.0-1_amd64.deb instance-name:.

転送後の手順:

sudo dpkg -i libcudnn7_7.6.5.32-1+cuda10.1_amd64.deb
sudo dpkg -i libcudnn7-dev_7.6.5.32+cuda10.1_amd64.deb
sudo dpkg -i libcudnn7-dec_7.6.5.32+cuda10.1_amd64deb

GitHub SSH設定

GitHubに登録する用の鍵を作成:

cd ~/.ssh
ssh-keygen -t rsa
# 公開鍵をGitHubに登録

JAXのインストール

ここまで完了したらPythonのpackageのダウンロードを行う。 JAXはGPU(CUDA)用を使用。PoetryでGPU用のJAXを落としてくるのに癖があるので素直にpipで実行:

pip install --upgrade pip
pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Poetryでrequirements.txtを出力する方法

poetry export -f requirements.txt -o requirements.txt

動作確認

以上でセットアップは完了。JAXがGPUに接続できているかどうかは以下のコマンドで確認:

import jax
print(jax.default_backend())
print(jax.local_devices())

まとめ

GCP上でのGPU環境構築は初回セットアップが複雑ですが、一度設定すれば強力な計算環境を利用できます。特にJAXを使った機械学習実験において、GPUの恩恵は非常に大きいです。

参考文献

  1. google/jax
  2. ubuntu 20.04 / 18.04 に pyenv をインストールする話
  3. WARNING - No GPU/TPU found, falling back to CPU. #10323
  4. Please provide PEP 503 compliant indices for CUDA versions of packages #5410
  5. Poetryでrequirements.txtを作成
  6. UbuntuでCUDAの削除から再インストールまでのメモ
  7. gcloud compute scp
  8. CUDAとcuDNN(GPU付きUbuntuデスクトップ)
  9. cuDNNがインストールされていることを確認する方法
  10. GitHubでssh接続する手順~公開鍵・秘密鍵の生成から~
  11. Add support for CUDA 12 #13637