diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 53bcb434..9c33740a 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -132,8 +132,8 @@ class LanceQueryBuilder(ABC): query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query_type: str, vector_column_name: str, - ordering_field_name: str = None, - fts_columns: Union[str, List[str]] = None, + ordering_field_name: Optional[str] = None, + fts_columns: Union[str, List[str]] = [], ) -> LanceQueryBuilder: """ Create a query builder based on the given query and query type. @@ -156,7 +156,9 @@ class LanceQueryBuilder(ABC): if query_type == "hybrid": # hybrid fts and vector query - return LanceHybridQueryBuilder(table, query, vector_column_name) + return LanceHybridQueryBuilder( + table, query, vector_column_name, fts_columns=fts_columns + ) # remember the string query for reranking purpose str_query = query if isinstance(query, str) else None @@ -168,7 +170,9 @@ class LanceQueryBuilder(ABC): ) if query_type == "hybrid": - return LanceHybridQueryBuilder(table, query, vector_column_name) + return LanceHybridQueryBuilder( + table, query, vector_column_name, fts_columns=fts_columns + ) if isinstance(query, str): # fts @@ -176,6 +180,7 @@ class LanceQueryBuilder(ABC): table, query, ordering_field_name=ordering_field_name, + fts_columns=fts_columns, ) if isinstance(query, list): @@ -693,8 +698,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): self, table: "Table", query: str, - ordering_field_name: str = None, - fts_columns: Union[str, List[str]] = None, + ordering_field_name: Optional[str] = None, + fts_columns: Union[str, List[str]] = [], ): super().__init__(table) self._query = query @@ -887,10 +892,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): in the `rerank` method to convert the scores to ranks and then normalize them. """ - def __init__(self, table: "Table", query: str, vector_column: str): + def __init__( + self, + table: "Table", + query: str, + vector_column: str, + fts_columns: Union[str, List[str]] = [], + ): super().__init__(table) vector_query, fts_query = self._validate_query(query) - self._fts_query = LanceFtsQueryBuilder(table, fts_query) + self._fts_query = LanceFtsQueryBuilder( + table, fts_query, fts_columns=fts_columns + ) vector_query = self._query_to_vector(table, vector_query, vector_column) self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column) self._norm = "score" @@ -1386,7 +1399,7 @@ class AsyncQuery(AsyncQueryBase): ) def nearest_to_text( - self, query: str, columns: Union[str, List[str]] = None + self, query: str, columns: Union[str, List[str]] = [] ) -> AsyncQuery: """ Find the documents that are most relevant to the given text query. @@ -1410,9 +1423,8 @@ class AsyncQuery(AsyncQueryBase): """ if isinstance(columns, str): columns = [columns] - return AsyncQuery( - self._inner.nearest_to_text({"query": query, "columns": columns}) - ) + self._inner.nearest_to_text({"query": query, "columns": columns}) + return self class AsyncVectorQuery(AsyncQueryBase): diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 596e7b81..4f5f6a0c 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -15,7 +15,7 @@ import logging import uuid from concurrent.futures import Future from functools import cached_property -from typing import Dict, Iterable, Optional, Union, Literal +from typing import Dict, Iterable, List, Optional, Union, Literal import pyarrow as pa from lance import json_to_schema @@ -268,6 +268,7 @@ class RemoteTable(Table): query: Union[VEC, str], vector_column_name: Optional[str] = None, query_type="auto", + fts_columns: Optional[Union[str, List[str]]] = None, ) -> LanceVectorQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -338,6 +339,7 @@ class RemoteTable(Table): query, query_type, vector_column_name=vector_column_name, + fts_columns=fts_columns, ) def _execute_query( diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 26ab53a1..46df91c2 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -545,7 +545,7 @@ class Table(ABC): vector_column_name: Optional[str] = None, query_type: str = "auto", ordering_field_name: Optional[str] = None, - fts_columns: Union[str, List[str]] = None, + fts_columns: Optional[Union[str, List[str]]] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1425,7 +1425,7 @@ class LanceTable(Table): vector_column_name: Optional[str] = None, query_type: str = "auto", ordering_field_name: Optional[str] = None, - fts_columns: Union[str, List[str]] = None, + fts_columns: Optional[Union[str, List[str]]] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1505,6 +1505,7 @@ class LanceTable(Table): query_type, vector_column_name=vector_column_name, ordering_field_name=ordering_field_name, + fts_columns=fts_columns, ) @classmethod diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index 9cfda85a..54ba9cf4 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -29,14 +29,26 @@ def table(tmp_path) -> ldb.table.LanceTable: db = ldb.connect(tmp_path) vectors = [np.random.randn(128) for _ in range(100)] - nouns = ("puppy", "car", "rabbit", "girl", "monkey") + text_nouns = ("puppy", "car") + text2_nouns = ("rabbit", "girl", "monkey") verbs = ("runs", "hits", "jumps", "drives", "barfs") adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.") adj = ("adorable", "clueless", "dirty", "odd", "stupid") text = [ " ".join( [ - nouns[random.randrange(0, 5)], + text_nouns[random.randrange(0, len(text_nouns))], + verbs[random.randrange(0, 5)], + adv[random.randrange(0, 5)], + adj[random.randrange(0, 5)], + ] + ) + for _ in range(100) + ] + text2 = [ + " ".join( + [ + text2_nouns[random.randrange(0, len(text2_nouns))], verbs[random.randrange(0, 5)], adv[random.randrange(0, 5)], adj[random.randrange(0, 5)], @@ -52,7 +64,7 @@ def table(tmp_path) -> ldb.table.LanceTable: "vector": vectors, "id": [i % 2 for i in range(100)], "text": text, - "text2": text, + "text2": text2, "nested": [{"text": t} for t in text], "count": count, } @@ -66,14 +78,26 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable: db = await ldb.connect_async(tmp_path) vectors = [np.random.randn(128) for _ in range(100)] - nouns = ("puppy", "car", "rabbit", "girl", "monkey") + text_nouns = ("puppy", "car") + text2_nouns = ("rabbit", "girl", "monkey") verbs = ("runs", "hits", "jumps", "drives", "barfs") adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.") adj = ("adorable", "clueless", "dirty", "odd", "stupid") text = [ " ".join( [ - nouns[random.randrange(0, 5)], + text_nouns[random.randrange(0, len(text_nouns))], + verbs[random.randrange(0, 5)], + adv[random.randrange(0, 5)], + adj[random.randrange(0, 5)], + ] + ) + for _ in range(100) + ] + text2 = [ + " ".join( + [ + text2_nouns[random.randrange(0, len(text2_nouns))], verbs[random.randrange(0, 5)], adv[random.randrange(0, 5)], adj[random.randrange(0, 5)], @@ -89,7 +113,7 @@ async def async_table(tmp_path) -> ldb.table.AsyncTable: "vector": vectors, "id": [i % 2 for i in range(100)], "text": text, - "text2": text, + "text2": text2, "nested": [{"text": t} for t in text], "count": count, } @@ -142,12 +166,81 @@ def test_search_fts(table, use_tantivy): assert len(results) == 5 +def test_search_fts_specify_column(table): + table.create_fts_index("text", use_tantivy=False) + table.create_fts_index("text2", use_tantivy=False) + + results = table.search("puppy", fts_columns="text").limit(5).to_list() + assert len(results) == 5 + + results = table.search("rabbit", fts_columns="text2").limit(5).to_list() + assert len(results) == 5 + + try: + # we can only specify one column for now + table.search("puppy", fts_columns=["text", "text2"]).limit(5).to_list() + assert False + except Exception: + pass + + try: + # have to specify a column because we have two fts indices + table.search("puppy").limit(5).to_list() + assert False + except Exception: + pass + + +@pytest.mark.asyncio async def test_search_fts_async(async_table): + async_table = await async_table await async_table.create_index("text", config=FTS()) results = await async_table.query().nearest_to_text("puppy").limit(5).to_list() assert len(results) == 5 +@pytest.mark.asyncio +async def test_search_fts_specify_column_async(async_table): + async_table = await async_table + await async_table.create_index("text", config=FTS()) + await async_table.create_index("text2", config=FTS()) + + results = ( + await async_table.query() + .nearest_to_text("puppy", columns="text") + .limit(5) + .to_list() + ) + assert len(results) == 5 + + results = ( + await async_table.query() + .nearest_to_text("rabbit", columns="text2") + .limit(5) + .to_list() + ) + assert len(results) == 5 + + try: + # we can only specify one column for now + await ( + async_table.query() + .nearest_to_text("rabbit", columns="text2") + .limit(5) + .to_list() + ) + assert False + except Exception: + pass + + try: + # have to specify a column because we have two fts indices + await async_table.query().nearest_to_text("puppy").limit(5).to_list() + assert False + except Exception: + pass + + def test_search_ordering_field_index_table(tmp_path, table): table.create_fts_index("text", ordering_field_names=["count"], use_tantivy=True) rows = ( diff --git a/python/src/index.rs b/python/src/index.rs index 884b2987..5a857561 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -98,6 +98,13 @@ impl Index { inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))), }) } + + #[staticmethod] + pub fn fts() -> PyResult { + Ok(Self { + inner: Mutex::new(Some(LanceDbIndex::FTS(Default::default()))), + }) + } } #[pyclass(get_all)]