diff --git a/python/lancedb/db.py b/python/lancedb/db.py index e0ab2fea..3db1c583 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -53,8 +53,9 @@ class LanceDBConnection: def __getitem__(self, name: str) -> LanceTable: return self.open_table(name) - def create_table(self, name: str, data: DATA = None, - schema: pa.Schema = None) -> LanceTable: + def create_table( + self, name: str, data: DATA = None, schema: pa.Schema = None + ) -> LanceTable: """Create a table in the database. Parameters diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 07c022f8..14ac2083 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -76,17 +76,12 @@ class LanceQueryBuilder: return self def to_df(self) -> pd.DataFrame: - """Execute the query and return the results as a pandas DataFrame. - """ + """Execute the query and return the results as a pandas DataFrame.""" ds = self._table.to_lance() # TODO indexed search tbl = ds.to_table( columns=self._columns, filter=self._where, - nearest={ - "column": VECTOR_COLUMN_NAME, - "q": self._query, - "k": self._limit - } + nearest={"column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit}, ) return tbl.to_pandas() diff --git a/python/lancedb/table.py b/python/lancedb/table.py index ce0d5fb5..7840f396 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -131,8 +131,9 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table: return data # cast the columns to the expected types data = data.combine_chunks() - return pa.Table.from_arrays([data[name] for name in schema.names], - schema=schema) + return pa.Table.from_arrays( + [data[name] for name in schema.names], schema=schema + ) # just check the vector column return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME) diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 7ce51e71..956ce505 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -20,9 +20,13 @@ def test_basic(tmp_path): assert db.uri == str(tmp_path) assert db.table_names() == [] - table = db.create_table("test", - data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, - {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]) + table = db.create_table( + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], + ) rs = table.search([100, 100]).limit(1).to_df() assert len(rs) == 1 assert rs["item"].iloc[0] == "bar" diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 692debcc..c08cdd8f 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -21,7 +21,6 @@ import pytest class MockTable: - def __init__(self, tmp_path): self.uri = tmp_path @@ -31,16 +30,22 @@ class MockTable: @pytest.fixture def table(tmp_path) -> MockTable: - df = pd.DataFrame({ - "vector": [[1, 2], [3, 4]], - "id": [1, 2], - "str_field": ["a", "b"], - "float_field": [1.0, 2.0] - }) - schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2)), - pa.field("id", pa.int32()), - pa.field("str_field", pa.string()), - pa.field("float_field", pa.float64())]) + df = pd.DataFrame( + { + "vector": [[1, 2], [3, 4]], + "id": [1, 2], + "str_field": ["a", "b"], + "float_field": [1.0, 2.0], + } + ) + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), list_size=2)), + pa.field("id", pa.int32()), + pa.field("str_field", pa.string()), + pa.field("float_field", pa.float64()), + ] + ) lance.write_dataset(df, tmp_path, schema) return MockTable(tmp_path) @@ -55,5 +60,3 @@ def test_query_builder_with_filter(table): df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df() assert df["id"].values[0] == 2 assert all(df["vector"].values[0] == [3, 4]) - - diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 3050d69b..e0a93f06 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -21,7 +21,6 @@ from lancedb.table import LanceTable class MockDB: - def __init__(self, uri: Path): self.uri = uri @@ -33,9 +32,12 @@ def db(tmp_path) -> MockDB: def test_basic(db): ds = LanceTable.create( - db, "test", - data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, - {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}] + db, + "test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ], ).to_lance() table = LanceTable(db, "test") @@ -45,21 +47,35 @@ def test_basic(db): def test_add(db): - schema = pa.schema([pa.field("vector", pa.list_(pa.float32())), - pa.field("item", pa.string()), - pa.field("price", pa.float32())]) - expected = pa.Table.from_arrays([ - pa.array([[3.1, 4.1], [5.9, 26.5]]), - pa.array(["foo", "bar"]), - pa.array([10.0, 20.0]) - ], schema=schema) - data = [[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, - {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]] + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32())), + pa.field("item", pa.string()), + pa.field("price", pa.float32()), + ] + ) + expected = pa.Table.from_arrays( + [ + pa.array([[3.1, 4.1], [5.9, 26.5]]), + pa.array(["foo", "bar"]), + pa.array([10.0, 20.0]), + ], + schema=schema, + ) + data = [ + [ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, + ] + ] df = pd.DataFrame(data[0]) data.append(df) data.append(pa.Table.from_pandas(df, schema=schema)) for i, d in enumerate(data): - tbl = (LanceTable.create(db, f"test_{i}", data=d, schema=schema) - .to_lance().to_table()) + tbl = ( + LanceTable.create(db, f"test_{i}", data=d, schema=schema) + .to_lance() + .to_table() + ) assert expected == tbl