View file src/colab/chat.py - Download
# -*- coding: utf-8 -*-
"""chat.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1y4io4hQ-A4j0n9QTvjzNMAvPRIBKKCEj
https://www.atpostlight.com/?_=%2Fupstage%2FSOLAR-10.7B-Instruct-v1.0%23KJWqMdlUlBnjPuoSWRPngYr2fc9jFA%3D%3D
"""
!pip install accelerate
# import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
model_name = "Upstage/SOLAR-10.7B-Instruct-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
)
# --- Ici on crée une GenerationConfig pour éviter le warning ---
print("GenerationConfig...")
gen_config = GenerationConfig(
max_new_tokens=4096, # remplace max_length
use_cache=True, # même option que use_cache dans generate()
do_sample=False # ou True si tu veux échantillonner
)
import textwrap
assistant_prefix = "### Assistant:\n"
user_prefix = "### User:\n"
def last_answer(text):
index = text.rfind(assistant_prefix)
if index == -1:
return ""
return text[index+len(assistant_prefix):]
def wrap_preserve_newlines(text, width=80):
# Découper le texte en lignes selon les retours à la ligne d'origine
lines = text.splitlines()
# Appliquer textwrap.wrap à chaque ligne séparément
wrapped_lines = [wrapped_line
for line in lines
for wrapped_line in textwrap.wrap(line, width=width)
or ['']] # Si ligne vide, forcer une ligne vide
# Rejoindre les lignes
return '\n'.join(wrapped_lines)
def chat():
prompt = ""
while True:
prompt += user_prefix
# print("")
# print(datetime.datetime.now())
# print("")
user_input = input(">>> ")
if user_input == "bye":
print("Good bye!\n")
break
print("")
print("You said: " + user_input)
# print("")
# print(datetime.datetime.now())
# print("")
prompt += user_input + "\n\n" + assistant_prefix
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print("")
outputs = model.generate(**inputs, generation_config=gen_config)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
prompt = output_text + "\n"
answer = last_answer(output_text)
answer_wrapped = wrap_preserve_newlines(answer, 80)
print(answer_wrapped)
print("")
chat()