| import torch | |
| from transformers import PreTrainedTokenizerBase | |
| def safe_decode(tokenizer: PreTrainedTokenizerBase, outputs: torch.Tensor): | |
| # Workaround to make SentencePiece .decode() keep leading spaces in a token | |
| fake_token = tokenizer("^")["input_ids"][0] | |
| result = tokenizer.decode([fake_token] + outputs.tolist()) | |
| # We use .lstrip() since SentencePiece may add leading spaces, e.g. if the outputs are "</s>" | |
| return result.lstrip()[1:] | |