mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
fix: infer schema from huggingface dataset (#1444)
Closes #1383 When creating a table from a HuggingFace dataset, infer the arrow schema directly
This commit is contained in:
@@ -732,7 +732,7 @@ class AsyncConnection(object):
|
||||
fill_value = 0.0
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=metadata,
|
||||
|
||||
@@ -245,7 +245,7 @@ class RemoteDBConnection(DBConnection):
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=None,
|
||||
|
||||
@@ -210,7 +210,7 @@ class RemoteTable(Table):
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
data = _sanitize_data(
|
||||
data, _ = _sanitize_data(
|
||||
data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
@@ -345,7 +345,7 @@ class RemoteTable(Table):
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
data = _sanitize_data(
|
||||
data, _ = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
|
||||
@@ -103,7 +103,8 @@ def _sanitize_data(
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
if schema is None:
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [model_to_dict(d) for d in data]
|
||||
data = pa.Table.from_pylist(data, schema=schema)
|
||||
else:
|
||||
@@ -133,7 +134,7 @@ def _sanitize_data(
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||
return data
|
||||
return data, schema
|
||||
|
||||
|
||||
def _schema_from_hf(data, schema):
|
||||
@@ -205,7 +206,7 @@ def _to_record_batch_generator(
|
||||
# and do things like add the vector column etc
|
||||
if isinstance(batch, pa.RecordBatch):
|
||||
batch = pa.Table.from_batches([batch])
|
||||
batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for b in batch.to_batches():
|
||||
yield b
|
||||
|
||||
@@ -1295,7 +1296,7 @@ class LanceTable(Table):
|
||||
The number of vectors in the table.
|
||||
"""
|
||||
# TODO: manage table listing and metadata separately
|
||||
data = _sanitize_data(
|
||||
data, _ = _sanitize_data(
|
||||
data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
@@ -1547,7 +1548,7 @@ class LanceTable(Table):
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=metadata,
|
||||
@@ -1675,7 +1676,7 @@ class LanceTable(Table):
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
new_data = _sanitize_data(
|
||||
new_data, _ = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
@@ -2153,7 +2154,7 @@ class AsyncTable:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data = _sanitize_data(
|
||||
data, _ = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
|
||||
@@ -124,3 +124,17 @@ def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with
|
||||
# this should still work because we don't add the split column
|
||||
# if it already exists
|
||||
train_table.add(hf_dataset_with_split)
|
||||
|
||||
|
||||
def test_generator(tmp_path: Path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
def gen():
|
||||
yield {"pokemon": "bulbasaur", "type": "grass"}
|
||||
yield {"pokemon": "squirtle", "type": "water"}
|
||||
|
||||
ds = datasets.Dataset.from_generator(gen)
|
||||
tbl = db.create_table("pokemon", ds)
|
||||
|
||||
assert len(tbl) == 2
|
||||
assert tbl.schema == ds.features.arrow_schema
|
||||
|
||||
Reference in New Issue
Block a user