diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 00a507f7..c0fa4712 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -165,17 +165,14 @@ class HybridQuery: def get_with_row_id(self) -> bool: ... def to_query_request(self) -> PyQueryRequest: ... -class PyFullTextSearchQuery: - columns: Optional[List[str]] - query: str - limit: Optional[int] - wand_factor: Optional[float] +class FullTextQuery: + pass class PyQueryRequest: limit: Optional[int] offset: Optional[int] filter: Optional[Union[str, bytes]] - full_text_search: Optional[PyFullTextSearchQuery] + full_text_search: Optional[FullTextQuery] select: Optional[Union[str, List[str]]] fast_search: Optional[bool] with_row_id: Optional[bool] diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 0b870dac..2c248699 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -4,7 +4,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -import abc from concurrent.futures import ThreadPoolExecutor from enum import Enum from datetime import timedelta @@ -88,15 +87,27 @@ def ensure_vector_query( return val -class FullTextQueryType(Enum): +class FullTextQueryType(str, Enum): MATCH = "match" MATCH_PHRASE = "match_phrase" BOOST = "boost" MULTI_MATCH = "multi_match" + BOOLEAN = "boolean" -class FullTextQuery(abc.ABC, pydantic.BaseModel): - @abc.abstractmethod +class FullTextOperator(str, Enum): + AND = "AND" + OR = "OR" + + +class Occur(str, Enum): + MUST = "MUST" + SHOULD = "SHOULD" + + +@pydantic.dataclasses.dataclass +class FullTextQuery(ABC): + @abstractmethod def query_type(self) -> FullTextQueryType: """ Get the query type of the query. @@ -106,193 +117,174 @@ class FullTextQuery(abc.ABC, pydantic.BaseModel): str The type of the query. """ + pass - @abc.abstractmethod - def to_dict(self) -> dict: + def __and__(self, other: "FullTextQuery") -> "FullTextQuery": """ - Convert the query to a dictionary. - - Returns - ------- - dict - The query as a dictionary. - """ - - -class MatchQuery(FullTextQuery): - query: str - column: str - boost: float = 1.0 - fuzziness: int = 0 - max_expansions: int = 50 - - def __init__( - self, - query: str, - column: str, - *, - boost: float = 1.0, - fuzziness: int = 0, - max_expansions: int = 50, - ): - """ - Match query for full-text search. + Combine two queries with a logical AND operation. Parameters ---------- - query : str - The query string to match against. - column : str - The name of the column to match against. - boost : float, default 1.0 - The boost factor for the query. - The score of each matching document is multiplied by this value. - fuzziness : int, optional - The maximum edit distance for each term in the match query. - Defaults to 0 (exact match). - If None, fuzziness is applied automatically by the rules: - - 0 for terms with length <= 2 - - 1 for terms with length <= 5 - - 2 for terms with length > 5 - max_expansions : int, optional - The maximum number of terms to consider for fuzzy matching. - Defaults to 50. + other : FullTextQuery + The other query to combine with. + + Returns + ------- + FullTextQuery + A new query that combines both queries with AND. """ - super().__init__( - query=query, - column=column, - boost=boost, - fuzziness=fuzziness, - max_expansions=max_expansions, - ) + return BooleanQuery([(Occur.MUST, self), (Occur.MUST, other)]) + + def __or__(self, other: "FullTextQuery") -> "FullTextQuery": + """ + Combine two queries with a logical OR operation. + + Parameters + ---------- + other : FullTextQuery + The other query to combine with. + + Returns + ------- + FullTextQuery + A new query that combines both queries with OR. + """ + return BooleanQuery([(Occur.SHOULD, self), (Occur.SHOULD, other)]) + + +@pydantic.dataclasses.dataclass +class MatchQuery(FullTextQuery): + """ + Match query for full-text search. + + Parameters + ---------- + query : str + The query string to match against. + column : str + The name of the column to match against. + boost : float, default 1.0 + The boost factor for the query. + The score of each matching document is multiplied by this value. + fuzziness : int, optional + The maximum edit distance for each term in the match query. + Defaults to 0 (exact match). + If None, fuzziness is applied automatically by the rules: + - 0 for terms with length <= 2 + - 1 for terms with length <= 5 + - 2 for terms with length > 5 + max_expansions : int, optional + The maximum number of terms to consider for fuzzy matching. + Defaults to 50. + operator : FullTextOperator, default OR + The operator to use for combining the query results. + Can be either `AND` or `OR`. + If `AND`, all terms in the query must match. + If `OR`, at least one term in the query must match. + """ + + query: str + column: str + boost: float = pydantic.Field(1.0, kw_only=True) + fuzziness: int = pydantic.Field(0, kw_only=True) + max_expansions: int = pydantic.Field(50, kw_only=True) + operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True) def query_type(self) -> FullTextQueryType: return FullTextQueryType.MATCH - def to_dict(self) -> dict: - return { - "match": { - self.column: { - "query": self.query, - "boost": self.boost, - "fuzziness": self.fuzziness, - "max_expansions": self.max_expansions, - } - } - } - +@pydantic.dataclasses.dataclass class PhraseQuery(FullTextQuery): + """ + Phrase query for full-text search. + + Parameters + ---------- + query : str + The query string to match against. + column : str + The name of the column to match against. + """ + query: str column: str - - def __init__(self, query: str, column: str): - """ - Phrase query for full-text search. - - Parameters - ---------- - query : str - The query string to match against. - column : str - The name of the column to match against. - """ - super().__init__(query=query, column=column) + slop: int = pydantic.Field(0, kw_only=True) def query_type(self) -> FullTextQueryType: return FullTextQueryType.MATCH_PHRASE - def to_dict(self) -> dict: - return { - "match_phrase": { - self.column: self.query, - } - } - +@pydantic.dataclasses.dataclass class BoostQuery(FullTextQuery): + """ + Boost query for full-text search. + + Parameters + ---------- + positive : dict + The positive query object. + negative : dict + The negative query object. + negative_boost : float, default 0.5 + The boost factor for the negative query. + """ + positive: FullTextQuery negative: FullTextQuery - negative_boost: float = 0.5 - - def __init__( - self, - positive: FullTextQuery, - negative: FullTextQuery, - *, - negative_boost: float = 0.5, - ): - """ - Boost query for full-text search. - - Parameters - ---------- - positive : dict - The positive query object. - negative : dict - The negative query object. - negative_boost : float - The boost factor for the negative query. - """ - super().__init__( - positive=positive, negative=negative, negative_boost=negative_boost - ) + negative_boost: float = pydantic.Field(0.5, kw_only=True) def query_type(self) -> FullTextQueryType: return FullTextQueryType.BOOST - def to_dict(self) -> dict: - return { - "boost": { - "positive": self.positive.to_dict(), - "negative": self.negative.to_dict(), - "negative_boost": self.negative_boost, - } - } - +@pydantic.dataclasses.dataclass class MultiMatchQuery(FullTextQuery): + """ + Multi-match query for full-text search. + + Parameters + ---------- + query : str | list[Query] + If a string, the query string to match against. + columns : list[str] + The list of columns to match against. + boosts : list[float], optional + The list of boost factors for each column. If not provided, + all columns will have the same boost factor. + operator : FullTextOperator, default OR + The operator to use for combining the query results. + Can be either `AND` or `OR`. + It would be applied to all columns individually. + For example, if the operator is `AND`, + then the query "hello world" is equal to + `match("hello AND world", column1) OR match("hello AND world", column2)`. + """ + query: str columns: list[str] - boosts: list[float] - - def __init__( - self, - query: str, - columns: list[str], - *, - boosts: Optional[list[float]] = None, - ): - """ - Multi-match query for full-text search. - - Parameters - ---------- - query : str - The query string to match against. - - columns : list[str] - The list of columns to match against. - - boosts : list[float], optional - The list of boost factors for each column. If not provided, - all columns will have the same boost factor. - """ - if boosts is None: - boosts = [1.0] * len(columns) - super().__init__(query=query, columns=columns, boosts=boosts) + boosts: Optional[list[float]] = pydantic.Field(None, kw_only=True) + operator: FullTextOperator = pydantic.Field(FullTextOperator.OR, kw_only=True) def query_type(self) -> FullTextQueryType: return FullTextQueryType.MULTI_MATCH - def to_dict(self) -> dict: - return { - "multi_match": { - "query": self.query, - "columns": self.columns, - "boost": self.boosts, - } - } + +@pydantic.dataclasses.dataclass +class BooleanQuery(FullTextQuery): + """ + Boolean query for full-text search. + + Parameters + ---------- + queries : list[tuple(Occur, FullTextQuery)] + The list of queries with their occurrence requirements. + """ + + queries: list[tuple[Occur, FullTextQuery]] + + def query_type(self) -> FullTextQueryType: + return FullTextQueryType.BOOLEAN class FullTextSearchQuery(pydantic.BaseModel): @@ -493,10 +485,8 @@ class Query(pydantic.BaseModel): query.postfilter = req.postfilter if req.full_text_search is not None: query.full_text_query = FullTextSearchQuery( - columns=req.full_text_search.columns, - query=req.full_text_search.query, - limit=req.full_text_search.limit, - wand_factor=req.full_text_search.wand_factor, + columns=None, + query=req.full_text_search, ) return query @@ -2513,7 +2503,7 @@ class AsyncQuery(AsyncQueryBase): self._inner.nearest_to_text({"query": query, "columns": columns}) ) # FullTextQuery object - return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()})) + return AsyncFTSQuery(self._inner.nearest_to_text({"query": query})) class AsyncFTSQuery(AsyncQueryBase): @@ -2835,7 +2825,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): self._inner.nearest_to_text({"query": query, "columns": columns}) ) # FullTextQuery object - return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()})) + return AsyncHybridQuery(self._inner.nearest_to_text({"query": query})) async def to_batches( self, diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index b2d31c92..2171d4d2 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -215,6 +215,19 @@ def test_search_fts(table, use_tantivy): assert len(results) == 5 assert len(results[0]) == 3 # id, text, _score + # Test boolean query + results = ( + table.search(MatchQuery("puppy", "text") & MatchQuery("runs", "text")) + .select(["id", "text"]) + .limit(5) + .to_list() + ) + assert len(results) == 5 + assert len(results[0]) == 3 # id, text, _score + for r in results: + assert "puppy" in r["text"] + assert "runs" in r["text"] + @pytest.mark.asyncio async def test_fts_select_async(async_table): diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 23d0e35f..c50642fb 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -25,6 +25,8 @@ from lancedb.query import ( AsyncQueryBase, AsyncVectorQuery, LanceVectorQueryBuilder, + MatchQuery, + PhraseQuery, Query, FullTextSearchQuery, ) @@ -1065,18 +1067,27 @@ async def test_query_serialization_async(table_async: AsyncTable): ) # FTS queries - q = (await table_async.search("foo")).limit(10).to_query_object() + match_query = MatchQuery("foo", "text") + q = (await table_async.search(match_query)).limit(10).to_query_object() check_set_props( q, limit=10, - full_text_query=FullTextSearchQuery(columns=[], query="foo"), + full_text_query=FullTextSearchQuery(columns=None, query=match_query), with_row_id=False, ) - q = (await table_async.search("foo", query_type="fts")).to_query_object() + q = (await table_async.search(match_query)).to_query_object() check_set_props( q, - full_text_query=FullTextSearchQuery(columns=[], query="foo"), + full_text_query=FullTextSearchQuery(columns=None, query=match_query), + with_row_id=False, + ) + + phrase_query = PhraseQuery("foo", "text", slop=1) + q = (await table_async.search(phrase_query)).to_query_object() + check_set_props( + q, + full_text_query=FullTextSearchQuery(columns=None, query=phrase_query), with_row_id=False, ) diff --git a/python/src/query.rs b/python/src/query.rs index f2cb2d62..49c2b392 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -9,15 +9,16 @@ use arrow::array::Array; use arrow::array::ArrayData; use arrow::pyarrow::FromPyArrow; use arrow::pyarrow::IntoPyArrow; -use lancedb::index::scalar::{FtsQuery, FullTextSearchQuery, MatchQuery, PhraseQuery}; +use lancedb::index::scalar::{ + BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur, + Operator, PhraseQuery, +}; use lancedb::query::QueryExecutionOptions; use lancedb::query::QueryFilter; use lancedb::query::{ ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, }; use lancedb::table::AnyQuery; -use pyo3::exceptions::PyRuntimeError; -use pyo3::exceptions::{PyNotImplementedError, PyValueError}; use pyo3::prelude::{PyAnyMethods, PyDictMethods}; use pyo3::pymethods; use pyo3::types::PyList; @@ -27,34 +28,182 @@ use pyo3::IntoPyObject; use pyo3::PyAny; use pyo3::PyRef; use pyo3::PyResult; +use pyo3::{exceptions::PyRuntimeError, FromPyObject}; +use pyo3::{ + exceptions::{PyNotImplementedError, PyValueError}, + intern, +}; use pyo3::{pyclass, PyErr}; use pyo3_async_runtimes::tokio::future_into_py; -use crate::arrow::RecordBatchStream; -use crate::error::PythonErrorExt; -use crate::util::{parse_distance_type, parse_fts_query}; +use crate::util::parse_distance_type; +use crate::{arrow::RecordBatchStream, util::PyLanceDB}; +use crate::{error::PythonErrorExt, index::class_name}; -// Python representation of full text search parameters -#[derive(Clone)] -#[pyclass(get_all)] -pub struct PyFullTextSearchQuery { - pub columns: Vec, - pub query: String, - pub limit: Option, - pub wand_factor: Option, +impl FromPyObject<'_> for PyLanceDB { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + match class_name(ob)?.as_str() { + "MatchQuery" => { + let query = ob.getattr("query")?.extract()?; + let column = ob.getattr("column")?.extract()?; + let boost = ob.getattr("boost")?.extract()?; + let fuzziness = ob.getattr("fuzziness")?.extract()?; + let max_expansions = ob.getattr("max_expansions")?.extract()?; + let operator = ob.getattr("operator")?.extract::()?; + + Ok(PyLanceDB( + MatchQuery::new(query) + .with_column(Some(column)) + .with_boost(boost) + .with_fuzziness(fuzziness) + .with_max_expansions(max_expansions) + .with_operator(Operator::try_from(operator.as_str()).map_err(|e| { + PyValueError::new_err(format!("Invalid operator: {}", e)) + })?) + .into(), + )) + } + "PhraseQuery" => { + let query = ob.getattr("query")?.extract()?; + let column = ob.getattr("column")?.extract()?; + let slop = ob.getattr("slop")?.extract()?; + + Ok(PyLanceDB( + PhraseQuery::new(query) + .with_column(Some(column)) + .with_slop(slop) + .into(), + )) + } + "BoostQuery" => { + let positive: PyLanceDB = ob.getattr("positive")?.extract()?; + let negative: PyLanceDB = ob.getattr("negative")?.extract()?; + let negative_boost = ob.getattr("negative_boost")?.extract()?; + Ok(PyLanceDB( + BoostQuery::new(positive.0, negative.0, negative_boost).into(), + )) + } + "MultiMatchQuery" => { + let query = ob.getattr("query")?.extract()?; + let columns = ob.getattr("columns")?.extract()?; + let boosts: Option> = ob.getattr("boosts")?.extract()?; + let operator: String = ob.getattr("operator")?.extract()?; + + let q = MultiMatchQuery::try_new(query, columns) + .map_err(|e| PyValueError::new_err(format!("Invalid query: {}", e)))?; + let q = if let Some(boosts) = boosts { + q.try_with_boosts(boosts) + .map_err(|e| PyValueError::new_err(format!("Invalid boosts: {}", e)))? + } else { + q + }; + + let op = Operator::try_from(operator.as_str()) + .map_err(|e| PyValueError::new_err(format!("Invalid operator: {}", e)))?; + + Ok(PyLanceDB(q.with_operator(op).into())) + } + "BooleanQuery" => { + let queries: Vec<(String, PyLanceDB)> = + ob.getattr("queries")?.extract()?; + let mut sub_queries = Vec::with_capacity(queries.len()); + for (occur, q) in queries { + let occur = Occur::try_from(occur.as_str()) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + sub_queries.push((occur, q.0)); + } + Ok(PyLanceDB(BooleanQuery::new(sub_queries).into())) + } + name => Err(PyValueError::new_err(format!( + "Unsupported FTS query type: {}", + name + ))), + } + } } -impl From for PyFullTextSearchQuery { - fn from(query: FullTextSearchQuery) -> Self { - Self { - columns: query.columns().into_iter().collect(), - query: query.query.query().to_owned(), - limit: query.limit, - wand_factor: query.wand_factor, +impl<'py> IntoPyObject<'py> for PyLanceDB { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult { + let namespace = py + .import(intern!(py, "lancedb")) + .and_then(|m| m.getattr(intern!(py, "query"))) + .expect("Failed to import namespace"); + + match self.0 { + FtsQuery::Match(query) => { + let kwargs = PyDict::new(py); + kwargs.set_item("boost", query.boost)?; + kwargs.set_item("fuzziness", query.fuzziness)?; + kwargs.set_item("max_expansions", query.max_expansions)?; + kwargs.set_item("operator", operator_to_str(query.operator))?; + namespace + .getattr(intern!(py, "MatchQuery"))? + .call((query.terms, query.column.unwrap()), Some(&kwargs)) + } + FtsQuery::Phrase(query) => { + let kwargs = PyDict::new(py); + kwargs.set_item("slop", query.slop)?; + namespace + .getattr(intern!(py, "PhraseQuery"))? + .call((query.terms, query.column.unwrap()), Some(&kwargs)) + } + FtsQuery::Boost(query) => { + let positive = PyLanceDB(query.positive.as_ref().clone()).into_pyobject(py)?; + let negative = PyLanceDB(query.negative.as_ref().clone()).into_pyobject(py)?; + let kwargs = PyDict::new(py); + kwargs.set_item("negative_boost", query.negative_boost)?; + namespace + .getattr(intern!(py, "BoostQuery"))? + .call((positive, negative), Some(&kwargs)) + } + FtsQuery::MultiMatch(query) => { + let first = &query.match_queries[0]; + let (columns, boosts): (Vec<_>, Vec<_>) = query + .match_queries + .iter() + .map(|q| (q.column.as_ref().unwrap().clone(), q.boost)) + .unzip(); + let kwargs = PyDict::new(py); + kwargs.set_item("boosts", boosts)?; + kwargs.set_item("operator", operator_to_str(first.operator))?; + namespace + .getattr(intern!(py, "MultiMatchQuery"))? + .call((first.terms.clone(), columns), Some(&kwargs)) + } + FtsQuery::Boolean(query) => { + let mut queries = Vec::with_capacity(query.must.len() + query.should.len()); + for q in query.must { + queries.push((occur_to_str(Occur::Must), PyLanceDB(q).into_pyobject(py)?)); + } + for q in query.should { + queries.push((occur_to_str(Occur::Should), PyLanceDB(q).into_pyobject(py)?)); + } + namespace + .getattr(intern!(py, "BooleanQuery"))? + .call1((queries,)) + } } } } +fn operator_to_str(op: Operator) -> &'static str { + match op { + Operator::And => "AND", + Operator::Or => "OR", + } +} + +fn occur_to_str(occur: Occur) -> &'static str { + match occur { + Occur::Must => "MUST", + Occur::Should => "SHOULD", + } +} + // Python representation of query vector(s) #[derive(Clone)] pub struct PyQueryVectors(Vec>); @@ -80,7 +229,7 @@ pub struct PyQueryRequest { pub limit: Option, pub offset: Option, pub filter: Option, - pub full_text_search: Option, + pub full_text_search: Option>, pub select: PySelect, pub fast_search: Option, pub with_row_id: Option, @@ -106,7 +255,7 @@ impl From for PyQueryRequest { filter: query_request.filter.map(PyQueryFilter), full_text_search: query_request .full_text_search - .map(PyFullTextSearchQuery::from), + .map(|fts| PyLanceDB(fts.query)), select: PySelect(query_request.select), fast_search: Some(query_request.fast_search), with_row_id: Some(query_request.with_row_id), @@ -269,8 +418,8 @@ impl Query { } }; let mut query = FullTextSearchQuery::new_query(query); - if let Some(cols) = columns { - if !cols.is_empty() { + match columns { + Some(cols) if !cols.is_empty() => { query = query.with_columns(&cols).map_err(|e| { PyValueError::new_err(format!( "Failed to set full text search columns: {}", @@ -278,15 +427,12 @@ impl Query { )) })?; } + _ => {} } query - } else if let Ok(query) = fts_query.downcast::() { - let query = parse_fts_query(query)?; - FullTextSearchQuery::new_query(query) } else { - return Err(PyValueError::new_err( - "query must be a string or a Query object", - )); + let query = fts_query.extract::>()?; + FullTextSearchQuery::new_query(query.0) }; Ok(FTSQuery { diff --git a/python/src/util.rs b/python/src/util.rs index e438cd6a..8ec8f40c 100644 --- a/python/src/util.rs +++ b/python/src/util.rs @@ -3,15 +3,11 @@ use std::sync::Mutex; -use lancedb::index::scalar::{BoostQuery, FtsQuery, MatchQuery, MultiMatchQuery, PhraseQuery}; use lancedb::DistanceType; -use pyo3::prelude::{PyAnyMethods, PyDictMethods, PyListMethods}; -use pyo3::types::PyDict; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyfunction, PyResult, }; -use pyo3::{Bound, PyAny}; /// A wrapper around a rust builder /// @@ -64,116 +60,6 @@ pub fn validate_table_name(table_name: &str) -> PyResult<()> { .map_err(|e| PyValueError::new_err(e.to_string())) } -pub fn parse_fts_query(query: &Bound<'_, PyDict>) -> PyResult { - let query_type = query.keys().get_item(0)?.extract::()?; - let query_value = query - .get_item(&query_type)? - .ok_or(PyValueError::new_err(format!( - "Query type {} not found", - query_type - )))?; - let query_value = query_value.downcast::()?; - - match query_type.as_str() { - "match" => { - let column = query_value.keys().get_item(0)?.extract::()?; - let params = query_value - .get_item(&column)? - .ok_or(PyValueError::new_err(format!( - "column {} not found", - column - )))?; - let params = params.downcast::()?; - - let query = params - .get_item("query")? - .ok_or(PyValueError::new_err("query not found"))? - .extract::()?; - let boost = params - .get_item("boost")? - .ok_or(PyValueError::new_err("boost not found"))? - .extract::()?; - let fuzziness = params - .get_item("fuzziness")? - .ok_or(PyValueError::new_err("fuzziness not found"))? - .extract::>()?; - let max_expansions = params - .get_item("max_expansions")? - .ok_or(PyValueError::new_err("max_expansions not found"))? - .extract::()?; - - let query = MatchQuery::new(query) - .with_column(Some(column)) - .with_boost(boost) - .with_fuzziness(fuzziness) - .with_max_expansions(max_expansions); - Ok(query.into()) - } - - "match_phrase" => { - let column = query_value.keys().get_item(0)?.extract::()?; - let query = query_value - .get_item(&column)? - .ok_or(PyValueError::new_err(format!( - "column {} not found", - column - )))? - .extract::()?; - - let query = PhraseQuery::new(query).with_column(Some(column)); - Ok(query.into()) - } - - "boost" => { - let positive: Bound<'_, PyAny> = query_value - .get_item("positive")? - .ok_or(PyValueError::new_err("positive not found"))?; - let positive = positive.downcast::()?; - - let negative = query_value - .get_item("negative")? - .ok_or(PyValueError::new_err("negative not found"))?; - let negative = negative.downcast::()?; - - let negative_boost = query_value - .get_item("negative_boost")? - .ok_or(PyValueError::new_err("negative_boost not found"))? - .extract::()?; - - let positive_query = parse_fts_query(positive)?; - let negative_query = parse_fts_query(negative)?; - let query = BoostQuery::new(positive_query, negative_query, Some(negative_boost)); - - Ok(query.into()) - } - - "multi_match" => { - let query = query_value - .get_item("query")? - .ok_or(PyValueError::new_err("query not found"))? - .extract::()?; - - let columns = query_value - .get_item("columns")? - .ok_or(PyValueError::new_err("columns not found"))? - .extract::>()?; - - let boost = query_value - .get_item("boost")? - .ok_or(PyValueError::new_err("boost not found"))? - .extract::>()?; - - let query = MultiMatchQuery::try_new(query, columns) - .and_then(|q| q.try_with_boosts(boost)) - .map_err(|e| { - PyValueError::new_err(format!("Error creating MultiMatchQuery: {}", e)) - })?; - Ok(query.into()) - } - - _ => Err(PyValueError::new_err(format!( - "Unsupported query type: {}", - query_type - ))), - } -} +/// A wrapper around a LanceDB type to allow it to be used in Python +#[derive(Debug, Clone)] +pub struct PyLanceDB(pub T);