From ab2c5adf5e3bea0f89744dd600edbd3f2e0444a6 Mon Sep 17 00:00:00 2001 From: Xin Sun Date: Sun, 17 May 2026 13:49:08 +0800 Subject: [PATCH] feat(nodejs): add order_by method to Query (#3123) --- docs/src/js/classes/Query.md | 24 +++ docs/src/js/classes/VectorQuery.md | 24 +++ docs/src/js/globals.md | 1 + docs/src/js/interfaces/ColumnOrdering.md | 31 ++++ nodejs/__test__/query.test.ts | 206 +++++++++++++++++++++++ nodejs/examples/filtering.test.ts | 9 + nodejs/lancedb/index.ts | 1 + nodejs/lancedb/query.ts | 21 +++ nodejs/src/query.rs | 55 +++++- python/python/lancedb/_lancedb.pyi | 10 ++ python/python/lancedb/query.py | 70 +++++++- python/python/lancedb/table.py | 2 + python/python/tests/test_fts.py | 31 ++++ python/python/tests/test_query.py | 66 ++++++++ python/python/tests/test_remote_db.py | 21 +++ python/src/query.rs | 79 ++++++++- rust/lancedb/src/query.rs | 18 ++ rust/lancedb/src/remote/table.rs | 33 +++- rust/lancedb/src/table/query.rs | 4 + 19 files changed, 696 insertions(+), 10 deletions(-) create mode 100644 docs/src/js/interfaces/ColumnOrdering.md diff --git a/docs/src/js/classes/Query.md b/docs/src/js/classes/Query.md index b69069069..bdf4764b7 100644 --- a/docs/src/js/classes/Query.md +++ b/docs/src/js/classes/Query.md @@ -343,6 +343,30 @@ This is useful for pagination. *** +### orderBy() + +```ts +orderBy(ordering): this +``` + +Sort the results by the specified column(s). + +#### Parameters + +* **ordering**: [`ColumnOrdering`](../interfaces/ColumnOrdering.md) \| [`ColumnOrdering`](../interfaces/ColumnOrdering.md)[] + +#### Returns + +`this` + +This query builder. + +#### Inherited from + +`StandardQueryBase.orderBy` + +*** + ### outputSchema() ```ts diff --git a/docs/src/js/classes/VectorQuery.md b/docs/src/js/classes/VectorQuery.md index 646c65cb6..fb010f65f 100644 --- a/docs/src/js/classes/VectorQuery.md +++ b/docs/src/js/classes/VectorQuery.md @@ -498,6 +498,30 @@ This is useful for pagination. *** +### orderBy() + +```ts +orderBy(ordering): this +``` + +Sort the results by the specified column(s). + +#### Parameters + +* **ordering**: [`ColumnOrdering`](../interfaces/ColumnOrdering.md) \| [`ColumnOrdering`](../interfaces/ColumnOrdering.md)[] + +#### Returns + +`this` + +This query builder. + +#### Inherited from + +`StandardQueryBase.orderBy` + +*** + ### outputSchema() ```ts diff --git a/docs/src/js/globals.md b/docs/src/js/globals.md index 5786afb88..fd2def7ab 100644 --- a/docs/src/js/globals.md +++ b/docs/src/js/globals.md @@ -51,6 +51,7 @@ - [AlterColumnsResult](interfaces/AlterColumnsResult.md) - [ClientConfig](interfaces/ClientConfig.md) - [ColumnAlteration](interfaces/ColumnAlteration.md) +- [ColumnOrdering](interfaces/ColumnOrdering.md) - [CompactionStats](interfaces/CompactionStats.md) - [ConnectNamespaceOptions](interfaces/ConnectNamespaceOptions.md) - [ConnectionOptions](interfaces/ConnectionOptions.md) diff --git a/docs/src/js/interfaces/ColumnOrdering.md b/docs/src/js/interfaces/ColumnOrdering.md new file mode 100644 index 000000000..550a24e49 --- /dev/null +++ b/docs/src/js/interfaces/ColumnOrdering.md @@ -0,0 +1,31 @@ +[**@lancedb/lancedb**](../README.md) • **Docs** + +*** + +[@lancedb/lancedb](../globals.md) / ColumnOrdering + +# Interface: ColumnOrdering + +## Properties + +### ascending? + +```ts +optional ascending: boolean; +``` + +*** + +### columnName + +```ts +columnName: string; +``` + +*** + +### nullsFirst? + +```ts +optional nullsFirst: boolean; +``` diff --git a/nodejs/__test__/query.test.ts b/nodejs/__test__/query.test.ts index a3d85dc31..a2975c133 100644 --- a/nodejs/__test__/query.test.ts +++ b/nodejs/__test__/query.test.ts @@ -109,3 +109,209 @@ describe("Query outputSchema", () => { expect(schema.fields.length).toBe(3); }); }); + +describe("Query orderBy", () => { + let tmpDir: tmp.DirResult; + let table: Table; + + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const db = await connect(tmpDir.name); + + // Create table with numeric data for sorting + const schema = new Schema([ + new Field("id", new Int64(), true), + new Field("score", new Float32(), true), + new Field("name", new Utf8(), true), + ]); + + const data = makeArrowTable( + [ + { id: 1n, score: 3.5, name: "charlie" }, + { id: 2n, score: 1.2, name: "alice" }, + { id: 3n, score: 2.8, name: "bob" }, + { id: 4n, score: 0.5, name: "david" }, + { id: 5n, score: 4.1, name: "eve" }, + ], + { schema }, + ); + table = await db.createTable("test", data); + }); + + afterEach(() => { + tmpDir.removeCallback(); + }); + + it("should sort by single column ascending", async () => { + const results = await table + .query() + .orderBy({ columnName: "score", ascending: true, nullsFirst: false }) + .toArray(); + + expect(results.length).toBe(5); + // Verify ascending order + expect(results[0].score).toBeCloseTo(0.5, 0.001); + expect(results[1].score).toBeCloseTo(1.2, 0.001); + expect(results[2].score).toBeCloseTo(2.8, 0.001); + expect(results[3].score).toBeCloseTo(3.5, 0.001); + expect(results[4].score).toBeCloseTo(4.1, 0.001); + }); + + it("should sort by single column descending", async () => { + const results = await table + .query() + .orderBy({ columnName: "score", ascending: false, nullsFirst: false }) + .toArray(); + + expect(results.length).toBe(5); + // Verify descending order + expect(results[0].score).toBeCloseTo(4.1, 0.001); + expect(results[1].score).toBeCloseTo(3.5, 0.001); + expect(results[2].score).toBeCloseTo(2.8, 0.001); + expect(results[3].score).toBeCloseTo(1.2, 0.001); + expect(results[4].score).toBeCloseTo(0.5, 0.001); + }); + + it("should use ascending as default direction", async () => { + const results = await table + .query() + .orderBy({ columnName: "score" }) + .toArray(); + + expect(results.length).toBe(5); + // Verify ascending order (default) + expect(results[0].score).toBeCloseTo(0.5, 0.001); + expect(results[1].score).toBeCloseTo(1.2, 0.001); + expect(results[2].score).toBeCloseTo(2.8, 0.001); + expect(results[3].score).toBeCloseTo(3.5, 0.001); + expect(results[4].score).toBeCloseTo(4.1, 0.001); + }); + + it("should sort by string column", async () => { + const results = await table + .query() + .orderBy({ columnName: "name" }) + .toArray(); + + expect(results.length).toBe(5); + // Verify alphabetical order + expect(results[0].name).toBe("alice"); + expect(results[1].name).toBe("bob"); + expect(results[2].name).toBe("charlie"); + expect(results[3].name).toBe("david"); + expect(results[4].name).toBe("eve"); + }); + + it("should support method chaining with where", async () => { + const results = await table + .query() + .where("score > 2.0") + .orderBy({ columnName: "score" }) + .toArray(); + expect(results.length).toBe(3); + // Verify filtered and sorted + expect(results[0].score).toBeCloseTo(2.8, 0.001); + expect(results[1].score).toBeCloseTo(3.5, 0.001); + expect(results[2].score).toBeCloseTo(4.1, 0.001); + }); + + it("should support method chaining with limit", async () => { + const results = await table + .query() + .orderBy({ columnName: "score", ascending: false }) + .limit(3) + .toArray(); + + expect(results.length).toBe(3); + // Verify top 3 in descending order + expect(results[0].score).toBeCloseTo(4.1, 0.001); + expect(results[1].score).toBeCloseTo(3.5, 0.001); + expect(results[2].score).toBeCloseTo(2.8, 0.001); + }); + + it("should support method chaining with offset", async () => { + const results = await table + .query() + .orderBy({ columnName: "score" }) + .offset(2) + .limit(2) + .toArray(); + + expect(results.length).toBe(2); + // Verify results skip first 2 and take next 2 + expect(results[0].score).toBeCloseTo(2.8, 0.001); + expect(results[1].score).toBeCloseTo(3.5, 0.001); + }); + + it("should support method chaining with select", async () => { + const results = await table + .query() + .orderBy({ columnName: "name" }) + .select(["name", "score"]) + .toArray(); + + expect(results.length).toBe(5); + // Verify only selected columns are present + expect(Object.keys(results[0])).toEqual(["name", "score"]); + expect(Object.keys(results[4])).toEqual(["name", "score"]); + // Verify sorted by name + expect(results[0].name).toBe("alice"); + expect(results[4].name).toBe("eve"); + }); + + it("should support complex method chaining", async () => { + const results = await table + .query() + .where("score > 1.0") + .orderBy({ columnName: "score", ascending: false }) + .limit(3) + .select(["id", "score", "name"]) + .toArray(); + + expect(results.length).toBe(3); + // Verify filtered, sorted, limited, and projected + expect(results[0].score).toBeCloseTo(4.1, 0.001); + expect(results[1].score).toBeCloseTo(3.5, 0.001); + expect(results[2].score).toBeCloseTo(2.8, 0.001); + expect(Object.keys(results[0])).toEqual(["id", "score", "name"]); + }); + + it("should support multi-column ordering and null placement", async () => { + const schema = new Schema([ + new Field("group", new Int64(), true), + new Field("score", new Float32(), true), + new Field("name", new Utf8(), true), + ]); + + const data = makeArrowTable( + [ + { group: 1n, score: null, name: "z" }, + { group: 1n, score: 1.0, name: "b" }, + { group: 1n, score: 1.0, name: "a" }, + { group: 2n, score: 0.5, name: "c" }, + ], + { schema }, + ); + const nullTable = await (await connect(tmpDir.name)).createTable( + "test_multi_order", + data, + { mode: "overwrite" }, + ); + + const results = await nullTable + .query() + .orderBy([ + { columnName: "group", ascending: true, nullsFirst: false }, + { columnName: "score", ascending: true, nullsFirst: true }, + { columnName: "name", ascending: true, nullsFirst: false }, + ]) + .toArray(); + + expect(results.map((r) => [r.group, r.score, r.name])).toEqual([ + [1n, null, "z"], + [1n, 1.0, "a"], + [1n, 1.0, "b"], + [2n, 0.5, "c"], + ]); + }); +}); diff --git a/nodejs/examples/filtering.test.ts b/nodejs/examples/filtering.test.ts index 1c35d5120..ab77a5851 100644 --- a/nodejs/examples/filtering.test.ts +++ b/nodejs/examples/filtering.test.ts @@ -38,5 +38,14 @@ test("filtering examples", async () => { // --8<-- [start:sql_search] await tbl.query().where("id = 10").limit(10).toArray(); // --8<-- [end:sql_search] + + // --8<-- [start:orderby_search] + await tbl + .query() + .where("id > 10") + .orderBy({ columnName: "id", ascending: false }) + .limit(5) + .toArray(); + // --8<-- [end:orderby_search] }); }); diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index e1c08b7b5..f1a36722f 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -82,6 +82,7 @@ export { VectorQuery, TakeQuery, QueryExecutionOptions, + ColumnOrdering, FullTextSearchOptions, RecordBatchIterator, FullTextQuery, diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index c077234ec..f985eaf83 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -79,6 +79,12 @@ export interface QueryExecutionOptions { timeoutMs?: number; } +export interface ColumnOrdering { + columnName: string; + ascending?: boolean; + nullsFirst?: boolean; +} + /** * Options that control the behavior of a full text search */ @@ -417,6 +423,21 @@ export class StandardQueryBase< return this; } + /** + * Sort the results by the specified column(s). + * @returns This query builder. + */ + orderBy(ordering: ColumnOrdering | ColumnOrdering[]): this { + const orderings = Array.isArray(ordering) ? ordering : [ordering]; + const normalized = orderings.map((o) => ({ + columnName: o.columnName, + ascending: o.ascending ?? true, + nullsFirst: o.nullsFirst ?? false, + })); + this.doCall((inner) => inner.orderBy(normalized)); + return this; + } + /** * Skip searching un-indexed data. This can make search faster, but will miss * any data that is not yet indexed. diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 4516385d5..6032619e3 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -3,6 +3,12 @@ use std::sync::Arc; +use crate::error::NapiErrorExt; +use crate::error::convert_error; +use crate::iterator::RecordBatchIterator; +use crate::rerankers::RerankHybridCallbackArgs; +use crate::rerankers::Reranker; +use crate::util::{parse_distance_type, schema_to_buffer}; use arrow_array::{ Array, Float16Array as ArrowFloat16Array, Float32Array as ArrowFloat32Array, Float64Array as ArrowFloat64Array, UInt8Array as ArrowUInt8Array, @@ -19,16 +25,27 @@ use lancedb::query::QueryBase; use lancedb::query::QueryExecutionOptions; use lancedb::query::Select; use lancedb::query::TakeQuery as LanceDbTakeQuery; -use lancedb::query::VectorQuery as LanceDbVectorQuery; +use lancedb::query::{ColumnOrdering as LanceDbColumnOrdering, VectorQuery as LanceDbVectorQuery}; use napi::bindgen_prelude::*; use napi_derive::napi; -use crate::error::NapiErrorExt; -use crate::error::convert_error; -use crate::iterator::RecordBatchIterator; -use crate::rerankers::RerankHybridCallbackArgs; -use crate::rerankers::Reranker; -use crate::util::{parse_distance_type, schema_to_buffer}; +#[napi(object)] +pub struct ColumnOrdering { + pub ascending: bool, + pub nulls_first: bool, + pub column_name: String, +} + +impl From for LanceDbColumnOrdering { + fn from(value: ColumnOrdering) -> Self { + match (value.ascending, value.nulls_first) { + (true, true) => Self::asc_nulls_first(value.column_name), + (true, false) => Self::asc_nulls_last(value.column_name), + (false, true) => Self::desc_nulls_first(value.column_name), + (false, false) => Self::desc_nulls_last(value.column_name), + } + } +} fn bytes_to_arrow_array(data: Uint8Array, dtype: String) -> napi::Result> { let buf = arrow_buffer::Buffer::from(data.to_vec()); @@ -128,6 +145,18 @@ impl Query { self.inner = self.inner.clone().with_row_id(); } + #[napi] + pub fn order_by(&mut self, ordering: Option>) -> napi::Result<()> { + let ordering = ordering.map(|ordering| { + ordering + .into_iter() + .map(LanceDbColumnOrdering::from) + .collect() + }); + self.inner = self.inner.clone().order_by(ordering); + Ok(()) + } + #[napi(catch_unwind)] pub async fn output_schema(&self) -> napi::Result { let schema = self.inner.output_schema().await.default_error()?; @@ -328,6 +357,18 @@ impl VectorQuery { Ok(()) } + #[napi] + pub fn order_by(&mut self, ordering: Option>) -> napi::Result<()> { + let ordering = ordering.map(|ordering| { + ordering + .into_iter() + .map(LanceDbColumnOrdering::from) + .collect() + }); + self.inner = self.inner.clone().order_by(ordering); + Ok(()) + } + #[napi(catch_unwind)] pub async fn output_schema(&self) -> napi::Result { let schema = self.inner.output_schema().await.default_error()?; diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index d6a8d71d6..8839af156 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -255,6 +255,11 @@ class RecordBatchStream: def __aiter__(self) -> "RecordBatchStream": ... async def __anext__(self) -> pa.RecordBatch: ... +class ColumnOrdering(TypedDict): + column_name: str + ascending: bool + nulls_first: bool + class Query: def where(self, filter: str): ... def where_expr(self, expr: PyExpr): ... @@ -268,6 +273,7 @@ class Query: def postfilter(self): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... + def order_by(self, ordering: Optional[List[ColumnOrdering]]): ... async def output_schema(self) -> pa.Schema: ... async def execute( self, max_batch_length: Optional[int], timeout: Optional[timedelta] @@ -296,6 +302,7 @@ class FTSQuery: def get_query(self) -> str: ... def add_query_vector(self, query_vec: pa.Array) -> None: ... def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... + def order_by(self, ordering: Optional[List[ColumnOrdering]]): ... async def output_schema(self) -> pa.Schema: ... async def execute( self, max_batch_length: Optional[int], timeout: Optional[timedelta] @@ -321,6 +328,7 @@ class VectorQuery: def maximum_nprobes(self, maximum_nprobes: int): ... def bypass_vector_index(self): ... def nearest_to_text(self, query: dict) -> HybridQuery: ... + def order_by(self, ordering: Optional[List[ColumnOrdering]]): ... def to_query_request(self) -> PyQueryRequest: ... class HybridQuery: @@ -339,6 +347,7 @@ class HybridQuery: def minimum_nprobes(self, minimum_nprobes: int): ... def maximum_nprobes(self, maximum_nprobes: int): ... def bypass_vector_index(self): ... + def order_by(self, ordering: Optional[List[ColumnOrdering]]): ... def to_vector_query(self) -> VectorQuery: ... def to_fts_query(self) -> FTSQuery: ... def get_limit(self) -> int: ... @@ -368,6 +377,7 @@ class PyQueryRequest: bypass_vector_index: Optional[bool] postfilter: Optional[bool] norm: Optional[str] + order_by: Optional[List[ColumnOrdering]] class CompactionStats: fragments_removed: int diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 0a9473a0a..f472f8bb7 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -92,6 +92,12 @@ def ensure_vector_query( return val +class ColumnOrdering(pydantic.BaseModel): + column_name: str + ascending: bool = True + nulls_first: bool = False + + class FullTextQueryType(str, Enum): MATCH = "match" MATCH_PHRASE = "match_phrase" @@ -504,6 +510,8 @@ class Query(pydantic.BaseModel): # Bypass the vector index and use a brute force search bypass_vector_index: Optional[bool] = None + order_by: Optional[List[ColumnOrdering]] = None + @classmethod def from_inner(cls, req: PyQueryRequest) -> Self: query = cls() @@ -524,6 +532,8 @@ class Query(pydantic.BaseModel): query.refine_factor = req.refine_factor query.bypass_vector_index = req.bypass_vector_index query.postfilter = req.postfilter + if req.order_by is not None: + query.order_by = [ColumnOrdering(**o) for o in req.order_by] if req.full_text_search is not None: query.full_text_query = FullTextSearchQuery( columns=None, @@ -572,9 +582,22 @@ class LanceQueryBuilder(ABC): If "auto", the query type is inferred based on the query. vector_column_name: str The name of the vector column to use for vector search. + ordering_field_name: Optional[str] + .. deprecated:: 0.27.0 + Use ``order_by()`` method instead. + fts_columns: Optional[Union[str, List[str]]] + The columns to search in for full text search. fast_search: bool Skip flat search of unindexed data. """ + if ordering_field_name is not None: + import warnings + + warnings.warn( + "ordering_field_name is deprecated, use .order_by() method instead.", + DeprecationWarning, + stacklevel=2, + ) # Check hybrid search first as it supports empty query pattern if query_type == "hybrid": # hybrid fts and vector query @@ -671,6 +694,7 @@ class LanceQueryBuilder(ABC): self._text = None self._ef = None self._bypass_vector_index = None + self._order_by = None @deprecation.deprecated( deprecated_in="0.3.1", @@ -947,6 +971,24 @@ class LanceQueryBuilder(ABC): """ # noqa: E501 return self._table._explain_plan(self.to_query_object(), verbose=verbose) + def order_by(self, ordering: Optional[List[ColumnOrdering]]) -> Self: + """ + Set the ordering for the results. + + Parameters + ---------- + ordering: Optional[List[ColumnOrdering]] + The ordering to use for the results. If None, then the default ordering + will be used. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + self._order_by = ordering + return self + def analyze_plan(self) -> str: """ Run the query and return its execution plan with runtime metrics. @@ -1314,6 +1356,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): fast_search=self._fast_search, ef=self._ef, bypass_vector_index=self._bypass_vector_index, + order_by=self._order_by, ) def to_batches( @@ -1465,7 +1508,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): super().__init__(table) self._query = query self._phrase_query = False - self.ordering_field_name = ordering_field_name + # Deprecated compatibility parameter. Native FTS ordering is now + # configured through order_by(); LanceQueryBuilder.create emits the warning. + _ = ordering_field_name self._reranker = None self._fast_search = fast_search if isinstance(fts_columns, str): @@ -1514,6 +1559,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): ), offset=self._offset, fast_search=self._fast_search, + order_by=self._order_by, ) def output_schema(self) -> pa.Schema: @@ -1579,6 +1625,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): limit=self._limit, with_row_id=self._with_row_id, offset=self._offset, + order_by=self._order_by, ) def output_schema(self) -> pa.Schema: @@ -2502,6 +2549,27 @@ class AsyncStandardQuery(AsyncQueryBase): self._inner.offset(offset) return self + def order_by(self, ordering: Optional[List[ColumnOrdering]]) -> Self: + """ + Set the ordering for the results. + + Parameters + ---------- + ordering: Optional[List[ColumnOrdering]] + The ordering to use for the results. If None, then the default ordering + will be used. + """ + if ordering is None: + self._inner.order_by(None) + else: + self._inner.order_by( + [ + o.model_dump() if hasattr(o, "model_dump") else o.dict() + for o in ordering + ] + ) + return self + def fast_search(self) -> Self: """ Skip searching un-indexed data. diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 87b14e434..c00c14f9c 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -4512,6 +4512,8 @@ class AsyncTable: async_query = async_query.fast_search() if query.with_row_id: async_query = async_query.with_row_id() + if query.order_by: + async_query = async_query.order_by(query.order_by) if query.vector: async_query = async_query.nearest_to(query.vector).distance_range( diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index 614d81185..acd362a09 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -29,6 +29,7 @@ from lancedb.query import ( MultiMatchQuery, PhraseQuery, BooleanQuery, + ColumnOrdering, Occur, LanceFtsQueryBuilder, ) @@ -499,6 +500,36 @@ async def test_search_fts_specify_column_async(async_table): pass +def test_search_order_by_descending(table): + table.create_fts_index("text") + rows = ( + table.search("puppy") + .order_by([ColumnOrdering(column_name="count", ascending=False)]) + .limit(20) + .select(["text", "count"]) + .to_list() + ) + + for r in rows: + assert "puppy" in r["text"] + assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows + + +def test_search_order_by_ascending(table): + table.create_fts_index("text") + rows = ( + table.search("puppy") + .order_by([ColumnOrdering(column_name="count", ascending=True)]) + .limit(20) + .select(["text", "count"]) + .to_list() + ) + + for r in rows: + assert "puppy" in r["text"] + assert sorted(rows, key=lambda x: x["count"]) == rows + + def test_create_index_from_table(tmp_path, table): table.create_fts_index("text") df = table.search("puppy").limit(5).select(["text"]).to_pandas() diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 9ac585df5..e5ba9e5ae 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -25,6 +25,7 @@ from lancedb.query import ( AsyncHybridQuery, AsyncQueryBase, AsyncVectorQuery, + ColumnOrdering, LanceVectorQueryBuilder, MatchQuery, PhraseQuery, @@ -164,6 +165,71 @@ def test_offset(table): assert len(results_with_offset.to_pandas()) == 1 +def test_order_by_plain_query(mem_db): + table = mem_db.create_table( + "test_order_by", + pa.table( + { + "group": [1, 1, 1, 2], + "score": [None, 1.0, 1.0, 0.5], + "name": ["z", "b", "a", "c"], + } + ), + ) + + res = ( + table.search() + .order_by( + [ + ColumnOrdering(column_name="group", ascending=True, nulls_first=False), + ColumnOrdering(column_name="score", ascending=True, nulls_first=True), + ColumnOrdering(column_name="name", ascending=True, nulls_first=False), + ] + ) + .to_arrow() + ) + + assert res.select(["group", "score", "name"]).to_pylist() == [ + {"group": 1, "score": None, "name": "z"}, + {"group": 1, "score": 1.0, "name": "a"}, + {"group": 1, "score": 1.0, "name": "b"}, + {"group": 2, "score": 0.5, "name": "c"}, + ] + + +@pytest.mark.asyncio +async def test_order_by_async_query(mem_db_async: AsyncConnection): + table = await mem_db_async.create_table( + "test_order_by_async", + pa.table( + { + "group": [1, 1, 1, 2], + "score": [None, 1.0, 1.0, 0.5], + "name": ["z", "b", "a", "c"], + } + ), + ) + + res = await ( + table.query() + .order_by( + [ + ColumnOrdering(column_name="group", ascending=True, nulls_first=False), + ColumnOrdering(column_name="score", ascending=True, nulls_first=True), + ColumnOrdering(column_name="name", ascending=True, nulls_first=False), + ] + ) + .to_arrow() + ) + + assert res.select(["group", "score", "name"]).to_pylist() == [ + {"group": 1, "score": None, "name": "z"}, + {"group": 1, "score": 1.0, "name": "a"}, + {"group": 1, "score": 1.0, "name": "b"}, + {"group": 2, "score": 0.5, "name": "c"}, + ] + + def test_query_builder(table): rs = ( LanceVectorQueryBuilder(table, [0, 0], "vector") diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index a499275c5..bc69a0410 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -16,6 +16,7 @@ from packaging.version import Version import lancedb from lancedb.conftest import MockTextEmbeddingFunction +from lancedb.query import ColumnOrdering from lancedb.remote import ClientConfig from lancedb.remote.errors import HttpError, RetryError import pytest @@ -660,6 +661,18 @@ def test_query_sync_maximal(): "ef": None, "filter": "id > 0", "columns": ["id", "name"], + "order_by": [ + { + "column_name": "score", + "ascending": False, + "nulls_first": True, + }, + { + "column_name": "id", + "ascending": True, + "nulls_first": False, + }, + ], "vector_column": "vector2", "fast_search": True, "with_row_id": True, @@ -677,6 +690,14 @@ def test_query_sync_maximal(): .refine_factor(10) .nprobes(5) .where("id > 0", prefilter=True) + .order_by( + [ + ColumnOrdering( + column_name="score", ascending=False, nulls_first=True + ), + ColumnOrdering(column_name="id", ascending=True, nulls_first=False), + ] + ) .with_row_id(True) .select(["id", "name"]) .to_list() diff --git a/python/src/query.rs b/python/src/query.rs index 1dc4f08db..2c682e38a 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -23,7 +23,7 @@ use lancedb::query::QueryBase; use lancedb::query::QueryExecutionOptions; use lancedb::query::QueryFilter; use lancedb::query::{ - ExecutableQuery, Query as LanceDbQuery, Select, TakeQuery as LanceDbTakeQuery, + ColumnOrdering, ExecutableQuery, Query as LanceDbQuery, Select, TakeQuery as LanceDbTakeQuery, VectorQuery as LanceDbVectorQuery, }; use lancedb::table::AnyQuery; @@ -207,6 +207,48 @@ impl<'py> IntoPyObject<'py> for PyLanceDB { #[derive(Clone)] pub struct PyQueryVectors(Vec>); +#[derive(Clone, FromPyObject)] +#[pyo3(from_item_all)] +pub struct PyColumnOrdering { + pub column_name: String, + pub ascending: bool, + pub nulls_first: bool, +} + +impl From for PyColumnOrdering { + fn from(ordering: ColumnOrdering) -> Self { + Self { + column_name: ordering.column_name, + ascending: ordering.ascending, + nulls_first: ordering.nulls_first, + } + } +} + +impl From for ColumnOrdering { + fn from(ordering: PyColumnOrdering) -> Self { + Self { + column_name: ordering.column_name, + ascending: ordering.ascending, + nulls_first: ordering.nulls_first, + } + } +} + +impl<'py> IntoPyObject<'py> for PyColumnOrdering { + type Target = PyDict; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + let dict = PyDict::new(py); + dict.set_item("column_name", self.column_name)?; + dict.set_item("ascending", self.ascending)?; + dict.set_item("nulls_first", self.nulls_first)?; + Ok(dict) + } +} + impl<'py> IntoPyObject<'py> for PyQueryVectors { type Target = PyList; type Output = Bound<'py, Self::Target>; @@ -246,6 +288,7 @@ pub struct PyQueryRequest { pub bypass_vector_index: Option, pub postfilter: Option, pub norm: Option, + pub order_by: Option>, } impl From for PyQueryRequest { @@ -273,6 +316,9 @@ impl From for PyQueryRequest { bypass_vector_index: None, postfilter: None, norm: None, + order_by: query_request + .order_by + .map(|order_by| order_by.into_iter().map(PyColumnOrdering::from).collect()), }, AnyQuery::VectorQuery(vector_query) => Self { limit: vector_query.base.limit, @@ -297,6 +343,10 @@ impl From for PyQueryRequest { bypass_vector_index: Some(!vector_query.use_index), postfilter: Some(!vector_query.base.prefilter), norm: vector_query.base.norm.map(|n| n.to_string()), + order_by: vector_query + .base + .order_by + .map(|order_by| order_by.into_iter().map(PyColumnOrdering::from).collect()), }, } } @@ -475,6 +525,13 @@ impl Query { }) } + pub fn order_by(&mut self, ordering: Option>) -> PyResult<()> { + let ordering = + ordering.map(|ordering| ordering.into_iter().map(ColumnOrdering::from).collect()); + self.inner = self.inner.clone().order_by(ordering); + Ok(()) + } + #[pyo3(signature = ())] pub fn output_schema(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner.clone(); @@ -647,6 +704,13 @@ impl FTSQuery { self.inner = self.inner.clone().offset(offset as usize); } + pub fn order_by(&mut self, ordering: Option>) -> PyResult<()> { + let ordering = + ordering.map(|ordering| ordering.into_iter().map(ColumnOrdering::from).collect()); + self.inner = self.inner.clone().order_by(ordering); + Ok(()) + } + pub fn fast_search(&mut self) { self.inner = self.inner.clone().fast_search(); } @@ -782,6 +846,13 @@ impl VectorQuery { self.inner = self.inner.clone().offset(offset as usize); } + pub fn order_by(&mut self, ordering: Option>) -> PyResult<()> { + let ordering = + ordering.map(|ordering| ordering.into_iter().map(ColumnOrdering::from).collect()); + self.inner = self.inner.clone().order_by(ordering); + Ok(()) + } + pub fn fast_search(&mut self) { self.inner = self.inner.clone().fast_search(); } @@ -954,6 +1025,12 @@ impl HybridQuery { self.inner_fts.offset(offset); } + pub fn order_by(&mut self, ordering: Option>) -> PyResult<()> { + self.inner_vec.order_by(ordering.clone())?; + self.inner_fts.order_by(ordering)?; + Ok(()) + } + pub fn fast_search(&mut self) { self.inner_vec.fast_search(); self.inner_fts.fast_search(); diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 8f60230b0..7f82f5517 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -11,6 +11,8 @@ use datafusion_expr::Expr; use datafusion_physical_plan::ExecutionPlan; use futures::{FutureExt, TryFutureExt, TryStreamExt, stream, try_join}; use half::f16; +/// Re-export Lance ColumnOrdering type for use in query ordering +pub use lance::dataset::scanner::ColumnOrdering; use lance::dataset::{ROW_ID, scanner::DatasetRecordBatchStream}; use lance_arrow::RecordBatchExt; use lance_datafusion::exec::execute_plan; @@ -510,6 +512,11 @@ pub trait QueryBase { /// the scores are converted to ranks and then normalized. If "Score", the /// scores are normalized directly. fn norm(self, norm: NormalizeMethod) -> Self; + + /// Sort the results by the specified column(s). + /// + /// This allows ordering query results by one or more columns in either ascending or descending order. + fn order_by(self, ordering: Option>) -> Self; } pub trait HasQuery { @@ -574,6 +581,11 @@ impl QueryBase for T { self.mut_query().norm = Some(norm); self } + + fn order_by(mut self, ordering: Option>) -> Self { + self.mut_query().order_by = ordering; + self + } } /// Options for controlling the execution of a query @@ -750,6 +762,11 @@ pub struct QueryRequest { /// /// By default, this is false (scoring columns are auto-projected for backward compatibility). pub disable_scoring_autoprojection: bool, + + /// Sort the results by the specified column(s). + /// + /// This allows ordering query results by one or more columns in either ascending or descending order. + pub order_by: Option>, } impl Default for QueryRequest { @@ -766,6 +783,7 @@ impl Default for QueryRequest { reranker: None, norm: None, disable_scoring_autoprojection: false, + order_by: None, } } } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index b991ed335..34c513bcd 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -518,6 +518,21 @@ impl RemoteTable { } } + if let Some(order_by) = ¶ms.order_by { + body["order_by"] = serde_json::Value::Array( + order_by + .iter() + .map(|o| { + serde_json::json!({ + "column_name": o.column_name, + "ascending": o.ascending, + "nulls_first": o.nulls_first, + }) + }) + .collect(), + ); + } + Ok(()) } @@ -2078,7 +2093,7 @@ mod tests { use crate::{ DistanceType, Error, Table, index::{Index, IndexStatistics, IndexType, vector::IvfPqIndexBuilder}, - query::{ExecutableQuery, QueryBase}, + query::{ColumnOrdering, ExecutableQuery, QueryBase}, remote::ARROW_FILE_CONTENT_TYPE, }; @@ -2988,6 +3003,18 @@ mod tests { "distance_type": "cosine", "bypass_vector_index": true, "columns": ["a", "b"], + "order_by": [ + { + "column_name": "score", + "ascending": false, + "nulls_first": true, + }, + { + "column_name": "id", + "ascending": true, + "nulls_first": false, + } + ], "nprobes": 12, "minimum_nprobes": 12, "maximum_nprobes": 12, @@ -3019,6 +3046,10 @@ mod tests { .limit(42) .offset(10) .select(Select::columns(&["a", "b"])) + .order_by(Some(vec![ + ColumnOrdering::desc_nulls_first("score".to_string()), + ColumnOrdering::asc_nulls_last("id".to_string()), + ])) .nearest_to(vec![0.1, 0.2, 0.3]) .unwrap() .column("my_vector") diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index e7e66a901..cc9312a0f 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -242,6 +242,10 @@ pub async fn create_plan( scanner.disable_scoring_autoprojection(); } + if let Some(order_by) = &query.base.order_by { + scanner.order_by(Some(order_by.clone()))?; + } + Ok(scanner.create_plan().await?) }