View file src/colab/chat.py - Download

# -*- coding: utf-8 -*-
"""chat.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1y4io4hQ-A4j0n9QTvjzNMAvPRIBKKCEj

https://www.atpostlight.com/?_=%2Fupstage%2FSOLAR-10.7B-Instruct-v1.0%23KJWqMdlUlBnjPuoSWRPngYr2fc9jFA%3D%3D
"""

!pip install accelerate

# import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model_name = "Upstage/SOLAR-10.7B-Instruct-v1.0"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
)

# --- Ici on crée une GenerationConfig pour éviter le warning ---
print("GenerationConfig...")
gen_config = GenerationConfig(
    max_new_tokens=4096,   # remplace max_length
    use_cache=True,        # même option que use_cache dans generate()
    do_sample=False        # ou True si tu veux échantillonner
)

import textwrap

assistant_prefix = "### Assistant:\n"
user_prefix = "### User:\n"

def last_answer(text):
  index = text.rfind(assistant_prefix)
  if index == -1:
    return ""
  return text[index+len(assistant_prefix):]

def wrap_preserve_newlines(text, width=80):
    # Découper le texte en lignes selon les retours à la ligne d'origine
    lines = text.splitlines()

    # Appliquer textwrap.wrap à chaque ligne séparément
    wrapped_lines = [wrapped_line
                     for line in lines
                     for wrapped_line in textwrap.wrap(line, width=width)
                     or ['']]  # Si ligne vide, forcer une ligne vide

    # Rejoindre les lignes
    return '\n'.join(wrapped_lines)

def chat():
  prompt = ""
  while True:
    prompt += user_prefix
    # print("")
    # print(datetime.datetime.now())
    # print("")
    user_input = input(">>> ")
    if user_input == "bye":
      print("Good bye!\n")
      break
    print("")
    print("You said: " + user_input)
    # print("")
    # print(datetime.datetime.now())
    # print("")
    prompt += user_input + "\n\n" + assistant_prefix
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    print("")
    outputs = model.generate(**inputs, generation_config=gen_config)
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    prompt = output_text + "\n"
    answer = last_answer(output_text)
    answer_wrapped = wrap_preserve_newlines(answer, 80)
    print(answer_wrapped)
    print("")

chat()