mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
test: string type conversion in pandas 3.0+ (#2928)
Pandas 3.0+ string now converts to Arrow large_utf8. This PR mainly makes sure our test accounts for the difference across the pandas versions when constructing schema.
This commit is contained in:
@@ -2,12 +2,27 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from lancedb.db import AsyncConnection, DBConnection
|
||||
import lancedb
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
def pandas_string_type():
|
||||
"""Return the PyArrow string type that pandas uses for string columns.
|
||||
|
||||
pandas 3.0+ uses large_string for string columns, pandas 2.x uses string.
|
||||
"""
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
version = tuple(int(x) for x in pd.__version__.split(".")[:2])
|
||||
if version >= (3, 0):
|
||||
return pa.large_utf8()
|
||||
return pa.utf8()
|
||||
|
||||
|
||||
# Use an in-memory database for most tests.
|
||||
@pytest.fixture
|
||||
def mem_db() -> DBConnection:
|
||||
|
||||
@@ -268,6 +268,8 @@ async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConne
|
||||
|
||||
|
||||
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
from conftest import pandas_string_type
|
||||
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -286,10 +288,11 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
assert tbl.schema == tbl2.schema
|
||||
assert len(tbl) == len(tbl2)
|
||||
|
||||
# pandas 3.0+ uses large_string, pandas 2.x uses string
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("item", pandas_string_type()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
@@ -299,7 +302,7 @@ def test_create_exist_ok(tmp_db: lancedb.DBConnection):
|
||||
bad_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("item", pandas_string_type()),
|
||||
pa.field("price", pa.float64()),
|
||||
pa.field("extra", pa.float32()),
|
||||
]
|
||||
@@ -365,6 +368,8 @@ async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
from conftest import pandas_string_type
|
||||
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
@@ -382,10 +387,11 @@ async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
|
||||
assert tbl.name == tbl2.name
|
||||
assert await tbl.schema() == await tbl2.schema()
|
||||
|
||||
# pandas 3.0+ uses large_string, pandas 2.x uses string
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("item", pandas_string_type()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
@@ -595,6 +601,8 @@ def test_open_table_sync(tmp_db: lancedb.DBConnection):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_table(tmp_path):
|
||||
from conftest import pandas_string_type
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
@@ -614,10 +622,11 @@ async def test_open_table(tmp_path):
|
||||
)
|
||||
is not None
|
||||
)
|
||||
# pandas 3.0+ uses large_string, pandas 2.x uses string
|
||||
assert await tbl.schema() == pa.schema(
|
||||
{
|
||||
"vector": pa.list_(pa.float32(), list_size=2),
|
||||
"item": pa.utf8(),
|
||||
"item": pandas_string_type(),
|
||||
"price": pa.float64(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -528,12 +528,19 @@ def test_sanitize_data(
|
||||
else:
|
||||
expected_schema = schema
|
||||
else:
|
||||
from conftest import pandas_string_type
|
||||
|
||||
# polars uses large_string, pandas 3.0+ uses large_string, others use string
|
||||
if isinstance(data, pl.DataFrame):
|
||||
text_type = pa.large_utf8()
|
||||
elif isinstance(data, pd.DataFrame):
|
||||
text_type = pandas_string_type()
|
||||
else:
|
||||
text_type = pa.string()
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"text": pa.large_utf8()
|
||||
if isinstance(data, pl.DataFrame)
|
||||
else pa.string(),
|
||||
"text": text_type,
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user