安装:
pip install wandb
登录
wanbd login
在terminal中操作查看你的API key并粘贴回车进行授权(https://wandb.ai/authorize)
设置离线模式/在线模式
设置为offline会在无网络(内网)的时候使用,常用于debug的时候使用,因为这样启动速度快
注意:设置offline要在wandb init之前使用,否则不起作用
# 设置为离线模式,常用于测试、debug的时候,因为在线模式启动速度慢
import os
os.environ["WANDB_MODE"]="offline"
代码中初始化wandb
# 初始化wandb
wandb.init(project="config_example")
如果记录参数和日志?
1)记录运行args参数
有很多方式,但是最常用的方式是直接通过一行代码上传args:
# 储存运行参数:将参数值转为dict,然后再储存
wandb.config.update(vars(args))
上传后的参数储存在overview中:
2)记录运行log + images
wandb.log()记录这些值
然后image要通过一步wandb.Image(image)转换才可以存储
# 储存运行过程中的图像:随机生成一个图像作为示例data = np.random.rand(256, 256, 3) * 255data = data.astype(np.uint8)image = Image.fromarray(data, 'RGB')# 储存运行过程中的loss等日志wandb.log({"epoch": epoch,"train_acc": train_acc,"train_loss": train_loss,"val_acc": val_acc,"val_loss": val_loss,'images': wandb.Image(image),})
3)记录某些文本
记录总结性的文本
例如:参数量
wandb.run.summary['Trainable parameters'] = f"{n_params / 1e6}M"
记录带有格式的文本
某些时候可能需要记录model的构造等等,我们需要使用:
wandb.log({"Model_architecture": wandb.Table(columns=["Model_architecture"], data=[[str(model_without_ddp)]])})
结果查看:
运行代码后,会出现日志,直接点击本次运行结果的连接即可
完整示例代码执行:
import wandb
import argparse
import numpy as np
import random
from PIL import Image# 初始化wandb
wandb.init(project="config_example")def train_one_epoch(epoch, lr, bs):acc = 0.25 + ((epoch / 30) + (random.random() / 10))loss = 0.2 + (1 - ((epoch - 1) / 10 + random.random() / 5))return acc, lossdef evaluate_one_epoch(epoch):acc = 0.1 + ((epoch / 20) + (random.random() / 10))loss = 0.25 + (1 - ((epoch - 1) / 10 + random.random() / 6))return acc, lossdef main(args):# 储存运行参数:将参数值转为dict,然后再储存wandb.config.update(vars(args))for epoch in np.arange(1, args.epochs):train_acc, train_loss = train_one_epoch(epoch, args.learning_rate, args.batch_size)val_acc, val_loss = evaluate_one_epoch(epoch)# 储存运行过程中的图像:随机生成一个图像作为示例data = np.random.rand(256, 256, 3) * 255data = data.astype(np.uint8)image = Image.fromarray(data, 'RGB')# 储存运行过程中的loss等日志wandb.log({"epoch": epoch,"train_acc": train_acc,"train_loss": train_loss,"val_acc": val_acc,"val_loss": val_loss,'images': wandb.Image(image),})if __name__ == "__main__":parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument("--batch_size", type=int, default=32, help="Batch size")parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")parser.add_argument("--learning_rate", type=int, default=0.001, help="Learning rate")args = parser.parse_args()main(args)