fix: infer schema from huggingface dataset (#1444)

Closes #1383

When creating a table from a HuggingFace dataset, infer the arrow schema
directly
This commit is contained in:
Chang She
2024-07-23 13:12:34 -07:00
committed by GitHub
parent 30047a5566
commit 374c1e7aba
5 changed files with 26 additions and 11 deletions

View File

@@ -732,7 +732,7 @@ class AsyncConnection(object):
fill_value = 0.0
if data is not None:
data = _sanitize_data(
data, schema = _sanitize_data(
data,
schema,
metadata=metadata,

View File

@@ -245,7 +245,7 @@ class RemoteDBConnection(DBConnection):
schema = schema.to_arrow_schema()
if data is not None:
data = _sanitize_data(
data, schema = _sanitize_data(
data,
schema,
metadata=None,

View File

@@ -210,7 +210,7 @@ class RemoteTable(Table):
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
data = _sanitize_data(
data, _ = _sanitize_data(
data,
self.schema,
metadata=None,
@@ -345,7 +345,7 @@ class RemoteTable(Table):
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
data, _ = _sanitize_data(
new_data,
self.schema,
metadata=None,

View File

@@ -103,7 +103,8 @@ def _sanitize_data(
if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema()
if schema is None:
schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data]
data = pa.Table.from_pylist(data, schema=schema)
else:
@@ -133,7 +134,7 @@ def _sanitize_data(
)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
return data
return data, schema
def _schema_from_hf(data, schema):
@@ -205,7 +206,7 @@ def _to_record_batch_generator(
# and do things like add the vector column etc
if isinstance(batch, pa.RecordBatch):
batch = pa.Table.from_batches([batch])
batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
batch, _ = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
for b in batch.to_batches():
yield b
@@ -1295,7 +1296,7 @@ class LanceTable(Table):
The number of vectors in the table.
"""
# TODO: manage table listing and metadata separately
data = _sanitize_data(
data, _ = _sanitize_data(
data,
self.schema,
metadata=self.schema.metadata,
@@ -1547,7 +1548,7 @@ class LanceTable(Table):
metadata = registry.get_table_metadata(embedding_functions)
if data is not None:
data = _sanitize_data(
data, schema = _sanitize_data(
data,
schema,
metadata=metadata,
@@ -1675,7 +1676,7 @@ class LanceTable(Table):
on_bad_vectors: str,
fill_value: float,
):
new_data = _sanitize_data(
new_data, _ = _sanitize_data(
new_data,
self.schema,
metadata=self.schema.metadata,
@@ -2153,7 +2154,7 @@ class AsyncTable:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data = _sanitize_data(
data, _ = _sanitize_data(
data,
schema,
metadata=schema.metadata,

View File

@@ -124,3 +124,17 @@ def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with
# this should still work because we don't add the split column
# if it already exists
train_table.add(hf_dataset_with_split)
def test_generator(tmp_path: Path):
db = lancedb.connect(tmp_path)
def gen():
yield {"pokemon": "bulbasaur", "type": "grass"}
yield {"pokemon": "squirtle", "type": "water"}
ds = datasets.Dataset.from_generator(gen)
tbl = db.create_table("pokemon", ds)
assert len(tbl) == 2
assert tbl.schema == ds.features.arrow_schema