feat: add to_list and to_pandas api's (#556)

Add `to_list` to return query results as list of python dict (so we're
not too pandas-centric). Closes #555

Add `to_pandas` API and add deprecation warning on `to_df`. Closes #545

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-10-11 12:18:55 -07:00
committed by Weston Pace
parent a737bbff19
commit 8469d010f8
26 changed files with 125 additions and 71 deletions

View File

@@ -47,7 +47,7 @@ def test_contextualizer(raw_df: pd.DataFrame):
.stride(3)
.text_col("token")
.groupby("document_id")
.to_df()["token"]
.to_pandas()["token"]
.to_list()
)
@@ -67,7 +67,7 @@ def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
.text_col("token")
.groupby("document_id")
.min_window_size(4)
.to_df()["token"]
.to_pandas()["token"]
.to_list()
)

View File

@@ -33,11 +33,11 @@ def test_basic(tmp_path):
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_df()
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
@@ -62,11 +62,11 @@ def test_ingest_pd(tmp_path):
}
)
table = db.create_table("test", data=data)
rs = table.search([100, 100]).limit(1).to_df()
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
@@ -137,8 +137,8 @@ def test_ingest_iterator(tmp_path):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
tbl_len = len(tbl)
tbl.add(make_batches())
assert tbl_len == 50

View File

@@ -23,5 +23,5 @@ from lancedb import LanceDBConnection
def test_against_local_server():
conn = LanceDBConnection("lancedb+http://localhost:10024")
table = conn.open_table("sift1m_ivf1024_pq16")
df = table.search(np.random.rand(128)).to_df()
df = table.search(np.random.rand(128)).to_pandas()
assert len(df) == 10

View File

@@ -71,14 +71,14 @@ def test_search_index(tmp_path, table):
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_df()
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) == 10
assert "text" in df.columns
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"])
df = table.search("puppy").limit(10).to_df()
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 10
assert "text" in df.columns
assert "text2" in df.columns
@@ -87,5 +87,5 @@ def test_create_index_multiple_columns(tmp_path, table):
def test_empty_rs(tmp_path, table, mocker):
table.create_fts_index(["text", "text2"])
mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_df()
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 0

View File

@@ -36,11 +36,11 @@ def test_s3_io():
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_df()
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"

View File

@@ -85,17 +85,20 @@ def test_cast(table):
def test_query_builder(table):
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1)
.select(["id"])
.to_list()
)
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
assert rs[0]["id"] == 1
assert all(np.array(rs[0]["vector"]) == [1, 2])
def test_query_builder_with_filter(table):
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
assert rs[0]["id"] == 2
assert all(np.array(rs[0]["vector"]) == [3, 4])
def test_query_builder_with_prefilter(table):
@@ -103,7 +106,7 @@ def test_query_builder_with_prefilter(table):
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2")
.limit(1)
.to_df()
.to_pandas()
)
assert len(df) == 0
@@ -111,7 +114,7 @@ def test_query_builder_with_prefilter(table):
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2", prefilter=True)
.limit(1)
.to_df()
.to_pandas()
)
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
@@ -120,9 +123,11 @@ def test_query_builder_with_prefilter(table):
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("L2")
.to_pandas()
)
tm.assert_frame_equal(df_default, df_l2)
@@ -130,7 +135,7 @@ def test_query_builder_with_metric(table):
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
.to_pandas()
)
assert df_cosine._distance[0] == pytest.approx(
cosine_distance(query, df_cosine.vector[0]),

View File

@@ -86,7 +86,7 @@ async def test_e2e_with_mock_server():
columns=["id", "vector"],
),
)
).to_df()
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns

View File

@@ -32,4 +32,4 @@ def test_remote_db():
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.search([1.0, 2.0]).to_df()
table.search([1.0, 2.0]).to_pandas()

View File

@@ -427,8 +427,8 @@ def test_multiple_vector_columns(db):
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
@@ -439,6 +439,6 @@ def test_empty_query(db):
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
val = df.id.iloc[0]
assert val == 1