j
first commit
2e37cc0
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import re
from evaluate_tokenizer import EvaluationTokenizer
import editdistance as ed
import torch
from transformers.pipelines.audio_utils import ffmpeg_read
import requests
from whisper_normalizer.english import EnglishTextNormalizer
from whisper_normalizer.basic import BasicTextNormalizer
from cn_tn import TextNorm
import zhconv
english_normalizer = EnglishTextNormalizer()
chinese_normalizer = TextNorm(
to_banjiao = False,
to_upper = False,
to_lower = False,
remove_fillers = False,
remove_erhua =False,
check_chars = False,
remove_space = False,
cc_mode = '',
)
basic_normalizer = BasicTextNormalizer()
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
PUNCS = '!,.?;:'
ds_collections = {
'librispeech': {'path': 'asr/librispeech_eval.jsonl','language': 'en'},
'aishell2': {'path': 'asr/aishell2_eval.jsonl', 'language': 'zh'},
'cv15_en': {'path': 'asr/cv15_asr_en_eval.jsonl', 'language': 'en'},
'cv15_zh': {'path': 'asr/cv15_asr_zh_eval.jsonl', 'language': 'zh'},
'cv15_yue': {'path': 'asr/cv15_asr_yue_eval.jsonl', 'language': 'yue'},
'cv15_fr': {'path': 'asr/cv15_asr_fr_eval.jsonl', 'language': 'fr'},
'fluers_zh': {'path': 'asr/fleurs_asr_zh_eval.jsonl', 'language': 'zh'},
}
class AudioDataset(torch.utils.data.Dataset):
def __init__(self, ds):
path = ds['path']
self.datas = open(path).readlines()
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
audio = data['audio']
source = data['source']
prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>"+data['prompt']
gt = data['gt']
return {
'audio': audio,
'prompt': prompt,
'source': source,
'gt': gt
}
def read_audio(audio_path):
if audio_path.startswith("http://") or audio_path.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(audio_path).content
else:
with open(audio_path, "rb") as f:
inputs = f.read()
return inputs
def collate_fn(inputs, processor):
input_texts = [_['prompt'] for _ in inputs]
source = [_['source'] for _ in inputs]
gt = [_['gt'] for _ in inputs]
audio_path = [_['audio'] for _ in inputs]
input_audios = [ffmpeg_read(read_audio(_['audio']),sampling_rate=processor.feature_extractor.sampling_rate) for _ in inputs]
inputs = processor(text=input_texts, audios=input_audios, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
return inputs, audio_path, source, gt
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def remove_sp(text, language):
gt = re.sub(r"<\|.*?\|>", " ", text)
gt = re.sub(rf"\s+", r" ", gt) # 将文本中的连续空格替换为单个空格
gt = re.sub(f" ?([{PUNCS}])", r"\1", gt)
gt = gt.lstrip(" ")
if language == "zh":
gt = re.sub(rf"\s+", r"", gt)
return gt
def compute_wer(refs, hyps, language):
distance = 0
ref_length = 0
tokenizer = EvaluationTokenizer(
tokenizer_type="none",
lowercase=True,
punctuation_removal=True,
character_tokenization=False,
)
for i in range(len(refs)):
ref = refs[i]
pred = hyps[i]
if language in ["yue"]:
ref = zhconv.convert(ref, 'zh-cn')
pred = zhconv.convert(pred, 'zh-cn')
if language in ["en"]:
ref = english_normalizer(ref)
pred = english_normalizer(pred)
if language in ["zh"]:
ref = chinese_normalizer(ref)
pred = chinese_normalizer(pred)
else:
ref = basic_normalizer(ref)
pred = basic_normalizer(pred)
ref_items = tokenizer.tokenize(ref).split()
pred_items = tokenizer.tokenize(pred).split()
if language in ["zh", "yue"]:
ref_items = [x for x in "".join(ref_items)]
pred_items = [x for x in "".join(pred_items)]
if i==0:
print(f"ref: {ref}")
print(f"pred: {pred}")
print(f"ref_items:\n{ref_items}\n{len(ref_items)}\n{ref_items[0]}")
print(f"pred_items:\n{pred_items}\n{len(ref_items)}\n{ref_items[0]}")
distance += ed.eval(ref_items, pred_items)
ref_length += len(ref_items)
return distance/ref_length
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='Qwen/Qwen2-Audio')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = Qwen2AudioForConditionalGeneration.from_pretrained(
args.checkpoint, device_map='cuda', torch_dtype='auto', trust_remote_code=True).eval()
processor = AutoProcessor.from_pretrained(args.checkpoint)
processor.tokenizer.padding_side = 'left'
random.seed(args.seed)
dataset = AudioDataset(
ds=ds_collections[args.dataset],
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, processor=processor),
)
gts = []
sources = []
rets = []
audio_paths = []
for _, (inputs, audio_path, source, gt) in tqdm(enumerate(data_loader)):
inputs['input_ids'] = inputs['input_ids'].to('cuda')
output_ids = model.generate(**inputs, max_new_tokens=256, min_new_tokens=1, do_sample=False)
output_ids = output_ids[:, inputs.input_ids.size(1):]
output = processor.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
gts.extend(gt)
rets.extend(output)
sources.extend(source)
audio_paths.extend(audio_path)
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_gts = [None for _ in range(world_size)]
merged_sources = [None for _ in range(world_size)]
merged_responses = [None for _ in range(world_size)]
merged_audio_paths = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_gts, gts)
torch.distributed.all_gather_object(merged_sources, sources)
torch.distributed.all_gather_object(merged_responses, rets)
torch.distributed.all_gather_object(merged_audio_paths, audio_paths)
merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)]
merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)]
merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)]
merged_responses = [
_ for _ in itertools.chain.from_iterable(merged_responses)
]
if torch.distributed.get_rank() == 0:
print(f"Evaluating {args.dataset} ...")
results = []
for gt, response, source, audio_path in zip(merged_gts, merged_responses, merged_sources, merged_audio_paths):
results.append({
'gt': gt,
'response': response,
'source': source,
'audio_path': audio_path,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
json.dump(results, open(results_file, 'w'))
results_dict = {}
for item in tqdm(results):
source = item["source"]
results_dict.setdefault(source, []).append(item)
lan = ds_collections[args.dataset]['language']
for source in results_dict:
refs, hyps = [], []
results_list = results_dict[source]
for result in results_list:
gt = result["gt"]
response = result["response"]
gt = remove_sp(gt, lan)
response = remove_sp(response, lan)
refs.append(gt)
hyps.append(response)
wer = compute_wer(refs, hyps, lan)
print(f"source: {source} cnt: {len(refs)} wer: {wer:.4f}")
torch.distributed.barrier()