diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index a143e308..c4642637 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -74,6 +74,7 @@ class Query: def select(self, columns: Tuple[str, str]): ... def limit(self, limit: int): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... + def nearest_to_text(self, query: dict) -> Query: ... async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ... class VectorQuery: diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 50046080..1c77b299 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -276,6 +276,10 @@ class DBConnection(EnforceOverrides): """ raise NotImplementedError + @property + def uri(self) -> str: + return self._uri + class LanceDBConnection(DBConnection): """ @@ -340,10 +344,6 @@ class LanceDBConnection(DBConnection): val += ")" return val - @property - def uri(self) -> str: - return self._uri - async def _async_get_table_names(self, start_after: Optional[str], limit: int): conn = AsyncConnection(await lancedb_connect(self.uri)) return await conn.table_names(start_after=start_after, limit=limit) diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index f9dd7900..2e0c7b95 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -70,6 +70,18 @@ class LabelList: self._inner = LanceDbIndex.label_list() +class FTS: + """Describe a FTS index configuration. + + `FTS` is a full-text search index that can be used on `String` columns + + For example, it works with `title`, `description`, `content`, etc. + """ + + def __init__(self): + self._inner = LanceDbIndex.fts() + + class IvfPq: """Describes an IVF PQ Index diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 8564575f..874a606a 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -15,7 +15,6 @@ from __future__ import annotations from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from pathlib import Path from typing import ( TYPE_CHECKING, Dict, @@ -38,7 +37,7 @@ from .arrow import AsyncRecordBatchReader from .common import VEC from .rerankers.base import Reranker from .rerankers.linear_combination import LinearCombinationReranker -from .util import fs_from_uri, safe_import_pandas +from .util import safe_import_pandas if TYPE_CHECKING: import PIL @@ -174,7 +173,9 @@ class LanceQueryBuilder(ABC): if isinstance(query, str): # fts return LanceFtsQueryBuilder( - table, query, ordering_field_name=ordering_field_name + table, + query, + ordering_field_name=ordering_field_name, ) if isinstance(query, list): @@ -681,6 +682,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): self._phrase_query = False self.ordering_field_name = ordering_field_name self._reranker = None + if isinstance(fts_columns, str): + fts_columns = [fts_columns] self._fts_columns = fts_columns def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder: @@ -701,8 +704,8 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): return self def to_arrow(self) -> pa.Table: - tantivy_index_path = self._table._get_fts_index_path() - if Path(tantivy_index_path).exists(): + path, fs, exist = self._table._get_fts_index_path() + if exist: return self.tantivy_to_arrow() query = self._query @@ -711,23 +714,20 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): "Phrase query is not yet supported in Lance FTS. " "Use tantivy-based index instead for now." ) - if self._reranker: - raise NotImplementedError( - "Reranking is not yet supported in Lance FTS. " - "Use tantivy-based index instead for now." - ) - ds = self._table.to_lance() - return ds.to_table( + query = Query( columns=self._columns, filter=self._where, - limit=self._limit, + k=self._limit, prefilter=self._prefilter, with_row_id=self._with_row_id, full_text_query={ "query": query, "columns": self._fts_columns, }, + vector=[], ) + results = self._table._execute_query(query) + return results.read_all() def tantivy_to_arrow(self) -> pa.Table: try: @@ -740,24 +740,24 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): from .fts import search_index # get the index path - index_path = self._table._get_fts_index_path() - - # Check that we are on local filesystem - fs, _path = fs_from_uri(index_path) - if not isinstance(fs, pa_fs.LocalFileSystem): - raise NotImplementedError( - "Full-text search is only supported on the local filesystem" - ) + path, fs, exist = self._table._get_fts_index_path() # check if the index exist - if not Path(index_path).exists(): + if not exist: raise FileNotFoundError( "Fts index does not exist. " "Please first call table.create_fts_index(['']) to " "create the fts index." ) + + # Check that we are on local filesystem + if not isinstance(fs, pa_fs.LocalFileSystem): + raise NotImplementedError( + "Tantivy-based full text search " + "is only supported on the local filesystem" + ) # open the index - index = tantivy.Index.open(index_path) + index = tantivy.Index.open(path) # get the scores and doc ids query = self._query if self._phrase_query: @@ -851,7 +851,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def __init__(self, table: "Table", query: str, vector_column: str): super().__init__(table) - self._validate_fts_index() vector_query, fts_query = self._validate_query(query) self._fts_query = LanceFtsQueryBuilder(table, fts_query) vector_query = self._query_to_vector(table, vector_query, vector_column) @@ -859,12 +858,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._norm = "score" self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0) - def _validate_fts_index(self): - if self._table._get_fts_index_path() is None: - raise ValueError( - "Please create a full-text search index " "to perform hybrid search." - ) - def _validate_query(self, query): # Temp hack to support vectorized queries for hybrid search if isinstance(query, str): @@ -1354,6 +1347,35 @@ class AsyncQuery(AsyncQueryBase): self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) ) + def nearest_to_text( + self, query: str, columns: Union[str, List[str]] = None + ) -> AsyncQuery: + """ + Find the documents that are most relevant to the given text query. + + This method will perform a full text search on the table and return + the most relevant documents. The relevance is determined by BM25. + + The columns to search must be with native FTS index + (Tantivy-based can't work with this method). + + By default, all indexed columns are searched, + now only one column can be searched at a time. + + Parameters + ---------- + query: str + The text query to search for. + columns: str or list of str, default None + The columns to search in. If None, all indexed columns are searched. + For now only one column can be searched at a time. + """ + if isinstance(columns, str): + columns = [columns] + return AsyncQuery( + self._inner.nearest_to_text({"query": query, "columns": columns}) + ) + class AsyncVectorQuery(AsyncQueryBase): def __init__(self, inner: LanceVectorQuery): diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 6f51f79e..0dd6bb6d 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -49,6 +49,7 @@ class RemoteDBConnection(DBConnection): parsed = urlparse(db_url) if parsed.scheme != "db": raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") + self._uri = str(db_url) self.db_name = parsed.netloc self.api_key = api_key self._client = RestfulLanceDBClient( diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 3d1669ab..596e7b81 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -35,10 +35,10 @@ from .db import RemoteDBConnection class RemoteTable(Table): def __init__(self, conn: RemoteDBConnection, name: str): self._conn = conn - self._name = name + self.name = name def __repr__(self) -> str: - return f"RemoteTable({self._conn.db_name}.{self._name})" + return f"RemoteTable({self._conn.db_name}.{self.name})" def __len__(self) -> int: self.count_rows(None) @@ -49,14 +49,14 @@ class RemoteTable(Table): of this Table """ - resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") + resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") schema = json_to_schema(resp["schema"]) return schema @property def version(self) -> int: """Get the current version of the table""" - resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") + resp = self._conn._client.post(f"/v1/table/{self.name}/describe/") return resp["version"] @cached_property @@ -84,13 +84,13 @@ class RemoteTable(Table): def list_indices(self): """List all the indices on the table""" - resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/") + resp = self._conn._client.post(f"/v1/table/{self.name}/index/list/") return resp def index_stats(self, index_uuid: str): """List all the stats of a specified index""" resp = self._conn._client.post( - f"/v1/table/{self._name}/index/{index_uuid}/stats/" + f"/v1/table/{self.name}/index/{index_uuid}/stats/" ) return resp @@ -116,11 +116,27 @@ class RemoteTable(Table): "replace": True, } resp = self._conn._client.post( - f"/v1/table/{self._name}/create_scalar_index/", data=data + f"/v1/table/{self.name}/create_scalar_index/", data=data ) return resp + def create_fts_index( + self, + column: str, + *, + replace: bool = False, + ): + data = { + "column": column, + "index_type": "FTS", + "replace": replace, + } + resp = self._conn._client.post( + f"/v1/table/{self.name}/create_index/", data=data + ) + return resp + def create_index( self, metric="L2", @@ -194,7 +210,7 @@ class RemoteTable(Table): "index_cache_size": index_cache_size, } resp = self._conn._client.post( - f"/v1/table/{self._name}/create_index/", data=data + f"/v1/table/{self.name}/create_index/", data=data ) return resp @@ -241,7 +257,7 @@ class RemoteTable(Table): request_id = uuid.uuid4().hex self._conn._client.post( - f"/v1/table/{self._name}/insert/", + f"/v1/table/{self.name}/insert/", data=payload, params={"request_id": request_id, "mode": mode}, content_type=ARROW_STREAM_CONTENT_TYPE, @@ -251,6 +267,7 @@ class RemoteTable(Table): self, query: Union[VEC, str], vector_column_name: Optional[str] = None, + query_type="auto", ) -> LanceVectorQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -310,10 +327,18 @@ class RemoteTable(Table): - and also the "_distance" column which is the distance between the query vector and the returned vector. """ - if vector_column_name is None: - vector_column_name = inf_vector_column_query(self.schema) - query = LanceQueryBuilder._query_to_vector(self, query, vector_column_name) - return LanceVectorQueryBuilder(self, query, vector_column_name) + if vector_column_name is None and query is not None and query_type != "fts": + try: + vector_column_name = inf_vector_column_query(self.schema) + except Exception as e: + raise e + + return LanceQueryBuilder.create( + self, + query, + query_type, + vector_column_name=vector_column_name, + ) def _execute_query( self, query: Query, batch_size: Optional[int] = None @@ -342,12 +367,12 @@ class RemoteTable(Table): v = list(v) q = query.copy() q.vector = v - results.append(submit(self._name, q)) + results.append(submit(self.name, q)) return pa.concat_tables( [add_index(r.result().to_arrow(), i) for i, r in enumerate(results)] ).to_reader() else: - result = self._conn._client.query(self._name, query) + result = self._conn._client.query(self.name, query) return result.to_arrow().to_reader() def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: @@ -397,7 +422,7 @@ class RemoteTable(Table): ) self._conn._client.post( - f"/v1/table/{self._name}/merge_insert/", + f"/v1/table/{self.name}/merge_insert/", data=payload, params=params, content_type=ARROW_STREAM_CONTENT_TYPE, @@ -451,7 +476,7 @@ class RemoteTable(Table): 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ payload = {"predicate": predicate} - self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) + self._conn._client.post(f"/v1/table/{self.name}/delete/", data=payload) def update( self, @@ -512,7 +537,7 @@ class RemoteTable(Table): updates = [[k, v] for k, v in values_sql.items()] payload = {"predicate": where, "updates": updates} - self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) + self._conn._client.post(f"/v1/table/{self.name}/update/", data=payload) def cleanup_old_versions(self, *_): """cleanup_old_versions() is not supported on the LanceDB cloud""" @@ -529,7 +554,7 @@ class RemoteTable(Table): def count_rows(self, filter: Optional[str] = None) -> int: payload = {"predicate": filter} resp = self._conn._client.post( - f"/v1/table/{self._name}/count_rows/", data=payload + f"/v1/table/{self.name}/count_rows/", data=payload ) return resp diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 18f6f90b..6f89e0f7 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -51,7 +51,7 @@ if TYPE_CHECKING: from lance.dataset import CleanupStats, ReaderLike from ._lancedb import Table as LanceDBTable, OptimizeStats from .db import LanceDBConnection - from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList + from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS pd = safe_import_pandas() @@ -840,6 +840,18 @@ class Table(ABC): The names of the columns to drop. """ + @cached_property + def _dataset_uri(self) -> str: + return _table_uri(self._conn.uri, self.name) + + def _get_fts_index_path(self) -> Tuple[str, pa_fs.FileSystem, bool]: + if get_uri_scheme(self._dataset_uri) != "file": + return ("", None, False) + path = join_uri(self._dataset_uri, "_indices", "fts") + fs, path = fs_from_uri(path) + index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound + return (path, fs, index_exists) + class _LanceDatasetRef(ABC): @property @@ -979,10 +991,6 @@ class LanceTable(Table): # Cacheable since it's deterministic return _table_path(self._conn.uri, self.name) - @cached_property - def _dataset_uri(self) -> str: - return _table_uri(self._conn.uri, self.name) - @property def _dataset(self) -> LanceDataset: return self._ref.dataset @@ -1247,9 +1255,8 @@ class LanceTable(Table): raise ValueError("field_names must be a string when use_tantivy=False") # delete the existing legacy index if it exists if replace: - fs, path = fs_from_uri(self._get_fts_index_path()) - index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound - if index_exists: + path, fs, exist = self._get_fts_index_path() + if exist: fs.delete_dir(path) self._dataset_mut.create_scalar_index( field_names, index_type="INVERTED", replace=replace @@ -1264,9 +1271,8 @@ class LanceTable(Table): if isinstance(ordering_field_names, str): ordering_field_names = [ordering_field_names] - fs, path = fs_from_uri(self._get_fts_index_path()) - index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound - if index_exists: + path, fs, exist = self._get_fts_index_path() + if exist: if not replace: raise ValueError("Index already exists. Use replace=True to overwrite.") fs.delete_dir(path) @@ -1277,7 +1283,7 @@ class LanceTable(Table): ) index = create_index( - self._get_fts_index_path(), + path, field_names, ordering_fields=ordering_field_names, tokenizer_name=tokenizer_name, @@ -1290,13 +1296,6 @@ class LanceTable(Table): writer_heap_size=writer_heap_size, ) - def _get_fts_index_path(self): - if get_uri_scheme(self._dataset_uri) != "file": - raise NotImplementedError( - "Full-text search is not supported on object stores." - ) - return join_uri(self._dataset_uri, "_indices", "tantivy") - def add( self, data: DATA, @@ -1492,14 +1491,11 @@ class LanceTable(Table): and also the "_distance" column which is the distance between the query vector and the returned vector. """ - if vector_column_name is None and query is not None: + if vector_column_name is None and query is not None and query_type != "fts": try: vector_column_name = inf_vector_column_query(self.schema) except Exception as e: - if query_type == "fts": - vector_column_name = "" - else: - raise e + raise e return LanceQueryBuilder.create( self, @@ -1690,18 +1686,22 @@ class LanceTable(Table): self, query: Query, batch_size: Optional[int] = None ) -> pa.RecordBatchReader: ds = self.to_lance() - return ds.scanner( - columns=query.columns, - filter=query.filter, - prefilter=query.prefilter, - nearest={ + nearest = None + if len(query.vector) > 0: + nearest = { "column": query.vector_column, "q": query.vector, "k": query.k, "metric": query.metric, "nprobes": query.nprobes, "refine_factor": query.refine_factor, - }, + } + return ds.scanner( + columns=query.columns, + limit=query.k, + filter=query.filter, + prefilter=query.prefilter, + nearest=nearest, full_text_query=query.full_text_query, with_row_id=query.with_row_id, batch_size=batch_size, @@ -2126,7 +2126,7 @@ class AsyncTable: column: str, *, replace: Optional[bool] = None, - config: Optional[Union[IvfPq, BTree, Bitmap, LabelList]] = None, + config: Optional[Union[IvfPq, BTree, Bitmap, LabelList, FTS]] = None, ): """Create an index to speed up queries diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index f4c7cd1c..9cfda85a 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -15,6 +15,7 @@ import random from unittest import mock import lancedb as ldb +from lancedb.index import FTS import numpy as np import pandas as pd import pytest @@ -60,6 +61,43 @@ def table(tmp_path) -> ldb.table.LanceTable: return table +@pytest.fixture +async def async_table(tmp_path) -> ldb.table.AsyncTable: + db = await ldb.connect_async(tmp_path) + vectors = [np.random.randn(128) for _ in range(100)] + + nouns = ("puppy", "car", "rabbit", "girl", "monkey") + verbs = ("runs", "hits", "jumps", "drives", "barfs") + adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.") + adj = ("adorable", "clueless", "dirty", "odd", "stupid") + text = [ + " ".join( + [ + nouns[random.randrange(0, 5)], + verbs[random.randrange(0, 5)], + adv[random.randrange(0, 5)], + adj[random.randrange(0, 5)], + ] + ) + for _ in range(100) + ] + count = [random.randint(1, 10000) for _ in range(100)] + table = await db.create_table( + "test", + data=pd.DataFrame( + { + "vector": vectors, + "id": [i % 2 for i in range(100)], + "text": text, + "text2": text, + "nested": [{"text": t} for t in text], + "count": count, + } + ), + ) + return table + + def test_create_index(tmp_path): index = ldb.fts.create_index(str(tmp_path / "index"), ["text"]) assert isinstance(index, tantivy.Index) @@ -91,17 +129,23 @@ def test_search_index(tmp_path, table): index = ldb.fts.create_index(str(tmp_path / "index"), ["text"]) ldb.fts.populate_index(index, table, ["text"]) index.reload() - results = ldb.fts.search_index(index, query="puppy", limit=10) + results = ldb.fts.search_index(index, query="puppy", limit=5) assert len(results) == 2 - assert len(results[0]) == 10 # row_ids - assert len(results[1]) == 10 # _distance + assert len(results[0]) == 5 # row_ids + assert len(results[1]) == 5 # _score @pytest.mark.parametrize("use_tantivy", [True, False]) def test_search_fts(table, use_tantivy): table.create_fts_index("text", use_tantivy=use_tantivy) - results = table.search("puppy").limit(10).to_list() - assert len(results) == 10 + results = table.search("puppy").limit(5).to_list() + assert len(results) == 5 + + +async def test_search_fts_async(async_table): + await async_table.create_index("text", config=FTS()) + results = await async_table.query().nearest_to_text("puppy").limit(5).to_list() + assert len(results) == 5 def test_search_ordering_field_index_table(tmp_path, table): @@ -125,11 +169,11 @@ def test_search_ordering_field_index(tmp_path, table): ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"]) index.reload() results = ldb.fts.search_index( - index, query="puppy", limit=10, ordering_field="count" + index, query="puppy", limit=5, ordering_field="count" ) assert len(results) == 2 - assert len(results[0]) == 10 # row_ids - assert len(results[1]) == 10 # _distance + assert len(results[0]) == 5 # row_ids + assert len(results[1]) == 5 # _distance rows = table.to_lance().take(results[0]).to_pylist() for r in rows: @@ -140,8 +184,8 @@ def test_search_ordering_field_index(tmp_path, table): @pytest.mark.parametrize("use_tantivy", [True, False]) def test_create_index_from_table(tmp_path, table, use_tantivy): table.create_fts_index("text", use_tantivy=use_tantivy) - df = table.search("puppy").limit(10).select(["text"]).to_pandas() - assert len(df) <= 10 + df = table.search("puppy").limit(5).select(["text"]).to_pandas() + assert len(df) <= 5 assert "text" in df.columns # Check whether it can be updated @@ -167,8 +211,8 @@ def test_create_index_from_table(tmp_path, table, use_tantivy): def test_create_index_multiple_columns(tmp_path, table): table.create_fts_index(["text", "text2"], use_tantivy=True) - df = table.search("puppy").limit(10).to_pandas() - assert len(df) == 10 + df = table.search("puppy").limit(5).to_pandas() + assert len(df) == 5 assert "text" in df.columns assert "text2" in df.columns @@ -176,14 +220,14 @@ def test_create_index_multiple_columns(tmp_path, table): def test_empty_rs(tmp_path, table, mocker): table.create_fts_index(["text", "text2"], use_tantivy=True) mocker.patch("lancedb.fts.search_index", return_value=([], [])) - df = table.search("puppy").limit(10).to_pandas() + df = table.search("puppy").limit(5).to_pandas() assert len(df) == 0 def test_nested_schema(tmp_path, table): table.create_fts_index("nested.text", use_tantivy=True) - rs = table.search("puppy").limit(10).to_list() - assert len(rs) == 10 + rs = table.search("puppy").limit(5).to_list() + assert len(rs) == 5 @pytest.mark.parametrize("use_tantivy", [True, False]) diff --git a/python/python/tests/test_s3.py b/python/python/tests/test_s3.py index 2b6ed38a..85b72749 100644 --- a/python/python/tests/test_s3.py +++ b/python/python/tests/test_s3.py @@ -251,7 +251,8 @@ def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch): # FTS indices should error since they are not supported yet. with pytest.raises( - NotImplementedError, match="Full-text search is not supported on object stores." + NotImplementedError, + match="Full-text search is only supported on the local filesystem", ): table.create_fts_index("x") diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 0d6beeb4..6ca2f5f1 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -28,7 +28,7 @@ from pydantic import BaseModel class MockDB: def __init__(self, uri: Path): - self.uri = uri + self.uri = str(uri) self.read_consistency_interval = None @functools.cached_property diff --git a/python/src/query.rs b/python/src/query.rs index 471f686b..f88e60b4 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -15,17 +15,20 @@ use arrow::array::make_array; use arrow::array::ArrayData; use arrow::pyarrow::FromPyArrow; +use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::QueryExecutionOptions; use lancedb::query::{ ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, }; use pyo3::exceptions::PyRuntimeError; -use pyo3::pyclass; +use pyo3::prelude::{PyAnyMethods, PyDictMethods}; use pyo3::pymethods; +use pyo3::types::PyDict; use pyo3::Bound; use pyo3::PyAny; use pyo3::PyRef; use pyo3::PyResult; +use pyo3::{pyclass, PyErr}; use pyo3_asyncio_0_21::tokio::future_into_py; use crate::arrow::RecordBatchStream; @@ -68,6 +71,24 @@ impl Query { Ok(VectorQuery { inner }) } + pub fn nearest_to_text(&mut self, query: Bound<'_, PyDict>) -> PyResult<()> { + let query_text = query + .get_item("query")? + .ok_or(PyErr::new::( + "Query text is required for nearest_to_text", + ))? + .extract::()?; + let columns = query + .get_item("columns")? + .map(|columns| columns.extract::>()) + .transpose()?; + + let fts_query = FullTextSearchQuery::new(query_text).columns(columns); + self.inner = self.inner.clone().full_text_search(fts_query); + + Ok(()) + } + pub fn execute( self_: PyRef<'_, Self>, max_batch_length: Option,