From 97d033dfd6b6868232359839f79d6081d313e1eb Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sat, 20 Jan 2024 16:23:28 -0800 Subject: [PATCH] bug: add a test for fp16 (#837) Add test to ingest fp16 to a database --- python/lancedb/remote/client.py | 2 +- python/lancedb/remote/connection_timeout.py | 6 ++--- python/lancedb/table.py | 12 +++++----- python/tests/test_table.py | 25 +++++++++++++++++++++ 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 9c65ed31..3bafec4b 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -23,8 +23,8 @@ from pydantic import BaseModel from lancedb.common import Credential from lancedb.remote import VectorQuery, VectorQueryResult -from lancedb.remote.errors import LanceDBClientError from lancedb.remote.connection_timeout import LanceDBClientHTTPAdapterFactory +from lancedb.remote.errors import LanceDBClientError ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream" diff --git a/python/lancedb/remote/connection_timeout.py b/python/lancedb/remote/connection_timeout.py index f4fb8235..f9d18e56 100644 --- a/python/lancedb/remote/connection_timeout.py +++ b/python/lancedb/remote/connection_timeout.py @@ -20,13 +20,13 @@ # https://github.com/urllib3/urllib3/pull/3275 import datetime -import os import logging +import os from requests.adapters import HTTPAdapter -from urllib3.poolmanager import PoolManager -from urllib3.connectionpool import HTTPSConnectionPool from urllib3.connection import HTTPSConnection +from urllib3.connectionpool import HTTPSConnectionPool +from urllib3.poolmanager import PoolManager def get_client_connection_timeout() -> int: diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 930b65fc..6f81c40f 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -1327,13 +1327,15 @@ def _sanitize_vector_column( elif not pa.types.is_fixed_size_list(vec_arr.type): raise TypeError(f"Unsupported vector column type: {vec_arr.type}") - vec_arr = ensure_fixed_size_list_of_f32(vec_arr) + vec_arr = ensure_fixed_size_list(vec_arr) data = data.set_column( data.column_names.index(vector_column_name), vector_column_name, vec_arr ) - has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py() - if has_nans: + # Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan` + # kernel over f16 types. + values_np = vec_arr.values.to_numpy(zero_copy_only=False) + if np.isnan(values_np).any(): data = _sanitize_nans( data, fill_value, on_bad_vectors, vec_arr, vector_column_name ) @@ -1341,9 +1343,9 @@ def _sanitize_vector_column( return data -def ensure_fixed_size_list_of_f32(vec_arr): +def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray: values = vec_arr.values - if not pa.types.is_float32(values.type): + if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)): values = values.cast(pa.float32()) if pa.types.is_fixed_size_list(vec_arr.type): list_size = vec_arr.type.list_size diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 339f8668..0dadbbbd 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -521,6 +521,31 @@ def test_create_with_embedding_function(db): assert actual == expected +def test_create_f16_table(db): + class MyTable(LanceModel): + text: str + vector: Vector(128, value_type=pa.float16()) + + df = pd.DataFrame( + { + "text": [f"s-{i}" for i in range(10000)], + "vector": [np.random.randn(128).astype(np.float16) for _ in range(10000)], + } + ) + table = LanceTable.create( + db, + "f16_tbl", + schema=MyTable, + ) + table.add(df) + table.create_index(num_partitions=2, num_sub_vectors=8) + + query = df.vector.iloc[2] + expected = table.search(query).limit(2).to_arrow() + + assert "s-2" in expected["text"].to_pylist() + + def test_add_with_embedding_function(db): emb = EmbeddingFunctionRegistry.get_instance().get("test")()