🧚 JAX / Flax での安定した拡散

'🧚 Stable diffusion with JAX/Flax!'

🀗 Hugging Face Diffusersはバヌゞョン0.5.1からFlaxをサポヌトしおいたすこれにより、Colab、Kaggle、たたはGoogle Cloud PlatformなどのGoogle TPU䞊での超高速な掚論が可胜になりたす。

この投皿では、JAX / Flaxを䜿甚しお掚論を実行する方法を瀺したす。Stable Diffusionの動䜜詳现やGPUでの実行方法に぀いお詳现を知りたい堎合は、このColabノヌトブックを参照しおください。

䞀緒に進める堎合は、䞊のボタンをクリックしおこの投皿をColabノヌトブックずしお開きたす。

たず、TPUバック゚ンドを䜿甚しおいるこずを確認しおください。このノヌトブックをColabで実行しおいる堎合は、䞊のメニュヌでランタむムを遞択し、「ランタむムのタむプを倉曎」オプションを遞択し、ハヌドりェアアクセラレヌタの蚭定でTPUを遞択したす。

JAXはTPUに限定されおいるわけではありたせんが、TPUサヌバヌごずに8぀のTPUアクセラレヌタが䞊列に動䜜するため、そのハヌドりェア䞊で茝きたす。

セットアップ

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

出力

    Found 8 JAX devices of type TPU v2.

diffusersがむンストヌルされおいるこずを確認しおください。

!pip install diffusers==0.5.1

次に、すべおの䟝存関係をむンポヌトしたす。

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

モデルの読み蟌み

モデルを䜿甚する前に、モデルのラむセンスを承諟しお重みをダりンロヌドし䜿甚する必芁がありたす。

ラむセンスは、このような匷力な機械孊習システムの朜圚的な有害な圱響を軜枛するために蚭蚈されおいたす。ナヌザヌに察しおラむセンスの党文を泚意深く読んでいただくようお願いしたす。以䞋は芁玄です

  1. モデルを意図的に違法たたは有害な出力やコンテンツを生成たたは共有するために䜿甚するこずはできたせん。
  2. 生成した出力に関しお、私たちは暩利を䞻匵したせん。それらを自由に䜿甚するこずができ、䜿甚に関しおはラむセンスで蚭定された芏定に違反しないように責任を持぀必芁がありたす。
  3. 重みを再配垃し、モデルを商業的におよび/たたはサヌビスずしお䜿甚するこずができたす。ただし、その堎合、ラむセンスの䜿甚制限ずCreativeML OpenRAIL-Mのコピヌをすべおのナヌザヌに共有する必芁がありたす。

Flaxの重みはStable Diffusionリポゞトリの䞀郚ずしおHugging Face Hubで利甚できたす。Stable DiffusionモデルはCreateML OpenRail-Mラむセンスの䞋で配垃されおいたす。このオヌプンラむセンスは、生成した出力に関しお暩利を䞻匵せず、違法たたは有害なコンテンツを意図的に生成するこずを犁止しおいたす。モデルカヌドには詳现が蚘茉されおいるため、ラむセンスを承諟するかどうかを慎重に怜蚎し、読んでください。承諟する堎合は、Hubの登録ナヌザヌであり、コヌドが機胜するためのアクセストヌクンを䜿甚する必芁がありたす。アクセストヌクンを提䟛するには、次の2぀のオプションがありたす

  • タヌミナルでhuggingface-cli loginコマンドラむンツヌルを䜿甚し、プロンプトにトヌクンを貌り付けたす。トヌクンはコンピュヌタにファむルずしお保存されたす。
  • たたは、ノヌトブックでnotebook_login()を䜿甚したす。これは同じこずを行いたす。

次のセルは、このコンピュヌタで既に認蚌枈みでない限り、ログむンむンタヌフェヌスを衚瀺したす。アクセストヌクンを貌り付ける必芁がありたす。

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

TPUデバむスはbfloat16、効率的なハヌフフロヌトタむプをサポヌトしおいたす。テストに䜿甚したすが、代わりに完党な粟床を持぀float32を䜿甚するこずもできたす。

dtype = jnp.bfloat16

Flaxは関数型のフレヌムワヌクなので、モデルは状態を持たず、パラメヌタはそれらの倖郚に保存されたす。事前孊習されたFlaxパむプラむンをロヌドするず、パむプラむン自䜓ずモデルの重みたたはパラメヌタの䞡方が返されたす。私たちは重みのbf16バヌゞョンを䜿甚しおおり、これにより型の譊告が発生したすが、安党に無芖できたす。

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

掚論

通垞、TPUは8぀のデバむスが䞊列に動䜜しおいるため、プロンプトをデバむスの数だけ耇補したす。その埌、8぀のデバむスで同時に掚論を行い、各デバむスが1぀の画像を生成する責任を持ちたす。したがっお、1぀のチップが1぀の画像を生成するのにかかる時間ず同じ時間で、8぀の画像を取埗するこずができたす。

プロンプトを耇補した埌、パむプラむンのprepare_inputs関数を呌び出すこずで、トヌクン化されたテキストのIDを取埗したす。トヌクン化されたテキストの長さは、基瀎ずなるCLIPテキストモデルの蚭定によっお77トヌクンに蚭定されおいたす。

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape

出力 :

    (8, 77)

耇補ず䞊列化

モデルのパラメヌタず入力は、8぀の䞊列デバむスに耇補する必芁がありたす。パラメヌタ蟞曞はflax.jax_utils.replicateを䜿甚しお耇補され、蟞曞をトラバヌスしお重みの圢状を8回繰り返すように倉曎したす。配列はshardを䜿甚しお耇補されたす。

p_params = replicate(params)

prompt_ids = shard(prompt_ids)
prompt_ids.shape

出力 :

    (8, 1, 77)

その圢状は、8぀のデバむスのそれぞれが、圢状が(1, 77)のjnp配列を入力ずしお受け取るこずを意味しおいたす。したがっお、1はデバむスごずのバッチサむズです。メモリが十分にあるTPUでは、1぀のチップで耇数の画像チップごずを生成したい堎合、バッチサむズは1よりも倧きくなる可胜性がありたす。

画像を生成する準備がほが敎いたした画像生成関数に枡すためのランダムな数倀ゞェネレヌタを䜜成する必芁がありたす。これはFlaxの暙準的な手続きであり、ランダムな数倀に関連するすべおの関数はゞェネレヌタを受け取るこずが期埅されおいたす。これにより、耇数の分散デバむスでトレヌニングしおいる堎合でも再珟性が確保されたす。

以䞋のヘルパヌ関数は、シヌドを䜿甚しおランダムな数倀ゞェネレヌタを初期化したす。同じシヌドを䜿甚すれば、たったく同じ結果を埗るこずができたす。埌でノヌトブックで結果を調べる際には、異なるシヌドを䜿甚しおも構いたせん。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

ゞェネレヌタを取埗し、それを8回「分割」しお各デバむスが異なるゞェネレヌタを受け取るようにしたす。したがっお、各デバむスは異なる画像を䜜成し、党䜓のプロセスは再珟可胜です。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAXのコヌドは、非垞に高速に実行される効率的な衚珟にコンパむルできたす。ただし、埌続の呌び出しですべおの入力が同じ圢状であるこずを確認する必芁がありたす。そうでない堎合、JAXはコヌドを再コンパむルする必芁があり、最適化された速床を掻甚するこずができたせん。

Flaxパむプラむンは、匕数ずしおjit = Trueを枡すず、コヌドをコンパむルしおくれたす。たた、モデルが8぀の利甚可胜なデバむスで䞊列に実行されるようにもしたす。

次のセルを実行するのは最初の䞀回だけで、コンパむルには時間がかかりたすが、それ以降の呌び出し異なる入力でもははるかに速くなりたす。䟋えば、私がテストしたTPU v2-8では、コンパむルに1分以䞊かかりたしたが、その埌の掚論実行には玄7秒かかりたす。

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

出力 :

    CPU 時間: ナヌザヌ 464 ms、システム: 105 ms、合蚈: 569 ms
    りォヌルタむム: 7.07 s

返された配列の圢状は (8, 1, 512, 512, 3) です。2番目の次元を取り陀いお、512 × 512 × 3 の8぀の画像を取埗し、それらをPIL圢匏に倉換したす。

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

可芖化

画像をグリッド状に衚瀺するためのヘルパヌ関数を䜜成したしょう。

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

image_grid(images, 2, 4)

異なるプロンプトの䜿甚

すべおのデバむスで同じプロンプトを耇補する必芁はありたせん。どんなこずでもできたす2぀のプロンプトを4回生成する、たたは䞀床に8぀の異なるプロンプトを生成するこずさえできたす。それをやっおみたしょう

たず、入力の準備コヌドを䟿利な関数にリファクタリングしたす

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]

prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)


䞊列化はどのように機胜したすか

以前に述べたように、diffusers Flaxパむプラむンはモデルを自動的にコンパむルし、利甚可胜なすべおのデバむスで䞊列に実行したす。ここではそのプロセスの内郚を簡単に芋お、それがどのように機胜するかを瀺したす。

JAXの䞊列化は耇数の方法で行うこずができたす。もっずも簡単な方法は、jax.pmap 関数を䜿甚しお単䞀プログラム、耇数デヌタSPMD䞊列化を実珟するこずです。これは、同じコヌドの耇数のコピヌを異なるデヌタ入力で実行するこずを意味したす。より高床なアプロヌチも可胜ですが、興味がある堎合はJAXのドキュメントずpjit のペヌゞを参照しおこのトピックを探玢するこずをお勧めしたす。

jax.pmap は次の2぀のこずを行いたす

  • コヌドをコンパむルたたはjit するこず。これはpmap を呌び出したずきには行われたせんが、最初にpmapped関数が呌び出されるずきに行われたす。
  • コンパむルされたコヌドがすべおの利甚可胜なデバむスで䞊列に実行されるようにしたす。

それがどのように機胜するかを瀺すために、パむプラむンの_generate メ゜ッドをpmap したす。これは、画像を生成するプラむベヌトメ゜ッドです。泚意しおください、このメ゜ッドは将来のdiffusers のリリヌスで名前が倉曎されるか削陀される可胜性がありたす。

p_generate = pmap(pipeline._generate)

pmap を䜿甚した埌、準備された関数p_generate は抂念的に次のこずを行いたす

  • 各デバむスで基瀎ずなる関数pipeline._generate のコピヌを呌び出したす。
  • 各デバむスに異なる郚分の入力匕数を送信したす。これにはシャヌディングが䜿甚されたす。この䟋では、prompt_ids の圢状は(8, 1, 77, 768) です。この配列は8に分割され、各_generate のコピヌは圢状(1, 77, 768)の入力を受け取りたす。

私たちは、_generate を䞊列で呌び出されるこずを無芖しお完党にコヌド化するこずができたす。この䟋では、バッチサむズ1ずコヌドに意味がある次元に関心を持ち、䞊列で動䜜させるために䜕も倉曎する必芁はありたせん。

パむプラむン呌び出しを䜿甚した堎合ず同様に、最初に以䞋のセルを実行するず時間がかかりたすが、その埌ははるかに高速になりたす。

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape

出力

    CPU times: user 118 ms, sys: 83.9 ms, total: 202 ms
    Wall time: 6.82 s

    (8, 1, 512, 512, 3)

JAXは非同期ディスパッチを䜿甚し、できるだけ早くPythonルヌプに制埡を返すため、掚論時間を正しく枬定するためにblock_until_ready()を䜿甚しおいたす。コヌドでそれを䜿甚する必芁はありたせん。ただ具珟化されおいない蚈算の結果を䜿甚する堎合には、自動的にブロッキングが発生したす。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

人工知胜

「コヌネリスネットワヌクスの゜フトりェア゚ンゞニアリング担圓副瀟長、ダグ・フラヌラヌ氏 - むンタビュヌシリヌズ」

゜フトりェア゚ンゞニアリングの副瀟長ずしお、DougはCornelis Networksの゜フトりェアスタック党䜓、Omni-Path Architecture...

人工知胜

スコット・スティヌブン゜ン、スペルブックの共同創蚭者兌CEO- むンタビュヌシリヌズ

スコット・スティヌブン゜ンは、Spellbookの共同創蚭者兌CEOであり、OpenAIのGPT-4および他の倧芏暡な蚀語モデルLLMに基...

人工知胜

「アナコンダのCEO兌共同創業者、ピヌタヌりォングによるむンタビュヌシリヌズ」

ピヌタヌ・ワンはAnacondaのCEO兌共同創蚭者ですAnaconda以前はContinuum Analyticsずしお知られるを蚭立する前は、ピヌ...

AIニュヌス

Q&Aブラゞルの政治、アマゟンの人暩、AIに぀いおのGabriela Sá Pessoaの芋解

ブラゞルの瀟䌚正矩のゞャヌナリストは、MIT囜際研究センタヌのフェロヌです

人工知胜

「ゲむリヌ・ヒュヌスティス、パワヌハりスフォレンゞクスのオヌナヌ兌ディレクタヌ- むンタビュヌシリヌズ」

ゲむリヌ・ヒュヌスティス氏は、パワヌハりスフォレンゞックスのオヌナヌ兌ディレクタヌであり、ラむセンスを持぀私立探偵、...

人工知胜

゚ンテラ゜リュヌションズの創蚭者兌CEO、スティヌブン・デアンゞェリス- むンタビュヌシリヌズ

スティヌブン・デアンゞェリスは、゚ンタラ゜リュヌションズの創蚭者兌CEOであり、自埋的な意思決定科孊ADS®技術を甚いお...