wbrooks's picture
added scripts for testing inference
c795cd4
raw
history blame
399 Bytes
import torch
#
def encode(sentences, tokenizer, model, device="mps"):
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(device = device)
with torch.no_grad():
outputs = model(**inputs)
# outputs.last_hidden_state = [batch, tokens, hidden_dim]
# mean pooling
embeddings = outputs.last_hidden_state.mean(dim=1)
return(embeddings)