t5-base-dutch-demo / flax_to_pt.py
yhavinga's picture
Add pytorch model
522b344
import torch
import numpy as np
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("./")
model_fx = FlaxT5ForConditionalGeneration.from_pretrained("./")
model_pt = T5ForConditionalGeneration.from_pretrained("./", from_flax=True)
model_pt.save_pretrained("./")
text = """Het is nog niet duidelijk
welke hoogte het water nabij Venlo heeft bereikt. De hoogwaterpiek is vermoedelijk iets vlakker dan verwacht, maar blijft langer aanhouden, tot zondag 19.00 uur. Vooralsnog zijn er weinig meldingen over schade of overlast, meldt een woordvoerder van Veiligheidsregio Limburg-Noord zaterdag aan NU.nl. Via het Nationaal Rampenfonds is binnen één etmaal al 1 miljoen euro opgehaald voor gedupeerden.
"""
e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
print(e_input_ids_fx)
print(d_input_ids_fx)
e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
print(e_input_ids_pt)
print(d_input_ids_pt)
print()
encoder_pt = model_fx.encode(**e_input_ids_pt)
decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
logits_pt = decoder_pt.logits
print(f"Pytorch output: {logits_pt}")
encoder_fx = model_fx.encode(**e_input_ids_fx)
decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
logits_fx = decoder_fx.logits
print(f"Flax output: {logits_fx}")