View file src/colab/llama_custom_generation.py - Download
# -*- coding: utf-8 -*-
"""llama_custom_generation.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/19cyGY1famffEISeonhndnsvjX_zX9LTq
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Chargement du modèle et du tokenizer
model_name = "Upstage/SOLAR-10.7B-Instruct-v1.0"
print("Chargement du tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("")
print("Chargement du modèle")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# Préparer l'entrée
question = "Quelle est la capitale de la France ?"
prompt = "### User:\n" + question + "\n\n### Assistant:\n"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
print(f"input_ids: {input_ids}")
# Paramètres de génération
max_new_tokens = 4096
temperature = 0.7
top_p = 0.9
eos_token_id = tokenizer.eos_token_id
# Mode évaluation
model.eval()
# Liste des tokens générés (on commence avec les tokens d'entrée)
generated = input_ids
print("Generation loop")
for _ in range(max_new_tokens):
with torch.no_grad():
# print("")
# print("Next step")
# Obtenir les logits pour les derniers tokens uniquement
# print(f"*** generate *** generated:{generated.shape} = {generated}")
outputs = model(input_ids=generated)
# print(f"*** generate *** logits:{outputs.logits.shape} = {outputs.logits}")
next_token_logits = outputs.logits[:, -1, :] # [batch, vocab]
# print(f"*** generate *** next_token_logits:{next_token_logits.shape} = {next_token_logits}")
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
print(f"*** generate *** Next token: {next_token}")
# Ajouter le token généré
generated = torch.cat([generated, next_token], dim=-1)
# Arrêt si EOS token
if next_token.item() == eos_token_id:
break
# break # Uncomment to generate only the first token
print(generated[0])
# Décodage du résultat
output_text = tokenizer.decode(generated[0], skip_special_tokens=True)
print(output_text)