最適化ストーリー:ブルーム推論

最適化ストーリー ブルーム推論

この記事では、bloomをパワーアップする効率的な推論サーバーの裏側について説明します。

数週間にわたり、レイテンシーを5倍削減し(スループットを50倍に増やしました)、このような速度向上を達成するために私たちが経験した苦労やエピックな勝利を共有したかったです。

さまざまな人々が多くの段階で関与していたため、ここではすべてをカバーすることはできません。また、最新のハードウェア機能やコンテンツが定期的に登場するため、一部の内容は古くなっているか、まったく間違っている可能性があることをご了承ください。

もし、お好みの最適化手法が議論されていなかったり、正しく表現されていなかったりした場合は、お詫び申し上げます。新しいことを試してみたり、間違いを修正するために、ぜひお知らせください。

言うまでもなく、まず大きなモデルが最初にアクセス可能でなければ、それを最適化する理由はありません。これは、多くの異なる人々によってリードされた信じられないほどの取り組みでした。

トレーニング中にGPUを最大限に活用するために、いくつかの解決策が検討され、結果としてMegatron-Deepspeedが最終的なモデルのトレーニングに選ばれました。これは、コードがそのままではtransformersライブラリと互換性がない可能性があることを意味します。

元のトレーニングコードのため、通常行っていることの1つである既存のモデルをtransformersに移植することに取り組みました。目標は、トレーニングコードから関連する部分を抽出し、transformers内に実装することでした。この取り組みには「Younes」が取り組みました。これは、1ヶ月近くかかり、200のコミットが必要でした。

後で戻ってくるいくつかの注意点があります:

小さなモデルbigscience/bigscience-small-testingとbigscience/bloom-560mを用意する必要があります。これは非常に重要です。なぜなら、それらと一緒に作業するとすべてが高速化されるからです。

まず、最後のログがバイトまで完全に同じになることを望むことをあきらめる必要があります。PyTorchのバージョンがカーネルを変更し、微妙な違いを導入する可能性があり、異なるハードウェアでは異なるアーキテクチャのため異なる結果が得られる場合があります(コストの理由から常にA100 GPUで開発したくはないでしょう)。

すべてのモデルにとって、良い厳格なテストスイートを作ることは非常に重要です

私たちが見つけた最高のテストは、固定された一連のプロンプトを持つことでした。プロンプトを知っており、決定論的な結果が得られる必要があります。2つの生成物が同じであれば、小さなログの違いは無視できます。ドリフトが見られるたびに調査する必要があります。それは、あなたのコードがやるべきことをしていないか、または実際にそのモデルがドメイン外であるためにノイズに対してより敏感であるかのいずれかです。いくつかのプロンプトと十分に長いプロンプトがあれば、すべてのプロンプトを誤ってトリガーする可能性は低くなります。プロンプトが多ければ多いほど良く、プロンプトが長ければ長いほど良いです。

最初のモデル(small-testing)は、bloomと同じようにbfloat16であり、すべてが非常に似ているはずですが、それほどトレーニングされていないか、うまく機能しないため、出力が大きく変動します。そのため、これらの生成テストに問題がありました。2番目のモデルはより安定していましたが、bfloat16ではなくfloat16でトレーニングおよび保存されていました。そのため、2つの間にはエラーの余地があります。

完全に公平を期すために言えば、bfloat16float16への変換は推論モードでは問題なさそうです(bfloat16は主に大きな勾配を扱うために存在しません)。

このステップでは、重要なトレードオフが発見され、実装されました。bloomは分散環境でトレーニングされたため、一部のコードはLinearレイヤー上でテンソル並列処理を行っており、単一のGPU上で同じ操作を実行すると異なる結果が得られていました。これを特定するのにかなりの時間がかかり、100%の準拠を選択した場合、モデルの速度が遅くなりましたが、少しの差がある場合は実行が速く、コードがシンプルになりました。設定可能なフラグを選択しました。

注:この文脈でのパイプライン並列処理(PP)は、各GPUがいくつかのレイヤーを所有し、各GPUがデータの一部を処理してから次のGPUに渡すことを意味します。

これで、動作可能なtransformersのクリーンなバージョンがあり、これに取り組むことができます。

Bloomは352GB(176Bパラメーターのbf16)のモデルであり、それに合わせるために少なくともそれだけのGPU RAMが必要です。一時的に小さなマシンでCPUにオフロードすることを検討しましたが、推論速度が桁違いに遅くなるため、それを取り下げました。

次に、基本的にはパイプラインを使用したかったのです。つまり、ドッグフーディングであり、これがAPIが常に裏で使用しているものです。

ただし、pipelinesは分散意識がありません(それがその目的ではありません)。オプションを簡単に話し合った後、新しく作成されたdevice_map="auto"を使用してモデルのシャーディングを管理するためにaccelerateを使用することにしました。いくつかのバグを修正し、transformersのコードを修正してaccelerateが正しい仕事をするのを助ける必要がありました。

これは、transformersのさまざまなレイヤーを分割し、各GPUにモデルの一部を与えて動作させることで機能します。つまり、GPU0が作業を行い、次にGPU1に引き渡し、それ以降同様に行います。

最終的には、上に小さなHTTPサーバーを置くことで、bloom(大規模なモデル)を提供できるようになりました!

しかし、まだ最適化については議論していません!

実際には、かなり多くの作業があります。このプロセス全体はカードの城です。最適化中に基になるコードを変更するため、モデルをいかなる方法で破壊しないようにすることは非常に重要であり、思っている以上に簡単に行うことができます。

したがって、最適化の最初のステップにいるため、パフォーマンスを測定し続ける必要があります。したがって、私たちは何に関心があるか考える必要があります。多くのオプションをサポートするオープンな推論サーバーでは、ユーザーがさまざまなパラメータで多くのクエリを送信することを予想しており、私たちが関心を持っているのは次の点です:

同時に提供できるユーザーの数(スループット) 平均ユーザーの処理時間はどれくらいか(レイテンシ)?

私たちは、これを正確に実行するテストスクリプトをlocustで作成しました。それは次のようなものです:

from locust import HttpUser, between, task
from random import randrange, random


class QuickstartUser(HttpUser):
    wait_time = between(1, 5)

    @task
    def bloom_small(self):
        sentence = "Translate to chinese. EN: I like soup. CN: "
        self.client.post(
            "/generate",
            json={
                "inputs": sentence[: randrange(1, len(sentence))],
                "parameters": {"max_new_tokens": 20, "seed": random()},
            },
        )

    @task
    def bloom_small(self):
        sentence = "Translate to chinese. EN: I like soup. CN: "
        self.client.post(
            "/generate",
            json={
                "inputs": sentence[: randrange(1, len(sentence))],
                "parameters": {
                    "max_new_tokens": 20,
                    "do_sample": True,
                    "top_p": 0.9,
                    "seed": random(),
                },
            },
        )

**注意:これは私たちが使用した最高の負荷テストではありませんが、常に最初に実行されるため、アプローチ間で公平に比較できます。このベンチマークで最も優れているということは、最良の解決策であるという意味ではありません。実際の現実のパフォーマンスに加えて、他のより複雑なシナリオも使用する必要があります。**

私たちは、さまざまな実装の立ち上がりを観察したり、サーバーが適切に回路を切断するかどうかを確認したりすることを望んでいました。回路切断とは、サーバーが(高速に)クエリに回答しないことを意味します。なぜなら、同時に使用しようとする人が多すぎるからです。これは死のハグを回避するために非常に重要です。

このベンチマークでは、初期のパフォーマンスは次のようでした(16xA100 40Go on GCP、これは使用されたマシンです):

リクエスト/秒:0.3(スループット) レイテンシ:350ms/トークン(レイテンシ)

これらの数値はあまり素晴らしくありません。作業に取りかかる前に、想像できる最良の結果を推定してみましょう。操作の量の計算式は24Bsh^2 + 4𝐵s^2h24Bsh^2 + 4𝐵s^2hです。ここで、Bはバッチサイズ、sはシーケンスの長さ、hは隠れた次元です。

計算してみましょう。単一のフォワードパスに対して17 TFlopが得られます。A100のスペックを見ると、単一のカードで312 TFLOPSとされています。つまり、単一のGPUは理論上は17 / 312 = 54ms/トークンで動作する可能性があります。私たちはそれを16個使用しているので、全体のマシンで3ms/トークンになります。これらの数値はあくまで参考までに、これらの数値に到達することは決して不可能ではなく、実際のパフォーマンスは稀にスペックと一致することはありません。また、計算が制限要素でない場合は、これ以上低いレイテンシを実現できます。ただし、目標からどれだけ遠いかを把握するためには、これは良い練習です。この場合、2桁のオーダーなので、かなり遠いです。また、この推定では、すべてのFLOPSがレイテンシに役立つようになっており、そのため一度に1つのリクエストしか処理できません(問題ありません、マシンを最大限に活用しているので、他にはあまりすることはありませんが、バッチ処理を通じてスループットを取り戻すことも容易です)。

注意:Tensor Parallelism(TP)は、この文脈では各GPUが重みの一部を所有することを意味します。したがって、すべてのGPUは常にアクティブであり、より少ない作業を行います。通常、これにはわずかなオーバーヘッドが伴いますが、一部の作業が重複して実行され、さらに重要なことは、GPUが定期的に互いに結果を通信して計算を継続する必要があるということです。

私たちがどのような状況にあるのかをよく理解したところで、仕事に取り掛かる時が来ました。

私たちは、人々と私たちのさまざまな知識に基づいて、さまざまなことを試しました。

すべての試みはそれぞれ独自のブログ投稿に値しますので、それらを一覧にして、最終的な学びを説明し、現在のサーバーに組み込まれたものの詳細についてのみ説明します。パイプライン並列処理(PP)からテンソル並列処理(TP)への移行は、レイテンシにとって大きな興味深い変化です。各GPUがパラメータの一部を所有し、すべてのGPUが同時に動作します。したがって、レイテンシは劇的に減少するはずですが、結果に関して定期的に互いに通信する必要があるため、通信のオーバーヘッドが発生します。

これは非常に幅広いアプローチの範囲であることに注意しておくべきであり、意図的に各ツールについてさらに学び、後の試みにどのように適合するかを理解するためのものでした。

コードをTPUで実行するためのJAX/Flaxへの移植:

  • 並列処理のタイプを選ぶのが簡単になるはずです。ですので、TPをテストするのが簡単になります。これはJaxの設計の利点の一つです。
  • ハードウェアに制約があり、TPUのパフォーマンスはGPUよりも優れており、TPUのベンダーの選択肢が少なくなります。
  • デメリットとして、別のポートが必要です。ただし、私たちのライブラリでは歓迎されるでしょう。

結果:

  • 移植は簡単な作業ではありませんでした。いくつかの条件とカーネルが正しく再現されるのが難しかったですが、それでも管理可能でした。
  • 移植後の並列処理は非常に簡単に取得できました。Jaxに感謝します。
  • Ray/TPUワーカーとの通信は本当に困難でした。ツールの問題なのか、ネットワークの問題なのか、単に私たちの知識不足なのかはわかりませんが、実験や作業が予想以上に遅くなりました。実行に5分かかる実験を開始し、5分待っても何も起こらず、さらに10分後もまだ何も起こらず、結果的には何らかのワーカーがダウンして応答しないことが判明し、手動で問題を解決し、何かを再起動し、再起動して半時間失ってしまったことがありました。これを何度も繰り返していると、失われる日数がすぐに積み重なってしまいます。使用したツールについての批判とは必ずしも言えませんが、私たちが経験した主観的な経験は変わりません。
  • コンパイルには制御がありません。実行できるようになった後、意図した推論に最適な設定をいくつか試しましたが、設定からレイテンシ/スループットがどのようになるかを予測するのは非常に難しかったです。例えば、バッチサイズ=1(つまり、各リクエスト/ユーザーが独自のものを持っている)で0.3 rps、トークンあたりのレイテンシが15msだった(この記事の他の数字とあまり比較しないでください、異なるプロファイルを持つ異なるマシン上で行われたものです)これは素晴らしいですが、全体のスループットは古いコードとほとんど変わりませんでした。したがって、バッチ処理を追加することにしました。BS=2でレイテンシが5倍になり、スループットは2倍になりました…さらなる調査の結果、バッチサイズ=16までの各バッチサイズが同じレイテンシのプロファイルを持っていることが判明しました。つまり、コストのかかるレイテンシを5倍にすることなく、16倍のスループットを得ることができました。悪くはないですが、私たちはもっと細かい制御を望んでいました。目指していた数字は100ms、1s、10s、1mnのルールです。

ONNX/TRTまたは他のコンパイルアプローチを使用する

  • 最適化のほとんどはこれらのアプローチで処理できるはずです
  • デメリットとして、通常は並列処理を手動で処理する必要があります。

結果:

  • トレース/ジット/エクスポートするためには、PyTorchの一部を再構築する必要があったため、純粋なPyTorchアプローチと容易に結合することができ、求めていた最適化のほとんどをPyTorchの世界内で実現することができることがわかりました。これにより、コーディングの努力を最小限に抑えながら柔軟性を保持することができます。また、GPU上で実行しており、テキスト生成には多くのフォワードパスが行われるため、テンソルをGPU上に保持する必要がありますが、テンソルをあるライブラリに送信し、結果を返され、ロジット計算(argmaxまたはサンプリングなど)を実行し、再度フィードバックするのは時々難しいことがあります。ループを外部ライブラリ内に配置すると、Jaxと同様に柔軟性が失われるため、私たちのユースケースでは考慮されていませんでした。

DeepSpeed

  • これはトレーニングに使用された技術で、推論にも使用することは公平だと思われました。
  • 欠点は、これまでに推論のために使用されることはなかった/準備されていなかったことです。

結果:

  • 私たちは非常に印象的な結果を素早く得ました。これは現在実行中の前回の反復とほぼ同じです。
  • DeepSpeedの上にウェブサーバー(つまり並行処理)を配置する方法を発明する必要がありました。DeepSpeed自体も複数のプロセス(各GPUに1つずつ)を持っています。優れたライブラリMiiが存在するため、それに基づいて作業を開始する可能性もありましたが、私たちが考えていた非常に柔軟な目標には合わないと思われます(現在の解決策については後で説明します)。
  • DeepSpeedで遭遇した最大の注意点は、安定性の欠如でした。CUDA 11.4で実行する際に問題が発生しました。コードは11.6向けにビルドされていました。そして、長い間解決できなかった問題は、定期的なカーネルクラッシュ(Cudaの不正なアクセス、次元の不一致など)が発生することです。これらの問題のいくつかは修正しましたが、ウェブサーバーのストレス下での安定性を実現することはできませんでした。それにもかかわらず、私たちを助けてくれたMicrosoftの皆さんに感謝したいと思います。私たちは非常に良い会話を持ち、何が起こっているのかをより良く理解することができ、その後の作業のための実質的な洞察を得ることができました。
  • 私が感じる痛みの一つは、私たちのチームがほとんどヨーロッパにいるのに対して、Microsoftがカリフォルニアにあるため、協力が時間的に難しかったことです。これは技術的な部分とは関係ありませんが、一緒に作業する組織の部分も非常に重要であることを認識することは良いことです。
  • もう一つ注意すべきことは、DeepSpeedは最適化を注入するためにtransformersに依存していることです。そして、私たちはコードをほぼ一貫して更新していたため、DeepSpeedチームが私たちのmainブランチで問題なく動作させることが難しくなりました。私たちはそれを難しくしたことをお詫び申し上げます。これがなぜ「最先端」と呼ばれるのかもしれません。

ウェブサーバーのアイデア

  • ユーザーが長いテキスト、短いテキスト、いくつかのトークン、または完全なレシピを送信する無料のサーバーを実行することを考えると、ここで何か対策を講じる必要がありました。

結果:

  • 私たちはすべてをRustで再コーディングしました。優れたバインディングtch-rsを使用しました。Rustはパフォーマンスの向上を目指しているわけではありませんが、並行性(スレッド/プロセス)に対するより詳細な制御と、ウェブサーバーの同時処理とPyTorchの両方に対するより詳細な制御を得ることができます。PythonはGILのため、低レベルの詳細を扱うのが難しいと広く知られています。
  • 結果として、大半の苦労はポートから発生し、その後の実験は楽でした。そして、ループに対する十分な制御を持っていれば、さまざまな特性を持つリクエストの非常に広範な配列の文脈でも、みんなにとって素晴らしいパフォーマンスを実現することができるとわかりました。興味のある方のためのコードですが、サポートや素敵なドキュメントは付属していません。
  • それは数週間にわたって本番環境になりました。並行性がより寛容であったため、GPUをより効率的に使用することができました(リクエスト1にはGPU0を使用し、リクエスト0を処理する間にGPU1を使用)。そして、RPSは0.3から〜2.5に向上しましたが、レイテンシは同じです。最適な場合はスループットを16倍に増やすことができるはずですが、ここに表示されている数値は実際のワークロードの測定値ですので、それほど悪くはありません。

純粋なPyTorch

  • 既存のコードを変更して、reshapeのような操作を削除し、より最適化されたカーネルなどを使用して速度を向上させる。
  • 欠点は、TPのコーディングを自分たちで行う必要があり、コードが私たちのライブラリに合致するという制約があることです。

結果:

  • 次の章へ。

より効率的なPyTorchの記述

リストの最初の項目は、最初の実装で不要な操作を削除することでした。コードを見て明らかな欠陥を見つけることでいくつかの操作が分かります:

  • AlibiはBloomで位置エンベディングを追加するために使用され、それが多くの場所で計算されていましたが、一度だけ計算すればより効率的になります。

古いコード:リンク 新しいコード:リンク

これは10倍の高速化であり、最新バージョンにはパディングも含まれています!このステップは一度だけ計算されるため、実際の速度は重要ではありませんが、オペレーション数やテンソルの作成数を減らすことは良い方向性です。

他の部分は、プロファイリングを開始し、テンソルボード拡張をかなり広範に使用した場合にはっきりと出てきます。

これによって次のようなイメージが提供され、洞察が得られます:

注意がたくさんかかるので、これはCPUビューです。長いバーは長いことを意味するのではなく、CPUが前のステップのGPUの結果を待っていることを意味します。 `baddbmm`の前にはたくさんの`cat`操作が見られます。

たとえば、reshape/transposeを大幅に削除することで、次のことがわかりました: – アテンションがホットパスであること(予想されていますが、常に確認するのは良いことです)。 – アテンションでは、reshapeの数が非常に多いため、多くのカーネルが実際のコピーでした – リシェイプを削除することで、重み自体と過去を再構成することができました。これは破壊的な変更ですが、パフォーマンスをかなり向上させました!

TPのサポート

さて、私たちはほとんどの簡単な問題を解決しました。PPでは、トークンごとのレイテンシが350msから300msに大幅に改善されました。これはレイテンシの15%の削減ですが、実際にはそれ以上の改善がありますが、最初の測定では非常に厳格ではありませんでしたので、その数字に固執しましょう。

次に、TPの実装を提供しました。実装は予想よりもはるかに速かったです。経験豊富な1人の開発者が半日かかりました。結果はここにあります。他のプロジェクトからのコードの再利用もできました。

レイテンシは直接300ms/tokenから91ms/tokenに改善されました。これはユーザーエクスペリエンスの大幅な改善です。単純な20トークンのリクエストは6秒から2秒に変わり、「遅い」体験からわずかに遅延した体験になりました。

また、スループットも非常に向上し、10RPSになりました。スループットは、バッチサイズ=1でクエリを実行する時間が、バッチサイズ=32と同じ時間をかけるため、この時点でレイテンシのコストがほぼゼロになるためです。

簡単な問題

TPの実装ができたので、再びプロファイリングと最適化を開始することができました。これはかなり大きな変化なので、最初からやり直さなければなりませんでした。

最初に目立ったことは、同期(ncclAllReduce)が負荷の中で支配的な部分になり始めていることです。これは予想されることで、これが同期部分であるため、ある程度の時間がかかります。私たちはすでに`nccl`を使用しているので、これを最適化するためには多くの余地はないと考えていましたが、まだ改善の余地があるかもしれません。

2番目のことは、Gelu演算子が多くの要素ごとのカーネルを起動しており、全体として予想以上に計算リソースを使用していたことです。

変更前:

def bloom_gelu_forward(x):
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

変更後:

@torch.jit.script
def bloom_gelu_forward(x):
    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

これにより、複数の小さな要素ごとのカーネル(およびテンソルのコピー)から単一のカーネル操作に変換されます!

これにより、レイテンシが91ms/tokenから81ms/tokenに10%改善されました!

ただし、これは魔法の黒箱ではありませんので、どこでも簡単に使用できるわけではありません。カーネルの融合が必ずしも発生しない場合や、以前に使用されていた操作がすでに非常に効率的である場合は、カーネルの融合は行われません。

うまく機能する場所:

  • 多くの小さな/要素ごとの操作がある場合
  • いくつかの削除が難しいreshape、一般的なコピーがあるホットスポットがある場合
  • 融合が発生する場合

大失敗

また、テスト期間中に、RustサーバーのレイテンシがPythonサーバーよりも一貫して25%低かったという点がありました。これは非常に奇妙でしたが、一貫して測定されていたため、カーネルを削除するとスピードアップが見られたため、Pythonのオーバーヘッドを削減することで素晴らしいブーストを提供できるかもしれないという印象を持っていました。

私たちは3日間の仕事を始め、torch.distributedの必要な部分を再実装しました。Rustの世界でnccl-rsを使って実行できるようにしました。バージョンは動作していましたが、Pythonの対応部分と比較して何かがうまくいかなかったです。問題の調査中に、私たちは…PyTorchの計測からプロファイラを削除するのを忘れていたことがわかりました

それは大失敗でした。削除すると25%が戻り、両方のコードが同じくらい速く実行されました。これは最初に予想していたことで、Pythonは主にtorch cppのコードを実行しているため、パフォーマンスに影響を与えるべきではありません。結局、3日間は世界の終わりではなく、将来的に役立つかもしれませんが、それでもかなり悪い状況です。最適化を行う際には、誤ったまたは誤った測定を行い、失望や全体的な製品への悪影響につながることが非常に一般的です。そのため、小さなステップで行い、可能な限り早く結果についての期待を持つことは、そのリスクを抑えるのに役立ちます。

もう一つ注意が必要だったのは、初期の順方向パス(過去を含まない)と後の順方向パス(過去を含む)の部分です。最初のパスを最適化すると、実行時間の大部分を占める後のパスが遅くなる可能性が非常に高いです。また、よくある問題の一つは、CPUの時間ではなく実際のCUDAの時間を測定していることです。したがって、実行する際にはtorch.cuda.synchronize()を使用してカーネルが完了することを確認する必要があります。

カスタムカーネル

これまで、PyTorchの外部でカスタムコードを使用せずにDeepSpeedのパフォーマンスにかなり近づくことができました!かなり素晴らしいです。また、実行時のバッチサイズの柔軟性に妥協する必要もありませんでした!

ただし、DeepSpeedの経験を考慮すると、torch.jit.scriptでは私たちが行いたかったいくつかの操作を結合するためのカスタムカーネルを作成してみたいと思いました。具体的には、次の2行です:

attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)

最初のmasked fillは新しいテンソルを作成しており、これはsoftmax演算子に対してこれらの値を無視するように伝えるためだけです。また、softmaxはfloat32で計算する必要があります(安定性のため)。ただし、カスタムカーネル内では、必要なアップキャストの量を制限することができるため、実際の合計値と蓄積値に制限することができます。

コードはこちらで見つけることができます。注意点として、ターゲットとする単一のGPUアーキテクチャがあったため、それに焦点を当てることができました。また、私たちは(まだ)カーネルの書き方には詳しくないため、これを行うためのより良い方法があるかもしれません。

このカスタムカーネルにより、レイテンシがさらに10%増加し、81ms/tokenから71ms/tokenのレイテンシになりました。柔軟性も保ちながらです。

その後、他のことを調査し、より多くの演算子を結合したり、他の場所に配置したりしましたが、十分な影響を与えることはありませんでした。

ウェブサーバーパート

Rustの対応部分と同様に、異なるパラメータでのリクエストのバッチ処理を実装する必要がありました。PyTorchの世界にいるので、起こっていることにほぼ完全な制御があります。Pythonであるため、torch.distributedはスレッドではなく複数のプロセスで実行する必要があり、プロセス間の通信がやや難しくなります。最終的に、Redisのpub/subを介して生の文字列を通信することで、リクエストを一度にすべてのプロセスに配布する方法を選択しました。異なるプロセスにいるため、テンソルを通信するよりも(テンソルの方がはるかに大きいため)この方法を選択する方が簡単です。

その後、パラメータをバッチのすべてのメンバーに適用する代わりに、generateの使用をやめる必要がありました。幸いなことに、LogitsProcessorのような下位レベルのアイテムを再利用することで、多くの作業を省くことができます。

したがって、バッチの各メンバーにパラメータを適用するgenerate関数を再構築しました。

最終的なUXの非常に重要な側面は、レイテンシです。異なるリクエスト用の異なるパラメータセットがあるため、20トークンのリクエストと250トークンのリクエストがあるかもしれません。レイテンシが1トークンあたり75msかかる場合、1つのリクエストは1.5秒かかり、もう1つのリクエストは18秒かかります。バッチ処理を行っている場合、18秒待たなければならないユーザーにとって、実行時間が900ms/tokenであるかのように見えるため、かなり遅いと思われるでしょう!

極めて柔軟性のあるPyTorchの世界にいるため、私たちができることは、バッチから最初のリクエストを抽出し、最初の20トークンを生成したら、要求された1.5秒以内にそのユーザーに返すことです!また、計算において230トークン分の節約も行うことができます。

したがって、柔軟性は最高のレイテンシーを得るために重要です。

最適化は終わりのない仕事であり、他のプロジェクトと同様に、作業の20%が通常80%の結果をもたらすでしょう。ある時点で、私たちはいくつかのアイデアの潜在的な収益を把握するための小規模なテスト戦略を始めました。テストが有意な結果をもたらさなかった場合、私たちはそのアイデアを破棄しました。10%の増加に1日を費やすことは十分に価値がありますが、10倍の結果を得るために2週間を費やすことも十分に価値があります。10%のために2週間を費やすことはあまり興味深くありません。

試したことはありますか…?

さまざまな理由で使用していないが存在することを知っているものです。それは私たちのユースケースに適応されないように感じた、あまりにも多くの作業、収益が十分に見込めなかった、または単に試すためのオプションが多すぎて時間が足りなかったためかもしれません。以下は特定の順序ではありません:

  • Cudaグラフ
  • nvFuser(これがtorch.jit.scriptの動力源ですので、使用しました。)
  • FasterTransformer
  • NvidiaのTriton
  • XLA(Jaxもxlaを使用しています!)
  • torch.fx
  • TensorRT

お気に入りのツールがここにない場合や、有用と思われる重要な要素を見逃していると思われる場合は、お気軽にご連絡ください!

フラッシュアテンション

フラッシュアテンションの統合を簡単に見てみましたが、past_key_valuesを使用する場合には大きな改善が見られませんでした。計算にalibiテンソルを含めるために適応する必要があったため、この作業を行うことはしませんでした(少なくともまだは)。

OpenAI Triton

TritonはPythonでカスタムカーネルを構築するための素晴らしいフレームワークです。これをもっと使いたいと思っていますが、現時点では使用していません。私たちのCudaカーネルよりも優れたパフォーマンスを発揮するかどうか、興味津々です。その部分のオプションを考慮する際、Cudaで直接書くことが最短の道のりに思えました。

パディングとリシェイプ

この記事全体で述べたように、テンソルのコピーごとにコストがかかり、プロダクションの実行にはパディングの隠れたコストもあります。2つのクエリが非常に異なる長さで入ってきた場合、四角形に合わせるためにパディング(ダミートークンの使用)を行う必要があります。これにより、不必要な計算が多く発生する可能性があります。詳細はこちら。

理想的には、これらの計算を行わずに、リシェイプも行わずにすべての推論をCUDAまたは純粋なGPU実装で行うことができるでしょう。操作を融合できたときのパフォーマンスの向上を考えると、これは望ましいです。しかし、どの程度効果があるかはわかりません。より賢いGPUの専門家のアイデアがあれば、ぜひ聞かせてください!

このすべての作業は、多くのHFチームメンバーの協力の結果です。特定の順序ではありませんが、@ThomasWang @stas @Nouamane @Suraj @Sanchit @Patrick @Younes @Sylvain @Jeff(Microsoft) @Reza およびすべてのBigScience組織のメンバーです。

We will continue to update VoAGI; if you have any questions or suggestions, please contact us!

Share:

Was this article helpful?

93 out of 132 found this helpful

Discover more

人工知能

「クリス・サレンス氏、CentralReachのCEO - インタビューシリーズ」

クリス・サレンズはCentralReachの最高経営責任者であり、同社を率いて、自閉症や関連する障害を持つ人々のために優れたクラ...

AIテクノロジー

アンソニー・グーネティレケ氏は、Amdocsのグループ社長であり、テクノロジー部門および戦略部門の責任者です- インタビューシリーズ

アンソニー・グーネティレーケは、Amdocsでグループ社長、テクノロジーと戦略担当です彼と企業戦略チームは、会社の戦略を策...

AIニュース

Q&A:ブラジルの政治、アマゾンの人権、AIについてのGabriela Sá Pessoaの見解

ブラジルの社会正義のジャーナリストは、MIT国際研究センターのフェローです

人工知能

「Zenの共同創設者兼CTO、イオン・アレクサンドル・セカラ氏によるインタビューシリーズ」

創業者兼CTOであるIon-Alexandru Secaraは、Zen(PostureHealth Inc.)の開発を牽引しており、画期的な姿勢矯正ソフトウェア...

AIテクノロジー

「LXTのテクノロジーバイスプレジデント、アムル・ヌール・エルディン - インタビューシリーズ」

アムル・ヌール・エルディンは、LXTのテクノロジー担当副社長ですアムルは、自動音声認識(ASR)の文脈での音声/音響処理と機...

データサイエンス

「Adam Ross Nelsonによる自信のあるデータサイエンスについて」

データサイエンスの中で新たな分野が現れ、研究内容が理解しにくい場合は、専門家や先駆者と話すのが最善です最近、私たちは...