Compare commits

...

5 Commits

Author SHA1 Message Date
Lance Release
38b0d91848 Bump version: 0.16.1-beta.0 → 0.17.0-beta.0 2024-11-25 22:05:49 +00:00
Will Jones
6826039575 fix(python): run remote SDK futures in background thread (#1856)
Users who call the remote SDK from code that uses futures (either
`ThreadPoolExecutor` or `asyncio`) can get odd errors like:

```
Traceback (most recent call last):
  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
RuntimeError: cannot enter context: <_contextvars.Context object at 0x7cfe94cdc900> is already entered
```

This PR fixes that by executing all LanceDB futures in a dedicated
thread pool running on a background thread. That way, it doesn't
interact with their threadpool.
2024-11-25 13:12:47 -08:00
QianZhu
3e9321fc40 docs: improve scalar index and filtering (#1874)
improved the docs on build a scalar index and pre-/post-filtering

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
2024-11-25 11:30:57 -08:00
Lei Xu
2ded17452b fix(python)!: handle bad openai embeddings gracefully (#1873)
BREAKING-CHANGE: change Pydantic Vector field to be nullable by default.
Closes #1577
2024-11-23 13:33:52 -08:00
Mr. Doge
dfd9d2ac99 ci: musl missing node/package.json targets (#1870)
I missed targets when manually merging draft PR to updated main
I was copying from:
https://github.com/lancedb/lancedb/pull/1816/files#diff-d6e19f28e97cfeda63a9bd9426f10f1d2454eeed375ee1235e8ba842ceeb46a0

fixes:
error: Rust target x86_64-unknown-linux-musl not found in package.json.
2024-11-22 10:40:59 -08:00
15 changed files with 254 additions and 122 deletions

View File

@@ -1,23 +1,35 @@
# Building Scalar Index # Building a Scalar Index
Similar to many SQL databases, LanceDB supports several types of Scalar indices to accelerate search Scalar indices organize data by scalar attributes (e.g. numbers, categorical values), enabling fast filtering of vector data. In vector databases, scalar indices accelerate the retrieval of scalar data associated with vectors, thus enhancing the query performance when searching for vectors that meet certain scalar criteria.
Similar to many SQL databases, LanceDB supports several types of scalar indices to accelerate search
over scalar columns. over scalar columns.
- `BTREE`: The most common type is BTREE. This index is inspired by the btree data structure - `BTREE`: The most common type is BTREE. The index stores a copy of the
although only the first few layers of the btree are cached in memory. column in sorted order. This sorted copy allows a binary search to be used to
It will perform well on columns with a large number of unique values and few rows per value. satisfy queries.
- `BITMAP`: this index stores a bitmap for each unique value in the column. - `BITMAP`: this index stores a bitmap for each unique value in the column. It
This index is useful for columns with a finite number of unique values and many rows per value. uses a series of bits to indicate whether a value is present in a row of a table
For example, columns that represent "categories", "labels", or "tags" - `LABEL_LIST`: a special index that can be used on `List<T>` columns to
- `LABEL_LIST`: a special index that is used to index list columns whose values have a finite set of possibilities. support queries with `array_contains_all` and `array_contains_any`
using an underlying bitmap index.
For example, a column that contains lists of tags (e.g. `["tag1", "tag2", "tag3"]`) can be indexed with a `LABEL_LIST` index. For example, a column that contains lists of tags (e.g. `["tag1", "tag2", "tag3"]`) can be indexed with a `LABEL_LIST` index.
!!! tips "How to choose the right scalar index type"
`BTREE`: This index is good for scalar columns with mostly distinct values and does best when the query is highly selective.
`BITMAP`: This index works best for low-cardinality numeric or string columns, where the number of unique values is small (i.e., less than a few thousands).
`LABEL_LIST`: This index should be used for columns containing list-type data.
| Data Type | Filter | Index Type | | Data Type | Filter | Index Type |
| --------------------------------------------------------------- | ----------------------------------------- | ------------ | | --------------------------------------------------------------- | ----------------------------------------- | ------------ |
| Numeric, String, Temporal | `<`, `=`, `>`, `in`, `between`, `is null` | `BTREE` | | Numeric, String, Temporal | `<`, `=`, `>`, `in`, `between`, `is null` | `BTREE` |
| Boolean, numbers or strings with fewer than 1,000 unique values | `<`, `=`, `>`, `in`, `between`, `is null` | `BITMAP` | | Boolean, numbers or strings with fewer than 1,000 unique values | `<`, `=`, `>`, `in`, `between`, `is null` | `BITMAP` |
| List of low cardinality of numbers or strings | `array_has_any`, `array_has_all` | `LABEL_LIST` | | List of low cardinality of numbers or strings | `array_has_any`, `array_has_all` | `LABEL_LIST` |
### Create a scalar index
=== "Python" === "Python"
```python ```python
@@ -46,7 +58,7 @@ over scalar columns.
await tlb.create_index("publisher", { config: lancedb.Index.bitmap() }) await tlb.create_index("publisher", { config: lancedb.Index.bitmap() })
``` ```
For example, the following scan will be faster if the column `my_col` has a scalar index: The following scan will be faster if the column `book_id` has a scalar index:
=== "Python" === "Python"
@@ -106,3 +118,30 @@ Scalar indices can also speed up scans containing a vector search or full text s
.limit(10) .limit(10)
.toArray(); .toArray();
``` ```
### Update a scalar index
Updating the table data (adding, deleting, or modifying records) requires that you also update the scalar index. This can be done by calling `optimize`, which will trigger an update to the existing scalar index.
=== "Python"
```python
table.add([{"vector": [7, 8], "book_id": 4}])
table.optimize()
```
=== "TypeScript"
```typescript
await tbl.add([{ vector: [7, 8], book_id: 4 }]);
await tbl.optimize();
```
=== "Rust"
```rust
let more_data: Box<dyn RecordBatchReader + Send> = create_some_records()?;
tbl.add(more_data).execute().await?;
tbl.optimize(OptimizeAction::All).execute().await?;
```
!!! note
New data added after creating the scalar index will still appear in search results if optimize is not used, but with increased latency due to a flat search on the unindexed portion. LanceDB Cloud automates the optimize process, minimizing the impact on search speed.

View File

@@ -7,6 +7,10 @@ performed on the top-k results returned by the vector search. However, pre-filte
option that performs the filter prior to vector search. This can be useful to narrow down on option that performs the filter prior to vector search. This can be useful to narrow down on
the search space on a very large dataset to reduce query latency. the search space on a very large dataset to reduce query latency.
Note that both pre-filtering and post-filtering can yield false positives. For pre-filtering, if the filter is too selective, it might eliminate relevant items that the vector search would have otherwise identified as a good match. In this case, increasing `nprobes` parameter will help reduce such false positives. It is recommended to set `use_index=false` if you know that the filter is highly selective.
Similarly, a highly selective post-filter can lead to false positives. Increasing both `nprobes` and `refine_factor` can mitigate this issue. When deciding between pre-filtering and post-filtering, pre-filtering is generally the safer choice if you're uncertain.
<!-- Setup Code <!-- Setup Code
```python ```python
import lancedb import lancedb
@@ -57,6 +61,9 @@ const tbl = await db.createTable('myVectors', data)
```ts ```ts
--8<-- "docs/src/sql_legacy.ts:search" --8<-- "docs/src/sql_legacy.ts:search"
``` ```
!!! note
Creating a [scalar index](guides/scalar_index.md) accelerates filtering
## SQL filters ## SQL filters

View File

@@ -84,6 +84,8 @@
"aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64", "aarch64-apple-darwin": "@lancedb/vectordb-darwin-arm64",
"x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu", "x86_64-unknown-linux-gnu": "@lancedb/vectordb-linux-x64-gnu",
"aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu", "aarch64-unknown-linux-gnu": "@lancedb/vectordb-linux-arm64-gnu",
"x86_64-unknown-linux-musl": "@lancedb/vectordb-linux-x64-musl",
"aarch64-unknown-linux-musl": "@lancedb/vectordb-linux-arm64-musl",
"x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc", "x86_64-pc-windows-msvc": "@lancedb/vectordb-win32-x64-msvc",
"aarch64-pc-windows-msvc": "@lancedb/vectordb-win32-arm64-msvc" "aarch64-pc-windows-msvc": "@lancedb/vectordb-win32-arm64-msvc"
} }

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.16.1-beta.0" current_version = "0.17.0-beta.0"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.16.1-beta.0" version = "0.17.0-beta.0"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true
@@ -17,11 +17,17 @@ crate-type = ["cdylib"]
arrow = { version = "52.1", features = ["pyarrow"] } arrow = { version = "52.1", features = ["pyarrow"] }
lancedb = { path = "../rust/lancedb", default-features = false } lancedb = { path = "../rust/lancedb", default-features = false }
env_logger.workspace = true env_logger.workspace = true
pyo3 = { version = "0.21", features = ["extension-module", "abi3-py38", "gil-refs"] } pyo3 = { version = "0.21", features = [
"extension-module",
"abi3-py39",
"gil-refs"
] }
# Using this fork for now: https://github.com/awestlake87/pyo3-asyncio/issues/119 # Using this fork for now: https://github.com/awestlake87/pyo3-asyncio/issues/119
# pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] } # pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
pyo3-asyncio-0-21 = { version = "0.21.0", features = ["attributes", "tokio-runtime"] } pyo3-asyncio-0-21 = { version = "0.21.0", features = [
"attributes",
"tokio-runtime"
] }
pin-project = "1.1.5" pin-project = "1.1.5"
futures.workspace = true futures.workspace = true
tokio = { version = "1.36.0", features = ["sync"] } tokio = { version = "1.36.0", features = ["sync"] }
@@ -29,14 +35,13 @@ tokio = { version = "1.36.0", features = ["sync"] }
[build-dependencies] [build-dependencies]
pyo3-build-config = { version = "0.20.3", features = [ pyo3-build-config = { version = "0.20.3", features = [
"extension-module", "extension-module",
"abi3-py38", "abi3-py39",
] } ] }
[features] [features]
default = ["default-tls", "remote"] default = ["default-tls", "remote"]
fp16kernels = ["lancedb/fp16kernels"] fp16kernels = ["lancedb/fp16kernels"]
remote = ["lancedb/remote"] remote = ["lancedb/remote"]
# TLS # TLS
default-tls = ["lancedb/default-tls"] default-tls = ["lancedb/default-tls"]
native-tls = ["lancedb/native-tls"] native-tls = ["lancedb/native-tls"]

View File

@@ -3,7 +3,6 @@ name = "lancedb"
# version in Cargo.toml # version in Cargo.toml
dependencies = [ dependencies = [
"deprecation", "deprecation",
"nest-asyncio~=1.0",
"pylance==0.20.0b2", "pylance==0.20.0b2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"pydantic>=1.10", "pydantic>=1.10",
@@ -31,7 +30,6 @@ classifiers = [
"Programming Language :: Python", "Programming Language :: Python",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",

View File

@@ -83,25 +83,33 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
""" """
openai = attempt_import_or_raise("openai") openai = attempt_import_or_raise("openai")
valid_texts = []
valid_indices = []
for idx, text in enumerate(texts):
if text:
valid_texts.append(text)
valid_indices.append(idx)
# TODO retry, rate limit, token limit # TODO retry, rate limit, token limit
try: try:
if self.name == "text-embedding-ada-002": kwargs = {
rs = self._openai_client.embeddings.create(input=texts, model=self.name) "input": valid_texts,
else: "model": self.name,
kwargs = { }
"input": texts, if self.name != "text-embedding-ada-002":
"model": self.name, kwargs["dimensions"] = self.dim
}
if self.dim: rs = self._openai_client.embeddings.create(**kwargs)
kwargs["dimensions"] = self.dim valid_embeddings = {
rs = self._openai_client.embeddings.create(**kwargs) idx: v.embedding for v, idx in zip(rs.data, valid_indices)
}
except openai.BadRequestError: except openai.BadRequestError:
logging.exception("Bad request: %s", texts) logging.exception("Bad request: %s", texts)
return [None] * len(texts) return [None] * len(texts)
except Exception: except Exception:
logging.exception("OpenAI embeddings error") logging.exception("OpenAI embeddings error")
raise raise
return [v.embedding for v in rs.data] return [valid_embeddings.get(idx, None) for idx in range(len(texts))]
@cached_property @cached_property
def _openai_client(self): def _openai_client(self):

View File

@@ -1,15 +1,5 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pydantic (v1 / v2) adapter for LanceDB""" """Pydantic (v1 / v2) adapter for LanceDB"""
@@ -30,6 +20,7 @@ from typing import (
Type, Type,
Union, Union,
_GenericAlias, _GenericAlias,
GenericAlias,
) )
import numpy as np import numpy as np
@@ -75,7 +66,7 @@ def vector(dim: int, value_type: pa.DataType = pa.float32()):
def Vector( def Vector(
dim: int, value_type: pa.DataType = pa.float32() dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
) -> Type[FixedSizeListMixin]: ) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type. """Pydantic Vector Type.
@@ -88,6 +79,8 @@ def Vector(
The dimension of the vector. The dimension of the vector.
value_type : pyarrow.DataType, optional value_type : pyarrow.DataType, optional
The value type of the vector, by default pa.float32() The value type of the vector, by default pa.float32()
nullable : bool, optional
Whether the vector is nullable, by default it is True.
Examples Examples
-------- --------
@@ -103,7 +96,7 @@ def Vector(
>>> assert schema == pa.schema([ >>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False), ... pa.field("id", pa.int64(), False),
... pa.field("url", pa.utf8(), False), ... pa.field("url", pa.utf8(), False),
... pa.field("embeddings", pa.list_(pa.float32(), 768), False) ... pa.field("embeddings", pa.list_(pa.float32(), 768))
... ]) ... ])
""" """
@@ -112,6 +105,10 @@ def Vector(
def __repr__(self): def __repr__(self):
return f"FixedSizeList(dim={dim})" return f"FixedSizeList(dim={dim})"
@staticmethod
def nullable() -> bool:
return nullable
@staticmethod @staticmethod
def dim() -> int: def dim() -> int:
return dim return dim
@@ -205,9 +202,7 @@ else:
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType""" """Convert a Pydantic FieldInfo to Arrow DataType"""
if isinstance(field.annotation, _GenericAlias) or ( if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias)
):
origin = field.annotation.__origin__ origin = field.annotation.__origin__
args = field.annotation.__args__ args = field.annotation.__args__
if origin is list: if origin is list:
@@ -235,7 +230,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
def is_nullable(field: FieldInfo) -> bool: def is_nullable(field: FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable.""" """Check if a Pydantic FieldInfo is nullable."""
if isinstance(field.annotation, _GenericAlias): if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__ origin = field.annotation.__origin__
args = field.annotation.__args__ args = field.annotation.__args__
if origin == Union: if origin == Union:
@@ -246,6 +241,10 @@ def is_nullable(field: FieldInfo) -> bool:
for typ in args: for typ in args:
if typ is type(None): if typ is type(None):
return True return True
elif inspect.isclass(field.annotation) and issubclass(
field.annotation, FixedSizeListMixin
):
return field.annotation.nullable()
return False return False

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import asyncio
import threading
class BackgroundEventLoop:
"""
A background event loop that can run futures.
Used to bridge sync and async code, without messing with users event loops.
"""
def __init__(self):
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(
target=self.loop.run_forever,
name="LanceDBBackgroundEventLoop",
daemon=True,
)
self.thread.start()
def run(self, future):
return asyncio.run_coroutine_threadsafe(future, self.loop).result()

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@@ -21,6 +20,7 @@ import warnings
from lancedb import connect_async from lancedb import connect_async
from lancedb.remote import ClientConfig from lancedb.remote import ClientConfig
from lancedb.remote.background_loop import BackgroundEventLoop
import pyarrow as pa import pyarrow as pa
from overrides import override from overrides import override
@@ -31,6 +31,8 @@ from ..pydantic import LanceModel
from ..table import Table from ..table import Table
from ..util import validate_table_name from ..util import validate_table_name
LOOP = BackgroundEventLoop()
class RemoteDBConnection(DBConnection): class RemoteDBConnection(DBConnection):
"""A connection to a remote LanceDB database.""" """A connection to a remote LanceDB database."""
@@ -86,18 +88,9 @@ class RemoteDBConnection(DBConnection):
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc self.db_name = parsed.netloc
import nest_asyncio
nest_asyncio.apply()
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.client_config = client_config self.client_config = client_config
self._conn = self._loop.run_until_complete( self._conn = LOOP.run(
connect_async( connect_async(
db_url, db_url,
api_key=api_key, api_key=api_key,
@@ -127,9 +120,7 @@ class RemoteDBConnection(DBConnection):
------- -------
An iterator of table names. An iterator of table names.
""" """
return self._loop.run_until_complete( return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
self._conn.table_names(start_after=page_token, limit=limit)
)
@override @override
def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table: def open_table(self, name: str, *, index_cache_size: Optional[int] = None) -> Table:
@@ -152,8 +143,8 @@ class RemoteDBConnection(DBConnection):
" (there is no local cache to configure)" " (there is no local cache to configure)"
) )
table = self._loop.run_until_complete(self._conn.open_table(name)) table = LOOP.run(self._conn.open_table(name))
return RemoteTable(table, self.db_name, self._loop) return RemoteTable(table, self.db_name)
@override @override
def create_table( def create_table(
@@ -268,7 +259,7 @@ class RemoteDBConnection(DBConnection):
from .table import RemoteTable from .table import RemoteTable
table = self._loop.run_until_complete( table = LOOP.run(
self._conn.create_table( self._conn.create_table(
name, name,
data, data,
@@ -278,7 +269,7 @@ class RemoteDBConnection(DBConnection):
fill_value=fill_value, fill_value=fill_value,
) )
) )
return RemoteTable(table, self.db_name, self._loop) return RemoteTable(table, self.db_name)
@override @override
def drop_table(self, name: str): def drop_table(self, name: str):
@@ -289,7 +280,7 @@ class RemoteDBConnection(DBConnection):
name: str name: str
The name of the table. The name of the table.
""" """
self._loop.run_until_complete(self._conn.drop_table(name)) LOOP.run(self._conn.drop_table(name))
@override @override
def rename_table(self, cur_name: str, new_name: str): def rename_table(self, cur_name: str, new_name: str):
@@ -302,7 +293,7 @@ class RemoteDBConnection(DBConnection):
new_name: str new_name: str
The new name of the table. The new name of the table.
""" """
self._loop.run_until_complete(self._conn.rename_table(cur_name, new_name)) LOOP.run(self._conn.rename_table(cur_name, new_name))
async def close(self): async def close(self):
"""Close the connection to the database.""" """Close the connection to the database."""

View File

@@ -12,12 +12,12 @@
# limitations under the License. # limitations under the License.
from datetime import timedelta from datetime import timedelta
import asyncio
import logging import logging
from functools import cached_property from functools import cached_property
from typing import Dict, Iterable, List, Optional, Union, Literal from typing import Dict, Iterable, List, Optional, Union, Literal
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfPq, LabelList
from lancedb.remote.db import LOOP
import pyarrow as pa import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
@@ -33,9 +33,7 @@ class RemoteTable(Table):
self, self,
table: AsyncTable, table: AsyncTable,
db_name: str, db_name: str,
loop: Optional[asyncio.AbstractEventLoop] = None,
): ):
self._loop = loop
self._table = table self._table = table
self.db_name = db_name self.db_name = db_name
@@ -56,12 +54,12 @@ class RemoteTable(Table):
of this Table of this Table
""" """
return self._loop.run_until_complete(self._table.schema()) return LOOP.run(self._table.schema())
@property @property
def version(self) -> int: def version(self) -> int:
"""Get the current version of the table""" """Get the current version of the table"""
return self._loop.run_until_complete(self._table.version()) return LOOP.run(self._table.version())
@cached_property @cached_property
def embedding_functions(self) -> dict: def embedding_functions(self) -> dict:
@@ -98,11 +96,11 @@ class RemoteTable(Table):
def list_indices(self): def list_indices(self):
"""List all the indices on the table""" """List all the indices on the table"""
return self._loop.run_until_complete(self._table.list_indices()) return LOOP.run(self._table.list_indices())
def index_stats(self, index_uuid: str): def index_stats(self, index_uuid: str):
"""List all the stats of a specified index""" """List all the stats of a specified index"""
return self._loop.run_until_complete(self._table.index_stats(index_uuid)) return LOOP.run(self._table.index_stats(index_uuid))
def create_scalar_index( def create_scalar_index(
self, self,
@@ -132,9 +130,7 @@ class RemoteTable(Table):
else: else:
raise ValueError(f"Unknown index type: {index_type}") raise ValueError(f"Unknown index type: {index_type}")
self._loop.run_until_complete( LOOP.run(self._table.create_index(column, config=config, replace=replace))
self._table.create_index(column, config=config, replace=replace)
)
def create_fts_index( def create_fts_index(
self, self,
@@ -144,9 +140,7 @@ class RemoteTable(Table):
with_position: bool = True, with_position: bool = True,
): ):
config = FTS(with_position=with_position) config = FTS(with_position=with_position)
self._loop.run_until_complete( LOOP.run(self._table.create_index(column, config=config, replace=replace))
self._table.create_index(column, config=config, replace=replace)
)
def create_index( def create_index(
self, self,
@@ -227,9 +221,7 @@ class RemoteTable(Table):
" 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" " 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
) )
self._loop.run_until_complete( LOOP.run(self._table.create_index(vector_column_name, config=config))
self._table.create_index(vector_column_name, config=config)
)
def add( def add(
self, self,
@@ -261,7 +253,7 @@ class RemoteTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
self._loop.run_until_complete( LOOP.run(
self._table.add( self._table.add(
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
) )
@@ -349,9 +341,7 @@ class RemoteTable(Table):
def _execute_query( def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader: ) -> pa.RecordBatchReader:
return self._loop.run_until_complete( return LOOP.run(self._table._execute_query(query, batch_size=batch_size))
self._table._execute_query(query, batch_size=batch_size)
)
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] """Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
@@ -368,9 +358,7 @@ class RemoteTable(Table):
on_bad_vectors: str, on_bad_vectors: str,
fill_value: float, fill_value: float,
): ):
self._loop.run_until_complete( LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
def delete(self, predicate: str): def delete(self, predicate: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -419,7 +407,7 @@ class RemoteTable(Table):
x vector _distance # doctest: +SKIP x vector _distance # doctest: +SKIP
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
""" """
self._loop.run_until_complete(self._table.delete(predicate)) LOOP.run(self._table.delete(predicate))
def update( def update(
self, self,
@@ -469,7 +457,7 @@ class RemoteTable(Table):
2 2 [10.0, 10.0] # doctest: +SKIP 2 2 [10.0, 10.0] # doctest: +SKIP
""" """
self._loop.run_until_complete( LOOP.run(
self._table.update(where=where, updates=values, updates_sql=values_sql) self._table.update(where=where, updates=values, updates_sql=values_sql)
) )
@@ -499,7 +487,7 @@ class RemoteTable(Table):
) )
def count_rows(self, filter: Optional[str] = None) -> int: def count_rows(self, filter: Optional[str] = None) -> int:
return self._loop.run_until_complete(self._table.count_rows(filter)) return LOOP.run(self._table.count_rows(filter))
def add_columns(self, transforms: Dict[str, str]): def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError( raise NotImplementedError(

View File

@@ -90,10 +90,13 @@ def test_embedding_with_bad_results(tmp_path):
self, texts: Union[List[str], np.ndarray] self, texts: Union[List[str], np.ndarray]
) -> list[Union[np.array, None]]: ) -> list[Union[np.array, None]]:
# Return None, which is bad if field is non-nullable # Return None, which is bad if field is non-nullable
return [ a = [
None if i % 2 == 0 else np.random.randn(self.ndims()) np.full(self.ndims(), np.nan)
if i % 2 == 0
else np.random.randn(self.ndims())
for i in range(len(texts)) for i in range(len(texts))
] ]
return a
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance() registry = EmbeddingFunctionRegistry.get_instance()

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import io import io
import os import os
@@ -17,6 +8,7 @@ import os
import lancedb import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa
import pytest import pytest
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@@ -444,6 +436,30 @@ def test_watsonx_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_with_empty_strs(tmp_path):
model = get_registry().get("openai").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", ""]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df, on_bad_vectors="skip")
tb = tbl.to_arrow()
assert tb.schema.field_by_name("vector").type == pa.list_(
pa.float32(), model.ndims()
)
assert len(tb) == 2
assert tb["vector"].is_null().to_pylist() == [False, True]
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.skipif( @pytest.mark.skipif(
importlib.util.find_spec("ollama") is None, reason="Ollama not installed" importlib.util.find_spec("ollama") is None, reason="Ollama not installed"

View File

@@ -1,16 +1,5 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import sys import sys
@@ -172,6 +161,26 @@ def test_pydantic_to_arrow_py38():
assert schema == expect_schema assert schema == expect_schema
def test_nullable_vector():
class NullableModel(pydantic.BaseModel):
vec: Vector(16, nullable=False)
schema = pydantic_to_schema(NullableModel)
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), False)])
class DefaultModel(pydantic.BaseModel):
vec: Vector(16)
schema = pydantic_to_schema(DefaultModel)
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
class NotNullableModel(pydantic.BaseModel):
vec: Vector(16)
schema = pydantic_to_schema(NotNullableModel)
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
def test_fixed_size_list_field(): def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel): class TestModel(pydantic.BaseModel):
vec: Vector(16) vec: Vector(16)
@@ -192,7 +201,7 @@ def test_fixed_size_list_field():
schema = pydantic_to_schema(TestModel) schema = pydantic_to_schema(TestModel)
assert schema == pa.schema( assert schema == pa.schema(
[ [
pa.field("vec", pa.list_(pa.float32(), 16), False), pa.field("vec", pa.list_(pa.float32(), 16)),
pa.field("li", pa.list_(pa.int64()), False), pa.field("li", pa.list_(pa.int64()), False),
] ]
) )

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors # SPDX-FileCopyrightText: Copyright The LanceDB Authors
from concurrent.futures import ThreadPoolExecutor
import contextlib import contextlib
from datetime import timedelta from datetime import timedelta
import http.server import http.server
@@ -187,6 +188,47 @@ async def test_retry_error():
assert cause.status_code == 429 assert cause.status_code == 429
def test_table_add_in_threadpool():
def handler(request):
if request.path == "/v1/table/test/insert/":
request.send_response(200)
request.end_headers()
elif request.path == "/v1/table/test/create/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(b"{}")
elif request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
payload = json.dumps(
dict(
version=1,
schema=dict(
fields=[
dict(name="id", type={"type": "int64"}, nullable=False),
]
),
)
)
request.wfile.write(payload.encode())
else:
request.send_response(404)
request.end_headers()
with mock_lancedb_connection(handler) as db:
table = db.create_table("test", [{"id": 1}])
with ThreadPoolExecutor(3) as executor:
futures = []
for _ in range(10):
future = executor.submit(table.add, [{"id": 1}])
futures.append(future)
for future in futures:
future.result()
@contextlib.contextmanager @contextlib.contextmanager
def query_test_table(query_handler): def query_test_table(query_handler):
def handler(request): def handler(request):