テキストから SQL へのタスクで広く使用されているベンチマークである Spider のトップ 10 に LLM をランクアップしたいですか? Spider は、LLM がテキスト クエリを SQL コードに変換できるかどうかを評価します。
text-to-SQL に馴染みのない方のために説明すると、その重要性は、企業がデータと対話する方法を変革することにあります。 クエリの作成について SQL の専門家に頼る代わりに、ユーザーは平易な英語でデータについて質問するだけで、正確な回答を得ることができます。 これにより、データへのアクセスが民主化され、ビジネスインテリジェンスが強化され、より情報に基づいた意思決定が可能になります。
Spider ベンチマークは、text-to-SQL システムのパフォーマンスを評価するための広く認識されている標準です。LLM は自然言語クエリを正確な SQL ステートメントに変換することが求められ、データベース スキーマに関する深い理解と、構文的および意味的に正しい SQL コードを生成する能力が求められます。
この投稿では、オープンソースの Llama3 8B Instruct モデルを使用して、1 日未満の作業で Spider 開発データセットで 79.9% 、テスト データセットで 78.9% のスコアを達成した方法について詳しく説明します。これは、ベースラインに対して 19 ポイントという驚異的な改善です。 このパフォーマンスは、 Databricksでの戦略的なプロンプトとファインチューニングのおかげで、現 在は凍結されている Spider リーダーボードでトップ 10 にランクインすることになります。
ベースライン性能のためのゼロショットプロンプト
まず、テーブルを作成した CREATE TABLE
ステートメントと、それらのテーブルを使用して回答したい質問で構成される非常にシンプルなプロンプト形式を使用して、Spider dev データセットでの Meta Llama 3 8B Instruct のパフォーマンスを評価してみましょう。
このタイプのプロンプトは、プロンプトに他の例がないため、しばしば「ゼロショット」と呼ばれます。 Spider dev データセットの最初の質問では、このプロンプト形式によって次が生成されます。
この形式を使用して開発データセットで Spider ベンチマークを実行すると、実行精度と貪欲なデコードを使用して測定した場合、全体のスコアは 60.9 になります。 これは、モデルが 60.9% の確率で SQL を生成し、それを実行すると、正しいソリューションを表す「ゴールド」クエリと同じ結果が生成されることを意味します。
Easy | Medium | Hard | Extra | All | |
---|---|---|---|---|---|
ゼロショット | 78.6 | 69.3 | 42.5 | 31.3 | 60.9 |
ベースライン スコアが確立されたので、ファインチューニングに入る前に、Spider 開発ベンチマーク データセットでベース モデルのスコアを上げるために、さまざまなプロンプト戦略を試してみましょう。
サンプル行によるプロンプト
最初に使用したプロンプトの欠点の 1 つは、データ型以外の列のデータに関する情報が含まれていないことです。 Spiderを使用したモデルのテキストからSQLへの機能の評価に関する論文では、サンプリングされた行をプロンプトに追加するとスコアが高くなることがわかったので、それを試してみましょう。
上記のプロンプト形式を更新して、テーブル作成クエリに各テーブルの最初の数行も含まれるようにすることができます。 以前の同じ質問に対して、更新されたプロンプトはありません。
各テーブルのサンプル行を含めると、全体のスコアが約 6 パーセント ポイント上昇して 67.0 になります。
Easy | Medium | Hard | Extra | All | |
---|---|---|---|---|---|
サンプル行によるゼロショット | 80.6 | 75.3 | 51.1 | 41.0 | 67.0 |
少数のプロンプト(Few-shot Prompting)
少数のプロンプトは、LLM で使用されるよく知られた戦略であり、実行するタスクを示すいくつかの例を含めることで、正しい SQL を生成するなど のタスクのパフォーマンスを向上させることができます。 ゼロショットのプロンプトで、スキーマを提供し、質問をしました。 数回のプロンプトで、いくつかのスキーマ、質問、その質問に答える SQL を提供し、そのシーケンスを数回繰り返してから、実際に尋ねたい質問に進みます。 これにより、通常、ゼロショットプロンプトよりもパフォーマンスが向上します。
SQL生成タスクを示す例の良いソースは、実際には Spider トレーニング データセットそのものです。 このデータセットから、対応するテーブルを含むいくつかの質問をランダムにサンプルとして取得し、これらの質問のそれぞれに回答できる SQL を示す数回のプロンプトを作成できます。 前のプロンプトの時点でサンプル行を使用しているため、これらの例の 1 つにサンプル行も含まれていることを確認して、使用法を示す必要があります。
以前のゼロショットプロンプトから改善できるもう1つの改善点は、最初に「システムプロンプト」を含めることです。 システム プロンプトは通常、実行されるタスクの概要を示す詳細なガイダンスをモデルに提供するために使用されます。 ユーザーはモデルとのチャットの過程で複数の質問をすることができますが、システムプロンプトはユーザーが質問する前に一度だけ提供されるため、基本的にチャット中に「システム」がどのように機能すべきかについての期待が確立されます。
これらの戦略を念頭に置いて、上部に大きな SQL コメント ブロックとして表されるシステム メッセージで始まり、その後に 3 つの例が続く、数回のプロンプトを作成できます。
この新しいプロンプトの結果、スコアは 70.8 となり、前回のスコアからさらに 3.8 ポイント改善されました。 私たちは、単純なプロンプト戦略だけで、開始したところからスコアを 10%近く 引き上げました。
Easy | Medium | Hard | Extra | All | |
---|---|---|---|---|---|
サンプル行による少数ショット | 83.9 | 79.1 | 55.7 | 44.6 | 70.8 |
おそらく、プロンプトを微調整することによる収穫逓減のポイントに達しているのでしょう。 モデルを微調整して、さらにどのようなメリットが得られるかを確認しましょう。
LoRAとの連携
モデルをファインチューニングする場合、最初の質問はどのトレーニング データを使用するかということです。 Spider にはトレーニング データセットが含まれているため、ここから始めるのが良いと思われます。 モデルを微調整するには、 Databricks の Standard_NC24ads_A100_v4 などの単一の A100 80GB Databricks GPU クラスターでモデルを効率的にトレーニングできるように QLoRA を使用します。 これは、Spider トレーニング データセットの 7,000 件のレコードを使用して約 4 時間で完了できます。 以前のブログ記事で、LoRA によるファインチューニングについて説明しました。 興味のある読者は、その投稿で詳細を参照してください。 trl、peft、bitsandbytes ライブラリを使用して、標準的なトレーニング レシピに従うことができます。
Spider からトレーニング レコードを取得していますが、モデルが学習できるような形式でフォーマットする必要があります。 目標は、スキーマ (サンプル行を含む)、質問、SQL で構成される各レコードを 1 つのテキスト文字列にマップすることです。 まず、生の Spider データセットに対していくつかの処理を実行します。 生データから、各レコードがschema_with_rows、question
、query
の 3 つのフィールドで構成されるデータセットを生成します。 schema_with_rows
フィールドは、以前の数回のプロンプトで使用されたCREATE TABLE
ステートメントと行のフォーマットに従って、質問に対応するテーブルから取得されます。
次に、トークナイザーをロードします。
処理済みの Spider トレーニング データセットの各レコードをテキスト文字列に変換するマッピング関数を定義します。 トークナイザーから apply_chat_template
を使用して、テキストを Instruct モデルが期待するチャット形式に便利にフォーマットできます。 これは、few-shot プロンプトに使用している形式とまったく同じではありませんが、プロンプトの定型文形式がわずかに異なっていても、モデルは十分に一般化されています。
SYSTEM_PROMPTには、前のfew-shotプロンプトで使用したのと同じシステムプロンプトを使用します。 USER_MESSAGE_FORMAT の場合も同様に使用します。
この関数を定義したら、あとは処理された Spider データセットをそれを使って変換し、JSONL ファイルとして保存するだけです。
トレーニングする準備が整いました。 数時間後、微調整されたLlama3 8B Instructが手に入りました。 この新しいモデルで数ショットのプロンプトを再実行すると、スコアは 79.9になり、以前のスコアからさらに 9ポイント 改善されました。 これで、合計スコアが単純なゼロショットのベースラインよりも~19パーセントポイント 上昇しました。
Easy | Medium | Hard | Extra | All | |
---|---|---|---|---|---|
サンプル行による少数ショット (微調整されたLlama3 8B指示) |
91.1 | 85.9 | 72.4 | 54.8 | 79.9 |
サンプル行による少数ショット (Llama3 8B指示) |
83.9 | 79.1 | 55.7 | 44.6 | 70.8 |
サンプル行によるゼロショット (Llama3 8B指示) |
80.6 | 75.3 | 51.1 | 41.0 | 67.0 |
ゼロショット (Llama3 8B指示) |
78.6 | 69.3 | 42.5 | 31.3 | 60.9 |
Llama3 8B Instructモデルと微調整バージョンが、Llama3 70B Instructなどの大型モデルとどのように比較されるのか疑問に思われるかもしれません。 8 つの A100 40 GB GPU を搭載した開発データセットで既製の 70B モデルを使用して評価プロセスを繰り返し、以下の結果を記録しました。
サンプル行による少数ショット (Llama3 70B指示) |
89.5 | 83.0 | 64.9 | 53.0 | 76.7 |
サンプル行によるゼロショット (Llama3 70B指示) |
83.1 | 81.8 | 59.2 | 36.7 | 71.1 |
ゼロショット (Llama3 70B指示) |
82.3 | 80.5 | 57.5 | 31.9 | 69.2 |
予想通り、既製のモデルを比較すると、同じプロンプト形式で測定すると、70Bモデルが8Bモデルを上回っています。 しかし、驚くべきことに、微調整されたLlama3 8B Instructモデルは、Llama3 70B Instructモデルよりも 3%高いスコアを獲得しています。 テキストから SQL への変換などの特定のタスクに重点を置くと、ファインチューニングによって、サイズがはるかに大きいモデルとパフォーマンスが同等の小さなモデルを作成できます。
モデルサービングエンドポイントにデプロイする
Llama3 は Mosaic AI Model Serving でサポートされているため、微調整された Llama3 モデルをエンドポイントにデプロイし、アプリケーションを強化するために使用することもできます。 必要なのは、微調整されたモデルをUnity Catalogに記録し、UI を使用してエンドポイントを作成することだけです。 デプロイされると、共通ライブラリを使用してクエリを実行できます。
まとめ
私たちは、ゼロショットプロンプトを使用して Spider 開発データセットで Llama3 8B Instruct の取り組みを開始し、60.9 という控えめなスコアを達成しました。 システムメッセージ、複数の例、サンプル行を含む数発のプロンプトでこれを強化することで、スコアを70.8に向上させました。 Spider トレーニング データセットでモデルをファインチューニングすることで、Spider dev で79.9 、Spider test で78.9 という素晴らしいスコアを達成し、さらなる向上が実現しました。 スタート地点から19ポイントの大幅な上昇 と、ベースのLlama3 70B Instructに対する3ポイントのリードは、私たちのモデルの実力を示すだけでなく、Spiderのトップ10リザルトで切望されていた場所を確保することにもなります。
オープンソース LLM とデータ インテリジェンス プラットフォームのパワーを活用する方法の詳細については、Data+ AI Summitに登録してください。
Appendix
評価セットアップ
生成は、vLLM、貪欲デコード(温度 0)、2 つの A100 80 GB GPU、および最大 1024 個の新しいトークンを使用して実行されました。 世代を評価するために、Github のtaoyds/test-suite-sql-evalリポジトリのテスト スイートを使用しました。
トレーニングのセットアップ
ファインチューニング設定に関する具体的な詳細は次のとおりです。
ベースモデル | Llama3 8B Instruct |
GPUs | シングルA100 80GB |
最大ステップ | 100 |
Spider トレーニング データセット | 7000 |
Lora R | 16 |
Lora Alpha | 32 |
Lora Dropout | 0.1 |
学習率 | 1.5e-4 |
学習率スケジューラ | 定数 |
勾配累積ステップ | 8 |
グラジエントチェックポイント | True |
トレーニングするバッチサイズ | 12 |
LoRAターゲットモジュール | q_proj、v_proj、k_proj、o_proj、gate_proj、up_proj、down_proj |
Data Collator 応答テンプレート | <|start_header_id|>assistant<|end_header_id|> |
ゼロショット プロンプトの例
これは、テーブル スキーマを含むゼロ ショット プロンプトとしてフォーマットされた、評価に使用した開発データセットの最初のレコードです。 質問の対象となっているテーブルは、そのテーブルを作成したCREATE TABLE
ステートメントを使用して表されます。
サンプル行を使用したゼロショットプロンプトの例
これは、テーブル スキーマとサンプル行を含むゼロ ショット プロンプトとしてフォーマットされた、評価に使用した開発データセットの最初のレコードです。 質問の対象となっているテーブルは、そのテーブルを作成した CREATE TABLE ステートメントを使用して表されます。 行は、各テーブルから "SELECT * {table_name} LIMIT 3" を使用して選択され、列名がヘッダーとして表示されます。
サンプル行を使用した数ショットのプロンプトの例
これは、テーブル スキーマとサンプル行を含む、数回のプロンプトとしてフォーマットされた、評価に使用した開発データセットの最初のレコードです。 質問の対象となっているテーブルは、そのテーブルを作成した CREATE TABLE ステートメントを使用して表されます。 行は、各テーブルから "SELECT * {table_name} LIMIT 3" を使用して選択され、列名がヘッダーとして表示されます。