|
|
|
import keras_nlp |
|
import keras |
|
import tensorflow.data as tf_data |
|
import tensorflow as tf |
|
from tensorflow.keras.optimizers import Adam |
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint |
|
import re |
|
from sacrebleu.metrics import CHRF |
|
import time |
|
|
|
|
|
MAX_SEQUENCE_LENGTH = 64 |
|
eval_samples = 100 |
|
|
|
|
|
transformer = keras.models.load_model('models_europarl/en_cs_translator_saved_20231209_0046.keras') |
|
def read_files(path, lowercase = False): |
|
with open(path, "r", encoding="utf-8") as f: |
|
dataset_split = f.read().split("\n")[:-1] |
|
|
|
if(lowercase): |
|
dataset_split = [line.lower() for line in dataset_split] |
|
return dataset_split |
|
|
|
en_vocab = read_files("tokenizers/en_europarl_vocab") |
|
cs_vocab = read_files("tokenizers/cs_europarl_vocab") |
|
en_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( |
|
vocabulary=en_vocab, |
|
lowercase=False |
|
) |
|
cs_tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( |
|
vocabulary=cs_vocab, |
|
lowercase=False |
|
) |
|
|
|
def compute_probabilities(logits): |
|
return keras.activations.softmax(logits) |
|
|
|
def next_token_logits(encoder_input_tokens, prompt, predicted_token_index): |
|
logits = transformer( |
|
[tf.expand_dims(encoder_input_tokens, axis=0), tf.expand_dims(prompt, axis=0)] |
|
)[:, predicted_token_index-1, :] |
|
return logits |
|
|
|
|
|
def greedy_decode(encoder_input_tokens, prompt, end_token_id): |
|
start_index = 1 |
|
current_prompt = prompt |
|
for predicted_token_index in range(start_index, MAX_SEQUENCE_LENGTH): |
|
next_logits = next_token_logits(encoder_input_tokens, current_prompt, predicted_token_index) |
|
next_probabilities = compute_probabilities(next_logits) |
|
max_probability_token_id = tf.argmax(next_probabilities, axis=-1) |
|
indices = tf.constant([[predicted_token_index]]) |
|
data = tf.constant([max_probability_token_id.numpy()[0]]) |
|
current_prompt = tf.tensor_scatter_nd_update(current_prompt, indices, data) |
|
|
|
if max_probability_token_id == end_token_id: |
|
break |
|
return current_prompt |
|
|
|
|
|
|
|
def beam_decode(encoder_input_tokens, prompt, end_token_id, beam_size): |
|
start_index = 1 |
|
|
|
next_logits = next_token_logits(encoder_input_tokens, prompt, start_index) |
|
next_probabilities = compute_probabilities(next_logits) |
|
top_k_probabilities, top_k_token_indices = tf.math.top_k(next_probabilities, k=beam_size) |
|
current_subsequencies = [] |
|
for index, value in enumerate(top_k_token_indices.numpy()[0]): |
|
|
|
indices = tf.constant([[start_index]]) |
|
data = tf.constant([value]) |
|
current_prompt = tf.tensor_scatter_nd_update(prompt, indices, data) |
|
|
|
log_prob = tf.math.log(top_k_probabilities.numpy()[0][index]) |
|
current_subsequencies.append((current_prompt, log_prob, log_prob)) |
|
|
|
final_potential_solutions = [] |
|
for predicted_token_index in range(start_index+1, MAX_SEQUENCE_LENGTH): |
|
|
|
if len(final_potential_solutions) == beam_size: |
|
break |
|
|
|
tmp_subsequencies = [] |
|
for index, (subseq_prompt, subseq_log_probability, _) in enumerate(current_subsequencies): |
|
next_logits = next_token_logits(encoder_input_tokens, subseq_prompt, predicted_token_index) |
|
next_probabilities = compute_probabilities(next_logits) |
|
top_k_probabilities, top_k_token_indices = tf.math.top_k(next_probabilities, k=beam_size-len(final_potential_solutions)) |
|
for index, value in enumerate(top_k_token_indices.numpy()[0]): |
|
|
|
indices = tf.constant([[predicted_token_index]]) |
|
data = tf.constant([value]) |
|
updated_subseq_prompt = tf.tensor_scatter_nd_update(subseq_prompt, indices, data) |
|
|
|
nextLogProbability = tf.math.log(top_k_probabilities.numpy()[0][index]) |
|
tmp_subsequencies.append((updated_subseq_prompt, subseq_log_probability + nextLogProbability, (subseq_log_probability + nextLogProbability)/(predicted_token_index+1))) |
|
|
|
current_subsequencies = [] |
|
current_sequences_to_find = beam_size - len(final_potential_solutions) |
|
tmp_subsequencies = sorted(tmp_subsequencies, key=lambda x: x[2], reverse=True) |
|
for i in range(current_sequences_to_find): |
|
if tmp_subsequencies[i][0][predicted_token_index] == end_token_id: |
|
final_potential_solutions.append(tmp_subsequencies[i]) |
|
else: |
|
current_subsequencies.append(tmp_subsequencies[i]) |
|
|
|
|
|
final_potential_solutions = sorted(final_potential_solutions, key=lambda x: x[2], reverse=True) |
|
|
|
if len(final_potential_solutions) > 0: |
|
return final_potential_solutions[0][0] |
|
|
|
else: |
|
sorted_subs = sorted(current_subsequencies, key=lambda x: x[2], reverse=True) |
|
return sorted_subs[0][0] |
|
|
|
|
|
def decode_sequences(input_sentence): |
|
|
|
|
|
encoder_input_tokens = en_tokenizer(input_sentence) |
|
|
|
if len(encoder_input_tokens) < MAX_SEQUENCE_LENGTH: |
|
pads = tf.fill((MAX_SEQUENCE_LENGTH - len(encoder_input_tokens)), 0) |
|
encoder_input_tokens = tf.concat([encoder_input_tokens, pads], 0) |
|
if len(encoder_input_tokens) > MAX_SEQUENCE_LENGTH: |
|
tensor_content = "[START] Exceeded. [END] [PAD] [PAD] [PAD] [PAD]" |
|
tensor = tf.constant([tensor_content], dtype=tf.string) |
|
return tensor |
|
|
|
start = tf.fill((1), cs_tokenizer.token_to_id("[START]")) |
|
pads = tf.fill((MAX_SEQUENCE_LENGTH - 1), cs_tokenizer.token_to_id("[PAD]")) |
|
prompt = tf.concat((start, pads), axis=-1) |
|
|
|
end_token_id = cs_tokenizer.token_to_id("[END]") |
|
|
|
generated_tokens = greedy_decode(encoder_input_tokens, prompt, end_token_id) |
|
|
|
|
|
generated_sentences = cs_tokenizer.detokenize(tf.expand_dims(generated_tokens, axis=0)) |
|
return generated_sentences |
|
|
|
|
|
test_en = read_files('datasets/europarl/test-cs-en.en') |
|
test_cs = read_files('datasets/europarl/test-cs-en.cs') |
|
bleu_metrics = keras_nlp.metrics.Bleu( |
|
name="bleu", |
|
tokenizer = cs_tokenizer |
|
) |
|
|
|
|
|
chrf = CHRF() |
|
refs = test_cs[:eval_samples] |
|
translations = [] |
|
start_time = time.time() |
|
|
|
for i in range(len(refs)): |
|
|
|
cs_translated = decode_sequences(test_en[i]) |
|
cs_translated = cs_translated.numpy()[0].decode("utf-8") |
|
cs_translated = ( |
|
cs_translated.replace("[PAD]", "") |
|
.replace("[START]", "") |
|
.replace("[END]", "") |
|
.strip() |
|
) |
|
|
|
cs_translated = re.sub(r'\s+([.,;!?:])', r'\1', cs_translated) |
|
print(cs_translated, flush=True) |
|
translations.append(cs_translated) |
|
|
|
end_time = time.time() |
|
|
|
|
|
|
|
|
|
refs_twodim = [[ref] for ref in refs] |
|
bleu_metrics(refs_twodim, translations) |
|
|
|
print("evaluating chrf", flush=True) |
|
chrf2_result = chrf.corpus_score(translations, refs_twodim) |
|
|
|
print("chrf2") |
|
print(chrf2_result) |
|
print("bleu") |
|
print(bleu_metrics.result().numpy()) |
|
print("elapsed time") |
|
elapsed_time = end_time - start_time |
|
print(elapsed_time) |
|
|