mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
fix(python): make sure pandas is optional (#2346)
Fixes #2344 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Tests** - Updated tests to use PyArrow Tables instead of pandas DataFrames where possible, reducing reliance on pandas. - Tests that require pandas are now automatically skipped if pandas is not installed. - **Chores** - Improved workflow to uninstall both pylance and pandas in a specific test step. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -136,9 +136,9 @@ jobs:
|
||||
- uses: ./.github/workflows/run_tests
|
||||
with:
|
||||
integration: true
|
||||
- name: Test without pylance
|
||||
- name: Test without pylance or pandas
|
||||
run: |
|
||||
pip uninstall -y pylance
|
||||
pip uninstall -y pylance pandas
|
||||
pytest -vv python/tests/test_table.py
|
||||
# Make sure wheels are not included in the Rust cache
|
||||
- name: Delete wheels
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
|
||||
from .dependencies import pandas as pd
|
||||
from .dependencies import _check_for_pandas, pandas as pd
|
||||
|
||||
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
@@ -63,7 +63,7 @@ def data_to_reader(
|
||||
data: DATA, schema: Optional[pa.Schema] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
"""Convert various types of input into a RecordBatchReader"""
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
if _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
return pa.Table.from_pandas(data, schema=schema).to_reader()
|
||||
elif isinstance(data, pa.Table):
|
||||
return data.to_reader()
|
||||
|
||||
@@ -9,9 +9,9 @@ from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import lancedb
|
||||
from lancedb.dependencies import _PANDAS_AVAILABLE
|
||||
from lancedb.index import HnswPq, HnswSq, IvfPq
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
@@ -138,13 +138,16 @@ def test_create_table(mem_db: DBConnection):
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
]
|
||||
df = pd.DataFrame(rows)
|
||||
pa_table = pa.Table.from_pandas(df, schema=schema)
|
||||
pa_table = pa.Table.from_pylist(rows, schema=schema)
|
||||
data = [
|
||||
("Rows", rows),
|
||||
("pd_DataFrame", df),
|
||||
("pa_Table", pa_table),
|
||||
]
|
||||
if _PANDAS_AVAILABLE:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
data.append(("pd_DataFrame", df))
|
||||
|
||||
for name, d in data:
|
||||
tbl = mem_db.create_table(name, data=d, schema=schema).to_arrow()
|
||||
@@ -296,7 +299,7 @@ def test_add_subschema(mem_db: DBConnection):
|
||||
|
||||
data = {"price": 10.0, "item": "foo"}
|
||||
table.add([data])
|
||||
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||
data = pa.Table.from_pydict({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||
table.add(data)
|
||||
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
||||
table.add([data])
|
||||
@@ -405,6 +408,7 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
|
||||
|
||||
def test_add_pydantic_model(mem_db: DBConnection):
|
||||
pytest.importorskip("pandas")
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
class Metadata(BaseModel):
|
||||
@@ -473,10 +477,10 @@ def test_polars(mem_db: DBConnection):
|
||||
table = mem_db.create_table("test", data=pl.DataFrame(data))
|
||||
assert len(table) == 2
|
||||
|
||||
result = table.to_pandas()
|
||||
assert np.allclose(result["vector"].tolist(), data["vector"])
|
||||
assert result["item"].tolist() == data["item"]
|
||||
assert np.allclose(result["price"].tolist(), data["price"])
|
||||
result = table.to_arrow()
|
||||
assert np.allclose(result["vector"].to_pylist(), data["vector"])
|
||||
assert result["item"].to_pylist() == data["item"]
|
||||
assert np.allclose(result["price"].to_pylist(), data["price"])
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
@@ -688,7 +692,7 @@ def test_delete(mem_db: DBConnection):
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
assert len(table) == 1
|
||||
assert table.to_pandas()["id"].tolist() == [1]
|
||||
assert table.to_arrow()["id"].to_pylist() == [1]
|
||||
|
||||
|
||||
def test_update(mem_db: DBConnection):
|
||||
@@ -852,6 +856,7 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
ids=["pa.Table", "pd.DataFrame", "rows"],
|
||||
)
|
||||
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
|
||||
pytest.importorskip("pandas")
|
||||
initial_data = pa.table(
|
||||
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
|
||||
)
|
||||
@@ -948,7 +953,7 @@ def test_create_with_embedding_function(mem_db: DBConnection):
|
||||
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
df = pa.table({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
@@ -973,7 +978,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
||||
text: str
|
||||
vector: Vector(32, value_type=pa.float16())
|
||||
|
||||
df = pd.DataFrame(
|
||||
df = pa.table(
|
||||
{
|
||||
"text": [f"s-{i}" for i in range(512)],
|
||||
"vector": [np.random.randn(32).astype(np.float16) for _ in range(512)],
|
||||
@@ -986,7 +991,7 @@ def test_create_f16_table(mem_db: DBConnection):
|
||||
table.add(df)
|
||||
table.create_index(num_partitions=2, num_sub_vectors=2)
|
||||
|
||||
query = df.vector.iloc[2]
|
||||
query = df["vector"][2].as_py()
|
||||
expected = table.search(query).limit(2).to_arrow()
|
||||
|
||||
assert "s-2" in expected["text"].to_pylist()
|
||||
@@ -1002,7 +1007,7 @@ def test_add_with_embedding_function(mem_db: DBConnection):
|
||||
table = mem_db.create_table("my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
df = pa.table({"text": texts})
|
||||
table.add(df)
|
||||
|
||||
texts = ["the quick brown fox", "jumped over the lazy dog"]
|
||||
@@ -1033,14 +1038,14 @@ def test_multiple_vector_columns(mem_db: DBConnection):
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
df = pa.Table.from_pylist(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
|
||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_arrow()
|
||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_arrow()
|
||||
|
||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||
assert result1["text"][0] != result2["text"][0]
|
||||
|
||||
|
||||
def test_create_scalar_index(mem_db: DBConnection):
|
||||
@@ -1078,22 +1083,22 @@ def test_empty_query(mem_db: DBConnection):
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
)
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
|
||||
val = df.id.iloc[0]
|
||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_arrow()
|
||||
val = df["id"][0].as_py()
|
||||
assert val == 1
|
||||
|
||||
table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)])
|
||||
df = table.search().select(["id"]).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).to_arrow()
|
||||
assert df.num_rows == 100
|
||||
# None is the same as default
|
||||
df = table.search().select(["id"]).limit(None).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).limit(None).to_arrow()
|
||||
assert df.num_rows == 100
|
||||
# invalid limist is the same as None, wihch is the same as default
|
||||
df = table.search().select(["id"]).limit(-1).to_pandas()
|
||||
assert len(df) == 100
|
||||
df = table.search().select(["id"]).limit(-1).to_arrow()
|
||||
assert df.num_rows == 100
|
||||
# valid limit should work
|
||||
df = table.search().select(["id"]).limit(42).to_pandas()
|
||||
assert len(df) == 42
|
||||
df = table.search().select(["id"]).limit(42).to_arrow()
|
||||
assert df.num_rows == 42
|
||||
|
||||
|
||||
def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
|
||||
@@ -1112,14 +1117,14 @@ def test_search_with_schema_inf_single_vector(mem_db: DBConnection):
|
||||
{"vector_col": v1, "text": "foo"},
|
||||
{"vector_col": v2, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
df = pa.Table.from_pylist(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
|
||||
result2 = table.search(q).limit(1).to_pandas()
|
||||
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_arrow()
|
||||
result2 = table.search(q).limit(1).to_arrow()
|
||||
|
||||
assert result1["text"].iloc[0] == result2["text"].iloc[0]
|
||||
assert result1["text"][0].as_py() == result2["text"][0].as_py()
|
||||
|
||||
|
||||
def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
@@ -1139,12 +1144,12 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
df = pa.Table.from_pylist(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
with pytest.raises(ValueError):
|
||||
table.search(q).limit(1).to_pandas()
|
||||
table.search(q).limit(1).to_arrow()
|
||||
|
||||
|
||||
def test_compact_cleanup(tmp_db: DBConnection):
|
||||
|
||||
Reference in New Issue
Block a user