From 4f601a2d4c7e554ae7401c7387dbc418ff8de7a6 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 22 Jul 2024 12:53:17 -0700 Subject: [PATCH] fix: handle camelCase column names in select (#1460) Fixes #1385 --- nodejs/__test__/table.test.ts | 22 ++++++++++++++++++++++ nodejs/lancedb/query.ts | 27 +++++++++++++++++---------- nodejs/src/query.rs | 10 ++++++++++ python/python/lancedb/query.py | 14 +++++++------- python/python/tests/test_query.py | 9 +++++++++ python/src/query.rs | 8 ++++++++ 6 files changed, 73 insertions(+), 17 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index c503db8a..5aba068b 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -834,3 +834,25 @@ describe("when calling explainPlan", () => { expect(plan).toMatch("KNN"); }); }); + +describe("column name options", () => { + let tmpDir: tmp.DirResult; + let table: Table; + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const con = await connect(tmpDir.name); + table = await con.createTable("vectors", [ + { camelCase: 1, vector: [0.1, 0.2] }, + ]); + }); + + test("can select columns with different names", async () => { + const results = await table.query().select(["camelCase"]).toArray(); + expect(results[0].camelCase).toBe(1); + }); + + test("can filter on columns with different names", async () => { + const results = await table.query().where("`camelCase` = 1").toArray(); + expect(results[0].camelCase).toBe(1); + }); +}); diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 0f8670b7..0f52acc9 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -167,20 +167,27 @@ export class QueryBase select( columns: string[] | Map | Record | string, ): this { - let columnTuples: [string, string][]; + const selectColumns = (columnArray: string[]) => { + this.doCall((inner: NativeQueryType) => { + inner.selectColumns(columnArray); + }); + }; + const selectMapping = (columnTuples: [string, string][]) => { + this.doCall((inner: NativeQueryType) => { + inner.select(columnTuples); + }); + }; + if (typeof columns === "string") { - columns = [columns]; - } - if (Array.isArray(columns)) { - columnTuples = columns.map((c) => [c, c]); + selectColumns([columns]); + } else if (Array.isArray(columns)) { + selectColumns(columns); } else if (columns instanceof Map) { - columnTuples = Array.from(columns.entries()); + selectMapping(Array.from(columns.entries())); } else { - columnTuples = Object.entries(columns); + selectMapping(Object.entries(columns)); } - this.doCall((inner: NativeQueryType) => { - inner.select(columnTuples); - }); + return this; } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 68b6511d..692dc56b 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -47,6 +47,11 @@ impl Query { self.inner = self.inner.clone().select(Select::dynamic(&columns)); } + #[napi] + pub fn select_columns(&mut self, columns: Vec) { + self.inner = self.inner.clone().select(Select::columns(&columns)); + } + #[napi] pub fn limit(&mut self, limit: u32) { self.inner = self.inner.clone().limit(limit as usize); @@ -138,6 +143,11 @@ impl VectorQuery { self.inner = self.inner.clone().select(Select::dynamic(&columns)); } + #[napi] + pub fn select_columns(&mut self, columns: Vec) { + self.inner = self.inner.clone().select(Select::columns(&columns)); + } + #[napi] pub fn limit(&mut self, limit: u32) { self.inner = self.inner.clone().limit(limit as usize); diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 7b147abd..4eaaec03 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1127,14 +1127,14 @@ class AsyncQueryBase(object): Columns will always be returned in the order given, even if that order is different than the order used when adding the data. """ - if isinstance(columns, dict): - column_tuples = list(columns.items()) + if isinstance(columns, list) and all(isinstance(c, str) for c in columns): + self._inner.select_columns(columns) + elif isinstance(columns, dict) and all( + isinstance(k, str) and isinstance(v, str) for k, v in columns.items() + ): + self._inner.select(list(columns.items())) else: - try: - column_tuples = [(c, c) for c in columns] - except TypeError: - raise TypeError("columns must be a list of column names or a dict") - self._inner.select(column_tuples) + raise TypeError("columns must be a list of column names or a dict") return self def limit(self, limit: int) -> AsyncQuery: diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 89c5530e..c569cd49 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -345,3 +345,12 @@ def test_explain_plan(table): async def test_explain_plan_async(table_async: AsyncTable): plan = await table_async.query().nearest_to(pa.array([1, 2])).explain_plan(True) assert "KNN" in plan + + +@pytest.mark.asyncio +async def test_query_camelcase_async(tmp_path): + db = await lancedb.connect_async(tmp_path) + table = await db.create_table("test", pa.table({"camelCase": pa.array([1, 2])})) + + result = await table.query().select(["camelCase"]).to_arrow() + assert result == pa.table({"camelCase": pa.array([1, 2])}) diff --git a/python/src/query.rs b/python/src/query.rs index 1cdedc66..791b9ea1 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -52,6 +52,10 @@ impl Query { self.inner = self.inner.clone().select(Select::dynamic(&columns)); } + pub fn select_columns(&mut self, columns: Vec) { + self.inner = self.inner.clone().select(Select::columns(&columns)); + } + pub fn limit(&mut self, limit: u32) { self.inner = self.inner.clone().limit(limit as usize); } @@ -101,6 +105,10 @@ impl VectorQuery { self.inner = self.inner.clone().select(Select::dynamic(&columns)); } + pub fn select_columns(&mut self, columns: Vec) { + self.inner = self.inner.clone().select(Select::columns(&columns)); + } + pub fn limit(&mut self, limit: u32) { self.inner = self.inner.clone().limit(limit as usize); }