bug: add a test for fp16 (#837)

Add test to ingest fp16 to a database
This commit is contained in:
Lei Xu
2024-01-20 16:23:28 -08:00
committed by GitHub
parent a1ab549457
commit 83ed8d1e49
4 changed files with 36 additions and 9 deletions

View File

@@ -1335,13 +1335,15 @@ def _sanitize_vector_column(
elif not pa.types.is_fixed_size_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
vec_arr = ensure_fixed_size_list_of_f32(vec_arr)
vec_arr = ensure_fixed_size_list(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
if has_nans:
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
# kernel over f16 types.
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
if np.isnan(values_np).any():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
@@ -1349,9 +1351,9 @@ def _sanitize_vector_column(
return data
def ensure_fixed_size_list_of_f32(vec_arr):
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
values = vec_arr.values
if not pa.types.is_float32(values.type):
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
values = values.cast(pa.float32())
if pa.types.is_fixed_size_list(vec_arr.type):
list_size = vec_arr.type.list_size