View file src/colab/soclover.py - Download

# -*- coding: utf-8 -*-
"""soclover.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/18afku_gh0gFqR830aDgUNW4gz3BthFLI

Find the word which is the most correlated to both two words with embedding
"""

import torch

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {torch_device}")

diffusers_old_version = True

if diffusers_old_version:
  !pip install -q --upgrade transformers diffusers==0.2.4 ftfy
else:
  !pip install -q --upgrade transformers diffusers ftfy

from transformers import CLIPTextModel, CLIPTokenizer, logging

# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

text_encoder = text_encoder.to(torch_device)

def embedtoken(t):
  ids = torch.LongTensor([[49406, t, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]])
  with torch.no_grad():
    text_embeddings = text_encoder(ids.to(torch_device))[0]
  return text_embeddings[0][1]

def embedtokens(tokens):
  res = [embedtoken(t) for t in tokens]
  res = torch.stack(res)
  return res

def embed(text):
    prompt = [text]
    text_input = tokenizer([prompt[0]], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
      text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    return text_embeddings[0][1]

def word(tok):
  ids = torch.Tensor([tok])
  dec = tokenizer.decode(ids)
  return dec

def dists(e1, e2):
  te1 = torch.transpose(e1, -1, 0)
  re2 = e2.repeat(49406, 1)
  te2 = torch.transpose(re2, -1, 0)
  d = te1 - te2
  q = d ** 2
  s = sum(q)
  return s

def soclover(w1, w2):
  e1 = embed(w1)
  e2 = embed(w2)
  tokens = torch.tensor(range(49406))
  e = embedtokens(tokens)
  d1 = dists(e, e1)
  d2 = dists(e, e2)
  m = torch.maximum(d1, d2)
  t = torch.argmin(m)
  w = word(t)
  return w

print(soclover("tail", "air"))

wa1 = "" # @param {type: "string"}
wb1 = "" # @param {type: "string"}
wa2 = "" # @param {type: "string"}
wb2 = "" # @param {type: "string"}
wa3 = "" # @param {type: "string"}
wb3 = "" # @param {type: "string"}
wa4 = "" # @param {type: "string"}
wb4 = "" # @param {type: "string"}

if wa1 != "" and wb1 != "":
  w1 = soclover(wa1, wb1)
  print(f"{wa1} {wb1} : {w1}")

if wa2 != "" and wb2 != "":
  w2 = soclover(wa2, wb2)
  print(f"{wa2} {wb2} : {w2}")

if wa3 != "" and wb3 != "":
  w3 = soclover(wa3, wb3)
  print(f"{wa3} {wb3} : {w3}")

if wa4 != "" and wb4 != "":
  w4 = soclover(wa4, wb4)
  print(f"{wa4} {wb4} : {w4}")