使用自定义大模型来部署Wren AI(开源的文本生成SQL方案)
关于
- 首次发表日期:2024-07-15
- Wren AI官方文档: https://docs.getwren.ai/overview/introduction
- Wren AI Github仓库: https://github.com/Canner/WrenAI
关于Wren AI
Wren AI 是一个开源的文本生成SQL解决方案。
前提准备
由于之后会使用docker来启动服务,所以首先确保docker已经安装好了,并且网络没问题。
先克隆仓库:
git clone https://github.com/Canner/WrenAI.git
关于在Wren AI中使用自定义大模型和Embedding模型
Wren AI目前是支持自定义LLM和Embedding模型的,其官方文档 https://docs.getwren.ai/installation/custom_llm 中有提及,需要创建自己的provider类。
其中Wren AI本身已经支持和OPEN AI兼容的大模型了;但是自定义的Embedding模型方面,可能会报错,具体来说是wren-ai-service/src/providers/embedder/openai.py
中的以下代码
if self.dimensions is not None:response = await self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=text_to_embed)
else:response = await self.client.embeddings.create(model=self.model, input=text_to_embed)
其中if self.dimensions is not None
这个条件分支是会报错的(默认会运行这个分支),所以我的临时解决方案是注释掉它。
具体而言是在wren-ai-service/src/providers/embedder
文件夹中创建一个openai_like.py
文件,表示定义一个和open ai类似的embedding provider,取个名字叫做openai_like_embedder
,具体的完整代码见本文附录。
配置docker环境变量等并启动服务
首先,进入docker
文件夹,拷贝.env.example
并重命名为.env.local
。
然后拷贝.env.ai.example
并重命名为.env.ai
,修改其中的LLM和Embedding的配置,相关部分如下:
LLM_PROVIDER=openai_llm
LLM_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxx
LLM_OPENAI_API_BASE=http://api.siliconflow.cn/v1
GENERATION_MODEL=meta-llama/Meta-Llama-3-70B
# GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 32768, "response_format": {"type": "json_object"}}EMBEDDER_PROVIDER=openai_like_embedder
EMBEDDING_MODEL=bge-m3
EMBEDDING_MODEL_DIMENSION=1024
EMBEDDER_OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxx
EMBEDDER_OPENAI_API_BASE=https://xxxxxxxxxxxxxxxx/v1
由于我们创建了一个自定义的embedding provider,需要将文件映射到docker容器中,具体可以通过配置docker-compose.yaml
中的wren-ai-service
,添加volumes
属性:
wren-ai-service:image: ghcr.io/canner/wren-ai-service:${WREN_AI_SERVICE_VERSION}volumes:- /root/WrenAI/wren-ai-service/src:/src
最后,启动服务:
docker-compose -f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai up -d
或者停止服务:
docker-compose -f docker-compose.yaml -f docker-compose.llm.yaml --env-file .env.local --env-file .env.ai down
附录
openai_like.py
文件(提供自定义embedding服务):
import logging
import os
from typing import Any, Dict, List, Optional, Tupleimport backoff
import openai
from haystack import Document, component
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.utils import Secret
from openai import AsyncOpenAI, OpenAI
from tqdm import tqdmfrom src.core.provider import EmbedderProvider
from src.providers.loader import providerimport logging
import syslogging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))logger = logging.getLogger("wren-ai-service")EMBEDDER_OPENAI_API_BASE = "https://api.openai.com/v1"
EMBEDDING_MODEL = "text-embedding-3-large"
EMBEDDING_MODEL_DIMENSION = 3072@component
class AsyncTextEmbedder(OpenAITextEmbedder):def __init__(self,api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),model: str = "text-embedding-ada-002",dimensions: Optional[int] = None,api_base_url: Optional[str] = None,organization: Optional[str] = None,prefix: str = "",suffix: str = "",):super(AsyncTextEmbedder, self).__init__(api_key,model,dimensions,api_base_url,organization,prefix,suffix,)self.client = AsyncOpenAI(api_key=api_key.resolve_value(),organization=organization,base_url=api_base_url,)@component.output_types(embedding=List[float], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)async def run(self, text: str):if not isinstance(text, str):raise TypeError("OpenAITextEmbedder expects a string as an input.""In case you want to embed a list of Documents, please use the OpenAIDocumentEmbedder.")logger.debug(f"Running Async OpenAI text embedder with text: {text}")text_to_embed = self.prefix + text + self.suffix# copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)# replace newlines, which can negatively affect performance.text_to_embed = text_to_embed.replace("
", " ")# if self.dimensions is not None:# response = await self.client.embeddings.create(# model=self.model, dimensions=self.dimensions, input=text_to_embed# )# else:response = await self.client.embeddings.create(model=self.model, input=text_to_embed)meta = {"model": response.model, "usage": dict(response.usage)}return {"embedding": response.data[0].embedding, "meta": meta}@component
class AsyncDocumentEmbedder(OpenAIDocumentEmbedder):def __init__(self,api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),model: str = "text-embedding-ada-002",dimensions: Optional[int] = None,api_base_url: Optional[str] = None,organization: Optional[str] = None,prefix: str = "",suffix: str = "",batch_size: int = 32,progress_bar: bool = True,meta_fields_to_embed: Optional[List[str]] = None,embedding_separator: str = "
",):super(AsyncDocumentEmbedder, self).__init__(api_key,model,dimensions,api_base_url,organization,prefix,suffix,batch_size,progress_bar,meta_fields_to_embed,embedding_separator,)self.client = AsyncOpenAI(api_key=api_key.resolve_value(),organization=organization,base_url=api_base_url,)async def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:all_embeddings = []meta: Dict[str, Any] = {}for i in tqdm(range(0, len(texts_to_embed), batch_size),disable=not self.progress_bar,desc="Calculating embeddings",):batch = texts_to_embed[i : i + batch_size]# if self.dimensions is not None:# response = await self.client.embeddings.create(# model=self.model, dimensions=self.dimensions, input=batch# )# else:response = await self.client.embeddings.create(model=self.model, input=batch)embeddings = [el.embedding for el in response.data]all_embeddings.extend(embeddings)if "model" not in meta:meta["model"] = response.modelif "usage" not in meta:meta["usage"] = dict(response.usage)else:meta["usage"]["prompt_tokens"] += response.usage.prompt_tokensmeta["usage"]["total_tokens"] += response.usage.total_tokensreturn all_embeddings, meta@component.output_types(documents=List[Document], meta=Dict[str, Any])@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)async def run(self, documents: List[Document]):if (not isinstance(documents, list)or documentsand not isinstance(documents[0], Document)):raise TypeError("OpenAIDocumentEmbedder expects a list of Documents as input.""In case you want to embed a string, please use the OpenAITextEmbedder.")logger.debug(f"Running Async OpenAI document embedder with documents: {documents}")texts_to_embed = self._prepare_texts_to_embed(documents=documents)embeddings, meta = await self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)for doc, emb in zip(documents, embeddings):doc.embedding = embreturn {"documents": documents, "meta": meta}@provider("openai_like_embedder")
class OpenAIEmbedderProvider(EmbedderProvider):def __init__(self,api_key: Secret = Secret.from_env_var("EMBEDDER_OPENAI_API_KEY"),api_base: str = os.getenv("EMBEDDER_OPENAI_API_BASE")or EMBEDDER_OPENAI_API_BASE,embedding_model: str = os.getenv("EMBEDDING_MODEL") or EMBEDDING_MODEL,embedding_model_dim: int = (int(os.getenv("EMBEDDING_MODEL_DIMENSION"))if os.getenv("EMBEDDING_MODEL_DIMENSION")else 0)or EMBEDDING_MODEL_DIMENSION,):def _verify_api_key(api_key: str, api_base: str) -> None:"""this is a temporary solution to verify that the required environment variables are set"""OpenAI(api_key=api_key, base_url=api_base).models.list()logger.info(f"Initializing OpenAIEmbedder provider with API base: {api_base}")# TODO: currently only OpenAI api key can be verifiedif api_base == EMBEDDER_OPENAI_API_BASE:_verify_api_key(api_key.resolve_value(), api_base)logger.info(f"Using OpenAI Embedding Model: {embedding_model}")else:logger.info(f"Using OpenAI API-compatible Embedding Model: {embedding_model}")self._api_key = api_keyself._api_base = api_baseself._embedding_model = embedding_modelself._embedding_model_dim = embedding_model_dimdef get_text_embedder(self):return AsyncTextEmbedder(api_key=self._api_key,api_base_url=self._api_base,model=self._embedding_model,dimensions=self._embedding_model_dim,)def get_document_embedder(self):return AsyncDocumentEmbedder(api_key=self._api_key,api_base_url=self._api_base,model=self._embedding_model,dimensions=self._embedding_model_dim,)