anylearn.interfaces.resource.algorithm 源代码

from __future__ import annotations
from datetime import datetime
import json
from typing import Optional, Union
import re

from anylearn.utils.api import url_base, get_with_token, post_with_token
from anylearn.utils.errors import AnyLearnException
from anylearn.interfaces.resource.resource import Resource, ResourceState

[文档]class Algorithm(Resource): """ AnyLearn算法类,以方法映射算法CRUD相关接口 Attributes ---------- id 算法的唯一标识符,自动生成,由ALGO+uuid1生成的编码中后28个有效位(小写字母和数字)组成)(自动生成) name 算法的名称 description 算法的描述 state 算法状态 public 算法是否公开(默认为False) upload_time 算法上传时间 filename 下一步中会被分片上传的文件的完整文件名(包括扩展名)(非空 长度1~128) is_zipfile 是否为zip文件 file_path 算法文件路径 size 算法文件大小 creator_id 算法的创建者 node_id 算法节点ID tags 算法的标签 mirror_id 算法使用的基础镜像的id train_params 算法的训练参数,包括数据集参数 follows_anylearn_norm 算法是否符合Anylearn的算法规范(默认为True) git_address 算法的Anylearn Gitea远端代码仓库地址 git_migrated 算法是否由传统模式自动迁移至Anylearn Gitea load_detail 初始化时是否加载详情 """ """具体资源信息配置""" _fields = { # 资源创建/更新请求包体中必须包含且不能为空的字段 'required': { 'create': ['name', 'filename', 'mirror_id'], 'update': ['id', 'name'], }, # 资源创建/更新请求包体中包含的所有字段 'payload': { 'create': ['name', 'description', 'public', 'filename', 'tags', 'mirror_id', 'train_params', 'follows_anylearn_norm'], 'update': ['id', 'name', 'description', 'public', 'tags', 'mirror_id', 'train_params', 'follows_anylearn_norm'], }, } """ 创建/更新对象时: - 必须包含且不能为空的字段 :obj:`_fields['required']` - 所有字段 :obj:`_fields['payload']` """ __train_params = None required_train_params = [] default_train_params = {}
[文档] def __init__(self, id: Optional[str]=None, name: Optional[str]=None, description: Optional[str]=None, state: Optional[int]=None, public=False, upload_time: Optional[Union[datetime, str]]=None, filename: Optional[str]=None, is_zipfile: Optional[int]=None, file_path: Optional[str]=None, size: Optional[str]=None, creator_id: Optional[str]=None, node_id: Optional[str]=None, tags: Optional[str]=None, mirror_id: Optional[str]=None, train_params: Optional[str]=None, follows_anylearn_norm=True, git_address: Optional[str]=None, git_migrated: Optional[bool]=False, load_detail=False): """ Parameters ---------- id 算法的唯一标识符,自动生成,由ALGO+uuid1生成的编码中后28个有效位(小写字母和数字)组成)(自动生成) name 算法的名称 description 算法的描述 state 算法状态 public 算法是否公开(默认为False) upload_time 算法上传时间 filename 下一步中会被分片上传的文件的完整文件名(包括扩展名)(非空 长度1~128) is_zipfile 是否为zip文件 file_path 算法文件路径 size 算法文件大小 creator_id 算法的创建者 node_id 算法节点ID tags 算法的标签 mirror_id 算法使用的基础镜像的id train_params 算法的训练参数,包括数据集参数 follows_anylearn_norm 算法是否符合Anylearn的算法规范(默认为True) git_address 算法的Anylearn Gitea远端代码仓库地址 git_migrated 算法是否由传统模式自动迁移至Anylearn Gitea load_detail 初始化时是否加载详情 """ self.tags = tags self.mirror_id = mirror_id self.train_params = train_params self.follows_anylearn_norm = follows_anylearn_norm self.git_address = git_address self.git_migrated = git_migrated super().__init__(id=id, name=name, description=description, state=state, public=public, upload_time=upload_time, filename=filename, is_zipfile=is_zipfile, file_path=file_path, size=size, creator_id=creator_id, node_id=node_id, load_detail=load_detail)
def __eq__(self, other: Algorithm) -> bool: if not isinstance(other, Algorithm): return NotImplemented return all([ self.id == other.id, self.name == other.name, self.description == other.description, self.public == other.public, # self.state == other.state, # self.upload_time == other.upload_time, # self.filename == other.filename, # self.file_path == other.file_path, # self.is_zipfile == other.is_zipfile, # self.size == other.size, # self.creator_id == other.creator_id, # self.node_id == other.node_id, # self.tags == other.tags, # self.mirror_id == other.mirror_id, # self.train_params == other.train_params, self.follows_anylearn_norm == other.follows_anylearn_norm, self.git_address == other.git_address, ])
[文档] @classmethod def get_list(cls) -> list: """ 获取算法列表 Returns ------- List [Algorithm] 算法对象的集合。 """ res = get_with_token(f"{url_base()}/algorithm/list") if res is None or not isinstance(res, list): raise AnyLearnException("请求未能得到有效响应") return [ Algorithm(id=a['id'], name=a['name'], description=a['description'], state=a['state'], public=a['public'], upload_time=a['upload_time'], tags=a['tags'], follows_anylearn_norm=a['follows_anylearn_norm'], git_address=a['git_address'], git_migrated=a['git_migrated']) for a in res ]
[文档] def get_detail(self): """ 获取算法详细信息 - 对象属性 :obj:`id` 应为非空 Returns ------- Algorithm 算法对象。 """ self._check_fields(required=['id']) res = get_with_token(f"{url_base()}/algorithm/query", params={'id': self.id}) if not res or not isinstance(res, list): raise AnyLearnException("请求未能得到有效响应") res = res[0] self.__init__(id=res['id'], name=res['name'], description=res['description'], state=res['state'], public=res['public'], upload_time=res['upload_time'], filename=res['filename'], is_zipfile=True if res['is_zipfile'] == 1 else False, file_path=res['file'], size=res['size'], creator_id=res['creator_id'], node_id=res['node_id'], mirror_id=res['mirror_id'], tags=res['tags'], train_params=res['train_params'], follows_anylearn_norm=res['follows_anylearn_norm'], git_address=res['git_address'], git_migrated=res['git_migrated'])
[文档] def _namespace(self): return "algorithm"
# def _create(self): # if not self.follows_anylearn_norm: # self._check_fields(required=['entrypoint_training', # 'output_training']) # return super()._create() def _payload_create(self): payload = super()._payload_create() # Algorithm name verification self.__validate_algo_name(payload) payload['follows_anylearn_norm'] = int(payload['follows_anylearn_norm']) return self.__payload_algo_params_to_str(payload) def _payload_update(self): payload = super()._payload_update() # Algorithm name verification self.__validate_algo_name(payload) payload['follows_anylearn_norm'] = int(payload['follows_anylearn_norm']) return self.__payload_algo_params_to_str(payload) def __validate_algo_name(self, payload): name_pattern = re.compile(r'[a-zA-Z0-9_.-]+$') if not name_pattern.match(payload['name']): raise AnyLearnException( "Algorithm name should contain only alphanumeric, dash ('-'), " "underscore ('_') and dot ('.') characters" ) def __payload_algo_params_to_str(self, payload): if 'train_params' in payload: payload['train_params'] = json.dumps(payload['train_params']) if 'owner' in payload and isinstance(payload['owner'], list): payload['owner'] = ",".join(payload['owner']) return payload @property def train_params(self): """ 获取训练参数 """ return self.__train_params @train_params.setter def train_params(self, train_params): """ 设置训练参数 """ if not train_params: return params = json.loads(train_params) if params == None: return self.__train_params = params (self.required_train_params, self.default_train_params) = self.__parse_params(params) def __parse_params(self, params): required_params = [p for p in params if 'default' not in p] default_params = {p['name']: p['default'] for p in params if 'default' in p} return required_params, default_params
[文档] @classmethod def get_user_custom_algorithm_by_name(cls, name: str): """ 根据算法名称获取当前用户的自定义算法 Parameters ---------- name : :obj:`str` 算法名称。 Returns ------- Algorithm 算法对象。 """ res = get_with_token(f"{url_base()}/algorithm/custom", params={'name': name}) if not res or not isinstance(res, dict): raise AnyLearnException("请求未能得到有效响应") return Algorithm( id=res['id'], name=res['name'], description=res['description'], state=res['state'], public=res['public'], upload_time=res['upload_time'], filename=res['filename'], is_zipfile=True if res['is_zipfile'] == 1 else False, file_path=res['file'], size=res['size'], creator_id=res['creator_id'], node_id=res['node_id'], tags=res['tags'], mirror_id=res['mirror_id'], train_params=res['train_params'], follows_anylearn_norm=res['follows_anylearn_norm'], git_address=res['git_address'], git_migrated=res['git_migrated'], )
[文档] def sync_finish(self, checkout_sha:str=None): """ 算法代码仓库同步完成接口 - 对象属性 :obj:`id` 应为非空 Returns ------- None """ self._check_fields(required=['id']) post_with_token(f"{url_base()}/resource/sync_finish", data={ 'algorithm_id': self.id, 'checkout_sha': checkout_sha, }) self.get_detail() if self.state != ResourceState.READY: raise AnyLearnException("Algorithm is not ready after sync finish")