Documentation
搜索文档…
Keras
利用Keras的回调函数callback,自动保存所有指标和model.fit跟踪的损失值。
example.py
1
import wandb
2
from wandb.keras import WandbCallback
3
wandb.init(config={"hyper": "parameter"})
4
5
# Magic
6
7
model.fit(X_train, y_train, validation_data=(X_test, y_test),
8
callbacks=[WandbCallback()])
Copied!
colab笔记本中尝试我们的集成,并提供了完整的视频教程,或查看我们的示例项目以获得完整的脚本示例。

选项

Keras的类WandbCallback()支持许多选项:
关键字参数
默认值
说明
monitor
val_loss
用来评估性能的训练指标,以便于保存最佳模型,例如val_loss。
mode
auto
“min”,“max”或“auto”:如何在不同步(step)之间比较monitor指定的训练指标。
save_weights_only
False
只保存权重而不是整个模型
save_model
True
如果每一步(step)都有改进,就保存该模型。
log_weights
False
记录各层参数在每个周期(epoch)的参数值
log_gradients
False
记录每个周期中各层的参数梯度。
training_data
None
需要多元组(x,y)用于计算梯度。
data_type
None
我们正保存的数据类型,目前仅支持图像“image”。
labels
None
只有指定了data_type才会用到,如果你要做分类器,要把数字输出转化为标签列表。(支持二元分类器。
predictions
36
如果指定了data_type,预测的次数。最大为100
generator
None
如果用数据扩增和data_type,你可以指定一个生成器来做预测。

常见问题

通过wandb使用Keras的multiprocessing

如果你设置use_multiprocessing=True,然后收到错误Error('You must call wandb.init() before wandb.config.batch_size'),意思为错误(‘你必须在wandb.config.batch_size前调用wandb.init()’),那么可以尝试下列方法:
  1. 1.
    在Sequence类init中, 添加: wandb.init(group='...')
  2. 2.
    在主程序中,确保使用if __name__ == "__main__":然后把你剩下的脚本逻辑的部分放进去。

示例

我们已经为你创建了一些示例,以了解集成的工作原理:
  • Github上的示例:Python脚本中的Fashion MNIST示例
  • 在Google Colab中运行: 一个简单的笔记本示例让你入门
  • Wandb仪表盘:在W&B上查看结果
最近更新 7mo ago