mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
1. Support persistent embedding function so users can just search using query string 2. Add fixed size list conversion for multiple vector columns 3. Add support for empty query (just apply select/where/limit). 4. Refactor and simplify some of the data prep code --------- Co-authored-by: Chang She <chang@lancedb.com> Co-authored-by: Weston Pace <weston.pace@gmail.com>
79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
# Copyright 2023 LanceDB Developers
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import sys
|
|
|
|
import lance
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
|
|
from lancedb.conftest import MockEmbeddingFunction
|
|
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
|
|
|
|
|
|
def mock_embed_func(input_data):
|
|
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
|
|
|
|
|
|
def test_with_embeddings():
|
|
for wrap_api in [True, False]:
|
|
if wrap_api and sys.version_info.minor >= 11:
|
|
# ratelimiter package doesn't work on 3.11
|
|
continue
|
|
data = pa.Table.from_arrays(
|
|
[
|
|
pa.array(["foo", "bar"]),
|
|
pa.array([10.0, 20.0]),
|
|
],
|
|
names=["text", "price"],
|
|
)
|
|
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
|
|
assert data.num_columns == 3
|
|
assert data.num_rows == 2
|
|
assert data.column_names == ["text", "price", "vector"]
|
|
assert data.column("text").to_pylist() == ["foo", "bar"]
|
|
assert data.column("price").to_pylist() == [10.0, 20.0]
|
|
|
|
|
|
def test_embedding_function(tmp_path):
|
|
registry = EmbeddingFunctionRegistry.get_instance()
|
|
|
|
# let's create a table
|
|
table = pa.table(
|
|
{
|
|
"text": pa.array(["hello world", "goodbye world"]),
|
|
"vector": [np.random.randn(10), np.random.randn(10)],
|
|
}
|
|
)
|
|
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
|
metadata = registry.get_table_metadata([func])
|
|
table = table.replace_schema_metadata(metadata)
|
|
|
|
# Write it to disk
|
|
lance.write_dataset(table, tmp_path / "test.lance")
|
|
|
|
# Load this back
|
|
ds = lance.dataset(tmp_path / "test.lance")
|
|
|
|
# can we get the serialized version back out?
|
|
functions = registry.parse_functions(ds.schema.metadata)
|
|
|
|
func = functions["vector"]
|
|
actual = func("hello world")
|
|
|
|
# We create an instance
|
|
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
|
# And we make sure we can call it
|
|
expected = expected_func("hello world")
|
|
|
|
assert np.allclose(actual, expected)
|