diff --git a/nodejs/__test__/arrow.test.ts b/nodejs/__test__/arrow.test.ts index 478c766b..9e2ef66f 100644 --- a/nodejs/__test__/arrow.test.ts +++ b/nodejs/__test__/arrow.test.ts @@ -1,7 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -import { Bool, Field, Int32, List, Schema, Struct, Utf8 } from "apache-arrow"; +import { + Bool, + Field, + Int32, + List, + Schema, + Struct, + Uint8, + Utf8, +} from "apache-arrow"; import * as arrow15 from "apache-arrow-15"; import * as arrow16 from "apache-arrow-16"; @@ -255,6 +264,98 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( expect(actualSchema).toEqual(schema); }); + it("will detect vector columns when name contains 'vector' or 'embedding'", async function () { + // Test various naming patterns that should be detected as vector columns + const floatVectorTable = makeArrowTable([ + { + // Float vectors (use decimal values to ensure they're treated as floats) + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + user_vector: [1.1, 2.2], + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + text_embedding: [3.3, 4.4], + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + doc_embeddings: [5.5, 6.6], + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + my_vector_field: [7.7, 8.8], + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + embedding_model: [9.9, 10.1], + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + VECTOR_COL: [11.1, 12.2], // uppercase + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + Vector_Mixed: [13.3, 14.4], // mixed case + }, + ]); + + // Check that columns with 'vector' or 'embedding' in name are converted to FixedSizeList + const floatVectorColumns = [ + "user_vector", + "text_embedding", + "doc_embeddings", + "my_vector_field", + "embedding_model", + "VECTOR_COL", + "Vector_Mixed", + ]; + + for (const columnName of floatVectorColumns) { + expect( + DataType.isFixedSizeList( + floatVectorTable.getChild(columnName)?.type, + ), + ).toBe(true); + // Check that float vectors use Float32 by default + expect( + floatVectorTable + .getChild(columnName) + ?.type.children[0].type.toString(), + ).toEqual(new Float32().toString()); + } + + // Test that regular integer arrays still get treated as float vectors + // (since JavaScript doesn't distinguish integers from floats at runtime) + const integerArrayTable = makeArrowTable([ + { + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + vector_int: [1, 2], // Regular array with integers - should be Float32 + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + embedding_int: [3, 4], // Regular array with integers - should be Float32 + }, + ]); + + const integerArrayColumns = ["vector_int", "embedding_int"]; + + for (const columnName of integerArrayColumns) { + expect( + DataType.isFixedSizeList( + integerArrayTable.getChild(columnName)?.type, + ), + ).toBe(true); + // Regular integer arrays should use Float32 (avoiding false positives) + expect( + integerArrayTable + .getChild(columnName) + ?.type.children[0].type.toString(), + ).toEqual(new Float32().toString()); + } + + // Test normal list should NOT be converted to FixedSizeList + const normalListTable = makeArrowTable([ + { + // biome-ignore lint/style/useNamingConvention: Testing vector column detection patterns + normal_list: [15.5, 16.6], // should NOT be detected as vector + }, + ]); + + expect( + DataType.isFixedSizeList( + normalListTable.getChild("normal_list")?.type, + ), + ).toBe(false); + expect( + DataType.isList(normalListTable.getChild("normal_list")?.type), + ).toBe(true); + }); + it("will allow different vector column types", async function () { const table = makeArrowTable([{ fp16: [1], fp32: [1], fp64: [1] }], { vectorColumns: { diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 7bf2bb40..ed399742 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -34,6 +34,7 @@ import { Struct, Timestamp, Type, + Uint8, Utf8, Vector, makeVector as arrowMakeVector, @@ -51,6 +52,15 @@ import { sanitizeTable, sanitizeType, } from "./sanitize"; + +/** + * Check if a field name indicates a vector column. + */ +function nameSuggestsVectorColumn(fieldName: string): boolean { + const nameLower = fieldName.toLowerCase(); + return nameLower.includes("vector") || nameLower.includes("embedding"); +} + export * from "apache-arrow"; export type SchemaLike = | Schema @@ -591,10 +601,17 @@ function inferType( return undefined; } // Try to automatically detect embedding columns. - if (valueType instanceof Float && path[path.length - 1] === "vector") { - // We default to Float32 for vectors. - const child = new Field("item", new Float32(), true); - return new FixedSizeList(value.length, child); + if (nameSuggestsVectorColumn(path[path.length - 1])) { + // Check if value is a Uint8Array for integer vector type determination + if (value instanceof Uint8Array) { + // For integer vectors, we default to Uint8 (matching Python implementation) + const child = new Field("item", new Uint8(), true); + return new FixedSizeList(value.length, child); + } else { + // For float vectors, we default to Float32 + const child = new Field("item", new Float32(), true); + return new FixedSizeList(value.length, child); + } } else { const child = new Field("item", valueType, true); return new List(child); diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index e6db7515..3088572a 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2926,6 +2926,12 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray return pc.is_in(indices, has_nan_indices) +def _name_suggests_vector_column(field_name: str) -> bool: + """Check if a field name indicates a vector column.""" + name_lower = field_name.lower() + return "vector" in name_lower or "embedding" in name_lower + + def _infer_target_schema( reader: pa.RecordBatchReader, ) -> Tuple[pa.Schema, pa.RecordBatchReader]: @@ -2933,35 +2939,27 @@ def _infer_target_schema( peeked = None for i, field in enumerate(schema): - if ( - field.name == VECTOR_COLUMN_NAME - and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type)) - and pa.types.is_floating(field.type.value_type) - ): + is_list_type = pa.types.is_list(field.type) or pa.types.is_large_list( + field.type + ) + + if _name_suggests_vector_column(field.name) and is_list_type: if peeked is None: peeked, reader = peek_reader(reader) # Use the most common length of the list as the dimensions dim = _modal_list_size(peeked.column(i)) - new_field = pa.field( - VECTOR_COLUMN_NAME, - pa.list_(pa.float32(), dim), - nullable=field.nullable, - ) + # Determine target type based on value type + if pa.types.is_floating(field.type.value_type): + target_type = pa.list_(pa.float32(), dim) + elif pa.types.is_integer(field.type.value_type): + target_type = pa.list_(pa.uint8(), dim) + else: + continue # Skip non-numeric types - schema = schema.set(i, new_field) - elif ( - field.name == VECTOR_COLUMN_NAME - and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type)) - and pa.types.is_integer(field.type.value_type) - ): - if peeked is None: - peeked, reader = peek_reader(reader) - # Use the most common length of the list as the dimensions - dim = _modal_list_size(peeked.column(i)) new_field = pa.field( - VECTOR_COLUMN_NAME, - pa.list_(pa.uint8(), dim), + field.name, # preserve original field name + target_type, nullable=field.nullable, ) diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 4a00f948..4ff120f4 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -390,6 +390,87 @@ def test_infer_target_schema(): assert output == expected +def test_infer_target_schema_with_vector_embedding_names(): + """Test that _infer_target_schema detects vector columns with 'vector'/'embedding'. + + This tests the enhanced column name detection for vector inference. + """ + + # Test float vectors with various naming patterns + example = pa.schema( + { + "user_vector": pa.list_(pa.float64()), + "text_embedding": pa.list_(pa.float64()), + "doc_embeddings": pa.list_(pa.float64()), + "my_vector_field": pa.list_(pa.float64()), + "embedding_model": pa.list_(pa.float64()), + "VECTOR_COL": pa.list_(pa.float64()), # uppercase + "Vector_Mixed": pa.list_(pa.float64()), # mixed case + "normal_list": pa.list_(pa.float64()), # should not be converted + } + ) + data = pa.table( + { + "user_vector": [[1.0, 2.0]], + "text_embedding": [[3.0, 4.0]], + "doc_embeddings": [[5.0, 6.0]], + "my_vector_field": [[7.0, 8.0]], + "embedding_model": [[9.0, 10.0]], + "VECTOR_COL": [[11.0, 12.0]], + "Vector_Mixed": [[13.0, 14.0]], + "normal_list": [[15.0, 16.0]], + }, + schema=example, + ) + + expected = pa.schema( + { + "user_vector": pa.list_(pa.float32(), 2), # converted + "text_embedding": pa.list_(pa.float32(), 2), # converted + "doc_embeddings": pa.list_(pa.float32(), 2), # converted + "my_vector_field": pa.list_(pa.float32(), 2), # converted + "embedding_model": pa.list_(pa.float32(), 2), # converted + "VECTOR_COL": pa.list_(pa.float32(), 2), # converted + "Vector_Mixed": pa.list_(pa.float32(), 2), # converted + "normal_list": pa.list_(pa.float64()), # not converted + } + ) + + output, _ = _infer_target_schema(data.to_reader()) + assert output == expected + + # Test integer vectors with various naming patterns + example_int = pa.schema( + { + "user_vector": pa.list_(pa.int32()), + "text_embedding": pa.list_(pa.int64()), + "doc_embeddings": pa.list_(pa.int16()), + "normal_list": pa.list_(pa.int32()), # should not be converted + } + ) + data_int = pa.table( + { + "user_vector": [[1, 2]], + "text_embedding": [[3, 4]], + "doc_embeddings": [[5, 6]], + "normal_list": [[7, 8]], + }, + schema=example_int, + ) + + expected_int = pa.schema( + { + "user_vector": pa.list_(pa.uint8(), 2), # converted + "text_embedding": pa.list_(pa.uint8(), 2), # converted + "doc_embeddings": pa.list_(pa.uint8(), 2), # converted + "normal_list": pa.list_(pa.int32()), # not converted + } + ) + + output_int, _ = _infer_target_schema(data_int.to_reader()) + assert output_int == expected_int + + @pytest.mark.parametrize( "data", [