mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
37 Commits
python-v0.
...
v0.3.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7732f7d41c | ||
|
|
5ca98c326f | ||
|
|
b55db397eb | ||
|
|
c04d72ac8a | ||
|
|
28b02fb72a | ||
|
|
f3cf986777 | ||
|
|
c73fcc8898 | ||
|
|
cd9debc3b7 | ||
|
|
26a97ba997 | ||
|
|
ce19fedb08 | ||
|
|
14e8e48de2 | ||
|
|
c30faf6083 | ||
|
|
64a4f025bb | ||
|
|
6dc968e7d3 | ||
|
|
06b5b69f1e | ||
|
|
6bd3a838fc | ||
|
|
f36fea8f20 | ||
|
|
0a30591729 | ||
|
|
0ed39b6146 | ||
|
|
a8c7f80073 | ||
|
|
0293bbe142 | ||
|
|
7372656369 | ||
|
|
d46bc5dd6e | ||
|
|
86efb11572 | ||
|
|
bb01ad5290 | ||
|
|
1b8cda0941 | ||
|
|
bc85a749a3 | ||
|
|
02c35d3457 | ||
|
|
345c136cfb | ||
|
|
043e388254 | ||
|
|
fe64fc4671 | ||
|
|
6d66404506 | ||
|
|
eff94ecea8 | ||
|
|
7dfb555fea | ||
|
|
f762a669e7 | ||
|
|
0bdc7140dd | ||
|
|
8f6e955b24 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.0
|
||||
current_version = 0.3.4
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -5,23 +5,23 @@ exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.8.3", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.8.3" }
|
||||
lance-testing = { "version" = "=0.8.3" }
|
||||
lance = { "version" = "=0.8.8", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.8.8" }
|
||||
lance-testing = { "version" = "=0.8.8" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "43.0.0", optional = false }
|
||||
arrow-array = "43.0"
|
||||
arrow-data = "43.0"
|
||||
arrow-ipc = "43.0"
|
||||
arrow-ord = "43.0"
|
||||
arrow-schema = "43.0"
|
||||
arrow-arith = "43.0"
|
||||
arrow-cast = "43.0"
|
||||
arrow = { version = "47.0.0", optional = false }
|
||||
arrow-array = "47.0"
|
||||
arrow-data = "47.0"
|
||||
arrow-ipc = "47.0"
|
||||
arrow-ord = "47.0"
|
||||
arrow-schema = "47.0"
|
||||
arrow-arith = "47.0"
|
||||
arrow-cast = "47.0"
|
||||
chrono = "0.4.23"
|
||||
half = { "version" = "=2.2.1", default-features = false, features = [
|
||||
"num-traits"
|
||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
log = "0.4"
|
||||
object_store = "0.6.1"
|
||||
object_store = "0.7.1"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
|
||||
26
docs/README.md
Normal file
26
docs/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# LanceDB Documentation
|
||||
|
||||
LanceDB docs are deployed to https://lancedb.github.io/lancedb/.
|
||||
|
||||
Docs is built and deployed automatically by [Github Actions](.github/workflows/docs.yml)
|
||||
whenever a commit is pushed to the `main` branch. So it is possible for the docs to show
|
||||
unreleased features.
|
||||
|
||||
## Building the docs
|
||||
|
||||
### Setup
|
||||
1. Install LanceDB. From LanceDB repo root: `pip install -e python`
|
||||
2. Install dependencies. From LanceDB repo root: `pip install -r docs/requirements.txt`
|
||||
3. Make sure you have node and npm setup
|
||||
4. Make sure protobuf and libssl are installed
|
||||
|
||||
### Building node module and create markdown files
|
||||
|
||||
See [Javascript docs README](docs/src/javascript/README.md)
|
||||
|
||||
### Build docs
|
||||
From LanceDB repo root:
|
||||
|
||||
Run: `PYTHONPATH=. mkdocs build -f docs/mkdocs.yml`
|
||||
|
||||
If successful, you should see a `docs/site` directory that you can verify locally.
|
||||
@@ -37,7 +37,7 @@ plugins:
|
||||
docstring_style: numpy
|
||||
rendering:
|
||||
heading_level: 4
|
||||
show_source: false
|
||||
show_source: true
|
||||
show_symbol_type_in_heading: true
|
||||
show_signature_annotations: true
|
||||
show_root_heading: true
|
||||
@@ -73,7 +73,14 @@ nav:
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- 🧬 Embeddings: embedding.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- 🧬 Embeddings:
|
||||
- embeddings/index.md
|
||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Create Custom Embedding Functions: embeddings/api.md
|
||||
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔍 Python full-text search: fts.md
|
||||
- 🔌 Integrations:
|
||||
- integrations/index.md
|
||||
@@ -105,7 +112,14 @@ nav:
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- Embeddings: embedding.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Embeddings:
|
||||
- embeddings/index.md
|
||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Create Custom Embedding Functions: embeddings/api.md
|
||||
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Python full-text search: fts.md
|
||||
- Integrations:
|
||||
- integrations/index.md
|
||||
@@ -136,6 +150,8 @@ nav:
|
||||
|
||||
extra_css:
|
||||
- styles/global.css
|
||||
extra_javascript:
|
||||
- scripts/posthog.js
|
||||
|
||||
extra:
|
||||
analytics:
|
||||
|
||||
BIN
docs/src/assets/dog_clip_output.png
Normal file
BIN
docs/src/assets/dog_clip_output.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 342 KiB |
BIN
docs/src/assets/embedding_intro.png
Normal file
BIN
docs/src/assets/embedding_intro.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 245 KiB |
BIN
docs/src/assets/embeddings_api.png
Normal file
BIN
docs/src/assets/embeddings_api.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
213
docs/src/embeddings/api.md
Normal file
213
docs/src/embeddings/api.md
Normal file
@@ -0,0 +1,213 @@
|
||||
To use your own custom embedding function, you need to follow these 2 simple steps.
|
||||
1. Create your embedding function by implementing the `EmbeddingFunction` interface
|
||||
2. Register your embedding function in the global `EmbeddingFunctionRegistry`.
|
||||
|
||||
Let us see how this looks like in action.
|
||||
|
||||

|
||||
|
||||
|
||||
`EmbeddingFunction` & `EmbeddingFunctionRegistry` handle low-level details for serializing schema and model information as metadata. To build a custom embdding function, you don't need to worry about those details and simply focus on setting up the model.
|
||||
|
||||
## `TextEmbeddingFunction` Interface
|
||||
|
||||
There is another optional layer of abstraction provided in form of `TextEmbeddingFunction`. You can use this if your model isn't multi-modal in nature and only operates on text. In such case both source and vector fields will have the same pathway for vectorization, so you simply just need to setup the model and rest is handled by `TextEmbeddingFunction`. You can read more about the class and its attributes in the class reference.
|
||||
|
||||
|
||||
Let's implement `SentenceTransformerEmbeddings` class. All you need to do is implement the `generate_embeddings()` and `ndims` function to handle the input types you expect and register the class in the global `EmbeddingFunctionRegistry`
|
||||
|
||||
```python
|
||||
from lancedb.embeddings import register
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
# set more default instance vars like device, etc.
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
def generate_embeddings(self, texts):
|
||||
return self._embedding_model().encode(list(texts), ...).tolist()
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
@cached(cache={})
|
||||
def _embedding_model(self):
|
||||
return sentence_transformers.SentenceTransformer(name)
|
||||
|
||||
```
|
||||
|
||||
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and defaul settings.
|
||||
|
||||
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
|
||||
|
||||
```python
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
stransformer = registry.get("sentence-transformers").create()
|
||||
|
||||
class TextModelSchema(LanceModel):
|
||||
vector: Vector(stransformer.ndims) = stransformer.VectorField()
|
||||
text: str = stransformer.SourceField()
|
||||
|
||||
tbl = db.create_table("table", schema=TextModelSchema)
|
||||
|
||||
tbl.add(pd.DataFrame({"text": ["halo", "world"]}))
|
||||
result = tbl.search("world").limit(5)
|
||||
```
|
||||
|
||||
NOTE:
|
||||
|
||||
You can always implement the `EmbeddingFunction` interface directly if you want or need to, `TextEmbeddingFunction` just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case
|
||||
|
||||
## Multi-modal embedding function example
|
||||
You can also use the `EmbeddingFunction` interface to implement more complex workflows such as multi-modal embedding function support. LanceDB implements `OpenClipEmeddingFunction` class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions.
|
||||
|
||||
```python
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in futures]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor)
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
```
|
||||
156
docs/src/embeddings/default_embedding_functions.md
Normal file
156
docs/src/embeddings/default_embedding_functions.md
Normal file
@@ -0,0 +1,156 @@
|
||||
There are various Embedding functions available out of the box with lancedb. We're working on supporting other popular embedding APIs.
|
||||
|
||||
## Text Embedding Functions
|
||||
Here are the text embedding functions registered by default
|
||||
|
||||
### Sentence Transformers
|
||||
Here are the parameters that you can set when registering a `sentence-transformers` object, and their default values:
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | `str` | `"all-MiniLM-L6-v2"` | The name of the model. |
|
||||
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
|
||||
| `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model. |
|
||||
|
||||
|
||||
```python
|
||||
db = lancedb.connect("/tmp/db")
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("sentence-transformers").create(device="cpu")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"}
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
|
||||
### OpenAIEmbeddings
|
||||
LanceDB has OpenAI embeddings function in the registry by default. It is registered as `openai` and here are the parameters that you can customize when creating the instances
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | `str` | `"text-embedding-ada-002"` | The name of the model. |
|
||||
|
||||
|
||||
|
||||
```python
|
||||
db = lancedb.connect("/tmp/db")
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("openai").create()
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello world"}
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
|
||||
## Multi-modal embedding functions
|
||||
Multi-modal embedding functions allow you query your table using both images and text.
|
||||
|
||||
### OpenClipEmbeddings
|
||||
We support CLIP model embeddings using the open souce alternbative, open-clip which support various customizations. It is registered as `open-clip` and supports following customizations.
|
||||
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | `str` | `"ViT-B-32"` | The name of the model. |
|
||||
| `pretrained` | `str` | `"laion2b_s34b_b79k"` | The name of the pretrained model to load. |
|
||||
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
|
||||
| `batch_size` | `int` | `64` | The number of images to process in a batch. |
|
||||
| `normalize` | `bool` | `True` | Whether to normalize the input images before feeding them to the model. |
|
||||
|
||||
|
||||
This embedding function supports ingesting images as both bytes and urls. You can query them using both test and other images.
|
||||
|
||||
NOTE:
|
||||
LanceDB supports ingesting images directly from accessible links.
|
||||
|
||||
|
||||
```python
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField() # image uri as the source
|
||||
image_bytes: bytes = func.SourceField() # image bytes as the source
|
||||
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
[{"label": labels, "image_uri": uris, "image_bytes": image_bytes}]
|
||||
)
|
||||
```
|
||||
Now we can search using text from both the default vector column and the custom vector column
|
||||
```python
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label) # prints "dog"
|
||||
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(frombytes.label)
|
||||
|
||||
```
|
||||
|
||||
Because we're using a multi-modal embedding function, we can also search using images
|
||||
|
||||
```python
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||
print(actual.label == "dog")
|
||||
|
||||
# image search using a custom vector column
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
print(actual.label)
|
||||
|
||||
```
|
||||
|
||||
If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue.
|
||||
82
docs/src/embeddings/embedding_functions.md
Normal file
82
docs/src/embeddings/embedding_functions.md
Normal file
@@ -0,0 +1,82 @@
|
||||
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions themselves be thought of as a part of the processing pipeline that each request(input) has to be passed through. After initial setup these components are not expected to change for a particular project.
|
||||
|
||||
This is main motivation behind our new embedding functions API, that allow you simply set it up once and the table remembers it, effectively making the **embedding functions disappear in the background** so you don't have to worry about modelling and simply focus on the DB aspects of VectorDB.
|
||||
|
||||
|
||||
You can simply follow these steps and forget about the details of your embedding functions as long as you don't intend to change it.
|
||||
|
||||
### Step 1 - Define the embedding function
|
||||
We have some pre-defined embedding functions in the global registry with more coming soon. Here's let's an implementation of CLIP as example.
|
||||
```
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
clip = registry.get("open-clip").create()
|
||||
|
||||
```
|
||||
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses PyDantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||
|
||||
### Step 2 - Define the Data Model or Schema
|
||||
Our embedding function from the previous section abstracts away all the details about the models and dimensions required to define the schema. You can simply set a feild as **source** or **vector** column. Here's how
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
|
||||
```
|
||||
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for `vector` column & `SourceField` tells that when adding data, automatically use the embedding function to encode `image_uri`.
|
||||
|
||||
|
||||
### Step 3 - Create LanceDB Table
|
||||
Now that we have chosen/defined our embedding function and the schema, we can create the table
|
||||
|
||||
```python
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table("pets", schema=Pets)
|
||||
|
||||
```
|
||||
That's it! We have ingested all the information needed to embed source and query inputs. We can now forget about the model and dimension details and start to build or VectorDB
|
||||
|
||||
### Step 4 - Ingest lots of data and run vector search!
|
||||
Now you can just add the data and it'll be vectorized automatically
|
||||
|
||||
```python
|
||||
table.add([{"image_uri": u} for u in uris])
|
||||
```
|
||||
|
||||
Our OpenCLIP query embedding function support querying via both text and images.
|
||||
|
||||
```python
|
||||
result = table.search("dog")
|
||||
```
|
||||
|
||||
Let's query an image
|
||||
|
||||
```python
|
||||
p = Path("path/to/images/samoyed_100.jpg")
|
||||
query_image = Image.open(p)
|
||||
table.search(query_image)
|
||||
|
||||
```
|
||||
|
||||
### A little fun with PyDantic
|
||||
LanceDB is integrated with PyDantic. Infact we've used the integration in the above example to define the schema. It is also being used behing the scene by the embdding function API to ingest useful information as table metadata.
|
||||
You can also use it for adding utility operations in the schema. For example, in our multi-modal example, you can search images using text or another image. Let us define a utility function to plot the image.
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
|
||||
@property
|
||||
def image(self):
|
||||
return Image.open(self.image_uri)
|
||||
```
|
||||
Now, you can covert your search results to pydantic model and use this property.
|
||||
|
||||
```python
|
||||
rs = table.search(query_image).limit(3).to_pydantic(Pets)
|
||||
rs[2].image
|
||||
```
|
||||
|
||||

|
||||
|
||||
Now that you've the basic idea about LanceDB embedding function, let us now dive deeper into the API that you can use to implement your own embedding functions!
|
||||
@@ -1,13 +1,20 @@
|
||||
# Embedding Functions
|
||||
# Embedding
|
||||
|
||||
Embeddings are high dimensional floating-point vector representations of your data or query.
|
||||
Anything can be embedded using some embedding model or function.
|
||||
For a given embedding function, the output will always have the same number of dimensions.
|
||||
Embeddings are high dimensional floating-point vector representations of your data or query. Anything can be embedded using some embedding model or function. Position of embedding in a high dimensional vector space has semantic significance to a degree that depends on the type of modal and training. These embeddings when projected in a 2-D space generally group similar entities close-by forming groups.
|
||||
|
||||
## Creating an embedding function
|
||||

|
||||
|
||||
Any function that takes as input a batch (list) of data and outputs a batch (list) of embeddings
|
||||
can be used by LanceDB as an embedding function. The input and output batch sizes should be the same.
|
||||
# Creating an embedding function
|
||||
|
||||
LanceDB supports 2 major ways of vectorizing your data, explicit and implicit.
|
||||
|
||||
1. By manually embedding the data before ingesting in the table
|
||||
2. By automatically embedding the data and query as they come, by ingesting embedding function information in the table itself! Covered in [Next Section](embedding_functions.md)
|
||||
|
||||
Whatever workflow you prefer, we have the tools to support you.
|
||||
## Explicit Vectorization
|
||||
|
||||
In this workflow, you can create your embedding function and vectorize your data using lancedb's `with_embedding` function. Let's look at some examples.
|
||||
|
||||
### HuggingFace example
|
||||
|
||||
@@ -134,9 +141,9 @@ belong in the same latent space and your results will be nonsensical.
|
||||
The above snippet returns an array of records with the 10 closest vectors to the query.
|
||||
|
||||
|
||||
## Roadmap
|
||||
## Implicit vectorization / Ingesting embedding functions
|
||||
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions themselves be thought of as a part of the processing pipeline that each request(input) has to be passed through. After initial setup these components are not expected to change for a particular project.
|
||||
|
||||
In the near future, we'll be integrating the embedding functions deeper into LanceDB<br/>.
|
||||
The goal is that you just have to configure the function once when you create the table,
|
||||
and then you'll never have to deal with embeddings / vectors after that unless you want to.
|
||||
We'll also integrate more popular models and APIs.
|
||||
This is main motivation behind our new embedding functions API, that allow you simply set it up once and the table remembers it, effectively making the **embedding functions disappear in the background** so you don't have to worry about modelling and simply focus on the DB aspects of VectorDB.
|
||||
|
||||
Learn more in the Next Section
|
||||
@@ -251,8 +251,9 @@ After a table has been created, you can always add more data to it using
|
||||
### Adding Pandas DataFrame
|
||||
|
||||
```python
|
||||
df = pd.DataFrame([{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
|
||||
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}])
|
||||
df = pd.DataFrame({
|
||||
"vector": [[1.3, 1.4], [9.5, 56.2]], "item": ["fizz", "buzz"], "price": [100.0, 200.0]
|
||||
})
|
||||
tbl.add(df)
|
||||
```
|
||||
|
||||
@@ -261,17 +262,12 @@ After a table has been created, you can always add more data to it using
|
||||
### Adding to table using Iterator
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
def make_batches():
|
||||
for i in range(5):
|
||||
yield pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [1, 1]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
})
|
||||
|
||||
yield [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||
]
|
||||
tbl.add(make_batches())
|
||||
```
|
||||
|
||||
@@ -306,9 +302,10 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
data = [{"x": 1, "vector": [1, 2]},
|
||||
{"x": 2, "vector": [3, 4]},
|
||||
{"x": 3, "vector": [5, 6]}]
|
||||
db = lancedb.connect("./.lancedb")
|
||||
table = db.create_table("my_table", data)
|
||||
table.to_pandas()
|
||||
|
||||
@@ -67,7 +67,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
|
||||
|
||||
## Documentation Quick Links
|
||||
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
|
||||
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
|
||||
* [`Embedding Functions`](embeddings/index.md) - functions for working with embeddings.
|
||||
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
|
||||
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
|
||||
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
|
||||
|
||||
764
docs/src/notebooks/DisappearingEmbeddingFunction.ipynb
Normal file
764
docs/src/notebooks/DisappearingEmbeddingFunction.ipynb
Normal file
File diff suppressed because one or more lines are too long
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -114,13 +114,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"data = pd.DataFrame({\n",
|
||||
" \"vector\": [[1.1, 1.2], [0.2, 1.8]],\n",
|
||||
" \"lat\": [45.5, 40.1],\n",
|
||||
" \"long\": [-122.7, -74.1]\n",
|
||||
"})\n",
|
||||
"data = [\n",
|
||||
" {\"vector\": [1.1, 1.2], \"lat\": 45.5, \"long\": -122.7},\n",
|
||||
" {\"vector\": [0.2, 1.8], \"lat\": 40.1, \"long\": -74.1},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"db.create_table(\"table2\", data)\n",
|
||||
"\n",
|
||||
@@ -366,11 +363,11 @@
|
||||
"def make_batches():\n",
|
||||
" for i in range(5):\n",
|
||||
" yield pd.DataFrame(\n",
|
||||
" {\n",
|
||||
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
|
||||
" \"item\": [\"foo\", \"bar\"],\n",
|
||||
" \"price\": [10.0, 20.0],\n",
|
||||
" })\n",
|
||||
" {\n",
|
||||
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
|
||||
" \"item\": [\"foo\", \"bar\"],\n",
|
||||
" \"price\": [10.0, 20.0],\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"tbl = db.create_table(\"table5\", make_batches(), schema=PydanticSchema)\n",
|
||||
"tbl.schema"
|
||||
@@ -572,9 +569,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.DataFrame([{\"vector\": [1.3, 1.4], \"item\": \"fizz\", \"price\": 100.0},\n",
|
||||
" {\"vector\": [9.5, 56.2], \"item\": \"buzz\", \"price\": 200.0}])\n",
|
||||
"tbl.add(df)"
|
||||
"data = [\n",
|
||||
" {\"vector\": [1.3, 1.4], \"item\": \"fizz\", \"price\": 100.0},\n",
|
||||
" {\"vector\": [9.5, 56.2], \"item\": \"buzz\", \"price\": 200.0}\n",
|
||||
"]\n",
|
||||
"tbl.add(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -596,17 +595,12 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"def make_batches():\n",
|
||||
" for i in range(5):\n",
|
||||
" yield pd.DataFrame(\n",
|
||||
" {\n",
|
||||
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
|
||||
" \"item\": [\"foo\", \"bar\"],\n",
|
||||
" \"price\": [10.0, 20.0],\n",
|
||||
" })\n",
|
||||
" yield [\n",
|
||||
" {\"vector\": [3.1, 4.1], \"item\": \"foo\", \"price\": 10.0},\n",
|
||||
" {\"vector\": [1, 1], \"item\": \"bar\", \"price\": 20.0},\n",
|
||||
" ]\n",
|
||||
"tbl.add(make_batches())"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -39,7 +39,6 @@ to lazily generate data:
|
||||
|
||||
from typing import Iterable
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
|
||||
def make_batches() -> Iterable[pa.RecordBatch]:
|
||||
for i in range(5):
|
||||
|
||||
@@ -11,15 +11,13 @@ pip install duckdb lancedb
|
||||
We will re-use [the dataset created previously](./arrow.md):
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect("data/sample-lancedb")
|
||||
data = pd.DataFrame({
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0]
|
||||
})
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||
]
|
||||
table = db.create_table("pd_table", data=data)
|
||||
arrow_table = table.to_arrow()
|
||||
```
|
||||
|
||||
@@ -26,17 +26,17 @@ pip install lancedb
|
||||
|
||||
## Embeddings
|
||||
|
||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
||||
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
|
||||
|
||||
::: lancedb.embeddings.functions.EmbeddingFunction
|
||||
::: lancedb.embeddings.base.EmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
||||
::: lancedb.embeddings.base.TextEmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
||||
::: lancedb.embeddings.sentence_transformers.SentenceTransformerEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
||||
::: lancedb.embeddings.openai.OpenAIEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
||||
::: lancedb.embeddings.open_clip.OpenClipEmbeddings
|
||||
|
||||
::: lancedb.embeddings.with_embeddings
|
||||
|
||||
|
||||
4
docs/src/scripts/posthog.js
Normal file
4
docs/src/scripts/posthog.js
Normal file
@@ -0,0 +1,4 @@
|
||||
window.addEventListener("DOMContentLoaded", (event) => {
|
||||
!function(t,e){var o,n,p,r;e.__SV||(window.posthog=e,e._i=[],e.init=function(i,s,a){function g(t,e){var o=e.split(".");2==o.length&&(t=t[o[0]],e=o[1]),t[e]=function(){t.push([e].concat(Array.prototype.slice.call(arguments,0)))}}(p=t.createElement("script")).type="text/javascript",p.async=!0,p.src=s.api_host+"/static/array.js",(r=t.getElementsByTagName("script")[0]).parentNode.insertBefore(p,r);var u=e;for(void 0!==a?u=e[a]=[]:a="posthog",u.people=u.people||[],u.toString=function(t){var e="posthog";return"posthog"!==a&&(e+="."+a),t||(e+=" (stub)"),e},u.people.toString=function(){return u.toString(1)+".people (stub)"},o="capture identify alias people.set people.set_once set_config register register_once unregister opt_out_capturing has_opted_out_capturing opt_in_capturing reset isFeatureEnabled onFeatureFlags getFeatureFlag getFeatureFlagPayload reloadFeatureFlags group updateEarlyAccessFeatureEnrollment getEarlyAccessFeatures getActiveMatchingSurveys getSurveys".split(" "),n=0;n<o.length;n++)g(u,o[n]);e._i.push([i,s,a])},e.__SV=1)}(document,window.posthog||[]);
|
||||
posthog.init('phc_oENDjGgHtmIDrV6puUiFem2RB4JA8gGWulfdulmMdZP',{api_host:'https://app.posthog.com'})
|
||||
});
|
||||
@@ -4,7 +4,7 @@
|
||||
In a recommendation system or search engine, you can find similar products from
|
||||
the one you searched.
|
||||
In LLM and other AI applications,
|
||||
each data point can be [presented by the embeddings generated from some models](embedding.md),
|
||||
each data point can be [presented by the embeddings generated from some models](embeddings/index.md),
|
||||
it returns the most relevant features.
|
||||
|
||||
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.
|
||||
|
||||
@@ -8,6 +8,7 @@ const excludedGlobs = [
|
||||
"../src/embedding.md",
|
||||
"../src/examples/*.md",
|
||||
"../src/guides/tables.md",
|
||||
"../src/embeddings/*.md",
|
||||
];
|
||||
|
||||
const nodePrefix = "javascript";
|
||||
|
||||
@@ -10,6 +10,7 @@ excluded_globs = [
|
||||
"../src/integrations/voxel51.md",
|
||||
"../src/guides/tables.md",
|
||||
"../src/python/duckdb.md",
|
||||
"../src/embeddings/*.md",
|
||||
]
|
||||
|
||||
python_prefix = "py"
|
||||
|
||||
74
node/package-lock.json
generated
74
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.0",
|
||||
"version": "0.3.3",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.3.0",
|
||||
"version": "0.3.3",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -53,11 +53,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.0",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.0",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.0",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.0",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.0"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.3",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.3",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.3",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.3",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.3"
|
||||
}
|
||||
},
|
||||
"node_modules/@apache-arrow/ts": {
|
||||
@@ -317,9 +317,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.0.tgz",
|
||||
"integrity": "sha512-Fg+k/cSnqmNQlSWyDp0PpaAJ67kAISfZAD+zZ3mcE8/3ml2I/wM/GVjPy2zeiQX9aR93lG1mZXFSNTDUc74tWQ==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.3.tgz",
|
||||
"integrity": "sha512-nvyj7xNX2/wb/PH5TjyhLR/NQ1jVuoBw2B5UaSg7qf8Tnm5SSXWQ7F25RVKcKwh72fz1qB+CWW24ftZnRzbT/Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -329,9 +329,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.0.tgz",
|
||||
"integrity": "sha512-CXp4b/brMbnBPZuGzKIOskd9uD90R73rWubaJ0du/Kt6fcyQX1dM1wEhWTLxI6eKf8IDL/R9QLL2cIahm1J86w==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.3.tgz",
|
||||
"integrity": "sha512-7CW+nILyPHp6cua0Rl0xaTDWw/vajEn/jCsEjFYgDmE+rtf5Z5Fum41FxR9C2TtIAvUK+nWb5mkYeOLqU6vRvg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -341,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.0.tgz",
|
||||
"integrity": "sha512-1bjaRzYcDsWIRUbO2K/f+ohNmNvCgKcrrOhmiXSHVlYY8kH1LUMFZj+BhqBC0Ea0Stt7/1rsRLMRXRtaeVOEHw==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.3.tgz",
|
||||
"integrity": "sha512-MmhwbacKxZPkLwwOqysVY8mUb8lFoyFIPlYhSLV4xS1C8X4HWALljIul1qMl1RYudp9Uc3PsOzRexl+OvCGfUw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -353,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.0.tgz",
|
||||
"integrity": "sha512-BEDIJ6ReGAi+tLTS/RzxIw621yo1UUUiVNTzPGV2didyiJCr1chIGbES+39d/wiFQM43Xs3CBZLNzp+jKkv0/w==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.3.tgz",
|
||||
"integrity": "sha512-OrNlsKi/QPw59Po040oRKn8IuqFEk4upc/4FaFKqVkcmQjjZrMg5Kgy9ZfWIhHdAnWXXggZZIPArpt0X1B0ceA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -365,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.0.tgz",
|
||||
"integrity": "sha512-7K2kbWbShuifQF/6L/tWSz2DhKfIreHKlBdVOuBTYYOReQMHn5cJxgwuFgQHqMubZ9zcagtHpmo+Wtqd034OKQ==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.3.tgz",
|
||||
"integrity": "sha512-lIT0A7a6eqX51IfGyhECtpXXgsr//kgbd+HZbcCdPy2GMmNezSch/7V22zExDSpF32hX8WfgcTLYCVWVilggDQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -4869,33 +4869,33 @@
|
||||
}
|
||||
},
|
||||
"@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.0.tgz",
|
||||
"integrity": "sha512-Fg+k/cSnqmNQlSWyDp0PpaAJ67kAISfZAD+zZ3mcE8/3ml2I/wM/GVjPy2zeiQX9aR93lG1mZXFSNTDUc74tWQ==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.3.tgz",
|
||||
"integrity": "sha512-nvyj7xNX2/wb/PH5TjyhLR/NQ1jVuoBw2B5UaSg7qf8Tnm5SSXWQ7F25RVKcKwh72fz1qB+CWW24ftZnRzbT/Q==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.0.tgz",
|
||||
"integrity": "sha512-CXp4b/brMbnBPZuGzKIOskd9uD90R73rWubaJ0du/Kt6fcyQX1dM1wEhWTLxI6eKf8IDL/R9QLL2cIahm1J86w==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.3.tgz",
|
||||
"integrity": "sha512-7CW+nILyPHp6cua0Rl0xaTDWw/vajEn/jCsEjFYgDmE+rtf5Z5Fum41FxR9C2TtIAvUK+nWb5mkYeOLqU6vRvg==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.0.tgz",
|
||||
"integrity": "sha512-1bjaRzYcDsWIRUbO2K/f+ohNmNvCgKcrrOhmiXSHVlYY8kH1LUMFZj+BhqBC0Ea0Stt7/1rsRLMRXRtaeVOEHw==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.3.tgz",
|
||||
"integrity": "sha512-MmhwbacKxZPkLwwOqysVY8mUb8lFoyFIPlYhSLV4xS1C8X4HWALljIul1qMl1RYudp9Uc3PsOzRexl+OvCGfUw==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.0.tgz",
|
||||
"integrity": "sha512-BEDIJ6ReGAi+tLTS/RzxIw621yo1UUUiVNTzPGV2didyiJCr1chIGbES+39d/wiFQM43Xs3CBZLNzp+jKkv0/w==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.3.tgz",
|
||||
"integrity": "sha512-OrNlsKi/QPw59Po040oRKn8IuqFEk4upc/4FaFKqVkcmQjjZrMg5Kgy9ZfWIhHdAnWXXggZZIPArpt0X1B0ceA==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.0.tgz",
|
||||
"integrity": "sha512-7K2kbWbShuifQF/6L/tWSz2DhKfIreHKlBdVOuBTYYOReQMHn5cJxgwuFgQHqMubZ9zcagtHpmo+Wtqd034OKQ==",
|
||||
"version": "0.3.3",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.3.tgz",
|
||||
"integrity": "sha512-lIT0A7a6eqX51IfGyhECtpXXgsr//kgbd+HZbcCdPy2GMmNezSch/7V22zExDSpF32hX8WfgcTLYCVWVilggDQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@neon-rs/cli": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.0",
|
||||
"version": "0.3.4",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -81,10 +81,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.0",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.0",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.0",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.0",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.0"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.4",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.4",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.4",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.4",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
Utf8,
|
||||
type Vector,
|
||||
FixedSizeList,
|
||||
vectorFromArray, type Schema, Table as ArrowTable
|
||||
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
@@ -77,7 +77,9 @@ function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
|
||||
|
||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
||||
function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
// Somewhere we always default to have the elements nullable, so we need to set it to true
|
||||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
||||
const children = new Field<Float32>('item', new Float32(), true)
|
||||
return new FixedSizeList(dim, children)
|
||||
}
|
||||
|
||||
@@ -88,6 +90,13 @@ export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC stream format
|
||||
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = await convertToTable(data, embeddings)
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC format
|
||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
@@ -105,6 +114,23 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC stream format
|
||||
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
if (source === null) {
|
||||
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
|
||||
}
|
||||
|
||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Creates an empty Arrow Table
|
||||
export function createEmptyTable (schema: Schema): ArrowTable {
|
||||
return new ArrowTable(schema)
|
||||
|
||||
@@ -23,7 +23,7 @@ import { Query } from './query'
|
||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles } = require('../native.js')
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
|
||||
export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
@@ -260,6 +260,27 @@ export interface Table<T = number[]> {
|
||||
* ```
|
||||
*/
|
||||
delete: (filter: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
listIndices: () => Promise<VectorIndex[]>
|
||||
|
||||
/**
|
||||
* Get statistics about an index.
|
||||
*/
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
uuid: string
|
||||
}
|
||||
|
||||
export interface IndexStats {
|
||||
numIndexedRows: number | null
|
||||
numUnindexedRows: number | null
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -502,6 +523,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return res.metrics
|
||||
})
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
return tableListIndices.call(this._tbl)
|
||||
}
|
||||
|
||||
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||
return tableIndexStats.call(this._tbl, indexUuid)
|
||||
}
|
||||
}
|
||||
|
||||
export interface CleanupStats {
|
||||
|
||||
@@ -65,8 +65,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
const mirroredPath = path.join(dir, `${tableName}.lance`)
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be two dirs
|
||||
assert.equal(files.length, 2)
|
||||
// there should be three dirs
|
||||
assert.equal(files.length, 3)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
|
||||
@@ -76,6 +76,12 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_versions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.manifest'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
@@ -88,8 +94,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be two dirs
|
||||
assert.equal(files.length, 3)
|
||||
// there should be four dirs
|
||||
assert.equal(files.length, 4)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
assert.isTrue(files[2].isDirectory())
|
||||
@@ -128,12 +134,13 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be two dirs
|
||||
assert.equal(files.length, 4)
|
||||
// there should be five dirs
|
||||
assert.equal(files.length, 5)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
assert.isTrue(files[2].isDirectory())
|
||||
assert.isTrue(files[3].isDirectory())
|
||||
assert.isTrue(files[4].isDirectory())
|
||||
|
||||
// Three TXs now
|
||||
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||
|
||||
@@ -108,13 +108,18 @@ export class HttpLancedbClient {
|
||||
/**
|
||||
* Sent POST request.
|
||||
*/
|
||||
public async post (path: string, data?: any, params?: Record<string, string | number>): Promise<AxiosResponse> {
|
||||
public async post (
|
||||
path: string,
|
||||
data?: any,
|
||||
params?: Record<string, string | number>,
|
||||
content?: string | undefined
|
||||
): Promise<AxiosResponse> {
|
||||
const response = await axios.post(
|
||||
`${this._url}${path}`,
|
||||
data,
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Content-Type': content ?? 'application/json',
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
},
|
||||
|
||||
@@ -14,12 +14,16 @@
|
||||
|
||||
import {
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions, type WriteOptions
|
||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||
type WriteOptions,
|
||||
type IndexStats
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
import { Vector } from 'apache-arrow'
|
||||
import { Vector, Table as ArrowTable } from 'apache-arrow'
|
||||
import { HttpLancedbClient } from './client'
|
||||
import { isEmbeddingFunction } from '../embedding/embedding_function'
|
||||
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
|
||||
|
||||
/**
|
||||
* Remote connection.
|
||||
@@ -66,8 +70,60 @@ export class RemoteConnection implements Connection {
|
||||
}
|
||||
}
|
||||
|
||||
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||
throw new Error('Not implemented')
|
||||
async createTable<T> (nameOrOpts: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||
// Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern
|
||||
let schema
|
||||
let embeddings: undefined | EmbeddingFunction<T>
|
||||
let tableName: string
|
||||
if (typeof nameOrOpts === 'string') {
|
||||
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
|
||||
embeddings = optsOrEmbedding
|
||||
}
|
||||
tableName = nameOrOpts
|
||||
} else {
|
||||
schema = nameOrOpts.schema
|
||||
embeddings = nameOrOpts.embeddingFunction
|
||||
tableName = nameOrOpts.name
|
||||
}
|
||||
|
||||
let buffer: Buffer
|
||||
|
||||
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
|
||||
if (data instanceof ArrowTable) {
|
||||
return data.data.length === 0
|
||||
}
|
||||
return data.length === 0
|
||||
}
|
||||
|
||||
if ((data === undefined) || isEmpty(data)) {
|
||||
if (schema === undefined) {
|
||||
throw new Error('Either data or schema needs to defined')
|
||||
}
|
||||
buffer = await fromTableToStreamBuffer(createEmptyTable(schema))
|
||||
} else if (data instanceof ArrowTable) {
|
||||
buffer = await fromTableToStreamBuffer(data, embeddings)
|
||||
} else {
|
||||
// data is Array<Record<...>>
|
||||
buffer = await fromRecordsToStreamBuffer(data, embeddings)
|
||||
}
|
||||
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${tableName}/create/`,
|
||||
buffer,
|
||||
undefined,
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
}
|
||||
|
||||
if (embeddings === undefined) {
|
||||
return new RemoteTable(this._client, tableName)
|
||||
} else {
|
||||
return new RemoteTable(this._client, tableName, embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
async dropTable (name: string): Promise<void> {
|
||||
@@ -141,11 +197,39 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
throw new Error('Not implemented')
|
||||
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${this._name}/insert/`,
|
||||
buffer,
|
||||
{
|
||||
mode: 'append'
|
||||
},
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
}
|
||||
return data.length
|
||||
}
|
||||
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
throw new Error('Not implemented')
|
||||
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${this._name}/insert/`,
|
||||
buffer,
|
||||
{
|
||||
mode: 'overwrite'
|
||||
},
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
}
|
||||
return data.length
|
||||
}
|
||||
|
||||
async createIndex (indexParams: VectorIndexParams): Promise<any> {
|
||||
@@ -157,6 +241,23 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async delete (filter: string): Promise<void> {
|
||||
throw new Error('Not implemented')
|
||||
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||
return results.data.indexes?.map((index: any) => ({
|
||||
columns: index.columns,
|
||||
name: index.index_name,
|
||||
uuid: index.index_uuid
|
||||
}))
|
||||
}
|
||||
|
||||
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
|
||||
return {
|
||||
numIndexedRows: results.data.num_indexed_rows,
|
||||
numUnindexedRows: results.data.num_unindexed_rows
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,6 +328,24 @@ describe('LanceDB client', function () {
|
||||
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
|
||||
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
|
||||
})
|
||||
|
||||
it('should be able to list index and stats', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
|
||||
const indices = await table.listIndices()
|
||||
expect(indices).to.have.lengthOf(1)
|
||||
expect(indices[0].name).to.equal('vector_idx')
|
||||
expect(indices[0].uuid).to.not.be.equal(undefined)
|
||||
expect(indices[0].columns).to.have.lengthOf(1)
|
||||
expect(indices[0].columns[0]).to.equal('vector')
|
||||
|
||||
const stats = await table.indexStats(indices[0].uuid)
|
||||
expect(stats.numIndexedRows).to.equal(300)
|
||||
expect(stats.numUnindexedRows).to.equal(0)
|
||||
}).timeout(50_000)
|
||||
})
|
||||
|
||||
describe('when using a custom embedding function', function () {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.1
|
||||
current_version = 0.3.2
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
1
python/LICENSE
Symbolic link
1
python/LICENSE
Symbolic link
@@ -0,0 +1 @@
|
||||
../LICENSE
|
||||
@@ -11,15 +11,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .functions import (
|
||||
EmbeddingFunction,
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
OpenAIEmbeddings,
|
||||
OpenClipEmbeddings,
|
||||
SentenceTransformerEmbeddings,
|
||||
TextEmbeddingFunction,
|
||||
)
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .utils import with_embeddings
|
||||
|
||||
138
python/lancedb/embeddings/base.py
Normal file
138
python/lancedb/embeddings/base.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from .utils import TEXT
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
@@ -16,7 +16,8 @@ from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .functions import TextEmbeddingFunction, register
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
|
||||
|
||||
|
||||
@@ -1,578 +0,0 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import urllib.error
|
||||
import urllib.parse as urlparse
|
||||
import urllib.request
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the model to load
|
||||
device : str
|
||||
The device to load the model on
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
163
python/lancedb/embeddings/open_clip.py
Normal file
163
python/lancedb/embeddings/open_clip.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
import urllib.parse as urlparse
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
37
python/lancedb/embeddings/openai.py
Normal file
37
python/lancedb/embeddings/openai.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
186
python/lancedb/embeddings/registry.py
Normal file
186
python/lancedb/embeddings/registry.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs):
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts, *args, **kwargs):
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
def get_registry():
|
||||
"""
|
||||
Utility function to get the global instance of the registry
|
||||
|
||||
Returns
|
||||
-------
|
||||
EmbeddingFunctionRegistry
|
||||
The global registry instance
|
||||
|
||||
Examples
|
||||
--------
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
openai = registry.get("openai").create()
|
||||
"""
|
||||
return __REGISTRY__.get_instance()
|
||||
77
python/lancedb/embeddings/sentence_transformers.py
Normal file
77
python/lancedb/embeddings/sentence_transformers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from cachetools import cached
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the model to load
|
||||
device : str
|
||||
The device to load the model on
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
@@ -12,8 +12,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import socket
|
||||
import sys
|
||||
from typing import Callable, Union
|
||||
import urllib.error
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -24,7 +26,12 @@ from ..util import safe_import_pandas
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
def with_embeddings(
|
||||
@@ -155,6 +162,20 @@ class FunctionWrapper:
|
||||
yield from _chunker(arr)
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
|
||||
|
||||
def api_key_not_found_help(provider):
|
||||
LOGGER.error(f"Could not find API key for {provider}.")
|
||||
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||
|
||||
@@ -19,6 +19,7 @@ import inspect
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||
|
||||
import numpy as np
|
||||
@@ -159,6 +160,10 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
||||
return pa.bool_()
|
||||
elif py_type == bytes:
|
||||
return pa.binary()
|
||||
elif py_type == date:
|
||||
return pa.date32()
|
||||
elif py_type == datetime:
|
||||
return pa.timestamp("us")
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
||||
)
|
||||
@@ -322,7 +327,12 @@ class LanceModel(pydantic.BaseModel):
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func == func:
|
||||
if src_func is func:
|
||||
# note we can't use == here since the function is a pydantic
|
||||
# model so two instances of the same function are ==, so if you
|
||||
# have multiple vector columns from multiple sources, both will
|
||||
# be mapped to the same source column
|
||||
# GH594
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
source_column=source, vector_column=vec, function=func
|
||||
|
||||
@@ -151,10 +151,15 @@ class RestfulLanceDBClient:
|
||||
return await deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
async def list_tables(self):
|
||||
async def list_tables(self, limit: int, page_token: str):
|
||||
"""List all tables in the database."""
|
||||
json = await self.get("/v1/table/", {})
|
||||
return json["tables"]
|
||||
try:
|
||||
json = await self.get(
|
||||
"/v1/table/", {"limit": limit, "page_token": page_token}
|
||||
)
|
||||
return json["tables"]
|
||||
except StopAsyncIteration:
|
||||
return []
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -52,10 +52,27 @@ class RemoteDBConnection(DBConnection):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
|
||||
def table_names(self) -> List[str]:
|
||||
"""List the names of all tables in the database."""
|
||||
result = self._loop.run_until_complete(self._client.list_tables())
|
||||
return result
|
||||
def table_names(self, last_token: str, limit=10) -> Iterator[str]:
|
||||
"""List the names of all tables in the database.
|
||||
Parameters
|
||||
----------
|
||||
last_token: str
|
||||
The last token to start the new page.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._loop.run_until_complete(
|
||||
self._client.list_tables(limit, last_token)
|
||||
)
|
||||
if len(result) > 0:
|
||||
last_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
yield result
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
@@ -122,3 +139,8 @@ class RemoteDBConnection(DBConnection):
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._loop.close()
|
||||
await self._client.close()
|
||||
|
||||
@@ -105,4 +105,8 @@ class RemoteTable(Table):
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
|
||||
def delete(self, predicate: str):
|
||||
raise NotImplementedError
|
||||
"""Delete rows from the table."""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
)
|
||||
|
||||
@@ -29,8 +29,7 @@ from lance.dataset import CleanupStats, ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
from .embeddings.functions import EmbeddingFunctionConfig
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
@@ -151,7 +150,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
|
||||
this [Table](Table)
|
||||
this Table
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -292,8 +291,9 @@ class Table(ABC):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]}, {"x": 2, "vector": [3, 4]}, {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
@@ -719,8 +719,9 @@ class LanceTable(Table):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]}, {"x": 2, "vector": [3, 4]}, {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
@@ -836,8 +837,9 @@ class LanceTable(Table):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]}, {"x": 2, "vector": [3, 4]}, {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.8.3",
|
||||
"pylance==0.8.7",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.1.0",
|
||||
@@ -52,7 +52,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"]
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
|
||||
@@ -19,7 +19,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
@@ -31,12 +31,15 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry = get_registry()
|
||||
func = registry.get(alias).create()
|
||||
func2 = registry.get(alias).create()
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
text2: str = func2.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vector2: Vector(func2.ndims()) = func2.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
@@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path):
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
],
|
||||
"text2": [
|
||||
"to be or not to be",
|
||||
"that is the question",
|
||||
"for whether tis nobler",
|
||||
"in the mind to suffer",
|
||||
"the slings and arrows",
|
||||
"of outrageous fortune",
|
||||
"or to take arms",
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path):
|
||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
actual = (
|
||||
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
assert actual.text != "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -69,7 +88,7 @@ def test_openclip(tmp_path):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry = get_registry()
|
||||
func = registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
@@ -131,11 +150,7 @@ def test_openclip(tmp_path):
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
) # also skip if cohere not installed
|
||||
def test_cohere_embedding_function():
|
||||
cohere = (
|
||||
EmbeddingFunctionRegistry.get_instance()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0")
|
||||
)
|
||||
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -40,10 +41,18 @@ def test_pydantic_to_arrow():
|
||||
li: List[int]
|
||||
opt: Optional[str] = None
|
||||
st: StructModel
|
||||
dt: date
|
||||
dtt: datetime
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
)
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
@@ -62,6 +71,8 @@ def test_pydantic_to_arrow():
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("dt", pa.date32(), False),
|
||||
pa.field("dtt", pa.timestamp("us"), False),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
@@ -79,10 +90,18 @@ def test_pydantic_to_arrow_py38():
|
||||
li: List[int]
|
||||
opt: Optional[str] = None
|
||||
st: StructModel
|
||||
dt: date
|
||||
dtt: datetime
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
)
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
@@ -101,6 +120,8 @@ def test_pydantic_to_arrow_py38():
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("dt", pa.date32(), False),
|
||||
pa.field("dtt", pa.timestamp("us"), False),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
@@ -458,7 +458,8 @@ def test_compact_cleanup(db):
|
||||
|
||||
stats = table.compact_files()
|
||||
assert len(table) == 3
|
||||
assert table.version == 4
|
||||
# Compact_files bump 2 versions.
|
||||
assert table.version == 5
|
||||
assert stats.fragments_removed > 0
|
||||
assert stats.fragments_added == 1
|
||||
|
||||
@@ -467,7 +468,7 @@ def test_compact_cleanup(db):
|
||||
|
||||
stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True)
|
||||
assert stats.bytes_removed > 0
|
||||
assert table.version == 4
|
||||
assert table.version == 5
|
||||
|
||||
with pytest.raises(Exception, match="Version 3 no longer exists"):
|
||||
table.checkout(3)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.3.0"
|
||||
version = "0.3.4"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
@@ -14,7 +14,7 @@ arrow-array = { workspace = true }
|
||||
arrow-ipc = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
conv = "0.3.3"
|
||||
conv = "0.3.4"
|
||||
once_cell = "1"
|
||||
futures = "0.3"
|
||||
half = { workspace = true }
|
||||
|
||||
@@ -74,7 +74,7 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
||||
static LOG: OnceCell<()> = OnceCell::new();
|
||||
|
||||
LOG.get_or_init(|| env_logger::init());
|
||||
LOG.get_or_init(env_logger::init);
|
||||
|
||||
RUNTIME.get_or_try_init(|| Runtime::new().or_throw(cx))
|
||||
}
|
||||
@@ -148,7 +148,7 @@ fn get_aws_creds(
|
||||
match (secret_key_id, secret_key, temp_token) {
|
||||
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
||||
StaticCredentialProvider::new(AwsCredential {
|
||||
key_id: key_id,
|
||||
key_id,
|
||||
secret_key: key,
|
||||
token: optional_token,
|
||||
}),
|
||||
@@ -239,6 +239,8 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||
cx.export_function("tableIndexStats", JsTable::js_index_stats)?;
|
||||
cx.export_function(
|
||||
"tableCreateVectorIndex",
|
||||
index::vector::table_create_vector_index,
|
||||
|
||||
@@ -70,7 +70,7 @@ impl JsTable {
|
||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||
aws_creds, aws_region,
|
||||
)),
|
||||
mode: mode,
|
||||
mode,
|
||||
..WriteParams::default()
|
||||
};
|
||||
|
||||
@@ -121,7 +121,7 @@ impl JsTable {
|
||||
let add_result = table.add(batch_reader, Some(params)).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let _added = add_result.or_throw(&mut cx)?;
|
||||
add_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
});
|
||||
});
|
||||
@@ -247,7 +247,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
rt.spawn(async move {
|
||||
let stats = table.compact_files(options).await;
|
||||
let stats = table.compact_files(options, None).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let stats = stats.or_throw(&mut cx)?;
|
||||
@@ -276,4 +276,91 @@ impl JsTable {
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let indices = table.load_indices().await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let indices = indices.or_throw(&mut cx)?;
|
||||
|
||||
let output = JsArray::new(&mut cx, indices.len() as u32);
|
||||
for (i, index) in indices.iter().enumerate() {
|
||||
let js_index = JsObject::new(&mut cx);
|
||||
let index_name = cx.string(index.index_name.clone());
|
||||
js_index.set(&mut cx, "name", index_name)?;
|
||||
|
||||
let index_uuid = cx.string(index.index_uuid.clone());
|
||||
js_index.set(&mut cx, "uuid", index_uuid)?;
|
||||
|
||||
let js_index_columns = JsArray::new(&mut cx, index.columns.len() as u32);
|
||||
for (j, column) in index.columns.iter().enumerate() {
|
||||
let js_column = cx.string(column.clone());
|
||||
js_index_columns.set(&mut cx, j as u32, js_column)?;
|
||||
}
|
||||
js_index.set(&mut cx, "columns", js_index_columns)?;
|
||||
|
||||
output.set(&mut cx, i as u32, js_index)?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
})
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let load_stats = futures::try_join!(
|
||||
table.count_indexed_rows(&index_uuid),
|
||||
table.count_unindexed_rows(&index_uuid)
|
||||
);
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let (indexed_rows, unindexed_rows) = load_stats.or_throw(&mut cx)?;
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
|
||||
match indexed_rows {
|
||||
Some(x) => {
|
||||
let i = cx.number(x as f64);
|
||||
output.set(&mut cx, "numIndexedRows", i)?;
|
||||
}
|
||||
None => {
|
||||
let null = cx.null();
|
||||
output.set(&mut cx, "numIndexedRows", null)?;
|
||||
}
|
||||
};
|
||||
|
||||
match unindexed_rows {
|
||||
Some(x) => {
|
||||
let i = cx.number(x as f64);
|
||||
output.set(&mut cx, "numUnindexedRows", i)?;
|
||||
}
|
||||
None => {
|
||||
let null = cx.null();
|
||||
output.set(&mut cx, "numUnindexedRows", null)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
})
|
||||
});
|
||||
|
||||
Ok(promise)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.3.0"
|
||||
version = "0.3.4"
|
||||
edition = "2021"
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -18,9 +18,9 @@ use arrow::compute::kernels::{aggregate::bool_and, length::length};
|
||||
use arrow_array::{
|
||||
cast::AsArray,
|
||||
types::{ArrowPrimitiveType, Int32Type, Int64Type},
|
||||
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
|
||||
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
|
||||
};
|
||||
use arrow_ord::comparison::eq_dyn_scalar;
|
||||
use arrow_ord::cmp::eq;
|
||||
use arrow_schema::DataType;
|
||||
use num_traits::{ToPrimitive, Zero};
|
||||
|
||||
@@ -38,7 +38,8 @@ where
|
||||
}
|
||||
|
||||
let dim = len_arr.as_primitive::<T>().value(0);
|
||||
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
|
||||
let datum = PrimitiveArray::<T>::new_scalar(dim);
|
||||
if bool_and(&eq(len_arr.as_primitive::<T>(), &datum)?) != Some(true) {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(dim))
|
||||
|
||||
@@ -135,7 +135,7 @@ impl Database {
|
||||
async fn open_path(path: &str) -> Result<Database> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
}
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance::format::{Index, Manifest};
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
@@ -95,8 +96,8 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
}
|
||||
|
||||
fn build(&self) -> VectorIndexParams {
|
||||
let ivf_params = self.ivf_params.clone().unwrap_or(IvfBuildParams::default());
|
||||
let pq_params = self.pq_params.clone().unwrap_or(PQBuildParams::default());
|
||||
let ivf_params = self.ivf_params.clone().unwrap_or_default();
|
||||
let pq_params = self.pq_params.clone().unwrap_or_default();
|
||||
|
||||
VectorIndexParams::with_ivf_pq_params(pq_params.metric_type, ivf_params, pq_params)
|
||||
}
|
||||
@@ -106,6 +107,27 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorIndex {
|
||||
pub columns: Vec<String>,
|
||||
pub index_name: String,
|
||||
pub index_uuid: String,
|
||||
}
|
||||
|
||||
impl VectorIndex {
|
||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
|
||||
let fields = index
|
||||
.fields
|
||||
.iter()
|
||||
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
||||
.collect();
|
||||
VectorIndex {
|
||||
columns: fields,
|
||||
index_name: index.name.clone(),
|
||||
index_uuid: index.uuid.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -57,7 +57,7 @@ trait PrimaryOnly {
|
||||
|
||||
impl PrimaryOnly for Path {
|
||||
fn primary_only(&self) -> bool {
|
||||
self.to_string().contains("manifest")
|
||||
self.filename().unwrap_or("") == "_latest.manifest"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,8 +118,10 @@ impl ObjectStore for MirroringObjectStore {
|
||||
self.primary.head(location).await
|
||||
}
|
||||
|
||||
// garbage collection on secondary will happen async from other means
|
||||
async fn delete(&self, location: &Path) -> Result<()> {
|
||||
if !location.primary_only() {
|
||||
self.secondary.delete(location).await?;
|
||||
}
|
||||
self.primary.delete(location).await
|
||||
}
|
||||
|
||||
@@ -132,7 +134,7 @@ impl ObjectStore for MirroringObjectStore {
|
||||
}
|
||||
|
||||
async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
|
||||
if from.primary_only() {
|
||||
if to.primary_only() {
|
||||
self.primary.copy(from, to).await
|
||||
} else {
|
||||
self.secondary.copy(from, to).await?;
|
||||
@@ -142,6 +144,9 @@ impl ObjectStore for MirroringObjectStore {
|
||||
}
|
||||
|
||||
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
|
||||
if !to.primary_only() {
|
||||
self.secondary.copy(from, to).await?;
|
||||
}
|
||||
self.primary.copy_if_not_exists(from, to).await
|
||||
}
|
||||
}
|
||||
@@ -379,7 +384,7 @@ mod test {
|
||||
let primary_f = primary_elem.unwrap().unwrap();
|
||||
// hit manifest, skip, _versions contains all the manifest and should not exist on secondary
|
||||
let primary_raw_path = primary_f.file_name().to_str().unwrap();
|
||||
if primary_raw_path.contains("manifest") || primary_raw_path.contains("_versions") {
|
||||
if primary_raw_path.contains("_latest.manifest") {
|
||||
primary_elem = primary_iter.next();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -18,14 +18,16 @@ use std::sync::Arc;
|
||||
use arrow_array::{Float32Array, RecordBatchReader};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
|
||||
use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
use lance::dataset::{Dataset, WriteParams};
|
||||
use lance::index::IndexType;
|
||||
use lance::index::{DatasetIndexExt, IndexType};
|
||||
use lance::io::object_store::WrappingObjectStore;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::VectorIndexBuilder;
|
||||
use crate::index::vector::{VectorIndexBuilder, VectorIndex};
|
||||
use crate::query::Query;
|
||||
use crate::utils::{PatchReadParam, PatchWriteParam};
|
||||
use crate::WriteMode;
|
||||
@@ -153,6 +155,22 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn checkout_latest(&self) -> Result<Self> {
|
||||
let latest_version_id = self.dataset.latest_version_id().await?;
|
||||
let dataset = if latest_version_id == self.dataset.version().version {
|
||||
self.dataset.clone()
|
||||
} else {
|
||||
Arc::new(self.dataset.checkout_version(latest_version_id).await?)
|
||||
};
|
||||
|
||||
Ok(Table {
|
||||
name: self.name.clone(),
|
||||
uri: self.uri.clone(),
|
||||
dataset,
|
||||
store_wrapper: self.store_wrapper.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_table_name(uri: &str) -> Result<String> {
|
||||
let path = Path::new(uri);
|
||||
let name = path
|
||||
@@ -222,8 +240,6 @@ impl Table {
|
||||
|
||||
/// Create index on the table.
|
||||
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
|
||||
use lance::index::DatasetIndexExt;
|
||||
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
dataset
|
||||
.create_index(
|
||||
@@ -241,6 +257,14 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn optimize_indices(&mut self) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
|
||||
dataset.optimize_indices().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert records into this Table
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -337,12 +361,44 @@ impl Table {
|
||||
/// for faster reads.
|
||||
///
|
||||
/// This calls into [lance::dataset::optimize::compact_files].
|
||||
pub async fn compact_files(&mut self, options: CompactionOptions) -> Result<CompactionMetrics> {
|
||||
pub async fn compact_files(
|
||||
&mut self,
|
||||
options: CompactionOptions,
|
||||
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
|
||||
) -> Result<CompactionMetrics> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
let metrics = compact_files(&mut dataset, options).await?;
|
||||
let metrics = compact_files(&mut dataset, options, remap_options).await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
pub fn count_fragments(&self) -> usize {
|
||||
self.dataset.count_fragments()
|
||||
}
|
||||
|
||||
pub fn count_deleted_rows(&self) -> usize {
|
||||
self.dataset.count_deleted_rows()
|
||||
}
|
||||
|
||||
pub fn num_small_files(&self, max_rows_per_group: usize) -> usize {
|
||||
self.dataset.num_small_files(max_rows_per_group)
|
||||
}
|
||||
|
||||
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
Ok(self.dataset.count_indexed_rows(index_uuid).await?)
|
||||
}
|
||||
|
||||
pub async fn count_unindexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
Ok(self.dataset.count_unindexed_rows(index_uuid).await?)
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let (indices, mf) = futures::try_join!(
|
||||
self.dataset.load_indices(),
|
||||
self.dataset.latest_manifest()
|
||||
)?;
|
||||
Ok(indices.iter().map(|i| VectorIndex::new_from_format(&mf, i)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user