Goal: Semantic Search
- Setup pgvector
- Setup sentence transformers for embeddings
- Setup semantic search with an index for speed
Setup pgvector
You have a model with a field you want to make semantically searchable.
Python
class Todo(models.Model):
title = models.TextField(blank=True, null=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
Installation comes from: https://github.com/pgvector/pgvector-python
Step: install pgvector on your machine
Bash
brew install pgvector
Step: create a migration to activate
Python
from django.db import migrations
from pgvector.django import VectorExtension
class Migration(migrations.Migration):
dependencies = [
("some app" , "the_last_migration"),
]
operations = [
VectorExtension()
]
Step: add a VectorField to your model
Python
from pgvector.django import VectorField
class Todo(models.Model):
# ...
embedding = VectorField(dimensions=768, blank=True, null=True)
Step: make migrations and migrate
Bash
python manage.py makemigrations
python manage.py migrate
Setup sentence transformers for embeddings
Bash
pip install sentence-transformers
We will use the default “all-mpnet-base-v2” (huggingface)
- “This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.”
Python
# this is how you create an embedding
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
model.encode("kale is great!")
Step: override the save function; save model to avoid reloading
Python
from sentence_transformers import SentenceTransformer
model = None
def get_model():
global model
if model:
return model
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
return model
class Todo(models.Model):
# ...
def save(self, *args, **kwargs):
# Check if the instance is already in the database
if self.pk:
# Fetch the old value from the database
old_value = Todo.objects.get(pk=self.pk).title
# Check if the 'title' field has changed
if self.title != old_value:
# Update the 'embedding' field
self.embedding = get_model().encode(self.question)
else:
self.embedding = get_model().encode(self.question)
# Call the "real" save() method.
super(Todo, self).save(*args, **kwargs)
Setup semantic search with an index for speed
Step: add an index to the model
Python
from pgvector.django import IvfflatIndex
class Todo(models.Model):
# ...
class Meta:
indexes = [
IvfflatIndex(
name="my_index",
fields=["embedding"],
lists=100,
opclasses=["vector_l2_ops"]),
]
Step: semantic search wherever you want
Python
from pgvector.django import L2Distance
from sentence_transformers import SentenceTransformer
model = None
def get_model():
global model
if model:
return model
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
return model
def some_view(request):
n = 3
if "q" in request.GET:
query = request.GET["q"]
q_embedding = get_model().encode(query)
todos = Todo.objects.order_by(L2Distance("embedding", q_embedding))[:n]
return { "todos": todos }
Update 11/9/23: Use get_model() function to avoid loading model multiple times