import argparse import itertools import json import os import random import time from functools import partial import re from evaluate_tokenizer import EvaluationTokenizer import editdistance as ed import torch from transformers.pipelines.audio_utils import ffmpeg_read import requests from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer from cn_tn import TextNorm import zhconv english_normalizer = EnglishTextNormalizer() chinese_normalizer = TextNorm( to_banjiao = False, to_upper = False, to_lower = False, remove_fillers = False, remove_erhua =False, check_chars = False, remove_space = False, cc_mode = '', ) basic_normalizer = BasicTextNormalizer() from tqdm import tqdm from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration PUNCS = '!,.?;:' ds_collections = { 'librispeech': {'path': 'asr/librispeech_eval.jsonl','language': 'en'}, 'aishell2': {'path': 'asr/aishell2_eval.jsonl', 'language': 'zh'}, 'cv15_en': {'path': 'asr/cv15_asr_en_eval.jsonl', 'language': 'en'}, 'cv15_zh': {'path': 'asr/cv15_asr_zh_eval.jsonl', 'language': 'zh'}, 'cv15_yue': {'path': 'asr/cv15_asr_yue_eval.jsonl', 'language': 'yue'}, 'cv15_fr': {'path': 'asr/cv15_asr_fr_eval.jsonl', 'language': 'fr'}, 'fluers_zh': {'path': 'asr/fleurs_asr_zh_eval.jsonl', 'language': 'zh'}, } class AudioDataset(torch.utils.data.Dataset): def __init__(self, ds): path = ds['path'] self.datas = open(path).readlines() def __len__(self): return len(self.datas) def __getitem__(self, idx): data = json.loads(self.datas[idx].strip()) audio = data['audio'] source = data['source'] prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt'] gt = data['gt'] return { 'audio': audio, 'prompt': prompt, 'source': source, 'gt': gt } def read_audio(audio_path): if audio_path.startswith("http://") or audio_path.startswith("https://"): # We need to actually check for a real protocol, otherwise it's impossible to use a local file # like http_huggingface_co.png inputs = requests.get(audio_path).content else: with open(audio_path, "rb") as f: inputs = f.read() return inputs def collate_fn(inputs, processor): input_texts = [_['prompt'] for _ in inputs] source = [_['source'] for _ in inputs] gt = [_['gt'] for _ in inputs] audio_path = [_['audio'] for _ in inputs] input_audios = [ffmpeg_read(read_audio(_['audio']),sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs] inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True) return inputs, audio_path, source, gt class InferenceSampler(torch.utils.data.sampler.Sampler): def __init__(self, size): self._size = int(size) assert size > 0 self._rank = torch.distributed.get_rank() self._world_size = torch.distributed.get_world_size() self._local_indices = self._get_local_indices(size, self._world_size, self._rank) @staticmethod def _get_local_indices(total_size, world_size, rank): shard_size = total_size // world_size left = total_size % world_size shard_sizes = [shard_size + int(r < left) for r in range(world_size)] begin = sum(shard_sizes[:rank]) end = min(sum(shard_sizes[:rank + 1]), total_size) return range(begin, end) def __iter__(self): yield from self._local_indices def __len__(self): return len(self._local_indices) def remove_sp(text, language): gt = re.sub(r"<\|.*?\|>", " ", text) gt = re.sub(rf"\s+", r" ", gt) # 将文本中的连续空格替换为单个空格 gt = re.sub(f" ?([{PUNCS}])", r"\1", gt) gt = gt.lstrip(" ") if language == "zh": gt = re.sub(rf"\s+", r"", gt) return gt def compute_wer(refs, hyps, language): distance = 0 ref_length = 0 tokenizer = EvaluationTokenizer( tokenizer_type="none", lowercase=True, punctuation_removal=True, character_tokenization=False, ) for i in range(len(refs)): ref = refs[i] pred = hyps[i] if language in ["yue"]: ref = zhconv.convert(ref, 'zh-cn') pred = zhconv.convert(pred, 'zh-cn') if language in ["en"]: ref = english_normalizer(ref) pred = english_normalizer(pred) if language in ["zh"]: ref = chinese_normalizer(ref) pred = chinese_normalizer(pred) else: ref = basic_normalizer(ref) pred = basic_normalizer(pred) ref_items = tokenizer.tokenize(ref).split() pred_items = tokenizer.tokenize(pred).split() if language in ["zh", "yue"]: ref_items = [x for x in "".join(ref_items)] pred_items = [x for x in "".join(pred_items)] if i==0: print(f"ref: {ref}") print(f"pred: {pred}") print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}") print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}") distance += ed.eval(ref_items, pred_items) ref_length += len(ref_items) return distance/ref_length if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio') parser.add_argument('--dataset', type=str, default='') parser.add_argument('--batch-size', type=int, default=1) parser.add_argument('--num-workers', type=int, default=1) parser.add_argument('--seed', type=int, default=0) args = parser.parse_args() torch.distributed.init_process_group( backend='nccl', world_size=int(os.getenv('WORLD_SIZE', '1')), rank=int(os.getenv('RANK', '0')), ) torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) model = Qwen2AudioForConditionalGeneration.from_pretrained( args.checkpoint, device_map='cuda', torch_dtype='auto', trust_remote_code=True).eval() processor = AutoProcessor.from_pretrained(args.checkpoint) processor.tokenizer.padding_side = 'left' random.seed(args.seed) dataset = AudioDataset( ds=ds_collections[args.dataset], ) data_loader = torch.utils.data.DataLoader( dataset=dataset, sampler=InferenceSampler(len(dataset)), batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False, collate_fn=partial(collate_fn, processor=processor), ) gts = [] sources = [] rets = [] audio_paths = [] for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)): inputs['input_ids'] = inputs['input_ids'].to('cuda') output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False) output_ids = output_ids[:, inputs.input_ids.size(1):] output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) gts.extend(gt) rets.extend(output) sources.extend(source) audio_paths.extend(audio_path) torch.distributed.barrier() world_size = torch.distributed.get_world_size() merged_gts = [None for _ in range(world_size)] merged_sources = [None for _ in range(world_size)] merged_responses = [None for _ in range(world_size)] merged_audio_paths = [None for _ in range(world_size)] torch.distributed.all_gather_object(merged_gts, gts) torch.distributed.all_gather_object(merged_sources, sources) torch.distributed.all_gather_object(merged_responses, rets) torch.distributed.all_gather_object(merged_audio_paths, audio_paths) merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)] merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)] merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)] merged_responses = [ _ for _ in itertools.chain.from_iterable(merged_responses) ] if torch.distributed.get_rank() == 0: print(f"Evaluating {args.dataset} ...") results = [] for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths): results.append({ 'gt': gt, 'response': response, 'source': source, 'audio_path': audio_path, }) time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) results_file = f'{args.dataset}_{time_prefix}.json' json.dump(results, open(results_file, 'w')) results_dict = {} for item in tqdm(results): source = item["source"] results_dict.setdefault(source, []).append(item) lan = ds_collections[args.dataset]['language'] for source in results_dict: refs, hyps = [], [] results_list = results_dict[source] for result in results_list: gt = result["gt"] response = result["response"] gt = remove_sp(gt, lan) response = remove_sp(response, lan) refs.append(gt) hyps.append(response) wer = compute_wer(refs, hyps, lan) print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}") torch.distributed.barrier()