Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import Callable, List, Union, Dict | |
| # fake default account to use kaggle.api.kaggle_api_extended | |
| os.environ['KAGGLE_USERNAME']='' | |
| os.environ['KAGGLE_KEY']='' | |
| from kaggle.api.kaggle_api_extended import KaggleApi | |
| from kaggle.rest import ApiException | |
| import shutil | |
| import time | |
| import threading | |
| import copy | |
| from logger import sheet_logger | |
| def get_api(): | |
| api = KaggleApi() | |
| api.authenticate() | |
| return api | |
| class KaggleApiWrapper(KaggleApi): | |
| """ | |
| Override KaggleApi.read_config_environment to use username and secret without environment variables | |
| """ | |
| def __init__(self, username, secret): | |
| super().__init__() | |
| self.username = username | |
| self.secret = secret | |
| def read_config_environment(self, config_data=None, quiet=False): | |
| config = super().read_config_environment(config_data, quiet) | |
| config['username'] = self.username | |
| config['key'] = self.secret | |
| # only work for pythonanyware | |
| # config['proxy'] = "http://proxy.server:3128" | |
| return config_data | |
| def __del__(self): | |
| # todo: fix bug when delete api | |
| pass | |
| # def get_accelerator_quota_with_http_info(self): # noqa: E501 | |
| # """ | |
| # | |
| # This method makes a synchronous HTTP request by default. To make an | |
| # asynchronous HTTP request, please pass async_req=True | |
| # >>> thread = api.competitions_list_with_http_info(async_req=True) | |
| # >>> result = thread.get() | |
| # | |
| # :param async_req bool | |
| # :param str group: Filter competitions by a particular group | |
| # :param str category: Filter competitions by a particular category | |
| # :param str sort_by: Sort the results | |
| # :param int page: Page number | |
| # :param str search: Search terms | |
| # :return: Result | |
| # If the method is called asynchronously, | |
| # returns the request thread. | |
| # """ | |
| # | |
| # all_params = [] # noqa: E501 | |
| # all_params.append('async_req') | |
| # all_params.append('_return_http_data_only') | |
| # all_params.append('_preload_content') | |
| # all_params.append('_request_timeout') | |
| # | |
| # params = locals() | |
| # | |
| # collection_formats = {} | |
| # | |
| # path_params = {} | |
| # | |
| # query_params = [] | |
| # # if 'group' in params: | |
| # # query_params.append(('group', params['group'])) # noqa: E501 | |
| # # if 'category' in params: | |
| # # query_params.append(('category', params['category'])) # noqa: E501 | |
| # # if 'sort_by' in params: | |
| # # query_params.append(('sortBy', params['sort_by'])) # noqa: E501 | |
| # # if 'page' in params: | |
| # # query_params.append(('page', params['page'])) # noqa: E501 | |
| # # if 'search' in params: | |
| # # query_params.append(('search', params['search'])) # noqa: E501 | |
| # | |
| # header_params = {} | |
| # | |
| # form_params = [] | |
| # local_var_files = {} | |
| # | |
| # body_params = None | |
| # # HTTP header `Accept` | |
| # header_params['Accept'] = self.api_client.select_header_accept( | |
| # ['application/json']) # noqa: E501 | |
| # | |
| # # Authentication setting | |
| # auth_settings = ['basicAuth'] # noqa: E501 | |
| # | |
| # return self.api_client.call_api( | |
| # 'i/kernels.KernelsService/GetAcceleratorQuotaStatistics', 'GET', | |
| # # '/competitions/list', 'GET', | |
| # path_params, | |
| # query_params, | |
| # header_params, | |
| # body=body_params, | |
| # post_params=form_params, | |
| # files=local_var_files, | |
| # response_type='Result', # noqa: E501 | |
| # auth_settings=auth_settings, | |
| # async_req=params.get('async_req'), | |
| # _return_http_data_only=params.get('_return_http_data_only'), | |
| # _preload_content=params.get('_preload_content', True), | |
| # _request_timeout=params.get('_request_timeout'), | |
| # collection_formats=collection_formats) | |
| # | |
| # if __name__ == "__main__": | |
| # api = KaggleApiWrapper('[email protected]', "c54e96568075fcc277bd10ba0e0a52b9") | |
| # api.authenticate() | |
| # print(api.get_accelerator_quota_with_http_info()) | |
| class ValidateException(Exception): | |
| def __init__(self, message: str): | |
| super(ValidateException, self).__init__(message) | |
| def from_api_exception(e: ApiException, kernel_slug: str): | |
| return ValidateException(f"Error: {e.status} {e.reason} with notebook {kernel_slug}") | |
| def from_api_exception_list(el: List[ApiException], kernel_slug_list: List[str]): | |
| message = f"Error: \n" | |
| for e, k in zip(el, kernel_slug_list): | |
| message = message + f"\t{e.status} {e.reason} with notebook {k}" | |
| return ValidateException(message) | |
| class KaggleNotebook: | |
| def __init__(self, api: KaggleApi, kernel_slug: str, container_path: str = "./tmp", id=None): | |
| """ | |
| :param api: KaggleApi | |
| :param kernel_slug: Notebook id, you can find it in the url of the notebook. | |
| For example, `username/notebook-name-123456` | |
| :param container_path: Path to the local folder where the notebook will be downloaded | |
| """ | |
| self.api = api | |
| self.kernel_slug = kernel_slug | |
| self.container_path = container_path | |
| self.id = id | |
| if self.id is None: | |
| print(f"Warn: {self.__class__.__name__}.id is None") | |
| def status(self) -> str or None: | |
| """ | |
| :return: | |
| "running" | |
| "cancelAcknowledged" | |
| "queued": waiting for run | |
| "error": when raise exception in notebook | |
| Throw exception when failed | |
| """ | |
| res = self.api.kernels_status(self.kernel_slug) | |
| print(f"Status: {res}") | |
| if res is None: | |
| if self.id is not None: | |
| sheet_logger.update_job_status(self.id, notebook_status='None') | |
| return None | |
| if self.id is not None: | |
| sheet_logger.update_job_status(self.id, notebook_status=res['status']) | |
| return res['status'] | |
| def _get_local_nb_path(self) -> str: | |
| return os.path.join(self.container_path, self.kernel_slug) | |
| def pull(self, path=None) -> str or None: | |
| """ | |
| :param path: | |
| :return: | |
| :raises: ApiException if notebook not found or not share to user | |
| """ | |
| self._clean() | |
| path = path or self._get_local_nb_path() | |
| metadata_path = os.path.join(path, "kernel-metadata.json") | |
| res = self.api.kernels_pull(self.kernel_slug, path=path, metadata=True, quiet=False) | |
| if not os.path.exists(metadata_path): | |
| print(f"Warn: Not found {metadata_path}. Clean {path}") | |
| self._clean() | |
| return None | |
| return res | |
| def push(self, path=None) -> str or None: | |
| status = self.status() | |
| if status in ['queued', 'running']: | |
| print("Warn: Notebook is " + status + ". Skip push notebook!") | |
| return None | |
| self.api.kernels_push(path or self._get_local_nb_path()) | |
| time.sleep(1) | |
| status = self.status() | |
| return status | |
| def _clean(self) -> None: | |
| if os.path.exists(self._get_local_nb_path()): | |
| shutil.rmtree(self._get_local_nb_path()) | |
| def get_metadata(self, path=None): | |
| path = path or self._get_local_nb_path() | |
| metadata_path = os.path.join(path, "kernel-metadata.json") | |
| if not os.path.exists(metadata_path): | |
| return None | |
| return json.loads(open(metadata_path).read()) | |
| def check_nb_permission(self) -> Union[tuple[bool], tuple[bool, None]]: | |
| status = self.status() # raise ApiException | |
| if status is None: | |
| return False, status | |
| return True, status | |
| def check_datasets_permission(self) -> bool: | |
| meta = self.get_metadata() | |
| if meta is None: | |
| print("Warn: cannot get metadata. Pull and try again?") | |
| dataset_sources = meta['dataset_sources'] | |
| ex_list = [] | |
| slugs = [] | |
| for dataset in dataset_sources: | |
| try: | |
| self.api.dataset_status(dataset) | |
| except ApiException as e: | |
| print(f"Error: {e.status} {e.reason} with dataset {dataset} in notebook {self.kernel_slug}") | |
| ex_list.append(e) | |
| slugs.append(self.kernel_slug) | |
| # return False | |
| if len(ex_list) > 0: | |
| raise ValidateException.from_api_exception_list(ex_list, slugs) | |
| return True | |
| class AccountTransactionManager: | |
| def __init__(self, acc_secret_dict: dict=None): | |
| """ | |
| :param acc_secret_dict: {username: secret} | |
| """ | |
| self.acc_secret_dict = acc_secret_dict | |
| if self.acc_secret_dict is None: | |
| self.acc_secret_dict = {} | |
| # self.api_dict = {username: KaggleApiWrapper(username, secret) for username, secret in acc_secret_dict.items()} | |
| # lock for each account to avoid concurrent use api | |
| self.lock_dict = {username: False for username in self.acc_secret_dict.keys()} | |
| self.state_lock = threading.Lock() | |
| def _get_api(self, username: str) -> KaggleApiWrapper: | |
| # return self.api_dict[username] | |
| return KaggleApiWrapper(username, self.acc_secret_dict[username]) | |
| def _get_lock(self, username: str) -> bool: | |
| return self.lock_dict[username] | |
| def _set_lock(self, username: str, lock: bool) -> None: | |
| self.lock_dict[username] = lock | |
| def add_account(self, username, secret): | |
| if username not in self.acc_secret_dict.keys(): | |
| self.state_lock.acquire() | |
| self.acc_secret_dict[username] = secret | |
| self.lock_dict[username] = False | |
| self.state_lock.release() | |
| def remove_account(self, username): | |
| if username in self.acc_secret_dict.keys(): | |
| self.state_lock.acquire() | |
| del self.acc_secret_dict[username] | |
| del self.lock_dict[username] | |
| self.state_lock.release() | |
| else: | |
| print(f"Warn: try to remove account not in the list: {username}, list: {self.acc_secret_dict.keys()}") | |
| def get_unlocked_api_unblocking(self, username_list: List[str]) -> tuple[KaggleApiWrapper, Callable[[], None]]: | |
| """ | |
| :param username_list: list of username | |
| :return: (api, release) where release is a function to release api | |
| """ | |
| while True: | |
| print("get_unlocked_api_unblocking" + str(username_list)) | |
| for username in username_list: | |
| self.state_lock.acquire() | |
| if not self._get_lock(username): | |
| self._set_lock(username, True) | |
| api = self._get_api(username) | |
| def release(): | |
| self.state_lock.acquire() | |
| self._set_lock(username, False) | |
| api.__del__() | |
| self.state_lock.release() | |
| self.state_lock.release() | |
| return api, release | |
| self.state_lock.release() | |
| time.sleep(1) | |
| class NbJob: | |
| def __init__(self, acc_dict: dict, nb_slug: str, rerun_stt: List[str] = None, not_rerun_stt: List[str] = None, id=None): | |
| """ | |
| :param acc_dict: | |
| :param nb_slug: | |
| :param rerun_stt: | |
| :param not_rerun_stt: If notebook status in this list, do not rerun it. (Note: do not add "queued", "running") | |
| """ | |
| self.rerun_stt = rerun_stt | |
| if self.rerun_stt is None: | |
| self.rerun_stt = ['complete'] | |
| self.not_rerun_stt = not_rerun_stt | |
| if self.not_rerun_stt is None: | |
| self.not_rerun_stt = ['queued', 'running', 'cancelAcknowledged'] | |
| assert "queued" in self.not_rerun_stt | |
| assert "running" in self.not_rerun_stt | |
| self.acc_dict = acc_dict | |
| self.nb_slug = nb_slug | |
| self.id = id | |
| def get_acc_dict(self): | |
| return self.acc_dict | |
| def get_username_list(self): | |
| return list(self.acc_dict.keys()) | |
| def is_valid_with_acc(self, api): | |
| """ | |
| :param api: | |
| :return: | |
| :raise: ValidationException | |
| """ | |
| notebook = KaggleNotebook(api, self.nb_slug, id=self.id) | |
| try: | |
| notebook.pull() # raise ApiException | |
| stt, _ = notebook.check_nb_permission() # note: raise ApiException | |
| stt = notebook.check_datasets_permission() # raise ValidationException | |
| except ApiException as e: | |
| raise ValidateException.from_api_exception(e, self.nb_slug) | |
| # if not stt: | |
| # return False | |
| return True | |
| def is_valid(self): | |
| for username in self.acc_dict.keys(): | |
| secrets = self.acc_dict[username] | |
| api = KaggleApiWrapper(username=username, secret=secrets) | |
| api.authenticate() | |
| if not self.is_valid_with_acc(api): | |
| return False | |
| return True | |
| def acc_check_and_rerun_if_need(self, api: KaggleApi) -> bool: | |
| """ | |
| :return: | |
| True if rerun success or notebook is running | |
| False user does not have enough gpu quotas | |
| :raises | |
| Exception if setup error | |
| """ | |
| notebook = KaggleNotebook(api, self.nb_slug, "./tmp", id=self.id) # todo: change hardcode container_path here | |
| notebook.pull() | |
| assert notebook.check_datasets_permission(), f"User {api} does not have permission on datasets of notebook {self.nb_slug}" | |
| success, status1 = notebook.check_nb_permission() | |
| assert success, f"User {api} does not have permission on notebook {self.nb_slug}" # todo: using api.username | |
| if status1 in self.rerun_stt: | |
| status2 = notebook.push() | |
| time.sleep(10) | |
| status3 = notebook.status() | |
| # if 3 times same stt -> acc out of quota | |
| if status1 == status2 == status3: | |
| sheet_logger.log(username=api.username, nb=self.nb_slug, log="Try but no effect. Seem account to be out of quota") | |
| return False | |
| if status3 in self.not_rerun_stt: | |
| # sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is {status3} is in ignore status list {self.not_rerun_stt}, do nothing!") | |
| sheet_logger.log(username=api.username, nb=self.nb_slug, | |
| log=f"Schedule notebook successfully. Current status is '{status3}'") | |
| return True | |
| if status3 not in ["queued", "running"]: | |
| # return False # todo: check when user is out of quota | |
| print(f"Error: status is {status3}") | |
| raise Exception("Setup exception") | |
| return True | |
| sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is '{status1}' is not in {self.rerun_stt}, do nothing!") | |
| return True | |
| def from_dict(obj: dict, id=None): | |
| return NbJob(acc_dict=obj['accounts'], nb_slug=obj['slug'], rerun_stt=obj.get('rerun_status'), not_rerun_stt=obj.get('not_rerun_stt'), id=id) | |
| class KernelRerunService: | |
| def __init__(self): | |
| self.jobs: Dict[str, NbJob] = {} | |
| self.acc_manager = AccountTransactionManager() | |
| self.username2jobid = {} | |
| self.jobid2username = {} | |
| def add_job(self, nb_job: NbJob): | |
| if nb_job.nb_slug in self.jobs.keys(): | |
| print("Warn: nb_job already in job list") | |
| return | |
| self.jobs[nb_job.nb_slug] = nb_job | |
| self.jobid2username[nb_job.nb_slug] = nb_job.get_username_list() | |
| for username in nb_job.get_username_list(): | |
| if username not in self.username2jobid.keys(): | |
| self.username2jobid[username] = [] | |
| self.acc_manager.add_account(username, nb_job.acc_dict[username]) | |
| self.username2jobid[username].append(nb_job.nb_slug) | |
| def remove_job(self, nb_job): | |
| if nb_job.nb_slug not in self.jobs.keys(): | |
| print("Warn: try to remove nb_job not in list") | |
| return | |
| username_list = self.jobid2username[nb_job.nb_slug] | |
| username_list = [username for username in username_list if len(self.username2jobid[username]) == 1] | |
| for username in username_list: | |
| del self.username2jobid[username] | |
| self.acc_manager.remove_account(username) | |
| del self.jobs[nb_job.nb_slug] | |
| del self.jobid2username[nb_job.nb_slug] | |
| def validate_all(self): | |
| for username in self.acc_manager.acc_secret_dict.keys(): | |
| api, release = self.acc_manager.get_unlocked_api_unblocking([username]) | |
| api.authenticate() | |
| print(f"Using username: {api.username}") | |
| for job in self.jobs.values(): | |
| ex_msg_list = [] | |
| if username in job.get_username_list(): | |
| print(f"Validate user: {username}, job: {job.nb_slug}") | |
| try: | |
| job.is_valid_with_acc(api) | |
| except ValidateException as e: | |
| print(f"Error: not valid") | |
| a = f"Setup error: {username} does not have permission on notebook {job.nb_slug} or related datasets" | |
| if job.id is not None: # if have id, write log | |
| ex_msg_list.append(f"Account {username}\n" + str(e) + "\n") | |
| else: # if not have id, raise | |
| raise Exception(a) | |
| if len(ex_msg_list) > 0: | |
| sheet_logger.update_job_status(job.id, validate_status="\n".join(ex_msg_list)) | |
| else: | |
| sheet_logger.update_job_status(job.id, validate_status="success") | |
| release() | |
| return True | |
| def status_all(self): | |
| for job in self.jobs.values(): | |
| print(f"Job: {job.nb_slug}") | |
| api, release = self.acc_manager.get_unlocked_api_unblocking(job.get_username_list()) | |
| api.authenticate() | |
| print(f"Using username: {api.username}") | |
| notebook = KaggleNotebook(api, job.nb_slug, id=job.id) | |
| print(f"Notebook: {notebook.kernel_slug}") | |
| print(notebook.status()) | |
| release() | |
| def run(self, nb_job: NbJob): | |
| username_list = copy.copy(nb_job.get_username_list()) | |
| while len(username_list) > 0: | |
| api, release = self.acc_manager.get_unlocked_api_unblocking(username_list) | |
| api.authenticate() | |
| print(f"Using username: {api.username}") | |
| try: | |
| result = nb_job.acc_check_and_rerun_if_need(api) | |
| if result: | |
| return True | |
| except Exception as e: | |
| print(e) | |
| release() | |
| break | |
| if api.username in username_list: | |
| username_list.remove(api.username) | |
| release() | |
| else: | |
| release() | |
| raise Exception("") | |
| return False | |
| def run_all(self): | |
| for job in self.jobs.values(): | |
| success = self.run(job) | |
| print(f"Job: {job.nb_slug} {success}") | |
| # if __name__ == "__main__": | |
| # service = KernelRerunService() | |
| # files = os.listdir("./config") | |
| # for file in files: | |
| # if '.example' not in file: | |
| # with open(os.path.join("./config", file), "r") as f: | |
| # obj = json.loads(f.read()) | |
| # print(obj) | |
| # service.add_job(NbJob.from_dict(obj)) | |
| # service.run_all() | |
| # try: | |
| # acc_secret_dict = { | |
| # "hahunavth": "secret", | |
| # "hahunavth2": "secret", | |
| # "hahunavth3": "secret", | |
| # "hahunavth4": "secret", | |
| # "hahunavth5": "secret", | |
| # } | |
| # acc_manager = AccountTransactionManager(acc_secret_dict) | |
| # | |
| # | |
| # def test1(): | |
| # username_list = ["hahunavth", "hahunavth2", "hahunavth3", "hahunavth4", "hahunavth5"] | |
| # while len(username_list) > 0: | |
| # api, release = acc_manager.get_unlocked_api_unblocking(username_list) | |
| # print("test1 is using " + api.username) | |
| # time.sleep(1) | |
| # release() | |
| # if api.username in username_list: | |
| # username_list.remove(api.username) | |
| # else: | |
| # raise Exception("") | |
| # print("test1 release " + api.username) | |
| # | |
| # | |
| # def test2(): | |
| # username_list = ["hahunavth2", "hahunavth3", "hahunavth5"] | |
| # while len(username_list) > 0: | |
| # api, release = acc_manager.get_unlocked_api_unblocking(username_list) | |
| # print("test2 is using " + api.username) | |
| # time.sleep(3) | |
| # release() | |
| # if api.username in username_list: | |
| # username_list.remove(api.username) | |
| # else: | |
| # raise Exception("") | |
| # print("test2 release " + api.username) | |
| # | |
| # | |
| # t1 = threading.Thread(target=test1) | |
| # t2 = threading.Thread(target=test2) | |
| # t1.start() | |
| # t2.start() | |
| # t1.join() | |
| # t2.join() | |
| # | |
| # # kgapi = KaggleApiWrapper("hahunavth", "fb3d65ea4d06f91a83cf571e9a39d40d") | |
| # # kgapi.authenticate() | |
| # # # kgapi = get_api() | |
| # # notebook = KaggleNotebook(kgapi, "hahunavth/ess-vlsp2023-denoising", "./tmp") | |
| # # # print(notebook.pull()) | |
| # # # print(notebook.check_datasets_permission()) | |
| # # print(notebook.check_nb_permission()) | |
| # # # print(notebook.status()) | |
| # # # notebook.push() | |
| # # # print(notebook.status()) | |
| # except ApiException as e: | |
| # print(e.status) | |
| # print(e.reason) | |
| # raise e | |
| # # 403 when nb not exists or not share to acc | |
| # # 404 when push to unknow kenel_slug.username | |
| # # 401 when invalid username, pass |