add unit tests

This commit is contained in:
Chang She
2023-03-21 22:29:19 -07:00
parent 5c15e0ee86
commit 690141d357
5 changed files with 140 additions and 8 deletions

View File

@@ -108,7 +108,7 @@ class LanceTable:
return LanceQueryBuilder(self, query)
@classmethod
def create(cls, db, name, data, schema):
def create(cls, db, name, data, schema=None):
tbl = LanceTable(db, name)
data = _sanitize_data(data, schema)
lance.write_dataset(data, tbl._dataset_uri, mode="create")
@@ -131,10 +131,8 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
return data
# cast the columns to the expected types
data = data.combine_chunks()
return pa.Table.from_arrays([
data[name].cast(schema.field(name).type)
for name in schema.names
], schema=schema)
return pa.Table.from_arrays([data[name] for name in schema.names],
schema=schema)
# just check the vector column
return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)