diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 8d495d64..1171415c 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -666,11 +666,11 @@ describe("When creating an index", () => { expect(fs.readdirSync(indexDir)).toHaveLength(1); for await (const r of tbl.query().where("id > 1").select(["id"])) { - expect(r.numRows).toBe(10); + expect(r.numRows).toBe(298); } // should also work with 'filter' alias for await (const r of tbl.query().filter("id > 1").select(["id"])) { - expect(r.numRows).toBe(10); + expect(r.numRows).toBe(298); } }); diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 087a8a6b..cf7db0ce 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -110,7 +110,7 @@ class Query(pydantic.BaseModel): full_text_query: Optional[Union[str, dict]] = None # top k results to return - k: int + k: Optional[int] = None # # metrics metric: str = "L2" @@ -257,7 +257,7 @@ class LanceQueryBuilder(ABC): def __init__(self, table: "Table"): self._table = table - self._limit = 10 + self._limit = None self._offset = 0 self._columns = None self._where = None @@ -370,8 +370,7 @@ class LanceQueryBuilder(ABC): The maximum number of results to return. The default query limit is 10 results. For ANN/KNN queries, you must specify a limit. - Entering 0, a negative number, or None will reset - the limit to the default value of 10. + For plain searches, all records are returned if limit not set. *WARNING* if you have a large dataset, setting the limit to a large number, e.g. the table size, can potentially result in reading a @@ -595,6 +594,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): fast_search: bool = False, ): super().__init__(table) + if self._limit is None: + self._limit = 10 self._query = query self._distance_type = "L2" self._nprobes = 20 @@ -888,6 +889,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): fts_columns: Union[str, List[str]] = [], ): super().__init__(table) + if self._limit is None: + self._limit = 10 self._query = query self._phrase_query = False self.ordering_field_name = ordering_field_name @@ -1055,7 +1058,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): query = Query( columns=self._columns, filter=self._where, - k=self._limit or 10, + k=self._limit, with_row_id=self._with_row_id, vector=[], # not actually respected in remote query diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index fce9c262..ce084b1d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3195,7 +3195,9 @@ class AsyncTable: # The sync remote table calls into this method, so we need to map the # query to the async version of the query and run that here. This is only # used for that code path right now. - async_query = self.query().limit(query.k) + async_query = self.query() + if query.k is not None: + async_query = async_query.limit(query.k) if query.offset > 0: async_query = async_query.offset(query.offset) if query.columns: diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index c08976f9..2e5a9fa9 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -174,6 +174,10 @@ def test_search_fts(table, use_tantivy): assert len(results) == 5 assert len(results[0]) == 3 # id, text, _score + # Default limit of 10 + results = table.search("puppy").select(["id", "text"]).to_list() + assert len(results) == 10 + @pytest.mark.asyncio async def test_fts_select_async(async_table): diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index b3117e83..82e11cf5 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1025,13 +1025,13 @@ def test_empty_query(mem_db: DBConnection): table = mem_db.create_table("my_table2", data=[{"id": i} for i in range(100)]) df = table.search().select(["id"]).to_pandas() - assert len(df) == 10 + assert len(df) == 100 # None is the same as default df = table.search().select(["id"]).limit(None).to_pandas() - assert len(df) == 10 + assert len(df) == 100 # invalid limist is the same as None, wihch is the same as default df = table.search().select(["id"]).limit(-1).to_pandas() - assert len(df) == 10 + assert len(df) == 100 # valid limit should work df = table.search().select(["id"]).limit(42).to_pandas() assert len(df) == 42 diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index afc47222..d3ccfcf4 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -470,6 +470,9 @@ impl QueryBase for T { } fn full_text_search(mut self, query: FullTextSearchQuery) -> Self { + if self.mut_query().limit.is_none() { + self.mut_query().limit = Some(DEFAULT_TOP_K); + } self.mut_query().full_text_search = Some(query); self } @@ -634,7 +637,7 @@ pub struct QueryRequest { impl Default for QueryRequest { fn default() -> Self { Self { - limit: Some(DEFAULT_TOP_K), + limit: None, offset: None, filter: None, full_text_search: None, @@ -719,6 +722,11 @@ impl Query { let mut vector_query = self.into_vector(); let query_vector = vector.to_query_vector(&DataType::Float32, "default")?; vector_query.request.query_vector.push(query_vector); + + if vector_query.request.base.limit.is_none() { + vector_query.request.base.limit = Some(DEFAULT_TOP_K); + } + Ok(vector_query) } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index e7c182dd..88ee3f45 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -154,9 +154,9 @@ impl RemoteTable { body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset)); } - if let Some(limit) = params.limit { - body["k"] = serde_json::Value::Number(serde_json::Number::from(limit)); - } + // Server requires k. + let limit = params.limit.unwrap_or(usize::MAX); + body["k"] = serde_json::Value::Number(serde_json::Number::from(limit)); if let Some(filter) = ¶ms.filter { if let QueryFilter::Sql(filter) = filter { @@ -1293,6 +1293,52 @@ mod tests { table.delete("id in (1, 2, 3)").await.unwrap(); } + #[tokio::test] + async fn test_query_plain() { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let expected_body = serde_json::json!({ + "k": usize::MAX, + "prefilter": true, + "vector": [], // Empty vector means no vector query. + "version": null, + }); + assert_eq!(body, expected_body); + + let response_body = write_ipc_file(&expected_data_ref); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let data = table + .query() + .execute() + .await + .unwrap() + .collect::>() + .await; + assert_eq!(data.len(), 1); + assert_eq!(data[0].as_ref().unwrap(), &expected_data); + } + #[tokio::test] async fn test_query_vector_default_values() { let expected_data = RecordBatch::try_new(