mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat: add a basic async python client starting point (#1014)
This changes `lancedb` from a "pure python" setuptools project to a maturin project and adds a rust lancedb dependency. The async python client is extremely minimal (only `connect` and `Connection.table_names` are supported). The purpose of this PR is to get the infrastructure in place for building out the rest of the async client. Although this is not technically a breaking change (no APIs are changing) it is still a considerable change in the way the wheels are built because they now include the native shared library.
This commit is contained in:
174
python/python/lancedb/__init__.py
Normal file
174
python/python/lancedb/__init__.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector # noqa: F401
|
||||
|
||||
|
||||
def connect(
|
||||
uri: URI,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str or Path
|
||||
The uri of the database.
|
||||
api_key: str, optional
|
||||
If presented, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
For a local directory, provide a path for the database:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("~/.lancedb")
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
Returns
|
||||
-------
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
if isinstance(uri, str) and uri.startswith("db://"):
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
if isinstance(request_thread_pool, int):
|
||||
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
||||
return RemoteDBConnection(
|
||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||
)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
|
||||
async def connect_async(
|
||||
uri: URI,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str or Path
|
||||
The uri of the database.
|
||||
api_key: str, optional
|
||||
If present, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
For a local directory, provide a path for the database:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("~/.lancedb")
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
Returns
|
||||
-------
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
return AsyncLanceDBConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
|
||||
)
|
||||
)
|
||||
12
python/python/lancedb/_lancedb.pyi
Normal file
12
python/python/lancedb/_lancedb.pyi
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
class Connection(object):
|
||||
async def table_names(self) -> list[str]: ...
|
||||
|
||||
async def connect(
|
||||
uri: str,
|
||||
api_key: Optional[str],
|
||||
region: Optional[str],
|
||||
host_override: Optional[str],
|
||||
read_consistency_interval: Optional[float],
|
||||
) -> Connection: ...
|
||||
40
python/python/lancedb/common.py
Normal file
40
python/python/lancedb/common.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
URI = Union[str, Path]
|
||||
VECTOR_COLUMN_NAME = "vector"
|
||||
|
||||
|
||||
class Credential(str):
|
||||
"""Credential field"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "********"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "********"
|
||||
|
||||
|
||||
def sanitize_uri(uri: URI) -> str:
|
||||
return str(uri)
|
||||
64
python/python/lancedb/conftest.py
Normal file
64
python/python/lancedb/conftest.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
|
||||
|
||||
# import lancedb so we don't have to in every example
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def doctest_setup(monkeypatch, tmpdir):
|
||||
# disable color for doctests so we don't have to include
|
||||
# escape codes in docstrings
|
||||
monkeypatch.setitem(os.environ, "NO_COLOR", "1")
|
||||
# Explicitly set the column width
|
||||
monkeypatch.setitem(os.environ, "COLUMNS", "80")
|
||||
# Work in a temporary directory
|
||||
monkeypatch.chdir(tmpdir)
|
||||
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
|
||||
@registry.register("test")
|
||||
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
Return the hash of the first 10 characters
|
||||
"""
|
||||
|
||||
def generate_embeddings(self, texts):
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
def _compute_one_embedding(self, row):
|
||||
emb = np.array([float(hash(c)) for c in row[:10]])
|
||||
emb /= np.linalg.norm(emb)
|
||||
return emb
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
|
||||
class RateLimitedAPI:
|
||||
rate_limit = 0.1 # 1 request per 0.1 second
|
||||
last_request_time = 0
|
||||
|
||||
@staticmethod
|
||||
def make_request():
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit:
|
||||
raise Exception("Rate limit exceeded. Please try again later.")
|
||||
|
||||
# Simulate a successful request
|
||||
RateLimitedAPI.last_request_time = current_time
|
||||
return "Request successful"
|
||||
|
||||
|
||||
@registry.register("test-rate-limited")
|
||||
class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction):
|
||||
def generate_embeddings(self, texts):
|
||||
RateLimitedAPI.make_request()
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
246
python/python/lancedb/context.py
Normal file
246
python/python/lancedb/context.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import deprecation
|
||||
|
||||
from . import __version__
|
||||
from .exceptions import MissingColumnError, MissingValueError
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
"""Create a Contextualizer object for the given DataFrame.
|
||||
|
||||
Used to create context windows. Context windows are rolling subsets of text
|
||||
data.
|
||||
|
||||
The input text column should already be separated into rows that will be the
|
||||
unit of the window. So to create a context window over tokens, start with
|
||||
a DataFrame with one token per row. To create a context window over sentences,
|
||||
start with a DataFrame with one sentence per row.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.context import contextualize
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... 'token': ['The', 'quick', 'brown', 'fox', 'jumped', 'over',
|
||||
... 'the', 'lazy', 'dog', 'I', 'love', 'sandwiches'],
|
||||
... 'document_id': [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
... })
|
||||
|
||||
``window`` determines how many rows to include in each window. In our case
|
||||
this how many tokens, but depending on the input data, it could be sentences,
|
||||
paragraphs, messages, etc.
|
||||
|
||||
>>> contextualize(data).window(3).stride(1).text_col('token').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown 1
|
||||
1 quick brown fox 1
|
||||
2 brown fox jumped 1
|
||||
3 fox jumped over 1
|
||||
4 jumped over the 1
|
||||
5 over the lazy 1
|
||||
6 the lazy dog 1
|
||||
7 lazy dog I 1
|
||||
8 dog I love 1
|
||||
9 I love sandwiches 2
|
||||
10 love sandwiches 2
|
||||
>>> (contextualize(data).window(7).stride(1).min_window_size(7)
|
||||
... .text_col('token').to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over the 1
|
||||
1 quick brown fox jumped over the lazy 1
|
||||
2 brown fox jumped over the lazy dog 1
|
||||
3 fox jumped over the lazy dog I 1
|
||||
4 jumped over the lazy dog I love 1
|
||||
5 over the lazy dog I love sandwiches 1
|
||||
|
||||
``stride`` determines how many rows to skip between each window start. This can
|
||||
be used to reduce the total number of windows generated.
|
||||
|
||||
>>> contextualize(data).window(4).stride(2).text_col('token').to_pandas()
|
||||
token document_id
|
||||
0 The quick brown fox 1
|
||||
2 brown fox jumped over 1
|
||||
4 jumped over the lazy 1
|
||||
6 the lazy dog I 1
|
||||
8 dog I love sandwiches 1
|
||||
10 love sandwiches 2
|
||||
|
||||
``groupby`` determines how to group the rows. For example, we would like to have
|
||||
context windows that don't cross document boundaries. In this case, we can
|
||||
pass ``document_id`` as the group by.
|
||||
|
||||
>>> (contextualize(data)
|
||||
... .window(4).stride(2).text_col('token').groupby('document_id')
|
||||
... .to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox 1
|
||||
2 brown fox jumped over 1
|
||||
4 jumped over the lazy 1
|
||||
6 the lazy dog 1
|
||||
9 I love sandwiches 2
|
||||
|
||||
``min_window_size`` determines the minimum size of the context windows
|
||||
that are generated.This can be used to trim the last few context windows
|
||||
which have size less than ``min_window_size``.
|
||||
By default context windows of size 1 are skipped.
|
||||
|
||||
>>> (contextualize(data)
|
||||
... .window(6).stride(3).text_col('token').groupby('document_id')
|
||||
... .to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over 1
|
||||
3 fox jumped over the lazy dog 1
|
||||
6 the lazy dog 1
|
||||
9 I love sandwiches 2
|
||||
|
||||
>>> (contextualize(data)
|
||||
... .window(6).stride(3).min_window_size(4).text_col('token')
|
||||
... .groupby('document_id')
|
||||
... .to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over 1
|
||||
3 fox jumped over the lazy dog 1
|
||||
|
||||
"""
|
||||
return Contextualizer(raw_df)
|
||||
|
||||
|
||||
class Contextualizer:
|
||||
"""Create context windows from a DataFrame.
|
||||
See [lancedb.context.contextualize][].
|
||||
"""
|
||||
|
||||
def __init__(self, raw_df):
|
||||
self._text_col = None
|
||||
self._groupby = None
|
||||
self._stride = None
|
||||
self._window = None
|
||||
self._min_window_size = 2
|
||||
self._raw_df = raw_df
|
||||
|
||||
def window(self, window: int) -> Contextualizer:
|
||||
"""Set the window size. i.e., how many rows to include in each window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
window: int
|
||||
The window size.
|
||||
"""
|
||||
self._window = window
|
||||
return self
|
||||
|
||||
def stride(self, stride: int) -> Contextualizer:
|
||||
"""Set the stride. i.e., how many rows to skip between each window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stride: int
|
||||
The stride.
|
||||
"""
|
||||
self._stride = stride
|
||||
return self
|
||||
|
||||
def groupby(self, groupby: str) -> Contextualizer:
|
||||
"""Set the groupby column. i.e., how to group the rows.
|
||||
Windows don't cross groups
|
||||
|
||||
Parameters
|
||||
----------
|
||||
groupby: str
|
||||
The groupby column.
|
||||
"""
|
||||
self._groupby = groupby
|
||||
return self
|
||||
|
||||
def text_col(self, text_col: str) -> Contextualizer:
|
||||
"""Set the text column used to make the context window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text_col: str
|
||||
The text column.
|
||||
"""
|
||||
self._text_col = text_col
|
||||
return self
|
||||
|
||||
def min_window_size(self, min_window_size: int) -> Contextualizer:
|
||||
"""Set the (optional) min_window_size size for the context window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
min_window_size: int
|
||||
The min_window_size.
|
||||
"""
|
||||
self._min_window_size = min_window_size
|
||||
return self
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
removed_in="0.4.0",
|
||||
current_version=__version__,
|
||||
details="Use to_pandas() instead",
|
||||
)
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
return self.to_pandas()
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Create the context windows and return a DataFrame."""
|
||||
if pd is None:
|
||||
raise ImportError(
|
||||
"pandas is required to create context windows using lancedb"
|
||||
)
|
||||
|
||||
if self._text_col not in self._raw_df.columns.tolist():
|
||||
raise MissingColumnError(self._text_col)
|
||||
|
||||
if self._window is None or self._window < 1:
|
||||
raise MissingValueError(
|
||||
"The value of window is None or less than 1. Specify the "
|
||||
"window size (number of rows to include in each window)"
|
||||
)
|
||||
|
||||
if self._stride is None or self._stride < 1:
|
||||
raise MissingValueError(
|
||||
"The value of stride is None or less than 1. Specify the "
|
||||
"stride (number of rows to skip between each window)"
|
||||
)
|
||||
|
||||
def process_group(grp):
|
||||
# For each group, create the text rolling window
|
||||
# with values of size >= min_window_size
|
||||
text = grp[self._text_col].values
|
||||
contexts = grp.iloc[:: self._stride, :].copy()
|
||||
windows = [
|
||||
" ".join(text[start_i : min(start_i + self._window, len(grp))])
|
||||
for start_i in range(0, len(grp), self._stride)
|
||||
if start_i + self._window <= len(grp)
|
||||
or len(grp) - start_i >= self._min_window_size
|
||||
]
|
||||
# if last few rows dropped
|
||||
if len(windows) < len(contexts):
|
||||
contexts = contexts.iloc[: len(windows)]
|
||||
contexts[self._text_col] = windows
|
||||
return contexts
|
||||
|
||||
if self._groupby is None:
|
||||
return process_group(self._raw_df)
|
||||
# concat result from all groups
|
||||
return pd.concat(
|
||||
[process_group(grp) for _, grp in self._raw_df.groupby(self._groupby)]
|
||||
)
|
||||
673
python/python/lancedb/db.py
Normal file
673
python/python/lancedb/db.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import EnforceOverrides, override
|
||||
from pyarrow import fs
|
||||
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from ._lancedb import Connection as LanceDbConnection
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
|
||||
|
||||
class DBConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
Only supported by LanceDb Cloud.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
Only supported by LanceDb Cloud.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str; default "create"
|
||||
The mode to use when creating the table.
|
||||
Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
exist_ok: bool, default False
|
||||
If a table by the same name already exists, then raise an exception
|
||||
if exist_ok=False. If exist_ok=True, then open the existing table;
|
||||
it will not add the provided data but will validate against any
|
||||
schema that's specified.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
control over how data is saved, either provide the PyArrow schema to
|
||||
convert to or else provide a [PyArrow Table](pyarrow.Table) directly.
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: float
|
||||
long: float
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
|
||||
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> def make_batches():
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
... ["vector", "item", "price"],
|
||||
... )
|
||||
>>> schema=pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, name: str) -> LanceTable:
|
||||
return self.open_table(name)
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
A connection to a LanceDB database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str or Path
|
||||
The root uri of the database.
|
||||
read_consistency_interval: timedelta, default None
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(connection=..., name="another_table")
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
>>> db.drop_table("another_table")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||
):
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
if is_local:
|
||||
if isinstance(uri, str):
|
||||
uri = Path(uri)
|
||||
uri = uri.expanduser().absolute()
|
||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||
self._uri = str(uri)
|
||||
|
||||
self._entered = False
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}({self._uri}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
@override
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""Get the names of all tables in the database. The names are sorted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterator of str.
|
||||
A list of table names.
|
||||
"""
|
||||
try:
|
||||
filesystem = fs_from_uri(self.uri)[0]
|
||||
except pa.ArrowInvalid:
|
||||
raise NotImplementedError("Unsupported scheme: " + self.uri)
|
||||
|
||||
try:
|
||||
loc = get_uri_location(self.uri)
|
||||
paths = filesystem.get_file_info(fs.FileSelector(loc))
|
||||
except FileNotFoundError:
|
||||
# It is ok if the file does not exist since it will be created
|
||||
paths = []
|
||||
tables = [
|
||||
os.path.splitext(file_info.base_name)[0]
|
||||
for file_info in paths
|
||||
if file_info.extension == "lance"
|
||||
]
|
||||
tables.sort()
|
||||
return tables
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.table_names())
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.table_names()
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
See
|
||||
---
|
||||
DBConnection.create_table
|
||||
"""
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
|
||||
tbl = LanceTable.create(
|
||||
self,
|
||||
name,
|
||||
data,
|
||||
schema,
|
||||
mode=mode,
|
||||
exist_ok=exist_ok,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
)
|
||||
return tbl
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> LanceTable:
|
||||
"""Open a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
return LanceTable.open(self, name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
try:
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
table_path = join_uri(path, name + ".lance")
|
||||
filesystem.delete_dir(table_path)
|
||||
except FileNotFoundError:
|
||||
if not ignore_missing:
|
||||
raise
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
|
||||
class AsyncConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def table_names(
|
||||
self, *, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
Only supported by LanceDb Cloud.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
Only supported by LanceDb Cloud.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str; default "create"
|
||||
The mode to use when creating the table.
|
||||
Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
exist_ok: bool, default False
|
||||
If a table by the same name already exists, then raise an exception
|
||||
if exist_ok=False. If exist_ok=True, then open the existing table;
|
||||
it will not add the provided data but will validate against any
|
||||
schema that's specified.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
control over how data is saved, either provide the PyArrow schema to
|
||||
convert to or else provide a [PyArrow Table](pyarrow.Table) directly.
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: float
|
||||
long: float
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
|
||||
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> def make_batches():
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
... ["vector", "item", "price"],
|
||||
... )
|
||||
>>> schema=pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncLanceDBConnection(AsyncConnection):
|
||||
def __init__(self, connection: LanceDbConnection):
|
||||
self._inner = connection
|
||||
|
||||
async def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@override
|
||||
async def table_names(
|
||||
self,
|
||||
*,
|
||||
page_token=None,
|
||||
limit=None,
|
||||
) -> Iterable[str]:
|
||||
return await self._inner.table_names()
|
||||
|
||||
@override
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> LanceTable:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def open_table(self, name: str) -> LanceTable:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_database(self):
|
||||
raise NotImplementedError
|
||||
24
python/python/lancedb/embeddings/__init__.py
Normal file
24
python/python/lancedb/embeddings/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .bedrock import BedRockText
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .gemini_text import GeminiText
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .utils import with_embeddings
|
||||
161
python/python/lancedb/embeddings/base.py
Normal file
161
python/python/lancedb/embeddings/base.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
|
||||
max_retries: int = (
|
||||
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query with retries
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_query_embeddings, max_retries=self.max_retries
|
||||
)(
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database with retries
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_source_embeddings, max_retries=self.max_retries
|
||||
)(*args, **kwargs)
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not hasattr(__value, "__dict__"):
|
||||
return False
|
||||
return vars(self) == vars(__value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
224
python/python/lancedb/embeddings/bedrock.py
Normal file
224
python/python/lancedb/embeddings/bedrock.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT
|
||||
|
||||
|
||||
@register("bedrock-text")
|
||||
class BedRockText(TextEmbeddingFunction):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "amazon.titan-embed-text-v1"
|
||||
The model ID of the bedrock model to use. Supported models for are:
|
||||
- amazon.titan-embed-text-v1
|
||||
- cohere.embed-english-v3
|
||||
- cohere.embed-multilingual-v3
|
||||
region: str, default "us-east-1"
|
||||
Optional name of the AWS Region in which the service should be called.
|
||||
profile_name: str, default None
|
||||
Optional name of the AWS profile to use for calling the Bedrock service.
|
||||
If not specified, the default profile will be used.
|
||||
assumed_role: str, default None
|
||||
Optional ARN of an AWS IAM role to assume for calling the Bedrock service.
|
||||
If not specified, the current active credentials will be used.
|
||||
role_session_name: str, default "lancedb-embeddings"
|
||||
Optional name of the AWS IAM role session to use for calling the Bedrock
|
||||
service. If not specified, "lancedb-embeddings" name will be used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
model = get_registry().get("bedrock-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("tmp_path")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
|
||||
rs = tbl.search("hello").limit(1).to_pandas()
|
||||
"""
|
||||
|
||||
name: str = "amazon.titan-embed-text-v1"
|
||||
region: str = "us-east-1"
|
||||
assumed_role: Union[str, None] = None
|
||||
profile_name: Union[str, None] = None
|
||||
role_session_name: str = "lancedb-embeddings"
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# return len(self._generate_embedding("test"))
|
||||
# TODO: fix hardcoding
|
||||
if self.name == "amazon.titan-embed-text-v1":
|
||||
return 1536
|
||||
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Unknown model name: {self.name}")
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: str, *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
return self.compute_source_embeddings(query)
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, texts: TEXT, *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[list[float]]
|
||||
The embeddings for the given texts
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
response = self._generate_embedding(text)
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def _generate_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: str
|
||||
The texts to embed
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[float]
|
||||
The embeddings for the given texts
|
||||
"""
|
||||
# format input body for provider
|
||||
provider = self.name.split(".")[0]
|
||||
_model_kwargs = {}
|
||||
input_body = {**_model_kwargs}
|
||||
if provider == "cohere":
|
||||
if "input_type" not in input_body.keys():
|
||||
input_body["input_type"] = "search_document"
|
||||
input_body["texts"] = [text]
|
||||
else:
|
||||
# includes common provider == "amazon"
|
||||
input_body["inputText"] = text
|
||||
body = json.dumps(input_body)
|
||||
|
||||
try:
|
||||
# invoke bedrock API
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# format output based on provider
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if provider == "cohere":
|
||||
return response_body.get("embeddings")[0]
|
||||
else:
|
||||
# includes common provider == "amazon"
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
help_txt = """
|
||||
boto3 client failed to invoke the bedrock API. In case of
|
||||
AWS credentials error:
|
||||
- Please check your AWS credentials and ensure that you have access.
|
||||
You can set up aws credentials using `aws configure` command and
|
||||
verify by running `aws sts get-caller-identity` in your terminal.
|
||||
"""
|
||||
raise ValueError(f"Error raised by boto3 client: {e}. \n {help_txt}")
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""Create a boto3 client for Amazon Bedrock service
|
||||
|
||||
Returns
|
||||
-------
|
||||
boto3.client
|
||||
The boto3 client for Amazon Bedrock service
|
||||
"""
|
||||
botocore = attempt_import_or_raise("botocore")
|
||||
boto3 = attempt_import_or_raise("boto3")
|
||||
|
||||
session_kwargs = {"region_name": self.region}
|
||||
client_kwargs = {**session_kwargs}
|
||||
|
||||
if self.profile_name:
|
||||
session_kwargs["profile_name"] = self.profile_name
|
||||
|
||||
retry_config = botocore.config.Config(
|
||||
region_name=self.region,
|
||||
retries={
|
||||
"max_attempts": 0, # disable this as retries retries are handled
|
||||
"mode": "standard",
|
||||
},
|
||||
)
|
||||
session = (
|
||||
boto3.Session(**session_kwargs) if self.profile_name else boto3.Session()
|
||||
)
|
||||
if self.assumed_role: # if not using default credentials
|
||||
sts = session.client("sts")
|
||||
response = sts.assume_role(
|
||||
RoleArn=str(self.assumed_role),
|
||||
RoleSessionName=self.role_session_name,
|
||||
)
|
||||
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
|
||||
client_kwargs["aws_secret_access_key"] = response["Credentials"][
|
||||
"SecretAccessKey"
|
||||
]
|
||||
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
|
||||
|
||||
service_name = "bedrock-runtime"
|
||||
|
||||
bedrock_client = session.client(
|
||||
service_name=service_name, config=retry_config, **client_kwargs
|
||||
)
|
||||
|
||||
return bedrock_client
|
||||
92
python/python/lancedb/embeddings/cohere.py
Normal file
92
python/python/lancedb/embeddings/cohere.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
|
||||
|
||||
@register("cohere")
|
||||
class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the Cohere API
|
||||
|
||||
https://docs.cohere.com/docs/multilingual-language-models
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "embed-multilingual-v2.0"
|
||||
The name of the model to use. See the Cohere documentation for
|
||||
a list of available models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
cohere = EmbeddingFunctionRegistry
|
||||
.get_instance()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
vector: Vector(cohere.ndims()) = cohere.VectorField()
|
||||
|
||||
data = [ { "text": "hello world" },
|
||||
{ "text": "goodbye world" }]
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(data)
|
||||
|
||||
"""
|
||||
|
||||
name: str = "embed-multilingual-v2.0"
|
||||
client: ClassVar = None
|
||||
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
self._init_client()
|
||||
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
|
||||
|
||||
return [emb for emb in rs.embeddings]
|
||||
|
||||
def _init_client(self):
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
if CohereEmbeddingFunction.client is None:
|
||||
if os.environ.get("COHERE_API_KEY") is None:
|
||||
api_key_not_found_help("cohere")
|
||||
CohereEmbeddingFunction.client = cohere.Client(os.environ["COHERE_API_KEY"])
|
||||
142
python/python/lancedb/embeddings/gemini_text.py
Normal file
142
python/python/lancedb/embeddings/gemini_text.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, api_key_not_found_help
|
||||
|
||||
|
||||
@register("gemini-text")
|
||||
class GeminiText(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to
|
||||
be set.
|
||||
|
||||
https://ai.google.dev/docs/embeddings_guide
|
||||
|
||||
Supports various tasks types:
|
||||
| Task Type | Description |
|
||||
|-------------------------|--------------------------------------------------------|
|
||||
| "`retrieval_query`" | Specifies the given text is a query in a |
|
||||
| | search/retrieval setting. |
|
||||
| "`retrieval_document`" | Specifies the given text is a document in a |
|
||||
| | search/retrieval setting. Using this task type |
|
||||
| | requires a title but is automatically provided by |
|
||||
| | Embeddings API |
|
||||
| "`semantic_similarity`" | Specifies the given text will be used for Semantic |
|
||||
| | Textual Similarity (STS). |
|
||||
| "`classification`" | Specifies that the embeddings will be used for |
|
||||
| | classification. |
|
||||
| "`clustering`" | Specifies that the embeddings will be used for |
|
||||
| | clustering. |
|
||||
|
||||
Note: The supported task types might change in the Gemini API, but as long as a
|
||||
supported task type and its argument set is provided, those will be delegated
|
||||
to the API calls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "models/embedding-001"
|
||||
The name of the model to use. See the Gemini documentation for a list of
|
||||
available models.
|
||||
|
||||
query_task_type: str, default "retrieval_query"
|
||||
Sets the task type for the queries.
|
||||
source_task_type: str, default "retrieval_document"
|
||||
Sets the task type for ingestion.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
model = get_registry().get("gemini-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
rs = tbl.search("hello").limit(1).to_pandas()
|
||||
|
||||
"""
|
||||
|
||||
name: str = "models/embedding-001"
|
||||
query_task_type: str = "retrieval_query"
|
||||
source_task_type: str = "retrieval_document"
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, task_type=self.query_task_type)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
task_type = (
|
||||
kwargs.get("task_type") or self.source_task_type
|
||||
) # assume source task type if not passed by `compute_query_embeddings`
|
||||
return self.generate_embeddings(texts, task_type=task_type)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
if (
|
||||
kwargs.get("task_type") == "retrieval_document"
|
||||
): # Provide a title to use existing API design
|
||||
title = "Embedding of a document"
|
||||
kwargs["title"] = title
|
||||
|
||||
return [
|
||||
self.client.embed_content(model=self.name, content=text, **kwargs)[
|
||||
"embedding"
|
||||
]
|
||||
for text in texts
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
||||
|
||||
if not os.environ.get("GOOGLE_API_KEY"):
|
||||
api_key_not_found_help("google")
|
||||
return genai
|
||||
131
python/python/lancedb/embeddings/gte.py
Normal file
131
python/python/lancedb/embeddings/gte.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
|
||||
@register("gte-text")
|
||||
class GteEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
|
||||
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
|
||||
|
||||
For Apple users, you will need the mlx package insalled, which can be done with:
|
||||
pip install mlx
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "thenlper/gte-large"
|
||||
The name of the model to use.
|
||||
device: str, default "cpu"
|
||||
Sets the device type for the model.
|
||||
normalize: str, default "True"
|
||||
Controls normalize param in encode function for the transformer.
|
||||
mlx: bool, default False
|
||||
Controls which model to use. False for gte-large,True for the mlx version.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
import lancedb.embeddings.gte
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
import pandas as pd
|
||||
|
||||
model = get_registry().get("gte-text").create() # mlx=True for Apple silicon
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
rs = tbl.search("hello").limit(1).to_pandas()
|
||||
|
||||
"""
|
||||
|
||||
name: str = "thenlper/gte-large"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
mlx: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
if kwargs:
|
||||
self.mlx = kwargs.get("mlx", False)
|
||||
if self.mlx is True:
|
||||
self.name = "gte-mlx"
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the embedding model specified by the flag,
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
def ndims(self):
|
||||
if self.mlx is True:
|
||||
self._ndims = self.embedding_model.dims
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
if self.mlx is True:
|
||||
return self.embedding_model.run(list(texts)).tolist()
|
||||
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
Get the embedding model specified by the flag,
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
if self.mlx is True:
|
||||
from .gte_mlx_model import Model
|
||||
|
||||
return Model()
|
||||
else:
|
||||
sentence_transformers = attempt_import_or_raise(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(
|
||||
self.name, device=self.device
|
||||
)
|
||||
154
python/python/lancedb/embeddings/gte_mlx_model.py
Normal file
154
python/python/lancedb/embeddings/gte_mlx_model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from pydantic import BaseModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
except ImportError:
|
||||
raise ImportError("You need to install MLX to use this model use - pip install mlx")
|
||||
|
||||
|
||||
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
|
||||
last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None])
|
||||
return last_hidden.sum(axis=1) / attention_mask.sum(axis=1)[..., None]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
dim: int = 1024
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
vocab_size: int = 30522
|
||||
attention_probs_dropout_prob: float = 0.1
|
||||
hidden_dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-12
|
||||
max_position_embeddings: int = 512
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
A transformer encoder layer with (the original BERT) post-normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
mlp_dims: Optional[int] = None,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
||||
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||
self.linear2 = nn.Linear(mlp_dims, dims)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def __call__(self, x, mask):
|
||||
attention_out = self.attention(x, x, x, mask)
|
||||
add_and_norm = self.ln1(x + attention_out)
|
||||
|
||||
ff = self.linear1(add_and_norm)
|
||||
ff_gelu = self.gelu(ff)
|
||||
ff_out = self.linear2(ff_gelu)
|
||||
x = self.ln2(ff_out + add_and_norm)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.dim
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
||||
words = self.word_embeddings(input_ids)
|
||||
position = self.position_embeddings(
|
||||
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
||||
)
|
||||
token_types = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = position + words + token_types
|
||||
return self.norm(embeddings)
|
||||
|
||||
|
||||
class Bert(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = TransformerEncoder(
|
||||
num_layers=config.num_hidden_layers,
|
||||
dims=config.dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
)
|
||||
self.pooler = nn.Linear(config.dim, config.dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
token_type_ids: mx.array,
|
||||
attention_mask: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.embeddings(input_ids, token_type_ids)
|
||||
|
||||
if attention_mask is not None:
|
||||
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||
attention_mask = mx.log(attention_mask)
|
||||
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||
|
||||
y = self.encoder(x, attention_mask)
|
||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self) -> None:
|
||||
# get converted embedding model
|
||||
model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag")
|
||||
with open(f"{model_path}/config.json") as f:
|
||||
model_config = ModelConfig(**json.load(f))
|
||||
self.dims = model_config.dim
|
||||
self.model = Bert(model_config)
|
||||
self.model.load_weights(f"{model_path}/model.npz")
|
||||
self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large")
|
||||
self.embeddings = []
|
||||
|
||||
def run(self, input_text: List[str]) -> mx.array:
|
||||
tokens = self.tokenizer(input_text, return_tensors="np", padding=True)
|
||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||
|
||||
last_hidden_state, _ = self.model(**tokens)
|
||||
|
||||
embeddings = average_pool(
|
||||
last_hidden_state, tokens["attention_mask"].astype(mx.float32)
|
||||
)
|
||||
self.embeddings = (
|
||||
embeddings / mx.linalg.norm(embeddings, ord=2, axis=1)[..., None]
|
||||
)
|
||||
|
||||
return np.array(embeddings.astype(mx.float32))
|
||||
172
python/python/lancedb/embeddings/imagebind.py
Normal file
172
python/python/lancedb/embeddings/imagebind.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import AUDIO, IMAGES, TEXT
|
||||
|
||||
|
||||
@register("imagebind")
|
||||
class ImageBindEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ImageBind API
|
||||
For generating multi-modal embeddings across
|
||||
six different modalities: images, text, audio, depth, thermal, and IMU data
|
||||
|
||||
to download package, run :
|
||||
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
|
||||
"""
|
||||
|
||||
name: str = "imagebind_huge"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ndims = 1024
|
||||
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
|
||||
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
|
||||
|
||||
@cached_property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the embedding model. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
@cached_property
|
||||
def _data(self):
|
||||
"""
|
||||
Get the data module from imagebind
|
||||
"""
|
||||
data = attempt_import_or_raise("imagebind.data", "imagebind")
|
||||
return data
|
||||
|
||||
@cached_property
|
||||
def _ModalityType(self):
|
||||
"""
|
||||
Get the ModalityType from imagebind
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
return imagebind.imagebind_model.ModalityType
|
||||
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str]
|
||||
The query to embed. A query can be either text, image paths or audio paths.
|
||||
"""
|
||||
query = self.sanitize_input(query)
|
||||
if query[0].endswith(self._audio_extensions):
|
||||
return [self.generate_audio_embeddings(query)]
|
||||
elif query[0].endswith(self._image_extensions):
|
||||
return [self.generate_image_embeddings(query)]
|
||||
else:
|
||||
return [self.generate_text_embeddings(query)]
|
||||
|
||||
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
|
||||
image, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
|
||||
audio, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
|
||||
if self.normalize:
|
||||
audio_features /= audio_features.norm(dim=-1, keepdim=True)
|
||||
return audio_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.TEXT: self._data.load_and_transform_text(
|
||||
text, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, source: Union[IMAGES, AUDIO], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given sourcefield column in the pydantic model.
|
||||
"""
|
||||
source = self.sanitize_input(source)
|
||||
embeddings = []
|
||||
if source[0].endswith(self._audio_extensions):
|
||||
embeddings.extend(self.generate_audio_embeddings(source))
|
||||
return embeddings
|
||||
elif source[0].endswith(self._image_extensions):
|
||||
embeddings.extend(self.generate_image_embeddings(source))
|
||||
return embeddings
|
||||
else:
|
||||
embeddings.extend(self.generate_text_embeddings(source))
|
||||
return embeddings
|
||||
|
||||
def sanitize_input(
|
||||
self, input: Union[IMAGES, AUDIO]
|
||||
) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(input, (str, bytes)):
|
||||
input = [input]
|
||||
elif isinstance(input, pa.Array):
|
||||
input = input.to_pylist()
|
||||
elif isinstance(input, pa.ChunkedArray):
|
||||
input = input.combine_chunks().to_pylist()
|
||||
return input
|
||||
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
fetches the imagebind embedding model
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
|
||||
model.eval()
|
||||
model.to(self.device)
|
||||
return model
|
||||
149
python/python/lancedb/embeddings/instructor.py
Normal file
149
python/python/lancedb/embeddings/instructor.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, weak_lru
|
||||
|
||||
|
||||
@register("instructor")
|
||||
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models
|
||||
support multi-task learning, and can be used for a variety of tasks, including
|
||||
text classification, sentence similarity, and document retrieval. If you want to
|
||||
calculate customized embeddings for specific sentences, you may follow the unified
|
||||
template to write instructions:
|
||||
"Represent the `domain` `text_type` for `task_objective`":
|
||||
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science,
|
||||
finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence,
|
||||
document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding,
|
||||
e.g., retrieve a document, classify the sentence, etc.
|
||||
|
||||
For example, if you want to calculate embeddings for a document, you may write the
|
||||
instruction as follows:
|
||||
"Represent the document for retrieval"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. Available models are listed at
|
||||
https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The default model is hkunlp/instructor-base
|
||||
batch_size: int, default 32
|
||||
The batch size to use when generating embeddings
|
||||
device: str, default "cpu"
|
||||
The device to use when generating embeddings
|
||||
show_progress_bar: bool, default True
|
||||
Whether to show a progress bar when generating embeddings
|
||||
normalize_embeddings: bool, default True
|
||||
Whether to normalize the embeddings
|
||||
quantize: bool, default False
|
||||
Whether to quantize the model
|
||||
source_instruction: str, default "represent the document for retrieval"
|
||||
The instruction for the source column
|
||||
query_instruction: str, default "represent the document for retrieving the most
|
||||
similar documents"
|
||||
The instruction for the query
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||
|
||||
instructor = get_registry().get("instructor").create(
|
||||
source_instruction="represent the document for retrieval",
|
||||
query_instruction="represent the document for retrieving the most "
|
||||
"similar documents"
|
||||
)
|
||||
|
||||
class Schema(LanceModel):
|
||||
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||
text: str = instructor.SourceField()
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the "
|
||||
"end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under "
|
||||
"the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices "
|
||||
"in employment, housing, and other areas that.."}]
|
||||
|
||||
tbl.add(texts)
|
||||
|
||||
"""
|
||||
|
||||
name: str = "hkunlp/instructor-base"
|
||||
batch_size: int = 32
|
||||
device: str = "cpu"
|
||||
show_progress_bar: bool = True
|
||||
normalize_embeddings: bool = True
|
||||
quantize: bool = False
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: (
|
||||
str
|
||||
) = "represent the document for retrieving the most similar documents"
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
model = self.get_model()
|
||||
return model.encode("foo").shape[0]
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.generate_embeddings([[self.query_instruction, query]])
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
texts_formatted = [[self.source_instruction, text] for text in texts]
|
||||
return self.generate_embeddings(texts_formatted)
|
||||
|
||||
def generate_embeddings(self, texts: List) -> List:
|
||||
model = self.get_model()
|
||||
res = model.encode(
|
||||
texts,
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.show_progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
).tolist()
|
||||
return res
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_model(self):
|
||||
instructor_embedding = attempt_import_or_raise(
|
||||
"InstructorEmbedding", "InstructorEmbedding"
|
||||
)
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||
if self.quantize:
|
||||
if (
|
||||
"qnnpack" in torch.backends.quantized.supported_engines
|
||||
): # fix for https://github.com/pytorch/pytorch/issues/29327
|
||||
torch.backends.quantized.engine = "qnnpack"
|
||||
model = torch.quantization.quantize_dynamic(
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
return model
|
||||
180
python/python/lancedb/embeddings/open_clip.py
Normal file
180
python/python/lancedb/embeddings/open_clip.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
import urllib.parse as urlparse
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = attempt_import_or_raise("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = attempt_import_or_raise("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
76
python/python/lancedb/embeddings/openai.py
Normal file
76
python/python/lancedb/embeddings/openai.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
dim: Optional[int] = None
|
||||
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
@cached_property
|
||||
def _ndims(self):
|
||||
if self.name == "text-embedding-ada-002":
|
||||
return 1536
|
||||
elif self.name == "text-embedding-3-large":
|
||||
return self.dim or 3072
|
||||
elif self.name == "text-embedding-3-small":
|
||||
return self.dim or 1536
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {self.name}")
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
rs = self._openai_client.embeddings.create(
|
||||
input=texts, model=self.name, dimensions=self.ndims()
|
||||
)
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
def _openai_client(self):
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
return openai.OpenAI()
|
||||
189
python/python/lancedb/embeddings/registry.py
Normal file
189
python/python/lancedb/embeddings/registry.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array,
|
||||
pa.ChunkedArray, np.ndarray]
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs):
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts, *args, **kwargs):
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
def register(name):
|
||||
return __REGISTRY__.get_instance().register(name)
|
||||
|
||||
|
||||
def get_registry():
|
||||
"""
|
||||
Utility function to get the global instance of the registry
|
||||
|
||||
Returns
|
||||
-------
|
||||
EmbeddingFunctionRegistry
|
||||
The global registry instance
|
||||
|
||||
Examples
|
||||
--------
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
openai = registry.get("openai").create()
|
||||
"""
|
||||
return __REGISTRY__.get_instance()
|
||||
82
python/python/lancedb/embeddings/sentence_transformers.py
Normal file
82
python/python/lancedb/embeddings/sentence_transformers.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = attempt_import_or_raise(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
281
python/python/lancedb/embeddings/utils.py
Normal file
281
python/python/lancedb/embeddings/utils.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import weakref
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from ..util import deprecated, safe_import_pandas
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
|
||||
@deprecated
|
||||
def with_embeddings(
|
||||
func: Callable,
|
||||
data: DATA,
|
||||
column: str = "text",
|
||||
wrap_api: bool = True,
|
||||
show_progress: bool = False,
|
||||
batch_size: int = 1000,
|
||||
) -> pa.Table:
|
||||
"""Add a vector column to a table using the given embedding function.
|
||||
|
||||
The new columns will be called "vector".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
A function that takes a list of strings and returns a list of vectors.
|
||||
data : pa.Table or pd.DataFrame
|
||||
The data to add an embedding column to.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the embedding function.
|
||||
wrap_api : bool, default True
|
||||
Whether to wrap the embedding function in a retry and rate limiter.
|
||||
show_progress : bool, default False
|
||||
Whether to show a progress bar.
|
||||
batch_size : int, default 1000
|
||||
The number of row values to pass to each call of the embedding function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The input table with a new column called "vector" containing the embeddings.
|
||||
"""
|
||||
func = FunctionWrapper(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit()
|
||||
func = func.batch_size(batch_size)
|
||||
if show_progress:
|
||||
func = func.show_progress()
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||
embeddings = func(data[column].to_numpy())
|
||||
table = vec_to_table(np.array(embeddings))
|
||||
return data.append_column("vector", table["vector"])
|
||||
|
||||
|
||||
class FunctionWrapper:
|
||||
"""
|
||||
A wrapper for embedding functions that adds rate limiting, retries, and batching.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable):
|
||||
self.func = func
|
||||
self.rate_limiter_kwargs = {}
|
||||
self.retry_kwargs = {}
|
||||
self._batch_size = None
|
||||
self._progress = False
|
||||
|
||||
def __call__(self, text):
|
||||
# Get the embedding with retry
|
||||
if len(self.retry_kwargs) > 0:
|
||||
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
else:
|
||||
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
if len(self.rate_limiter_kwargs) > 0:
|
||||
v = int(sys.version_info.minor)
|
||||
if v >= 11:
|
||||
print(
|
||||
"WARNING: rate limit only support up to 3.10, proceeding "
|
||||
"without rate limiter"
|
||||
)
|
||||
else:
|
||||
import ratelimiter
|
||||
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
embed_func = limiter(embed_func)
|
||||
batches = self.to_batches(text)
|
||||
embeds = [emb for c in batches for emb in embed_func(c)]
|
||||
return embeds
|
||||
|
||||
def __repr__(self):
|
||||
return f"EmbeddingFunction(func={self.func})"
|
||||
|
||||
def rate_limit(self, max_calls=0.9, period=1.0):
|
||||
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
||||
return self
|
||||
|
||||
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
||||
self.retry_kwargs = dict(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
)
|
||||
return self
|
||||
|
||||
def batch_size(self, batch_size):
|
||||
self._batch_size = batch_size
|
||||
return self
|
||||
|
||||
def show_progress(self):
|
||||
self._progress = True
|
||||
return self
|
||||
|
||||
def to_batches(self, arr):
|
||||
length = len(arr)
|
||||
|
||||
def _chunker(arr):
|
||||
for start_i in range(0, len(arr), self._batch_size):
|
||||
yield arr[start_i : start_i + self._batch_size]
|
||||
|
||||
if self._progress:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
||||
else:
|
||||
yield from _chunker(arr)
|
||||
|
||||
|
||||
def weak_lru(maxsize=128):
|
||||
"""
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the
|
||||
latest instance of the objects to make sure memory usage is bounded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maxsize : int, default 128
|
||||
The maximum number of objects to cache.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
A decorator that can be applied to a method.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> class Foo:
|
||||
... @weak_lru()
|
||||
... def bar(self, x):
|
||||
... return x
|
||||
>>> foo = Foo()
|
||||
>>> foo.bar(1)
|
||||
1
|
||||
>>> foo.bar(2)
|
||||
2
|
||||
>>> foo.bar(1)
|
||||
1
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
@functools.lru_cache(maxsize)
|
||||
def _func(_self, *args, **kwargs):
|
||||
return func(_self(), *args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(self, *args, **kwargs):
|
||||
return _func(weakref.ref(self), *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def retry_with_exponential_backoff(
|
||||
func,
|
||||
initial_delay: float = 1,
|
||||
exponential_base: float = 2,
|
||||
jitter: bool = True,
|
||||
max_retries: int = 7,
|
||||
):
|
||||
"""Retry a function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func (function): The function to be retried.
|
||||
initial_delay (float): Initial delay in seconds (default is 1).
|
||||
exponential_base (float): The base for exponential backoff (default is 2).
|
||||
jitter (bool): Whether to add jitter to the delay (default is True).
|
||||
max_retries (int): Maximum number of retries (default is 10).
|
||||
|
||||
Returns:
|
||||
function: The decorated function.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception
|
||||
# is raised
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Currently retrying on all exceptions as there is no way to know the
|
||||
# format of the error msgs used by different APIs. We'll log the error
|
||||
# and say that it is assumed that if this portion errors out, it's due
|
||||
# to rate limit but the user should check the error message to be sure.
|
||||
except Exception as e: # noqa: PERF203
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
raise Exception(
|
||||
f"Maximum number of retries ({max_retries}) exceeded.", e
|
||||
)
|
||||
|
||||
delay *= exponential_base * (1 + jitter * random.random())
|
||||
LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}")
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
|
||||
|
||||
def api_key_not_found_help(provider):
|
||||
LOGGER.error(f"Could not find API key for {provider}.")
|
||||
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||
22
python/python/lancedb/exceptions.py
Normal file
22
python/python/lancedb/exceptions.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Custom exception handling"""
|
||||
|
||||
|
||||
class MissingValueError(ValueError):
|
||||
"""Exception raised when a required value is missing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MissingColumnError(KeyError):
|
||||
"""
|
||||
Exception raised when a column name specified is not in
|
||||
the DataFrame object
|
||||
"""
|
||||
|
||||
def __init__(self, column_name):
|
||||
self.column_name = column_name
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"Error: Column '{self.column_name}' does not exist in the DataFrame object"
|
||||
)
|
||||
184
python/python/lancedb/fts.py
Normal file
184
python/python/lancedb/fts.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .table import LanceTable
|
||||
|
||||
|
||||
def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
|
||||
"""
|
||||
Create a new Index (not populated)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_path : str
|
||||
Path to the index directory
|
||||
text_fields : List[str]
|
||||
List of text fields to index
|
||||
|
||||
Returns
|
||||
-------
|
||||
index : tantivy.Index
|
||||
The index object (not yet populated)
|
||||
"""
|
||||
# Declaring our schema.
|
||||
schema_builder = tantivy.SchemaBuilder()
|
||||
# special field that we'll populate with row_id
|
||||
schema_builder.add_integer_field("doc_id", stored=True)
|
||||
# data fields
|
||||
for name in text_fields:
|
||||
schema_builder.add_text_field(name, stored=True)
|
||||
schema = schema_builder.build()
|
||||
os.makedirs(index_path, exist_ok=True)
|
||||
index = tantivy.Index(schema, path=index_path)
|
||||
return index
|
||||
|
||||
|
||||
def populate_index(
|
||||
index: tantivy.Index,
|
||||
table: LanceTable,
|
||||
fields: List[str],
|
||||
writer_heap_size: int = 1024 * 1024 * 1024,
|
||||
) -> int:
|
||||
"""
|
||||
Populate an index with data from a LanceTable
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : tantivy.Index
|
||||
The index object
|
||||
table : LanceTable
|
||||
The table to index
|
||||
fields : List[str]
|
||||
List of fields to index
|
||||
writer_heap_size : int
|
||||
The writer heap size in bytes, defaults to 1GB
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of rows indexed
|
||||
"""
|
||||
# first check the fields exist and are string or large string type
|
||||
nested = []
|
||||
for name in fields:
|
||||
try:
|
||||
f = table.schema.field(name) # raises KeyError if not found
|
||||
except KeyError:
|
||||
f = resolve_path(table.schema, name)
|
||||
nested.append(name)
|
||||
|
||||
if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type):
|
||||
raise TypeError(f"Field {name} is not a string type")
|
||||
|
||||
# create a tantivy writer
|
||||
writer = index.writer(heap_size=writer_heap_size)
|
||||
# write data into index
|
||||
dataset = table.to_lance()
|
||||
row_id = 0
|
||||
|
||||
max_nested_level = 0
|
||||
if len(nested) > 0:
|
||||
max_nested_level = max([len(name.split(".")) for name in nested])
|
||||
|
||||
for b in dataset.to_batches(columns=fields):
|
||||
if max_nested_level > 0:
|
||||
b = pa.Table.from_batches([b])
|
||||
for _ in range(max_nested_level - 1):
|
||||
b = b.flatten()
|
||||
for i in range(b.num_rows):
|
||||
doc = tantivy.Document()
|
||||
for name in fields:
|
||||
value = b[name][i].as_py()
|
||||
if value is not None:
|
||||
doc.add_text(name, value)
|
||||
if not doc.is_empty:
|
||||
doc.add_integer("doc_id", row_id)
|
||||
writer.add_document(doc)
|
||||
row_id += 1
|
||||
# commit changes
|
||||
writer.commit()
|
||||
return row_id
|
||||
|
||||
|
||||
def resolve_path(schema, field_name: str) -> pa.Field:
|
||||
"""
|
||||
Resolve a nested field path to a list of field names
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field_name : str
|
||||
The field name to resolve
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
The resolved path
|
||||
"""
|
||||
path = field_name.split(".")
|
||||
field = schema.field(path.pop(0))
|
||||
for segment in path:
|
||||
if pa.types.is_struct(field.type):
|
||||
field = field.type.field(segment)
|
||||
else:
|
||||
raise KeyError(f"field {field_name} not found in schema {schema}")
|
||||
return field
|
||||
|
||||
|
||||
def search_index(
|
||||
index: tantivy.Index, query: str, limit: int = 10
|
||||
) -> Tuple[Tuple[int], Tuple[float]]:
|
||||
"""
|
||||
Search an index for a query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : tantivy.Index
|
||||
The index object
|
||||
query : str
|
||||
The query string
|
||||
limit : int
|
||||
The maximum number of results to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
ids_and_score: list[tuple[int], tuple[float]]
|
||||
A tuple of two tuples, the first containing the document ids
|
||||
and the second containing the scores
|
||||
"""
|
||||
searcher = index.searcher()
|
||||
query = index.parse_query(query)
|
||||
# get top results
|
||||
results = searcher.search(query, limit)
|
||||
if results.count == 0:
|
||||
return tuple(), tuple()
|
||||
return tuple(
|
||||
zip(
|
||||
*[
|
||||
(searcher.doc(doc_address)["doc_id"][0], score)
|
||||
for score, doc_address in results.hits
|
||||
]
|
||||
)
|
||||
)
|
||||
107
python/python/lancedb/merge.py
Normal file
107
python/python/lancedb/merge.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .common import DATA
|
||||
|
||||
|
||||
class LanceMergeInsertBuilder(object):
|
||||
"""Builder for a LanceDB merge insert operation
|
||||
|
||||
See [`merge_insert`][lancedb.table.Table.merge_insert] for
|
||||
more context
|
||||
"""
|
||||
|
||||
def __init__(self, table: "Table", on: List[str]): # noqa: F821
|
||||
# Do not put a docstring here. This method should be hidden
|
||||
# from API docs. Users should use merge_insert to create
|
||||
# this object.
|
||||
self._table = table
|
||||
self._on = on
|
||||
self._when_matched_update_all = False
|
||||
self._when_matched_update_all_condition = None
|
||||
self._when_not_matched_insert_all = False
|
||||
self._when_not_matched_by_source_delete = False
|
||||
self._when_not_matched_by_source_condition = None
|
||||
|
||||
def when_matched_update_all(
|
||||
self, *, where: Optional[str] = None
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist in both the source table (new data) and
|
||||
the target table (old data) will be updated, replacing
|
||||
the old row with the corresponding matching row.
|
||||
|
||||
If there are multiple matches then the behavior is undefined.
|
||||
Currently this causes multiple copies of the row to be created
|
||||
but that behavior is subject to change.
|
||||
"""
|
||||
self._when_matched_update_all = True
|
||||
self._when_matched_update_all_condition = where
|
||||
return self
|
||||
|
||||
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist only in the source table (new data) should
|
||||
be inserted into the target table.
|
||||
"""
|
||||
self._when_not_matched_insert_all = True
|
||||
return self
|
||||
|
||||
def when_not_matched_by_source_delete(
|
||||
self, condition: Optional[str] = None
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist only in the target table (old data) will be
|
||||
deleted. An optional condition can be provided to limit what
|
||||
data is deleted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
condition: Optional[str], default None
|
||||
If None then all such rows will be deleted. Otherwise the
|
||||
condition will be used as an SQL filter to limit what rows
|
||||
are deleted.
|
||||
"""
|
||||
self._when_not_matched_by_source_delete = True
|
||||
if condition is not None:
|
||||
self._when_not_matched_by_source_condition = condition
|
||||
return self
|
||||
|
||||
def execute(
|
||||
self,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Executes the merge insert operation
|
||||
|
||||
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_data: DATA
|
||||
New records which will be matched against the existing records
|
||||
to potentially insert or update into the table. This parameter
|
||||
can be anything you use for [`add`][lancedb.table.Table.add]
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
400
python/python/lancedb/pydantic.py
Normal file
400
python/python/lancedb/pydantic.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pydantic (v1 / v2) adapter for LanceDB"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import date, datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
|
||||
class FixedSizeListMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def dim() -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vector(dim: int, value_type: pa.DataType = pa.float32()):
|
||||
# TODO: remove in future release
|
||||
from warnings import warn
|
||||
|
||||
warn(
|
||||
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
|
||||
"This function will be removed in future release",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Vector(dim, value_type)
|
||||
|
||||
|
||||
def Vector(
|
||||
dim: int, value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim : int
|
||||
The dimension of the vector.
|
||||
value_type : pyarrow.DataType, optional
|
||||
The value type of the vector, by default pa.float32()
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import Vector
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... embeddings: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("url", pa.utf8(), False),
|
||||
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
# TODO: make a public parameterized type.
|
||||
class FixedSizeList(list, FixedSizeListMixin):
|
||||
def __repr__(self):
|
||||
return f"FixedSizeList(dim={dim})"
|
||||
|
||||
@staticmethod
|
||||
def dim() -> int:
|
||||
return dim
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return value_type
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls,
|
||||
core_schema.list_schema(
|
||||
min_length=dim,
|
||||
max_length=dim,
|
||||
items_schema=core_schema.float_schema(),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not isinstance(v, (list, range, np.ndarray)) or len(v) != dim:
|
||||
raise TypeError("A list of numbers or numpy.ndarray is needed")
|
||||
return cls(v)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]):
|
||||
field_schema["items"] = {"type": "number"}
|
||||
field_schema["maxItems"] = dim
|
||||
field_schema["minItems"] = dim
|
||||
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a field with native Python type to Arrow data type.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If the type is not supported.
|
||||
"""
|
||||
if py_type == int:
|
||||
return pa.int64()
|
||||
elif py_type == float:
|
||||
return pa.float64()
|
||||
elif py_type == str:
|
||||
return pa.utf8()
|
||||
elif py_type == bool:
|
||||
return pa.bool_()
|
||||
elif py_type == bytes:
|
||||
return pa.binary()
|
||||
elif py_type == date:
|
||||
return pa.date32()
|
||||
elif py_type == datetime:
|
||||
tz = get_extras(field, "tz")
|
||||
return pa.timestamp("us", tz=tz)
|
||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||
child = py_type.__args__[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||
)
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
|
||||
]
|
||||
|
||||
else:
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field)
|
||||
for name, field in model.model_fields.items()
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||
|
||||
if isinstance(field.annotation, _GenericAlias) or (
|
||||
sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias)
|
||||
):
|
||||
origin = field.annotation.__origin__
|
||||
args = field.annotation.__args__
|
||||
if origin == list:
|
||||
child = args[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
elif origin == Union:
|
||||
if len(args) == 2 and args[1] == type(None):
|
||||
return _py_type_to_arrow_type(args[0], field)
|
||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||
args = field.annotation.__args__
|
||||
if len(args) == 2:
|
||||
for typ in args:
|
||||
if typ == type(None):
|
||||
continue
|
||||
return _py_type_to_arrow_type(typ, field)
|
||||
elif inspect.isclass(field.annotation):
|
||||
if issubclass(field.annotation, pydantic.BaseModel):
|
||||
# Struct
|
||||
fields = _pydantic_model_to_fields(field.annotation)
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
def is_nullable(field: FieldInfo) -> bool:
|
||||
"""Check if a Pydantic FieldInfo is nullable."""
|
||||
if isinstance(field.annotation, _GenericAlias):
|
||||
origin = field.annotation.__origin__
|
||||
args = field.annotation.__args__
|
||||
if origin == Union:
|
||||
if len(args) == 2 and args[1] == type(None):
|
||||
return True
|
||||
elif sys.version_info >= (3, 10) and isinstance(field.annotation, types.UnionType):
|
||||
args = field.annotation.__args__
|
||||
for typ in args:
|
||||
if typ == type(None):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
|
||||
"""Convert a Pydantic field to a PyArrow Field."""
|
||||
dt = _pydantic_to_arrow_type(field)
|
||||
return pa.field(name, dt, is_nullable(field))
|
||||
|
||||
|
||||
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
|
||||
"""Convert a Pydantic model to a PyArrow Schema.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : Type[pydantic.BaseModel]
|
||||
The Pydantic BaseModel to convert to Arrow Schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pyarrow.Schema
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> from typing import List, Optional
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import pydantic_to_schema
|
||||
>>> class FooModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... s: str
|
||||
... vec: List[float]
|
||||
... li: List[int]
|
||||
...
|
||||
>>> schema = pydantic_to_schema(FooModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("s", pa.utf8(), False),
|
||||
... pa.field("vec", pa.list_(pa.float64()), False),
|
||||
... pa.field("li", pa.list_(pa.int64()), False),
|
||||
... ])
|
||||
"""
|
||||
fields = _pydantic_model_to_fields(model)
|
||||
return pa.schema(fields)
|
||||
|
||||
|
||||
class LanceModel(pydantic.BaseModel):
|
||||
"""
|
||||
A Pydantic Model base class that can be converted to a LanceDB Table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> from lancedb.pydantic import LanceModel, Vector
|
||||
>>>
|
||||
>>> class TestModel(LanceModel):
|
||||
... name: str
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("./example")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
>>> table.add([
|
||||
... TestModel(name="test", vector=[1.0, 2.0])
|
||||
... ])
|
||||
>>> table.search([0., 0.]).limit(1).to_pydantic(TestModel)
|
||||
[TestModel(name='test', vector=FixedSizeList(dim=2))]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_arrow_schema(cls):
|
||||
"""
|
||||
Get the Arrow Schema for this model.
|
||||
"""
|
||||
schema = pydantic_to_schema(cls)
|
||||
functions = cls.parse_embedding_functions()
|
||||
if len(functions) > 0:
|
||||
# Prevent circular import
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||
functions
|
||||
)
|
||||
schema = schema.with_metadata(metadata)
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def field_names(cls) -> List[str]:
|
||||
"""
|
||||
Get the field names of this model.
|
||||
"""
|
||||
return list(cls.safe_get_fields().keys())
|
||||
|
||||
@classmethod
|
||||
def safe_get_fields(cls):
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return cls.__fields__
|
||||
return cls.model_fields
|
||||
|
||||
@classmethod
|
||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the embedding functions from this model.
|
||||
"""
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
vec_and_function = []
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = get_extras(field_info, "vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
|
||||
configs = []
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func is func:
|
||||
# note we can't use == here since the function is a pydantic
|
||||
# model so two instances of the same function are ==, so if you
|
||||
# have multiple vector columns from multiple sources, both will
|
||||
# be mapped to the same source column
|
||||
# GH594
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
source_column=source, vector_column=vec, function=func
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
"""
|
||||
return model.dict()
|
||||
|
||||
else:
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
"""
|
||||
return model.model_dump()
|
||||
884
python/python/lancedb/query.py
Normal file
884
python/python/lancedb/query.py
Normal file
@@ -0,0 +1,884 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""The LanceDB Query
|
||||
|
||||
Attributes
|
||||
----------
|
||||
vector : List[float]
|
||||
the vector to search for
|
||||
filter : Optional[str]
|
||||
sql filter to refine the query with, optional
|
||||
prefilter : bool
|
||||
if True then apply the filter before vector search
|
||||
k : int
|
||||
top k results to return
|
||||
metric : str
|
||||
the distance metric between a pair of vectors,
|
||||
|
||||
can support L2 (default), Cosine and Dot.
|
||||
[metric definitions][search]
|
||||
columns : Optional[List[str]]
|
||||
which columns to return in the results
|
||||
nprobes : int
|
||||
The number of probes used - optional
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
refine_factor : Optional[int]
|
||||
Refine the results by reading extra elements and re-ranking them in memory.
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
"""
|
||||
|
||||
vector_column: Optional[str] = None
|
||||
|
||||
# vector to search for
|
||||
vector: Union[List[float], List[List[float]]]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# if True then apply the filter before vector search
|
||||
prefilter: bool = False
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
with_row_id: bool = False
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""Build LanceDB query based on specific query type:
|
||||
vector or full text search.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
# convert "auto" query_type to "vector", "fts"
|
||||
# or "hybrid" and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(table, query)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query, dtype=np.float32)
|
||||
elif isinstance(query, np.ndarray):
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
# If query_type is fts, then query must be a string.
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
query = cls._query_to_vector(table, query, vector_column_name)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _query_to_vector(cls, table, query, vector_column_name):
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
return conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._with_row_id = False
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
removed_in="0.4.0",
|
||||
current_version=__version__,
|
||||
details="Use to_pandas() instead",
|
||||
)
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
*Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.*
|
||||
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
return self.to_pandas()
|
||||
|
||||
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flatten: Optional[Union[int, bool]]
|
||||
If flatten is True, flatten all nested columns.
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
"""
|
||||
tbl = self.to_arrow()
|
||||
if flatten is True:
|
||||
while True:
|
||||
tbl = tbl.flatten()
|
||||
# loop through all columns to check if there is any struct column
|
||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
elif isinstance(flatten, int):
|
||||
if flatten <= 0:
|
||||
raise ValueError(
|
||||
"Please specify a positive integer for flatten or the boolean "
|
||||
"value `True`"
|
||||
)
|
||||
while flatten > 0:
|
||||
tbl = tbl.flatten()
|
||||
flatten -= 1
|
||||
return tbl.to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_list(self) -> List[dict]:
|
||||
"""
|
||||
Execute the query and return the results as a list of dictionaries.
|
||||
|
||||
Each list entry is a dictionary with the selected column names as keys,
|
||||
or all table columns if `select` is not called. The vector and the "_distance"
|
||||
fields are returned whether or not they're explicitly selected.
|
||||
"""
|
||||
return self.to_arrow().to_pylist()
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def to_polars(self) -> "pl.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a Polars DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
import polars as pl
|
||||
|
||||
return pl.from_arrow(self.to_arrow())
|
||||
|
||||
def limit(self, limit: Union[int, None]) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
By default the query is limited to the first 10.
|
||||
Call this method and pass 0, a negative value,
|
||||
or None to remove the limit.
|
||||
*WARNING* if you have a large dataset, removing
|
||||
the limit can potentially result in reading a
|
||||
large amount of data into memory and cause
|
||||
out of memory issues.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
if limit is None or limit <= 0:
|
||||
self._limit = None
|
||||
else:
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
def with_row_id(self, with_row_id: bool) -> LanceQueryBuilder:
|
||||
"""Set whether to return row ids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_row_id: bool
|
||||
If True, return _rowid column in the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4},
|
||||
... {"vector": [0.4, 0.4], "b": 6},
|
||||
... {"vector": [0.4, 0.4], "b": 10}]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data=data)
|
||||
>>> (table.search([0.4, 0.4])
|
||||
... .metric("cosine")
|
||||
... .where("b < 10")
|
||||
... .select(["b"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
b vector _distance
|
||||
0 6 [0.4, 0.4] 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._vector_column = vector_column
|
||||
self._prefilter = False
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nprobes: int
|
||||
The number of probes to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
As an example, a refine factor of 2 will sample 2x as many vectors as
|
||||
requested, re-ranks them, and returns the top half most relevant results.
|
||||
|
||||
See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
refine_factor: int
|
||||
The refine factor to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||
if isinstance(vector[0], np.ndarray):
|
||||
vector = [v.tolist() for v in vector]
|
||||
query = Query(
|
||||
vector=vector,
|
||||
filter=self._where,
|
||||
prefilter=self._prefilter,
|
||||
k=self._limit,
|
||||
metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, table: "Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
|
||||
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
|
||||
"""Set whether to use phrase query.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
phrase_query: bool, default True
|
||||
If True, then the query will be wrapped in quotes and
|
||||
double quotes replaced by single quotes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceFtsQueryBuilder
|
||||
The LanceFtsQueryBuilder object.
|
||||
"""
|
||||
self._phrase_query = phrase_query
|
||||
return self
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .fts import search_index
|
||||
|
||||
# get the index path
|
||||
index_path = self._table._get_fts_index_path()
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
raise FileNotFoundError(
|
||||
"Fts index does not exist. "
|
||||
"Please first call table.create_fts_index(['<field_names>']) to "
|
||||
"create the fts index."
|
||||
)
|
||||
# open the index
|
||||
index = tantivy.Index.open(index_path)
|
||||
# get the scores and doc ids
|
||||
query = self._query
|
||||
if self._phrase_query:
|
||||
query = query.replace('"', "'")
|
||||
query = f'"{query}"'
|
||||
row_ids, scores = search_index(index, query, self._limit)
|
||||
if len(row_ids) == 0:
|
||||
empty_schema = pa.schema([pa.field("score", pa.float32())])
|
||||
return pa.Table.from_pylist([], schema=empty_schema)
|
||||
scores = pa.array(scores)
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
|
||||
if self._where is not None:
|
||||
try:
|
||||
# TODO would be great to have Substrait generate pyarrow compute
|
||||
# expressions or conversely have pyarrow support SQL expressions
|
||||
# using Substrait
|
||||
import duckdb
|
||||
|
||||
output_tbl = (
|
||||
duckdb.sql("SELECT * FROM output_tbl")
|
||||
.filter(self._where)
|
||||
.to_arrow_table()
|
||||
)
|
||||
except ImportError:
|
||||
import tempfile
|
||||
|
||||
import lance
|
||||
|
||||
# TODO Use "memory://" instead once that's supported
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
output_tbl = ds.to_table(filter=self._where)
|
||||
|
||||
if self._with_row_id:
|
||||
# Need to set this to uint explicitly as vector results are in uint64
|
||||
row_ids = pa.array(row_ids, type=pa.uint64())
|
||||
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||
return output_tbl
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
)
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._norm = "score"
|
||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
|
||||
def _validate_fts_index(self):
|
||||
if self._table._get_fts_index_path() is None:
|
||||
raise ValueError(
|
||||
"Please create a full-text search index " "to perform hybrid search."
|
||||
)
|
||||
|
||||
def _validate_query(self, query):
|
||||
# Temp hack to support vectorized queries for hybrid search
|
||||
if isinstance(query, str):
|
||||
return query, query
|
||||
elif isinstance(query, tuple):
|
||||
if len(query) != 2:
|
||||
raise ValueError(
|
||||
"The query must be a tuple of (vector_query, fts_query)."
|
||||
)
|
||||
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||
if not isinstance(query[1], str):
|
||||
raise ValueError("The fts query must be a string.")
|
||||
return query[0], query[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"The query must be either a string or a tuple of (vector, string)."
|
||||
)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
vector_future = executor.submit(
|
||||
self._vector_query.with_row_id(True).to_arrow
|
||||
)
|
||||
fts_results = fts_future.result()
|
||||
vector_results = vector_future.result()
|
||||
|
||||
# convert to ranks first if needed
|
||||
if self._norm == "rank":
|
||||
vector_results = self._rank(vector_results, "_distance")
|
||||
fts_results = self._rank(fts_results, "score")
|
||||
# normalize the scores to be between 0 and 1, 0 being most relevant
|
||||
vector_results = self._normalize_scores(vector_results, "_distance")
|
||||
|
||||
# In fts higher scores represent relevance. Not inverting them here as
|
||||
# rerankers might need to preserve this score to support `return_score="all"`
|
||||
fts_results = self._normalize_scores(fts_results, "score")
|
||||
|
||||
results = self._reranker.rerank_hybrid(
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
|
||||
if not self._with_row_id:
|
||||
results = results.drop(["_rowid"])
|
||||
return results
|
||||
|
||||
def _rank(self, results: pa.Table, column: str, ascending: bool = True):
|
||||
if len(results) == 0:
|
||||
return results
|
||||
# Get the _score column from results
|
||||
scores = results.column(column).to_numpy()
|
||||
sort_indices = np.argsort(scores)
|
||||
if not ascending:
|
||||
sort_indices = sort_indices[::-1]
|
||||
ranks = np.empty_like(sort_indices)
|
||||
ranks[sort_indices] = np.arange(len(scores)) + 1
|
||||
# replace the _score column with the ranks
|
||||
_score_idx = results.column_names.index(column)
|
||||
results = results.set_column(
|
||||
_score_idx, column, pa.array(ranks, type=pa.float32())
|
||||
)
|
||||
return results
|
||||
|
||||
def _normalize_scores(self, results: pa.Table, column: str, invert=False):
|
||||
if len(results) == 0:
|
||||
return results
|
||||
# Get the _score column from results
|
||||
scores = results.column(column).to_numpy()
|
||||
# normalize the scores by subtracting the min and dividing by the max
|
||||
max, min = np.max(scores), np.min(scores)
|
||||
if np.isclose(max, min):
|
||||
rng = max
|
||||
else:
|
||||
rng = max - min
|
||||
scores = (scores - min) / rng
|
||||
if invert:
|
||||
scores = 1 - scores
|
||||
# replace the _score column with the ranks
|
||||
_score_idx = results.column_names.index(column)
|
||||
results = results.set_column(
|
||||
_score_idx, column, pa.array(scores, type=pa.float32())
|
||||
)
|
||||
return results
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
normalize="score",
|
||||
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Rerank the hybrid search results using the specified reranker. The reranker
|
||||
must be an instance of Reranker class.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
normalize: str, default "score"
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||
the scores are converted to ranks and then normalized. If "score", the
|
||||
scores are normalized directly.
|
||||
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
if normalize not in ["rank", "score"]:
|
||||
raise ValueError("normalize must be 'rank' or 'score'.")
|
||||
if reranker and not isinstance(reranker, Reranker):
|
||||
raise ValueError("reranker must be an instance of Reranker class.")
|
||||
|
||||
self._norm = normalize
|
||||
self._reranker = reranker
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the maximum number of results to return for both vector and fts search
|
||||
components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
self._limit = limit
|
||||
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the columns to return for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.select(columns)
|
||||
self._fts_query.select(columns)
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the where clause for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
|
||||
self._vector_query.where(where, prefilter=prefilter)
|
||||
self._fts_query.where(where)
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance metric to use for vector search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.metric(metric)
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of probes to use for vector search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nprobes: int
|
||||
The number of probes to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Refine the vector search results by reading extra elements and
|
||||
re-ranking them in memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
refine_factor: int
|
||||
The refine factor to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.refine_factor(refine_factor)
|
||||
return self
|
||||
64
python/python/lancedb/remote/__init__.py
Normal file
64
python/python/lancedb/remote/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import List, Optional
|
||||
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import VECTOR_COLUMN_NAME
|
||||
|
||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||
|
||||
|
||||
class VectorQuery(BaseModel):
|
||||
# vector to search for
|
||||
vector: List[float]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
_metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
|
||||
|
||||
@attrs.define
|
||||
class VectorQueryResult:
|
||||
# for now the response is directly seralized into a pandas dataframe
|
||||
tbl: pa.Table
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.tbl
|
||||
|
||||
|
||||
class LanceDBClient(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
22
python/python/lancedb/remote/arrow.py
Normal file
22
python/python/lancedb/remote/arrow.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def to_ipc_binary(table: pa.Table) -> bytes:
|
||||
"""Serialize a PyArrow Table to IPC binary."""
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_stream(sink, table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
return sink.getvalue().to_pybytes()
|
||||
249
python/python/lancedb/remote/client.py
Normal file
249
python/python/lancedb/remote/client.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import Retry
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
|
||||
|
||||
|
||||
def _check_not_closed(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self.closed:
|
||||
raise ValueError("Connection is closed")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _read_ipc(resp: requests.Response) -> pa.Table:
|
||||
resp_body = resp.content
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attrs.field(default=None)
|
||||
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
|
||||
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
||||
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
||||
|
||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
||||
sess.mount("https://", adapter_class())
|
||||
return sess
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@functools.cached_property
|
||||
def headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
if self.region == "local": # Local test mode
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _check_status(resp: requests.Response):
|
||||
if resp.status_code == 404:
|
||||
raise LanceDBClientError(f"Not found: {resp.text}")
|
||||
elif 400 <= resp.status_code < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status_code}, error: {resp.text}"
|
||||
)
|
||||
elif 500 <= resp.status_code < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status_code}, error: {resp.text}"
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
with self.session.get(
|
||||
urljoin(self.url, uri),
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=(120.0, 300.0),
|
||||
) as resp:
|
||||
self._check_status(resp)
|
||||
return resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
content_type: Optional[str] = None,
|
||||
deserialize: Callable = lambda resp: resp.json(),
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a POST request and returns the deserialized response payload.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The uri to send the POST request to.
|
||||
data: Union[Dict[str, Any], BaseModel]
|
||||
request_id: Optional[str]
|
||||
Optional client side request id to be sent in the request headers.
|
||||
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||
if isinstance(data, bytes):
|
||||
req_kwargs = {"data": data}
|
||||
else:
|
||||
req_kwargs = {"json": data}
|
||||
|
||||
headers = self.headers.copy()
|
||||
if content_type is not None:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
with self.session.post(
|
||||
urljoin(self.url, uri),
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=(120.0, 300.0),
|
||||
**req_kwargs,
|
||||
) as resp:
|
||||
self._check_status(resp)
|
||||
return deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
|
||||
"""List all tables in the database."""
|
||||
if page_token is None:
|
||||
page_token = ""
|
||||
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
||||
"""
|
||||
Adds an http adapter to session that will retry retryable requests to the table.
|
||||
"""
|
||||
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
||||
retry_adapter_instance = retry_adapter(retry_options)
|
||||
session = self.session
|
||||
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
|
||||
|
||||
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
||||
return {
|
||||
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
||||
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
||||
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
||||
"backoff_factor": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
||||
),
|
||||
"backoff_jitter": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
||||
),
|
||||
"statuses": [
|
||||
int(i.strip())
|
||||
for i in os.environ.get(
|
||||
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
||||
).split(",")
|
||||
],
|
||||
"methods": methods,
|
||||
}
|
||||
|
||||
|
||||
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||
total_retries = options["retries"]
|
||||
connect_retries = options["connect_retries"]
|
||||
read_retries = options["read_retries"]
|
||||
backoff_factor = options["backoff_factor"]
|
||||
backoff_jitter = options["backoff_jitter"]
|
||||
statuses = options["statuses"]
|
||||
methods = frozenset(options["methods"])
|
||||
logging.debug(
|
||||
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
||||
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
||||
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
||||
+ f"methods {methods}"
|
||||
)
|
||||
|
||||
return HTTPAdapter(
|
||||
max_retries=Retry(
|
||||
total=total_retries,
|
||||
connect=connect_retries,
|
||||
read=read_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
backoff_jitter=backoff_jitter,
|
||||
status_forcelist=statuses,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
)
|
||||
115
python/python/lancedb/remote/connection_timeout.py
Normal file
115
python/python/lancedb/remote/connection_timeout.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2024 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This module contains an adapter that will close connections if they have not been
|
||||
# used before a certain timeout. This is necessary because some load balancers will
|
||||
# close connections after a certain amount of time, but the request module may not yet
|
||||
# have received the FIN/ACK and will try to reuse the connection.
|
||||
#
|
||||
# TODO some of the code here can be simplified if/when this PR is merged:
|
||||
# https://github.com/urllib3/urllib3/pull/3275
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.connection import HTTPSConnection
|
||||
from urllib3.connectionpool import HTTPSConnectionPool
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
|
||||
def get_client_connection_timeout() -> int:
|
||||
return int(os.environ.get("LANCE_CLIENT_CONNECTION_TIMEOUT", "300"))
|
||||
|
||||
|
||||
class LanceDBHTTPSConnection(HTTPSConnection):
|
||||
"""
|
||||
HTTPSConnection that tracks the last time it was used.
|
||||
"""
|
||||
|
||||
idle_timeout: datetime.timedelta
|
||||
last_activity: datetime.datetime
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_activity = datetime.datetime.now()
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
self.last_activity = datetime.datetime.now()
|
||||
super().request(*args, **kwargs)
|
||||
|
||||
def is_expired(self):
|
||||
return datetime.datetime.now() - self.last_activity > self.idle_timeout
|
||||
|
||||
|
||||
def LanceDBHTTPSConnectionPoolFactory(client_idle_timeout: int):
|
||||
"""
|
||||
Creates a connection pool class that can be used to close idle connections.
|
||||
"""
|
||||
|
||||
class LanceDBHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
# override the connection class
|
||||
ConnectionCls = LanceDBHTTPSConnection
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _get_conn(self, timeout: float | None = None):
|
||||
logging.debug("Getting https connection")
|
||||
conn = super()._get_conn(timeout)
|
||||
if conn.is_expired():
|
||||
logging.debug("Closing expired connection")
|
||||
conn.close()
|
||||
|
||||
return conn
|
||||
|
||||
def _new_conn(self):
|
||||
conn = super()._new_conn()
|
||||
conn.idle_timeout = datetime.timedelta(seconds=client_idle_timeout)
|
||||
return conn
|
||||
|
||||
return LanceDBHTTPSConnectionPool
|
||||
|
||||
|
||||
class LanceDBClientPoolManager(PoolManager):
|
||||
def __init__(
|
||||
self, client_idle_timeout: int, num_pools: int, maxsize: int, **kwargs
|
||||
):
|
||||
super().__init__(num_pools=num_pools, maxsize=maxsize, **kwargs)
|
||||
# inject our connection pool impl
|
||||
connection_pool_class = LanceDBHTTPSConnectionPoolFactory(
|
||||
client_idle_timeout=client_idle_timeout
|
||||
)
|
||||
self.pool_classes_by_scheme["https"] = connection_pool_class
|
||||
|
||||
|
||||
def LanceDBClientHTTPAdapterFactory():
|
||||
"""
|
||||
Creates an HTTPAdapter class that can be used to close idle connections
|
||||
"""
|
||||
|
||||
# closure over the timeout
|
||||
client_idle_timeout = get_client_connection_timeout()
|
||||
|
||||
class LanceDBClientRequestHTTPAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, connections, maxsize, block=False):
|
||||
# inject our pool manager impl
|
||||
self.poolmanager = LanceDBClientPoolManager(
|
||||
client_idle_timeout=client_idle_timeout,
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
)
|
||||
|
||||
return LanceDBClientRequestHTTPAdapter
|
||||
279
python/python/lancedb/remote/db.py
Normal file
279
python/python/lancedb/remote/db.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
from .errors import LanceDBClientError
|
||||
|
||||
|
||||
class RemoteDBConnection(DBConnection):
|
||||
"""A connection to a remote LanceDB database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_url: str,
|
||||
api_key: str,
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@override
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List the names of all tables in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str
|
||||
The last token to start the new page.
|
||||
limit: int, default 10
|
||||
The maximum number of tables to return for each page.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._client.list_tables(limit, page_token)
|
||||
|
||||
if len(result) > 0:
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
yield item
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
self._client.mount_retry_adapter_for_table(name)
|
||||
|
||||
# check if table exists
|
||||
try:
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
except LanceDBClientError as err:
|
||||
if str(err).startswith("Not found"):
|
||||
logging.error(
|
||||
"Table %s does not exist. Please first call "
|
||||
"db.create_table(%s, data).",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: DATA = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
mode: Optional[str] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data) # doctest: +SKIP
|
||||
LanceTable(my_table)
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data) # doctest: +SKIP
|
||||
LanceTable(table2)
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema) # doctest: +SKIP
|
||||
LanceTable(table3)
|
||||
|
||||
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> def make_batches():
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
... ["vector", "item", "price"],
|
||||
... )
|
||||
>>> schema=pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema) # doctest: +SKIP
|
||||
LanceTable(table4)
|
||||
|
||||
"""
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
logging.warning(
|
||||
"embedding_functions is not yet supported on LanceDB Cloud."
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
if mode is not None:
|
||||
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
else:
|
||||
if schema is None:
|
||||
raise ValueError("Either data or schema must be provided")
|
||||
data = pa.Table.from_pylist([], schema=schema)
|
||||
|
||||
from .table import RemoteTable
|
||||
|
||||
data = to_ipc_binary(data)
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._client.close()
|
||||
16
python/python/lancedb/remote/errors.py
Normal file
16
python/python/lancedb/remote/errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class LanceDBClientError(RuntimeError):
|
||||
pass
|
||||
497
python/python/lancedb/remote/table.py
Normal file
497
python/python/lancedb/remote/table.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from ..util import inf_vector_column_query, value_to_sql
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
from .db import RemoteDBConnection
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||
self._conn = conn
|
||||
self._name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
|
||||
@cached_property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""to_arrow() is not yet supported on LanceDB cloud."""
|
||||
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def to_pandas(self):
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def create_scalar_index(self, *args, **kwargs):
|
||||
"""Creates a scalar index"""
|
||||
return NotImplementedError(
|
||||
"create_scalar_index() is not yet supported on LanceDB cloud."
|
||||
)
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
index_cache_size: Optional[int] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
replace: Optional[bool] = None,
|
||||
accelerator: Optional[str] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
the metric and the vector column name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric : str
|
||||
The metric to use for the index. Default is "L2".
|
||||
vector_column_name : str
|
||||
The name of the vector column. Default is "vector".
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import uuid
|
||||
>>> from lancedb.schema import vector
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table_name = uuid.uuid4().hex
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
... pa.field("id", pa.uint32(), False),
|
||||
... pa.field("vector", vector(128), False),
|
||||
... pa.field("s", pa.string(), False),
|
||||
... ]
|
||||
... )
|
||||
>>> table = db.create_table( # doctest: +SKIP
|
||||
... table_name, # doctest: +SKIP
|
||||
... schema=schema, # doctest: +SKIP
|
||||
... )
|
||||
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
|
||||
if num_partitions is not None:
|
||||
logging.warning(
|
||||
"num_partitions is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if num_sub_vectors is not None:
|
||||
logging.warning(
|
||||
"num_sub_vectors is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
index_type = "vector"
|
||||
|
||||
data = {
|
||||
"column": vector_column_name,
|
||||
"index_type": index_type,
|
||||
"metric_type": metric,
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add more data to the [Table](Table). It has the same API signature as
|
||||
the OSS version.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: DATA
|
||||
The data to insert into the table. Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
vector_column_name: Optional[str] = None,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query) # doctest: +SKIP
|
||||
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
|
||||
... .select(["caption", "original_width"]) # doctest: +SKIP
|
||||
... .limit(2) # doctest: +SKIP
|
||||
... .to_pandas()) # doctest: +SKIP
|
||||
caption original_width vector _distance # doctest: +SKIP
|
||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000 # doctest: +SKIP
|
||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996 # doctest: +SKIP
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
Once executed, the query returns
|
||||
|
||||
- selected columns
|
||||
|
||||
- the vector
|
||||
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
if (
|
||||
query.vector is not None
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
if self._conn._request_thread_pool is None:
|
||||
|
||||
def submit(name, q):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
return self._conn._request_thread_pool.submit(
|
||||
self._conn._client.query, name, q
|
||||
)
|
||||
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(submit(self._name, q))
|
||||
|
||||
return pa.concat_tables(
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
)
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return result.to_arrow()
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
data = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
params = {}
|
||||
if len(merge._on) != 1:
|
||||
raise ValueError(
|
||||
"RemoteTable only supports a single on key in merge_insert"
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
if merge._when_matched_update_all_condition is not None:
|
||||
params[
|
||||
"when_matched_update_all_filt"
|
||||
] = merge._when_matched_update_all_condition
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
params["when_not_matched_by_source_delete"] = str(
|
||||
merge._when_not_matched_by_source_delete
|
||||
).lower()
|
||||
if merge._when_not_matched_by_source_condition is not None:
|
||||
params[
|
||||
"when_not_matched_by_source_delete_filt"
|
||||
] = merge._when_not_matched_by_source_condition
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def delete(self, predicate: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
This can be used to delete a single row, many rows, all rows, or
|
||||
sometimes no rows (if your predicate matches nothing).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predicate: str
|
||||
The SQL where clause to use when deleting rows.
|
||||
|
||||
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||
|
||||
The filter must not be empty, or it will error.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP
|
||||
1 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
2 1 [1.0, 2.0] 145.0 # doctest: +SKIP
|
||||
>>> table.delete("x = 2") # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP
|
||||
1 1 [1.0, 2.0] 145.0 # doctest: +SKIP
|
||||
|
||||
If you have a list of values to delete, you can combine them into a
|
||||
stringified list and use the `IN` operator:
|
||||
|
||||
>>> to_remove = [1, 3] # doctest: +SKIP
|
||||
>>> to_remove = ", ".join([str(v) for v in to_remove]) # doctest: +SKIP
|
||||
>>> table.delete(f"x IN ({to_remove})") # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
|
||||
def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str, optional
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict, optional
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
values_sql: dict, optional
|
||||
The values to update, expressed as SQL expression strings. These can
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.to_pandas() # doctest: +SKIP
|
||||
x vector # doctest: +SKIP
|
||||
0 1 [1.0, 2.0] # doctest: +SKIP
|
||||
1 2 [3.0, 4.0] # doctest: +SKIP
|
||||
2 3 [5.0, 6.0] # doctest: +SKIP
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]}) # doctest: +SKIP
|
||||
>>> table.to_pandas() # doctest: +SKIP
|
||||
x vector # doctest: +SKIP
|
||||
0 1 [1.0, 2.0] # doctest: +SKIP
|
||||
1 3 [5.0, 6.0] # doctest: +SKIP
|
||||
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||
|
||||
"""
|
||||
if values is not None and values_sql is not None:
|
||||
raise ValueError("Only one of values or values_sql can be provided")
|
||||
if values is None and values_sql is None:
|
||||
raise ValueError("Either values or values_sql must be provided")
|
||||
|
||||
if values is not None:
|
||||
updates = [[k, value_to_sql(v)] for k, v in values.items()]
|
||||
else:
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"cleanup_old_versions() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def compact_files(self, *_):
|
||||
"""compact_files() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
# payload = {"filter": filter}
|
||||
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
|
||||
return NotImplementedError(
|
||||
"count_rows() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
raise NotImplementedError(
|
||||
"add_columns() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
raise NotImplementedError(
|
||||
"alter_columns() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
raise NotImplementedError(
|
||||
"drop_columns() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
0,
|
||||
pa.field("query_index", pa.uint32()),
|
||||
pa.array([i] * len(tbl), pa.uint32()),
|
||||
)
|
||||
15
python/python/lancedb/rerankers/__init__.py
Normal file
15
python/python/lancedb/rerankers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .base import Reranker
|
||||
from .cohere import CohereReranker
|
||||
from .colbert import ColbertReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"CrossEncoderReranker",
|
||||
"CohereReranker",
|
||||
"LinearCombinationReranker",
|
||||
"OpenaiReranker",
|
||||
"ColbertReranker",
|
||||
]
|
||||
75
python/python/lancedb/rerankers/base.py
Normal file
75
python/python/lancedb/rerankers/base.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
def __init__(self, return_score: str = "relevance"):
|
||||
"""
|
||||
Interface for a reranker. A reranker is used to rerank the results from a
|
||||
vector and FTS search. This is useful for combining the results from both
|
||||
search methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
return_score : str, default "relevance"
|
||||
opntions are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the relevance
|
||||
score. If "all", will return all scores from the vector and FTS search along
|
||||
with the relevance score.
|
||||
|
||||
"""
|
||||
if return_score not in ["relevance", "all"]:
|
||||
raise ValueError("score must be either 'relevance' or 'all'")
|
||||
self.score = return_score
|
||||
|
||||
@abstractmethod
|
||||
def rerank_hybrid(
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
"""
|
||||
Rerank function receives the individual results from the vector and FTS search
|
||||
results. You can choose to use any of the results to generate the final results,
|
||||
allowing maximum flexibility. This is mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : str
|
||||
The input query
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
pass
|
||||
|
||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
"""
|
||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||
function that just concatenates the results and removes the duplicates.
|
||||
|
||||
NOTE: This doesn't take score into account. It'll keep the instance that was
|
||||
encountered first. This is designed for rerankers that don't use the score.
|
||||
In case you want to use the score, or support `return_scores="all"` you'll
|
||||
have to implement your own merging function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
combined = pa.concat_tables([vector_results, fts_results], promote=True)
|
||||
row_id = combined.column("_rowid")
|
||||
|
||||
# deduplicate
|
||||
mask = np.full((combined.shape[0]), False)
|
||||
_, mask_indices = np.unique(np.array(row_id), return_index=True)
|
||||
mask[mask_indices] = True
|
||||
combined = combined.filter(mask=mask)
|
||||
|
||||
return combined
|
||||
81
python/python/lancedb/rerankers/cohere.py
Normal file
81
python/python/lancedb/rerankers/cohere.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class CohereReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the Cohere Rerank API.
|
||||
https://docs.cohere.com/docs/rerank-guide
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available cohere models are:
|
||||
- rerank-english-v2.0
|
||||
- rerank-multilingual-v2.0
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
top_n : str, default None
|
||||
The number of results to return. If None, will return all results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "rerank-english-v2.0",
|
||||
column: str = "text",
|
||||
top_n: Union[int, None] = None,
|
||||
return_score="relevance",
|
||||
api_key: Union[str, None] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the CohereReranker."
|
||||
)
|
||||
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
) # tuples
|
||||
combined_results = combined_results.take(list(indices))
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for cohere reranker"
|
||||
)
|
||||
return combined_results
|
||||
109
python/python/lancedb/rerankers/colbert.py
Normal file
109
python/python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from functools import cached_property
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class ColbertReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the ColBERT model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "colbert-ir/colbertv2.0"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = attempt_import_or_raise(
|
||||
"torch"
|
||||
) # import here for faster ops later
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
|
||||
tokenizer, model = self._model
|
||||
|
||||
# Encode the query
|
||||
query_encoding = tokenizer(query, return_tensors="pt")
|
||||
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||
scores = []
|
||||
# Get score for each document
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
# Expand dimensions for broadcasting
|
||||
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||
expanded_query = query_embedding.unsqueeze(2)
|
||||
expanded_doc = document_embedding.unsqueeze(1)
|
||||
|
||||
# Compute cosine similarity across the embedding dimension
|
||||
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||
expanded_query, expanded_doc, dim=-1
|
||||
)
|
||||
|
||||
# Take the maximum similarity for each query token (across all document tokens)
|
||||
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||
|
||||
# Average these maximum scores across all query tokens
|
||||
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||
return avg_max_sim
|
||||
74
python/python/lancedb/rerankers/cross_encoder.py
Normal file
74
python/python/lancedb/rerankers/cross_encoder.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class CrossEncoderReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using a cross encoder model. The cross encoder model is
|
||||
used to score the query and each result. The results are then sorted by the score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
|
||||
The name of the cross encoder model to use. See the sentence transformers
|
||||
documentation for a list of available models.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
device : str, default None
|
||||
The device to use for the cross encoder model. If None, will use "cuda"
|
||||
if available, otherwise "cpu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6",
|
||||
column: str = "text",
|
||||
device: Union[str, None] = None,
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
passages = combined_results[self.column].to_pylist()
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
)
|
||||
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for CrossEncoderReranker"
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
117
python/python/lancedb/rerankers/linear_combination.py
Normal file
117
python/python/lancedb/rerankers/linear_combination.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import List
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class LinearCombinationReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using a linear combination of the scores from the
|
||||
vector and FTS search. For missing scores, fill with `fill` value.
|
||||
Parameters
|
||||
----------
|
||||
weight : float, default 0.7
|
||||
The weight to give to the vector score. Must be between 0 and 1.
|
||||
fill : float, default 1.0
|
||||
The score to give to results that are only in one of the two result sets.
|
||||
This is treated as penalty, so a higher value means a lower score.
|
||||
TODO: We should just hardcode this--
|
||||
its pretty confusing as we invert scores to calculate final score
|
||||
return_score : str, default "relevance"
|
||||
opntions are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the relevance
|
||||
score. If "all", will return all scores from the vector and FTS search along
|
||||
with the relevance score.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, weight: float = 0.7, fill: float = 1.0, return_score="relevance"
|
||||
):
|
||||
if weight < 0 or weight > 1:
|
||||
raise ValueError("weight must be between 0 and 1.")
|
||||
super().__init__(return_score)
|
||||
self.weight = weight
|
||||
self.fill = fill
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str, # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results, self.fill)
|
||||
|
||||
return combined_results
|
||||
|
||||
def merge_results(
|
||||
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
|
||||
):
|
||||
# If both are empty then just return an empty table
|
||||
if len(vector_results) == 0 and len(fts_results) == 0:
|
||||
return vector_results
|
||||
# If one is empty then return the other
|
||||
if len(vector_results) == 0:
|
||||
return fts_results
|
||||
if len(fts_results) == 0:
|
||||
return vector_results
|
||||
|
||||
# sort both input tables on _rowid
|
||||
combined_list = []
|
||||
vector_list = vector_results.sort_by("_rowid").to_pylist()
|
||||
fts_list = fts_results.sort_by("_rowid").to_pylist()
|
||||
i, j = 0, 0
|
||||
while i < len(vector_list):
|
||||
if j >= len(fts_list):
|
||||
for vi in vector_list[i:]:
|
||||
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||
combined_list.append(vi)
|
||||
break
|
||||
|
||||
vi = vector_list[i]
|
||||
fj = fts_list[j]
|
||||
# invert the fts score from relevance to distance
|
||||
inverted_fts_score = self._invert_score(fj["score"])
|
||||
if vi["_rowid"] == fj["_rowid"]:
|
||||
vi["_relevance_score"] = self._combine_score(
|
||||
vi["_distance"], inverted_fts_score
|
||||
)
|
||||
vi["score"] = fj["score"] # keep the original score
|
||||
combined_list.append(vi)
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
|
||||
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||
combined_list.append(vi)
|
||||
i += 1
|
||||
else:
|
||||
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||
combined_list.append(fj)
|
||||
j += 1
|
||||
if j < len(fts_list) - 1:
|
||||
for fj in fts_list[j:]:
|
||||
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||
combined_list.append(fj)
|
||||
|
||||
relevance_score_schema = pa.schema(
|
||||
[
|
||||
pa.field("_relevance_score", pa.float32()),
|
||||
]
|
||||
)
|
||||
combined_schema = pa.unify_schemas(
|
||||
[vector_results.schema, fts_results.schema, relevance_score_schema]
|
||||
)
|
||||
tbl = pa.Table.from_pylist(combined_list, schema=combined_schema).sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
if self.score == "relevance":
|
||||
tbl = tbl.drop_columns(["score", "_distance"])
|
||||
return tbl
|
||||
|
||||
def _combine_score(self, score1, score2):
|
||||
# these scores represent distance
|
||||
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
|
||||
|
||||
def _invert_score(self, scores: List[float]):
|
||||
# Invert the scores between relevance and distance
|
||||
return 1 - scores
|
||||
104
python/python/lancedb/rerankers/openai.py
Normal file
104
python/python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class OpenaiReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the OpenAI API.
|
||||
WARNING: This is a prompt based reranker that uses chat model that is
|
||||
not a dedicated reranker API. This should be treated as experimental.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "gpt-4-turbo-preview"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
api_key : str, default None
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-4-turbo-preview",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.api_key = api_key
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert relevance ranker. Given a list of\
|
||||
documents and a query, your job is to determine the relevance\
|
||||
each document is for answering the query. Your output is JSON,\
|
||||
which is a list of documents. Each document has two fields,\
|
||||
content and relevance_score. relevance_score is from 0.0 to\
|
||||
1.0 indicating the relevance of the text to the given query.\
|
||||
Make sure to include all documents in the response.",
|
||||
},
|
||||
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||
],
|
||||
)
|
||||
results = json.loads(response.choices[0].message.content)["documents"]
|
||||
docs, scores = list(
|
||||
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||
) # tuples
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
openai = attempt_import_or_raise(
|
||||
"openai"
|
||||
) # TODO: force version or handle versions < 1.0
|
||||
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the CohereReranker."
|
||||
)
|
||||
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||
41
python/python/lancedb/schema.py
Normal file
41
python/python/lancedb/schema.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Schema related utilities."""
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType:
|
||||
"""A help function to create a vector type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dimension: The dimension of the vector.
|
||||
value_type: pa.DataType, optional
|
||||
The type of the value in the vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A PyArrow DataType for vectors.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> import lancedb
|
||||
>>> schema = pa.schema([
|
||||
... pa.field("id", pa.int64()),
|
||||
... pa.field("vector", lancedb.vector(756)),
|
||||
... ])
|
||||
"""
|
||||
return pa.list_(value_type, dimension)
|
||||
1774
python/python/lancedb/table.py
Normal file
1774
python/python/lancedb/table.py
Normal file
File diff suppressed because it is too large
Load Diff
265
python/python/lancedb/util.py
Normal file
265
python/python/lancedb/util.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
|
||||
def get_uri_scheme(uri: str) -> str:
|
||||
"""
|
||||
Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The URI to parse.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: The scheme of the URI.
|
||||
"""
|
||||
parsed = urlparse(uri)
|
||||
scheme = parsed.scheme
|
||||
if not scheme:
|
||||
scheme = "file"
|
||||
elif scheme in ["s3a", "s3n"]:
|
||||
scheme = "s3"
|
||||
elif len(scheme) == 1:
|
||||
# Windows drive names are parsed as the scheme
|
||||
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
|
||||
# So we add special handling here for schemes that are a single character
|
||||
scheme = "file"
|
||||
return scheme
|
||||
|
||||
|
||||
def get_uri_location(uri: str) -> str:
|
||||
"""
|
||||
Get the location of a URI. If the parameter is not a url, assumes it is just a path
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The URI to parse.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: Location part of the URL, without scheme
|
||||
"""
|
||||
parsed = urlparse(uri)
|
||||
if len(parsed.scheme) == 1:
|
||||
# Windows drive names are parsed as the scheme
|
||||
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
|
||||
# So we add special handling here for schemes that are a single character
|
||||
return uri
|
||||
|
||||
if not parsed.netloc:
|
||||
return parsed.path
|
||||
else:
|
||||
return parsed.netloc + parsed.path
|
||||
|
||||
|
||||
def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
|
||||
"""
|
||||
Get a PyArrow FileSystem from a URI, handling extra environment variables.
|
||||
"""
|
||||
if get_uri_scheme(uri) == "s3":
|
||||
fs = pa_fs.S3FileSystem(
|
||||
endpoint_override=os.environ.get("AWS_ENDPOINT"),
|
||||
request_timeout=30,
|
||||
connect_timeout=30,
|
||||
)
|
||||
path = get_uri_location(uri)
|
||||
return fs, path
|
||||
|
||||
return pa_fs.FileSystem.from_uri(uri)
|
||||
|
||||
|
||||
def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
||||
"""
|
||||
Join a URI with multiple parts, handles both local and remote paths
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base : str
|
||||
The base URI
|
||||
parts : str
|
||||
The parts to join to the base URI, each separated by the
|
||||
appropriate path separator for the URI scheme and OS
|
||||
"""
|
||||
if isinstance(base, pathlib.Path):
|
||||
return base.joinpath(*parts)
|
||||
base = str(base)
|
||||
if get_uri_scheme(base) == "file":
|
||||
# using pathlib for local paths make this windows compatible
|
||||
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
|
||||
return str(pathlib.Path(base, *parts))
|
||||
# for remote paths, just use os.path.join
|
||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
||||
|
||||
|
||||
def attempt_import_or_raise(module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
|
||||
def safe_import_pandas():
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
return pd
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema : pa.Schema
|
||||
The schema of the vector column.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: the vector column name.
|
||||
"""
|
||||
vector_col_name = ""
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
|
||||
field.type.value_type
|
||||
):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
"for vector search"
|
||||
)
|
||||
break
|
||||
elif vector_col_count == 1:
|
||||
vector_col_name = field_name
|
||||
if vector_col_count == 0:
|
||||
raise ValueError(
|
||||
"There is no vector column in the data. "
|
||||
"Please specify the vector column name for vector search"
|
||||
)
|
||||
return vector_col_name
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
|
||||
@value_to_sql.register(str)
|
||||
def _(value: str):
|
||||
return f"'{value}'"
|
||||
|
||||
|
||||
@value_to_sql.register(int)
|
||||
def _(value: int):
|
||||
return str(value)
|
||||
|
||||
|
||||
@value_to_sql.register(float)
|
||||
def _(value: float):
|
||||
return str(value)
|
||||
|
||||
|
||||
@value_to_sql.register(bool)
|
||||
def _(value: bool):
|
||||
return str(value).upper()
|
||||
|
||||
|
||||
@value_to_sql.register(type(None))
|
||||
def _(value: type(None)):
|
||||
return "NULL"
|
||||
|
||||
|
||||
@value_to_sql.register(datetime)
|
||||
def _(value: datetime):
|
||||
return f"'{value.isoformat()}'"
|
||||
|
||||
|
||||
@value_to_sql.register(date)
|
||||
def _(value: date):
|
||||
return f"'{value.isoformat()}'"
|
||||
|
||||
|
||||
@value_to_sql.register(list)
|
||||
def _(value: list):
|
||||
return "[" + ", ".join(map(value_to_sql, value)) + "]"
|
||||
|
||||
|
||||
@value_to_sql.register(np.ndarray)
|
||||
def _(value: np.ndarray):
|
||||
return value_to_sql(value.tolist())
|
||||
|
||||
|
||||
def deprecated(func):
|
||||
"""This is a decorator which can be used to mark functions
|
||||
as deprecated. It will result in a warning being emitted
|
||||
when the function is used."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def new_func(*args, **kwargs):
|
||||
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
||||
warnings.warn(
|
||||
(
|
||||
f"Function {func.__name__} is deprecated and will be "
|
||||
"removed in a future version"
|
||||
),
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
76
python/python/tests/test_context.py
Normal file
76
python/python/tests/test_context.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from lancedb.context import contextualize
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def raw_df() -> pd.DataFrame:
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"token": [
|
||||
"The",
|
||||
"quick",
|
||||
"brown",
|
||||
"fox",
|
||||
"jumped",
|
||||
"over",
|
||||
"the",
|
||||
"lazy",
|
||||
"dog",
|
||||
"I",
|
||||
"love",
|
||||
"sandwiches",
|
||||
],
|
||||
"document_id": [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_contextualizer(raw_df: pd.DataFrame):
|
||||
result = (
|
||||
contextualize(raw_df)
|
||||
.window(6)
|
||||
.stride(3)
|
||||
.text_col("token")
|
||||
.groupby("document_id")
|
||||
.to_pandas()["token"]
|
||||
.to_list()
|
||||
)
|
||||
|
||||
assert result == [
|
||||
"The quick brown fox jumped over",
|
||||
"fox jumped over the lazy dog",
|
||||
"the lazy dog",
|
||||
"I love sandwiches",
|
||||
]
|
||||
|
||||
|
||||
def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
|
||||
result = (
|
||||
contextualize(raw_df)
|
||||
.window(6)
|
||||
.stride(3)
|
||||
.text_col("token")
|
||||
.groupby("document_id")
|
||||
.min_window_size(4)
|
||||
.to_pandas()["token"]
|
||||
.to_list()
|
||||
)
|
||||
|
||||
assert result == [
|
||||
"The quick brown fox jumped over",
|
||||
"fox jumped over the lazy dog",
|
||||
]
|
||||
389
python/python/tests/test_db.py
Normal file
389
python/python/tests/test_db.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def test_basic(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
assert db.uri == str(tmp_path)
|
||||
assert db.table_names() == []
|
||||
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert "test" in db
|
||||
assert len(db) == 1
|
||||
|
||||
assert db.open_table("test").name == db["test"].name
|
||||
|
||||
|
||||
def test_ingest_pd(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
assert db.uri == str(tmp_path)
|
||||
assert db.table_names() == []
|
||||
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data=data)
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert "test" in db
|
||||
assert len(db) == 1
|
||||
|
||||
assert db.open_table("test").name == db["test"].name
|
||||
|
||||
|
||||
def test_ingest_iterator(tmp_path):
|
||||
class PydanticSchema(LanceModel):
|
||||
vector: Vector(2)
|
||||
item: str
|
||||
price: float
|
||||
|
||||
arrow_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
|
||||
def make_batches():
|
||||
for _ in range(5):
|
||||
yield from [
|
||||
# pandas
|
||||
pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [1, 1]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
),
|
||||
# pylist
|
||||
[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
# recordbatch
|
||||
pa.RecordBatch.from_arrays(
|
||||
[
|
||||
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
["vector", "item", "price"],
|
||||
),
|
||||
# pa Table
|
||||
pa.Table.from_arrays(
|
||||
[
|
||||
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
["vector", "item", "price"],
|
||||
),
|
||||
# pydantic list
|
||||
[
|
||||
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
||||
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
||||
],
|
||||
# TODO: test pydict separately. it is unique column number and
|
||||
# name constraints
|
||||
]
|
||||
|
||||
def run_tests(schema):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
assert len(tbl) == tbl_len * 2
|
||||
assert len(tbl.list_versions()) == 3
|
||||
db.drop_database()
|
||||
|
||||
run_tests(arrow_schema)
|
||||
run_tests(PydanticSchema)
|
||||
|
||||
|
||||
def test_table_names(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
assert db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_names_async(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert await db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
def test_create_mode(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["fizz", "buzz"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = db.create_table("test", data=new_data, mode="overwrite")
|
||||
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||
|
||||
|
||||
def test_create_exist_ok(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
db.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
tbl2 = db.create_table("test", data=data, exist_ok=True)
|
||||
assert tbl.name == tbl2.name
|
||||
assert tbl.schema == tbl2.schema
|
||||
assert len(tbl) == len(tbl2)
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
|
||||
assert tbl3.schema == schema
|
||||
|
||||
bad_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("price", pa.float64()),
|
||||
pa.field("extra", pa.float32()),
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
def test_delete_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
|
||||
db.drop_table("test")
|
||||
assert db.table_names() == []
|
||||
|
||||
db.create_table("test", data=data)
|
||||
assert db.table_names() == ["test"]
|
||||
|
||||
# dropping a table that does not exist should pass
|
||||
# if ignore_missing=True
|
||||
db.drop_table("does_not_exist", ignore_missing=True)
|
||||
|
||||
|
||||
def test_drop_database(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[5.1, 4.1], [5.9, 10.5]],
|
||||
"item": ["kiwi", "avocado"],
|
||||
"price": [12.0, 17.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test", data=data)
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test", data=data)
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
|
||||
db.create_table("new_test", data=new_data)
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
|
||||
# it should pass when no tables are present
|
||||
db.create_table("test", data=new_data)
|
||||
db.drop_table("test")
|
||||
assert db.table_names() == []
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
|
||||
# creating an empty database with schema
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||
db.create_table("empty_table", schema=schema)
|
||||
# dropping a empty database should pass
|
||||
db.drop_database()
|
||||
assert db.table_names() == []
|
||||
|
||||
|
||||
def test_empty_or_nonexistent_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
with pytest.raises(Exception):
|
||||
db.create_table("test_with_no_data")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
db.open_table("does_not_exist")
|
||||
|
||||
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
|
||||
test = db.create_table("test", schema=schema)
|
||||
|
||||
class TestModel(LanceModel):
|
||||
a: int
|
||||
|
||||
test2 = db.create_table("test2", schema=TestModel)
|
||||
assert test.schema == test2.schema
|
||||
|
||||
|
||||
def test_replace_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
table = db.create_table(
|
||||
"test",
|
||||
[
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
],
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
replace=False,
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
replace=True,
|
||||
index_cache_size=10,
|
||||
)
|
||||
|
||||
|
||||
def test_prefilter_with_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
data = [
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
]
|
||||
sample_key = data[100]["vector"]
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data,
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
)
|
||||
table = (
|
||||
table.search(sample_key)
|
||||
.where("price == 500", prefilter=True)
|
||||
.limit(5)
|
||||
.to_arrow()
|
||||
)
|
||||
assert table.num_rows == 1
|
||||
26
python/python/tests/test_e2e_remote_db.py
Normal file
26
python/python/tests/test_e2e_remote_db.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from lancedb import LanceDBConnection
|
||||
|
||||
# TODO: setup integ test mark and script
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to set up a local server")
|
||||
def test_against_local_server():
|
||||
conn = LanceDBConnection("lancedb+http://localhost:10024")
|
||||
table = conn.open_table("sift1m_ivf1024_pq16")
|
||||
df = table.search(np.random.rand(128)).to_pandas()
|
||||
assert len(df) == 10
|
||||
114
python/python/tests/test_embeddings.py
Normal file
114
python/python/tests/test_embeddings.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
|
||||
|
||||
|
||||
def test_with_embeddings():
|
||||
for wrap_api in [True, False]:
|
||||
if wrap_api and sys.version_info.minor >= 11:
|
||||
# ratelimiter package doesn't work on 3.11
|
||||
continue
|
||||
data = pa.Table.from_arrays(
|
||||
[
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
names=["text", "price"],
|
||||
)
|
||||
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
|
||||
assert data.num_columns == 3
|
||||
assert data.num_rows == 2
|
||||
assert data.column_names == ["text", "price", "vector"]
|
||||
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||
|
||||
|
||||
def test_embedding_function(tmp_path):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
# let's create a table
|
||||
table = pa.table(
|
||||
{
|
||||
"text": pa.array(["hello world", "goodbye world"]),
|
||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||
}
|
||||
)
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
# Write it to disk
|
||||
lance.write_dataset(table, tmp_path / "test.lance")
|
||||
|
||||
# Load this back
|
||||
ds = lance.dataset(tmp_path / "test.lance")
|
||||
|
||||
# can we get the serialized version back out?
|
||||
configs = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
conf = configs["vector"]
|
||||
func = conf.function
|
||||
actual = func.compute_query_embeddings("hello world")
|
||||
|
||||
# And we make sure we can call it
|
||||
expected = func.compute_query_embeddings("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
return Schema
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("test-rate-limited").create(max_retries=0)
|
||||
schema = _get_schema_from_model(model)
|
||||
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||
table.add([{"text": "hello world"}])
|
||||
with pytest.raises(Exception):
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 1
|
||||
|
||||
model = registry.get("test-rate-limited").create()
|
||||
schema = _get_schema_from_model(model)
|
||||
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||
table.add([{"text": "hello world"}])
|
||||
table.add([{"text": "hello world"}])
|
||||
assert len(table) == 2
|
||||
419
python/python/tests/test_embeddings_slow.py
Normal file
419
python/python/tests/test_embeddings_slow.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
else:
|
||||
_mlx = None
|
||||
except Exception:
|
||||
_mlx = None
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("imagebind") is not None:
|
||||
_imagebind = True
|
||||
else:
|
||||
_imagebind = None
|
||||
except Exception:
|
||||
_imagebind = None
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get(alias).create(max_retries=0)
|
||||
func2 = registry.get(alias).create(max_retries=0)
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
text2: str = func2.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vector2: Vector(func2.ndims()) = func2.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"text": [
|
||||
"hello world",
|
||||
"goodbye world",
|
||||
"fizz",
|
||||
"buzz",
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
],
|
||||
"text2": [
|
||||
"to be or not to be",
|
||||
"that is the question",
|
||||
"for whether tis nobler",
|
||||
"in the mind to suffer",
|
||||
"the slings and arrows",
|
||||
"of outrageous fortune",
|
||||
"or to take arms",
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = (
|
||||
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
|
||||
vec = func.compute_query_embeddings(query)[0]
|
||||
expected = (
|
||||
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
actual = (
|
||||
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
assert actual.text != "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("open-clip").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == frombytes.label
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
actual = (
|
||||
table.search(query_image, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == other.label
|
||||
|
||||
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||
assert np.allclose(
|
||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_imagebind is None,
|
||||
reason="skip if imagebind not installed.",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_imagebind(tmp_path):
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
import pandas as pd
|
||||
import requests
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
print(f"Created temporary directory {temp_dir}")
|
||||
|
||||
def download_images(image_uris):
|
||||
downloaded_image_paths = []
|
||||
for uri in image_uris:
|
||||
try:
|
||||
response = requests.get(uri, stream=True)
|
||||
if response.status_code == 200:
|
||||
# Extract image name from URI
|
||||
image_name = os.path.basename(uri)
|
||||
image_path = os.path.join(temp_dir, image_name)
|
||||
with open(image_path, "wb") as out_file:
|
||||
shutil.copyfileobj(response.raw, out_file)
|
||||
downloaded_image_paths.append(image_path)
|
||||
except Exception as e: # noqa: PERF203
|
||||
print(f"Failed to download {uri}. Error: {e}")
|
||||
return temp_dir, downloaded_image_paths
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("imagebind").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(uris)
|
||||
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
|
||||
# text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
# image search
|
||||
query_image_uri = [
|
||||
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(query_image_uri)
|
||||
query_image_uri = downloaded_images[0]
|
||||
actual = (
|
||||
table.search(query_image_uri, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
if os.path.isdir(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"Deleted temporary directory {temp_dir}")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
) # also skip if cohere not installed
|
||||
def test_cohere_embedding_function():
|
||||
cohere = (
|
||||
get_registry()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0", max_retries=0)
|
||||
)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
vector: Vector(cohere.ndims()) = cohere.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_instructor_embedding(tmp_path):
|
||||
model = get_registry().get("instructor").create(max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
|
||||
)
|
||||
def test_gemini_embedding(tmp_path):
|
||||
model = get_registry().get("gemini-text").create(max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_gte_embedding(tmp_path):
|
||||
import lancedb.embeddings.gte
|
||||
|
||||
model = get_registry().get("gte-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
def aws_setup():
|
||||
try:
|
||||
import boto3
|
||||
|
||||
sts = boto3.client("sts")
|
||||
sts.get_caller_identity()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
not aws_setup(), reason="AWS credentials not set or libraries not installed"
|
||||
)
|
||||
def test_bedrock_embedding(tmp_path):
|
||||
for name in [
|
||||
"amazon.titan-embed-text-v1",
|
||||
"cohere.embed-english-v3",
|
||||
"cohere.embed-multilingual-v3",
|
||||
]:
|
||||
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_embedding(tmp_path):
|
||||
def _get_table(model):
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
return tbl
|
||||
|
||||
model = get_registry().get("openai").create(max_retries=0)
|
||||
tbl = _get_table(model)
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
model = (
|
||||
get_registry()
|
||||
.get("openai")
|
||||
.create(max_retries=0, name="text-embedding-3-large")
|
||||
)
|
||||
tbl = _get_table(model)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
model = (
|
||||
get_registry()
|
||||
.get("openai")
|
||||
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
|
||||
)
|
||||
tbl = _get_table(model)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
184
python/python/tests/test_fts.py
Normal file
184
python/python/tests/test_fts.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("lancedb.fts")
|
||||
tantivy = pytest.importorskip("tantivy")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> ldb.table.LanceTable:
|
||||
db = ldb.connect(tmp_path)
|
||||
vectors = [np.random.randn(128) for _ in range(100)]
|
||||
|
||||
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
|
||||
verbs = ("runs", "hits", "jumps", "drives", "barfs")
|
||||
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
|
||||
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
|
||||
text = [
|
||||
" ".join(
|
||||
[
|
||||
nouns[random.randrange(0, 5)],
|
||||
verbs[random.randrange(0, 5)],
|
||||
adv[random.randrange(0, 5)],
|
||||
adj[random.randrange(0, 5)],
|
||||
]
|
||||
)
|
||||
for _ in range(100)
|
||||
]
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data=pd.DataFrame(
|
||||
{
|
||||
"vector": vectors,
|
||||
"id": [i % 2 for i in range(100)],
|
||||
"text": text,
|
||||
"text2": text,
|
||||
"nested": [{"text": t} for t in text],
|
||||
}
|
||||
),
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def test_create_index(tmp_path):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
assert isinstance(index, tantivy.Index)
|
||||
assert os.path.exists(str(tmp_path / "index"))
|
||||
|
||||
|
||||
def test_populate_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
assert ldb.fts.populate_index(index, table, ["text"]) == len(table)
|
||||
|
||||
|
||||
def test_search_index(tmp_path, table):
|
||||
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
|
||||
ldb.fts.populate_index(index, table, ["text"])
|
||||
index.reload()
|
||||
results = ldb.fts.search_index(index, query="puppy", limit=10)
|
||||
assert len(results) == 2
|
||||
assert len(results[0]) == 10 # row_ids
|
||||
assert len(results[1]) == 10 # _distance
|
||||
|
||||
|
||||
def test_create_index_from_table(tmp_path, table):
|
||||
table.create_fts_index("text")
|
||||
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
||||
assert len(df) <= 10
|
||||
assert "text" in df.columns
|
||||
|
||||
# Check whether it can be updated
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
"vector": np.random.randn(128),
|
||||
"id": 101,
|
||||
"text": "gorilla",
|
||||
"text2": "gorilla",
|
||||
"nested": {"text": "gorilla"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
table.create_fts_index("text")
|
||||
|
||||
table.create_fts_index("text", replace=True)
|
||||
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
|
||||
|
||||
|
||||
def test_create_index_multiple_columns(tmp_path, table):
|
||||
table.create_fts_index(["text", "text2"])
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 10
|
||||
assert "text" in df.columns
|
||||
assert "text2" in df.columns
|
||||
|
||||
|
||||
def test_empty_rs(tmp_path, table, mocker):
|
||||
table.create_fts_index(["text", "text2"])
|
||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||
df = table.search("puppy").limit(10).to_pandas()
|
||||
assert len(df) == 0
|
||||
|
||||
|
||||
def test_nested_schema(tmp_path, table):
|
||||
table.create_fts_index("nested.text")
|
||||
rs = table.search("puppy").limit(10).to_list()
|
||||
assert len(rs) == 10
|
||||
|
||||
|
||||
def test_search_index_with_filter(table):
|
||||
table.create_fts_index("text")
|
||||
orig_import = __import__
|
||||
|
||||
def import_mock(name, *args):
|
||||
if name == "duckdb":
|
||||
raise ImportError
|
||||
return orig_import(name, *args)
|
||||
|
||||
# no duckdb
|
||||
with mock.patch("builtins.__import__", side_effect=import_mock):
|
||||
rs = table.search("puppy").where("id=1").limit(10).to_list()
|
||||
for r in rs:
|
||||
assert r["id"] == 1
|
||||
|
||||
# yes duckdb
|
||||
rs2 = table.search("puppy").where("id=1").limit(10).to_list()
|
||||
for r in rs2:
|
||||
assert r["id"] == 1
|
||||
|
||||
assert rs == rs2
|
||||
|
||||
|
||||
def test_null_input(table):
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
"vector": np.random.randn(128),
|
||||
"id": 101,
|
||||
"text": None,
|
||||
"text2": None,
|
||||
"nested": {"text": None},
|
||||
}
|
||||
]
|
||||
)
|
||||
table.create_fts_index("text")
|
||||
|
||||
|
||||
def test_syntax(table):
|
||||
# https://github.com/lancedb/lancedb/issues/769
|
||||
table.create_fts_index("text")
|
||||
with pytest.raises(ValueError, match="Syntax Error"):
|
||||
table.search("they could have been dogs OR cats").limit(10).to_list()
|
||||
table.search("they could have been dogs OR cats").phrase_query().limit(10).to_list()
|
||||
# this should work
|
||||
table.search('"they could have been dogs OR cats"').limit(10).to_list()
|
||||
# this should work too
|
||||
table.search('''"the cats OR dogs were not really 'pets' at all"''').limit(
|
||||
10
|
||||
).to_list()
|
||||
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
|
||||
10
|
||||
).to_list()
|
||||
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
|
||||
10
|
||||
).to_list()
|
||||
50
python/python/tests/test_io.py
Normal file
50
python/python/tests/test_io.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import pytest
|
||||
|
||||
# You need to setup AWS credentials an a base path to run this test. Example
|
||||
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
(os.environ.get("TEST_S3_BASE_URL") is None),
|
||||
reason="please setup s3 base url",
|
||||
)
|
||||
def test_s3_io():
|
||||
db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL"))
|
||||
assert db.table_names() == []
|
||||
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert "test" in db
|
||||
assert len(db) == 1
|
||||
|
||||
assert db.open_table("test").name == db["test"].name
|
||||
245
python/python/tests/test_pydantic.py
Normal file
245
python/python/tests/test_pydantic.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9),
|
||||
reason="using native type alias requires python3.9 or higher",
|
||||
)
|
||||
def test_pydantic_to_arrow():
|
||||
class StructModel(pydantic.BaseModel):
|
||||
a: str
|
||||
b: Optional[float]
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: int
|
||||
s: str
|
||||
vec: list[float]
|
||||
li: list[int]
|
||||
lili: list[list[float]]
|
||||
litu: list[tuple[float, float]]
|
||||
opt: Optional[str] = None
|
||||
st: StructModel
|
||||
dt: date
|
||||
dtt: datetime
|
||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||
# d: dict
|
||||
|
||||
# TODO: test we can actually convert the model into data.
|
||||
# m = TestModel(
|
||||
# id=1,
|
||||
# s="hello",
|
||||
# vec=[1.0, 2.0, 3.0],
|
||||
# li=[2, 3, 4],
|
||||
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
# st=StructModel(a="a", b=1.0),
|
||||
# dt=date.today(),
|
||||
# dtt=datetime.now(),
|
||||
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
# )
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64(), False),
|
||||
pa.field("s", pa.utf8(), False),
|
||||
pa.field("vec", pa.list_(pa.float64()), False),
|
||||
pa.field("li", pa.list_(pa.int64()), False),
|
||||
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
|
||||
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
|
||||
pa.field("opt", pa.utf8(), True),
|
||||
pa.field(
|
||||
"st",
|
||||
pa.struct(
|
||||
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("dt", pa.date32(), False),
|
||||
pa.field("dtt", pa.timestamp("us"), False),
|
||||
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10),
|
||||
reason="using | type syntax requires python3.10 or higher",
|
||||
)
|
||||
def test_optional_types_py310():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
a: str | None
|
||||
b: None | str
|
||||
c: Optional[str]
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("a", pa.utf8(), True),
|
||||
pa.field("b", pa.utf8(), True),
|
||||
pa.field("c", pa.utf8(), True),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info > (3, 8),
|
||||
reason="using native type alias requires python3.9 or higher",
|
||||
)
|
||||
def test_pydantic_to_arrow_py38():
|
||||
class StructModel(pydantic.BaseModel):
|
||||
a: str
|
||||
b: Optional[float]
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
id: int
|
||||
s: str
|
||||
vec: List[float]
|
||||
li: List[int]
|
||||
lili: List[List[float]]
|
||||
litu: List[Tuple[float, float]]
|
||||
opt: Optional[str] = None
|
||||
st: StructModel
|
||||
dt: date
|
||||
dtt: datetime
|
||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||
# d: dict
|
||||
|
||||
# TODO: test we can actually convert the model to Arrow data.
|
||||
# m = TestModel(
|
||||
# id=1,
|
||||
# s="hello",
|
||||
# vec=[1.0, 2.0, 3.0],
|
||||
# li=[2, 3, 4],
|
||||
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
# st=StructModel(a="a", b=1.0),
|
||||
# dt=date.today(),
|
||||
# dtt=datetime.now(),
|
||||
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
# )
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
expect_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64(), False),
|
||||
pa.field("s", pa.utf8(), False),
|
||||
pa.field("vec", pa.list_(pa.float64()), False),
|
||||
pa.field("li", pa.list_(pa.int64()), False),
|
||||
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
|
||||
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
|
||||
pa.field("opt", pa.utf8(), True),
|
||||
pa.field(
|
||||
"st",
|
||||
pa.struct(
|
||||
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("dt", pa.date32(), False),
|
||||
pa.field("dtt", pa.timestamp("us"), False),
|
||||
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
|
||||
def test_fixed_size_list_field():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
assert json.loads(data.model_dump_json()) == {
|
||||
"vec": list(range(16)),
|
||||
"li": [1, 2, 3],
|
||||
}
|
||||
else:
|
||||
assert data.dict() == {
|
||||
"vec": list(range(16)),
|
||||
"li": [1, 2, 3],
|
||||
}
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == pa.schema(
|
||||
[
|
||||
pa.field("vec", pa.list_(pa.float32(), 16), False),
|
||||
pa.field("li", pa.list_(pa.int64()), False),
|
||||
]
|
||||
)
|
||||
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
json_schema = TestModel.model_json_schema()
|
||||
else:
|
||||
json_schema = TestModel.schema()
|
||||
|
||||
assert json_schema == {
|
||||
"properties": {
|
||||
"vec": {
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 16,
|
||||
"minItems": 16,
|
||||
"title": "Vec",
|
||||
"type": "array",
|
||||
},
|
||||
"li": {"items": {"type": "integer"}, "title": "Li", "type": "array"},
|
||||
},
|
||||
"required": ["vec", "li"],
|
||||
"title": "TestModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_fixed_size_list_validation():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: Vector(8)
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=range(9))
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=range(7))
|
||||
|
||||
TestModel(vec=range(8))
|
||||
|
||||
|
||||
def test_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(16) = Field(default=[0.0] * 16)
|
||||
li: List[int] = Field(default=[1, 2, 3])
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["vector", "li"]
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
175
python/python/tests/test_query.py
Normal file
175
python/python/tests/test_query.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest.mock as mock
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
def __init__(self, tmp_path):
|
||||
self.uri = tmp_path
|
||||
self._conn = LanceDBConnection(self.uri)
|
||||
|
||||
def to_lance(self):
|
||||
return lance.dataset(self.uri)
|
||||
|
||||
def _execute_query(self, query):
|
||||
ds = self.to_lance()
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
"k": query.k,
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> MockTable:
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
}
|
||||
)
|
||||
lance.write_dataset(df, tmp_path)
|
||||
return MockTable(tmp_path)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(2)
|
||||
id: int
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
|
||||
results = q.to_pydantic(TestModel)
|
||||
assert len(results) == 1
|
||||
r0 = results[0]
|
||||
assert isinstance(r0, TestModel)
|
||||
assert r0.id == 1
|
||||
assert r0.vector == [1, 2]
|
||||
assert r0.str_field == "a"
|
||||
assert r0.float_field == 1.0
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(1)
|
||||
.select(["id"])
|
||||
.to_list()
|
||||
)
|
||||
assert rs[0]["id"] == 1
|
||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
|
||||
assert rs[0]["id"] == 2
|
||||
assert all(np.array(rs[0]["vector"]) == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_prefilter(table):
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2")
|
||||
.limit(1)
|
||||
.to_pandas()
|
||||
)
|
||||
assert len(df) == 0
|
||||
|
||||
df = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.where("id = 2", prefilter=True)
|
||||
.limit(1)
|
||||
.to_pandas()
|
||||
)
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
|
||||
df_l2 = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("L2")
|
||||
.to_pandas()
|
||||
)
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_pandas()
|
||||
)
|
||||
assert df_cosine._distance[0] == pytest.approx(
|
||||
cosine_distance(query, df_cosine.vector[0]),
|
||||
abs=1e-6,
|
||||
)
|
||||
assert 0 <= df_cosine._distance[0] <= 1
|
||||
|
||||
|
||||
def test_query_builder_with_different_vector_column():
|
||||
table = mock.MagicMock(spec=LanceTable)
|
||||
query = [4, 8]
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
.limit(2)
|
||||
)
|
||||
ds = mock.Mock()
|
||||
table.to_lance.return_value = ds
|
||||
builder.to_arrow()
|
||||
table._execute_query.assert_called_once_with(
|
||||
Query(
|
||||
vector=query,
|
||||
filter="b < 10",
|
||||
k=2,
|
||||
metric="cosine",
|
||||
columns=["b"],
|
||||
nprobes=20,
|
||||
refine_factor=None,
|
||||
vector_column="foo_vector",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def cosine_distance(vec1, vec2):
|
||||
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
94
python/python/tests/test_remote_client.py
Normal file
94
python/python/tests/test_remote_client.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attrs
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MockLanceDBServer:
|
||||
runner: web.AppRunner = attrs.field(init=False)
|
||||
site: web.TCPSite = attrs.field(init=False)
|
||||
|
||||
async def query_handler(self, request: web.Request) -> web.Response:
|
||||
table_name = request.match_info["table_name"]
|
||||
assert table_name == "test_table"
|
||||
|
||||
await request.json()
|
||||
# TODO: do some matching
|
||||
|
||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
||||
ids = pd.Series(range(10), name="id")
|
||||
df = pd.DataFrame([vecs, ids]).T
|
||||
|
||||
batch = pa.RecordBatch.from_pandas(
|
||||
df,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 128)),
|
||||
pa.field("id", pa.int64()),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_file(sink, batch.schema) as writer:
|
||||
writer.write_batch(batch)
|
||||
|
||||
return web.Response(body=sink.getvalue().to_pybytes())
|
||||
|
||||
async def setup(self):
|
||||
app = web.Application()
|
||||
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
|
||||
self.runner = web.AppRunner(app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, "localhost", 8111)
|
||||
|
||||
async def start(self):
|
||||
await self.site.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky somehow, fix later")
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_with_mock_server():
|
||||
mock_server = MockLanceDBServer()
|
||||
await mock_server.setup()
|
||||
await mock_server.start()
|
||||
|
||||
try:
|
||||
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
|
||||
df = (
|
||||
await client.query(
|
||||
"test_table",
|
||||
VectorQuery(
|
||||
vector=np.random.rand(128).tolist(),
|
||||
k=10,
|
||||
_metric="L2",
|
||||
columns=["id", "vector"],
|
||||
),
|
||||
)
|
||||
).to_pandas()
|
||||
|
||||
assert "vector" in df.columns
|
||||
assert "id" in df.columns
|
||||
finally:
|
||||
# make sure we don't leak resources
|
||||
await mock_server.stop()
|
||||
41
python/python/tests/test_remote_db.py
Normal file
41
python/python/tests/test_remote_db.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
|
||||
def post(self, path: str):
|
||||
pass
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str):
|
||||
pass
|
||||
|
||||
|
||||
def test_remote_db():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
261
python/python/tests/test_rerankers.py
Normal file
261
python/python/tests/test_rerankers.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import (
|
||||
CohereReranker,
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
# Tests rely on FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
|
||||
def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
# Need to test with a bunch of phrases to make sure sorting is consistent
|
||||
phrases = [
|
||||
"great kid don't get cocky",
|
||||
"now that's a name I haven't heard in a long time",
|
||||
"if you strike me down I shall become more powerful than you imagine",
|
||||
"I find your lack of faith disturbing",
|
||||
"I've got a bad feeling about this",
|
||||
"never tell me the odds",
|
||||
"I am your father",
|
||||
"somebody has to save our skins",
|
||||
"New strategy R2 let the wookiee win",
|
||||
"Arrrrggghhhhhhh",
|
||||
"I see a mansard roof through the trees",
|
||||
"I see a salty message written in the eves",
|
||||
"the ground beneath my feet",
|
||||
"the hot garbage and concrete",
|
||||
"and now the tops of buildings",
|
||||
"everybody with a worried mind could never forgive the sight",
|
||||
"of wicked snakes inside a place you thought was dignified",
|
||||
"I don't wanna live like this",
|
||||
"but I don't wanna die",
|
||||
"The templars want control",
|
||||
"the brotherhood of assassins want freedom",
|
||||
"if only they could both see the world as it really is",
|
||||
"there would be peace",
|
||||
"but the war goes on",
|
||||
"altair's legacy was a warning",
|
||||
"Kratos had a son",
|
||||
"he was a god",
|
||||
"the god of war",
|
||||
"but his son was mortal",
|
||||
"there hasn't been a good battlefield game since 2142",
|
||||
"I wish they would make another one",
|
||||
"campains are not as good as they used to be",
|
||||
"Multiplayer and open world games have destroyed the single player experience",
|
||||
"Maybe the future is console games",
|
||||
"I don't know",
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
return table, MyTable
|
||||
|
||||
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven.", query_type="hybrid")
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven..", query_type="hybrid"
|
||||
).to_pydantic(schema)
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
928
python/python/tests/test_table.py
Normal file
928
python/python/tests/test_table.py
Normal file
@@ -0,0 +1,928 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from copy import copy
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
self.read_consistency_interval = None
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path) -> MockDB:
|
||||
return MockDB(tmp_path)
|
||||
|
||||
|
||||
def test_basic(db):
|
||||
ds = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
).to_lance()
|
||||
|
||||
table = LanceTable(db, "test")
|
||||
assert table.name == "test"
|
||||
assert table.schema == ds.schema
|
||||
assert table.to_lance().to_table() == ds.to_table()
|
||||
|
||||
|
||||
def test_create_table(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
expected = pa.Table.from_arrays(
|
||||
[
|
||||
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
data = [
|
||||
[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
]
|
||||
df = pd.DataFrame(data[0])
|
||||
data.append(df)
|
||||
data.append(pa.Table.from_pandas(df, schema=schema))
|
||||
|
||||
for i, d in enumerate(data):
|
||||
tbl = (
|
||||
LanceTable.create(db, f"test_{i}", data=d, schema=schema)
|
||||
.to_lance()
|
||||
.to_table()
|
||||
)
|
||||
assert expected == tbl
|
||||
|
||||
|
||||
def test_empty_table(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float32()),
|
||||
]
|
||||
)
|
||||
tbl = LanceTable.create(db, "test", schema=schema)
|
||||
data = [
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
tbl.add(data=data)
|
||||
|
||||
|
||||
def test_add(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
_add(table, schema)
|
||||
|
||||
table = LanceTable.create(db, "test2", schema=schema)
|
||||
table.add(
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
_add(table, schema)
|
||||
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
class Metadata(BaseModel):
|
||||
source: str
|
||||
timestamp: datetime
|
||||
|
||||
class Document(BaseModel):
|
||||
content: str
|
||||
meta: Metadata
|
||||
|
||||
class LanceSchema(LanceModel):
|
||||
id: str
|
||||
vector: Vector(2)
|
||||
li: List[int]
|
||||
payload: Document
|
||||
|
||||
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite")
|
||||
assert tbl.schema == LanceSchema.to_arrow_schema()
|
||||
|
||||
# add works
|
||||
expected = LanceSchema(
|
||||
id="id",
|
||||
vector=[0.0, 0.0],
|
||||
li=[1, 2, 3],
|
||||
payload=Document(
|
||||
content="foo", meta=Metadata(source="bar", timestamp=datetime.now())
|
||||
),
|
||||
)
|
||||
tbl.add([expected])
|
||||
|
||||
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
|
||||
assert result == expected
|
||||
|
||||
flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=1)
|
||||
assert len(flattened.columns) == 6 # _distance is automatically added
|
||||
|
||||
really_flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=True)
|
||||
assert len(really_flattened.columns) == 7
|
||||
|
||||
|
||||
def test_polars(db):
|
||||
data = {
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
# Ingest polars dataframe
|
||||
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
|
||||
assert len(table) == 2
|
||||
|
||||
result = table.to_pandas()
|
||||
assert np.allclose(result["vector"].tolist(), data["vector"])
|
||||
assert result["item"].tolist() == data["item"]
|
||||
assert np.allclose(result["price"].tolist(), data["price"])
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.large_string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
assert table.schema == schema
|
||||
|
||||
# search results to polars dataframe
|
||||
q = [3.1, 4.1]
|
||||
result = table.search(q).limit(1).to_polars()
|
||||
assert np.allclose(result["vector"][0], q)
|
||||
assert result["item"][0] == "foo"
|
||||
assert np.allclose(result["price"][0], 10.0)
|
||||
|
||||
# enter table to polars dataframe
|
||||
result = table.to_polars()
|
||||
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
|
||||
|
||||
# make sure filtering isn't broken
|
||||
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
|
||||
assert len(filtered_result) == 2
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table) == 3
|
||||
|
||||
expected = pa.Table.from_arrays(
|
||||
[
|
||||
pa.FixedSizeListArray.from_arrays(
|
||||
pa.array([3.1, 4.1, 5.9, 26.5, 6.3, 100.5]), 2
|
||||
),
|
||||
pa.array(["foo", "bar", "new"]),
|
||||
pa.array([10.0, 20.0, 30.0]),
|
||||
],
|
||||
schema=schema,
|
||||
)
|
||||
assert expected == table.to_arrow()
|
||||
|
||||
|
||||
def test_versioning(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 3
|
||||
|
||||
table.checkout(2)
|
||||
assert table.version == 2
|
||||
assert len(table) == 2
|
||||
|
||||
|
||||
def test_create_index_method():
|
||||
with patch.object(
|
||||
LanceTable, "_dataset_mut", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
# by default we raise an error on bad input vectors
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"error_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"drop_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
assert len(table) == 1
|
||||
|
||||
# We can fill bad input with some value
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"fill_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
assert len(table) == 3
|
||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_restore(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "type": "vector"}],
|
||||
)
|
||||
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
table.restore(2)
|
||||
assert len(table.list_versions()) == 4
|
||||
assert len(table) == 1
|
||||
|
||||
expected = table.to_arrow()
|
||||
table.checkout(2)
|
||||
table.restore()
|
||||
assert len(table.list_versions()) == 5
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
table.restore(5) # latest version should be no-op
|
||||
assert len(table.list_versions()) == 5
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(6)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.restore(0)
|
||||
|
||||
|
||||
def test_merge(db, tmp_path):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
|
||||
table.merge(other_table, left_on="id")
|
||||
assert len(table.list_versions()) == 3
|
||||
expected = pa.table(
|
||||
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
|
||||
schema=table.schema,
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
|
||||
table.restore(1)
|
||||
table.merge(other_dataset, left_on="id")
|
||||
|
||||
|
||||
def test_delete(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
table.delete("id=0")
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 1
|
||||
assert table.to_pandas()["id"].tolist() == [1]
|
||||
|
||||
|
||||
def test_update(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert len(table.list_versions()) == 3
|
||||
assert table.version == 3
|
||||
assert len(table) == 2
|
||||
v = table.to_arrow()["vector"].combine_chunks()
|
||||
v = v.values.to_numpy().reshape(2, 2)
|
||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||
|
||||
|
||||
def test_update_types(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[
|
||||
{
|
||||
"id": 0,
|
||||
"str": "foo",
|
||||
"float": 1.1,
|
||||
"timestamp": datetime(2021, 1, 1),
|
||||
"date": date(2021, 1, 1),
|
||||
"vector1": [1.0, 0.0],
|
||||
"vector2": [1.0, 1.0],
|
||||
}
|
||||
],
|
||||
)
|
||||
# Update with SQL
|
||||
table.update(
|
||||
values_sql=dict(
|
||||
id="1",
|
||||
str="'bar'",
|
||||
float="2.2",
|
||||
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
|
||||
date="DATE '2021-01-02'",
|
||||
vector1="[2.0, 2.0]",
|
||||
vector2="[3.0, 3.0]",
|
||||
)
|
||||
)
|
||||
actual = table.to_arrow().to_pylist()[0]
|
||||
expected = dict(
|
||||
id=1,
|
||||
str="bar",
|
||||
float=2.2,
|
||||
timestamp=datetime(2021, 1, 2),
|
||||
date=date(2021, 1, 2),
|
||||
vector1=[2.0, 2.0],
|
||||
vector2=[3.0, 3.0],
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
# Update with values
|
||||
table.update(
|
||||
values=dict(
|
||||
id=2,
|
||||
str="baz",
|
||||
float=3.3,
|
||||
timestamp=datetime(2021, 1, 3),
|
||||
date=date(2021, 1, 3),
|
||||
vector1=[3.0, 3.0],
|
||||
vector2=np.array([4.0, 4.0]),
|
||||
)
|
||||
)
|
||||
actual = table.to_arrow().to_pylist()[0]
|
||||
expected = dict(
|
||||
id=2,
|
||||
str="baz",
|
||||
float=3.3,
|
||||
timestamp=datetime(2021, 1, 3),
|
||||
date=date(2021, 1, 3),
|
||||
vector1=[3.0, 3.0],
|
||||
vector2=[4.0, 4.0],
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_merge_insert(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}),
|
||||
)
|
||||
assert len(table) == 3
|
||||
version = table.version
|
||||
|
||||
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
|
||||
# upsert
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# conditional update
|
||||
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||
new_data
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# insert-if-not-exists
|
||||
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
|
||||
# replace-range
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
|
||||
"a > 2"
|
||||
).execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# replace-range no condition
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute(
|
||||
new_data
|
||||
)
|
||||
|
||||
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockTextEmbeddingFunction()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[conf],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_create_f16_table(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: Vector(128, value_type=pa.float16())
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"text": [f"s-{i}" for i in range(10000)],
|
||||
"vector": [np.random.randn(128).astype(np.float16) for _ in range(10000)],
|
||||
}
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"f16_tbl",
|
||||
schema=MyTable,
|
||||
)
|
||||
table.add(df)
|
||||
table.create_index(num_partitions=2, num_sub_vectors=8)
|
||||
|
||||
query = df.vector.iloc[2]
|
||||
expected = table.search(query).limit(2).to_arrow()
|
||||
|
||||
assert "s-2" in expected["text"].to_pylist()
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
table.add(df)
|
||||
|
||||
texts = ["the quick brown fox", "jumped over the lazy dog"]
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = emb.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_multiple_vector_columns(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: Vector(10)
|
||||
vector2: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
|
||||
|
||||
def test_create_scalar_index(db):
|
||||
vec_array = pa.array(
|
||||
[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], pa.list_(pa.float32(), 2)
|
||||
)
|
||||
test_data = pa.Table.from_pydict(
|
||||
{"x": ["c", "b", "a", "e", "b"], "y": [1, 2, 3, 4, 5], "vector": vec_array}
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=test_data,
|
||||
)
|
||||
table.create_scalar_index("x")
|
||||
indices = table.to_lance().list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
assert scalar_index["type"] == "Scalar"
|
||||
|
||||
# Confirm that prefiltering still works with the scalar index column
|
||||
results = table.search().where("x = 'c'").to_arrow()
|
||||
assert results == test_data.slice(0, 1)
|
||||
results = table.search([5, 5]).to_arrow()
|
||||
assert results["_distance"][0].as_py() == 0
|
||||
results = table.search([5, 5]).where("x != 'b'").to_arrow()
|
||||
assert results["_distance"][0].as_py() > 0
|
||||
|
||||
|
||||
def test_empty_query(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
|
||||
val = df.id.iloc[0]
|
||||
assert val == 1
|
||||
|
||||
table = LanceTable.create(db, "my_table2", data=[{"id": i} for i in range(100)])
|
||||
df = table.search().select(["id"]).to_pandas()
|
||||
assert len(df) == 10
|
||||
df = table.search().select(["id"]).limit(None).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).limit(-1).to_pandas()
|
||||
assert len(df) == 100
|
||||
|
||||
|
||||
def test_search_with_schema_inf_single_vector(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector_col: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector_col": v1, "text": "foo"},
|
||||
{"vector_col": v2, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
|
||||
result2 = table.search(q).limit(1).to_pandas()
|
||||
|
||||
assert result1["text"].iloc[0] == result2["text"].iloc[0]
|
||||
|
||||
|
||||
def test_search_with_schema_inf_multiple_vector(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: Vector(10)
|
||||
vector2: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
with pytest.raises(ValueError):
|
||||
table.search(q).limit(1).to_pandas()
|
||||
|
||||
|
||||
def test_compact_cleanup(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
|
||||
table.add([{"text": "baz", "id": 2}])
|
||||
assert len(table) == 3
|
||||
assert table.version == 3
|
||||
|
||||
stats = table.compact_files()
|
||||
assert len(table) == 3
|
||||
# Compact_files bump 2 versions.
|
||||
assert table.version == 5
|
||||
assert stats.fragments_removed > 0
|
||||
assert stats.fragments_added == 1
|
||||
|
||||
stats = table.cleanup_old_versions()
|
||||
assert stats.bytes_removed == 0
|
||||
|
||||
stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True)
|
||||
assert stats.bytes_removed > 0
|
||||
assert table.version == 5
|
||||
|
||||
with pytest.raises(Exception, match="Version 3 no longer exists"):
|
||||
table.checkout(3)
|
||||
|
||||
|
||||
def test_count_rows(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert table.count_rows() == 2
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
# Create a list of 10 unique english phrases
|
||||
phrases = [
|
||||
"great kid don't get cocky",
|
||||
"now that's a name I haven't heard in a long time",
|
||||
"if you strike me down I shall become more powerful than you imagine",
|
||||
"I find your lack of faith disturbing",
|
||||
"I've got a bad feeling about this",
|
||||
"never tell me the odds",
|
||||
"I am your father",
|
||||
"somebody has to save our skins",
|
||||
"New strategy R2 let the wookiee win",
|
||||
"Arrrrggghhhhhhh",
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(MyTable)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(MyTable)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
assert result1 == result3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
)
|
||||
def test_consistency(tmp_path, consistency_interval):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
table.add([{"id": 1}])
|
||||
|
||||
if consistency_interval is None:
|
||||
assert table2.version == table.version - 1
|
||||
table2.checkout_latest()
|
||||
assert table2.version == table.version
|
||||
elif consistency_interval == timedelta(seconds=0):
|
||||
assert table2.version == table.version
|
||||
else:
|
||||
# (consistency_interval == timedelta(seconds=0.1)
|
||||
assert table2.version == table.version - 1
|
||||
sleep(0.1)
|
||||
assert table2.version == table.version
|
||||
|
||||
|
||||
def test_restore_consistency(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
# If we call checkout, it should lose consistency
|
||||
table_fixed = copy(table2)
|
||||
table_fixed.checkout(table.version)
|
||||
# But if we call checkout_latest, it should be consistent again
|
||||
table_ref_latest = copy(table_fixed)
|
||||
table_ref_latest.checkout_latest()
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
|
||||
# Schema evolution
|
||||
def test_add_columns(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = LanceTable.create(db, "my_table", data=data)
|
||||
table.add_columns({"new_col": "id + 2"})
|
||||
assert table.to_arrow().column_names == ["id", "new_col"]
|
||||
assert table.to_arrow()["new_col"].to_pylist() == [2, 3]
|
||||
|
||||
|
||||
def test_alter_columns(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = LanceTable.create(db, "my_table", data=data)
|
||||
table.alter_columns({"path": "id", "rename": "new_id"})
|
||||
assert table.to_arrow().column_names == ["new_id"]
|
||||
|
||||
|
||||
def test_drop_columns(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"id": [0, 1], "category": ["a", "b"]})
|
||||
table = LanceTable.create(db, "my_table", data=data)
|
||||
table.drop_columns(["category"])
|
||||
assert table.to_arrow().column_names == ["id"]
|
||||
86
python/python/tests/test_util.py
Normal file
86
python/python/tests/test_util.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
from lancedb.util import get_uri_scheme, join_uri
|
||||
|
||||
|
||||
def test_normalize_uri():
|
||||
uris = [
|
||||
"relative/path",
|
||||
"/absolute/path",
|
||||
"file:///absolute/path",
|
||||
"s3://bucket/path",
|
||||
"gs://bucket/path",
|
||||
"c:\\windows\\path",
|
||||
]
|
||||
schemes = ["file", "file", "file", "s3", "gs", "file"]
|
||||
|
||||
for uri, expected_scheme in zip(uris, schemes):
|
||||
parsed_scheme = get_uri_scheme(uri)
|
||||
assert parsed_scheme == expected_scheme
|
||||
|
||||
|
||||
def test_join_uri_remote():
|
||||
schemes = ["s3", "az", "gs"]
|
||||
for scheme in schemes:
|
||||
expected = f"{scheme}://bucket/path/to/table.lance"
|
||||
base_uri = f"{scheme}://bucket/path/to/"
|
||||
parts = ["table.lance"]
|
||||
assert join_uri(base_uri, *parts) == expected
|
||||
|
||||
base_uri = f"{scheme}://bucket"
|
||||
parts = ["path", "to", "table.lance"]
|
||||
assert join_uri(base_uri, *parts) == expected
|
||||
|
||||
|
||||
# skip this test if on windows
|
||||
@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX")
|
||||
def test_join_uri_posix():
|
||||
for base in [
|
||||
# relative path
|
||||
"relative/path",
|
||||
"relative/path/",
|
||||
# an absolute path
|
||||
"/absolute/path",
|
||||
"/absolute/path/",
|
||||
# a file URI
|
||||
"file:///absolute/path",
|
||||
"file:///absolute/path/",
|
||||
]:
|
||||
joined = join_uri(base, "table.lance")
|
||||
assert joined == str(pathlib.Path(base) / "table.lance")
|
||||
joined = join_uri(pathlib.Path(base), "table.lance")
|
||||
assert joined == pathlib.Path(base) / "table.lance"
|
||||
|
||||
|
||||
# skip this test if not on windows
|
||||
@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX")
|
||||
def test_local_join_uri_windows():
|
||||
# https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats
|
||||
for base in [
|
||||
# windows relative path
|
||||
"relative\\path",
|
||||
"relative\\path\\",
|
||||
# windows absolute path from current drive
|
||||
"c:\\absolute\\path",
|
||||
# relative path from root of current drive
|
||||
"\\relative\\path",
|
||||
]:
|
||||
joined = join_uri(base, "table.lance")
|
||||
assert joined == str(pathlib.Path(base) / "table.lance")
|
||||
joined = join_uri(pathlib.Path(base), "table.lance")
|
||||
assert joined == pathlib.Path(base) / "table.lance"
|
||||
Reference in New Issue
Block a user