From a062a92f6b73665ef83072f6238bc37909cfc072 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 30 Jul 2024 18:19:55 -0500 Subject: [PATCH] docs: custom embedding function for ts (#1479) --- .../embeddings/custom_embedding_function.md | 366 ++++++++++-------- nodejs/examples/custom_embedding_function.ts | 64 +++ .../lancedb/embedding/embedding_function.ts | 33 ++ nodejs/lancedb/embedding/index.ts | 2 +- 4 files changed, 295 insertions(+), 170 deletions(-) create mode 100644 nodejs/examples/custom_embedding_function.ts diff --git a/docs/src/embeddings/custom_embedding_function.md b/docs/src/embeddings/custom_embedding_function.md index b306640d..4dbd19a1 100644 --- a/docs/src/embeddings/custom_embedding_function.md +++ b/docs/src/embeddings/custom_embedding_function.md @@ -15,198 +15,226 @@ There is another optional layer of abstraction available: `TextEmbeddingFunction Let's implement `SentenceTransformerEmbeddings` class. All you need to do is implement the `generate_embeddings()` and `ndims` function to handle the input types you expect and register the class in the global `EmbeddingFunctionRegistry` -```python -from lancedb.embeddings import register -from lancedb.util import attempt_import_or_raise -@register("sentence-transformers") -class SentenceTransformerEmbeddings(TextEmbeddingFunction): - name: str = "all-MiniLM-L6-v2" - # set more default instance vars like device, etc. +=== "Python" - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._ndims = None - - def generate_embeddings(self, texts): - return self._embedding_model().encode(list(texts), ...).tolist() + ```python + from lancedb.embeddings import register + from lancedb.util import attempt_import_or_raise - def ndims(self): - if self._ndims is None: - self._ndims = len(self.generate_embeddings("foo")[0]) - return self._ndims + @register("sentence-transformers") + class SentenceTransformerEmbeddings(TextEmbeddingFunction): + name: str = "all-MiniLM-L6-v2" + # set more default instance vars like device, etc. - @cached(cache={}) - def _embedding_model(self): - return sentence_transformers.SentenceTransformer(name) -``` + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._ndims = None -This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and defaul settings. + def generate_embeddings(self, texts): + return self._embedding_model().encode(list(texts), ...).tolist() + + def ndims(self): + if self._ndims is None: + self._ndims = len(self.generate_embeddings("foo")[0]) + return self._ndims + + @cached(cache={}) + def _embedding_model(self): + return sentence_transformers.SentenceTransformer(name) + ``` + +=== "TypeScript" + + ```ts + --8<--- "nodejs/examples/custom_embedding_function.ts:imports" + + --8<--- "nodejs/examples/custom_embedding_function.ts:embedding_impl" + ``` + + +This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and default settings. Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs. -```python -from lancedb.pydantic import LanceModel, Vector +=== "Python" -registry = EmbeddingFunctionRegistry.get_instance() -stransformer = registry.get("sentence-transformers").create() + ```python + from lancedb.pydantic import LanceModel, Vector -class TextModelSchema(LanceModel): - vector: Vector(stransformer.ndims) = stransformer.VectorField() - text: str = stransformer.SourceField() + registry = EmbeddingFunctionRegistry.get_instance() + stransformer = registry.get("sentence-transformers").create() -tbl = db.create_table("table", schema=TextModelSchema) + class TextModelSchema(LanceModel): + vector: Vector(stransformer.ndims) = stransformer.VectorField() + text: str = stransformer.SourceField() -tbl.add(pd.DataFrame({"text": ["halo", "world"]})) -result = tbl.search("world").limit(5) -``` + tbl = db.create_table("table", schema=TextModelSchema) -NOTE: + tbl.add(pd.DataFrame({"text": ["halo", "world"]})) + result = tbl.search("world").limit(5) + ``` -You can always implement the `EmbeddingFunction` interface directly if you want or need to, `TextEmbeddingFunction` just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case +=== "TypeScript" + + ```ts + --8<--- "nodejs/examples/custom_embedding_function.ts:call_custom_function" + ``` + +!!! note + + You can always implement the `EmbeddingFunction` interface directly if you want or need to, `TextEmbeddingFunction` just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case ## Multi-modal embedding function example -You can also use the `EmbeddingFunction` interface to implement more complex workflows such as multi-modal embedding function support. LanceDB implements `OpenClipEmeddingFunction` class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions. +You can also use the `EmbeddingFunction` interface to implement more complex workflows such as multi-modal embedding function support. -```python -@register("open-clip") -class OpenClipEmbeddings(EmbeddingFunction): - name: str = "ViT-B-32" - pretrained: str = "laion2b_s34b_b79k" - device: str = "cpu" - batch_size: int = 64 - normalize: bool = True - _model = PrivateAttr() - _preprocess = PrivateAttr() - _tokenizer = PrivateAttr() +=== "Python" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found - model, _, preprocess = open_clip.create_model_and_transforms( - self.name, pretrained=self.pretrained - ) - model.to(self.device) - self._model, self._preprocess = model, preprocess - self._tokenizer = open_clip.get_tokenizer(self.name) - self._ndims = None + LanceDB implements `OpenClipEmeddingFunction` class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions. - def ndims(self): - if self._ndims is None: - self._ndims = self.generate_text_embeddings("foo").shape[0] - return self._ndims + ```python + @register("open-clip") + class OpenClipEmbeddings(EmbeddingFunction): + name: str = "ViT-B-32" + pretrained: str = "laion2b_s34b_b79k" + device: str = "cpu" + batch_size: int = 64 + normalize: bool = True + _model = PrivateAttr() + _preprocess = PrivateAttr() + _tokenizer = PrivateAttr() - def compute_query_embeddings( - self, query: Union[str, "PIL.Image.Image"], *args, **kwargs - ) -> List[np.ndarray]: - """ - Compute the embeddings for a given user query + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found + model, _, preprocess = open_clip.create_model_and_transforms( + self.name, pretrained=self.pretrained + ) + model.to(self.device) + self._model, self._preprocess = model, preprocess + self._tokenizer = open_clip.get_tokenizer(self.name) + self._ndims = None - Parameters - ---------- - query : Union[str, PIL.Image.Image] - The query to embed. A query can be either text or an image. - """ - if isinstance(query, str): - return [self.generate_text_embeddings(query)] - else: + def ndims(self): + if self._ndims is None: + self._ndims = self.generate_text_embeddings("foo").shape[0] + return self._ndims + + def compute_query_embeddings( + self, query: Union[str, "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str, PIL.Image.Image] + The query to embed. A query can be either text or an image. + """ + if isinstance(query, str): + return [self.generate_text_embeddings(query)] + else: + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError("OpenClip supports str or PIL Image as query") + + def generate_text_embeddings(self, text: str) -> np.ndarray: + torch = attempt_import_or_raise("torch") + text = self.sanitize_input(text) + text = self._tokenizer(text) + text.to(self.device) + with torch.no_grad(): + text_features = self._model.encode_text(text.to(self.device)) + if self.normalize: + text_features /= text_features.norm(dim=-1, keepdim=True) + return text_features.cpu().numpy().squeeze() + + def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: + """ + Sanitize the input to the embedding function. + """ + if isinstance(images, (str, bytes)): + images = [images] + elif isinstance(images, pa.Array): + images = images.to_pylist() + elif isinstance(images, pa.ChunkedArray): + images = images.combine_chunks().to_pylist() + return images + + def compute_source_embeddings( + self, images: IMAGES, *args, **kwargs + ) -> List[np.array]: + """ + Get the embeddings for the given images + """ + images = self.sanitize_input(images) + embeddings = [] + for i in range(0, len(images), self.batch_size): + j = min(i + self.batch_size, len(images)) + batch = images[i:j] + embeddings.extend(self._parallel_get(batch)) + return embeddings + + def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: + """ + Issue concurrent requests to retrieve the image data + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self.generate_image_embedding, image) + for image in images + ] + return [future.result() for future in futures] + + def generate_image_embedding( + self, image: Union[str, bytes, "PIL.Image.Image"] + ) -> np.ndarray: + """ + Generate the embedding for a single image + + Parameters + ---------- + image : Union[str, bytes, PIL.Image.Image] + The image to embed. If the image is a str, it is treated as a uri. + If the image is bytes, it is treated as the raw image bytes. + """ + torch = attempt_import_or_raise("torch") + # TODO handle retry and errors for https + image = self._to_pil(image) + image = self._preprocess(image).unsqueeze(0) + with torch.no_grad(): + return self._encode_and_normalize_image(image) + + def _to_pil(self, image: Union[str, bytes]): PIL = attempt_import_or_raise("PIL", "pillow") - if isinstance(query, PIL.Image.Image): - return [self.generate_image_embedding(query)] - else: - raise TypeError("OpenClip supports str or PIL Image as query") + if isinstance(image, bytes): + return PIL.Image.open(io.BytesIO(image)) + if isinstance(image, PIL.Image.Image): + return image + elif isinstance(image, str): + parsed = urlparse.urlparse(image) + # TODO handle drive letter on windows. + if parsed.scheme == "file": + return PIL.Image.open(parsed.path) + elif parsed.scheme == "": + return PIL.Image.open(image if os.name == "nt" else parsed.path) + elif parsed.scheme.startswith("http"): + return PIL.Image.open(io.BytesIO(url_retrieve(image))) + else: + raise NotImplementedError("Only local and http(s) urls are supported") - def generate_text_embeddings(self, text: str) -> np.ndarray: - torch = attempt_import_or_raise("torch") - text = self.sanitize_input(text) - text = self._tokenizer(text) - text.to(self.device) - with torch.no_grad(): - text_features = self._model.encode_text(text.to(self.device)) + def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"): + """ + encode a single image tensor and optionally normalize the output + """ + image_features = self._model.encode_image(image_tensor) if self.normalize: - text_features /= text_features.norm(dim=-1, keepdim=True) - return text_features.cpu().numpy().squeeze() + image_features /= image_features.norm(dim=-1, keepdim=True) + return image_features.cpu().numpy().squeeze() + ``` - def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]: - """ - Sanitize the input to the embedding function. - """ - if isinstance(images, (str, bytes)): - images = [images] - elif isinstance(images, pa.Array): - images = images.to_pylist() - elif isinstance(images, pa.ChunkedArray): - images = images.combine_chunks().to_pylist() - return images +=== "TypeScript" - def compute_source_embeddings( - self, images: IMAGES, *args, **kwargs - ) -> List[np.array]: - """ - Get the embeddings for the given images - """ - images = self.sanitize_input(images) - embeddings = [] - for i in range(0, len(images), self.batch_size): - j = min(i + self.batch_size, len(images)) - batch = images[i:j] - embeddings.extend(self._parallel_get(batch)) - return embeddings - - def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]: - """ - Issue concurrent requests to retrieve the image data - """ - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self.generate_image_embedding, image) - for image in images - ] - return [future.result() for future in futures] - - def generate_image_embedding( - self, image: Union[str, bytes, "PIL.Image.Image"] - ) -> np.ndarray: - """ - Generate the embedding for a single image - - Parameters - ---------- - image : Union[str, bytes, PIL.Image.Image] - The image to embed. If the image is a str, it is treated as a uri. - If the image is bytes, it is treated as the raw image bytes. - """ - torch = attempt_import_or_raise("torch") - # TODO handle retry and errors for https - image = self._to_pil(image) - image = self._preprocess(image).unsqueeze(0) - with torch.no_grad(): - return self._encode_and_normalize_image(image) - - def _to_pil(self, image: Union[str, bytes]): - PIL = attempt_import_or_raise("PIL", "pillow") - if isinstance(image, bytes): - return PIL.Image.open(io.BytesIO(image)) - if isinstance(image, PIL.Image.Image): - return image - elif isinstance(image, str): - parsed = urlparse.urlparse(image) - # TODO handle drive letter on windows. - if parsed.scheme == "file": - return PIL.Image.open(parsed.path) - elif parsed.scheme == "": - return PIL.Image.open(image if os.name == "nt" else parsed.path) - elif parsed.scheme.startswith("http"): - return PIL.Image.open(io.BytesIO(url_retrieve(image))) - else: - raise NotImplementedError("Only local and http(s) urls are supported") - - def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"): - """ - encode a single image tensor and optionally normalize the output - """ - image_features = self._model.encode_image(image_tensor) - if self.normalize: - image_features /= image_features.norm(dim=-1, keepdim=True) - return image_features.cpu().numpy().squeeze() -``` + Coming Soon! See this [issue](https://github.com/lancedb/lancedb/issues/1482) to track the status! diff --git a/nodejs/examples/custom_embedding_function.ts b/nodejs/examples/custom_embedding_function.ts new file mode 100644 index 00000000..5307cd6d --- /dev/null +++ b/nodejs/examples/custom_embedding_function.ts @@ -0,0 +1,64 @@ +// --8<-- [start:imports] +import * as lancedb from "@lancedb/lancedb"; +import { + LanceSchema, + TextEmbeddingFunction, + getRegistry, + register, +} from "@lancedb/lancedb/embedding"; +import { pipeline } from "@xenova/transformers"; +// --8<-- [end:imports] + +// --8<-- [start:embedding_impl] +@register("sentence-transformers") +class SentenceTransformersEmbeddings extends TextEmbeddingFunction { + name = "Xenova/all-miniLM-L6-v2"; + #ndims!: number; + extractor: any; + + async init() { + this.extractor = await pipeline("feature-extraction", this.name); + this.#ndims = await this.generateEmbeddings(["hello"]).then( + (e) => e[0].length, + ); + } + + ndims() { + return this.#ndims; + } + + toJSON() { + return { + name: this.name, + }; + } + async generateEmbeddings(texts: string[]) { + const output = await this.extractor(texts, { + pooling: "mean", + normalize: true, + }); + return output.tolist(); + } +} +// -8<-- [end:embedding_impl] + +// --8<-- [start:call_custom_function] +const registry = getRegistry(); + +const sentenceTransformer = await registry + .get("sentence-transformers")! + .create(); + +const schema = LanceSchema({ + vector: sentenceTransformer.vectorField(), + text: sentenceTransformer.sourceField(), +}); + +const db = await lancedb.connect("/tmp/db"); +const table = await db.createEmptyTable("table", schema, { mode: "overwrite" }); + +await table.add([{ text: "hello" }, { text: "world" }]); + +const results = await table.search("greeting").limit(1).toArray(); +console.log(results[0].text); +// -8<-- [end:call_custom_function] diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index c42951c2..11e6d153 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -21,6 +21,7 @@ import { Float32, FloatLike, type IntoVector, + Utf8, isDataType, isFixedSizeList, isFloat, @@ -192,6 +193,38 @@ export abstract class EmbeddingFunction< } } +/** + * an abstract class for implementing embedding functions that take text as input + */ +export abstract class TextEmbeddingFunction< + M extends FunctionOptions = FunctionOptions, +> extends EmbeddingFunction { + //** Generate the embeddings for the given texts */ + abstract generateEmbeddings( + texts: string[], + // biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do + ...args: any[] + ): Promise; + + async computeQueryEmbeddings(data: string): Promise> { + return this.generateEmbeddings([data]).then((data) => data[0]); + } + + embeddingDataType(): FloatLike { + return new Float32(); + } + + override sourceField(): [DataTypeLike, Map] { + return super.sourceField(new Utf8()); + } + + computeSourceEmbeddings( + data: string[], + ): Promise { + return this.generateEmbeddings(data); + } +} + export interface FieldOptions { datatype: T; dims?: number; diff --git a/nodejs/lancedb/embedding/index.ts b/nodejs/lancedb/embedding/index.ts index 509eddc5..8045b0af 100644 --- a/nodejs/lancedb/embedding/index.ts +++ b/nodejs/lancedb/embedding/index.ts @@ -18,7 +18,7 @@ import { sanitizeType } from "../sanitize"; import { EmbeddingFunction } from "./embedding_function"; import { EmbeddingFunctionConfig, getRegistry } from "./registry"; -export { EmbeddingFunction } from "./embedding_function"; +export { EmbeddingFunction, TextEmbeddingFunction } from "./embedding_function"; // We need to explicitly export '*' so that the `register` decorator actually registers the class. export * from "./openai";