Documentation
搜索文档…
Import/Export API Guide
我们的公共API的最佳实践和常见用例——用于导出数据和更新现有的运行
使用公共API导出或更新已保存到W&B的数据。在使用API之前,您需要记录脚本中的数据——有关详情,请参见快速入门
公共API的用例
  • 导出数据:下拉一个数据帧,在Jupyter Notebook中进行自定义分析。研究完数据之后,您可以通过创建一个新的分析运行和记录结果来同步您的发现,例如:wandb.init(job_type="analysis")
  • 更新现有的运行:您可以更新与W&B运行相关联的记录数据。例如,您可能想要更新一组运行的配置,来纳入架构或最初时没有记录的超参数等额外信息。
有关可用函数的详细信息,请参阅生成的参考文档

验证

通过以下两种方式之一,使用API密钥对计算机进行身份验证:
  1. 1.
    在命令行上运行wandb login并贴入您的API密钥。
  2. 2.
    设置WANDB_API_KEY环境变量为您的API密钥。

导出运行数据

从已完成或活动的运行中下载数据。常见的用法包括下载数据帧,在jupiter笔记本中进行自定义分析,或者在自动化环境中使用自定义逻辑。
1
import wandb
2
api = wandb.Api()
3
run = api.run("<entity>/<project>/<run_id>")
Copied!
运行对象的最常用的属性是:
属性
定义
run.config
用于模型输入(例如超参数)的字典
run.history()
字典列表,用于存储模型训练时发生变化的值,例如损失。wandb.log()命令追加到该对象。
run.summary
输出内容的字典。可以是精度和损失等标量,也可以是大文件。默认情况下,wandb.log()将汇总设置为已记录时间序列的最终值。这也可以直接设置。
您还可以修改或更新过去运行的数据。默认情况下,api对象的单个实例将缓存所有的网络请求。如果您的用例需要提供在运行脚本中的实时信息,则可以调用api.flush()来获取更新的值。

采样

默认的历史方法将指标抽样到固定数量的样本中(默认是500,您也可以使用样本参数来更改此值)。如果要一次大批量导出所有数据,则可以使用run.scan_history()方法。更多详情,请参见API参考。

查询多次运行

数据帧和CSV
MongoDB Style
该示例脚本查找一个项目,然后输出运行的CSV文件,其中包含了名称、配置和汇总统计信息。
1
import wandb
2
api = wandb.Api()
3
4
# Change oreilly-class/cifar to <entity/project-name>
5
runs = api.runs("<entity>/<project>")
6
summary_list = []
7
config_list = []
8
name_list = []
9
for run in runs:
10
# run.summary are the output key/values like accuracy. We call ._json_dict to omit large files
11
summary_list.append(run.summary._json_dict)
12
13
# run.config is the input metrics. We remove special values that start with _.
14
config_list.append({k:v for k,v in run.config.items() if not k.startswith('_')})
15
16
# run.name is the name of the run.
17
name_list.append(run.name)
18
19
import pandas as pd
20
summary_df = pd.DataFrame.from_records(summary_list)
21
config_df = pd.DataFrame.from_records(config_list)
22
name_df = pd.DataFrame({'name': name_list})
23
all_df = pd.concat([name_df, config_df,summary_df], axis=1)
24
25
all_df.to_csv("project.csv")
Copied!
W&B API还提供了一种在项目中使用api.runs()进行跨运行查询的方法。最常见的用例是导出运行数据以进行自定义分析。查询界面与MongoDB使用的界面相同。
1
runs = api.runs("username/project", {"$or": [{"config.experiment_name": "foo"}, {"config.experiment_name": "bar"}]})
2
print("Found %i" % len(runs))
Copied!
调用api.runs(...)返回一个可迭代的运行对象,其作用类似于列表。对象按需要每次依次加载50次运行,您也可以使用per_page关键词参数来更改每页加载的次数。
api.runs(...)还接受order(顺序)关键词参数。默认顺序为-created_at,指定+created_at以按升序获取结果。您还可以按配置或汇总值(即summary.val_accconfig.experiment_name)进行排序。

错误处理

如果与W&B服务器通信时发生错误,将引发wandb.CommError。此时可以通过exc属性对原始异常进行自省。

通过API获取最新的git commit

在UI中,单击运行,然后单击运行页面上的“概述”选项卡以查看最新的git commit。它也在wandb-metadata.json文件中。使用公共API,您可以通过run.commit获得git哈希。

常见问题

导出数据,在matplotlib或seaborn中进行可视化

查看我们的API示例,了解一些常见的导出模式。您还可以在自定义图或展开的运行表上单击下载按钮,从浏览器下载CSV。
从脚本中获取随机的运行ID和运行名称
在调用wandb.init()之后,您可以从脚本中读取随机的运行ID或人类可读的运行名称,如下所示:
  • 唯一的运行ID(8个字符的哈希): wandb.run.id
  • 随机的运行名称(人类可读): wandb.run.name
如果您想为运行设置一些识别符,可以采用以下方法:
  • 运行ID:保留为生成的哈希。这在项目的各运行之间是独一无二的。
  • 运行名称:名称应该简短、易读,并且最好是唯一的,方便您分辨图表上不同线条之间的区别。
  • 运行注解:提供有关于您运行时做了什么的简单描述。可以使用wandb.init(notes="your notes here")进行设置。
  • 运行标签:在运行标签中对内容进行动态跟踪,并在UI中使用过滤器对表进行过滤,只保留您关心的运行。您可以从脚本中设置标签,然后在UI中,在运行表和运行页面的“概述”选项卡中对其进行编辑。

公共API的例子

查找运行路径

要使用公共API,通常需要使用运行路径“<entity>/<project>/<run_id”。在应用程序中,打开一个运行,然后单击概述选项卡,可以查看任何运行的运行路径。

从运行中读取指标

本示例输出用wandb.log({"accuracy": acc})保存的时间戳和精度,该运行保存到//。
1
import wandb
2
api = wandb.Api()
3
4
run = api.run("<entity>/<project>/<run_id>")
5
if run.state == "finished":
6
for i, row in run.history().iterrows():
7
print(row["_timestamp"], row["accuracy"])
Copied!

从运行中读取特定的指标

要从运行中提取特定的指标,请使用keys参数。使用run.history()时,默认样本数为500。不包含特定指标的已记录步骤将在输出数据帧中显示为NaNkeys参数将使api更频繁地对那些包含所列出指标键的步骤进行采样。
1
import wandb
2
api = wandb.Api()
3
4
run = api.run("<entity>/<project>/<run_id>")
5
if run.state == "finished":
6
for i, row in run.history(keys=["accuracy"]).iterrows():
7
print(row["_timestamp"], row["accuracy"])
Copied!

比较两次运行

这将输出run1和run2之间不同的配置参数。
1
import wandb
2
api = wandb.Api()
3
4
# replace with your <entity_name>/<project_name>/<run_id>
5
run1 = api.run("<entity>/<project>/<run_id>")
6
run2 = api.run("<entity>/<project>/<run_id>")
7
8
import pandas as pd
9
df = pd.DataFrame([run1.config, run2.config]).transpose()
10
11
df.columns = [run1.name, run2.name]
12
print(df[df[run1.name] != df[run2.name]])
Copied!
输出:
1
c_10_sgd_0.025_0.01_long_switch base_adam_4_conv_2fc
2
batch_size 32 16
3
n_conv_layers 5 4
4
optimizer rmsprop adam
Copied!

完成运行后,更新运行的指标

本示例将前一次运行的精度设置为0.9。它还将前一次运行的精度直方图修改为numpy_array的直方图。
1
import wandb
2
api = wandb.Api()
3
4
run = api.run("<entity>/<project>/<run_id>")
5
run.summary["accuracy"] = 0.9
6
run.summary["accuracy_histogram"] = wandb.Histogram(numpy_array)
7
run.summary.update()
Copied!

更新现有运行的配置

本示例将更新您的一个配置设置
1
import wandb
2
api = wandb.Api()
3
run = api.run("<entity>/<project>/<run_id>")
4
run.config["key"] = updated_value
5
run.update()
Copied!

从单次运行中导出指标到CSV文件

该脚本查找单次运行保存的所有指标,并将 其保存到CSV中。
1
import wandb
2
api = wandb.Api()
3
4
# run is specified by <entity>/<project>/<run id>
5
run = api.run("<entity>/<project>/<run_id>")
6
7
# save the metrics for the run to a csv file
8
metrics_dataframe = run.history()
9
metrics_dataframe.to_csv("metrics.csv")
Copied!

从大型单次运行中导出指标而无需采样

默认的历史方法将指标抽样到固定数量的样本中(默认是500,您也可以使用样本参数来更改此值)。如果要一次大批量导出所有数据,则可以使用run.scan_history()方法。该脚本将所有的损失指标加载到可变损失中,以实现更长的运行时间。
1
import wandb
2
api = wandb.Api()
3
4
run = api.run("<entity>/<project>/<run_id>")
5
history = run.scan_history()
6
losses = [row["Loss"] for row in history]
Copied!

将项目中所有运行的指标导出到CSV文件

该脚本查找一个项目,然后输出运行的CSV文件,其中包含了名称、配置和汇总统计信息。
1
import wandb
2
api = wandb.Api()
3
4
runs = api.runs("<entity>/<project>")
5
summary_list = []
6
config_list = []
7
name_list = []
8
for run in runs:
9
# run.summary are the output key/values like accuracy. We call ._json_dict to omit large files
10
summary_list.append(run.summary._json_dict)
11
12
# run.config is the input metrics. We remove special values that start with _.
13
config_list.append({k:v for k,v in run.config.items() if not k.startswith('_')})
14
15
# run.name is the name of the run.
16
name_list.append(run.name)
17
18
import pandas as pd
19
summary_df = pd.DataFrame.from_records(summary_list)
20
config_df = pd.DataFrame.from_records(config_list)
21
name_df = pd.DataFrame({'name': name_list})
22
all_df = pd.concat([name_df, config_df,summary_df], axis=1)
23
24
all_df.to_csv("project.csv")
Copied!

上传文件到已完成的运行

下面的代码片段将所选的文件上载到一个已完成的运行中。
1
import wandb
2
api = wandb.Api()
3
run = api.run("entity/project/run_id")
4
run.upload_file("file_name.extension")
Copied!

从运行中下载文件

此处在cifar项目中找到与运行ID uxte44z7关联的文件“model-best.h5”,并将其保存在本地
1
import wandb
2
api = wandb.Api()
3
run = api.run("<entity>/<project>/<run_id>")
4
run.file("model-best.h5").download()
Copied!

从运行中下载所有文件

此处查找与运行ID uxte44z7关联的所有文件,并将其保存在本地。(注意:您也可以通过从命令行运行wandb restore 来完成此操作。)
1
import wandb
2
api = wandb.Api()
3
run = api.run("<entity>/<project>/<run_id>")
4
for file in run.files():
5
file.download()
Copied!

下载最佳模型文件

1
import wandb
2
api = wandb.Api()
3
sweep = api.sweep("<entity>/<project>/<sweep_id>")
4
runs = sorted(sweep.runs, key=lambda run: run.summary.get("val_acc", 0), reverse=True)
5
val_acc = runs[0].summary.get("val_acc", 0)
6
print(f"Best run {runs[0].name} with {val_acc}% validation accuracy")
7
runs[0].file("model-best.h5").download(replace=True)
8
print("Best model saved to model-best.h5")
Copied!
从运行中删除所有具有给定扩展名的文件
1
import wandb
2
3
api = wandb.Api()
4
run = api.run("<entity>/<project>/<run_id>")
5
6
files = run.files()
7
for file in files:
8
if file.name.endswith(".png"):
9
file.delete()
Copied!

从特定扫描中获取运行

1
import wandb
2
api = wandb.Api()
3
sweep = api.sweep("<entity>/<project>/<sweep_id>")
4
print(sweep.runs)
Copied!

下载系统指标数据

这里为您提供了一个运行时包含所有系统指标的数据帧。
1
import wandb
2
api = wandb.Api()
3
run = api.run("<entity>/<project>/<run_id>")
4
system_metrics = run.history(stream = 'events')
Copied!

更新汇总指标

您可以通过字典来更新汇总指标。
1
summary.update({“key”: val})
Copied!

获取运行run的命令

每次运行都会在运行概述页面上捕获启动它的命令。要从API下拉此命令,您可以运行:
1
api = wandb.Api()
2
run = api.run("username/project/run_id")
3
meta = json.load(run.file("wandb-metadata.json").download())
4
program = ["python"] + [meta["program"]] + meta["args"]
Copied!

从历史中获取分页数据

如果在我们的后端获取指标很慢,或者API请求超时,您可以尝试降低scan_history中的页面大小,以使单个请求不会超时。默认的页面大小为1000,因此您可以尝试不同的大小,看看哪种效果最好:
1
api = wandb.Api()
2
run = api.run("username/project/run_id")
3
run.scan_history(keys=sorted(cols), page_size=100)
Copied!
最近更新 9mo ago