fix: handle camelCase column names in select (#1460)

Fixes #1385
This commit is contained in:
Will Jones
2024-07-22 12:53:17 -07:00
committed by GitHub
parent 391fa26175
commit 4f601a2d4c
6 changed files with 73 additions and 17 deletions

View File

@@ -1127,14 +1127,14 @@ class AsyncQueryBase(object):
Columns will always be returned in the order given, even if that order is
different than the order used when adding the data.
"""
if isinstance(columns, dict):
column_tuples = list(columns.items())
if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
self._inner.select_columns(columns)
elif isinstance(columns, dict) and all(
isinstance(k, str) and isinstance(v, str) for k, v in columns.items()
):
self._inner.select(list(columns.items()))
else:
try:
column_tuples = [(c, c) for c in columns]
except TypeError:
raise TypeError("columns must be a list of column names or a dict")
self._inner.select(column_tuples)
raise TypeError("columns must be a list of column names or a dict")
return self
def limit(self, limit: int) -> AsyncQuery:

View File

@@ -345,3 +345,12 @@ def test_explain_plan(table):
async def test_explain_plan_async(table_async: AsyncTable):
plan = await table_async.query().nearest_to(pa.array([1, 2])).explain_plan(True)
assert "KNN" in plan
@pytest.mark.asyncio
async def test_query_camelcase_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
table = await db.create_table("test", pa.table({"camelCase": pa.array([1, 2])}))
result = await table.query().select(["camelCase"]).to_arrow()
assert result == pa.table({"camelCase": pa.array([1, 2])})