mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
feat: add ColPali embedding support with MultiVector type (#2170)
This PR adds ColPali support with ColPaliEmbeddings class (tagged "colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also added MultiVector Pydantic type to handle the vector lists. I've added some integration test for the embedding model and some unit test for the new Pydantic type. Could be a template for other ColPali variants as well. or until transformers🤗 starts supporting it. Still `TODO`: - [ ] Documentation - [ ] Add an example _Could also allow Image as query, but didn't work well when testing it._ [ColPali-Engine](https://github.com/illuin-tech/colpali) version: 0.3.9.dev17+g3faee24 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced support for ColPali-based multimodal multi-vector embeddings for both text and images. - Added a new embedding class for generating multi-vector embeddings, configurable for various model and processing options. - Added a new Pydantic type for multi-vector embeddings, supporting validation and schema generation for lists of fixed-dimension vectors. - **Bug Fixes** - Ensured proper asynchronous index creation in query tests for improved reliability. - **Tests** - Added integration tests for ColPali embeddings, including text-to-image search and validation of multi-vector fields. - Added comprehensive tests for the new multi-vector Pydantic type, covering schema, validation, and default value behavior. - **Chores** - Updated optional dependencies to include the ColPali engine. - Added utility to check for availability of flash attention support. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -77,6 +77,7 @@ embeddings = [
|
||||
"pillow",
|
||||
"open-clip-torch",
|
||||
"cohere",
|
||||
"colpali-engine>=0.3.10",
|
||||
"huggingface_hub",
|
||||
"InstructorEmbedding",
|
||||
"google.generativeai",
|
||||
|
||||
@@ -19,3 +19,4 @@ from .imagebind import ImageBindEmbeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
from .colpali import ColPaliEmbeddings
|
||||
|
||||
255
python/python/lancedb/embeddings/colpali.py
Normal file
255
python/python/lancedb/embeddings/colpali.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Union, Optional, Any
|
||||
import numpy as np
|
||||
import io
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, IMAGES, is_flash_attn_2_available
|
||||
|
||||
|
||||
@register("colpali")
|
||||
class ColPaliEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ColPali engine for
|
||||
multimodal multi-vector embeddings.
|
||||
|
||||
This embedding function supports ColQwen2.5 models, producing multivector outputs
|
||||
for both text and image inputs. The output embeddings are lists of vectors, each
|
||||
vector being 128-dimensional by default, represented as List[List[float]].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
The name of the model to use (e.g., "Metric-AI/ColQwen2.5-3b-multilingual-v1.0")
|
||||
device : str
|
||||
The device for inference (default "cuda:0").
|
||||
dtype : str
|
||||
Data type for model weights (default "bfloat16").
|
||||
use_token_pooling : bool
|
||||
Whether to use token pooling to reduce embedding size (default True).
|
||||
pool_factor : int
|
||||
Factor to reduce sequence length if token pooling is enabled (default 2).
|
||||
quantization_config : Optional[BitsAndBytesConfig]
|
||||
Quantization configuration for the model. (default None, bitsandbytes needed)
|
||||
batch_size : int
|
||||
Batch size for processing inputs (default 2).
|
||||
"""
|
||||
|
||||
model_name: str = "Metric-AI/ColQwen2.5-3b-multilingual-v1.0"
|
||||
device: str = "auto"
|
||||
dtype: str = "bfloat16"
|
||||
use_token_pooling: bool = True
|
||||
pool_factor: int = 2
|
||||
quantization_config: Optional[Any] = None
|
||||
batch_size: int = 2
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
_token_pooler = None
|
||||
_vector_dim = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
(
|
||||
self._model,
|
||||
self._processor,
|
||||
self._token_pooler,
|
||||
) = self._load_model(
|
||||
self.model_name,
|
||||
self.dtype,
|
||||
self.device,
|
||||
self.use_token_pooling,
|
||||
self.quantization_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def _load_model(
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
use_token_pooling: bool,
|
||||
quantization_config: Optional[Any],
|
||||
):
|
||||
"""
|
||||
Initialize and cache the ColPali model, processor, and token pooler.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
transformers = attempt_import_or_raise("transformers", "transformers")
|
||||
colpali_engine = attempt_import_or_raise("colpali_engine", "colpali_engine")
|
||||
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
|
||||
|
||||
if quantization_config is not None:
|
||||
if not isinstance(quantization_config, transformers.BitsAndBytesConfig):
|
||||
raise ValueError("quantization_config must be a BitsAndBytesConfig")
|
||||
|
||||
if dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif dtype == "float64":
|
||||
torch_dtype = torch.float64
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = colpali_engine.models.ColQwen2_5.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device,
|
||||
quantization_config=quantization_config
|
||||
if quantization_config is not None
|
||||
else None,
|
||||
attn_implementation="flash_attention_2"
|
||||
if is_flash_attn_2_available()
|
||||
else None,
|
||||
).eval()
|
||||
processor = colpali_engine.models.ColQwen2_5_Processor.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
token_pooler = HierarchicalTokenPooler() if use_token_pooling else None
|
||||
return model, processor, token_pooler
|
||||
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimension of a vector in the multivector output (e.g., 128).
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self._vector_dim is None:
|
||||
dummy_query = "test"
|
||||
batch_queries = self._processor.process_queries([dummy_query]).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
query_embeddings = self._token_pooler.pool_embeddings(
|
||||
query_embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
self._vector_dim = query_embeddings[0].shape[-1]
|
||||
return self._vector_dim
|
||||
|
||||
def _process_embeddings(self, embeddings):
|
||||
"""
|
||||
Format model embeddings into List[List[float]].
|
||||
Use token pooling if enabled.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
if self.use_token_pooling and self._token_pooler is not None:
|
||||
embeddings = self._token_pooler.pool_embeddings(
|
||||
embeddings,
|
||||
pool_factor=self.pool_factor,
|
||||
padding=True,
|
||||
padding_side=self._processor.tokenizer.padding_side,
|
||||
)
|
||||
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
tensors = embeddings.detach().cpu()
|
||||
if tensors.dtype == torch.bfloat16:
|
||||
tensors = tensors.to(torch.float32)
|
||||
return (
|
||||
tensors.numpy()
|
||||
.astype(np.float64 if self.dtype == "float64" else np.float32)
|
||||
.tolist()
|
||||
)
|
||||
return []
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> List[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for text input.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
text = self.sanitize_input(text)
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(text), self.batch_size):
|
||||
batch_text = text[i : i + self.batch_size]
|
||||
batch_queries = self._processor.process_queries(batch_text).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
query_embeddings = self._model(**batch_queries)
|
||||
all_embeddings.extend(self._process_embeddings(query_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
def _prepare_images(self, images: IMAGES) -> List:
|
||||
"""
|
||||
Convert image inputs to PIL Images.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
requests = attempt_import_or_raise("requests", "requests")
|
||||
images = self.sanitize_input(images)
|
||||
pil_images = []
|
||||
try:
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
if image.startswith(("http://", "https://")):
|
||||
response = requests.get(image, timeout=10)
|
||||
response.raise_for_status()
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(response.content)))
|
||||
else:
|
||||
with PIL.Image.open(image) as im:
|
||||
pil_images.append(im.copy())
|
||||
elif isinstance(image, bytes):
|
||||
pil_images.append(PIL.Image.open(io.BytesIO(image)))
|
||||
else:
|
||||
# Assume it's a PIL Image; will raise if invalid
|
||||
pil_images.append(image)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to process image: {e}")
|
||||
|
||||
return pil_images
|
||||
|
||||
def generate_image_embeddings(self, images: IMAGES) -> List[List[List[float]]]:
|
||||
"""
|
||||
Generate embeddings for a batch of images.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
pil_images = self._prepare_images(images)
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(pil_images), self.batch_size):
|
||||
batch_images = pil_images[i : i + self.batch_size]
|
||||
batch_images = self._processor.process_images(batch_images).to(
|
||||
self._model.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
image_embeddings = self._model(**batch_images)
|
||||
all_embeddings.extend(self._process_embeddings(image_embeddings))
|
||||
return all_embeddings
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, IMAGES], *args, **kwargs
|
||||
) -> List[List[List[float]]]:
|
||||
"""
|
||||
Compute embeddings for a single user query (text only).
|
||||
"""
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
"Query must be a string, image to image search is not supported"
|
||||
)
|
||||
return self.generate_text_embeddings([query])
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[List[List[float]]]:
|
||||
"""
|
||||
Compute embeddings for a batch of source images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : Union[str, bytes, List, pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
Batch of images (paths, URLs, bytes, or PIL Images).
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
return self.generate_image_embeddings(images)
|
||||
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..dependencies import pandas as pd
|
||||
from ..util import attempt_import_or_raise
|
||||
|
||||
|
||||
# ruff: noqa: PERF203
|
||||
@@ -275,3 +276,12 @@ def url_retrieve(url: str):
|
||||
def api_key_not_found_help(provider):
|
||||
logging.error("Could not find API key for %s", provider)
|
||||
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||
|
||||
|
||||
def is_flash_attn_2_available():
|
||||
try:
|
||||
attempt_import_or_raise("flash_attn", "flash_attn")
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@@ -152,6 +152,104 @@ def Vector(
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
def MultiVector(
|
||||
dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
|
||||
) -> Type:
|
||||
"""Pydantic MultiVector Type for multi-vector embeddings.
|
||||
|
||||
This type represents a list of vectors, each with the same dimension.
|
||||
Useful for models that produce multiple embeddings per input, like ColPali.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim : int
|
||||
The dimension of each vector in the multi-vector.
|
||||
value_type : pyarrow.DataType, optional
|
||||
The value type of the vectors, by default pa.float32()
|
||||
nullable : bool, optional
|
||||
Whether the multi-vector is nullable, by default it is True.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import MultiVector
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... text: str
|
||||
... embeddings: MultiVector(128) # List of 128-dimensional vectors
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("text", pa.utf8(), False),
|
||||
... pa.field("embeddings", pa.list_(pa.list_(pa.float32(), 128)))
|
||||
... ])
|
||||
"""
|
||||
|
||||
class MultiVectorList(list, FixedSizeListMixin):
|
||||
def __repr__(self):
|
||||
return f"MultiVector(dim={dim})"
|
||||
|
||||
@staticmethod
|
||||
def nullable() -> bool:
|
||||
return nullable
|
||||
|
||||
@staticmethod
|
||||
def dim() -> int:
|
||||
return dim
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return value_type
|
||||
|
||||
@staticmethod
|
||||
def is_multi_vector() -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.list_schema(
|
||||
items_schema=core_schema.list_schema(
|
||||
min_length=dim,
|
||||
max_length=dim,
|
||||
items_schema=core_schema.float_schema(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, (list, range)):
|
||||
raise TypeError("A list of vectors is needed")
|
||||
for vec in v:
|
||||
if not isinstance(vec, (list, range, np.ndarray)) or len(vec) != dim:
|
||||
raise TypeError(f"Each vector must be a list of {dim} numbers")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["items"] = {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"minItems": dim,
|
||||
"maxItems": dim,
|
||||
}
|
||||
|
||||
return MultiVectorList
|
||||
|
||||
|
||||
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a field with native Python type to Arrow data type.
|
||||
|
||||
@@ -206,6 +304,9 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||
fields = _pydantic_model_to_fields(tp)
|
||||
return pa.struct(fields)
|
||||
if issubclass(tp, FixedSizeListMixin):
|
||||
if getattr(tp, "is_multi_vector", lambda: False)():
|
||||
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
|
||||
# For regular Vector
|
||||
return pa.list_(tp.value_arrow_type(), tp.dim())
|
||||
return _py_type_to_arrow_type(tp, field)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Vector, MultiVector
|
||||
import requests
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
@@ -575,3 +575,67 @@ def test_voyageai_multimodal_embedding_text_function():
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("colpali_engine") is None,
|
||||
reason="colpali_engine not installed",
|
||||
)
|
||||
def test_colpali(tmp_path):
|
||||
import requests
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("colpali").create()
|
||||
|
||||
class MediaItems(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
image_vectors: MultiVector(func.ndims()) = (
|
||||
func.VectorField()
|
||||
) # Multivector image embeddings
|
||||
|
||||
table = db.create_table("media", schema=MediaItems)
|
||||
|
||||
texts = [
|
||||
"a cute cat playing with yarn",
|
||||
"a puppy in a flower field",
|
||||
"a red sports car on the highway",
|
||||
"a vintage bicycle leaning against a wall",
|
||||
"a plate of delicious pasta",
|
||||
"fresh fruit salad in a bowl",
|
||||
]
|
||||
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
|
||||
# Get images as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
|
||||
table.add(
|
||||
pd.DataFrame({"text": texts, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# Test text-to-image search
|
||||
image_results = (
|
||||
table.search("fluffy companion", vector_column_name="image_vectors")
|
||||
.limit(1)
|
||||
.to_pydantic(MediaItems)[0]
|
||||
)
|
||||
assert "cat" in image_results.text.lower() or "puppy" in image_results.text.lower()
|
||||
|
||||
# Verify multivector dimensions
|
||||
first_row = table.to_arrow().to_pylist()[0]
|
||||
assert len(first_row["image_vectors"]) > 1, "Should have multiple image vectors"
|
||||
assert len(first_row["image_vectors"][0]) == func.ndims(), (
|
||||
"Vector dimension mismatch"
|
||||
)
|
||||
|
||||
@@ -9,7 +9,13 @@ from typing import List, Optional, Tuple
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
MultiVector,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
@@ -354,3 +360,55 @@ def test_optional_nested_model():
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_multi_vector():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: MultiVector(8)
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == pa.schema(
|
||||
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
|
||||
)
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=[[1.0] * 7])
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=[[1.0] * 9])
|
||||
|
||||
TestModel(vec=[[1.0] * 8])
|
||||
TestModel(vec=[[1.0] * 8, [2.0] * 8])
|
||||
|
||||
TestModel(vec=[])
|
||||
|
||||
|
||||
def test_multi_vector_nullable():
|
||||
class NullableModel(pydantic.BaseModel):
|
||||
vec: MultiVector(16, nullable=False)
|
||||
|
||||
schema = pydantic_to_schema(NullableModel)
|
||||
assert schema == pa.schema(
|
||||
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
|
||||
)
|
||||
|
||||
class DefaultModel(pydantic.BaseModel):
|
||||
vec: MultiVector(16)
|
||||
|
||||
schema = pydantic_to_schema(DefaultModel)
|
||||
assert schema == pa.schema(
|
||||
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
|
||||
)
|
||||
|
||||
|
||||
def test_multi_vector_in_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
id: int
|
||||
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["id", "vectors"]
|
||||
|
||||
t = TestModel(id=1)
|
||||
assert t.vectors == [[0.0] * 16]
|
||||
|
||||
@@ -257,7 +257,9 @@ async def test_distance_range_with_new_rows_async():
|
||||
}
|
||||
)
|
||||
table = await conn.create_table("test", data)
|
||||
table.create_index("vector", config=IvfPq(num_partitions=1, num_sub_vectors=2))
|
||||
await table.create_index(
|
||||
"vector", config=IvfPq(num_partitions=1, num_sub_vectors=2)
|
||||
)
|
||||
|
||||
q = [0, 0]
|
||||
rs = await table.query().nearest_to(q).to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user