import copy
import json
from pathlib import Path
from threading import Thread
import time
from typing import Dict, List, Optional, Union
from anylearn.utils.func import generate_primary_key
from anylearn.applications.train_profile import TrainProfile
from .algorithm_manager import sync_algorithm
from .utils import (
    _check_resource_input,
    _get_archive_checksum,
    _get_or_create_resource_archive,
    generate_random_name,
    get_mirror_by_name,
    make_name_by_path,
)
from ..interfaces import (
    Project,
    QuotaGroup,
    TrainTask,
)
from ..interfaces.resource import (
    Algorithm,
    Dataset,
    Model,
    Resource,
    ResourceState,
    ResourceUploader,
    SyncResourceUploader,
)
from ..storage.db import DB
from ..utils import logger
from ..utils.errors import (
    AnyLearnException,
    AnyLearnMissingParamException,
    AnylearnRequiredLocalCommitException,
)
def _get_or_create_dataset(id: Optional[str]=None,
                           dir_path: Optional[Union[str, Path]]=None,
                           archive_path: Optional[str]=None):
    if not any([id, dir_path, archive_path]):
        return None, None, None
    try:
        dset = Dataset(id=id, load_detail=True)
        return dset, None, None
    except:
        if not any([dir_path, archive_path]):
            raise AnyLearnMissingParamException((
                "ID provided does not exist and none of "
                "['dir_path', 'archive_path'] "
                "is specified."
            ))
        name = make_name_by_path(dir_path or archive_path)
        archive_path = _get_or_create_resource_archive(
            name=name,
            dir_path=dir_path,
            archive_path=archive_path
        )
        checksum = _get_archive_checksum(archive_path)
        local_id = DB().find_local_dataset_by_checksum(checksum=checksum)
        if local_id:
            try:
                return Dataset(id=local_id, load_detail=True), None, None
            except:
                logger.warning(
                    f"Local dataset ({local_id}) "
                    "has been deleted remotely, "
                    "forced to re-registering dataset."
                )
                DB().delete_local_dataset(id=local_id)
        dset = Dataset(name=name, description="SDK_QUICKSTART",
                       public=False,
                       filename=f"{name}.zip",
                       is_zipfile=True)
        dset.save()
        return dset, archive_path, checksum
def _get_or_create_model(id: Optional[str]=None,
                         dir_path: Optional[Union[str,Path]]=None,
                         archive_path: Optional[str]=None,
                         algorithm: Optional[Algorithm]=None):
    _check_resource_input(id, dir_path, archive_path)
    try:
        model = Model(id=id, load_detail=True)
        return model, None, None
    except:
        if not any([dir_path, archive_path]):
            raise AnyLearnMissingParamException((
                "ID provided does not exist and none of "
                "['dir_path', 'archive_path'] "
                "is specified."
            ))
        if not algorithm or not algorithm.id:
            raise AnyLearnMissingParamException(
                "Parameter 'algorithm' should be specified "
                "when using local models."
            )
        name = make_name_by_path(dir_path or archive_path)
        archive_path = _get_or_create_resource_archive(
            name=name,
            dir_path=dir_path,
            archive_path=archive_path
        )
        checksum = _get_archive_checksum(archive_path)
        local_id = DB().find_local_model_by_checksum(checksum=checksum)
        if local_id:
            try:
                # Fetch remote model and update (eventually) related algo
                model = Model(id=local_id, load_detail=True)
                model.algorithm_id = algorithm.id
                model.save()
                return model, None, None
            except:
                logger.warning(
                    f"Local model ({local_id}) "
                    "has been deleted remotely, "
                    "forced to re-registering model."
                )
                DB().delete_local_model(id=local_id)
        model = Model(name=name, description="SDK_QUICKSTART",
                      public=False,
                      filename=f"{name}.zip",
                      is_zipfile=True,
                      algorithm_id=algorithm.id)
        model.save()
        return model, archive_path, checksum
def _upload_dataset(dataset: Dataset,
                    dataset_archive: str,
                    uploader: Optional[ResourceUploader]=None,
                    polling: Union[float, int]=5):
    if not uploader:
        uploader = SyncResourceUploader()
    t_dataset = Thread(target=Resource.upload_file,
                    kwargs={
                        'resource_id': dataset.id,
                        'file_path': dataset_archive,
                        'uploader': uploader,
                    })
    logger.info(f"Uploading dataset {dataset.name}...")
    t_dataset.start()
    t_dataset.join()
    finished = [ResourceState.ERROR, ResourceState.READY]
    while dataset.state not in finished:
        time.sleep(polling)
        dataset.get_detail()
    if dataset.state == ResourceState.ERROR:
        raise AnyLearnException("Error occured when uploading dataset")
    logger.info("Successfully uploaded dataset")
def _upload_model(model: Model,
                  model_archive: str,
                  uploader: Optional[ResourceUploader]=None,
                  polling: Union[float, int]=5):
    if not uploader:
        uploader = SyncResourceUploader()
    t_model = Thread(target=Resource.upload_file,
                    kwargs={
                        'resource_id': model.id,
                        'file_path': model_archive,
                        'uploader': uploader,
                    })
    logger.info(f"Uploading dataset {model.name}...")
    t_model.start()
    t_model.join()
    finished = [ResourceState.ERROR, ResourceState.READY]
    while model.state not in finished:
        time.sleep(polling)
        model.get_detail()
    if model.state == ResourceState.ERROR:
        raise AnyLearnException("Error occured when uploading model")
    logger.info("Successfully uploaded model")
def _get_or_create_default_project():
    try:
        return Project.get_my_default_project()
    except:
        return Project.create_my_default_project()
def _get_or_create_project(project_id: Optional[str]=None,
                           project_name: Optional[str]=None):
    try:
        return Project(id=project_id, load_detail=True)
    except:
        name = project_name or f"PROJ_{generate_random_name()}"
        description = project_name or "SDK_QUICKSTART"
        project = Project(name=name, description=description)
        project.save()
        return project
def _format_resource_request(resource_request: List[Dict[str, Dict[str, int]]]):
    if not resource_request:
        return None
    # Resource group by ID or name
    for i, req in enumerate(copy.deepcopy(resource_request)):
        for key, val in req.items():
            if key in ['default', 'besteffort'] or key.startswith('QGRP'):
                # Already ID
                continue
            try:
                qid = QuotaGroup(name=key, load_detail=True).id
            except AnyLearnException:
                raise AnyLearnException(f"Failed to find QuotaGroup: {key}")
            resource_request[i][qid] = val
            del resource_request[i][key]
    return resource_request
def _create_train_task(name: str,
                       algorithm: Algorithm,
                       project: Project,
                       hyperparams: dict,
                       hyperparams_prefix: str="--",
                       hyperparams_delimeter: str=" ",
                       mounts: Optional[Union[Dict[str, List[Resource]], Dict[str, Resource]]]=None,
                       algorithm_git_ref: Optional[str]=None,
                       resource_request: Optional[List[Dict[str, Dict[str, int]]]]=None,
                       description: Optional[str]=None,			
                       entrypoint: Optional[str]=None,
                       output: Optional[str]=None,
                       algorithm_dir: Optional[Union[str, Path]]=None,
                       mirror_name: Optional[str]="QUICKSTART",
                       num_nodes: Optional[int]=1,
                       nproc_per_node: Optional[int]=1):
    resource_ids = []
    train_params = hyperparams
    for k, v in mounts.items():
        if isinstance(v, list):
            resource_ids.extend([v_item.id for v_item in v])
            train_params[k] = [f"\"${v_item.id}\"" for v_item in v]
        else:
            resource_ids.append(v.id)
            train_params[k] = f"\"${v.id}\""
    train_task = TrainTask(
        name=name,
        project_id=project.id,
        algorithm_id=algorithm.id,
        algorithm_git_ref=algorithm_git_ref,
        files=resource_ids,
        train_params=json.dumps(train_params),
        train_params_prefix=hyperparams_prefix,
        train_params_delimeter=hyperparams_delimeter,
        resource_request=resource_request,
        description=description,
        num_nodes=num_nodes,
        nproc_per_node=nproc_per_node,
        mirror_id=get_mirror_by_name(mirror_name).id,
    )
    if entrypoint is not None:
        train_task.entrypoint = entrypoint
    if output is not None:
        output_path = Path(output)
        output_path_ok = (
            not output_path.is_absolute()
            and '..' not in output_path.parts
        )
        if output_path_ok and algorithm_dir is not None:
            output_path_joined = algorithm_dir / output_path
            output_path_ok &= (
                not output_path_joined.exists()
                or (
                    not output_path_joined.is_symlink()
                    and output_path_joined.is_dir()
                    and len([*output_path_joined.iterdir()]) == 0
                )
            )
        if not output_path_ok:
            raise AnyLearnException(
                f'Invalid output path. A relative path without ".." required, '
                f'and that path must be pointing at nothing or an empty '
                f'directory (symlink not allowed). Got '
                f'"{output_path_ok}".'
            )
        train_task.output = output
    train_task.save()
    train_task.get_detail()
    return train_task
[文档]def quick_train(algorithm_id: Optional[str]=None,
                algorithm_name: Optional[str]=None,
                algorithm_dir: Optional[Union[str, Path]]=None,
                algorithm_force_update: bool=False,
                algorithm_git_ref: Optional[str]=None,
                dataset_hyperparam_name: str="dataset",
                dataset_id: Optional[Union[List[str], str]]=None,
                dataset_dir: Optional[Union[str, Path]]=None,
                dataset_archive: Optional[str]=None,
                model_hyperparam_name: str="model",
                model_id: Optional[Union[List[str], str]]=None,
                pretrain_hyperparam_name: str="pretrain",
                pretrain_task_id: Optional[Union[List[str], str]]=None,
                project_id: Optional[str]=None,
                project_name: Optional[str]=None,
                entrypoint: Optional[str]=None,
                output: Optional[str]=None,
                mirror_name: Optional[str]="QUICKSTART",
                resource_uploader: Optional[ResourceUploader]=None,
                resource_polling: Union[float, int]=5,
                hyperparams: dict={},
                hyperparams_prefix: str="--",
                hyperparams_delimeter: str=" ",
                resource_request: Optional[List[Dict[str, Dict[str, int]]]]=None,
                quota_group_name: Optional[str]=None,
                quota_group_request: Optional[Dict[str, int]]=None,
                task_description: Optional[str]=None,
                num_nodes: Optional[int]=1,
                nproc_per_node: Optional[int]=1):
    """
    本地算法快速训练接口。
    仅需提供本地资源和训练相关的信息,
    即可在Anylearn后端引擎启动自定义算法/数据集的训练:
    
    - 算法路径(文件目录或压缩包)
    - 数据集路径(文件目录或压缩包)
    - 训练启动命令
    - 训练输出路径
    - 训练超参数
    本接口封装了Anylearn从零启动训练的一系列流程:
    - 算法注册、上传
    - 数据集注册、上传
    - 训练项目创建
    - 训练任务创建
    本地资源初次在Anylearn注册和上传时,
    会在本地记录资源的校验信息。
    下一次调用快速训练或快速验证接口时,
    如果提供了相同的资源信息,
    则不再重复注册和上传资源,
    自动复用远程资源。
    如有需要,也可向本接口传入已在Anylearn远程注册的算法或数据集的ID,
    省略资源创建的过程。
    Parameters
    ----------
    algorithm_id : :obj:`str`, optional
        已在Anylearn远程注册的算法ID。
    algorithm_name: :obj:`str`, optional
        指定的算法名称。
        注:同一用户的自定义算法的名称不可重复。
        如有重复,则复用已存在的同名算法,
        算法文件将被覆盖并提升版本。
        原有版本仍可追溯。
    algorithm_dir : :obj:`str`, optional
        本地算法目录路径。
    algorithm_git_ref : :obj:`str`, optional
        算法Gitea代码仓库的版本号(可以是commit号、分支名、tag名)。
        使用本地算法时,如未提供此参数,则取本地算法当前分支名。
    algorithm_force_update : :obj:`bool`, optional
        在同步算法的过程中是否强制更新算法,如为True,Anylearn会对未提交的本地代码变更进行自动提交。默认为False。
    dataset_hyperparam_name : :obj:`str`, optional
        启动训练时,数据集路径作为启动命令参数传入算法的参数名。
        需指定长参数名,如 :obj:`--data` ,并省略 :obj:`--` 部分传入。
        数据集路径由Anylearn后端引擎管理。
        默认为 :obj:`dataset` 。
    dataset_id : :obj:`str`, optional
        已在Anylearn远程注册的数据集ID。
    dataset_dir : :obj:`str`, optional
        本地数据集目录路径。
    dataset_archive : :obj:`str`, optional
        本地数据集压缩包路径。
    model_hyperparam_name : :obj:`str`, optional
        启动训练时,模型路径作为启动命令参数传入算法的参数名。
        需指定长参数名,如 :obj:`--model` ,并省略 :obj:`--` 部分传入。
        模型路径由Anylearn后端引擎管理。
        默认为 :obj:`model` 。
    model_id : :obj:`str`, optional
        已在Anylearn远程注册/转存的模型ID。
    pretrain_hyperparam_name: :obj:`str`, optional
        启动训练时,前置训练结果(间接抽象为“预训练”,即"pretrain")路径作为启动命令参数传入算法的参数名。
        需指定长参数名,如 :obj:`--pretrain` ,并省略 :obj:`--` 部分传入。
        预训练结果路径由Anylearn后端引擎管理。
        默认为 :obj:`pretrain` 。
    pretrain_task_id: :obj:`List[str]` | :obj:`str`, optional
        在Anylearn进行过的训练的ID,一般为前缀TRAI的32位字符串。
        Anylearn会对指定的训练进行结果抽取并挂载到新一次的训练中。
    project_id : :obj:`str`, optional
        已在Anylearn远程创建的训练项目ID。
    entrypoint : :obj:`str`, optional
        启动训练的入口命令。
    output : :obj:`str`, optional
        训练输出模型的相对路径(相对于算法目录)。
    resource_uploader : :obj:`ResourceUploader`, optional
        资源上传实现。
        默认使用系统内置的同步上传器 :obj:`SyncResourceUploader` 。
    resource_polling : :obj:`float|int`, optional
        资源上传中轮询资源状态的时间间隔(单位:秒)。
        默认为5秒。
    hyperparams : :obj:`dict`, optional
        训练超参数字典。
        超参数将作为训练启动命令的参数传入算法。
        超参数字典中的键应为长参数名,如 :obj:`--param` ,并省略 :obj:`--` 部分传入。
        如需要标识类参数(flag),可将参数的值设为空字符串,如 :obj:`{'my-flag': ''}` ,等价于 :obj:`--my-flag` 传入训练命令。
        默认为空字典。
    hyperparams_prefix : :obj:`str`, optional
        训练超参数键前标识,可支持hydra特殊命令行传参格式的诸如 :obj:`+key1` 、 :obj:`++key2` 、 空前置 :obj:`key3` 等需求,
        默认为 :obj:`--` 。
    hyperparams_delimeter :obj:`str`, optional
        训练超参数键值间的分隔符,默认为空格 :obj:` ` 。
    resource_request : :obj:`List[Dict[str, Dict[str, int]]]`, optional
        训练所需计算资源的请求。
        如未填,则使用Anylearn后端的 :obj:`default` 资源组中的默认资源套餐。
        自0.13.1版本起,此参数被标记为废弃,将于0.14.0版本中移除。
        请使用 :obj:`quota_group_name` 和 :obj:`quota_group_request` 作为替代。
        
        .. deprecated:: 0.13.1
            
            use :obj:`quota_group_name` and :obj:`quota_group_request` instead.
            remove in 0.14.0.
    quota_group_name : :obj:`str`, optional
        训练所需计算资源组名称或ID。
    quota_group_request : :obj:`dict`, optional
        训练所需计算资源组中资源数量。
        若 :obj:`quota_group_name` 和 :obj:`quota_group_request` 有其一未填,则使用Anylearn后端的 :obj:`default` 资源组中的默认资源套餐。
    task_description : :obj:`str`, optional
        训练任务详细描述。
        若值为非空,
        且参数 :obj:`algorithm_force_update` 为 :obj:`True` 时,
        则Anylearn在自动提交本地算法变更时,
        会将此值作为commit message同步至远端
    num_nodes : :obj:`int`, optional
        分布式训练需要的节点数。
    nproc_per_node : :obj:`int`, optional
        分布式训练每个节点运行的进程数。
    Returns
    -------
    TrainTask
        创建的训练任务对象
    Algorithm
        在快速训练过程中创建或获取的算法对象
    Dataset
        在快速训练过程中创建或获取的数据集对象
    Project
        创建的训练项目对象
    """
    # Resource request
    if quota_group_name and quota_group_request:
        resource_request = [{quota_group_name: quota_group_request}]
    resource_request = _format_resource_request(resource_request)
    # Remote odel
    if model_id is None:
        model_id = []
    elif not isinstance(model_id, list):
        model_id = [model_id]
    models = [
        Model(id=_id, load_detail=True)
        for _id in model_id
    ]
    # Remote dataset
    if dataset_id is None:
        dataset_id = []
    elif not isinstance(dataset_id, list):
        dataset_id = [dataset_id]
    datasets = [
        Dataset(id=_id, load_detail=True)
        for _id in dataset_id
    ]
    # Remote pretrain task results
    if pretrain_task_id is None:
        pretrain_task_id = []
    elif not isinstance(pretrain_task_id, list):
        pretrain_task_id = [pretrain_task_id]
    pretrain_tasks = [
        TrainTask(id=_id, load_detail=True)
        for _id in pretrain_task_id 
    ]
    # Algorithm
    try:
        algo, current_sha = sync_algorithm(
            id=algorithm_id,
            name=algorithm_name,
            dir_path=algorithm_dir,
            mirror_name=mirror_name,
            uploader=resource_uploader,
            polling=resource_polling,
            force=algorithm_force_update,
            commit_msg=task_description,
        )
    except AnylearnRequiredLocalCommitException:
        # Notify possible usage of algorithm_force_update=True
        raise AnylearnRequiredLocalCommitException(
            "Local algorithm code has uncommitted changes. "
            "Please commit your changes or "
            "specify `algorithm_force_update=True` "
            "to let Anylearn make an auto-commit."
    )
    # Dataset
    dset, dataset_archive, dataset_checksum = _get_or_create_dataset(
        dir_path=dataset_dir,
        archive_path=dataset_archive
    )
    if dataset_archive:
        # Local dataset registration
        _upload_dataset(dataset=dset,
                        dataset_archive=dataset_archive,
                        uploader=resource_uploader,
                        polling=resource_polling)
        DB().create_local_dataset(id=dset.id, checksum=dataset_checksum)
    if not len(datasets) and dset:
        datasets.append(dset)
    mounts = {}
    if datasets and dataset_hyperparam_name:
        mounts[dataset_hyperparam_name] = datasets
    if models and model_hyperparam_name:
        mounts[model_hyperparam_name] = models
    if pretrain_tasks and pretrain_hyperparam_name:
        mounts[pretrain_hyperparam_name] = [t.get_results_file() for t in pretrain_tasks]
    # Project
    if project_id or project_name:
        project = _get_or_create_project(project_id=project_id,
                                         project_name=project_name)
    else:
        try:
            project = _get_or_create_default_project()
        except:
            # Backward compatibility when default projects not supported
            project = _get_or_create_project()
    # Train task
    train_task_name = generate_random_name()
    train_task = _create_train_task(
        name=train_task_name,
        algorithm=algo,
        algorithm_git_ref=algorithm_git_ref or current_sha,
        project=project,
        hyperparams=hyperparams,
        hyperparams_prefix=hyperparams_prefix,
        hyperparams_delimeter=hyperparams_delimeter,
        mounts=mounts,
        resource_request=resource_request,
        description=task_description,
        num_nodes=num_nodes,
        nproc_per_node=nproc_per_node,
        entrypoint=entrypoint,
        output=output,
        algorithm_dir=algorithm_dir,
        mirror_name=mirror_name,
    )
    DB().create_or_update_train_task(train_task=train_task)
    train_profile = TrainProfile(id=generate_primary_key("DESC"),
                                 train_task_id=train_task.id,
                                 entrypoint=train_task.entrypoint,
                                 algorithm_id=algo.id,
                                 dataset_id=','.join([dset.id for dset in datasets]) if datasets else None,
                                 train_params=train_task.train_params,
                                 algorithm_dir=str(algorithm_dir),
                                 dataset_dir=str(dataset_dir),
                                 dataset_archive=dataset_archive,)
    train_profile.create_in_db()
    return train_task, algo, datasets, project 
def resume_unfinished_local_train_tasks():
    db = DB()
    local_list = db.get_unfinished_train_tasks()
    task_list = [TrainTask(id=local_train_task.id,
                           secret_key=local_train_task.secret_key,
                           project_id=local_train_task.project_id,
                           state=local_train_task.remote_state_sofar,
                           load_detail=True)
                 for local_train_task in local_list]
    [db.update_train_task(train_task) for train_task in task_list]
    return task_list