「PyTorchでのSoft Nearest Neighbor Lossの実装方法」

「PyTorchでのSoft Nearest Neighbor Lossの実装方法を解説」

データセットのクラス近傍は、ソフト最近傍損失を使用して学習することができます

この記事では、ソフト最近傍損失の実装方法について説明します。こちらでも話しています。

表現学習は、ディープニューラルネットワークによって与えられたデータセット内の最も顕著な特徴を学習するタスクです。通常は教師あり学習の枠組みで行われる暗黙のタスクであり、深層学習の成功において重要な要素です(Krizhevsky et al., 2012; He et al., 2016; Simonyan et al., 2014)。言い換えれば、表現学習は特徴抽出のプロセスを自動化します。これにより、学習された表現を分類、回帰、合成などの下流タスクで使用することができます。

図1. SNNL(Frosst et al.、2019)からのイラスト。ソフト最近傍損失を最小化することで、クラス類似のデータ点間の距離(色で示されています)が最小化され、クラスの異なるデータ点間の距離が最大化されます。

また、学習された表現が特定のユースケースに適合するように形成することもできます。分類の場合、同じクラスのデータポイントは集まりやすくなります。一方、生成(例:GAN)の場合、学習された表現は実データと合成データのポイントが集まるようになります。

同様に、Principal Components Analysis(PCA)を使用して特徴をエンコードすることもあります。ただし、PCAによるエンコードされた表現にはクラスまたはラベル情報がありません。そのため、下流タスクのパフォーマンスをさらに向上させることができます。データセットの近隣構造、つまりどの特徴がクラスター化されているか、そしてそのクラスターがセミスーパーバイズド学習文献のクラスタリング仮説に従って同じクラスに属する特徴であることを学習することで、エンコードされた表現を改善することができます(Chapelle et al., 2009)。

表現に近隣構造を統合するために、局所線形埋め込み(LLE)、またはLLE(Roweis & Saul, 2000)や近隣成分分析(NCA)(Hinton et al., 2004)やt分散近傍埋め込み(t-SNE)(Maaten & Hinton, 2008)などの手法が導入されています。

しかし、前述のような埋め込みの手法には欠点があります。例えば、LLEとNCAは非線形な埋め込みではなく、線形な埋め込みをエンコードします。一方、t-SNEの埋め込みは、使用されるハイパーパラメータに依存して異なる構造を生成します。

このような欠点を回避するために、改良されたNCAアルゴリズムであるソフト最近傍損失(SNNL)(Salakhutdinov & Hinton, 2007; Frosst et al., 2019)を使用することができます。SNNLは非線形性を導入し、ニューラルネットワークの各隠れ層で計算されます。この損失関数はデータセット内のポイントの絡み合いを最適化するために使用されます。

この文脈では、絡み合い (entanglement) とは、類似するクラスのデータ点同士の距離が、異なるクラスのデータ点と比べてどれだけ近いかを表します。低い絡み合いは、類似するクラスのデータ点同士が異なるクラスのデータ点よりもずっと近いことを意味します(図1を参照)。このようなデータ点のセットを持つことは、後続のタスクをより簡単に達成し、さらに優れたパフォーマンスを実現します。Frosst et al. (2019) は、温度係数Tを導入することでSNNL目的関数を拡張しました。したがって、最終的な損失関数は次のようになります:

図2. ソフトニアレストネイバー損失関数。図は著者によるもの。

ここで、dはニューラルネットワークの生入力特徴量または隠れ層表現における距離尺度であり、Tは隠れ層のデータ点間の距離に比例する温度係数です。この実装では、より安定した計算のためにコサイン距離を距離尺度として使用しています。

図3. コサイン距離の式。図は著者によるもの。

この記事の目的は、読者がソフトニアレストネイバー損失を理解し、実装できるようにすることです。そのために、損失関数を解析して理解を深めましょう。

距離尺度

まず計算すべきことは、データ点間の距離です。これは、ネットワークの生入力特徴量または隠れ層表現になります。

図4. SNNLの計算の最初のステップは、入力データ点の距離尺度を計算することです。図は著者によるもの。

この実装では、より安定した計算のためにコサイン距離尺度(図3)を使用しています。上記の図では、ijとikという指定されたサブセットを無視し、単に入力データ点のコサイン距離を計算することに焦点を当てています。これは、以下のPyTorchコードを使用して実現します:

normalized_a = torch.nn.functional.normalize(features, dim=1, p=2)normalized_b = torch.nn.functional.normalize(features, dim=1, p=2)normalized_b = torch.conj(normalized_b).Tproduct = torch.matmul(normalized_a, normalized_b)distance_matrix = torch.sub(torch.tensor(1.0), product)

上記のコードスニペットでは、まず行1と2で入力特徴量をユークリッドノルムで正規化します。そして、行3で正規化された入力特徴量の第2セットの共役転置を取得します。共役転置は複素ベクトルを考慮するためです。そして、行4と5で入力特徴量のコサイン類似度と距離を計算します。

具体的には、次の特徴量セットを考えましょう:

tensor([[ 1.0999, -0.9438,  0.7996, -0.4247],        [ 1.2150, -0.2953,  0.0417, -1.2913],        [ 1.3218,  0.4214, -0.1541,  0.0961],        [-0.7253,  1.1685, -0.1070,  1.3683]])

上記で定義した距離尺度を使用すると、次の距離行列が得られます:

tensor([[ 0.0000e+00,  2.8502e-01,  6.2687e-01,  1.7732e+00],        [ 2.8502e-01,  0.0000e+00,  4.6293e-01,  1.8581e+00],        [ 6.2687e-01,  4.6293e-01, -1.1921e-07,  1.1171e+00],        [ 1.7732e+00,  1.8581e+00,  1.1171e+00, -1.1921e-07]])

サンプリング確率

私たちは、各特徴が他のすべての特徴とのペアワイズ距離に基づいて選ばれる確率を表す行列を計算することができます。これは、iとjまたはkの点の間の距離に基づいてi点を選ぶ確率です。

図5. 2番目のステップは、距離に基づいてポイントを選ぶサンプリング確率を計算することです。作成者による図。

次のコードを使用してこれを計算することができます:

pairwise_distance_matrix = torch.exp(    -(distance_matrix / temperature)) - torch.eye(features.shape[0]).to(model.device)

このコードはまず、距離行列を温度係数で割った後、その値を負の指数で計算し、値を正の値にスケーリングします。温度係数は、ポイントの対の距離に与えられる重要性を制御する方法を決定します。例えば、低温では、損失は小さな距離によって支配されますが、実際の距離が離れた表現の間の距離はより関係がなくなります。

torch.eye(features.shape[0])(対角行列)を引く前のテンソルの値は次のようになります:

tensor([[1.0000, 0.7520, 0.5343, 0.1698],        [0.7520, 1.0000, 0.6294, 0.1560],        [0.5343, 0.6294, 1.0000, 0.3272],        [0.1698, 0.1560, 0.3272, 1.0000]])

距離行列から対角行列を引くことで、自己類似性の項(つまり、各点自体への距離または類似性)をすべて除去します。

次に、次のコードを使用してデータポイントの各ペアに対するサンプリング確率を計算できます:

pick_probability = pairwise_distance_matrix / (    torch.sum(pairwise_distance_matrix, 1).view(-1, 1)    + stability_epsilon)

マスクされたサンプリング確率

これまでに計算したサンプリング確率には、まだラベル情報が含まれていません。ラベル情報をサンプリング確率に統合するために、データセットのラベルでマスクします。

図6. 同じクラスに属するポイントの確率を分離するために、ラベル情報を使用します。作成者による図。

まず、ラベルベクトルからペアワイズ行列を導出する必要があります:

masking_matrix = torch.squeeze(    torch.eq(labels, labels.unsqueeze(1)).float())

ラベル情報を使用して同じクラスに属するポイントの確率を分離するために、マスキング行列を適用します:

masked_pick_probability = pick_probability * masking_matrix

次に、特定の特徴のサンプリング確率の合計確率を計算するために、マスクされたサンプリング確率の行ごとの合計を計算します:

summed_masked_pick_probability = torch.sum(masked_pick_probability, dim=1)

最後に、計算上の便宜のためにサンプリング確率の合計の対数を計算し、平均を取ってネットワークの最近傍損失として機能するようにします:

snnl = torch.mean(    -torch.log(summed_masked_pick_probability + stability_epsilon)

これらのコンポーネントを組み合わせて、ディープニューラルネットワークのすべてのレイヤーにわたってソフト最近傍損失を計算するためのフォワードパス関数を作成することができます:

def forward(    self,    model: torch.nn.Module,    features: torch.Tensor,    labels: torch.Tensor,    outputs: torch.Tensor,    epoch: int,) -> Tuple:    if self.use_annealing:        self.temperature = 1.0 / ((1.0 + epoch) ** 0.55)    primary_loss = self.primary_criterion(        outputs, features if self.unsupervised else labels    )    activations = self.compute_activations(model=model, features=features)    layers_snnl = []    for key, value in activations.items():        value = value[:, : self.code_units]        distance_matrix = self.pairwise_cosine_distance(features=value)        pairwise_distance_matrix = self.normalize_distance_matrix(            features=value, distance_matrix=distance_matrix        )        pick_probability = self.compute_sampling_probability(            pairwise_distance_matrix        )        summed_masked_pick_probability = self.mask_sampling_probability(            labels, pick_probability        )        snnl = torch.mean(            -torch.log(self.stability_epsilon + summed_masked_pick_probability)        )        layers_snnl.append(snnl)        snn_loss = torch.stack(layers_snnl).sum()    train_loss = torch.add(primary_loss, torch.mul(self.factor, snn_loss))    return train_loss, primary_loss, snn_loss

ディセンタングルされた表現の可視化

ソフト最近傍損失を持つオートエンコーダを訓練し、学習したディセンタングルされた表現を可視化します。オートエンコーダは(x-500–500–2000-d-2000–500–500-x)ユニットを持ち、MNIST、Fashion-MNIST、およびEMNIST-Balancedの小規模なラベル付きサブセットでトレーニングされました。これは、オートエンコーダが教師なしモデルであるため、ラベル付きの例の希少性をシミュレートするためです。

Figure 7. 3D visualization comparing the original representation...

EMNIST-Balancedデータセットの簡単でクリーンな可視化のために、任意に選ばれた10のクラスターのみを可視化しました。上記の図では、クラスターの分散とクラスターの色による正しいクラスター割り当てによって、潜在コード表現がクラスタリングに適したものになったことがわかります。

結論

本記事では、PyTorchでソフト最近傍損失関数をどのように実装するかを詳しく説明しました。

ソフト最近傍損失は、最初にSalakhutdinov&Hinton(2007)によって導入され、その後、オートエンコーダの潜在コード(ボトルネック)表現で損失を計算し、それからダウンストリームのkNN分類タスクに使用されました。

Frosst、Papernot、&Hinton(2019)は、ソフト最近傍損失を拡張し、温度係数を導入し、ニューラルネットワークのすべての層で損失を計算しました。

最後に、オートエンコーダのディセンタングルされた表現をさらに改善し、ディセンタングルのプロセスを高速化するために、ソフト最近傍損失にアニーリング温度係数を適用しました(Agarap & Azcarraga、2020)。

完全なコード実装はGitLabで利用できます。

参考文献

  • Agarap, Abien Fred、およびArnulfo P. Azcarraga。「Improving k-means clustering performance with disentangled internal representations.」2020 International Joint Conference on Neural Networks(IJCNN)。IEEE、2020。
  • Chapelle、Olivier、Bernhard Scholkopf、およびAlexander Zien。「Semi-supervised learning(chapelle、o. et al.、eds。; 2006)[book reviews]。」IEEE Transactions on Neural Networks 20.3(2009):542–542。
  • Frosst、Nicholas、Nicolas Papernot、およびGeoffrey Hinton。「Analyzing and improving representations with the soft nearest neighbor loss.」International conference on machine learning。PMLR、2019。
  • Goldberger、Jacob、他。「Neighbourhood components analysis.」Advances in neural information processing systems。2005。
  • He、Kaiming、他。「Deep residual learning for image recognition.」Proceedings of the IEEE conference on computer vision and pattern recognition。2016。
  • Hinton、G.、他。「Neighborhood components analysis.」Proc. NIPS. 2004。
  • Krizhevsky、Alex、Ilya Sutskever、およびGeoffrey E. Hinton。「Imagenet classification with deep convolutional neural networks.」Advances in neural information processing systems 25(2012)。
  • Roweis、Sam T.、およびLawrence K. Saul。「Nonlinear dimensionality reduction by locally linear embedding.」science 290.5500(2000):2323–2326。
  • Salakhutdinov、Ruslan、およびGeoff Hinton。「Learning a nonlinear embedding by preserving class neighbourhood structure.」Artificial Intelligence and Statistics。2007。
  • Simonyan、Karen、およびAndrew Zisserman。「Very deep convolutional networks for large-scale image recognition.」arXiv preprint arXiv:1409.1556(2014)。
  • Van der Maaten、Laurens、およびGeoffrey Hinton。「Visualizing data using t-SNE.」Journal of machine learning research 9.11(2008)。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more