欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > Rhymes AI发布首款开源多模态AI模型Aria 性能超越GPT-4o mini等多家知名AI模型

Rhymes AI发布首款开源多模态AI模型Aria 性能超越GPT-4o mini等多家知名AI模型

2024/10/23 15:21:08 来源:https://blog.csdn.net/weixin_41446370/article/details/142891057  浏览:    关键词:Rhymes AI发布首款开源多模态AI模型Aria 性能超越GPT-4o mini等多家知名AI模型

最近,日本东京的初创公司 Rhymes AI 推出了他们的首款人工智能模型 ——Aria。该公司自称,Aria 是全球首个开源的多模态混合专家(MoE)模型。这个模型不仅具有处理多种输入模态的能力,还声称在能力上与一些知名的商业模型不相上下,甚至更胜一筹。

Aria 的设计理念是希望能够在文本、代码、图像和视频等多种输入形式上,提供卓越的理解和处理能力。与传统的 Transformer 模型不同,MoE 模型通过多个专业的专家来替代其前馈层。当处理每个输入令牌时,一个路由模块会选择一部分专家进行激活,从而提高计算效率,减少每个令牌的激活参数数量。

在这里插入图片描述
Aria 的解码器每个文本令牌可以激活35亿个参数,整个模型拥有249亿个参数。为了处理视觉输入,Aria 还设计了一款轻量级的视觉编码器,拥有4.38亿个参数,可以将各种长度、大小和纵横比的视觉输入转换为视觉令牌。此外,Aria 的多模态上下文窗口达到64,000个令牌,意味着它能处理更长的输入数据。

在这里插入图片描述
在训练方面,Rhymes AI 共分为四个阶段,先用文本数据进行预训练,再引入多模态数据,接着是长序列的训练,最后进行微调。

在此过程中,Aria 总共使用了6.4万亿个文本令牌和4000亿个多模态令牌进行预训练,数据来自 Common Crawl 和 LAION 等知名数据集,并进行了部分合成增强。

根据相关基准测试,Aria 在多个多模态、语言和编程任务中表现优于 Pixtral-12B 和 Llama-3.2-11B 等模型,并且因激活参数较少,推理成本也较低。

此外,Aria 在处理带有字幕的视频或多页文档时表现良好,其理解长视频和文档的能力超过了 GPT-4o mini 和 Gemini1.5Flash 等其他开源模型。

在这里插入图片描述
为便于使用,Rhymes AI 将 Aria 的源代码以 Apache2.0许可证形式发布在 GitHub 上,支持学术和商业使用。同时,他们还提供了一个训练框架,可以在单个 GPU 上对 Aria 进行多种数据源和格式的微调。值得一提的是,Rhymes AI 与 AMD 达成了合作,以优化模型性能,展示了一款名为 BeaGo 的搜索应用,该应用能够在 AMD 硬件上运行,为用户提供更全面的文本和图像 AI 搜索结果。

Quick Start

pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torchvision requests torch Pillow
pip install flash-attn --no-build-isolation# For better performance, you can install grouped-gemm, which may take 3-5 minutes to install
pip install grouped_gemm==0.1.6
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessormodel_id_or_path = "rhymes-ai/Aria"model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"image = Image.open(requests.get(image_path, stream=True).raw)messages = [{"role": "user","content": [{"text": None, "type": "image"},{"text": "what is the image?", "type": "text"},],}
]text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):output = model.generate(**inputs,max_new_tokens=500,stop_strings=["<|im_end|>"],tokenizer=processor.tokenizer,do_sample=True,temperature=0.9,)output_ids = output[0][inputs["input_ids"].shape[1]:]result = processor.decode(output_ids, skip_special_tokens=True)print(result)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com