|
import torch |
|
import numpy as np |
|
import jax.numpy as jnp |
|
|
|
from transformers import AutoTokenizer |
|
from transformers import FlaxT5ForConditionalGeneration |
|
from transformers import T5ForConditionalGeneration |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./") |
|
model_fx = FlaxT5ForConditionalGeneration.from_pretrained("./") |
|
model_pt = T5ForConditionalGeneration.from_pretrained("./", from_flax=True) |
|
model_pt.save_pretrained("./") |
|
|
|
text = """Het is nog niet duidelijk |
|
welke hoogte het water nabij Venlo heeft bereikt. De hoogwaterpiek is vermoedelijk iets vlakker dan verwacht, maar blijft langer aanhouden, tot zondag 19.00 uur. Vooralsnog zijn er weinig meldingen over schade of overlast, meldt een woordvoerder van Veiligheidsregio Limburg-Noord zaterdag aan NU.nl. Via het Nationaal Rampenfonds is binnen één etmaal al 1 miljoen euro opgehaald voor gedupeerden. |
|
""" |
|
e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True) |
|
d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id |
|
|
|
print(e_input_ids_fx) |
|
print(d_input_ids_fx) |
|
|
|
e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True) |
|
d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id |
|
|
|
|
|
print(e_input_ids_pt) |
|
print(d_input_ids_pt) |
|
|
|
print() |
|
|
|
encoder_pt = model_fx.encode(**e_input_ids_pt) |
|
decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt) |
|
logits_pt = decoder_pt.logits |
|
print(f"Pytorch output: {logits_pt}") |
|
|
|
encoder_fx = model_fx.encode(**e_input_ids_fx) |
|
decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx) |
|
logits_fx = decoder_fx.logits |
|
print(f"Flax output: {logits_fx}") |
|
|