View file src/colab/hf_translation_pt.py - Download

# -*- coding: utf-8 -*-
"""hf_translation_pt.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1PGFkeBtmD-LsEIwBptwODyyb-Pz998rO

https://huggingface.co/docs/transformers/tasks/translation
"""

!pip install transformers==4.28.0 datasets evaluate sacrebleu

import torch

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {torch_device}")


from transformers import AutoTokenizer

checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

!pip install accelerate==0.20.1

from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

from google.colab import drive
drive.mount("/content/drive")

model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(torch_device)
model_en_fr.load_state_dict(torch.load("/content/drive/My Drive/hf_translation_en_fr.pth"))

model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(torch_device)
model_fr_en.load_state_dict(torch.load("/content/drive/My Drive/hf_translation_fr_en.pth"))

text_en = "translate English to French: Legumes share resources with nitrogen-fixing bacteria."

print("Tokenize...")
inputs_en = tokenizer(text_en, return_tensors="pt").input_ids.to(torch_device)

outputs_fr = model_en_fr.generate(inputs_en, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
print(tokenizer.decode(outputs_fr[0], skip_special_tokens=True))

text_fr = "translate French to English: Les légumineuses partagent des ressources avec les bactéries fixatrices d'azote."

print("Tokenize...")
inputs_fr = tokenizer(text_fr, return_tensors="pt").input_ids.to(torch_device)

outputs_en = model_en_fr.generate(inputs_fr, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
print(tokenizer.decode(outputs_en[0], skip_special_tokens=True))

def en_to_fr(en):
  en = "translate English to French: " + en
  tok_en = tokenizer(en, return_tensors="pt").input_ids.to(torch_device)
  tok_fr = model_en_fr.generate(tok_en, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
  fr = tokenizer.decode(tok_fr[0], skip_special_tokens=True)
  return fr

def fr_to_en(fr):
  fr = "translate French to English: " + fr
  tok_fr = tokenizer(fr, return_tensors="pt").input_ids.to(torch_device)
  tok_en = model_fr_en.generate(tok_fr, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95)
  en = tokenizer.decode(tok_en[0], skip_special_tokens=True)
  return en

print(en_to_fr("Legumes share resource with nitrogen-fixing bacteria."))
print(fr_to_en("Les légumineuses partagent des ressources avec les bactéries fixatrices d'azote."))
print(fr_to_en("queue"))
print(fr_to_en("air"))