Spaces:
Runtime error
Runtime error
| from functools import lru_cache | |
| from typing import Mapping | |
| from huggingface_hub import hf_hub_download | |
| from imgutils.data import ImageTyping, load_image | |
| from onnx_ import _open_onnx_model | |
| from preprocess import _img_encode | |
| _LABELS = ['3d', 'bangumi', 'comic', 'illustration'] | |
| _CLS_MODELS = [ | |
| 'caformer_s36', | |
| 'caformer_s36_plus', | |
| 'mobilenetv3', | |
| 'mobilenetv3_dist', | |
| 'mobilenetv3_sce', | |
| 'mobilenetv3_sce_dist', | |
| 'mobilevitv2_150', | |
| ] | |
| _DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist' | |
| def _open_anime_classify_model(model_name): | |
| return _open_onnx_model(hf_hub_download( | |
| f'deepghs/anime_classification', | |
| f'{model_name}/model.onnx', | |
| )) | |
| def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: | |
| image = load_image(image, mode='RGB') | |
| input_ = _img_encode(image, size=(size, size))[None, ...] | |
| output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) | |
| values = dict(zip(_LABELS, map(lambda x: x.item(), output[0]))) | |
| return values | |