mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 04:42:57 +00:00
fix: can't do structured FTS in python (#2300)
missed to support it in `search()` API and there were some pydantic errors <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced full-text search capabilities by incorporating additional parameters, enabling more flexible query definitions. - Extended table search functionality to support full-text queries alongside existing search types. - **Tests** - Introduced new tests that validate both structured and conditional full-text search behaviors. - Expanded test coverage for various query types, including MatchQuery, BoostQuery, MultiMatchQuery, and PhraseQuery. - **Bug Fixes** - Fixed a logic issue in query processing to ensure correct handling of full-text search queries. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -117,6 +117,12 @@ class FullTextQuery(abc.ABC, pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class MatchQuery(FullTextQuery):
|
class MatchQuery(FullTextQuery):
|
||||||
|
query: str
|
||||||
|
column: str
|
||||||
|
boost: float = 1.0
|
||||||
|
fuzziness: int = 0
|
||||||
|
max_expansions: int = 50
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -149,11 +155,13 @@ class MatchQuery(FullTextQuery):
|
|||||||
The maximum number of terms to consider for fuzzy matching.
|
The maximum number of terms to consider for fuzzy matching.
|
||||||
Defaults to 50.
|
Defaults to 50.
|
||||||
"""
|
"""
|
||||||
self.column = column
|
super().__init__(
|
||||||
self.query = query
|
query=query,
|
||||||
self.boost = boost
|
column=column,
|
||||||
self.fuzziness = fuzziness
|
boost=boost,
|
||||||
self.max_expansions = max_expansions
|
fuzziness=fuzziness,
|
||||||
|
max_expansions=max_expansions,
|
||||||
|
)
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.MATCH
|
return FullTextQueryType.MATCH
|
||||||
@@ -172,6 +180,9 @@ class MatchQuery(FullTextQuery):
|
|||||||
|
|
||||||
|
|
||||||
class PhraseQuery(FullTextQuery):
|
class PhraseQuery(FullTextQuery):
|
||||||
|
query: str
|
||||||
|
column: str
|
||||||
|
|
||||||
def __init__(self, query: str, column: str):
|
def __init__(self, query: str, column: str):
|
||||||
"""
|
"""
|
||||||
Phrase query for full-text search.
|
Phrase query for full-text search.
|
||||||
@@ -183,8 +194,7 @@ class PhraseQuery(FullTextQuery):
|
|||||||
column : str
|
column : str
|
||||||
The name of the column to match against.
|
The name of the column to match against.
|
||||||
"""
|
"""
|
||||||
self.column = column
|
super().__init__(query=query, column=column)
|
||||||
self.query = query
|
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.MATCH_PHRASE
|
return FullTextQueryType.MATCH_PHRASE
|
||||||
@@ -198,11 +208,16 @@ class PhraseQuery(FullTextQuery):
|
|||||||
|
|
||||||
|
|
||||||
class BoostQuery(FullTextQuery):
|
class BoostQuery(FullTextQuery):
|
||||||
|
positive: FullTextQuery
|
||||||
|
negative: FullTextQuery
|
||||||
|
negative_boost: float = 0.5
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
positive: FullTextQuery,
|
positive: FullTextQuery,
|
||||||
negative: FullTextQuery,
|
negative: FullTextQuery,
|
||||||
negative_boost: float,
|
*,
|
||||||
|
negative_boost: float = 0.5,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Boost query for full-text search.
|
Boost query for full-text search.
|
||||||
@@ -216,9 +231,9 @@ class BoostQuery(FullTextQuery):
|
|||||||
negative_boost : float
|
negative_boost : float
|
||||||
The boost factor for the negative query.
|
The boost factor for the negative query.
|
||||||
"""
|
"""
|
||||||
self.positive = positive
|
super().__init__(
|
||||||
self.negative = negative
|
positive=positive, negative=negative, negative_boost=negative_boost
|
||||||
self.negative_boost = negative_boost
|
)
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.BOOST
|
return FullTextQueryType.BOOST
|
||||||
@@ -234,6 +249,10 @@ class BoostQuery(FullTextQuery):
|
|||||||
|
|
||||||
|
|
||||||
class MultiMatchQuery(FullTextQuery):
|
class MultiMatchQuery(FullTextQuery):
|
||||||
|
query: str
|
||||||
|
columns: list[str]
|
||||||
|
boosts: list[float]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -256,11 +275,9 @@ class MultiMatchQuery(FullTextQuery):
|
|||||||
The list of boost factors for each column. If not provided,
|
The list of boost factors for each column. If not provided,
|
||||||
all columns will have the same boost factor.
|
all columns will have the same boost factor.
|
||||||
"""
|
"""
|
||||||
self.query = query
|
|
||||||
self.columns = columns
|
|
||||||
if boosts is None:
|
if boosts is None:
|
||||||
boosts = [1.0] * len(columns)
|
boosts = [1.0] * len(columns)
|
||||||
self.boosts = boosts
|
super().__init__(query=query, columns=columns, boosts=boosts)
|
||||||
|
|
||||||
def query_type(self) -> FullTextQueryType:
|
def query_type(self) -> FullTextQueryType:
|
||||||
return FullTextQueryType.MULTI_MATCH
|
return FullTextQueryType.MULTI_MATCH
|
||||||
@@ -544,7 +561,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
table, query, vector_column_name, fts_columns=fts_columns
|
table, query, vector_column_name, fts_columns=fts_columns
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(query, str):
|
if isinstance(query, (str, FullTextQuery)):
|
||||||
# fts
|
# fts
|
||||||
return LanceFtsQueryBuilder(
|
return LanceFtsQueryBuilder(
|
||||||
table,
|
table,
|
||||||
@@ -569,8 +586,10 @@ class LanceQueryBuilder(ABC):
|
|||||||
# If query_type is fts, then query must be a string.
|
# If query_type is fts, then query must be a string.
|
||||||
# otherwise raise TypeError
|
# otherwise raise TypeError
|
||||||
if query_type == "fts":
|
if query_type == "fts":
|
||||||
if not isinstance(query, str):
|
if not isinstance(query, (str, FullTextQuery)):
|
||||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
raise TypeError(
|
||||||
|
f"'fts' query must be a string or FullTextQuery: {type(query)}"
|
||||||
|
)
|
||||||
return query, query_type
|
return query, query_type
|
||||||
elif query_type == "vector":
|
elif query_type == "vector":
|
||||||
query = cls._query_to_vector(table, query, vector_column_name)
|
query = cls._query_to_vector(table, query, vector_column_name)
|
||||||
@@ -1486,7 +1505,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
table: "Table",
|
table: "Table",
|
||||||
query: Optional[str] = None,
|
query: Optional[Union[str, FullTextQuery]] = None,
|
||||||
vector_column: Optional[str] = None,
|
vector_column: Optional[str] = None,
|
||||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
):
|
):
|
||||||
@@ -1516,8 +1535,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
text_query = text or query
|
text_query = text or query
|
||||||
if text_query is None:
|
if text_query is None:
|
||||||
raise ValueError("Text query must be provided for hybrid search.")
|
raise ValueError("Text query must be provided for hybrid search.")
|
||||||
if not isinstance(text_query, str):
|
if not isinstance(text_query, (str, FullTextQuery)):
|
||||||
raise ValueError("Text query must be a string")
|
raise ValueError("Text query must be a string or FullTextQuery")
|
||||||
|
|
||||||
return vector_query, text_query
|
return vector_query, text_query
|
||||||
|
|
||||||
@@ -2308,7 +2327,7 @@ class AsyncQuery(AsyncQueryBase):
|
|||||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||||
)
|
)
|
||||||
# FullTextQuery object
|
# FullTextQuery object
|
||||||
return AsyncFTSQuery(self._inner.nearest_to_text(query.to_dict()))
|
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||||
|
|
||||||
|
|
||||||
class AsyncFTSQuery(AsyncQueryBase):
|
class AsyncFTSQuery(AsyncQueryBase):
|
||||||
@@ -2627,7 +2646,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
|||||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||||
)
|
)
|
||||||
# FullTextQuery object
|
# FullTextQuery object
|
||||||
return AsyncHybridQuery(self._inner.nearest_to_text(query.to_dict()))
|
return AsyncHybridQuery(self._inner.nearest_to_text({"query": query.to_dict()}))
|
||||||
|
|
||||||
async def to_batches(
|
async def to_batches(
|
||||||
self, *, max_batch_length: Optional[int] = None
|
self, *, max_batch_length: Optional[int] = None
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from .query import (
|
|||||||
AsyncHybridQuery,
|
AsyncHybridQuery,
|
||||||
AsyncQuery,
|
AsyncQuery,
|
||||||
AsyncVectorQuery,
|
AsyncVectorQuery,
|
||||||
|
FullTextQuery,
|
||||||
LanceEmptyQueryBuilder,
|
LanceEmptyQueryBuilder,
|
||||||
LanceFtsQueryBuilder,
|
LanceFtsQueryBuilder,
|
||||||
LanceHybridQueryBuilder,
|
LanceHybridQueryBuilder,
|
||||||
@@ -919,7 +920,9 @@ class Table(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[
|
||||||
|
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||||
|
] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: QueryType = "auto",
|
query_type: QueryType = "auto",
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -2039,7 +2042,9 @@ class LanceTable(Table):
|
|||||||
@overload
|
@overload
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[
|
||||||
|
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||||
|
] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: Literal["hybrid"] = "hybrid",
|
query_type: Literal["hybrid"] = "hybrid",
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -2058,7 +2063,9 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[
|
||||||
|
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||||
|
] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: QueryType = "auto",
|
query_type: QueryType = "auto",
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -3134,7 +3141,9 @@ class AsyncTable:
|
|||||||
@overload
|
@overload
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[
|
||||||
|
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||||
|
] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: Literal["vector"] = ...,
|
query_type: Literal["vector"] = ...,
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -3143,7 +3152,9 @@ class AsyncTable:
|
|||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[
|
||||||
|
Union[VEC, str, "PIL.Image.Image", Tuple, FullTextQuery]
|
||||||
|
] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: QueryType = "auto",
|
query_type: QueryType = "auto",
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -3253,6 +3264,8 @@ class AsyncTable:
|
|||||||
if is_embedding(query):
|
if is_embedding(query):
|
||||||
vector_query = query
|
vector_query = query
|
||||||
query_type = "vector"
|
query_type = "vector"
|
||||||
|
elif isinstance(query, FullTextQuery):
|
||||||
|
query_type = "fts"
|
||||||
elif isinstance(query, str):
|
elif isinstance(query, str):
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from unittest import mock
|
|||||||
import lancedb as ldb
|
import lancedb as ldb
|
||||||
from lancedb.db import DBConnection
|
from lancedb.db import DBConnection
|
||||||
from lancedb.index import FTS
|
from lancedb.index import FTS
|
||||||
|
from lancedb.query import BoostQuery, MatchQuery, MultiMatchQuery, PhraseQuery
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
@@ -178,11 +179,47 @@ def test_search_fts(table, use_tantivy):
|
|||||||
results = table.search("puppy").select(["id", "text"]).to_list()
|
results = table.search("puppy").select(["id", "text"]).to_list()
|
||||||
assert len(results) == 10
|
assert len(results) == 10
|
||||||
|
|
||||||
|
if not use_tantivy:
|
||||||
|
# Test with a query
|
||||||
|
results = (
|
||||||
|
table.search(MatchQuery("puppy", "text"))
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
# Test boost query
|
||||||
|
results = (
|
||||||
|
table.search(
|
||||||
|
BoostQuery(
|
||||||
|
MatchQuery("puppy", "text"),
|
||||||
|
MatchQuery("runs", "text"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
# Test multi match query
|
||||||
|
table.create_fts_index("text2", use_tantivy=use_tantivy)
|
||||||
|
results = (
|
||||||
|
table.search(MultiMatchQuery("puppy", ["text", "text2"]))
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fts_select_async(async_table):
|
async def test_fts_select_async(async_table):
|
||||||
tbl = await async_table
|
tbl = await async_table
|
||||||
await tbl.create_index("text", config=FTS())
|
await tbl.create_index("text", config=FTS())
|
||||||
|
await tbl.create_index("text2", config=FTS())
|
||||||
results = (
|
results = (
|
||||||
await tbl.query()
|
await tbl.query()
|
||||||
.nearest_to_text("puppy")
|
.nearest_to_text("puppy")
|
||||||
@@ -193,6 +230,54 @@ async def test_fts_select_async(async_table):
|
|||||||
assert len(results) == 5
|
assert len(results) == 5
|
||||||
assert len(results[0]) == 3 # id, text, _score
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
# Test with FullTextQuery
|
||||||
|
results = (
|
||||||
|
await tbl.query()
|
||||||
|
.nearest_to_text(MatchQuery("puppy", "text"))
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
# Test with BoostQuery
|
||||||
|
results = (
|
||||||
|
await tbl.query()
|
||||||
|
.nearest_to_text(
|
||||||
|
BoostQuery(
|
||||||
|
MatchQuery("puppy", "text"),
|
||||||
|
MatchQuery("runs", "text"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
# Test with MultiMatchQuery
|
||||||
|
results = (
|
||||||
|
await tbl.query()
|
||||||
|
.nearest_to_text(MultiMatchQuery("puppy", ["text", "text2"]))
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
# Test with search() API
|
||||||
|
results = (
|
||||||
|
await (await tbl.search(MatchQuery("puppy", "text")))
|
||||||
|
.select(["id", "text"])
|
||||||
|
.limit(5)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(results[0]) == 3 # id, text, _score
|
||||||
|
|
||||||
|
|
||||||
def test_search_fts_phrase_query(table):
|
def test_search_fts_phrase_query(table):
|
||||||
table.create_fts_index("text", use_tantivy=False, with_position=False)
|
table.create_fts_index("text", use_tantivy=False, with_position=False)
|
||||||
@@ -207,6 +292,13 @@ def test_search_fts_phrase_query(table):
|
|||||||
assert len(results) > len(phrase_results)
|
assert len(results) > len(phrase_results)
|
||||||
assert len(phrase_results) > 0
|
assert len(phrase_results) > 0
|
||||||
|
|
||||||
|
# Test with a query
|
||||||
|
phrase_results = (
|
||||||
|
table.search(PhraseQuery("puppy runs", "text")).limit(100).to_list()
|
||||||
|
)
|
||||||
|
assert len(results) > len(phrase_results)
|
||||||
|
assert len(phrase_results) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_fts_phrase_query_async(async_table):
|
async def test_search_fts_phrase_query_async(async_table):
|
||||||
@@ -227,6 +319,16 @@ async def test_search_fts_phrase_query_async(async_table):
|
|||||||
assert len(results) > len(phrase_results)
|
assert len(results) > len(phrase_results)
|
||||||
assert len(phrase_results) > 0
|
assert len(phrase_results) > 0
|
||||||
|
|
||||||
|
# Test with a query
|
||||||
|
phrase_results = (
|
||||||
|
await async_table.query()
|
||||||
|
.nearest_to_text(PhraseQuery("puppy runs", "text"))
|
||||||
|
.limit(100)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
assert len(results) > len(phrase_results)
|
||||||
|
assert len(phrase_results) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_search_fts_specify_column(table):
|
def test_search_fts_specify_column(table):
|
||||||
table.create_fts_index("text", use_tantivy=False)
|
table.create_fts_index("text", use_tantivy=False)
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ impl Query {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
query
|
query
|
||||||
} else if let Ok(query) = query.downcast::<PyDict>() {
|
} else if let Ok(query) = fts_query.downcast::<PyDict>() {
|
||||||
let query = parse_fts_query(query)?;
|
let query = parse_fts_query(query)?;
|
||||||
FullTextSearchQuery::new_query(query)
|
FullTextSearchQuery::new_query(query)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1048,6 +1048,7 @@ mod tests {
|
|||||||
use arrow_schema::{DataType, Field, Schema};
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||||
|
use lance_index::scalar::inverted::query::MatchQuery;
|
||||||
use lance_index::scalar::FullTextSearchQuery;
|
use lance_index::scalar::FullTextSearchQuery;
|
||||||
use reqwest::Body;
|
use reqwest::Body;
|
||||||
use rstest::rstest;
|
use rstest::rstest;
|
||||||
@@ -1734,6 +1735,66 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_query_structured_fts() {
|
||||||
|
let table =
|
||||||
|
Table::new_with_handler_version("my_table", semver::Version::new(0, 3, 0), |request| {
|
||||||
|
assert_eq!(request.method(), "POST");
|
||||||
|
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||||
|
assert_eq!(
|
||||||
|
request.headers().get("Content-Type").unwrap(),
|
||||||
|
JSON_CONTENT_TYPE
|
||||||
|
);
|
||||||
|
|
||||||
|
let body = request.body().unwrap().as_bytes().unwrap();
|
||||||
|
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||||
|
let expected_body = serde_json::json!({
|
||||||
|
"full_text_query": {
|
||||||
|
"query": {
|
||||||
|
"match": {
|
||||||
|
"terms": "hello world",
|
||||||
|
"column": "a",
|
||||||
|
"boost": 1.0,
|
||||||
|
"fuzziness": 0,
|
||||||
|
"max_expansions": 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"k": 10,
|
||||||
|
"vector": [],
|
||||||
|
"with_row_id": true,
|
||||||
|
"prefilter": true,
|
||||||
|
"version": null
|
||||||
|
});
|
||||||
|
assert_eq!(body, expected_body);
|
||||||
|
|
||||||
|
let data = RecordBatch::try_new(
|
||||||
|
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||||
|
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let response_body = write_ipc_file(&data);
|
||||||
|
http::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||||
|
.body(response_body)
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = table
|
||||||
|
.query()
|
||||||
|
.full_text_search(FullTextSearchQuery::new_query(
|
||||||
|
MatchQuery::new("hello world".to_owned())
|
||||||
|
.with_column(Some("a".to_owned()))
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
.with_row_id()
|
||||||
|
.limit(10)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[rstest]
|
#[rstest]
|
||||||
#[case(DEFAULT_SERVER_VERSION.clone())]
|
#[case(DEFAULT_SERVER_VERSION.clone())]
|
||||||
#[case(semver::Version::new(0, 2, 0))]
|
#[case(semver::Version::new(0, 2, 0))]
|
||||||
|
|||||||
Reference in New Issue
Block a user