mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-14 15:52:57 +00:00
@@ -1335,13 +1335,15 @@ def _sanitize_vector_column(
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
|
||||
vec_arr = ensure_fixed_size_list_of_f32(vec_arr)
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
|
||||
if has_nans:
|
||||
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
|
||||
# kernel over f16 types.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
@@ -1349,9 +1351,9 @@ def _sanitize_vector_column(
|
||||
return data
|
||||
|
||||
|
||||
def ensure_fixed_size_list_of_f32(vec_arr):
|
||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
values = vec_arr.values
|
||||
if not pa.types.is_float32(values.type):
|
||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
||||
values = values.cast(pa.float32())
|
||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||
list_size = vec_arr.type.list_size
|
||||
|
||||
Reference in New Issue
Block a user