Files
lancedb/docs/src/embeddings/api.md
Prashanth Rao 119b928a52 docs: Updates and refactor (#683)
This PR makes incremental changes to the documentation.

* Closes #697 
* Closes #698

## Chores
- [x] Add dark mode
- [x] Fix headers in navbar
- [x] Add `extra.css` to customize navbar styles
- [x] Customize fonts for prose/code blocks, navbar and admonitions
- [x] Inspect all admonition boxes (remove redundant dropdowns) and
improve clarity and readability
- [x] Ensure that all images in the docs have white background (not
transparent) to be viewable in dark mode
- [x] Improve code formatting in code blocks to make them consistent
with autoformatters (eslint/ruff)
- [x] Add bolder weight to h1 headers
- [x] Add diagram showing the difference between embedded (OSS) and
serverless (Cloud)
- [x] Fix [Creating an empty
table](https://lancedb.github.io/lancedb/guides/tables/#creating-empty-table)
section: right now, the subheaders are not clickable.
- [x] In critical data ingestion methods like `table.add` (among
others), the type signature often does not match the actual code
- [x] Proof-read each documentation section and rewrite as necessary to
provide more context, use cases, and explanations so it reads less like
reference documentation. This is especially important for CRUD and
search sections since those are so central to the user experience.

## Restructure/new content 
- [x] The section for [Adding
data](https://lancedb.github.io/lancedb/guides/tables/#adding-to-a-table)
only shows examples for pandas and iterables. We should include pydantic
models, arrow tables, etc.
- [x] Add conceptual tutorial for IVF-PQ index
- [x] Clearly separate vector search, FTS and filtering sections so that
these are easier to find
- [x] Add docs on refine factor to explain its importance for recall.
Closes #716
- [x] Add an FAQ page showing answers to commonly asked questions about
LanceDB. Closes #746
- [x] Add simple polars example to the integrations section. Closes #756
and closes #153
- [ ] Add basic docs for the Rust API (more detailed API docs can come
later). Closes #781
- [x] Add a section on the various storage options on local vs. cloud
(S3, EBS, EFS, local disk, etc.) and the tradeoffs involved. Closes #782
- [x] Revamp filtering docs: add pre-filtering examples and redo headers
and update content for SQL filters. Closes #783 and closes #784.
- [x] Add docs for data management: compaction, cleaning up old versions
and incremental indexing. Closes #785
- [ ] Add a benchmark section that also discusses some best practices.
Closes #787

---------

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Will Jones <willjones127@gmail.com>
2024-01-19 00:18:37 +05:30

8.7 KiB

To use your own custom embedding function, you can follow these 2 simple steps:

  1. Create your embedding function by implementing the EmbeddingFunction interface
  2. Register your embedding function in the global EmbeddingFunctionRegistry.

Let us see how this looks like in action.

EmbeddingFunction and EmbeddingFunctionRegistry handle low-level details for serializing schema and model information as metadata. To build a custom embedding function, you don't have to worry about the finer details - simply focus on setting up the model and leave the rest to LanceDB.

TextEmbeddingFunction interface

There is another optional layer of abstraction available: TextEmbeddingFunction. You can use this abstraction if your model isn't multi-modal in nature and only needs to operate on text. In such cases, both the source and vector fields will have the same work for vectorization, so you simply just need to setup the model and rest is handled by TextEmbeddingFunction. You can read more about the class and its attributes in the class reference.

Let's implement SentenceTransformerEmbeddings class. All you need to do is implement the generate_embeddings() and ndims function to handle the input types you expect and register the class in the global EmbeddingFunctionRegistry

from lancedb.embeddings import register

@register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
    name: str = "all-MiniLM-L6-v2"
    # set more default instance vars like device, etc.

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._ndims = None
    
    def generate_embeddings(self, texts):
        return self._embedding_model().encode(list(texts), ...).tolist()

    def ndims(self):
        if self._ndims is None:
            self._ndims = len(self.generate_embeddings("foo")[0])
        return self._ndims

    @cached(cache={}) 
    def _embedding_model(self):
        return sentence_transformers.SentenceTransformer(name)

This is a stripped down version of our implementation of SentenceTransformerEmbeddings that removes certain optimizations and defaul settings.

Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.

from lancedb.pydantic import LanceModel, Vector

registry = EmbeddingFunctionRegistry.get_instance()
stransformer = registry.get("sentence-transformers").create()

class TextModelSchema(LanceModel):
    vector: Vector(stransformer.ndims) = stransformer.VectorField()
    text: str = stransformer.SourceField()

tbl = db.create_table("table", schema=TextModelSchema)

tbl.add(pd.DataFrame({"text": ["halo", "world"]}))
result = tbl.search("world").limit(5)

NOTE:

You can always implement the EmbeddingFunction interface directly if you want or need to, TextEmbeddingFunction just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case

Multi-modal embedding function example

You can also use the EmbeddingFunction interface to implement more complex workflows such as multi-modal embedding function support. LanceDB implements OpenClipEmeddingFunction class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions.

@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):
    name: str = "ViT-B-32"
    pretrained: str = "laion2b_s34b_b79k"
    device: str = "cpu"
    batch_size: int = 64
    normalize: bool = True
    _model = PrivateAttr()
    _preprocess = PrivateAttr()
    _tokenizer = PrivateAttr()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
        model, _, preprocess = open_clip.create_model_and_transforms(
            self.name, pretrained=self.pretrained
        )
        model.to(self.device)
        self._model, self._preprocess = model, preprocess
        self._tokenizer = open_clip.get_tokenizer(self.name)
        self._ndims = None

    def ndims(self):
        if self._ndims is None:
            self._ndims = self.generate_text_embeddings("foo").shape[0]
        return self._ndims

    def compute_query_embeddings(
        self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
    ) -> List[np.ndarray]:
        """
        Compute the embeddings for a given user query

        Parameters
        ----------
        query : Union[str, PIL.Image.Image]
            The query to embed. A query can be either text or an image.
        """
        if isinstance(query, str):
            return [self.generate_text_embeddings(query)]
        else:
            PIL = self.safe_import("PIL", "pillow")
            if isinstance(query, PIL.Image.Image):
                return [self.generate_image_embedding(query)]
            else:
                raise TypeError("OpenClip supports str or PIL Image as query")

    def generate_text_embeddings(self, text: str) -> np.ndarray:
        torch = self.safe_import("torch")
        text = self.sanitize_input(text)
        text = self._tokenizer(text)
        text.to(self.device)
        with torch.no_grad():
            text_features = self._model.encode_text(text.to(self.device))
            if self.normalize:
                text_features /= text_features.norm(dim=-1, keepdim=True)
            return text_features.cpu().numpy().squeeze()

    def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
        """
        Sanitize the input to the embedding function.
        """
        if isinstance(images, (str, bytes)):
            images = [images]
        elif isinstance(images, pa.Array):
            images = images.to_pylist()
        elif isinstance(images, pa.ChunkedArray):
            images = images.combine_chunks().to_pylist()
        return images

    def compute_source_embeddings(
        self, images: IMAGES, *args, **kwargs
    ) -> List[np.array]:
        """
        Get the embeddings for the given images
        """
        images = self.sanitize_input(images)
        embeddings = []
        for i in range(0, len(images), self.batch_size):
            j = min(i + self.batch_size, len(images))
            batch = images[i:j]
            embeddings.extend(self._parallel_get(batch))
        return embeddings

    def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
        """
        Issue concurrent requests to retrieve the image data
        """
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(self.generate_image_embedding, image)
                for image in images
            ]
            return [future.result() for future in futures]

    def generate_image_embedding(
        self, image: Union[str, bytes, "PIL.Image.Image"]
    ) -> np.ndarray:
        """
        Generate the embedding for a single image

        Parameters
        ----------
        image : Union[str, bytes, PIL.Image.Image]
            The image to embed. If the image is a str, it is treated as a uri.
            If the image is bytes, it is treated as the raw image bytes.
        """
        torch = self.safe_import("torch")
        # TODO handle retry and errors for https
        image = self._to_pil(image)
        image = self._preprocess(image).unsqueeze(0)
        with torch.no_grad():
            return self._encode_and_normalize_image(image)

    def _to_pil(self, image: Union[str, bytes]):
        PIL = self.safe_import("PIL", "pillow")
        if isinstance(image, bytes):
            return PIL.Image.open(io.BytesIO(image))
        if isinstance(image, PIL.Image.Image):
            return image
        elif isinstance(image, str):
            parsed = urlparse.urlparse(image)
            # TODO handle drive letter on windows.
            if parsed.scheme == "file":
                return PIL.Image.open(parsed.path)
            elif parsed.scheme == "":
                return PIL.Image.open(image if os.name == "nt" else parsed.path)
            elif parsed.scheme.startswith("http"):
                return PIL.Image.open(io.BytesIO(url_retrieve(image)))
            else:
                raise NotImplementedError("Only local and http(s) urls are supported")

    def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
        """
        encode a single image tensor and optionally normalize the output
        """
        image_features = self._model.encode_image(image_tensor)
        if self.normalize:
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.cpu().numpy().squeeze()