mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
@@ -23,8 +23,8 @@ from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
|
||||
|
||||
|
||||
@@ -20,13 +20,13 @@
|
||||
# https://github.com/urllib3/urllib3/pull/3275
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.poolmanager import PoolManager
|
||||
from urllib3.connectionpool import HTTPSConnectionPool
|
||||
from urllib3.connection import HTTPSConnection
|
||||
from urllib3.connectionpool import HTTPSConnectionPool
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
|
||||
def get_client_connection_timeout() -> int:
|
||||
|
||||
@@ -1327,13 +1327,15 @@ def _sanitize_vector_column(
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
|
||||
vec_arr = ensure_fixed_size_list_of_f32(vec_arr)
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
|
||||
if has_nans:
|
||||
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
|
||||
# kernel over f16 types.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
@@ -1341,9 +1343,9 @@ def _sanitize_vector_column(
|
||||
return data
|
||||
|
||||
|
||||
def ensure_fixed_size_list_of_f32(vec_arr):
|
||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
values = vec_arr.values
|
||||
if not pa.types.is_float32(values.type):
|
||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
||||
values = values.cast(pa.float32())
|
||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||
list_size = vec_arr.type.list_size
|
||||
|
||||
@@ -521,6 +521,31 @@ def test_create_with_embedding_function(db):
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_create_f16_table(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: Vector(128, value_type=pa.float16())
|
||||
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"text": [f"s-{i}" for i in range(10000)],
|
||||
"vector": [np.random.randn(128).astype(np.float16) for _ in range(10000)],
|
||||
}
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"f16_tbl",
|
||||
schema=MyTable,
|
||||
)
|
||||
table.add(df)
|
||||
table.create_index(num_partitions=2, num_sub_vectors=8)
|
||||
|
||||
query = df.vector.iloc[2]
|
||||
expected = table.search(query).limit(2).to_arrow()
|
||||
|
||||
assert "s-2" in expected["text"].to_pylist()
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user