欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > 通过 PromptTemplate 生成干净的 SQL 查询语句并执行SQL查询语句

通过 PromptTemplate 生成干净的 SQL 查询语句并执行SQL查询语句

2025/3/10 18:11:50 来源:https://blog.csdn.net/u013565133/article/details/145907208  浏览:    关键词:通过 PromptTemplate 生成干净的 SQL 查询语句并执行SQL查询语句

问题描述

在使用 LangChain 和 Llama 模型生成 SQL 查询时,遇到了 sqlite3.OperationalError 错误。错误信息如下:

OperationalError: (sqlite3.OperationalError) near "```sql
SELECT Name 
FROM MediaType 
LIMIT 5;
```": syntax error
[SQL: ```sql
SELECT Name 
FROM MediaType 
LIMIT 5;
```]

错误发生的原因是生成的 SQL 查询包含了不必要的 Markdown 代码块标记 ```,也就是在生成SQL语句的过程中,产生了其他的不干净文本,导致 SQL 语法错误。

最终解决方案

通过修改 PromptTemplate 来生成干净的 SQL 查询,确保生成的查询不包含任何 Markdown 代码块标记或附加评论。以下是解决方案的详细步骤和代码实现:

1. 初始化环境

首先,初始化所需的环境变量和模型:

import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq", temperature=0)

2. 定义自定义提示模板

定义一个自定义的 PromptTemplate,用于生成干净的 SQL 查询:

custom_prompt = PromptTemplate(input_variables=["dialect", "input", "table_info", "top_k"],template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

3. 创建 SQL 查询链

创建一个 SQL 查询链,并使用自定义提示模板:

write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)

4. 构造输入数据字典

构造输入数据字典,其中包含方言、表结构、问题和行数限制:

input_data = {"dialect": db.dialect,                    # 数据库方言,如 "sqlite""table_info": db.get_table_info(),        # 表结构信息"input": "What name of MediaType is?",    # 问题"top_k": 5                                # 行数限制
}

5. 调用链生成并执行 SQL 查询

调用链生成 SQL 查询,确保生成的查询不包含 Markdown 代码块标记,然后执行查询并打印结果:

response = write_query.invoke(input_data)
query = response["query"]# 执行 SQL 查询并打印结果
execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke({"query": query})
print(result)

总结

通过修改 PromptTemplate 来生成 SQL 查询时,明确要求返回的 SQL 查询不包含任何附加评论或 Markdown 格式,确保生成的 SQL 查询是干净的、可执行的。这样可以避免由多余的标记导致的 SQL 语法错误。

最后提供完整代码:

import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabaseload_dotenv()# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"\n Original table info: {table_info}")#  初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(input_variables=["dialect", "input", "table_info", "top_k"],template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)write_query  = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 构造输入数据字典,其中包含方言、表结构、问题和行数限制
input_data = {"dialect": db.dialect,                    # 数据库方言,如 "sqlite""table_info": db.get_table_info(),          # 表结构信息"question": "What name of MediaType is?","top_k": 5
}# 调用链生成 SQL 查询,返回结果为一个字典,包含键 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result:',write_query_response)#执行SQL语句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result:',execute_response)#两个动作合起来搞成链
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)

输出:
在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词