View file src/colab/soclover.py - Download
# -*- coding: utf-8 -*-
"""soclover.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/18afku_gh0gFqR830aDgUNW4gz3BthFLI
Find the word which is the most correlated to both two words with embedding
"""
import torch
# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {torch_device}")
diffusers_old_version = True
if diffusers_old_version:
!pip install -q --upgrade transformers diffusers==0.2.4 ftfy
else:
!pip install -q --upgrade transformers diffusers ftfy
from transformers import CLIPTextModel, CLIPTokenizer, logging
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = text_encoder.to(torch_device)
def embedtoken(t):
ids = torch.LongTensor([[49406, t, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
with torch.no_grad():
text_embeddings = text_encoder(ids.to(torch_device))[0]
return text_embeddings[0][1]
def embedtokens(tokens):
res = [embedtoken(t) for t in tokens]
res = torch.stack(res)
return res
def embed(text):
prompt = [text]
text_input = tokenizer([prompt[0]], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
return text_embeddings[0][1]
def word(tok):
ids = torch.Tensor([tok])
dec = tokenizer.decode(ids)
return dec
def dists(e1, e2):
te1 = torch.transpose(e1, -1, 0)
re2 = e2.repeat(49406, 1)
te2 = torch.transpose(re2, -1, 0)
d = te1 - te2
q = d ** 2
s = sum(q)
return s
def soclover(w1, w2):
e1 = embed(w1)
e2 = embed(w2)
tokens = torch.tensor(range(49406))
e = embedtokens(tokens)
d1 = dists(e, e1)
d2 = dists(e, e2)
m = torch.maximum(d1, d2)
t = torch.argmin(m)
w = word(t)
return w
print(soclover("tail", "air"))
wa1 = "" # @param {type: "string"}
wb1 = "" # @param {type: "string"}
wa2 = "" # @param {type: "string"}
wb2 = "" # @param {type: "string"}
wa3 = "" # @param {type: "string"}
wb3 = "" # @param {type: "string"}
wa4 = "" # @param {type: "string"}
wb4 = "" # @param {type: "string"}
if wa1 != "" and wb1 != "":
w1 = soclover(wa1, wb1)
print(f"{wa1} {wb1} : {w1}")
if wa2 != "" and wb2 != "":
w2 = soclover(wa2, wb2)
print(f"{wa2} {wb2} : {w2}")
if wa3 != "" and wb3 != "":
w3 = soclover(wa3, wb3)
print(f"{wa3} {wb3} : {w3}")
if wa4 != "" and wb4 != "":
w4 = soclover(wa4, wb4)
print(f"{wa4} {wb4} : {w4}")