欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 高考 > Wandb使用指南

Wandb使用指南

2025/2/26 15:21:44 来源:https://blog.csdn.net/weixin_43135178/article/details/141276464  浏览:    关键词:Wandb使用指南

安装:

pip install wandb

登录

wanbd login

在terminal中操作查看你的API key并粘贴回车进行授权(https://wandb.ai/authorize)

设置离线模式/在线模式

设置为offline会在无网络(内网)的时候使用,常用于debug的时候使用,因为这样启动速度快

注意:设置offline要在wandb init之前使用,否则不起作用

# 设置为离线模式,常用于测试、debug的时候,因为在线模式启动速度慢
import os
os.environ["WANDB_MODE"]="offline"

代码中初始化wandb

# 初始化wandb
wandb.init(project="config_example")

如果记录参数和日志?

1)记录运行args参数

有很多方式,但是最常用的方式是直接通过一行代码上传args:

# 储存运行参数:将参数值转为dict,然后再储存
wandb.config.update(vars(args))

上传后的参数储存在overview中: 

2)记录运行log + images

wandb.log()记录这些值

然后image要通过一步wandb.Image(image)转换才可以存储

        # 储存运行过程中的图像:随机生成一个图像作为示例data = np.random.rand(256, 256, 3) * 255data = data.astype(np.uint8)image = Image.fromarray(data, 'RGB')# 储存运行过程中的loss等日志wandb.log({"epoch": epoch,"train_acc": train_acc,"train_loss": train_loss,"val_acc": val_acc,"val_loss": val_loss,'images': wandb.Image(image),})

  3)记录某些文本

记录总结性的文本

例如:参数量

wandb.run.summary['Trainable parameters'] = f"{n_params / 1e6}M"

记录带有格式的文本 

某些时候可能需要记录model的构造等等,我们需要使用:

    wandb.log({"Model_architecture": wandb.Table(columns=["Model_architecture"], data=[[str(model_without_ddp)]])})

结果查看: 

运行代码后,会出现日志,直接点击本次运行结果的连接即可

完整示例代码执行:

import wandb
import argparse
import numpy as np
import random
from PIL import Image# 初始化wandb
wandb.init(project="config_example")def train_one_epoch(epoch, lr, bs):acc = 0.25 + ((epoch / 30) + (random.random() / 10))loss = 0.2 + (1 - ((epoch - 1) / 10 + random.random() / 5))return acc, lossdef evaluate_one_epoch(epoch):acc = 0.1 + ((epoch / 20) + (random.random() / 10))loss = 0.25 + (1 - ((epoch - 1) / 10 + random.random() / 6))return acc, lossdef main(args):# 储存运行参数:将参数值转为dict,然后再储存wandb.config.update(vars(args))for epoch in np.arange(1, args.epochs):train_acc, train_loss = train_one_epoch(epoch, args.learning_rate, args.batch_size)val_acc, val_loss = evaluate_one_epoch(epoch)# 储存运行过程中的图像:随机生成一个图像作为示例data = np.random.rand(256, 256, 3) * 255data = data.astype(np.uint8)image = Image.fromarray(data, 'RGB')# 储存运行过程中的loss等日志wandb.log({"epoch": epoch,"train_acc": train_acc,"train_loss": train_loss,"val_acc": val_acc,"val_loss": val_loss,'images': wandb.Image(image),})if __name__ == "__main__":parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument("--batch_size", type=int, default=32, help="Batch size")parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")parser.add_argument("--learning_rate", type=int, default=0.001, help="Learning rate")args = parser.parse_args()main(args)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词