diff --git a/.github/workflows/run_tests/action.yml b/.github/workflows/run_tests/action.yml index 9fd65c70..6140cb5c 100644 --- a/.github/workflows/run_tests/action.yml +++ b/.github/workflows/run_tests/action.yml @@ -11,7 +11,7 @@ runs: - name: Install lancedb shell: bash run: | - pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev,embeddings] + pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev] - name: pytest shell: bash run: pytest -m "not slow" -x -v --durations=30 python/python/tests diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 0064ba7f..840b8830 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -336,10 +336,8 @@ class LanceQueryBuilder(ABC): LanceQueryBuilder The LanceQueryBuilder object. """ - if isinstance(columns, list): + if isinstance(columns, list) or isinstance(columns, dict): self._columns = columns - elif isinstance(columns, dict): - self._columns = list(columns.items()) else: raise ValueError("columns must be a list or a dictionary") return self diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index ed88f9c7..d1a08666 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -94,6 +94,17 @@ def test_query_builder(table): assert all(np.array(rs[0]["vector"]) == [1, 2]) +def test_dynamic_projection(table): + rs = ( + LanceVectorQueryBuilder(table, [0, 0], "vector") + .limit(1) + .select({"id": "id", "id2": "id * 2"}) + .to_list() + ) + assert rs[0]["id"] == 1 + assert rs[0]["id2"] == 2 + + def test_query_builder_with_filter(table): rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list() assert rs[0]["id"] == 2