View file src/colab/hf_translation_pt.py - Download
# -*- coding: utf-8 -*-
"""hf_translation_pt.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1PGFkeBtmD-LsEIwBptwODyyb-Pz998rO
https://huggingface.co/docs/transformers/tasks/translation
"""
!pip install transformers==4.28.0 datasets evaluate sacrebleu
import torch
# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {torch_device}")
from transformers import AutoTokenizer
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
!pip install accelerate==0.20.1
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from google.colab import drive
drive.mount("/content/drive")
model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(torch_device)
model_en_fr.load_state_dict(torch.load("/content/drive/My Drive/hf_translation_en_fr.pth"))
model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(torch_device)
model_fr_en.load_state_dict(torch.load("/content/drive/My Drive/hf_translation_fr_en.pth"))
text_en = "translate English to French: Legumes share resources with nitrogen-fixing bacteria."
print("Tokenize...")
inputs_en = tokenizer(text_en, return_tensors="pt").input_ids.to(torch_device)
outputs_fr = model_en_fr.generate(inputs_en, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
print(tokenizer.decode(outputs_fr[0], skip_special_tokens=True))
text_fr = "translate French to English: Les légumineuses partagent des ressources avec les bactéries fixatrices d'azote."
print("Tokenize...")
inputs_fr = tokenizer(text_fr, return_tensors="pt").input_ids.to(torch_device)
outputs_en = model_en_fr.generate(inputs_fr, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
print(tokenizer.decode(outputs_en[0], skip_special_tokens=True))
def en_to_fr(en):
en = "translate English to French: " + en
tok_en = tokenizer(en, return_tensors="pt").input_ids.to(torch_device)
tok_fr = model_en_fr.generate(tok_en, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
fr = tokenizer.decode(tok_fr[0], skip_special_tokens=True)
return fr
def fr_to_en(fr):
fr = "translate French to English: " + fr
tok_fr = tokenizer(fr, return_tensors="pt").input_ids.to(torch_device)
tok_en = model_fr_en.generate(tok_fr, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
en = tokenizer.decode(tok_en[0], skip_special_tokens=True)
return en
print(en_to_fr("Legumes share resource with nitrogen-fixing bacteria."))
print(fr_to_en("Les légumineuses partagent des ressources avec les bactéries fixatrices d'azote."))
print(fr_to_en("queue"))
print(fr_to_en("air"))