diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 71073a28..62be2042 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -732,7 +732,7 @@ class AsyncConnection(object): fill_value = 0.0 if data is not None: - data = _sanitize_data( + data, schema = _sanitize_data( data, schema, metadata=metadata, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 66e01360..6f51f79e 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -245,7 +245,7 @@ class RemoteDBConnection(DBConnection): schema = schema.to_arrow_schema() if data is not None: - data = _sanitize_data( + data, schema = _sanitize_data( data, schema, metadata=None, diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 68fb0ed4..8e4ff5d5 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -210,7 +210,7 @@ class RemoteTable(Table): The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - data = _sanitize_data( + data, _ = _sanitize_data( data, self.schema, metadata=None, @@ -345,7 +345,7 @@ class RemoteTable(Table): on_bad_vectors: str, fill_value: float, ): - data = _sanitize_data( + data, _ = _sanitize_data( new_data, self.schema, metadata=None, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 61a8f7e8..4a31e340 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -103,7 +103,8 @@ def _sanitize_data( if isinstance(data, list): # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel): - schema = data[0].__class__.to_arrow_schema() + if schema is None: + schema = data[0].__class__.to_arrow_schema() data = [model_to_dict(d) for d in data] data = pa.Table.from_pylist(data, schema=schema) else: @@ -133,7 +134,7 @@ def _sanitize_data( ) else: raise TypeError(f"Unsupported data type: {type(data)}") - return data + return data, schema def _schema_from_hf(data, schema): @@ -205,7 +206,7 @@ def _to_record_batch_generator( # and do things like add the vector column etc if isinstance(batch, pa.RecordBatch): batch = pa.Table.from_batches([batch]) - batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) + batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) for b in batch.to_batches(): yield b @@ -1295,7 +1296,7 @@ class LanceTable(Table): The number of vectors in the table. """ # TODO: manage table listing and metadata separately - data = _sanitize_data( + data, _ = _sanitize_data( data, self.schema, metadata=self.schema.metadata, @@ -1547,7 +1548,7 @@ class LanceTable(Table): metadata = registry.get_table_metadata(embedding_functions) if data is not None: - data = _sanitize_data( + data, schema = _sanitize_data( data, schema, metadata=metadata, @@ -1675,7 +1676,7 @@ class LanceTable(Table): on_bad_vectors: str, fill_value: float, ): - new_data = _sanitize_data( + new_data, _ = _sanitize_data( new_data, self.schema, metadata=self.schema.metadata, @@ -2153,7 +2154,7 @@ class AsyncTable: on_bad_vectors = "error" if fill_value is None: fill_value = 0.0 - data = _sanitize_data( + data, _ = _sanitize_data( data, schema, metadata=schema.metadata, diff --git a/python/python/tests/test_huggingface.py b/python/python/tests/test_huggingface.py index b5377149..ea4d39b2 100644 --- a/python/python/tests/test_huggingface.py +++ b/python/python/tests/test_huggingface.py @@ -124,3 +124,17 @@ def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with # this should still work because we don't add the split column # if it already exists train_table.add(hf_dataset_with_split) + + +def test_generator(tmp_path: Path): + db = lancedb.connect(tmp_path) + + def gen(): + yield {"pokemon": "bulbasaur", "type": "grass"} + yield {"pokemon": "squirtle", "type": "water"} + + ds = datasets.Dataset.from_generator(gen) + tbl = db.create_table("pokemon", ds) + + assert len(tbl) == 2 + assert tbl.schema == ds.features.arrow_schema