feat: support prefix matching and must_not clause (#2441)

This commit is contained in:
BubbleCal
2025-06-19 10:32:32 +08:00
committed by GitHub
parent c72f6770fd
commit cbb5a841b1
6 changed files with 124 additions and 33 deletions

View File

@@ -50,6 +50,7 @@ impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
let fuzziness = ob.getattr("fuzziness")?.extract()?;
let max_expansions = ob.getattr("max_expansions")?.extract()?;
let operator = ob.getattr("operator")?.extract::<String>()?;
let prefix_length = ob.getattr("prefix_length")?.extract()?;
Ok(PyLanceDB(
MatchQuery::new(query)
@@ -60,6 +61,7 @@ impl FromPyObject<'_> for PyLanceDB<FtsQuery> {
.with_operator(Operator::try_from(operator.as_str()).map_err(|e| {
PyValueError::new_err(format!("Invalid operator: {}", e))
})?)
.with_prefix_length(prefix_length)
.into(),
))
}
@@ -139,7 +141,8 @@ impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
kwargs.set_item("boost", query.boost)?;
kwargs.set_item("fuzziness", query.fuzziness)?;
kwargs.set_item("max_expansions", query.max_expansions)?;
kwargs.set_item("operator", operator_to_str(query.operator))?;
kwargs.set_item::<_, &str>("operator", query.operator.into())?;
kwargs.set_item("prefix_length", query.prefix_length)?;
namespace
.getattr(intern!(py, "MatchQuery"))?
.call((query.terms, query.column.unwrap()), Some(&kwargs))
@@ -169,19 +172,25 @@ impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
.unzip();
let kwargs = PyDict::new(py);
kwargs.set_item("boosts", boosts)?;
kwargs.set_item("operator", operator_to_str(first.operator))?;
kwargs.set_item::<_, &str>("operator", first.operator.into())?;
namespace
.getattr(intern!(py, "MultiMatchQuery"))?
.call((first.terms.clone(), columns), Some(&kwargs))
}
FtsQuery::Boolean(query) => {
let mut queries = Vec::with_capacity(query.must.len() + query.should.len());
for q in query.must {
queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?));
}
let mut queries: Vec<(&str, Bound<'py, PyAny>)> = Vec::with_capacity(
query.should.len() + query.must.len() + query.must_not.len(),
);
for q in query.should {
queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?));
queries.push((Occur::Should.into(), PyLanceDB(q).into_pyobject(py)?));
}
for q in query.must {
queries.push((Occur::Must.into(), PyLanceDB(q).into_pyobject(py)?));
}
for q in query.must_not {
queries.push((Occur::MustNot.into(), PyLanceDB(q).into_pyobject(py)?));
}
namespace
.getattr(intern!(py, "BooleanQuery"))?
.call1((queries,))
@@ -190,21 +199,6 @@ impl<'py> IntoPyObject<'py> for PyLanceDB<FtsQuery> {
}
}
fn operator_to_str(op: Operator) -> &'static str {
match op {
Operator::And => "AND",
Operator::Or => "OR",
}
}
fn occur_to_str(occur: Occur) -> &'static str {
match occur {
Occur::Must => "MUST",
Occur::Should => "SHOULD",
Occur::MustNot => "MUST NOT",
}
}
// Python representation of query vector(s)
#[derive(Clone)]
pub struct PyQueryVectors(Vec<Arc<dyn Array>>);