Compare commits

...

5 Commits

Author SHA1 Message Date
Chang She
96a7c1ab42 feat(python): add Tensor pydantic type
- [x] Can be used to declare data model
- [ ] Can be used to ingest data
2024-02-17 10:29:50 -08:00
Chang She
e0277383a5 feat(python): add optional threadpool for batch requests (#981)
Currently if a batch request is given to the remote API, each query is
sent sequentially. We should allow the user to specify a threadpool.
2024-02-16 20:22:22 -08:00
Will Jones
d6b408e26f fix: use static C runtime on Windows (#979)
We depend on C static runtime, but not all Windows machines have that.
So might be worth statically linking it.

https://github.com/reorproject/reor/issues/36#issuecomment-1948876463
2024-02-16 15:54:12 -08:00
Will Jones
2447372c1f docs: show DuckDB with dataset, not table (#974)
Using datasets is preferred way to allow filter and projection pushdown,
as well as aggregated larger-than-memory tables.
2024-02-16 09:18:18 -08:00
Ayush Chaurasia
f0298d8372 docs: Minimal reranking evaluation benchmarks (#977) 2024-02-15 22:16:53 +05:30
12 changed files with 357 additions and 18 deletions

View File

@@ -33,3 +33,8 @@ rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
# Not all Windows systems have the C runtime installed, so this avoids library
# not found errors on systems that are missing it.
[target.x86_64-pc-windows-msvc]
rustflags = ["-Ctarget-feature=+crt-static"]

View File

@@ -92,6 +92,7 @@ nav:
- Full-text search: fts.md
- Hybrid search:
- Overview: hybrid_search/hybrid_search.md
- Comparing Rerankers: hybrid_search/eval.md
- Airbnb financial data example: notebooks/hybrid_search.ipynb
- Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
@@ -156,6 +157,7 @@ nav:
- Full-text search: fts.md
- Hybrid search:
- Overview: hybrid_search/hybrid_search.md
- Comparing Rerankers: hybrid_search/eval.md
- Airbnb financial data example: notebooks/hybrid_search.ipynb
- Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb

View File

@@ -0,0 +1,49 @@
# Hybrid Search
Hybrid Search is a broad (often misused) term. It can mean anything from combining multiple methods for searching, to applying ranking methods to better sort the results. In this blog, we use the definition of "hybrid search" to mean using a combination of keyword-based and vector search.
## The challenge of (re)ranking search results
Once you have a group of the most relevant search results from multiple search sources, you'd likely standardize the score and rank them accordingly. This process can also be seen as another independent step-reranking.
There are two approaches for reranking search results from multiple sources.
* <b>Score-based</b>: Calculate final relevance scores based on a weighted linear combination of individual search algorithm scores. Example-Weighted linear combination of semantic search & keyword-based search results.
* <b>Relevance-based</b>: Discards the existing scores and calculates the relevance of each search result-query pair. Example-Cross Encoder models
Even though there are many strategies for reranking search results, none works for all cases. Moreover, evaluating them itself is a challenge. Also, reranking can be dataset, application specific so it's hard to generalize.
### Example evaluation of hybrid search with Reranking
Here's some evaluation numbers from experiment comparing these re-rankers on about 800 queries. It is modified version of an evaluation script from [llama-index](https://github.com/run-llama/finetune-embedding/blob/main/evaluate.ipynb) that measures hit-rate at top-k.
<b> With OpenAI ada2 embedding </b>
Vector Search baseline - `0.64`
| Reranker | Top-3 | Top-5 | Top-10 |
| --- | --- | --- | --- |
| Linear Combination | `0.73` | `0.74` | `0.85` |
| Cross Encoder | `0.71` | `0.70` | `0.77` |
| Cohere | `0.81` | `0.81` | `0.85` |
| ColBERT | `0.68` | `0.68` | `0.73` |
<p>
<img src="https://github.com/AyushExel/assets/assets/15766192/d57b1780-ef27-414c-a5c3-73bee7808a45">
</p>
<b> With OpenAI embedding-v3-small </b>
Vector Search baseline - `0.59`
| Reranker | Top-3 | Top-5 | Top-10 |
| --- | --- | --- | --- |
| Linear Combination | `0.68` | `0.70` | `0.84` |
| Cross Encoder | `0.72` | `0.72` | `0.79` |
| Cohere | `0.79` | `0.79` | `0.84` |
| ColBERT | `0.70` | `0.70` | `0.76` |
<p>
<img src="https://github.com/AyushExel/assets/assets/15766192/259adfd2-6ec6-4df6-a77d-1456598970dd">
</p>
### Conclusion
The results show that the reranking methods are able to improve the search results. However, the improvement is not consistent across all rerankers. The choice of reranker depends on the dataset and the application. It is also important to note that the reranking methods are not a replacement for the search methods. They are complementary and should be used together to get the best results. The speed to recall tradeoff is also an important factor to consider when choosing the reranker.

View File

@@ -1,6 +1,9 @@
# DuckDB
LanceDB is very well-integrated with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This integration is done via [Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow) .
In Python, LanceDB tables can also be queried with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This means you can write complex SQL queries to analyze your data in LanceDB.
This integration is done via [Apache Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow), which provides zero-copy data sharing between LanceDB and DuckDB. DuckDB is capable of passing down column selections and basic filters to LanceDB, reducing the amount of data that needs to be scanned to perform your query. Finally, the integration allows streaming data from LanceDB tables, allowing you to aggregate tables that won't fit into memory. All of this uses the same mechanism described in DuckDB's blog post *[DuckDB quacks Arrow](https://duckdb.org/2021/12/03/duck-arrow.html)*.
We can demonstrate this by first installing `duckdb` and `lancedb`.
@@ -19,14 +22,15 @@ data = [
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
]
table = db.create_table("pd_table", data=data)
arrow_table = table.to_arrow()
```
DuckDB can directly query the `pyarrow.Table` object:
To query the table, first call `to_lance` to convert the table to a "dataset", which is an object that can be queried by DuckDB. Then all you need to do is reference that dataset by the same name in your SQL query.
```python
import duckdb
arrow_table = table.to_lance()
duckdb.query("SELECT * FROM arrow_table")
```

View File

@@ -13,8 +13,9 @@
import importlib.metadata
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Optional
from typing import Optional, Union
__version__ = importlib.metadata.version("lancedb")
@@ -32,6 +33,7 @@ def connect(
region: str = "us-east-1",
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -58,7 +60,14 @@ def connect(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
request_thread_pool: int or ThreadPoolExecutor, optional
The thread pool to use for making batch requests to the LanceDB Cloud API.
If an integer, then a ThreadPoolExecutor will be created with that
number of threads. If None, then a ThreadPoolExecutor will be created
with the default number of threads. If a ThreadPoolExecutor, then that
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
Examples
--------
@@ -86,5 +95,9 @@ def connect(
api_key = os.environ.get("LANCEDB_API_KEY")
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override)
if isinstance(request_thread_pool, int):
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
return RemoteDBConnection(
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -27,6 +27,7 @@ from typing import (
Dict,
Generator,
List,
Tuple,
Type,
Union,
_GenericAlias,
@@ -37,6 +38,11 @@ import pyarrow as pa
import pydantic
import semver
from lancedb.util import safe_import_tf, safe_import_torch
torch = safe_import_torch()
tf = safe_import_tf()
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
@@ -79,9 +85,6 @@ def Vector(
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
!!! warning
Experimental feature.
Parameters
----------
dim : int
@@ -155,6 +158,142 @@ def Vector(
return FixedSizeList
class FixedShapeTensorMixin(ABC):
@staticmethod
@abstractmethod
def shape() -> Tuple[int]:
raise NotImplementedError
@staticmethod
@abstractmethod
def value_arrow_type() -> pa.DataType:
raise NotImplementedError
def Tensor(
shape: Tuple[int], value_type: pa.DataType = pa.float32()
) -> Type[FixedShapeTensorMixin]:
"""Pydantic Tensor Type.
!!! warning
Experimental feature.
Parameters
----------
shape : tuple of int
The shape of the tensor
value_type : pyarrow.DataType, optional
The value type of the vector, by default pa.float32()
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import LanceModel, Tensor, Vector
...
>>> class MyModel(LanceModel):
... id: int
... url: str
... tensor: Tensor((3, 3))
... embedding: Vector(768)
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("url", pa.utf8(), False),
... pa.field("tensor", pa.fixed_shape_tensor(pa.float32(), (3, 3)), False),
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
... ])
"""
# TODO: make a public parameterized type.
class FixedShapeTensor(FixedShapeTensorMixin):
def __repr__(self):
return f"FixedShapeTensor(shape={shape})"
@staticmethod
def shape() -> Tuple[int]:
return shape
@staticmethod
def value_arrow_type() -> pa.DataType:
return value_type
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
np.asarray,
nested_schema(shape, core_schema.float_schema()),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if isinstance(v, list):
v = cls._validate_list(v, shape)
elif isinstance(v, np.ndarray):
v = cls._validate_ndarray(v, shape)
elif torch is not None and isinstance(v, torch.Tensor):
v = cls._validate_torch(v, shape)
elif tf is not None and isinstance(v, tf.Tensor):
v = cls._validate_tf(v, shape)
else:
raise TypeError(
"A list of numbers, numpy.ndarray, torch.Tensor, "
f"or tf.Tensor is needed but got {type(v)} instead."
)
return np.asarray(v)
@classmethod
def _validate_list(cls, v, shape):
v = np.asarray(v)
return cls._validate_ndarray(v, shape)
@classmethod
def _validate_ndarray(cls, v, shape):
if v.shape != shape:
raise ValueError(f"Invalid shape {v.shape}, expected {shape}")
return v
@classmethod
def _validate_torch(cls, v, shape):
v = v.detach().cpu().numpy()
return cls._validate_ndarray(v, shape)
@classmethod
def _validate_tf(cls, v, shape):
v = v.numpy()
return cls._validate_ndarray(v, shape)
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any], field):
if field and field.sub_fields:
type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]"
else:
type_with_potential_subtype = "np.ndarray"
field_schema.update({"type": type_with_potential_subtype})
return FixedShapeTensor
def nested_schema(shape, items_schema):
if len(shape) == 0:
return items_schema
else:
return core_schema.list_schema(
min_length=shape[0],
max_length=shape[0],
items_schema=nested_schema(shape[1:], items_schema),
)
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert a field with native Python type to Arrow data type.
@@ -230,6 +369,10 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
elif issubclass(field.annotation, FixedShapeTensorMixin):
return pa.fixed_shape_tensor(
field.annotation.value_arrow_type(), field.annotation.shape()
)
return _py_type_to_arrow_type(field.annotation, field)

View File

@@ -14,6 +14,7 @@
import inspect
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union
from urllib.parse import urlparse
@@ -39,6 +40,7 @@ class RemoteDBConnection(DBConnection):
api_key: str,
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
@@ -49,6 +51,7 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override
)
self._request_thread_pool = request_thread_pool
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"

View File

@@ -13,6 +13,7 @@
import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Optional, Union
@@ -270,15 +271,28 @@ class RemoteTable(Table):
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
if self._conn._request_thread_pool is None:
def submit(name, q):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):
return self._conn._request_thread_pool.submit(
self._conn._client.query, name, q
)
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
results.append(self._conn._client.query(self._name, q))
results.append(submit(self._name, q))
return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
)
else:
result = self._conn._client.query(self._name, query)

View File

@@ -1568,7 +1568,7 @@ def _sanitize_schema(
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and pa.types.is_floating(field.type.value_type)
and field.type.list_size >= 10
)
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
@@ -1581,6 +1581,11 @@ def _sanitize_schema(
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
is_tensor_type = isinstance(field.type, pa.FixedShapeTensorType)
if is_tensor_type and field.name in data.column_names:
data = _sanitize_tensor_column(data, column_name=field.name)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
@@ -1649,6 +1654,31 @@ def _sanitize_vector_column(
return data
def _sanitize_tensor_column(data: pa.Table, column_name: str) -> pa.Table:
"""
Ensure that the tensor column exists and has type tensor(float32)
Parameters
----------
data: pa.Table
The table to sanitize.
column_name: str
The name of the tensor column.
"""
# ChunkedArray is annoying to work with, so we combine chunks here
tensor_arr = data[column_name].combine_chunks()
typ = data[column_name].type
if not isinstance(typ, pa.FixedShapeTensorType):
raise TypeError(f"Unsupported tensor column type: {tensor_arr.type}")
tensor_arr = ensure_tensor(tensor_arr)
data = data.set_column(
data.column_names.index(column_name), column_name, tensor_arr
)
return data
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
values = vec_arr.values
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
@@ -1661,6 +1691,11 @@ def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
return vec_arr
def ensure_tensor(tensor_arr) -> pa.TensorArray:
assert 0 == 1
return tensor_arr
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize jagged vectors."""
if on_bad_vectors == "error":

View File

@@ -153,6 +153,24 @@ def safe_import_polars():
return None
def safe_import_torch():
try:
import torch
return torch
except ImportError:
return None
def safe_import_tf():
try:
import tensorflow as tf
return tf
except ImportError:
return None
def inf_vector_column_query(schema: pa.Schema) -> str:
"""
Get the vector column name

View File

@@ -22,7 +22,13 @@ import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from lancedb.pydantic import (
PYDANTIC_VERSION,
LanceModel,
Tensor,
Vector,
pydantic_to_schema,
)
@pytest.mark.skipif(
@@ -244,3 +250,37 @@ def test_lance_model():
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_tensor():
class TestModel(LanceModel):
tensor: Tensor((3, 3))
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["tensor"]
if PYDANTIC_VERSION >= (2,):
json_schema = TestModel.model_json_schema()
else:
json_schema = TestModel.schema()
assert json_schema == {
"properties": {
"tensor": {
"items": {
"items": {"type": "number"},
"maxItems": 3,
"minItems": 3,
"type": "array",
},
"maxItems": 3,
"minItems": 3,
"title": "Tensor",
"type": "array",
}
},
"required": ["tensor"],
"title": "TestModel",
"type": "object",
}

View File

@@ -31,7 +31,7 @@ import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, Tensor, Vector
from lancedb.table import LanceTable
@@ -803,10 +803,8 @@ def test_count_rows(db):
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db):
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
# Probably not being parsed right by the fts
db = MockDB("~/lancedb_")
def test_hybrid_search(db, tmp_path):
db = MockDB(str(tmp_path))
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
@@ -900,3 +898,18 @@ def test_restore_consistency(tmp_path):
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version
def test_tensor_type(tmp_path):
# create a model with a tensor column
class MyTable(LanceModel):
tensor: Tensor((256, 256, 3))
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", schema=MyTable)
tensor = np.random.rand(256, 256, 3)
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
result = table.search().limit(2).to_pandas()
assert np.allclose(result.tensor[0], result.tensor[1])