运行示例训练¶
本示例将一个hello world级的本地CNN算法 在本地的Fashion MNIST数据集上进行训练。 通过使用Anylearn SDK, 该训练将运行在Anylearn后端引擎上, 用户无需配置和管理模型训练的执行环境。
注解
用户需知悉想要连接的Anylearn后端引擎的环境, 并拥有该环境下的Anylearn账号, 如未注册,请先移步相应的Anylearn前端创建账号。
本示例的文件名为: project_quick_training_local_resources.ipynb
,
下载方式详见 下载快速上手示例 。
示例主要包括3部分:
初始化SDK¶
from anylearn.config import init_sdk
init_sdk('http://anylearn.backend', 'username', 'password')
注解
上述代码中的Anylearn远端地址( http://anylearn.backend
)、用户名( username
)和密码( password
)均非真实值,
请按需重新填入相应的值。
用户需调用 init_sdk
接口来连接Anylearn后端引擎并初始化SDK,
入参分别为:Anylearn引擎远端地址、用户名、密码。
获取资源组¶
注解
版本 0.11.0
新特性:用户配给计算资源组
from anylearn.interfaces import QuotaGroup
groups = QuotaGroup.get_list()
groups
当用户不确定当前有哪些资源组可用时,
可以通过调用 QuotaGroup.get_list()
接口来确认。
创建快速训练任务¶
注解
版本 0.11.0
新特性:创建训练时,需要额外增加计算资源请求
from anylearn.applications.quickstart import quick_train
train_task, algo, dset, project = quick_train(
algorithm_dir="./cnn",
dataset_dir="./fashion_mnist",
entrypoint="python fashion_mnist.py",
output="model",
dataset_hyperparam_name="data-path",
hyperparams={'batch-size': 256, 'epochs': 12},
resource_request=[{
'QGRP03fe160211ecb6119ef94103bf12': {
'A-100-unique': 1,
'CPU': 4,
'Memory': 4,
}
}],
)
调用 quick_train
接口启动训练。
这里传入的参数包括:
本地算法路径、
本地数据集路径、
训练入口命令、
训练输出路径、
数据集文件传入算法时的参数名、
训练超参数、
资源请求。
注解
如需要标识类参数(flag),可将参数的值设为空字符串,如 {'param': '123', 'my-flag': ''}
,等价于 --param 123 --my-flag
传入训练命令。
打印返回的 TrainTask
对象可查询创建的训练任务的元信息,如:
TrainTask(name='e2bfz4i1', description='', state=0, visibility=1, creator_id='USERfb6c6d2111eaadda13fd17feeac7', owner=['USERfb6c6d2111eaadda13fd17feeac7'], project_id='PROJbe16b2f511eb8f3a022cb4375d6a', algorithm_id='ALGOb34ab2f511eb8f3a022cb4375d6a', train_params='{"data-path": "$DSET0c2ab2f511eb8f3a022cb4375d6a", "batch-size": 256, "epochs": 12}', files='DSET0c2ab2f511eb8f3a022cb4375d6a', results_id='FILEbe78b2f511eb8f3a022cb4375d6a', secret_key='TKEY899cb2f511eb8f3a022cb4375d6a', create_time='2021-05-12 15:44:22', finish_time='', envs='', hpo=False, hpo_search_space=None, final_metric=None, id='TRAI62dab2f511eb8f3a022cb4375d6a')
追踪训练进度¶
import time
status = train_task.get_status()
while 'state' not in status:
print("Waiting...")
time.sleep(10)
status = train_task.get_status()
while status['state'] not in ["success", "fail"]:
if 'process' in status:
print(f"Progress: {int(100 * float(status['process']))}%")
else:
print(status['state'])
time.sleep(30)
status = train_task.get_status()
print(status)
status['state']
上一小节返回的 TrainTask
对象中的 get_status
方法可供用户获取训练的当前进度和状态。
用户也可以通过 TrainTask
对象的 get_detail
方法与后端引擎同步训练的完整元信息,
以获取当前的训练状态码( TrainTask
对象的 state
属性)。
这段代码粗略地展示了如何通过轮询的方式持续跟踪训练的进度和状态。