from datetime import datetime
import json
import time
from typing import Dict, List, Optional, Union
from anylearn.utils.api import (
    url_base,
    get_with_token,
    post_with_token,
    post_with_secret_key,
)
from anylearn.utils.errors import (
    AnyLearnException,
    AnyLearnMissingParamException,
)
from anylearn.utils.func import logs_beautify
from anylearn.interfaces.base import BaseObject
from anylearn.interfaces.resource import (
    AsyncResourceDownloader,
    File,
    Model,
    Resource,
    ResourceDownloader,
    ResourceState,
    SyncResourceDownloader,
)
from anylearn.applications.train_profile import TrainProfile
from anylearn.storage.db.models import SqlLocalTrainTask
[文档]class TrainTaskState:
    """
    训练任务状态标识:
    - 0(CREATED)表示已创建
    - 1(RUNNING)表示运行中
    - 2(SUCCESS)表示已完成
    - -1(DELETED)表示已删除
    - -2(FAIL)表示失败
    - -3(ABORT)表示中断
    """
    CREATED = 0
    RUNNING = 1
    SUCCESS = 2
    DELETED = -1
    FAIL = -2
    ABORT = -3 
[文档]class TrainTask(BaseObject):
    """
    AnyLearn训练任务类,以方法映射训练任务CRUD相关接口
    Attributes
    ----------
    id
        训练任务的唯一标识符,自动生成,由TRAI+uuid1生成的编码中后28个有效位(小写字母和数字)组成
    name
        训练任务的名称(长度1~50))
    description
        训练任务描述(长度最大255)
    state
        训练任务状态
    creator_id
        创建者ID
    project_id
        训练项目ID
    algorithm_id
        算法ID
    algorithm_git_ref
        算法Gitea代码仓库的版本号(可以是commit号、分支名、tag名),非必填。
    train_params
        训练任务参数列表
    train_params_prefix
        训练超参数键前标识
    train_params_delimeter
         训练超参数键值间的分隔符
    files
        训练相关的文件(默认'' 多个的话用逗号','隔开)
    results_id
        训练结果文件ID
    secret_key
        密钥
    create_time
        创建时间
    finish_time
        结束时间
    envs
        训练任务环境变量(默认'')
    hpo
        是否开启超参数自动调优
    hpo_search_space
        开启超参数自动调优时不能为空
    final_metric
        最终结果指标
    load_detail
        初始化时是否加载详情
    resource_request : :obj:`List[Dict[str, Dict[str, int]]]`, optional
        训练所需计算资源的请求。
        如未填,则使用Anylearn后端的:obj:`default`资源组中的默认资源套餐。
    entrypoint
        算法训练的启动命令,非标准算法必填
    output
        算法训练结果(模型)存储目录路径,非标准算法必填
    mirror_id
        训练使用的镜像ID,默认为空,即使用算法绑定的镜像ID
    """
    _fields = {
        # 资源创建/更新请求包体中必须包含且不能为空的字段
        'required': {
            'create': ['name', 'project_id', 'algorithm_id', 'train_params'],
            'update': [],
        },
        # 资源创建/更新请求包体中包含的所有字段
        'payload': {
            'create': ['name', 'description', 'project_id', 'algorithm_id',
                       'algorithm_git_ref', 'train_params', 'envs',
                       'train_params_prefix', 'train_params_delimeter',
                       'files', 'resource_request', 'num_nodes', 'nproc_per_node',
                       'entrypoint', 'output', 'mirror_id'],
            'update': [],
        },
    }
    """
    创建/更新对象时:
    - 必须包含且不能为空的字段 :obj:`_fields['required']`
    - 所有字段 :obj:`_fields['payload']`
    """
[文档]    def __init__(self,
                 id: Optional[str]=None,
                 name: Optional[str]=None,
                 description: Optional[str]=None,
                 state: Optional[int]=None,
                 creator_id: Optional[str]=None,
                 project_id: Optional[str]=None,
                 algorithm_id: Optional[str]=None,
                 algorithm_git_ref: Optional[str]=None,
                 train_params: Optional[str]=None,
                 train_params_prefix: str="--",
                 train_params_delimeter: str=" ",
                 files: Optional[list]=None,
                 results_id: Optional[str]=None,
                 secret_key: Optional[str]=None,
                 create_time: Optional[datetime]=None,
                 finish_time: Optional[datetime]=None,
                 envs: Optional[str]=None,
                 hpo=False,
                 hpo_search_space: Optional[str]=None,
                 final_metric: Optional[float]=None,
                 resource_request: Optional[List[Dict[str, Dict[str, int]]]]=None,
                 load_detail=False,
                 entrypoint: Optional[str]=None,
                 output: Optional[str]=None,
                 mirror_id: Optional[str]=None,
                 num_nodes=1,
                 nproc_per_node=1):
        """
        Parameters
        ----------
        id
            训练任务的唯一标识符,自动生成,由TRAI+uuid1生成的编码中后28个有效位(小写字母和数字)组成
        name
            训练任务的名称(长度1~50))
        description
            训练任务描述(长度最大255)
        state
            训练任务状态
        creator_id
            创建者ID
        project_id
            训练项目ID
        algorithm_id
            算法ID
        algorithm_git_ref
            算法Gitea代码仓库的版本号(可以是commit号、分支名、tag名),非必填。
        train_params
            训练任务参数列表
        train_params_prefix
            训练超参数键前标识
        train_params_delimeter
            训练超参数键值间的分隔符
        files
            训练相关的文件(默认'' 多个的话用逗号','隔开)
        results_id
            训练结果文件ID
        secret_key
            密钥
        create_time
            创建时间
        finish_time
            结束时间
        envs
            训练任务环境变量(默认'')
        hpo
            是否开启超参数自动调优
        hpo_search_space
            开启超参数自动调优时不能为空
        final_metric
            最终结果指标
        resource_request : :obj:`List[Dict[str, Dict[str, int]]]`, optional
            训练所需计算资源的请求。
            如未填,则使用Anylearn后端的:obj:`default`资源组中的默认资源套餐。
        load_detail
            初始化时是否加载详情
        num_nodes
            训练任务请求的节点数
        entrypoint
            算法训练的启动命令,非标准算法必填
        output
            算法训练结果(模型)存储目录路径,非标准算法必填
        mirror_id
            训练使用的镜像ID,默认为空,即使用算法绑定的镜像ID
        """
        self.name = name
        self.description = description
        self.state = state
        self.creator_id = creator_id
        self.project_id = project_id
        self.algorithm_id = algorithm_id
        self.algorithm_git_ref = algorithm_git_ref
        self.train_params = train_params
        self.train_params_prefix = train_params_prefix
        self.train_params_delimeter = train_params_delimeter
        self.files = files
        self.results_id = results_id
        self.secret_key = secret_key
        self.create_time = create_time
        self.finish_time = finish_time
        self.envs = envs
        self.hpo = hpo
        self.hpo_search_space = hpo_search_space
        self.final_metric = final_metric
        self.resource_request = resource_request
        self.num_nodes = num_nodes
        self.nproc_per_node = nproc_per_node
        self.entrypoint = entrypoint
        self.output = output
        self.mirror_id = mirror_id
        super().__init__(id=id, load_detail=load_detail) 
[文档]    def finished(self):
        """
        检查训练任务是否完成
        Returns
        -------
        bool
            True or False
        """
        return self.state in [
            TrainTaskState.SUCCESS,
            TrainTaskState.FAIL,
            TrainTaskState.ABORT,
            TrainTaskState.DELETED,
        ] 
[文档]    @classmethod
    def get_list(cls):
        """
        Listing is not supported for TrainTask
        """
        raise AnyLearnException("Listing is not supported for TrainTask") 
[文档]    def get_detail(self):
        """
        获取训练任务详细信息
        - 对象属性 :obj:`id` 应为非空
        Returns
        -------
        TrainTask
            训练任务对象。
        """
        self._check_fields(required=['id'])
        res = get_with_token(f"{url_base()}/train_task/query",
                             params={'id': self.id})
        if not res or not isinstance(res, list):
            raise AnyLearnException("请求未能得到有效响应")
        res = res[0]
        self.__init__(id=res['id'], name=res['name'],
                      description=res['description'], state=res['state'],
                      creator_id=res['creator_id'],
                      project_id=res['project_id'],
                      algorithm_id=res['algorithm_id'],
                      algorithm_git_ref=res['algorithm_git_ref'],
                      train_params=res['args'], files=res['files'],
                      results_id=res['results_id'],
                      secret_key=res['secret_key'],
                      create_time=res['create_time'],
                      finish_time=res['finish_time'],
                      entrypoint=res['entrypoint'],
                      output=res['output'],
                      mirror_id=res['mirror_id'],
                      envs=res['envs'],
                      resource_request=json.loads(res['resource_request'])) 
[文档]    def _create(self):
        data = self._payload_create()
        if data['files'] and isinstance(data['files'], list):
            data['files'] = ','.join(data['files'])
        if self.hpo:
            if not self.hpo_search_space:
                msg = f"{self.__class__.__name__}缺少必要字段:hpo_search_space" + \
                    
"——当开启超参数自动调优时(hpo==True),hpo_search_space为必填字段"
                raise AnyLearnMissingParamException(msg)
            data['hpo'] = 1
            data['hpo_search_space'] = json.dumps(self.hpo_search_space)
        if data['resource_request']:
            data['resources'] = json.dumps(data['resource_request'])
        del(data['resource_request'])
        res = post_with_token(self._url_create(), data=data)
        if not res or 'data' not in res:
            raise AnyLearnException("请求未能得到有效响应")
        self.id = res['data']
        return True 
[文档]    def _update(self):
        # No update for train task
        pass 
[文档]    def get_log(self, limit=100, direction="init", offset=0, offset_index=-1):
        """
        训练任务日志查询接口
        - 对象属性 :obj:`id` 应为非空
        :param limit: :obj:`int`
            日志条数上限(默认值100)。
        :param direction: :obj:`str`
            日志查询方向。
        :param offset: :obj:`int`
            日志查询偏移量。
        :param offset_index :obj:`int`
            日志查询偏移量索引,搭配偏移量使用作为分页基准。
        :return: 
            .. code-block:: json
                [
                  {
                    "offset": 164324567,
                    "offset_index": 1234,
                    "text": "Task TRAId123 started."
                  },
                  {
                    "offset": 164324590,
                    "offset_index": 1238,
                    "text": "Task TRAId123 finished."
                  }
                ]
        """
        self._check_fields(required=['id'])
        params = {
            'id': self.id,
            'limit': limit,
            'direction': direction,
            'index': offset,
            'offset_index': offset_index,
        }
        res = get_with_token(f"{url_base()}/train_task/log", params=params)
        if not res or type(res) != list:
            raise AnyLearnException("请求未能得到有效响应")
        return [r for r in res if r['text'].strip()] 
[文档]    def get_last_log(self, limit: int=100, debug: bool=False):
        """
        训练任务日志最近n行查询接口,返回日志文本列表。
        - 对象属性 :obj:`id` 应为非空
        :param limit :obj:`int`
            需要查询的行数(默认100)。
        :param debug :obj:`bool`
            是否显示更全面的debug信息(默认False)。
        :return: 
            .. code-block:: json
                [
                  "log text1",
                  "log text2"
                ]
        """
        logs = logs_beautify(logs=self.get_log(limit=limit), debug=debug)
        return list(reversed(logs)) 
[文档]    def get_full_log(self, debug: bool=False):
        """
        训练任务日志全量查询接口,返回日志文本列表。
        - 对象属性 :obj:`id` 应为非空
        :param debug :obj:`bool`
            是否显示更全面的debug信息(默认False)。
        :return: 
            .. code-block:: json
                [
                  "log text1",
                  "log text2"
                ]
        """
        logs = []
        offset = 0
        offset_index = -1
        while True:
            try:
                log_parts = self.get_log(
                    offset=offset,
                    offset_index=offset_index,
                    direction="back"
                )
                last = log_parts[-1]
                offset = last['offset']
                offset_index = last['offset_index']
                logs.extend(log_parts)
            except:
                break
        return logs_beautify(logs=logs, debug=debug) 
[文档]    def stream_log(self,
                   init_limit: int=100,
                   polling: int=2,
                   debug: bool=False):
        """
        实时训练任务日志流式生成接口,每次迭代返回日志文本的一行。
        - 对象属性 :obj:`id` 应为非空
        :param init_limit :obj:`bool`
            起始日志需要查询的行数(默认100)。
        :param polling :obj:`int`
            轮询间隔时间(单位:秒,默认2)。
        :param debug :obj:`bool`
            是否显示更全面的debug信息(默认False)。
        :return: :obj:`Iterator`
        """
        # Some initial logs (latest)
        logs = []
        while not logs:
            logs = self.get_log(limit=init_limit)
        for l in reversed(logs_beautify(logs, debug=debug)):
            yield l
        # Keep fetching new logs
        offset = logs[0]['offset']
        offset_index = logs[0]['offset_index']
        while not self.finished():
            try:
                self.get_detail()
                logs = self.get_log(
                    offset=offset,
                    offset_index=offset_index,
                    direction="back",
                )
                if not logs:
                    continue
                for l in logs_beautify(logs, debug=debug):
                    yield l
                offset = logs[-1]['offset']
                offset_index = logs[-1]['offset_index']
            except:
                time.sleep(polling) 
[文档]    def get_status(self):
        """
        训练任务状态查询接口
        - 对象属性 :obj:`id` 、 :obj:`secret_key` 应为非空
        :return: 
            .. code-block:: json
                {
                  "current_epoch": "2",
                  "current_train_loss": "2.169192314147949",
                  "current_train_step": "1288",
                  "ip": "10.244.2.124",
                  "process": "1.0",
                  "secret_key": "TKEY123",
                  "state": "success"
                }
        """
        self._check_fields(required=['id', 'secret_key'])
        params = {
            'id': self.id,
            'secret_key': self.secret_key,
        }
        res = get_with_token(f"{url_base()}/train_task/status", params=params)
        if not res or type(res) != dict:
            raise AnyLearnException("请求未能得到有效响应")
        return res 
[文档]    def download_results(self,
                         save_path: str,
                         async_download: bool=True,
                         downloader: Optional[ResourceDownloader]=None,
                         polling: Union[float, int]=5,
                        ):
        """
        下载训练任务结果
        Parameters
        ----------
        save_path : :obj:`str`
            文件保存路径。
        downloader : :obj:`ResourceDownloader`
            可以使用SDK中的AsyncResourceDownloader,也可以自定义实现ResourceDownloader。
        polling : :obj:`float, int`
            下载前要先压缩文件,轮询查看文件有没有压缩完的时间间隔,单位:秒。默认值5
        Returns
        -------
        str
            文件名。
        """
        self._check_fields(required=['results_id'])
        if self.state == TrainTaskState.FAIL:
            raise AnyLearnException("训练失败!")
        elif self.state == TrainTaskState.DELETED:
            raise AnyLearnException("训练已删除!")
        elif self.state == TrainTaskState.ABORT:
            raise AnyLearnException("训练中断!")
        elif self.state == TrainTaskState.CREATED:
            raise AnyLearnException("训练未开始!")
        elif self.state == TrainTaskState.RUNNING:
            raise AnyLearnException("正在训练中,请耐心等待...")
        if not downloader:
            if async_download:
                downloader = AsyncResourceDownloader()
            else:
                downloader = SyncResourceDownloader()
        return Resource.download_file(resource_id=self.results_id, # type: ignore
                                      save_path=save_path,
                                      downloader=downloader,
                                      polling=polling,
                                     ) 
[文档]    def report_final_metric(self, metric: float):
        """
        训练任务汇报最终结果指标
        - 对象属性 :obj:`id` 、 :obj:`secret_key` 应为非空
        :param metric: :obj:`float`
            最终结果指标。
        :return:
            .. code-block:: json
                {
                  "msg": "任务TRAId123结果指标保存成功"
                }
        """
        self._check_fields(required=['id', 'secret_key'])
        data = {
            'id': self.id,
            'metric': metric,
        }
        res = post_with_secret_key(f"{url_base()}/train_task/final_metric",
                                   data=data,
                                   secret_key=self.secret_key)
        if not res or type(res) != dict:
            raise AnyLearnException("请求未能得到有效响应")
        self.final_metric = metric
        return res 
[文档]    def get_final_metric(self):
        """
        获取训练任务最终结果指标
        - 对象属性 :obj:`id` 应为非空
        :return: 
            .. code-block:: json
                {
                  "final_metric": 662.8,
                  "id": "TRAI1d3",
                  "name": "test"
                }
        """
        self._check_fields(required=['id'])
        res = get_with_token(f"{url_base()}/train_task/final_metric",
                             params={'id': self.id})
        if not res or 'final_metric' not in res:
            raise AnyLearnException("请求未能得到有效响应")
        self.final_metric = res['final_metric']
        return res 
[文档]    def get_train_profile(self):
        """
        获取训练任务描述
        - 对象属性 :obj:`id` 应为非空
        Returns
        -------
        TrainProfile
            训练任务描述对象。
        """
        self._check_fields(required=['id'])
        return TrainProfile.get(train_task_id=self.id) 
    def get_results_file(self) -> File:
        self._check_fields(required=['results_id'])
        return File(id=self.results_id, load_detail=True)
[文档]    @classmethod
    def from_sql(cls, sql_local_train_task: SqlLocalTrainTask):
        """
        把本地保存的训练任务SqlLocalTrainTask转化为TrainTask
        Parameters
        ----------
        sql_local_train_task : :obj:`SqlLocalTrainTask`
            本地训练任务
        Returns
        -------
        TrainTask
        """
        return TrainTask(id=sql_local_train_task.id, state=int(sql_local_train_task.remote_state_sofar),
                         secret_key=sql_local_train_task.secret_key,
                         project_id=sql_local_train_task.project_id,
                         final_metric=sql_local_train_task.hpo_final_metric) # type: ignore 
[文档]    def _namespace(self):
        return "train_task"