bug: add a test for fp16 (#837)

Add test to ingest fp16 to a database
This commit is contained in:
Lei Xu
2024-01-20 16:23:28 -08:00
committed by Weston Pace
parent 0c580abd70
commit 97d033dfd6
4 changed files with 36 additions and 9 deletions

View File

@@ -23,8 +23,8 @@ from pydantic import BaseModel
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.errors import LanceDBClientError
from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory
from lancedb.remote.errors import LanceDBClientError
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"

View File

@@ -20,13 +20,13 @@
# https://github.com/urllib3/urllib3/pull/3275
import datetime
import os
import logging
import os
from requests.adapters import HTTPAdapter
from urllib3.poolmanager import PoolManager
from urllib3.connectionpool import HTTPSConnectionPool
from urllib3.connection import HTTPSConnection
from urllib3.connectionpool import HTTPSConnectionPool
from urllib3.poolmanager import PoolManager
def get_client_connection_timeout() -> int:

View File

@@ -1327,13 +1327,15 @@ def _sanitize_vector_column(
elif not pa.types.is_fixed_size_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
vec_arr = ensure_fixed_size_list_of_f32(vec_arr)
vec_arr = ensure_fixed_size_list(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
if has_nans:
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
# kernel over f16 types.
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
if np.isnan(values_np).any():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
@@ -1341,9 +1343,9 @@ def _sanitize_vector_column(
return data
def ensure_fixed_size_list_of_f32(vec_arr):
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
values = vec_arr.values
if not pa.types.is_float32(values.type):
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
values = values.cast(pa.float32())
if pa.types.is_fixed_size_list(vec_arr.type):
list_size = vec_arr.type.list_size

View File

@@ -521,6 +521,31 @@ def test_create_with_embedding_function(db):
assert actual == expected
def test_create_f16_table(db):
class MyTable(LanceModel):
text: str
vector: Vector(128, value_type=pa.float16())
df = pd.DataFrame(
{
"text": [f"s-{i}" for i in range(10000)],
"vector": [np.random.randn(128).astype(np.float16) for _ in range(10000)],
}
)
table = LanceTable.create(
db,
"f16_tbl",
schema=MyTable,
)
table.add(df)
table.create_index(num_partitions=2, num_sub_vectors=8)
query = df.vector.iloc[2]
expected = table.search(query).limit(2).to_arrow()
assert "s-2" in expected["text"].to_pylist()
def test_add_with_embedding_function(db):
emb = EmbeddingFunctionRegistry.get_instance().get("test")()