mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 22:09:58 +00:00
Compare commits
9 Commits
python-v0.
...
changhiskh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96a7c1ab42 | ||
|
|
e0277383a5 | ||
|
|
d6b408e26f | ||
|
|
2447372c1f | ||
|
|
f0298d8372 | ||
|
|
54693e6bec | ||
|
|
73b2977bff | ||
|
|
aec85f7875 | ||
|
|
51f92ecb3d |
@@ -33,3 +33,8 @@ rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
|
||||
|
||||
# Not all Windows systems have the C runtime installed, so this avoids library
|
||||
# not found errors on systems that are missing it.
|
||||
[target.x86_64-pc-windows-msvc]
|
||||
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||
|
||||
17
.github/workflows/npm-publish.yml
vendored
17
.github/workflows/npm-publish.yml
vendored
@@ -80,10 +80,25 @@ jobs:
|
||||
- arch: x86_64
|
||||
runner: ubuntu-latest
|
||||
- arch: aarch64
|
||||
runner: buildjet-8vcpu-ubuntu-2204-arm
|
||||
# For successful fat LTO builds, we need a large runner to avoid OOM errors.
|
||||
runner: buildjet-16vcpu-ubuntu-2204-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
# Buildjet aarch64 runners have only 1.5 GB RAM per core, vs 3.5 GB per core for
|
||||
# x86_64 runners. To avoid OOM errors on ARM, we create a swap file.
|
||||
- name: Configure aarch64 build
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
free -h
|
||||
sudo fallocate -l 16G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
echo "/swapfile swap swap defaults 0 0" >> sudo /etc/fstab
|
||||
# print info
|
||||
swapon --show
|
||||
free -h
|
||||
- name: Build Linux Artifacts
|
||||
run: |
|
||||
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
|
||||
|
||||
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.9.15", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.15" }
|
||||
lance-linalg = { "version" = "=0.9.15" }
|
||||
lance-testing = { "version" = "=0.9.15" }
|
||||
lance = { "version" = "=0.9.16", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.16" }
|
||||
lance-linalg = { "version" = "=0.9.16" }
|
||||
lance-testing = { "version" = "=0.9.16" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
|
||||
@@ -13,7 +13,9 @@ docker build \
|
||||
.
|
||||
popd
|
||||
|
||||
# We turn on memory swap to avoid OOM killer
|
||||
docker run \
|
||||
-v $(pwd):/io -w /io \
|
||||
--memory-swap=-1 \
|
||||
lancedb-node-manylinux \
|
||||
bash ci/manylinux_node/build.sh $ARCH
|
||||
|
||||
@@ -92,6 +92,7 @@ nav:
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
@@ -156,6 +157,7 @@ nav:
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
|
||||
49
docs/src/hybrid_search/eval.md
Normal file
49
docs/src/hybrid_search/eval.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Hybrid Search
|
||||
|
||||
Hybrid Search is a broad (often misused) term. It can mean anything from combining multiple methods for searching, to applying ranking methods to better sort the results. In this blog, we use the definition of "hybrid search" to mean using a combination of keyword-based and vector search.
|
||||
|
||||
## The challenge of (re)ranking search results
|
||||
Once you have a group of the most relevant search results from multiple search sources, you'd likely standardize the score and rank them accordingly. This process can also be seen as another independent step - reranking.
|
||||
There are two approaches for reranking search results from multiple sources.
|
||||
* <b>Score-based</b>: Calculate final relevance scores based on a weighted linear combination of individual search algorithm scores. Example - Weighted linear combination of semantic search & keyword-based search results.
|
||||
* <b>Relevance-based</b>: Discards the existing scores and calculates the relevance of each search result - query pair. Example - Cross Encoder models
|
||||
|
||||
Even though there are many strategies for reranking search results, none works for all cases. Moreover, evaluating them itself is a challenge. Also, reranking can be dataset, application specific so it's hard to generalize.
|
||||
|
||||
### Example evaluation of hybrid search with Reranking
|
||||
|
||||
Here's some evaluation numbers from experiment comparing these re-rankers on about 800 queries. It is modified version of an evaluation script from [llama-index](https://github.com/run-llama/finetune-embedding/blob/main/evaluate.ipynb) that measures hit-rate at top-k.
|
||||
|
||||
<b> With OpenAI ada2 embedding </b>
|
||||
|
||||
Vector Search baseline - `0.64`
|
||||
|
||||
| Reranker | Top-3 | Top-5 | Top-10 |
|
||||
| --- | --- | --- | --- |
|
||||
| Linear Combination | `0.73` | `0.74` | `0.85` |
|
||||
| Cross Encoder | `0.71` | `0.70` | `0.77` |
|
||||
| Cohere | `0.81` | `0.81` | `0.85` |
|
||||
| ColBERT | `0.68` | `0.68` | `0.73` |
|
||||
|
||||
<p>
|
||||
<img src="https://github.com/AyushExel/assets/assets/15766192/d57b1780-ef27-414c-a5c3-73bee7808a45">
|
||||
</p>
|
||||
|
||||
<b> With OpenAI embedding-v3-small </b>
|
||||
|
||||
Vector Search baseline - `0.59`
|
||||
|
||||
| Reranker | Top-3 | Top-5 | Top-10 |
|
||||
| --- | --- | --- | --- |
|
||||
| Linear Combination | `0.68` | `0.70` | `0.84` |
|
||||
| Cross Encoder | `0.72` | `0.72` | `0.79` |
|
||||
| Cohere | `0.79` | `0.79` | `0.84` |
|
||||
| ColBERT | `0.70` | `0.70` | `0.76` |
|
||||
|
||||
<p>
|
||||
<img src="https://github.com/AyushExel/assets/assets/15766192/259adfd2-6ec6-4df6-a77d-1456598970dd">
|
||||
</p>
|
||||
|
||||
### Conclusion
|
||||
|
||||
The results show that the reranking methods are able to improve the search results. However, the improvement is not consistent across all rerankers. The choice of reranker depends on the dataset and the application. It is also important to note that the reranking methods are not a replacement for the search methods. They are complementary and should be used together to get the best results. The speed to recall tradeoff is also an important factor to consider when choosing the reranker.
|
||||
@@ -1,6 +1,9 @@
|
||||
# DuckDB
|
||||
|
||||
LanceDB is very well-integrated with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This integration is done via [Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow) .
|
||||
In Python, LanceDB tables can also be queried with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This means you can write complex SQL queries to analyze your data in LanceDB.
|
||||
|
||||
This integration is done via [Apache Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow), which provides zero-copy data sharing between LanceDB and DuckDB. DuckDB is capable of passing down column selections and basic filters to LanceDB, reducing the amount of data that needs to be scanned to perform your query. Finally, the integration allows streaming data from LanceDB tables, allowing you to aggregate tables that won't fit into memory. All of this uses the same mechanism described in DuckDB's blog post *[DuckDB quacks Arrow](https://duckdb.org/2021/12/03/duck-arrow.html)*.
|
||||
|
||||
|
||||
We can demonstrate this by first installing `duckdb` and `lancedb`.
|
||||
|
||||
@@ -19,14 +22,15 @@ data = [
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||
]
|
||||
table = db.create_table("pd_table", data=data)
|
||||
arrow_table = table.to_arrow()
|
||||
```
|
||||
|
||||
DuckDB can directly query the `pyarrow.Table` object:
|
||||
To query the table, first call `to_lance` to convert the table to a "dataset", which is an object that can be queried by DuckDB. Then all you need to do is reference that dataset by the same name in your SQL query.
|
||||
|
||||
```python
|
||||
import duckdb
|
||||
|
||||
arrow_table = table.to_lance()
|
||||
|
||||
duckdb.query("SELECT * FROM arrow_table")
|
||||
```
|
||||
|
||||
|
||||
60
node/package-lock.json
generated
60
node/package-lock.json
generated
@@ -328,6 +328,66 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.10.tgz",
|
||||
"integrity": "sha512-y/uHOGb0g15pvqv5tdTyZ6oN+0QVpBmZDzKFWW6pPbuSZjB2uPqcs+ti0RB+AUdmS21kavVQqaNsw/HLKEGrHA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.10.tgz",
|
||||
"integrity": "sha512-XbfR58OkQpAe0xMSTrwJh9ZjGSzG9EZ7zwO6HfYem8PxcLYAcC6eWRWoSG/T0uObyrPTcYYyvHsp0eNQWYBFAQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.10.tgz",
|
||||
"integrity": "sha512-x40WKH9b+KxorRmKr9G7fv8p5mMj8QJQvRMA0v6v+nbZHr2FLlAZV+9mvhHOnm4AGIkPP5335cUgv6Qz6hgwkQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.10.tgz",
|
||||
"integrity": "sha512-CTGPpuzlqq2nVjUxI9gAJOT1oBANIovtIaFsOmBSnEAHgX7oeAxKy2b6L/kJzsgqSzvR5vfLwYcWFrr6ZmBxSA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.10.tgz",
|
||||
"integrity": "sha512-Fd7r74coZyrKzkfXg4WthqOL+uKyJyPTia6imcrMNqKOlTGdKmHf02Qi2QxWZrFaabkRYo4Tpn5FeRJ3yYX8CA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
]
|
||||
},
|
||||
"node_modules/@neon-rs/cli": {
|
||||
"version": "0.0.160",
|
||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||
|
||||
2
nodejs/vectordb/native.d.ts
vendored
2
nodejs/vectordb/native.d.ts
vendored
@@ -73,7 +73,7 @@ export class Table {
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
schema(): Buffer
|
||||
add(buf: Buffer): Promise<void>
|
||||
countRows(filter?: string): Promise<bigint>
|
||||
countRows(filter?: string | undefined | null): Promise<bigint>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(): IndexBuilder
|
||||
query(): Query
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
@@ -32,6 +33,7 @@ def connect(
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -58,7 +60,14 @@ def connect(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -86,5 +95,9 @@ def connect(
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||
if isinstance(request_thread_pool, int):
|
||||
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
||||
return RemoteDBConnection(
|
||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||
)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing import (
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
@@ -37,6 +38,11 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
from lancedb.util import safe_import_tf, safe_import_torch
|
||||
|
||||
torch = safe_import_torch()
|
||||
tf = safe_import_tf()
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
@@ -79,9 +85,6 @@ def Vector(
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim : int
|
||||
@@ -155,6 +158,142 @@ def Vector(
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
class FixedShapeTensorMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def shape() -> Tuple[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def Tensor(
|
||||
shape: Tuple[int], value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedShapeTensorMixin]:
|
||||
"""Pydantic Tensor Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape : tuple of int
|
||||
The shape of the tensor
|
||||
value_type : pyarrow.DataType, optional
|
||||
The value type of the vector, by default pa.float32()
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
...
|
||||
>>> class MyModel(LanceModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... tensor: Tensor((3, 3))
|
||||
... embedding: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("url", pa.utf8(), False),
|
||||
... pa.field("tensor", pa.fixed_shape_tensor(pa.float32(), (3, 3)), False),
|
||||
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
# TODO: make a public parameterized type.
|
||||
class FixedShapeTensor(FixedShapeTensorMixin):
|
||||
def __repr__(self):
|
||||
return f"FixedShapeTensor(shape={shape})"
|
||||
|
||||
@staticmethod
|
||||
def shape() -> Tuple[int]:
|
||||
return shape
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return value_type
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
np.asarray,
|
||||
nested_schema(shape, core_schema.float_schema()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, list):
|
||||
v = cls._validate_list(v, shape)
|
||||
elif isinstance(v, np.ndarray):
|
||||
v = cls._validate_ndarray(v, shape)
|
||||
elif torch is not None and isinstance(v, torch.Tensor):
|
||||
v = cls._validate_torch(v, shape)
|
||||
elif tf is not None and isinstance(v, tf.Tensor):
|
||||
v = cls._validate_tf(v, shape)
|
||||
else:
|
||||
raise TypeError(
|
||||
"A list of numbers, numpy.ndarray, torch.Tensor, "
|
||||
f"or tf.Tensor is needed but got {type(v)} instead."
|
||||
)
|
||||
return np.asarray(v)
|
||||
|
||||
@classmethod
|
||||
def _validate_list(cls, v, shape):
|
||||
v = np.asarray(v)
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
@classmethod
|
||||
def _validate_ndarray(cls, v, shape):
|
||||
if v.shape != shape:
|
||||
raise ValueError(f"Invalid shape {v.shape}, expected {shape}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _validate_torch(cls, v, shape):
|
||||
v = v.detach().cpu().numpy()
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
@classmethod
|
||||
def _validate_tf(cls, v, shape):
|
||||
v = v.numpy()
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any], field):
|
||||
if field and field.sub_fields:
|
||||
type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]"
|
||||
else:
|
||||
type_with_potential_subtype = "np.ndarray"
|
||||
field_schema.update({"type": type_with_potential_subtype})
|
||||
|
||||
return FixedShapeTensor
|
||||
|
||||
|
||||
def nested_schema(shape, items_schema):
|
||||
if len(shape) == 0:
|
||||
return items_schema
|
||||
else:
|
||||
return core_schema.list_schema(
|
||||
min_length=shape[0],
|
||||
max_length=shape[0],
|
||||
items_schema=nested_schema(shape[1:], items_schema),
|
||||
)
|
||||
|
||||
|
||||
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a field with native Python type to Arrow data type.
|
||||
|
||||
@@ -230,6 +369,10 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
elif issubclass(field.annotation, FixedShapeTensorMixin):
|
||||
return pa.fixed_shape_tensor(
|
||||
field.annotation.value_arrow_type(), field.annotation.shape()
|
||||
)
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -39,6 +40,7 @@ class RemoteDBConnection(DBConnection):
|
||||
api_key: str,
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
@@ -49,6 +51,7 @@ class RemoteDBConnection(DBConnection):
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
@@ -270,15 +271,28 @@ class RemoteTable(Table):
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
if self._conn._request_thread_pool is None:
|
||||
|
||||
def submit(name, q):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
return self._conn._request_thread_pool.submit(
|
||||
self._conn._client.query, name, q
|
||||
)
|
||||
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(self._conn._client.query(self._name, q))
|
||||
results.append(submit(self._name, q))
|
||||
|
||||
return pa.concat_tables(
|
||||
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
)
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
|
||||
@@ -1568,7 +1568,7 @@ def _sanitize_schema(
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_float32(field.type.value_type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and field.type.list_size >= 10
|
||||
)
|
||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
||||
@@ -1581,6 +1581,11 @@ def _sanitize_schema(
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
is_tensor_type = isinstance(field.type, pa.FixedShapeTensorType)
|
||||
if is_tensor_type and field.name in data.column_names:
|
||||
data = _sanitize_tensor_column(data, column_name=field.name)
|
||||
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
@@ -1649,6 +1654,31 @@ def _sanitize_vector_column(
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_tensor_column(data: pa.Table, column_name: str) -> pa.Table:
|
||||
"""
|
||||
Ensure that the tensor column exists and has type tensor(float32)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: pa.Table
|
||||
The table to sanitize.
|
||||
column_name: str
|
||||
The name of the tensor column.
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
tensor_arr = data[column_name].combine_chunks()
|
||||
typ = data[column_name].type
|
||||
if not isinstance(typ, pa.FixedShapeTensorType):
|
||||
raise TypeError(f"Unsupported tensor column type: {tensor_arr.type}")
|
||||
|
||||
tensor_arr = ensure_tensor(tensor_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(column_name), column_name, tensor_arr
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
values = vec_arr.values
|
||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
||||
@@ -1661,6 +1691,11 @@ def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
return vec_arr
|
||||
|
||||
|
||||
def ensure_tensor(tensor_arr) -> pa.TensorArray:
|
||||
assert 0 == 1
|
||||
return tensor_arr
|
||||
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "error":
|
||||
|
||||
@@ -153,6 +153,24 @@ def safe_import_polars():
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_torch():
|
||||
try:
|
||||
import torch
|
||||
|
||||
return torch
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_tf():
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
return tf
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "lancedb"
|
||||
version = "0.5.5"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.15",
|
||||
"pylance==0.9.16",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
|
||||
@@ -22,7 +22,13 @@ import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Tensor,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -244,3 +250,37 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_tensor():
|
||||
class TestModel(LanceModel):
|
||||
tensor: Tensor((3, 3))
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["tensor"]
|
||||
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
json_schema = TestModel.model_json_schema()
|
||||
else:
|
||||
json_schema = TestModel.schema()
|
||||
|
||||
assert json_schema == {
|
||||
"properties": {
|
||||
"tensor": {
|
||||
"items": {
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"type": "array",
|
||||
},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"title": "Tensor",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["tensor"],
|
||||
"title": "TestModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -803,10 +803,8 @@ def test_count_rows(db):
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db):
|
||||
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
|
||||
# Probably not being parsed right by the fts
|
||||
db = MockDB("~/lancedb_")
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
@@ -900,3 +898,18 @@ def test_restore_consistency(tmp_path):
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
|
||||
def test_tensor_type(tmp_path):
|
||||
# create a model with a tensor column
|
||||
class MyTable(LanceModel):
|
||||
tensor: Tensor((256, 256, 3))
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
tensor = np.random.rand(256, 256, 3)
|
||||
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
|
||||
|
||||
result = table.search().limit(2).to_pandas()
|
||||
assert np.allclose(result.tensor[0], result.tensor[1])
|
||||
|
||||
Reference in New Issue
Block a user