View file src/colab/llama_custom_generation.py - Download

# -*- coding: utf-8 -*-
"""llama_custom_generation.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/19cyGY1famffEISeonhndnsvjX_zX9LTq
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Chargement du modèle et du tokenizer
model_name = "Upstage/SOLAR-10.7B-Instruct-v1.0"
print("Chargement du tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("")

print("Chargement du modèle")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)

# Préparer l'entrée
question = "Quelle est la capitale de la France ?"
prompt = "### User:\n" + question + "\n\n### Assistant:\n"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
print(f"input_ids: {input_ids}")

# Paramètres de génération
max_new_tokens = 4096
temperature = 0.7
top_p = 0.9
eos_token_id = tokenizer.eos_token_id

# Mode évaluation
model.eval()

# Liste des tokens générés (on commence avec les tokens d'entrée)
generated = input_ids

print("Generation loop")
for _ in range(max_new_tokens):
    with torch.no_grad():
        # print("")
        # print("Next step")

        # Obtenir les logits pour les derniers tokens uniquement
        # print(f"*** generate *** generated:{generated.shape} = {generated}")
        outputs = model(input_ids=generated)
        # print(f"*** generate *** logits:{outputs.logits.shape} = {outputs.logits}")
        next_token_logits = outputs.logits[:, -1, :]  # [batch, vocab]
        # print(f"*** generate *** next_token_logits:{next_token_logits.shape} = {next_token_logits}")

        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        print(f"*** generate *** Next token: {next_token}")

        # Ajouter le token généré
        generated = torch.cat([generated, next_token], dim=-1)

        # Arrêt si EOS token
        if next_token.item() == eos_token_id:
            break

        # break  # Uncomment to generate only the first token

print(generated[0])

# Décodage du résultat
output_text = tokenizer.decode(generated[0], skip_special_tokens=True)
print(output_text)