注釈付き拡散モデル

'annotated diffusion model' in English.

このブログ記事では、Denoising Diffusion Probabilistic Models(DDPM、拡散モデル、スコアベースの生成モデル、または単にオートエンコーダーとも呼ばれる)について詳しく見ていきます。これらのモデルは、(非)条件付きの画像/音声/ビデオの生成において、驚くべき結果が得られています。具体的な例としては、OpenAIのGLIDEやDALL-E 2、University of HeidelbergのLatent Diffusion、Google BrainのImageGenなどがあります。

この記事では、(Hoら、2020)による元のDDPMの論文を取り上げ、Phil Wangの実装をベースにPyTorchでステップバイステップで実装します。なお、このアイデアは実際には(Sohl-Dicksteinら、2015)で既に導入されていました。ただし、改善が行われるまでには(Stanford大学のSongら、2019)を経て、Google BrainのHoら、2020)が独自にアプローチを改良しました。

拡散モデルにはいくつかの視点がありますので、ここでは離散時間(潜在変数モデル)の視点を採用していますが、他の視点もチェックしてください。

さあ、始めましょう!

from IPython.display import Image
Image(filename='assets/78_annotated-diffusion/ddpm_paper.png')

まず必要なライブラリをインストールしてインポートします(PyTorchがインストールされていることを前提としています)。

!pip install -q -U einops datasets matplotlib tqdm

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

拡散モデルとは何ですか?

(ノイズ除去)拡散モデルは、Normalizing Flows、GAN、VAEなどの他の生成モデルと比較してそれほど複雑ではありません。これらのモデルはすべて、ノイズをいくつかの単純な分布からデータサンプルに変換します。ここでも、ネットワークが純粋なノイズから徐々にデータを除去することを学習します。

画像の場合、セットアップは2つのプロセスで構成されています:

  • 選択した固定(または事前定義された)前方拡散プロセス q q q が、イメージに徐々にガウスノイズを追加し、純粋なノイズになるまで進行します
  • 学習された逆ノイズ除去拡散プロセス p θ p_\theta p θ ​ があります。ここでは、ニューラルネットワークが純粋なノイズから始まり、実際のイメージになるまで徐々に画像をノイズ除去するようにトレーニングされます。

t t t でインデックス付けられた前方および逆プロセスは、いくつかの有限時間ステップ T T T(DDPMの著者は T = 1000 T=1000 T = 1 0 0 0 としています)で行われます。 t = 0 t=0 t = 0 から始め、実データ分布から実際のイメージ x 0 \mathbf{x}_0 x 0 ​ をサンプリングします(たとえば、ImageNetの猫の画像とします)。前方プロセスは各時間ステップ t t t でガウス分布からノイズをサンプリングし、前の時間ステップの画像に追加します。十分に大きな T T T と各時間ステップでノイズを追加するための適切なスケジュールが与えられると、徐々に等方性ガウス分布が t = T t=T t = T で得られます。

より数学的な形で

これをより形式的に書いてみましょう。最終的には、ニューラルネットワークが最適化する必要のある扱いやすい損失関数が必要です。

実データ分布 q ( x 0 ) q(\mathbf{x}_0) q ( x 0 ​ ) からの実際のデータサンプルを x 0 ∼ q ( x 0 ) \mathbf{x}_0 \sim q(\mathbf{x}_0) x 0 ​ ∼ q ( x 0 ​ ) とします。各時間ステップ t t t で既知の分散スケジュール 0 < β 1 < β 2 < . . . < β T < 1 0 < \beta_1 < \beta_2 < … < \beta_T < 1 0 < β 1 ​ < β 2 ​ < . . . < β T ​ < 1 に従って、前方拡散プロセス q ( x t ∣ x t − 1 ) q(\mathbf{x}_t | \mathbf{x}_{t-1}) q ( x t ​ ∣ x t − 1 ​ ) を定義します。これにより、各時間ステップ t t t でガウスノイズが追加されます。拡散モデルの著者は、q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) とします。

正規分布(またはガウス分布とも呼ばれる)は、2つのパラメータによって定義されます:平均 μ \mu μ と分散 σ 2 ≥ 0 \sigma^2 \geq 0 σ 2 ≥ 0 。基本的に、各新しい(わずかにノイズのある)時刻 t t t の画像は、μ t = 1 − β t x t − 1 \mathbf{\mu}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} μ t ​ = 1 − β t ​ ​ x t − 1 ​ と σ t 2 = β t \sigma^2_t = \beta_t σ t 2 ​ = β t ​ の条件付きガウス分布から描かれます。これは、ϵ ∼ N ( 0 , I ) \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) ϵ ∼ N ( 0 , I ) をサンプリングして x t = 1 − β t x t − 1 + β t ϵ \mathbf{x}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} + \sqrt{\beta_t} \mathbf{\epsilon} x t ​ = 1 − β t ​ ​ x t − 1 ​ + β t ​ ​ ϵ を設定することで行うことができます。

注意点として、β t \beta_t β t ​ は各時刻 t t t で一定ではない(したがって下付き文字がある)ことに留意してください。実際、これは「分散スケジュール」と呼ばれるものを定義します。そのスケジュールは線形、二次、余弦などであることがあります(学習率スケジュールのようなものです)。

したがって、x 0 \mathbf{x}_0 x 0 ​ から始めると、x 1 , . . . , x t , . . . , x T \mathbf{x}_1, …, \mathbf{x}_t, …, \mathbf{x}_T x 1 ​ , . . . , x t ​ , . . . , x T ​ が得られます。ただし、適切なスケジュールを設定すれば、x T \mathbf{x}_T x T ​ は純粋なガウスノイズです。

次に、p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ​ ∣ x t ​ ) の条件付き分布がわかっていれば、逆のプロセスを実行できます。つまり、ランダムなガウスノイズ x T \mathbf{x}_T x T ​ をサンプリングし、それを徐々に「ノイズ除去」して実際の分布 x 0 \mathbf{x}_0 x 0 ​ のサンプルを得ることができます。

ただし、p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1} | \mathbf{x}_t) p ( x t − 1 ​ ∣ x t ​ ) はわかりません。この条件付き確率を計算するには、すべての可能な画像の分布を知る必要があるため、扱いにくいです。したがって、ニューラルネットワークを利用してこの条件付き確率分布を「近似(学習)すること」を活用することにします。これを p θ ( x t − 1 ∣ x t ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) p θ ​ ( x t − 1 ​ ∣ x t ​ ) と呼び、θ \theta θ は勾配降下法によって更新されるニューラルネットワークのパラメータです。

では、逆プロセスの(条件付き)確率分布を表現するためにニューラルネットワークが必要です。この逆プロセスもまたガウス分布であると仮定すると、ガウス分布は次の2つのパラメータによって定義されます:

  • μ θ \mu_\theta μ θ ​ によってパラメータ化された平均;
  • Σ θ \Sigma_\theta Σ θ ​ によってパラメータ化された分散;

そのため、プロセスを p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta (\mathbf{x}_{t},t)) p θ ​ ( x t − 1 ​ ∣ x t ​ ) = N ( x t − 1 ​ ; μ θ ​ ( x t ​ , t ) , Σ θ ​ ( x t ​ , t ) ) とパラメータ化することができます。ここで、平均値と分散はノイズレベル t t t にも依存しています。

したがって、私たちのニューラルネットワークは、平均と分散を学習/表現する必要があります。しかし、DDPMの著者たちは、分散を固定し、ニューラルネットワークがこの条件付き確率分布の平均 μ θ \mu_\theta μ θ ​ のみを学習(表現)するようにすることを決定しました。論文から:

まず、未学習の時間依存定数 Σ θ ( x t , t ) = σ t 2 I \Sigma_\theta ( \mathbf{x}_t, t) = \sigma^2_t \mathbf{I} Σ θ ​ ( x t ​ , t ) = σ t 2 ​ I を設定します。実験的には、σ t 2 = β t \sigma^2_t = \beta_t σ t 2 ​ = β t ​ および σ t 2 = β ~ t \sigma^2_t = \tilde{\beta}_t σ t 2 ​ = β ~ ​ t ​(論文参照)の両方が似た結果を示しました。

これは後に改良され、改良拡散モデルの論文で、ニューラルネットワークは平均に加えて、この逆過程の分散も学習します。

したがって、私たちは、ニューラルネットワークがこの条件付き確率分布の平均のみを学習/表現する必要があると仮定して続行します。

平均を再パラメータ化して目的関数を定義する

逆過程の平均を学習するための目的関数を導くために、著者たちは、q q q と p θ p_\theta p θ ​ の組み合わせが変分オートエンコーダ(VAE)(Kingma et al.、2013)と見なすことができると観察しました。したがって、変分下限(またはELBOとも呼ばれる)は、真のデータサンプル x 0 \mathbf{x}_0 x 0 ​ に関して負の対数尤度を最小化するために使用できます(ELBOの詳細については、VAEの論文を参照してください)。このプロセスのELBOは、各時間ステップ t t t での損失の合計であり、L = L 0 + L 1 + . . . + L T L = L_0 + L_1 + … + L_T L = L 0 ​ + L 1 ​ + . . . + L T ​ です。前向きの q q q プロセスと逆プロセスの構築により、損失の各項(L 0 L_0 L 0 ​ を除く)は、実際には2つのガウス分布間のKLダイバージェンスであり、それは平均に関するL2損失として明示的に書くことができます!

Sohl-Dickstein et al.によって示されるように、構築された前向きプロセス q q q の直接の結果は、x t \mathbf{x}_t x t ​ を x 0 \mathbf{x}_0 x 0 ​ に条件付けた任意のノイズレベルでサンプリングできることです(ガウスの和もガウスですから)。これは非常に便利です:x t \mathbf{x}_t x t ​ をサンプリングするために繰り返し q q q を適用する必要はありません。q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t | \mathbf{x}_0) = \cal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1- \bar{\alpha}_t) \mathbf{I}) q ( x t ​ ∣ x 0 ​ ) = N ( x t ​ ; α ˉ t ​ ​ x 0 ​ , ( 1 − α ˉ t ​ ) I )

ここで、α t : = 1 − β t \alpha_t := 1 – \beta_t α t ​ : = 1 − β t ​ および α ˉ t : = Π s = 1 t α s \bar{\alpha}_t := \Pi_{s=1}^{t} \alpha_s α ˉ t ​ : = Π s = 1 t ​ α s ​ とします。この式を「いい性質」と呼びましょう。これは、ガウスノイズをサンプリングし、適切にスケーリングして x 0 \mathbf{x}_0 x 0 ​ に追加することで、x t \mathbf{x}_t x t ​ を直接得ることができます。α ˉ t \bar{\alpha}_t α ˉ t ​ は既知の β t \beta_t β t ​ 分散スケジュールの関数であり、したがって既知で事前計算可能です。したがって、訓練中に損失関数 L L L のランダムな項を最適化することができます(つまり、訓練中にランダムに t t t をサンプリングして L t L_t L t ​ を最適化することができます)。

この性質のもう一つの美しさは、Ho et al.に示されているように、(数学的な計算が必要ですが、その詳細はこの優れたブログ投稿を参照してください)平均を再パラメータ化することで、ニューラルネットワークが追加されたノイズ(ネットワークϵθ(xt, t)を介して予測されるノイズ)を学習することができることです。KL項でのノイズレベルtに関する損失を構成するために、平均を再パラメータ化することができます。つまり、私たちのニューラルネットワークは平均予測ではなく、ノイズ予測になります。平均は以下のように計算できます:

μθ(xt, t) = 1/αt(xt – βt/(1-ᾱt)ϵθ(xt, t))

最終的な目的関数Ltは次のようになります(ランダムな時間ステップtにおけるϵ∼N(0, I)):

∥ϵ – ϵθ(xt, t)∥2 = ∥ϵ – ϵθ(ᾱt x0 + (1-ᾱt)ϵ, t)∥2

ここで、x0は初期の(実際の、破損していない)画像であり、tによって与えられる直接のノイズレベルは固定された前方プロセスによって与えられます。ϵは時間ステップtでサンプリングされる純粋なノイズであり、ϵθ(xt, t)は私たちのニューラルネットワークです。ニューラルネットワークは、真のガウスノイズと予測されたノイズとの間の単純な平均二乗誤差(MSE)を使用して最適化されます。

トレーニングアルゴリズムは次のようになります:

要するに:

  • 実際の未知の複雑なデータ分布q(x0)からランダムなサンプルx0を取得する
  • 1からTまでの間でノイズレベルtを一様にサンプリングする(つまり、ランダムな時間ステップ)
  • ガウス分布からノイズをサンプリングし、既知のスケジュールβtに基づいてレベルtで入力を破損させる(前述の性質を使用)
  • ニューラルネットワークは、破損した画像xtに基づいてこのノイズを予測するようにトレーニングされる(つまり、既知のスケジュールβtに基づいてx0に適用されるノイズ)

実際には、これらのすべてはデータのバッチ上で行われます。ニューラルネットワークを最適化するために確率的勾配降下法が使用されるためです。

ニューラルネットワーク

ニューラルネットワークは、特定の時間ステップでノイズの混入した画像を受け取り、予測されたノイズを返す必要があります。予測されたノイズは、入力画像と同じサイズ/解像度のテンソルです。したがって、ネットワークはテンソルの形状が同じである入力と出力のテンソルを受け取ります。このためにどのようなタイプのニューラルネットワークを使用できるのでしょうか?

ここで通常使用されるのは、一般的な「深層学習入門」チュートリアルで紹介されるオートエンコーダに非常に似ています。オートエンコーダには、エンコーダとデコーダの間にいわゆる「ボトルネック」層があります。エンコーダはまず画像をより小さい隠れた表現である「ボトルネック」に符号化し、デコーダはその隠れた表現を元の画像に復元します。これにより、ネットワークはボトルネック層で最も重要な情報のみを保持するようになります。

アーキテクチャの観点では、DDPMの著者は( Ronneberger et al., 2015 )によって紹介されたU-Netを採用しました(当時、医療画像セグメンテーションの分野で最先端の結果を達成しました)。このネットワークは、他のオートエンコーダと同様に、ネットワークが最も重要な情報のみを学習するようにするボトルネックを持っています。重要なのは、エンコーダとデコーダの間に残差接続を導入し、勾配フローを大幅に改善したことです(He et al., 2015のResNetに触発されました)。

U-Netモデルは、まず入力をダウンサンプリング(つまり、空間解像度に基づいて入力を小さくする)し、その後にアップサンプリングを行います。

以下では、このネットワークをステップバイステップで実装します。

ネットワークの補助機能

まず、ニューラルネットワークの実装時に使用されるいくつかのヘルパー関数とクラスを定義します。重要なのは、特定の関数の入力と出力に入力を単純に追加する「Residual」モジュールを定義することです(つまり、特定の関数に対して残差接続を追加します)。

また、アップサンプリングおよびダウンサンプリング操作のエイリアスも定義します。

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # ストライド畳み込みやプーリングは行わない
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

位置エンベディング

ニューラルネットワークのパラメータは時間(ノイズレベル)を共有しているため、著者はトランスフォーマー( Vaswani et al., 2017 )に触発されて、正弦波位置エンベディングを使用して t t t をエンコードします。これにより、ニューラルネットワークはバッチ内のすべての画像について、どの特定の時間ステップ(ノイズレベル)で動作しているかを「知る」ことができます。

SinusoidalPositionEmbeddingsモジュールは、形状が(batch_size, 1)のテンソル(つまり、バッチ内の複数のノイズ画像のノイズレベル)を入力とし、これを形状が(batch_size, dim)であるテンソルに変換します。ここで、dimは位置エンベディングの次元です。これは、後ほど見るように、各残差ブロックに追加されます。

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

ResNetブロック

次に、U-Netモデルの中核となるビルディングブロックを定義します。DDPMの著者はWide ResNetブロック(Zagoruyko et al., 2016)を使用しましたが、Phil Wangは標準の畳み込み層を「重み標準化」バージョンに置き換えました。これはグループ正規化との組み合わせでより良い結果を得るためです(詳細については、Kolesnikov et al., 2019を参照)。

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    重み標準化は、グループ正規化と相乗効果があるとされています
    """

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

アテンションモジュール

次に、DDPMの著者が畳み込みブロックの間に追加したアテンションモジュールを定義します。アテンションは有名なTransformerアーキテクチャ(Vaswani et al., 2017)のビルディングブロックであり、NLPやビジョンからタンパク質の折りたたみまで、さまざまなAIの領域で大きな成功を収めています。Phil Wangは2つのバリアントのアテンションを使用しています。1つは通常のマルチヘッドセルフアテンション(Transformerで使用されるもの)、もう1つは線形アテンションバリアント(Shen et al., 2018)です。後者は、時間とメモリの要件がシーケンスの長さに対して二次的ではなく線形にスケーリングされます。

アテンションメカニズムの詳しい説明については、Jay Allamarの素晴らしいブログ記事を参照してください。

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

グループ正規化

DDPMの著者はU-Netの畳み込み/アテンションレイヤーにグループ正規化を交互に適用しています(Wu et al., 2018)。以下では、さらに見ていくように、アテンションレイヤーの前にグループ正規化を適用するために使用されるPreNormクラスを定義します。ただし、Transformersのアテンションの前後に正規化を適用するかどうかについては議論がありました。

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

条件付きU-Net

位置埋め込み、ResNetブロック、アテンション、およびグループ正規化のすべての構築ブロックを定義したので、全体のニューラルネットワークを定義する時が来ました。ネットワーク ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵ θ ​ ( x t ​ , t ) の役割は、一括のノイズ画像とそれに対応するノイズレベルを入力として受け取り、入力に追加されたノイズを出力することです。より正確には、以下のようになります:

  • ネットワークは、形状が(バッチサイズ、チャネル数、高さ、幅)の一括のノイズ画像と形状が(バッチサイズ、1)のノイズレベルの一括を入力とし、形状が(バッチサイズ、チャネル数、高さ、幅)のテンソルを返します

ネットワークは次のように構築されます:

  • まず、一括のノイズ画像に畳み込み層を適用し、ノイズレベルのために位置埋め込みを計算します
  • 次に、ダウンサンプリングステージのシーケンスが適用されます。各ダウンサンプリングステージは、2つのResNetブロック + グループ正規化 + アテンション + 残差接続 + ダウンサンプル操作で構成されます
  • ネットワークの中央部では、再びResNetブロックが適用され、アテンションと交互になります
  • 次に、アップサンプリングステージのシーケンスが適用されます。各アップサンプリングステージは、2つのResNetブロック + グループ正規化 + アテンション + 残差接続 + アップサンプル操作で構成されます
  • 最後に、ResNetブロックの後に畳み込み層が適用されます。

最終的に、ニューラルネットワークはレゴブロックのようにレイヤーを積み重ねます(ただし、それらがどのように機能するかを理解することが重要です)。

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=False,
        resnet_block_groups=4,
    ):
        super().__init__()

        # 寸法を決定する
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # 7,3から1,0に変更

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # 時間の埋め込み
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # レイヤー
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

前方拡散プロセスの定義

前方拡散プロセスは、実際の分布から画像にノイズを徐々に追加します。ノイズは、時間ステップ T T T の数だけ追加されます。これは「分散スケジュール」に従って行われます。元の DDPM の著者たちは、線形スケジュールを使用しました:

前方プロセスの分散を β 1 = 10^{−4} から β T = 0.02 まで線形に増加させます。

しかし、(Nichol et al., 2021) で示されたように、コサインスケジュールを使用するとより良い結果が得られることがわかりました。

以下では、T T T のタイムステップのためのさまざまなスケジュールを定義します(後で 1 つを選びます)。

def cosine_beta_schedule(timesteps, s=0.008):
    """
    https://arxiv.org/abs/2102.09672 で提案されたコサインスケジュール
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

まず、T = 300 の線形スケジュールを使用して、T T T の時間ステップにおける β t \beta_t β t ​ から必要なさまざまな変数を定義します。これらの変数は、単に t t t から T T T までの値を格納する 1 次元テンソルであり、次元変換を行うための extract 関数も定義します。

timesteps = 300

# ベータスケジュールの定義
betas = linear_beta_schedule(timesteps=timesteps)

# アルファスケジュールの定義
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# 拡散 q(x_t | x_{t-1}) およびその他の計算
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# 後部 q(x_{t-1} | x_t, x_0) の計算
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

次に、拡散プロセスの各時間ステップでノイズがどのように追加されるかを、猫の画像を使って説明します。

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # PIL 画像(形状 HWC)
image

ノイズはPillowの画像ではなく、PyTorchのテンソルに追加されます。まず、PIL画像からPyTorchテンソルに変換するための画像変換を定義します(ノイズを追加できるようにします)、逆も同様です。

これらの変換は非常にシンプルです。まず、画像を255で割って正規化します([0, 1]の範囲になるように)その後、[-1, 1]の範囲になるようにします。DPPM論文から:

画像データは、{0, 1, …, 255}の整数で構成されていると仮定し、これを[−1, 1]に線形スケールします。これにより、ニューラルネットワークの逆プロセスは、標準正規事前分布p(xT)からの一貫したスケーリング入力で作業することが保証されます。

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(), # torch Tensorの形状をCHWに変換し、255で割る
    Lambda(lambda t: (t * 2) - 1),
    
])

x_start = transform(image).unsqueeze(0)
x_start.shape

また、逆変換も定義します。これは[-1, 1]の値を含むPyTorchテンソルを受け取り、それらをPIL画像に戻します:

import numpy as np

reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHWからHWCへ
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

これを検証しましょう:

reverse_transform(x_start.squeeze())

これで、論文のように順方向拡散プロセスを定義できます:

# 順方向拡散(良い特性を使用)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

特定のタイムステップでテストしてみましょう:

def get_noisy_image(x_start, t):
  # ノイズを追加
  x_noisy = q_sample(x_start, t=t)

  # PIL画像に戻す
  noisy_image = reverse_transform(x_noisy.squeeze())

  return noisy_image

# タイムステップを取得
t = torch.tensor([40])

get_noisy_image(x_start, t)

さまざまなタイムステップでこれを可視化しましょう:

import matplotlib.pyplot as plt

# 再現性のためにシードを使用
torch.manual_seed(0)

# ソース:https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # 行が1つしかない場合でも、2Dグリッドを作成します
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='元の画像')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

これにより、モデルを使用して損失関数を次のように定義できます:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

denoise_model は上記で定義した U-Net です。真のノイズと予測されたノイズの間の Huber 損失を使用します。

PyTorch データセット + DataLoader の定義

ここでは通常の PyTorch データセットを定義します。データセットは、Fashion-MNIST、CIFAR-10、または ImageNet のような実データセットの画像で構成され、[ − 1 , 1 ] の範囲に線形スケーリングされます。

各画像は同じサイズにリサイズされます。興味深いことに、画像はランダムに水平方向に反転されます。論文から引用:

CIFAR10 のトレーニング中にランダムな水平方向の反転を使用しました。反転ありとなしの両方でトレーニングを試しましたが、反転はサンプルの品質をわずかに改善することがわかりました。

ここでは、🤗 Datasets ライブラリを使用してハブから Fashion MNIST データセットを簡単にロードします。このデータセットは、すでに解像度が同じである 28×28 の画像で構成されています。

from datasets import load_dataset

# ハブからデータセットをロード
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128

次に、データセット全体に対してオンザフライで適用する関数を定義します。そのために、with_transform 機能を使用します。この関数は、基本的な画像前処理を適用します:ランダムな水平反転、スケーリング、そして最後に値を [ − 1 , 1 ] の範囲に設定します。

from torchvision import transforms
from torch.utils.data import DataLoader

# 画像の変換を定義する(たとえば torchvision を使用して)
transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# 関数を定義する
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# DataLoader を作成する
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)

batch = next(iter(dataloader))
print(batch.keys())

サンプリング

トレーニング中にモデルからサンプリングするために(進行状況を追跡するために)、以下のコードを定義します。サンプリングは、論文のアルゴリズム 2 で要約されています:

拡散モデルから新しい画像を生成するには、拡散プロセスを逆にします:ガウス分布から純粋なノイズをサンプリングし、ニューラルネットワークを使用して徐々にノイズを除去します(学習した条件付き確率を使用)。最終的には、時間ステップ t = 0 に到達します。上記のように、ノイズ予測器を使用して平均の再パラメータ化をプラグインすることで、わずかに除去された画像 x t − 1 を導くことができます。分散は事前に既知です。

理想的には、実データ分布から生成されたような画像が得られます。

以下のコードはこれを実装しています。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # 論文の式 11
    # モデル(ノイズ予測器)を使用して平均を予測する
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # アルゴリズム 2 の 4 行目:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# アルゴリズム 2(すべての画像を返す)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # 純粋なノイズから開始(バッチ内の各例ごとに)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

上記のコードは、元の実装の簡略化バージョンであることに注意してください。私たちは、クリッピングを使用する元のより複雑な実装と同様に、私たちの簡略化バージョンが同じくらいうまく機能することを見つけました(これは論文のアルゴリズム2と一致しています)。

モデルのトレーニング

次に、通常のPyTorchの方法でモデルをトレーニングします。また、上記で定義したsampleメソッドを使用して、定期的に生成された画像を保存するためのいくつかのロジックも定義します。

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

以下では、モデルを定義し、GPUに移動します。また、標準のオプティマイザ(Adam)も定義します。

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

トレーニングを開始しましょう!

from torchvision.utils import save_image

epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: バッチ内の各例について一様にtをサンプリングする
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # 生成された画像を保存する
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)

サンプリング(推論)

モデルからサンプリングするには、上記で定義したsample関数を使用するだけです。

# 64枚の画像をサンプリングする
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# ランダムな画像を表示する
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

モデルは素敵なTシャツを生成することができるようです!トレーニングに使用したデータセットはかなり低解像度(28×28)であることに注意してください。

また、ノイズ除去プロセスのGIFも作成できます。

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

DDPMの論文では、拡散モデルは(非)条件付き画像生成に向けた有望な方向性であることが示されています。その後、特にテキスト条件付き画像生成において大幅に改良されています。以下に、重要な(しかし完全ではない)後続の作品のいくつかをリストアップしています:

  • Improved Denoising Diffusion Probabilistic Models(Nichol et al., 2021):条件付き分布の分散(平均だけでなく)を学習することがパフォーマンスの向上に役立つことを発見
  • Cascaded Diffusion Models for High Fidelity Image Generation(Ho et al., 2021):高品質な画像合成のために解像度が増加する複数の拡散モデルのパイプラインを導入
  • Diffusion Models Beat GANs on Image Synthesis(Dhariwal et al., 2021):U-Netアーキテクチャの改善とクラシファイアのガイダンスの導入により、拡散モデルは現在の最先端の生成モデルよりも優れた画像サンプルの品質を達成できることを示す
  • Classifier-Free Diffusion Guidance(Ho et al., 2021):条件付きおよび無条件の拡散モデルを単一のニューラルネットワークで共同トレーニングすることにより、拡散モデルをガイドするためのクラシファイアは不要であることを示す
  • Hierarchical Text-Conditional Image Generation with CLIP Latents(DALL-E 2)(Ramesh et al., 2022):テキストキャプションをCLIP画像埋め込みに変換し、その後、拡散モデルで画像をデコードするための事前知識を使用
  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding(ImageGen)(Saharia et al., 2022):大規模な事前訓練言語モデル(例:T5)との組み合わせによるcascaded diffusionは、テキストから画像への合成に適していることを示す

書き込み時点での重要な作品のみを含んでいることに注意してください。これは2022年6月7日です。

現時点では、拡散モデルの主な(おそらく唯一の)欠点は、画像を生成するために複数のフォワードパスが必要であることです(GANなどの生成モデルではそうではありません)。ただし、10回のノイズ除去ステップで高品質の生成が可能になる研究が進行中です。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

機械学習

もし芸術が私たちの人間性を表現する方法であるなら、人工知能はどこに適合するのでしょうか?

MITのポストドクターであるジヴ・エプスタイン氏(SM '19、PhD '23)は、芸術やその他のメディアを作成するために生成的AIを...

データサイエンス

「2023年にデータサイエンスFAANGの仕事をゲットする方法は?」

データサイエンスは非常に求められる分野となり、FAANG(Facebook、Amazon、Apple、Netflix、Google)企業での就職は大きな成...

人工知能

「クリス・サレンス氏、CentralReachのCEO - インタビューシリーズ」

クリス・サレンズはCentralReachの最高経営責任者であり、同社を率いて、自閉症や関連する障害を持つ人々のために優れたクラ...

人工知能

ベイリー・カクスマー、ウォータールー大学の博士課程候補 - インタビューシリーズ

カツマー・ベイリーは、ウォータールー大学のコンピュータ科学学部の博士課程の候補者であり、アルバータ大学の新入教員です...

人工知能

「LeanTaaSの創設者兼CEO、モハン・ギリダラダスによるインタビューシリーズ」

モーハン・ギリダラダスは、AIを活用したSaaSベースのキャパシティ管理、スタッフ配置、患者フローのソフトウェアを提供する...

人工知能

「15Rockの共同創業者兼CEO、ガウタム・バクシ氏によるインタビューシリーズ」

「ガウタム・バクシは、気候リスク管理とアドバイザリーサービスのグローバルリーダーである15Rockの共同創設者兼CEOですガウ...