本文翻译整理自:https://github.com/replicate/replicate-python
文章目录
- 一、关于 Replicate Python 客户端
- 相关链接资源
- 关键功能特性
- 二、1.0.0 版本的重大变更
- 三、安装与配置
- 1、系统要求
- 2、安装
- 3、认证配置
- 四、核心功能
- 1、运行模型
- 2、异步IO支持
- 3、流式输出模型
- 4、后台运行模型
- 5、后台运行模型并获取Webhook
- 6、组合模型管道
- 7、获取运行中模型的输出
- 8、取消预测
- 9、列出预测
- 10、加载输出文件
- FileOutput 对象
- 11、列出模型
- 12、创建模型
- 13、微调模型
- 14、自定义客户端行为
- 五、开发
一、关于 Replicate Python 客户端
这是一个用于 Replicate 的 Python 客户端库,允许您从 Python 代码或 Jupyter Notebook 中运行模型,并在 Replicate 平台上执行各种操作。
相关链接资源
- github : https://github.com/replicate/replicate-python
- 官网:https://replicate.com
- 官方文档:https://replicate.com/docs
- 训练API文档:https://replicate.com/docs/fine-tuning
- Webhooks文档:https://replicate.com/docs/webhooks
- 流式输出文档:https://replicate.com/docs/streaming
- Colab教程:https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c
关键功能特性
- 运行模型预测
- 流式输出处理
- 后台模型执行
- 模型管道组合
- 训练自定义模型
- 预测管理(取消/列表)
- 异步IO支持
- Webhook集成
二、1.0.0 版本的重大变更
1.0.0 版本包含以下破坏性变更:
- 对于输出文件的模型,
replicate.run()
方法现在默认返回FileOutput
对象而非 URL 字符串。FileOutput
实现了类似httpx.Response
的可迭代接口,使文件处理更高效。
如需恢复旧行为,可通过传递 use_file_output=False
参数禁用 FileOutput
:
output = replicate.run("acmecorp/acme-model", use_file_output=False)
在大多数情况下,更新现有应用程序以调用 output.url
即可解决问题。
但我们建议直接使用 FileOutput
对象,因为我们计划对该 API 进行进一步改进,这种方法能确保获得最快的处理结果。
[!TIP]
👋 查看本教程的交互式版本:Google Colabhttps://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c
三、安装与配置
1、系统要求
- Python 3.8+
2、安装
pip install replicate
3、认证配置
在使用 API 运行任何 Python 脚本前,需设置环境变量中的 Replicate API 令牌。
从 replicate.com/account 获取令牌并设置为环境变量:
export REPLICATE_API_TOKEN=<your token>
我们建议不要直接将令牌添加到源代码中,因为您不希望将凭证提交到版本控制系统。如果任何人使用您的 API 密钥,其使用量将计入您的账户。
四、核心功能
1、运行模型
创建新的 Python 文件并添加以下代码,替换为您自己的模型标识符和输入:
>>> import replicate
>>> outputs = replicate.run("black-forest-labs/flux-schnell", input={"prompt": "astronaut riding a rocket like a horse"})
[<replicate.helpers.FileOutput object at 0x107179b50>]
>>> for index, output in enumerate(outputs):with open(f"output_{index}.webp", "wb") as file:file.write(output.read())
如果预测失败,replicate.run
会抛出 ModelError
异常。您可以通过异常的 prediction
属性获取更多失败信息。
import replicate
from replicate.exceptions import ModelErrortry:output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
except ModelError as eif "(some known issue)" in e.prediction.logs:passprint("Failed prediction: " + e.prediction.id)
[!NOTE]
默认情况下,Replicate 客户端会保持连接打开最多 60 秒,等待预测完成。这种设计是为了优化模型输出返回客户端的速度。可通过传递
wait=x
给replicate.run()
来配置超时,其中x
是 1 到 60 秒之间的超时值。要禁用同步模式,可传递wait=False
。
2、异步IO支持
通过在方法名前添加 async_
前缀,您也可以异步使用 Replicate 客户端。
以下是并发运行多个预测并等待它们全部完成的示例:
import asyncio
import replicate# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [f"A chariot pulled by a team of {count} rainbow unicorns"for count in ["two", "four", "six", "eight"]
]async with asyncio.TaskGroup() as tg:tasks = [tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))for prompt in prompts]results = await asyncio.gather(*tasks)
print(results)
对于需要文件输入的模型,您可以传递互联网上可公开访问文件的 URL,或本地设备上的文件句柄:
>>> output = replicate.run("andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9", input={ "image": open("path/to/mystery.jpg") })"an astronaut riding a horse"
3、流式输出模型
Replicate 的 API 支持语言模型的服务器发送事件流(SSEs)。使用 stream
方法可以实时消费模型生成的标记。
import replicatefor event in replicate.stream("meta/meta-llama-3-70b-instruct", input={"prompt": "Please write a haiku about llamas.", }, ):print(str(event), end="")
[!TIP]
某些模型如 meta/meta-llama-3-70b-instruct 不需要版本字符串。您始终可以参考模型页面上的 API 文档了解具体细节。
您也可以流式传输已创建预测的输出。这在您希望将预测 ID 与其输出分开时很有用。
prediction = replicate.predictions.create(model="meta/meta-llama-3-70b-instruct", input={"prompt": "Please write a haiku about llamas."}, stream=True, )for event in prediction.stream():print(str(event), end="")
更多信息请参阅 Replicate 文档中的"流式输出"。
4、后台运行模型
您可以使用异步模式在后台启动并运行模型:
>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(version=version, input={"prompt":"Watercolor painting of an underwater submarine"})>>> prediction
Prediction(...)>>> prediction.status
'starting'>>> dict(prediction)
{"id": "...", "status": "starting", ...}>>> prediction.reload()
>>> prediction.status
'processing'>>> print(prediction.logs)
iteration: 0, render:loss: -0.6171875
iteration: 10, render:loss: -0.92236328125
iteration: 20, render:loss: -1.197265625
iteration: 30, render:loss: -1.3994140625>>> prediction.wait()>>> prediction.status
'succeeded'>>> prediction.output
<replicate.helpers.FileOutput object at 0x107179b50>>>> with open("output.png", "wb") as file:file.write(prediction.output.read())
5、后台运行模型并获取Webhook
您可以运行模型并在完成时获取 webhook,而不是等待它完成:
model = replicate.models.get("ai-forever/kandinsky-2.2")
version = model.versions.get("ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463")
prediction = replicate.predictions.create(version=version, input={"prompt":"Watercolor painting of an underwater submarine"}, webhook="https://example.com/your-webhook", webhook_events_filter=["completed"]
)
有关接收 webhook 的详细信息,请参阅 replicate.com/docs/webhooks。
6、组合模型管道
您可以运行一个模型并将其输出作为另一个模型的输入:
laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05")
swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a")
image = laionide.predict(prompt="avocado armchair")
upscaled_image = swinir.predict(image=image)
7、获取运行中模型的输出
在模型运行时获取其输出:
iterator = replicate.run("pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf", input={"prompts": "san francisco sunset"}
)for index, image in enumerate(iterator):with open(f"file_{index}.png", "wb") as file:file.write(image.read())
8、取消预测
您可以取消正在运行的预测:
>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(version=version, input={"prompt":"Watercolor painting of an underwater submarine"})>>> prediction.status
'starting'>>> prediction.cancel()>>> prediction.reload()
>>> prediction.status
'canceled'
9、列出预测
您可以列出所有运行过的预测:
replicate.predictions.list()
# [<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]
预测列表是分页的。您可以通过将 next
属性作为参数传递给 list
方法来获取下一页预测:
page1 = replicate.predictions.list()if page1.next:page2 = replicate.predictions.list(page1.next)
10、加载输出文件
输出文件作为 FileOutput
对象返回:
import replicate
from PIL import Image # pip install pillowoutput = replicate.run("stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", input={"prompt": "wavy colorful abstract patterns, oceans"})# 具有返回二进制数据的.read()方法
with open("my_output.png", "wb") as file:file.write(output[0].read())# 也实现了迭代器协议以流式传输数据
background = Image.open(output[0])
FileOutput 对象
FileOutput
是从 replicate.run()
方法返回的类文件对象,使处理输出文件的模型更容易使用。它实现了 Iterator
和 AsyncIterator
用于分块读取文件数据,以及 read()
和 aread()
方法将整个文件读入内存。
[!NOTE]
值得注意的是,目前read()
和aread()
不接受size
参数来读取最多size
字节。
最后,底层数据源的 URL 可通过 url
属性获得,但我们建议您将对象用作迭代器或使用其 read()
或 aread()
方法,因为 url
属性在未来可能不总是返回 HTTP URL。
print(output.url) #=> "data:image/png;base64,xyz123..." or "https://delivery.replicate.com/..."
要直接消费文件:
with open('output.bin', 'wb') as file:file.write(output.read())
对于非常大的文件,可以流式传输:
with open(file_path, 'wb') as file:for chunk in output:file.write(chunk)
每种方法都有对应的 asyncio
API:
async with aiofiles.open(filename, 'w') as file:await file.write(await output.aread())async with aiofiles.open(filename, 'w') as file:await for chunk in output:await file.write(chunk)
对于来自常见框架的流式响应,都支持接受 Iterator
类型:
Django
@condition(etag_func=None)
def stream_response(request):output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True)return HttpResponse(output, content_type='image/webp')
FastAPI
@app.get("/")
async def main():output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True)return StreamingResponse(output)
Flask
@app.route('/stream')
def streamed_response():output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True)return app.response_class(stream_with_context(output))
您可以通过向 replicate.run()
方法传递 use_file_output=False
来禁用 FileOutput
:
const replicate = replicate.run("acmecorp/acme-model", use_file_output=False);
11、列出模型
您可以列出您创建的模型:
replicate.models.list()
模型列表是分页的。您可以通过将 next
属性作为参数传递给 list
方法来获取下一页模型,或者使用 paginate
方法自动获取页面。
# 使用 `replicate.paginate` 自动分页(推荐)
models = []
for page in replicate.paginate(replicate.models.list):models.extend(page.results)if len(models) > 100:break# 使用 `next` 游标手动分页
page = replicate.models.list()
while page:models.extend(page.results)if len(models) > 100:breakpage = replicate.models.list(page.next) if page.next else None
您还可以在 Replicate 上找到精选模型集合:
>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
>>> collections[0].slug
"vision-models"
>>> collections[0].description
"Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)">>> replicate.collections.get("text-to-image").models
[<Model: stability-ai/sdxl>, ...]
12、创建模型
您可以为用户或组织创建具有给定名称、可见性和硬件 SKU 的模型:
import replicatemodel = replicate.models.create(owner="your-username", name="my-model", visibility="public", hardware="gpu-a40-large"
)
以下是列出 Replicate 上可用于运行模型的所有可用硬件的方法:
>>> [hw.sku for hw in replicate.hardware.list()]
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']
13、微调模型
使用训练API微调模型,使其在特定任务上表现更好。要查看当前支持微调的语言模型,请查看 Replicate 的可训练语言模型集合。
如果您想微调图像模型,请查看 Replicate 的图像模型微调指南。
以下是在 Replicate 上微调模型的方法:
training = replicate.trainings.create(model="stability-ai/sdxl", version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={"input_images": "https://my-domain/training-images.zip", "token_string": "TOK", "caption_prefix": "a photo of TOK", "max_train_steps": 1000, "use_face_detection_instead": False}, # 您需要在 Replicate 上创建一个模型作为训练版本的接收方destination="your-username/model-name"
)
14、自定义客户端行为
replicate
包导出一个默认的共享客户端。此客户端使用 REPLICATE_API_TOKEN
环境变量设置的 API 令牌初始化。
您可以创建自己的客户端实例以传递不同的 API 令牌值,向请求添加自定义标头,或控制底层 HTTPX 客户端的行为:
import os
from replicate.client import Clientreplicate = Client(api_token=os.environ["SOME_OTHER_REPLICATE_API_TOKEN"]headers={"User-Agent": "my-app/1.0"}
)
[!WARNING]
切勿将 API 令牌等认证凭证硬编码到代码中。
相反,在运行程序时将它们作为环境变量传递。
五、开发
参见 CONTRIBUTING.md
伊织 xAI 2024-04-19(六)