From cbb5a841b105c9e9b994531f874165bdcb443938 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 19 Jun 2025 10:32:32 +0800 Subject: [PATCH] feat: support prefix matching and must_not clause (#2441) --- nodejs/__test__/table.test.ts | 38 ++++++++++++--- nodejs/lancedb/query.ts | 7 ++- nodejs/src/query.rs | 2 + python/python/lancedb/query.py | 7 ++- python/python/tests/docs/test_search.py | 65 ++++++++++++++++++++++++- python/src/query.rs | 38 ++++++--------- 6 files changed, 124 insertions(+), 33 deletions(-) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 23fe67dd..fd091a67 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -1650,13 +1650,25 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( expect(resultSet.has("fob")).toBe(true); expect(resultSet.has("fo")).toBe(true); expect(resultSet.has("food")).toBe(true); + + const prefixResults = await table + .search( + new MatchQuery("foo", "text", { fuzziness: 3, prefixLength: 3 }), + ) + .toArray(); + expect(prefixResults.length).toBe(2); + const resultSet2 = new Set(prefixResults.map((r) => r.text)); + expect(resultSet2.has("foo")).toBe(true); + expect(resultSet2.has("food")).toBe(true); }); test("full text search boolean query", async () => { const db = await connect(tmpDir.name); const data = [ - { text: "hello world", vector: [0.1, 0.2, 0.3] }, - { text: "goodbye world", vector: [0.4, 0.5, 0.6] }, + { text: "The cat and dog are playing" }, + { text: "The cat is sleeping" }, + { text: "The dog is barking" }, + { text: "The dog chases the cat" }, ]; const table = await db.createTable("test", data); await table.createIndex("text", { @@ -1666,22 +1678,32 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( const shouldResults = await table .search( new BooleanQuery([ - [Occur.Should, new MatchQuery("hello", "text")], - [Occur.Should, new MatchQuery("goodbye", "text")], + [Occur.Should, new MatchQuery("cat", "text")], + [Occur.Should, new MatchQuery("dog", "text")], ]), ) .toArray(); - expect(shouldResults.length).toBe(2); + expect(shouldResults.length).toBe(4); const mustResults = await table .search( new BooleanQuery([ - [Occur.Must, new MatchQuery("hello", "text")], - [Occur.Must, new MatchQuery("world", "text")], + [Occur.Must, new MatchQuery("cat", "text")], + [Occur.Must, new MatchQuery("dog", "text")], ]), ) .toArray(); - expect(mustResults.length).toBe(1); + expect(mustResults.length).toBe(2); + + const mustNotResults = await table + .search( + new BooleanQuery([ + [Occur.Must, new MatchQuery("cat", "text")], + [Occur.MustNot, new MatchQuery("dog", "text")], + ]), + ) + .toArray(); + expect(mustNotResults.length).toBe(1); }); test.each([ diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index c9fc97d9..9ea82145 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -812,10 +812,12 @@ export enum Operator { * * - `Must`: The term must be present in the document. * - `Should`: The term should contribute to the document score, but is not required. + * - `MustNot`: The term must not be present in the document. */ export enum Occur { - Must = "MUST", Should = "SHOULD", + Must = "MUST", + MustNot = "MUST_NOT", } /** @@ -856,6 +858,7 @@ export class MatchQuery implements FullTextQuery { * - `fuzziness`: The fuzziness level for the query (default is 0). * - `maxExpansions`: The maximum number of terms to consider for fuzzy matching (default is 50). * - `operator`: The logical operator to use for combining terms in the query (default is "OR"). + * - `prefixLength`: The number of beginning characters being unchanged for fuzzy matching. */ constructor( query: string, @@ -865,6 +868,7 @@ export class MatchQuery implements FullTextQuery { fuzziness?: number; maxExpansions?: number; operator?: Operator; + prefixLength?: number; }, ) { let fuzziness = options?.fuzziness; @@ -878,6 +882,7 @@ export class MatchQuery implements FullTextQuery { fuzziness, options?.maxExpansions ?? 50, options?.operator ?? Operator.Or, + options?.prefixLength ?? 0, ); } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 00630c3c..aa28aa05 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -335,6 +335,7 @@ impl JsFullTextQuery { fuzziness: Option, max_expansions: u32, operator: String, + prefix_length: u32, ) -> napi::Result { Ok(Self { inner: MatchQuery::new(query) @@ -347,6 +348,7 @@ impl JsFullTextQuery { napi::Error::from_reason(format!("Invalid operator: {}", e)) })?, ) + .with_prefix_length(prefix_length) .into(), }) } diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 4025b1a2..338a7f16 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -101,8 +101,9 @@ class FullTextOperator(str, Enum): class Occur(str, Enum): - MUST = "MUST" SHOULD = "SHOULD" + MUST = "MUST" + MUST_NOT = "MUST_NOT" @pydantic.dataclasses.dataclass @@ -181,6 +182,9 @@ class MatchQuery(FullTextQuery): Can be either `AND` or `OR`. If `AND`, all terms in the query must match. If `OR`, at least one term in the query must match. + prefix_length : int, optional + The number of beginning characters being unchanged for fuzzy matching. + This is useful to achieve prefix matching. """ query: str @@ -189,6 +193,7 @@ class MatchQuery(FullTextQuery): fuzziness: int = pydantic.Field(0, kw_only=True) max_expansions: int = pydantic.Field(50, kw_only=True) operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True) + prefix_length: int = pydantic.Field(0, kw_only=True) def query_type(self) -> FullTextQueryType: return FullTextQueryType.MATCH diff --git a/python/python/tests/docs/test_search.py b/python/python/tests/docs/test_search.py index 913b2138..d9065163 100644 --- a/python/python/tests/docs/test_search.py +++ b/python/python/tests/docs/test_search.py @@ -6,7 +6,7 @@ import lancedb # --8<-- [end:import-lancedb] # --8<-- [start:import-numpy] -from lancedb.query import BoostQuery, MatchQuery +from lancedb.query import BooleanQuery, BoostQuery, MatchQuery, Occur import numpy as np import pyarrow as pa @@ -191,6 +191,15 @@ def test_fts_fuzzy_query(): "food", # 1 insertion } + results = table.search( + MatchQuery("foo", "text", fuzziness=1, prefix_length=3) + ).to_pandas() + assert len(results) == 2 + assert set(results["text"].to_list()) == { + "foo", + "food", + } + @pytest.mark.skipif( os.name == "nt", reason="Need to fix https://github.com/lancedb/lance/issues/3905" @@ -240,6 +249,60 @@ def test_fts_boost_query(): ) +@pytest.mark.skipif( + os.name == "nt", reason="Need to fix https://github.com/lancedb/lance/issues/3905" +) +def test_fts_boolean_query(tmp_path): + uri = tmp_path / "boolean-example" + db = lancedb.connect(uri) + table = db.create_table( + "my_table_fts_boolean", + data=[ + {"text": "The cat and dog are playing"}, + {"text": "The cat is sleeping"}, + {"text": "The dog is barking"}, + {"text": "The dog chases the cat"}, + ], + mode="overwrite", + ) + table.create_fts_index("text", use_tantivy=False, replace=True) + + # SHOULD + results = table.search( + MatchQuery("cat", "text") | MatchQuery("dog", "text") + ).to_pandas() + assert len(results) == 4 + assert set(results["text"].to_list()) == { + "The cat and dog are playing", + "The cat is sleeping", + "The dog is barking", + "The dog chases the cat", + } + # MUST + results = table.search( + MatchQuery("cat", "text") & MatchQuery("dog", "text") + ).to_pandas() + assert len(results) == 2 + assert set(results["text"].to_list()) == { + "The cat and dog are playing", + "The dog chases the cat", + } + + # MUST NOT + results = table.search( + BooleanQuery( + [ + (Occur.MUST, MatchQuery("cat", "text")), + (Occur.MUST_NOT, MatchQuery("dog", "text")), + ] + ) + ).to_pandas() + assert len(results) == 1 + assert set(results["text"].to_list()) == { + "The cat is sleeping", + } + + @pytest.mark.skipif( os.name == "nt", reason="Need to fix https://github.com/lancedb/lance/issues/3905" ) diff --git a/python/src/query.rs b/python/src/query.rs index 1e58960d..ad0309ca 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -50,6 +50,7 @@ impl FromPyObject<'_> for PyLanceDB { let fuzziness = ob.getattr("fuzziness")?.extract()?; let max_expansions = ob.getattr("max_expansions")?.extract()?; let operator = ob.getattr("operator")?.extract::()?; + let prefix_length = ob.getattr("prefix_length")?.extract()?; Ok(PyLanceDB( MatchQuery::new(query) @@ -60,6 +61,7 @@ impl FromPyObject<'_> for PyLanceDB { .with_operator(Operator::try_from(operator.as_str()).map_err(|e| { PyValueError::new_err(format!("Invalid operator: {}", e)) })?) + .with_prefix_length(prefix_length) .into(), )) } @@ -139,7 +141,8 @@ impl<'py> IntoPyObject<'py> for PyLanceDB { kwargs.set_item("boost", query.boost)?; kwargs.set_item("fuzziness", query.fuzziness)?; kwargs.set_item("max_expansions", query.max_expansions)?; - kwargs.set_item("operator", operator_to_str(query.operator))?; + kwargs.set_item::<_, &str>("operator", query.operator.into())?; + kwargs.set_item("prefix_length", query.prefix_length)?; namespace .getattr(intern!(py, "MatchQuery"))? .call((query.terms, query.column.unwrap()), Some(&kwargs)) @@ -169,19 +172,25 @@ impl<'py> IntoPyObject<'py> for PyLanceDB { .unzip(); let kwargs = PyDict::new(py); kwargs.set_item("boosts", boosts)?; - kwargs.set_item("operator", operator_to_str(first.operator))?; + kwargs.set_item::<_, &str>("operator", first.operator.into())?; namespace .getattr(intern!(py, "MultiMatchQuery"))? .call((first.terms.clone(), columns), Some(&kwargs)) } FtsQuery::Boolean(query) => { - let mut queries = Vec::with_capacity(query.must.len() + query.should.len()); - for q in query.must { - queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?)); - } + let mut queries: Vec<(&str, Bound<'py, PyAny>)> = Vec::with_capacity( + query.should.len() + query.must.len() + query.must_not.len(), + ); for q in query.should { - queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?)); + queries.push((Occur::Should.into(), PyLanceDB(q).into_pyobject(py)?)); } + for q in query.must { + queries.push((Occur::Must.into(), PyLanceDB(q).into_pyobject(py)?)); + } + for q in query.must_not { + queries.push((Occur::MustNot.into(), PyLanceDB(q).into_pyobject(py)?)); + } + namespace .getattr(intern!(py, "BooleanQuery"))? .call1((queries,)) @@ -190,21 +199,6 @@ impl<'py> IntoPyObject<'py> for PyLanceDB { } } -fn operator_to_str(op: Operator) -> &'static str { - match op { - Operator::And => "AND", - Operator::Or => "OR", - } -} - -fn occur_to_str(occur: Occur) -> &'static str { - match occur { - Occur::Must => "MUST", - Occur::Should => "SHOULD", - Occur::MustNot => "MUST NOT", - } -} - // Python representation of query vector(s) #[derive(Clone)] pub struct PyQueryVectors(Vec>);