mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
Compare commits
6 Commits
600
...
saas-searc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de14120bbe | ||
|
|
fa342e7df4 | ||
|
|
6689192cee | ||
|
|
dbec598610 | ||
|
|
8f6e7ce4f3 | ||
|
|
b482f41bf4 |
@@ -224,7 +224,6 @@ This embedding function supports ingesting images as both bytes and urls. You ca
|
||||
!!! info
|
||||
LanceDB supports ingesting images directly from accessible links.
|
||||
|
||||
|
||||
```python
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -290,4 +289,67 @@ print(actual.label)
|
||||
|
||||
```
|
||||
|
||||
### Imagebind embeddings
|
||||
We have support for [imagebind](https://github.com/facebookresearch/ImageBind) model embeddings. You can download our version of the packaged model via - `pip install imagebind-packaged==0.1.2`.
|
||||
|
||||
This function is registered as `imagebind` and supports Audio, Video and Text modalities(extending to Thermal,Depth,IMU data):
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | `str` | `"imagebind_huge"` | Name of the model. |
|
||||
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
|
||||
| `normalize` | `bool` | `False` | set to `True` to normalize your inputs before model ingestion. |
|
||||
|
||||
Below is an example demonstrating how the API works:
|
||||
|
||||
```python
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("imagebind").create()
|
||||
|
||||
class ImageBindModel(LanceModel):
|
||||
text: str
|
||||
image_uri: str = func.SourceField()
|
||||
audio_path: str
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
# add locally accessible image paths
|
||||
text_list=["A dog.", "A car", "A bird"]
|
||||
image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
|
||||
audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]
|
||||
|
||||
# Load data
|
||||
inputs = [
|
||||
{"text": a, "audio_path": b, "image_uri": c}
|
||||
for a, b, c in zip(text_list, audio_paths, image_paths)
|
||||
]
|
||||
|
||||
#create table and add data
|
||||
table = db.create_table("img_bind", schema=ImageBindModel)
|
||||
table.add(inputs)
|
||||
```
|
||||
|
||||
Now, we can search using any modality:
|
||||
|
||||
#### image search
|
||||
```python
|
||||
query_image = "./assets/dog_image2.jpg" #download an image and enter that path here
|
||||
actual = table.search(query_image).limit(1).to_pydantic(ImageBindModel)[0]
|
||||
print(actual.text == "dog")
|
||||
```
|
||||
#### audio search
|
||||
|
||||
```python
|
||||
query_audio = "./assets/car_audio2.wav" #download an audio clip and enter path here
|
||||
actual = table.search(query_audio).limit(1).to_pydantic(ImageBindModel)[0]
|
||||
print(actual.text == "car")
|
||||
```
|
||||
#### Text search
|
||||
You can add any input query and fetch the result as follows:
|
||||
```python
|
||||
query = "an animal which flies and tweets"
|
||||
actual = table.search(query).limit(1).to_pydantic(ImageBindModel)[0]
|
||||
print(actual.text == "bird")
|
||||
```
|
||||
|
||||
If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue [on GitHub](https://github.com/lancedb/lancedb/issues).
|
||||
|
||||
569
docs/src/notebooks/multi_modal_video_RAG.ipynb
Normal file
569
docs/src/notebooks/multi_modal_video_RAG.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -57,7 +57,6 @@ tests = [
|
||||
"duckdb",
|
||||
"pytz",
|
||||
"polars>=0.19",
|
||||
"pillow",
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = [
|
||||
|
||||
@@ -31,7 +31,7 @@ class ImageBindEmbeddings(EmbeddingFunction):
|
||||
six different modalities: images, text, audio, depth, thermal, and IMU data
|
||||
|
||||
to download package, run :
|
||||
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
|
||||
`pip install imagebind-packaged==0.1.2`
|
||||
"""
|
||||
|
||||
name: str = "imagebind_huge"
|
||||
|
||||
@@ -126,10 +126,6 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
return [
|
||||
self.generate_image_embedding(image) for image in tqdm(images)
|
||||
]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
|
||||
@@ -113,5 +113,5 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
if self.organization:
|
||||
kwargs["organization"] = self.organization
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self
|
||||
kwargs["api_key"] = self.api_key
|
||||
return openai.OpenAI(**kwargs)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import io
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -27,9 +26,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
@@ -39,30 +36,19 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
from lance.arrow import (
|
||||
EncodedImageType,
|
||||
)
|
||||
from lance.util import _check_huggingface
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from .util import attempt_import_or_raise
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
try:
|
||||
from pydantic import GetJsonSchemaHandler
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema
|
||||
except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
|
||||
|
||||
class FixedSizeListMixin(ABC):
|
||||
@staticmethod
|
||||
@@ -137,7 +123,7 @@ def Vector(
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> "CoreSchema":
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.list_schema(
|
||||
@@ -195,118 +181,25 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||
child = py_type.__args__[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
elif _safe_is_huggingface_image():
|
||||
import datasets
|
||||
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||
)
|
||||
|
||||
|
||||
class ImageMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def EncodedImage():
|
||||
attempt_import_or_raise("PIL", "pillow or pip install lancedb[embeddings]")
|
||||
import PIL.Image
|
||||
|
||||
class EncodedImage(bytes, ImageMixin):
|
||||
"""Pydantic type for inlined images.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import EncodedImage
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... image: EncodedImage()
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("image", pa.binary(), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return "EncodedImage()"
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return EncodedImageType()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> "CoreSchema":
|
||||
from_bytes_schema = core_schema.bytes_schema()
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=from_bytes_schema,
|
||||
python_schema=core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(PIL.Image.Image),
|
||||
from_bytes_schema,
|
||||
]
|
||||
),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda instance: cls.validate(instance)
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, _core_schema: "CoreSchema", handler: "GetJsonSchemaHandler"
|
||||
) -> "JsonSchemaValue":
|
||||
return handler(core_schema.bytes_schema())
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v2
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
if isinstance(v, PIL.Image.Image):
|
||||
with io.BytesIO() as output:
|
||||
v.save(output, format=v.format)
|
||||
return output.getvalue()
|
||||
raise TypeError(
|
||||
"EncodedImage can take bytes or PIL.Image.Image "
|
||||
f"as input but got {type(v)}"
|
||||
)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["type"] = "string"
|
||||
field_schema["format"] = "binary"
|
||||
|
||||
return EncodedImage
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
def _safe_get_fields(model: pydantic.BaseModel):
|
||||
return model.__fields__
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
|
||||
]
|
||||
|
||||
else:
|
||||
def _safe_get_fields(model: pydantic.BaseModel):
|
||||
return model.model_fields
|
||||
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field)
|
||||
for name, field in _safe_get_fields(model).items()
|
||||
]
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field)
|
||||
for name, field in model.model_fields.items()
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
@@ -337,9 +230,6 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
elif issubclass(field.annotation, ImageMixin):
|
||||
return field.annotation.value_arrow_type()
|
||||
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
@@ -445,7 +335,13 @@ class LanceModel(pydantic.BaseModel):
|
||||
"""
|
||||
Get the field names of this model.
|
||||
"""
|
||||
return list(_safe_get_fields(cls).keys())
|
||||
return list(cls.safe_get_fields().keys())
|
||||
|
||||
@classmethod
|
||||
def safe_get_fields(cls):
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return cls.__fields__
|
||||
return cls.model_fields
|
||||
|
||||
@classmethod
|
||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||
@@ -455,16 +351,14 @@ class LanceModel(pydantic.BaseModel):
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
vec_and_function = []
|
||||
def get_vector_column(name, field_info):
|
||||
fun = get_extras(field_info, "vector_column_for")
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = get_extras(field_info, "vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
|
||||
|
||||
configs = []
|
||||
# find the source columns for each one
|
||||
for vec, func in vec_and_function:
|
||||
def get_source_column(source, field_info):
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func is func:
|
||||
# note we can't use == here since the function is a pydantic
|
||||
@@ -477,48 +371,20 @@ class LanceModel(pydantic.BaseModel):
|
||||
source_column=source, vector_column=vec, function=func
|
||||
)
|
||||
)
|
||||
visit_fields(_safe_get_fields(cls).items(), get_source_column)
|
||||
return configs
|
||||
|
||||
|
||||
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
|
||||
visitor: Callable[[str, FieldInfo], Any]):
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Visit all the leaf fields in a Pydantic model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fields : Iterable[Tuple(str, FieldInfo)]
|
||||
The fields to visit.
|
||||
visitor : Callable[[str, FieldInfo], Any]
|
||||
The visitor function.
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
for name, field_info in fields:
|
||||
# if the field is a pydantic model then
|
||||
# visit all subfields
|
||||
if (isinstance(getattr(field_info, "annotation"), type)
|
||||
and issubclass(field_info.annotation, pydantic.BaseModel)):
|
||||
visit_fields(_safe_get_fields(field_info.annotation).items(),
|
||||
_add_prefix(visitor, name))
|
||||
else:
|
||||
visitor(name, field_info)
|
||||
|
||||
|
||||
def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
|
||||
def prefixed_visitor(name: str, field: FieldInfo):
|
||||
return visitor(f"{prefix}.{name}", field)
|
||||
return prefixed_visitor
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
@@ -527,13 +393,6 @@ if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
else:
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
|
||||
@@ -271,7 +271,8 @@ class LanceQueryBuilder(ABC):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# raise NotImplementedError
|
||||
self.to_arrow()
|
||||
|
||||
def to_list(self) -> List[dict]:
|
||||
"""
|
||||
@@ -434,12 +435,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
|
||||
@@ -68,10 +68,16 @@ class RemoteTable(Table):
|
||||
|
||||
def list_indices(self):
|
||||
"""List all the indices on the table"""
|
||||
print(self._name)
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
||||
return resp
|
||||
|
||||
def index_stats(self, index_uuid: str):
|
||||
"""List all the indices on the table"""
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
|
||||
)
|
||||
return resp
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
@@ -290,6 +296,7 @@ class RemoteTable(Table):
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
print("query metric", query.metric)
|
||||
if (
|
||||
query.vector is not None
|
||||
and len(query.vector) > 0
|
||||
|
||||
@@ -37,7 +37,6 @@ import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
from lance import LanceDataset
|
||||
from lance.dependencies import _check_for_hugging_face
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
@@ -75,16 +74,7 @@ def _sanitize_data(
|
||||
on_bad_vectors: str,
|
||||
fill_value: Any,
|
||||
):
|
||||
import pdb; pdb.set_trace()
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
import datasets
|
||||
|
||||
if isinstance(data, datasets.Dataset):
|
||||
if schema is None:
|
||||
schema = data.features.arrow_schema
|
||||
data = data.data.to_batches()
|
||||
elif isinstance(data, list):
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
@@ -128,7 +118,8 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
no_vector_column = vector_column not in data.column_names
|
||||
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
|
||||
col_data = func.compute_source_embeddings_with_retry(
|
||||
data[conf.source_column]
|
||||
)
|
||||
@@ -136,9 +127,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
data = data.append_column(
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
if no_vector_column:
|
||||
data = data.append_column(
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
else:
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column),
|
||||
pa.field(vector_column, type=dtype),
|
||||
pa.array(col_data, type=dtype),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -146,10 +144,11 @@ def _to_record_batch_generator(
|
||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||
):
|
||||
for batch in data:
|
||||
if isinstance(batch, pa.RecordBatch):
|
||||
batch = pa.Table.from_batches([batch])
|
||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for batch in table.to_batches():
|
||||
if not isinstance(batch, pa.RecordBatch):
|
||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for batch in table.to_batches():
|
||||
yield batch
|
||||
else:
|
||||
yield batch
|
||||
|
||||
|
||||
@@ -1523,7 +1522,7 @@ class LanceTable(Table):
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
|
||||
print("metric:", query.metric)
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 83 B |
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
from typing import List, Union
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
@@ -23,6 +24,8 @@ from lancedb.embeddings import (
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
@@ -112,3 +115,34 @@ def test_embedding_function_rate_limit(tmp_path):
|
||||
table.add([{"text": "hello world"}])
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 2
|
||||
|
||||
|
||||
def test_add_optional_vector(tmp_path):
|
||||
@register("mock-embedding")
|
||||
class MockEmbeddingFunction(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
|
||||
|
||||
registry = get_registry()
|
||||
model = registry.get("mock-embedding").create()
|
||||
|
||||
class LanceSchema(LanceModel):
|
||||
id: str
|
||||
vector: Vector(model.ndims()) = model.VectorField(default=None)
|
||||
text: str = model.SourceField()
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("optional_vector", schema=LanceSchema)
|
||||
|
||||
# add works
|
||||
expected = LanceSchema(id="id", text="text")
|
||||
tbl.add([expected])
|
||||
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()
|
||||
|
||||
@@ -12,27 +12,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
EncodedImage,
|
||||
LanceModel,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9),
|
||||
@@ -253,23 +243,3 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_schema_with_images():
|
||||
pytest.importorskip("PIL")
|
||||
import PIL.Image
|
||||
|
||||
class TestModel(LanceModel):
|
||||
img: EncodedImage()
|
||||
|
||||
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
|
||||
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||
m2 = TestModel(img=img_bytes)
|
||||
|
||||
def tobytes(m):
|
||||
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||
|
||||
assert tobytes(m1) == tobytes(m2)
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
from copy import copy
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
@@ -22,21 +20,19 @@ from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lance.arrow import EncodedImageType
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import AsyncConnection, LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import EncodedImage, LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MockDB:
|
||||
@@ -112,7 +108,6 @@ def test_create_table(db):
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
pa.field("encoded_image", EncodedImageType()),
|
||||
]
|
||||
)
|
||||
expected = pa.Table.from_arrays(
|
||||
@@ -120,26 +115,13 @@ def test_create_table(db):
|
||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
pa.ExtensionArray.from_storage(
|
||||
EncodedImageType(), pa.array([b"foo", b"bar"], pa.binary())
|
||||
),
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
data = [
|
||||
[
|
||||
{
|
||||
"vector": [3.1, 4.1],
|
||||
"item": "foo",
|
||||
"price": 10.0,
|
||||
"encoded_image": b"foo",
|
||||
},
|
||||
{
|
||||
"vector": [5.9, 26.5],
|
||||
"item": "bar",
|
||||
"price": 20.0,
|
||||
"encoded_image": b"bar",
|
||||
},
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
]
|
||||
df = pd.DataFrame(data[0])
|
||||
@@ -1043,22 +1025,3 @@ async def test_time_travel(db_async: AsyncConnection):
|
||||
# Can't use restore if not checked out
|
||||
with pytest.raises(ValueError, match="checkout before running restore"):
|
||||
await table.restore()
|
||||
|
||||
|
||||
def test_add_image(tmp_path):
|
||||
pytest.importorskip("PIL")
|
||||
import PIL.Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
img: EncodedImage()
|
||||
|
||||
img_path = Path(os.path.dirname(__file__)) / "images/1.png"
|
||||
m1 = TestModel(img=PIL.Image.open(img_path))
|
||||
|
||||
def tobytes(m):
|
||||
return PIL.Image.open(io.BytesIO(m.model_dump()["img"])).tobytes()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=TestModel)
|
||||
table.add([m1])
|
||||
|
||||
Reference in New Issue
Block a user