From f9d5fa88a1cd2391c88d9f09123e056f181bec74 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 8 Aug 2024 15:33:15 +0800 Subject: [PATCH] feat!: migrate FTS from tantivy to lance-index (#1483) Lance now supports FTS, so add it into lancedb Python, TypeScript and Rust SDKs. For Python, we still use tantivy based FTS by default because the lance FTS index now misses some features of tantivy. For Python: - Support to create lance based FTS index - Support to specify columns for full text search (only available for lance based FTS index) For TypeScript: - Change the search method so that it can accept both string and vector - Support full text search For Rust - Support full text search The others: - Update the FTS doc BREAKING CHANGE: - for Python, this renames the attached score column of FTS from "score" to "_score", this could be a breaking change for users that rely the scores --------- Signed-off-by: BubbleCal --- docs/src/fts.md | 183 +++++++++++++----- nodejs/__test__/table.test.ts | 17 +- nodejs/examples/full_text_search.ts | 52 +++++ nodejs/examples/package-lock.json | 42 ++-- nodejs/examples/package.json | 13 +- nodejs/examples/search.ts | 1 + nodejs/lancedb/embedding/registry.ts | 7 + nodejs/lancedb/indices.ts | 16 ++ nodejs/lancedb/query.ts | 32 +++ nodejs/lancedb/table.ts | 86 +++++--- nodejs/package-lock.json | 9 +- nodejs/package.json | 2 +- nodejs/src/index.rs | 9 +- nodejs/src/query.rs | 13 ++ nodejs/tsconfig.json | 3 +- python/python/lancedb/query.py | 54 +++++- python/python/lancedb/rerankers/base.py | 4 +- python/python/lancedb/rerankers/cohere.py | 2 +- python/python/lancedb/rerankers/colbert.py | 2 +- .../python/lancedb/rerankers/cross_encoder.py | 2 +- python/python/lancedb/rerankers/jinaai.py | 2 +- .../lancedb/rerankers/linear_combination.py | 4 +- python/python/lancedb/rerankers/openai.py | 2 +- python/python/lancedb/table.py | 43 +++- python/python/tests/test_db.py | 5 +- python/python/tests/test_fts.py | 47 +++-- python/python/tests/test_rerankers.py | 18 +- rust/lancedb/Cargo.toml | 1 + rust/lancedb/examples/full_text_search.rs | 114 +++++++++++ rust/lancedb/src/connection.rs | 2 +- rust/lancedb/src/index.rs | 2 + rust/lancedb/src/index/scalar.rs | 10 + rust/lancedb/src/query.rs | 23 +++ rust/lancedb/src/table.rs | 36 ++++ 34 files changed, 713 insertions(+), 145 deletions(-) create mode 100644 nodejs/examples/full_text_search.ts create mode 100644 rust/lancedb/examples/full_text_search.rs diff --git a/docs/src/fts.md b/docs/src/fts.md index 2330e837..50e1cebb 100644 --- a/docs/src/fts.md +++ b/docs/src/fts.md @@ -1,9 +1,14 @@ # Full-text search -LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for Rust and JavaScript users as well. Follow along at [this Github issue](https://github.com/lancedb/lance/issues/1195) +LanceDB provides support for full-text search via Lance (before via [Tantivy](https://github.com/quickwit-oss/tantivy) (Python only)), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. + +Currently, the Lance full text search is missing some features that are in the Tantivy full text search. This includes phrase queries, re-ranking, and customizing the tokenizer. Thus, in Python, Tantivy is still the default way to do full text search and many of the instructions below apply just to Tantivy-based indices. -## Installation +## Installation (Only for Tantivy-based FTS) + +!!! note + No need to install the tantivy dependency if using native FTS To use full-text search, install the dependency [`tantivy-py`](https://github.com/quickwit-oss/tantivy-py): @@ -14,42 +19,83 @@ pip install tantivy==0.20.1 ## Example -Consider that we have a LanceDB table named `my_table`, whose string column `text` we want to index and query via keyword search. +Consider that we have a LanceDB table named `my_table`, whose string column `text` we want to index and query via keyword search, the FTS index must be created before you can search via keywords. -```python -import lancedb +=== "Python" -uri = "data/sample-lancedb" -db = lancedb.connect(uri) + ```python + import lancedb -table = db.create_table( - "my_table", - data=[ - {"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"}, - {"vector": [5.9, 26.5], "text": "There are several kittens playing"}, - ], -) -``` + uri = "data/sample-lancedb" + db = lancedb.connect(uri) -## Create FTS index on single column + table = db.create_table( + "my_table", + data=[ + {"vector": [3.1, 4.1], "text": "Frodo was a happy puppy"}, + {"vector": [5.9, 26.5], "text": "There are several kittens playing"}, + ], + ) -The FTS index must be created before you can search via keywords. + # passing `use_tantivy=False` to use lance FTS index + # `use_tantivy=True` by default + table.create_fts_index("text") + table.search("puppy").limit(10).select(["text"]).to_list() + # [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}] + # ... + ``` -```python -table.create_fts_index("text") -``` +=== "TypeScript" -To search an FTS index via keywords, LanceDB's `table.search` accepts a string as input: + ```typescript + import * as lancedb from "@lancedb/lancedb"; + const uri = "data/sample-lancedb" + const db = await lancedb.connect(uri); -```python -table.search("puppy").limit(10).select(["text"]).to_list() -``` + const data = [ + { vector: [3.1, 4.1], text: "Frodo was a happy puppy" }, + { vector: [5.9, 26.5], text: "There are several kittens playing" }, + ]; + const tbl = await db.createTable("my_table", data, { mode: "overwrite" }); + await tbl.createIndex("text", { + config: lancedb.Index.fts(), + }); -This returns the result as a list of dictionaries as follows. + await tbl + .search("puppy") + .select(["text"]) + .limit(10) + .toArray(); + ``` -```python -[{'text': 'Frodo was a happy puppy', 'score': 0.6931471824645996}] -``` +=== "Rust" + + ```rust + let uri = "data/sample-lancedb"; + let db = connect(uri).execute().await?; + let initial_data: Box = create_some_records()?; + let tbl = db + .create_table("my_table", initial_data) + .execute() + .await?; + tbl + .create_index(&["text"], Index::FTS(FtsIndexBuilder::default())) + .execute() + .await?; + + tbl + .query() + .full_text_search(FullTextSearchQuery::new("puppy".to_owned())) + .select(lancedb::query::Select::Columns(vec!["text".to_owned()])) + .limit(10) + .execute() + .await?; + ``` + +It would search on all indexed columns by default, so it's useful when there are multiple indexed columns. +For now, this is supported in tantivy way only. + +Passing `fts_columns="text"` if you want to specify the columns to search, but it's not available for Tantivy-based full text search. !!! note LanceDB automatically searches on the existing FTS index if the input to the search is of type `str`. If you provide a vector as input, LanceDB will search the ANN index instead. @@ -57,20 +103,33 @@ This returns the result as a list of dictionaries as follows. ## Tokenization By default the text is tokenized by splitting on punctuation and whitespaces and then removing tokens that are longer than 40 chars. For more language specific tokenization then provide the argument tokenizer_name with the 2 letter language code followed by "_stem". So for english it would be "en_stem". -```python -table.create_fts_index("text", tokenizer_name="en_stem") -``` +For now, only the Tantivy-based FTS index supports to specify the tokenizer, so it's only available in Python with `use_tantivy=True`. -The following [languages](https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html) are currently supported. +=== "use_tantivy=True" + ```python + table.create_fts_index("text", use_tantivy=True, tokenizer_name="en_stem") + ``` + +=== "use_tantivy=False" + + [**Not supported yet**](https://github.com/lancedb/lance/issues/1195) + +the following [languages](https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html) are currently supported. ## Index multiple columns If you have multiple string columns to index, there's no need to combine them manually -- simply pass them all as a list to `create_fts_index`: -```python -table.create_fts_index(["text1", "text2"]) -``` +=== "use_tantivy=True" + + ```python + table.create_fts_index(["text1", "text2"]) + ``` + +=== "use_tantivy=False" + + [**Not supported yet**](https://github.com/lancedb/lance/issues/1195) Note that the search API call does not change - you can search over all indexed columns at once. @@ -80,19 +139,48 @@ Currently the LanceDB full text search feature supports *post-filtering*, meanin applied on top of the full text search results. This can be invoked via the familiar `where` syntax: -```python -table.search("puppy").limit(10).where("meta='foo'").to_list() -``` +=== "Python" + + ```python + table.search("puppy").limit(10).where("meta='foo'").to_list() + ``` + +=== "TypeScript" + + ```typescript + await tbl + .search("apple") + .select(["id", "doc"]) + .limit(10) + .where("meta='foo'") + .toArray(); + ``` + +=== "Rust" + + ```rust + table + .query() + .full_text_search(FullTextSearchQuery::new(words[0].to_owned())) + .select(lancedb::query::Select::Columns(vec!["doc".to_owned()])) + .limit(10) + .only_if("meta='foo'") + .execute() + .await?; + ``` ## Sorting +!!! warning "Warn" + Sorting is available for only Tantivy-based FTS + You can pre-sort the documents by specifying `ordering_field_names` when creating the full-text search index. Once pre-sorted, you can then specify `ordering_field_name` while searching to return results sorted by the given -field. For example, +field. For example, -``` -table.create_fts_index(["text_field"], ordering_field_names=["sort_by_field"]) +```python +table.create_fts_index(["text_field"], use_tantivy=True, ordering_field_names=["sort_by_field"]) (table.search("terms", ordering_field_name="sort_by_field") .limit(20) @@ -105,8 +193,8 @@ table.create_fts_index(["text_field"], ordering_field_names=["sort_by_field"]) error will be raised that looks like `ValueError: The field does not exist: xxx` !!! note - The fields to sort on must be of typed unsigned integer, or else you will see - an error during indexing that looks like + The fields to sort on must be of typed unsigned integer, or else you will see + an error during indexing that looks like `TypeError: argument 'value': 'float' object cannot be interpreted as an integer`. !!! note @@ -116,6 +204,9 @@ table.create_fts_index(["text_field"], ordering_field_names=["sort_by_field"]) ## Phrase queries vs. terms queries +!!! warning "Warn" + Phrase queries are available for only Tantivy-based FTS + For full-text search you can specify either a **phrase** query like `"the old man and the sea"`, or a **terms** search query like `"(Old AND Man) AND Sea"`. For more details on the terms query syntax, see Tantivy's [query parser rules](https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html). @@ -142,7 +233,7 @@ enforce it in one of two ways: 1. Place the double-quoted query inside single quotes. For example, `table.search('"they could have been dogs OR cats"')` is treated as a phrase query. -2. Explicitly declare the `phrase_query()` method. This is useful when you have a phrase query that +1. Explicitly declare the `phrase_query()` method. This is useful when you have a phrase query that itself contains double quotes. For example, `table.search('the cats OR dogs were not really "pets" at all').phrase_query()` is treated as a phrase query. @@ -150,7 +241,7 @@ In general, a query that's declared as a phrase query will be wrapped in double double quotes replaced by single quotes. -## Configurations +## Configurations (Only for Tantivy-based FTS) By default, LanceDB configures a 1GB heap size limit for creating the index. You can reduce this if running on a smaller node, or increase this for faster performance while @@ -164,6 +255,8 @@ table.create_fts_index(["text1", "text2"], writer_heap_size=heap, replace=True) ## Current limitations +For that Tantivy-based FTS: + 1. Currently we do not yet support incremental writes. If you add data after FTS index creation, it won't be reflected in search results until you do a full reindex. diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 5aba068b..f95bf121 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -785,11 +785,26 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( ]; const table = await db.createTable("test", data); - expect(table.search("hello").toArray()).rejects.toThrow( + expect(table.search("hello", "vector").toArray()).rejects.toThrow( "No embedding functions are defined in the table", ); }); + test("full text search if no embedding function provided", async () => { + const db = await connect(tmpDir.name); + const data = [ + { text: "hello world", vector: [0.1, 0.2, 0.3] }, + { text: "goodbye world", vector: [0.4, 0.5, 0.6] }, + ]; + const table = await db.createTable("test", data); + await table.createIndex("text", { + config: Index.fts(), + }); + + const results = await table.search("hello").toArray(); + expect(results[0].text).toBe(data[0].text); + }); + test.each([ [0.4, 0.5, 0.599], // number[] Float32Array.of(0.4, 0.5, 0.599), // Float32Array diff --git a/nodejs/examples/full_text_search.ts b/nodejs/examples/full_text_search.ts new file mode 100644 index 00000000..0fcc10a7 --- /dev/null +++ b/nodejs/examples/full_text_search.ts @@ -0,0 +1,52 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import * as lancedb from "@lancedb/lancedb"; + +const db = await lancedb.connect("data/sample-lancedb"); + +const words = [ + "apple", + "banana", + "cherry", + "date", + "elderberry", + "fig", + "grape", +]; + +const data = Array.from({ length: 10_000 }, (_, i) => ({ + vector: Array(1536).fill(i), + id: i, + item: `item ${i}`, + strId: `${i}`, + doc: words[i % words.length], +})); + +const tbl = await db.createTable("myVectors", data, { mode: "overwrite" }); + +await tbl.createIndex("doc", { + config: lancedb.Index.fts(), +}); + +// --8<-- [start:full_text_search] +let result = await tbl + .search("apple") + .select(["id", "doc"]) + .limit(10) + .toArray(); +console.log(result); +// --8<-- [end:full_text_search] + +console.log("SQL search: done"); diff --git a/nodejs/examples/package-lock.json b/nodejs/examples/package-lock.json index 8c2b2b26..3ff4d029 100644 --- a/nodejs/examples/package-lock.json +++ b/nodejs/examples/package-lock.json @@ -10,7 +10,11 @@ "license": "Apache-2.0", "dependencies": { "@lancedb/lancedb": "file:../", - "@xenova/transformers": "^2.17.2" + "@xenova/transformers": "^2.17.2", + "tsc": "^2.0.4" + }, + "devDependencies": { + "typescript": "^5.5.4" }, "peerDependencies": { "typescript": "^5.0.0" @@ -18,7 +22,7 @@ }, "..": { "name": "@lancedb/lancedb", - "version": "0.7.1", + "version": "0.8.0", "cpu": [ "x64", "arm64" @@ -43,26 +47,30 @@ "@types/axios": "^0.14.0", "@types/jest": "^29.1.2", "@types/tmp": "^0.2.6", - "apache-arrow-old": "npm:apache-arrow@13.0.0", + "apache-arrow-13": "npm:apache-arrow@13.0.0", + "apache-arrow-14": "npm:apache-arrow@14.0.0", + "apache-arrow-15": "npm:apache-arrow@15.0.0", + "apache-arrow-16": "npm:apache-arrow@16.0.0", + "apache-arrow-17": "npm:apache-arrow@17.0.0", "eslint": "^8.57.0", "jest": "^29.7.0", "shx": "^0.3.4", "tmp": "^0.2.3", "ts-jest": "^29.1.2", - "typedoc": "^0.25.7", - "typedoc-plugin-markdown": "^3.17.1", - "typescript": "^5.3.3", + "typedoc": "^0.26.4", + "typedoc-plugin-markdown": "^4.2.1", + "typescript": "^5.5.4", "typescript-eslint": "^7.1.0" }, "engines": { "node": ">= 18" }, "optionalDependencies": { - "@xenova/transformers": "^2.17.2", + "@xenova/transformers": ">=2.17 < 3", "openai": "^4.29.2" }, "peerDependencies": { - "apache-arrow": "^15.0.0" + "apache-arrow": ">=13.0.0 <=17.0.0" } }, "node_modules/@huggingface/jinja": { @@ -785,6 +793,15 @@ "b4a": "^1.6.4" } }, + "node_modules/tsc": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/tsc/-/tsc-2.0.4.tgz", + "integrity": "sha512-fzoSieZI5KKJVBYGvwbVZs/J5za84f2lSTLPYf6AGiIf43tZ3GNrI1QzTLcjtyDDP4aLxd46RTZq1nQxe7+k5Q==", + "license": "MIT", + "bin": { + "tsc": "bin/tsc" + } + }, "node_modules/tunnel-agent": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz", @@ -797,10 +814,11 @@ } }, "node_modules/typescript": { - "version": "5.5.2", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.5.2.tgz", - "integrity": "sha512-NcRtPEOsPFFWjobJEtfihkLCZCXZt/os3zf8nTxjVH3RvTSxjrCamJpbExGvYOF+tFHc3pA65qpdwPbzjohhew==", - "peer": true, + "version": "5.5.4", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.5.4.tgz", + "integrity": "sha512-Mtq29sKDAEYP7aljRgtPOpTvOfbwRWlS6dPRzwjdE+C0R4brX/GUyhHSecbHMFLNBLcJIPt9nl9yG5TZ1weH+Q==", + "dev": true, + "license": "Apache-2.0", "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/nodejs/examples/package.json b/nodejs/examples/package.json index ecc1b44f..9f89287a 100644 --- a/nodejs/examples/package.json +++ b/nodejs/examples/package.json @@ -13,7 +13,16 @@ "@lancedb/lancedb": "file:../", "@xenova/transformers": "^2.17.2" }, - "peerDependencies": { - "typescript": "^5.0.0" + "devDependencies": { + "typescript": "^5.5.4" + }, + "compilerOptions": { + "target": "ESNext", + "module": "ESNext", + "moduleResolution": "Node", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true } } diff --git a/nodejs/examples/search.ts b/nodejs/examples/search.ts index 07f4323a..80d64795 100644 --- a/nodejs/examples/search.ts +++ b/nodejs/examples/search.ts @@ -32,6 +32,7 @@ const _results2 = await tbl .distanceType("cosine") .limit(10) .toArray(); +console.log(_results2); // --8<-- [end:search2] console.log("search: done"); diff --git a/nodejs/lancedb/embedding/registry.ts b/nodejs/lancedb/embedding/registry.ts index b1f62be4..a5af3523 100644 --- a/nodejs/lancedb/embedding/registry.ts +++ b/nodejs/lancedb/embedding/registry.ts @@ -37,6 +37,13 @@ interface EmbeddingFunctionCreate { export class EmbeddingFunctionRegistry { #functions = new Map(); + /** + * Get the number of registered functions + */ + length() { + return this.#functions.size; + } + /** * Register an embedding function * @param name The name of the function diff --git a/nodejs/lancedb/indices.ts b/nodejs/lancedb/indices.ts index cf5fab63..37aa3c5d 100644 --- a/nodejs/lancedb/indices.ts +++ b/nodejs/lancedb/indices.ts @@ -175,6 +175,22 @@ export class Index { static btree() { return new Index(LanceDbIndex.btree()); } + + /** + * Create a full text search index + * + * A full text search index is an index on a string column, so that you can conduct full + * text searches on the column. + * + * The results of a full text search are ordered by relevance measured by BM25. + * + * You can combine filters with full text search. + * + * For now, the full text search index only supports English, and doesn't support phrase search. + */ + static fts() { + return new Index(LanceDbIndex.fts()); + } } export interface IndexOptions { diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 0f52acc9..c96d0c8e 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -88,6 +88,19 @@ export interface QueryExecutionOptions { maxBatchLength?: number; } +/** + * Options that control the behavior of a full text search + */ +export interface FullTextSearchOptions { + /** + * The columns to search + * + * If not specified, all indexed columns will be searched. + * For now, only one column can be searched. + */ + columns?: string | string[]; +} + /** Common methods supported by all query types */ export class QueryBase implements AsyncIterable @@ -134,6 +147,25 @@ export class QueryBase return this.where(predicate); } + fullTextSearch( + query: string, + options?: Partial, + ): this { + let columns = null; + if (options) { + if (typeof options.columns === "string") { + columns = [options.columns]; + } else if (Array.isArray(options.columns)) { + columns = options.columns; + } + } + + this.doCall((inner: NativeQueryType) => + inner.fullTextSearch(query, columns), + ); + return this; + } + /** * Return only the specified columns. * diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 768be93c..83758b39 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -270,22 +270,23 @@ export abstract class Table { * @returns {Query} A builder that can be used to parameterize the query */ abstract query(): Query; + /** * Create a search query to find the nearest neighbors - * of the given query vector - * @param {string} query - the query. This will be converted to a vector using the table's provided embedding function - * @note If no embedding functions are defined in the table, this will error when collecting the results. + * of the given query + * @param {string | IntoVector} query - the query, a vector or string + * @param {string} queryType - the type of the query, "vector", "fts", or "auto" + * @param {string | string[]} ftsColumns - the columns to search in for full text search + * for now, only one column can be searched at a time. * - * This is just a convenience method for calling `.query().nearestTo(await myEmbeddingFunction(query))` + * when "auto" is used, if the query is a string and an embedding function is defined, it will be treated as a vector query + * if the query is a string and no embedding function is defined, it will be treated as a full text search query */ - abstract search(query: string): VectorQuery; - /** - * Create a search query to find the nearest neighbors - * of the given query vector - * @param {IntoVector} query - the query vector - * This is just a convenience method for calling `.query().nearestTo(query)` - */ - abstract search(query: IntoVector): VectorQuery; + abstract search( + query: string | IntoVector, + queryType?: string, + ftsColumns?: string | string[], + ): VectorQuery | Query; /** * Search the table with a given query vector. * @@ -581,27 +582,50 @@ export class LocalTable extends Table { query(): Query { return new Query(this.inner); } - search(query: string | IntoVector): VectorQuery { - if (typeof query !== "string") { - return this.vectorSearch(query); - } else { - const queryPromise = this.getEmbeddingFunctions().then( - async (functions) => { - // TODO: Support multiple embedding functions - const embeddingFunc: EmbeddingFunctionConfig | undefined = functions - .values() - .next().value; - if (!embeddingFunc) { - return Promise.reject( - new Error("No embedding functions are defined in the table"), - ); - } - return await embeddingFunc.function.computeQueryEmbeddings(query); - }, - ); - return this.query().nearestTo(queryPromise); + search( + query: string | IntoVector, + queryType: string = "auto", + ftsColumns?: string | string[], + ): VectorQuery | Query { + if (typeof query !== "string") { + if (queryType === "fts") { + throw new Error("Cannot perform full text search on a vector query"); + } + return this.vectorSearch(query); } + + // If the query is a string, we need to determine if it is a vector query or a full text search query + if (queryType === "fts") { + return this.query().fullTextSearch(query, { + columns: ftsColumns, + }); + } + + // The query type is auto or vector + // fall back to full text search if no embedding functions are defined and the query is a string + if (queryType === "auto" && getRegistry().length() === 0) { + return this.query().fullTextSearch(query, { + columns: ftsColumns, + }); + } + + const queryPromise = this.getEmbeddingFunctions().then( + async (functions) => { + // TODO: Support multiple embedding functions + const embeddingFunc: EmbeddingFunctionConfig | undefined = functions + .values() + .next().value; + if (!embeddingFunc) { + return Promise.reject( + new Error("No embedding functions are defined in the table"), + ); + } + return await embeddingFunc.function.computeQueryEmbeddings(query); + }, + ); + + return this.query().nearestTo(queryPromise); } vectorSearch(vector: IntoVector): VectorQuery { diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index d2ba91a8..bbe9ab2f 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -43,7 +43,7 @@ "ts-jest": "^29.1.2", "typedoc": "^0.26.4", "typedoc-plugin-markdown": "^4.2.1", - "typescript": "^5.3.3", + "typescript": "^5.5.4", "typescript-eslint": "^7.1.0" }, "engines": { @@ -9292,10 +9292,11 @@ } }, "node_modules/typescript": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz", - "integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==", + "version": "5.5.4", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.5.4.tgz", + "integrity": "sha512-Mtq29sKDAEYP7aljRgtPOpTvOfbwRWlS6dPRzwjdE+C0R4brX/GUyhHSecbHMFLNBLcJIPt9nl9yG5TZ1weH+Q==", "dev": true, + "license": "Apache-2.0", "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/nodejs/package.json b/nodejs/package.json index d9e98579..baf47a6a 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -53,7 +53,7 @@ "ts-jest": "^29.1.2", "typedoc": "^0.26.4", "typedoc-plugin-markdown": "^4.2.1", - "typescript": "^5.3.3", + "typescript": "^5.5.4", "typescript-eslint": "^7.1.0" }, "ava": { diff --git a/nodejs/src/index.rs b/nodejs/src/index.rs index cef43505..461b0021 100644 --- a/nodejs/src/index.rs +++ b/nodejs/src/index.rs @@ -14,7 +14,7 @@ use std::sync::Mutex; -use lancedb::index::scalar::BTreeIndexBuilder; +use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder}; use lancedb::index::vector::IvfPqIndexBuilder; use lancedb::index::Index as LanceDbIndex; use napi_derive::napi; @@ -76,4 +76,11 @@ impl Index { inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))), } } + + #[napi(factory)] + pub fn fts() -> Self { + Self { + inner: Mutex::new(Some(LanceDbIndex::FTS(FtsIndexBuilder::default()))), + } + } } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 692dc56b..c0fd16d9 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::ExecutableQuery; use lancedb::query::Query as LanceDbQuery; use lancedb::query::QueryBase; @@ -42,6 +43,12 @@ impl Query { self.inner = self.inner.clone().only_if(predicate); } + #[napi] + pub fn full_text_search(&mut self, query: String, columns: Option>) { + let query = FullTextSearchQuery::new(query).columns(columns); + self.inner = self.inner.clone().full_text_search(query); + } + #[napi] pub fn select(&mut self, columns: Vec<(String, String)>) { self.inner = self.inner.clone().select(Select::dynamic(&columns)); @@ -138,6 +145,12 @@ impl VectorQuery { self.inner = self.inner.clone().only_if(predicate); } + #[napi] + pub fn full_text_search(&mut self, query: String, columns: Option>) { + let query = FullTextSearchQuery::new(query).columns(columns); + self.inner = self.inner.clone().full_text_search(query); + } + #[napi] pub fn select(&mut self, columns: Vec<(String, String)>) { self.inner = self.inner.clone().select(Select::dynamic(&columns)); diff --git a/nodejs/tsconfig.json b/nodejs/tsconfig.json index a7ddc1c1..0588bb84 100644 --- a/nodejs/tsconfig.json +++ b/nodejs/tsconfig.json @@ -9,7 +9,8 @@ "allowJs": true, "resolveJsonModule": true, "emitDecoratorMetadata": true, - "experimentalDecorators": true + "experimentalDecorators": true, + "moduleResolution": "Node" }, "exclude": ["./dist/*"], "typedocOptions": { diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 11a04995..7f76c78c 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -99,6 +99,9 @@ class Query(pydantic.BaseModel): # if True then apply the filter before vector search prefilter: bool = False + # full text search query + full_text_query: Optional[Union[str, dict]] = None + # top k results to return k: int @@ -131,6 +134,7 @@ class LanceQueryBuilder(ABC): query_type: str, vector_column_name: str, ordering_field_name: str = None, + fts_columns: Union[str, List[str]] = None, ) -> LanceQueryBuilder: """ Create a query builder based on the given query and query type. @@ -226,6 +230,7 @@ class LanceQueryBuilder(ABC): self._limit = 10 self._columns = None self._where = None + self._prefilter = False self._with_row_id = False @deprecation.deprecated( @@ -664,12 +669,19 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" - def __init__(self, table: "Table", query: str, ordering_field_name: str = None): + def __init__( + self, + table: "Table", + query: str, + ordering_field_name: str = None, + fts_columns: Union[str, List[str]] = None, + ): super().__init__(table) self._query = query self._phrase_query = False self.ordering_field_name = ordering_field_name self._reranker = None + self._fts_columns = fts_columns def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder: """Set whether to use phrase query. @@ -689,6 +701,35 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): return self def to_arrow(self) -> pa.Table: + tantivy_index_path = self._table._get_fts_index_path() + if Path(tantivy_index_path).exists(): + return self.tantivy_to_arrow() + + query = self._query + if self._phrase_query: + raise NotImplementedError( + "Phrase query is not yet supported in Lance FTS. " + "Use tantivy-based index instead for now." + ) + if self._reranker: + raise NotImplementedError( + "Reranking is not yet supported in Lance FTS. " + "Use tantivy-based index instead for now." + ) + ds = self._table.to_lance() + return ds.to_table( + columns=self._columns, + filter=self._where, + limit=self._limit, + prefilter=self._prefilter, + with_row_id=self._with_row_id, + full_text_query={ + "query": query, + "columns": self._fts_columns, + }, + ) + + def tantivy_to_arrow(self) -> pa.Table: try: import tantivy except ImportError: @@ -726,11 +767,11 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): index, query, self._limit, ordering_field=self.ordering_field_name ) if len(row_ids) == 0: - empty_schema = pa.schema([pa.field("score", pa.float32())]) + empty_schema = pa.schema([pa.field("_score", pa.float32())]) return pa.Table.from_pylist([], schema=empty_schema) scores = pa.array(scores) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) - output_tbl = output_tbl.append_column("score", scores) + output_tbl = output_tbl.append_column("_score", scores) # this needs to match vector search results which are uint64 row_ids = pa.array(row_ids, type=pa.uint64()) @@ -784,8 +825,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): LanceFtsQueryBuilder The LanceQueryBuilder object. """ - self._reranker = reranker - return self + raise NotImplementedError("Reranking is not yet supported for FTS queries.") class LanceEmptyQueryBuilder(LanceQueryBuilder): @@ -856,13 +896,13 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): # convert to ranks first if needed if self._norm == "rank": vector_results = self._rank(vector_results, "_distance") - fts_results = self._rank(fts_results, "score") + fts_results = self._rank(fts_results, "_score") # normalize the scores to be between 0 and 1, 0 being most relevant vector_results = self._normalize_scores(vector_results, "_distance") # In fts higher scores represent relevance. Not inverting them here as # rerankers might need to preserve this score to support `return_score="all"` - fts_results = self._normalize_scores(fts_results, "score") + fts_results = self._normalize_scores(fts_results, "_score") results = self._reranker.rerank_hybrid( self._fts_query._query, vector_results, fts_results diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index 85536e1c..8667ca9c 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -220,8 +220,8 @@ class Reranker(ABC): def _keep_relevance_score(self, combined_results: pa.Table): if self.score == "relevance": - if "score" in combined_results.column_names: - combined_results = combined_results.drop_columns(["score"]) + if "_score" in combined_results.column_names: + combined_results = combined_results.drop_columns(["_score"]) if "_distance" in combined_results.column_names: combined_results = combined_results.drop_columns(["_distance"]) return combined_results diff --git a/python/python/lancedb/rerankers/cohere.py b/python/python/lancedb/rerankers/cohere.py index c925f54f..e4a12dbf 100644 --- a/python/python/lancedb/rerankers/cohere.py +++ b/python/python/lancedb/rerankers/cohere.py @@ -113,6 +113,6 @@ class CohereReranker(Reranker): ): result_set = self._rerank(fts_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["score"]) + result_set = result_set.drop_columns(["_score"]) return result_set diff --git a/python/python/lancedb/rerankers/colbert.py b/python/python/lancedb/rerankers/colbert.py index e09f2029..77ef58a1 100644 --- a/python/python/lancedb/rerankers/colbert.py +++ b/python/python/lancedb/rerankers/colbert.py @@ -105,7 +105,7 @@ class ColbertReranker(Reranker): ): result_set = self._rerank(fts_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["score"]) + result_set = result_set.drop_columns(["_score"]) result_set = result_set.sort_by([("_relevance_score", "descending")]) diff --git a/python/python/lancedb/rerankers/cross_encoder.py b/python/python/lancedb/rerankers/cross_encoder.py index daf02f75..88396fc3 100644 --- a/python/python/lancedb/rerankers/cross_encoder.py +++ b/python/python/lancedb/rerankers/cross_encoder.py @@ -96,7 +96,7 @@ class CrossEncoderReranker(Reranker): ): fts_results = self._rerank(fts_results, query) if self.score == "relevance": - fts_results = fts_results.drop_columns(["score"]) + fts_results = fts_results.drop_columns(["_score"]) fts_results = fts_results.sort_by([("_relevance_score", "descending")]) return fts_results diff --git a/python/python/lancedb/rerankers/jinaai.py b/python/python/lancedb/rerankers/jinaai.py index d8f22b02..4d4edcfb 100644 --- a/python/python/lancedb/rerankers/jinaai.py +++ b/python/python/lancedb/rerankers/jinaai.py @@ -117,6 +117,6 @@ class JinaReranker(Reranker): ): result_set = self._rerank(fts_results, query) if self.score == "relevance": - result_set = result_set.drop_columns(["score"]) + result_set = result_set.drop_columns(["_score"]) return result_set diff --git a/python/python/lancedb/rerankers/linear_combination.py b/python/python/lancedb/rerankers/linear_combination.py index 983fa901..3d7dcc25 100644 --- a/python/python/lancedb/rerankers/linear_combination.py +++ b/python/python/lancedb/rerankers/linear_combination.py @@ -69,12 +69,12 @@ class LinearCombinationReranker(Reranker): vi = vector_list[i] fj = fts_list[j] # invert the fts score from relevance to distance - inverted_fts_score = self._invert_score(fj["score"]) + inverted_fts_score = self._invert_score(fj["_score"]) if vi["_rowid"] == fj["_rowid"]: vi["_relevance_score"] = self._combine_score( vi["_distance"], inverted_fts_score ) - vi["score"] = fj["score"] # keep the original score + vi["_score"] = fj["_score"] # keep the original score combined_list.append(vi) i += 1 j += 1 diff --git a/python/python/lancedb/rerankers/openai.py b/python/python/lancedb/rerankers/openai.py index d24a4bcc..7e6c19b2 100644 --- a/python/python/lancedb/rerankers/openai.py +++ b/python/python/lancedb/rerankers/openai.py @@ -108,7 +108,7 @@ class OpenaiReranker(Reranker): def rerank_fts(self, query: str, fts_results: pa.Table): fts_results = self._rerank(fts_results, query) if self.score == "relevance": - fts_results = fts_results.drop_columns(["score"]) + fts_results = fts_results.drop_columns(["_score"]) fts_results = fts_results.sort_by([("_relevance_score", "descending")]) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 4a31e340..4e60eb7d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -59,7 +59,6 @@ from .util import ( if TYPE_CHECKING: import PIL from lance.dataset import CleanupStats, ReaderLike - from ._lancedb import Table as LanceDBTable, OptimizeStats from .db import LanceDBConnection from .index import BTree, IndexConfig, IvfPq @@ -350,6 +349,7 @@ class Table(ABC): def create_scalar_index( self, column: str, + index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", *, replace: bool = True, ): @@ -511,6 +511,8 @@ class Table(ABC): query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, query_type: str = "auto", + ordering_field_name: Optional[str] = None, + fts_columns: Union[str, List[str]] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1188,9 +1190,15 @@ class LanceTable(Table): index_cache_size=index_cache_size, ) - def create_scalar_index(self, column: str, *, replace: bool = True): + def create_scalar_index( + self, + column: str, + index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", + *, + replace: bool = True, + ): self._dataset_mut.create_scalar_index( - column, index_type="BTREE", replace=replace + column, index_type=index_type, replace=replace ) def create_fts_index( @@ -1201,6 +1209,7 @@ class LanceTable(Table): replace: bool = False, writer_heap_size: Optional[int] = 1024 * 1024 * 1024, tokenizer_name: str = "default", + use_tantivy: bool = True, ): """Create a full-text search index on the table. @@ -1211,6 +1220,7 @@ class LanceTable(Table): ---------- field_names: str or list of str The name(s) of the field to index. + can be only str if use_tantivy=True for now. replace: bool, default False If True, replace the existing index if it exists. Note that this is not yet an atomic operation; the index will be temporarily @@ -1218,12 +1228,31 @@ class LanceTable(Table): writer_heap_size: int, default 1GB ordering_field_names: A list of unsigned type fields to index to optionally order - results on at search time + results on at search time. + only available with use_tantivy=True tokenizer_name: str, default "default" The tokenizer to use for the index. Can be "raw", "default" or the 2 letter language code followed by "_stem". So for english it would be "en_stem". For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html + only available with use_tantivy=True for now + use_tantivy: bool, default False + If True, use the legacy full-text search implementation based on tantivy. + If False, use the new full-text search implementation based on lance-index. """ + if not use_tantivy: + if not isinstance(field_names, str): + raise ValueError("field_names must be a string when use_tantivy=False") + # delete the existing legacy index if it exists + if replace: + fs, path = fs_from_uri(self._get_fts_index_path()) + index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound + if index_exists: + fs.delete_dir(path) + self._dataset_mut.create_scalar_index( + field_names, index_type="INVERTED", replace=replace + ) + return + from .fts import create_index, populate_index if isinstance(field_names, str): @@ -1392,6 +1421,7 @@ class LanceTable(Table): vector_column_name: Optional[str] = None, query_type: str = "auto", ordering_field_name: Optional[str] = None, + fts_columns: Union[str, List[str]] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1446,6 +1476,10 @@ class LanceTable(Table): or raise an error if no corresponding embedding function is found. If the `query` is a string, then the query type is "vector" if the table has embedding functions, else the query type is "fts" + fts_columns: str or list of str, default None + The column(s) to search in for full-text search. + If None then the search is performed on all indexed columns. + For now, only one column can be searched at a time. Returns ------- @@ -1665,6 +1699,7 @@ class LanceTable(Table): "nprobes": query.nprobes, "refine_factor": query.refine_factor, }, + full_text_query=query.full_text_query, with_row_id=query.with_row_id, batch_size=batch_size, ).to_reader() diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index d9c7f53e..373ae2b6 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -22,7 +22,8 @@ import pytest from lancedb.pydantic import LanceModel, Vector -def test_basic(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_basic(tmp_path, use_tantivy): db = lancedb.connect(tmp_path) assert db.uri == str(tmp_path) @@ -55,7 +56,7 @@ def test_basic(tmp_path): assert len(rs) == 1 assert rs["item"].iloc[0] == "foo" - table.create_fts_index(["item"]) + table.create_fts_index("item", use_tantivy=use_tantivy) rs = table.search("bar", query_type="fts").to_pandas() assert len(rs) == 1 assert rs["item"].iloc[0] == "bar" diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index 77f07388..f4c7cd1c 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -74,7 +74,12 @@ def test_create_index_with_stemming(tmp_path, table): assert os.path.exists(str(tmp_path / "index")) # Check stemming by running tokenizer on non empty table - table.create_fts_index("text", tokenizer_name="en_stem") + table.create_fts_index("text", tokenizer_name="en_stem", use_tantivy=True) + + +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_create_inverted_index(table, use_tantivy): + table.create_fts_index("text", use_tantivy=use_tantivy) def test_populate_index(tmp_path, table): @@ -92,8 +97,15 @@ def test_search_index(tmp_path, table): assert len(results[1]) == 10 # _distance +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_search_fts(table, use_tantivy): + table.create_fts_index("text", use_tantivy=use_tantivy) + results = table.search("puppy").limit(10).to_list() + assert len(results) == 10 + + def test_search_ordering_field_index_table(tmp_path, table): - table.create_fts_index("text", ordering_field_names=["count"]) + table.create_fts_index("text", ordering_field_names=["count"], use_tantivy=True) rows = ( table.search("puppy", ordering_field_name="count") .limit(20) @@ -125,8 +137,9 @@ def test_search_ordering_field_index(tmp_path, table): assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows -def test_create_index_from_table(tmp_path, table): - table.create_fts_index("text") +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_create_index_from_table(tmp_path, table, use_tantivy): + table.create_fts_index("text", use_tantivy=use_tantivy) df = table.search("puppy").limit(10).select(["text"]).to_pandas() assert len(df) <= 10 assert "text" in df.columns @@ -145,15 +158,15 @@ def test_create_index_from_table(tmp_path, table): ] ) - with pytest.raises(ValueError, match="already exists"): - table.create_fts_index("text") + with pytest.raises(Exception, match="already exists"): + table.create_fts_index("text", use_tantivy=use_tantivy) - table.create_fts_index("text", replace=True) + table.create_fts_index("text", replace=True, use_tantivy=use_tantivy) assert len(table.search("gorilla").limit(1).to_pandas()) == 1 def test_create_index_multiple_columns(tmp_path, table): - table.create_fts_index(["text", "text2"]) + table.create_fts_index(["text", "text2"], use_tantivy=True) df = table.search("puppy").limit(10).to_pandas() assert len(df) == 10 assert "text" in df.columns @@ -161,20 +174,21 @@ def test_create_index_multiple_columns(tmp_path, table): def test_empty_rs(tmp_path, table, mocker): - table.create_fts_index(["text", "text2"]) + table.create_fts_index(["text", "text2"], use_tantivy=True) mocker.patch("lancedb.fts.search_index", return_value=([], [])) df = table.search("puppy").limit(10).to_pandas() assert len(df) == 0 def test_nested_schema(tmp_path, table): - table.create_fts_index("nested.text") + table.create_fts_index("nested.text", use_tantivy=True) rs = table.search("puppy").limit(10).to_list() assert len(rs) == 10 -def test_search_index_with_filter(table): - table.create_fts_index("text") +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_search_index_with_filter(table, use_tantivy): + table.create_fts_index("text", use_tantivy=use_tantivy) orig_import = __import__ def import_mock(name, *args): @@ -186,7 +200,7 @@ def test_search_index_with_filter(table): with mock.patch("builtins.__import__", side_effect=import_mock): rs = table.search("puppy").where("id=1").limit(10) # test schema - assert rs.to_arrow().drop("score").schema.equals(table.schema) + assert rs.to_arrow().drop("_score").schema.equals(table.schema) rs = rs.to_list() for r in rs: @@ -204,7 +218,8 @@ def test_search_index_with_filter(table): assert r["_rowid"] is not None -def test_null_input(table): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_null_input(table, use_tantivy): table.add( [ { @@ -217,12 +232,12 @@ def test_null_input(table): } ] ) - table.create_fts_index("text") + table.create_fts_index("text", use_tantivy=use_tantivy) def test_syntax(table): # https://github.com/lancedb/lancedb/issues/769 - table.create_fts_index("text") + table.create_fts_index("text", use_tantivy=True) with pytest.raises(ValueError, match="Syntax Error"): table.search("they could have been dogs OR").limit(10).to_list() diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index d2d90e42..2c27b61d 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -22,7 +22,7 @@ from lancedb.table import LanceTable pytest.importorskip("lancedb.fts") -def get_test_table(tmp_path): +def get_test_table(tmp_path, use_tantivy): db = lancedb.connect(tmp_path) # Create a LanceDB table schema with a vector and a text column emb = EmbeddingFunctionRegistry.get_instance().get("test")() @@ -89,7 +89,7 @@ def get_test_table(tmp_path): ) # Create a fts index - table.create_fts_index("text") + table.create_fts_index("text", use_tantivy=use_tantivy) return table, MyTable @@ -174,8 +174,8 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): assert len(result) == 20 and result == result_arrow -def _run_test_hybrid_reranker(reranker, tmp_path): - table, schema = get_test_table(tmp_path) +def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): + table, schema = get_test_table(tmp_path, use_tantivy) # The default reranker result1 = ( table.search( @@ -221,14 +221,16 @@ def _run_test_hybrid_reranker(reranker, tmp_path): ) -def test_linear_combination(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_linear_combination(tmp_path, use_tantivy): reranker = LinearCombinationReranker() - _run_test_hybrid_reranker(reranker, tmp_path) + _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy) -def test_rrf_reranker(tmp_path): +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_rrf_reranker(tmp_path, use_tantivy): reranker = RRFReranker() - _run_test_hybrid_reranker(reranker, tmp_path) + _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy) @pytest.mark.skipif( diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 63de5ca7..187139f0 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -56,6 +56,7 @@ tokenizers = { version = "0.19.1", optional = true } [dev-dependencies] tempfile = "3.5.0" rand = { version = "0.8.3", features = ["small_rng"] } +random_word = { version = "0.4.3", features = ["en"] } uuid = { version = "1.7.0", features = ["v4"] } walkdir = "2" aws-sdk-dynamodb = { version = "1.38.0" } diff --git a/rust/lancedb/examples/full_text_search.rs b/rust/lancedb/examples/full_text_search.rs new file mode 100644 index 00000000..ad9a880f --- /dev/null +++ b/rust/lancedb/examples/full_text_search.rs @@ -0,0 +1,114 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray}; +use arrow_schema::{DataType, Field, Schema}; + +use futures::TryStreamExt; +use lance_index::scalar::FullTextSearchQuery; +use lancedb::connection::Connection; +use lancedb::index::scalar::FtsIndexBuilder; +use lancedb::index::Index; +use lancedb::query::{ExecutableQuery, QueryBase}; +use lancedb::{connect, Result, Table}; +use rand::random; + +#[tokio::main] +async fn main() -> Result<()> { + if std::path::Path::new("data").exists() { + std::fs::remove_dir_all("data").unwrap(); + } + let uri = "data/sample-lancedb"; + let db = connect(uri).execute().await?; + let tbl = create_table(&db).await?; + + create_index(&tbl).await?; + search_index(&tbl).await?; + Ok(()) +} + +fn create_some_records() -> Result> { + const TOTAL: usize = 1000; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("doc", DataType::Utf8, true), + ])); + + let words = random_word::all(random_word::Lang::En) + .iter() + .step_by(1024) + .take(500) + .map(|w| *w) + .collect::>(); + let n_terms = 3; + let batches = RecordBatchIterator::new( + vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), + Arc::new(StringArray::from_iter_values((0..TOTAL).map(|_| { + (0..n_terms) + .map(|_| words[random::() % words.len()]) + .collect::>() + .join(" ") + }))), + ], + ) + .unwrap()] + .into_iter() + .map(Ok), + schema.clone(), + ); + Ok(Box::new(batches)) +} + +async fn create_table(db: &Connection) -> Result { + let initial_data: Box = create_some_records()?; + let tbl = db.create_table("my_table", initial_data).execute().await?; + Ok(tbl) +} + +async fn create_index(table: &Table) -> Result<()> { + table + .create_index(&["doc"], Index::FTS(FtsIndexBuilder::default())) + .execute() + .await?; + Ok(()) +} + +async fn search_index(table: &Table) -> Result<()> { + let words = random_word::all(random_word::Lang::En) + .iter() + .step_by(1024) + .take(500) + .map(|w| *w) + .collect::>(); + let query = words[0].to_owned(); + println!("Searching for: {}", query); + + let mut results = table + .query() + .full_text_search(FullTextSearchQuery::new(words[0].to_owned())) + .select(lancedb::query::Select::Columns(vec!["doc".to_owned()])) + .limit(10) + .execute() + .await?; + while let Some(batch) = results.try_next().await? { + println!("{:?}", batch); + } + Ok(()) +} diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 700446ec..57763d23 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -1217,7 +1217,7 @@ mod tests { let tbl = db .create_table("v2_test", make_data()) - .use_legacy_format(false) + .data_storage_version(LanceFileVersion::Stable) .execute() .await .unwrap(); diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index 4346e579..949af0be 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use scalar::FtsIndexBuilder; use serde::Deserialize; use serde_with::skip_serializing_none; @@ -30,6 +31,7 @@ pub mod vector; pub enum Index { Auto, BTree(BTreeIndexBuilder), + FTS(FtsIndexBuilder), IvfPq(IvfPqIndexBuilder), IvfHnswPq(IvfHnswPqIndexBuilder), IvfHnswSq(IvfHnswSqIndexBuilder), diff --git a/rust/lancedb/src/index/scalar.rs b/rust/lancedb/src/index/scalar.rs index 7d447cc7..9623efe3 100644 --- a/rust/lancedb/src/index/scalar.rs +++ b/rust/lancedb/src/index/scalar.rs @@ -28,3 +28,13 @@ pub struct BTreeIndexBuilder {} impl BTreeIndexBuilder {} + +/// Builder for a full text search index +/// +/// A full text search index is an index on a string column that allows for full text search +#[derive(Debug, Clone, Default)] +pub struct FtsIndexBuilder {} + +impl FtsIndexBuilder {} + +pub use lance_index::scalar::FullTextSearchQuery; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index bcbc7980..714200ae 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -21,6 +21,7 @@ use datafusion_physical_plan::ExecutionPlan; use half::f16; use lance::dataset::scanner::DatasetRecordBatchStream; use lance_datafusion::exec::execute_plan; +use lance_index::scalar::FullTextSearchQuery; use crate::arrow::SendableRecordBatchStream; use crate::error::{Error, Result}; @@ -351,6 +352,17 @@ pub trait QueryBase { /// on the filter column(s). fn only_if(self, filter: impl AsRef) -> Self; + /// Perform a full text search on the table. + /// + /// The results will be returned in order of BM25 scores. + /// + /// This method is only valid on tables that have a full text search index. + /// + /// ```ignore + /// query.full_text_search(FullTextSearchQuery::new("hello world")) + /// ``` + fn full_text_search(self, query: FullTextSearchQuery) -> Self; + /// Return only the specified columns. /// /// By default a query will return all columns from the table. However, this can have @@ -401,6 +413,11 @@ impl QueryBase for T { self } + fn full_text_search(mut self, query: FullTextSearchQuery) -> Self { + self.mut_query().full_text_search = Some(query); + self + } + fn select(mut self, select: Select) -> Self { self.mut_query().select = select; self @@ -502,8 +519,13 @@ pub struct Query { /// limit the number of rows to return. pub(crate) limit: Option, + /// Apply filter to the returned rows. pub(crate) filter: Option, + + /// Perform a full text search on the table. + pub(crate) full_text_search: Option, + /// Select column projection. pub(crate) select: Select, @@ -520,6 +542,7 @@ impl Query { parent, limit: None, filter: None, + full_text_search: None, select: Select::All, fast_search: false, } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index ffb763bd..d716643c 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1054,6 +1054,10 @@ impl NativeTable { ) } + fn supported_fts_data_type(dtype: &DataType) -> bool { + matches!(dtype, DataType::Utf8 | DataType::LargeUtf8) + } + fn supported_vector_data_type(dtype: &DataType) -> bool { match dtype { DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()), @@ -1524,6 +1528,33 @@ impl NativeTable { Ok(()) } + async fn create_fts_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> { + if !Self::supported_fts_data_type(field.data_type()) { + return Err(Error::Schema { + message: format!( + "A FTS index cannot be created on the field `{}` which has data type {}", + field.name(), + field.data_type() + ), + }); + } + + let mut dataset = self.dataset.get_mut().await?; + let lance_idx_params = lance_index::scalar::ScalarIndexParams { + force_index_type: Some(lance_index::scalar::ScalarIndexType::Inverted), + }; + dataset + .create_index( + &[field.name()], + IndexType::Scalar, + None, + &lance_idx_params, + opts.replace, + ) + .await?; + Ok(()) + } + async fn generic_query( &self, query: &VectorQuery, @@ -1659,6 +1690,7 @@ impl TableInternal for NativeTable { match opts.index { Index::Auto => self.create_auto_index(field, opts).await, Index::BTree(_) => self.create_btree_index(field, opts).await, + Index::FTS(_) => self.create_fts_index(field, opts).await, Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await, Index::IvfHnswPq(ivf_hnsw_pq) => { self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace) @@ -1789,6 +1821,10 @@ impl TableInternal for NativeTable { scanner.filter(filter)?; } + if let Some(fts) = &query.base.full_text_search { + scanner.full_text_search(fts.clone())?; + } + if let Some(refine_factor) = query.refine_factor { scanner.refine(refine_factor); }