.. _quickstart_example_training: 运行示例训练 ============ 本示例将一个hello world级的本地CNN算法 在本地的Fashion MNIST数据集上进行训练。 通过使用Anylearn SDK, 该训练将运行在Anylearn后端引擎上, 用户无需配置和管理模型训练的执行环境。 .. Note:: 用户需知悉想要连接的Anylearn后端引擎的环境, 并拥有该环境下的Anylearn账号, 如未注册,请先移步相应的Anylearn前端创建账号。 本示例的文件名为: ``project_quick_training_local_resources.ipynb`` , 下载方式详见 :ref:`quickstart_example_downloading` 。 示例主要包括3部分: #. `初始化SDK`_ #. `创建快速训练任务`_ #. `追踪训练进度`_ 初始化SDK --------- .. code-block:: python from anylearn.config import init_sdk init_sdk('http://anylearn.backend', 'username', 'password') .. Note:: 上述代码中的Anylearn远端地址( ``http://anylearn.backend`` )、用户名( ``username``)和密码( ``password`` )均非真实值, 请按需重新填入相应的值。 用户需调用 ``init_sdk`` 接口来连接Anylearn后端引擎并初始化SDK, 入参分别为:Anylearn引擎远端地址、用户名、密码。 获取资源组 ---------- .. Note:: 版本 ``0.11.0`` 新特性:用户配给计算资源组 .. code-block:: python from anylearn.interfaces import QuotaGroup groups = QuotaGroup.get_list() groups 当用户不确定当前有哪些资源组可用时, 可以通过调用 ``QuotaGroup.get_list()`` 接口来确认。 创建快速训练任务 ---------------- .. Note:: 版本 ``0.11.0`` 新特性:创建训练时,需要额外增加计算资源请求 .. code-block:: python from anylearn.applications.quickstart import quick_train train_task, algo, dset, project = quick_train( algorithm_dir="./cnn", dataset_dir="./fashion_mnist", entrypoint="python fashion_mnist.py", output="model", dataset_hyperparam_name="data-path", hyperparams={'batch-size': 256, 'epochs': 12}, resource_request=[{ 'QGRP03fe160211ecb6119ef94103bf12': { 'A-100-unique': 1, 'CPU': 4, 'Memory': 4, } }], ) 调用 ``quick_train`` 接口启动训练。 这里传入的参数包括: 本地算法路径、 本地数据集路径、 训练入口命令、 训练输出路径、 数据集文件传入算法时的参数名、 训练超参数、 资源请求。 .. Note:: 如需要标识类参数(flag),可将参数的值设为空字符串,如 ``{'param': '123', 'my-flag': ''}`` ,等价于 ``--param 123 --my-flag`` 传入训练命令。 打印返回的 ``TrainTask`` 对象可查询创建的训练任务的元信息,如: .. code-block:: python TrainTask(name='e2bfz4i1', description='', state=0, visibility=1, creator_id='USERfb6c6d2111eaadda13fd17feeac7', owner=['USERfb6c6d2111eaadda13fd17feeac7'], project_id='PROJbe16b2f511eb8f3a022cb4375d6a', algorithm_id='ALGOb34ab2f511eb8f3a022cb4375d6a', train_params='{"data-path": "$DSET0c2ab2f511eb8f3a022cb4375d6a", "batch-size": 256, "epochs": 12}', files='DSET0c2ab2f511eb8f3a022cb4375d6a', results_id='FILEbe78b2f511eb8f3a022cb4375d6a', secret_key='TKEY899cb2f511eb8f3a022cb4375d6a', create_time='2021-05-12 15:44:22', finish_time='', envs='', hpo=False, hpo_search_space=None, final_metric=None, id='TRAI62dab2f511eb8f3a022cb4375d6a') 追踪训练进度 ------------ .. code-block:: python import time status = train_task.get_status() while 'state' not in status: print("Waiting...") time.sleep(10) status = train_task.get_status() while status['state'] not in ["success", "fail"]: if 'process' in status: print(f"Progress: {int(100 * float(status['process']))}%") else: print(status['state']) time.sleep(30) status = train_task.get_status() print(status) status['state'] 上一小节返回的 ``TrainTask`` 对象中的 ``get_status`` 方法可供用户获取训练的当前进度和状态。 用户也可以通过 ``TrainTask`` 对象的 ``get_detail`` 方法与后端引擎同步训练的完整元信息, 以获取当前的训练状态码( ``TrainTask`` 对象的 ``state`` 属性)。 这段代码粗略地展示了如何通过轮询的方式持续跟踪训练的进度和状态。