Spaces:
Runtime error
Runtime error
| from .constants import ( | |
| BOX_END_TOKEN, | |
| BOX_START_TOKEN, | |
| IMG_CONTEXT_TOKEN, | |
| IMG_END_TOKEN, | |
| IMG_START_TOKEN, | |
| IMG_TAG_TOKEN, | |
| PATCH_CONTEXT_TOKEN, | |
| PATCH_END_TOKEN, | |
| PATCH_START_TOKEN, | |
| QUAD_END_TOKEN, | |
| QUAD_START_TOKEN, | |
| REF_END_TOKEN, | |
| REF_START_TOKEN, | |
| VID_CONTEXT_TOKEN, | |
| VID_END_TOKEN, | |
| VID_START_TOKEN, | |
| VID_TAG_TOKEN, | |
| ) | |
| def update_tokenizer(tokenizer): | |
| token_list = [ | |
| IMG_START_TOKEN, | |
| IMG_END_TOKEN, | |
| IMG_CONTEXT_TOKEN, | |
| VID_START_TOKEN, | |
| VID_END_TOKEN, | |
| VID_CONTEXT_TOKEN, | |
| PATCH_START_TOKEN, | |
| PATCH_END_TOKEN, | |
| PATCH_CONTEXT_TOKEN, | |
| QUAD_START_TOKEN, | |
| QUAD_END_TOKEN, | |
| REF_START_TOKEN, | |
| REF_END_TOKEN, | |
| BOX_START_TOKEN, | |
| BOX_END_TOKEN, | |
| IMG_TAG_TOKEN, | |
| VID_TAG_TOKEN, | |
| ] | |
| num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True) | |
| # print(f"tokenizer {tokenizer}") | |
| return tokenizer | |
| def update_tokenizer_for_s2s(tokenizer, model_type): | |
| if model_type is None: | |
| return update_tokenizer(tokenizer) | |
| if model_type == "glm4voice": | |
| from .tokenizer_glm4voice import update_tokenizer_for_glm4voice, GLM4VoiceTokenizer | |
| return update_tokenizer_for_glm4voice(tokenizer) | |
| if model_type == "cosyvoice2": | |
| from .tokenizer_cosyvoice2 import update_tokenizer_for_cosyvoice2, CosyVoice2Tokenizer | |
| return update_tokenizer_for_cosyvoice2(tokenizer) | |
| if model_type == "snac24khz": | |
| from .tokenizer_snac import update_tokenizer_for_snac, SNACTokenizer | |
| return update_tokenizer_for_snac(tokenizer) | |
| if model_type == "sensevoice_sparktts": | |
| from .tokenizer_sensevoice_sparktts import ( | |
| update_tokenizer_for_sensevoice_sparktts, | |
| SenseVoiceSparkTTSTokenizer, | |
| ) | |
| return update_tokenizer_for_sensevoice_sparktts(tokenizer) | |
| if model_type == "sensevoice_glm4voice": | |
| from .tokenizer_sensevoice_glm4voice import ( | |
| update_tokenizer_for_sensevoice_glm4voice, | |
| SenseVoiceGLM4VoiceTokenizer, | |
| ) | |
| return update_tokenizer_for_sensevoice_glm4voice(tokenizer) | |
| raise NotImplementedError | |
| def get_audio_tokenizer(model_name_or_path, model_type, flow_path=None, rank=None): | |
| if model_type is None: | |
| return None | |
| if model_type == "glm4voice": | |
| from .tokenizer_glm4voice import update_tokenizer_for_glm4voice, GLM4VoiceTokenizer | |
| return GLM4VoiceTokenizer(model_name_or_path, flow_path=flow_path, rank=rank) | |
| if model_type == "cosyvoice2": | |
| from .tokenizer_cosyvoice2 import update_tokenizer_for_cosyvoice2, CosyVoice2Tokenizer | |
| return CosyVoice2Tokenizer(model_name_or_path, rank=rank) | |
| if model_type == "snac24khz": | |
| from .tokenizer_snac import update_tokenizer_for_snac, SNACTokenizer | |
| return SNACTokenizer(model_name_or_path, rank=rank) | |
| if model_type == "sensevoice_sparktts": | |
| from .tokenizer_sensevoice_sparktts import ( | |
| update_tokenizer_for_sensevoice_sparktts, | |
| SenseVoiceSparkTTSTokenizer, | |
| ) | |
| return SenseVoiceSparkTTSTokenizer(model_name_or_path, rank=rank) | |
| if model_type == "sensevoice_glm4voice": | |
| from .tokenizer_sensevoice_glm4voice import ( | |
| update_tokenizer_for_sensevoice_glm4voice, | |
| SenseVoiceGLM4VoiceTokenizer, | |
| ) | |
| return SenseVoiceGLM4VoiceTokenizer(model_name_or_path, flow_path=flow_path, rank=rank) | |
| raise NotImplementedError | |