feat(python): Hybrid search & Reranker API (#824)

based on https://github.com/lancedb/lancedb/pull/713
- The Reranker api can be plugged into vector only or fts only search
but this PR doesn't do that (see example -
https://txt.cohere.com/rerank/)


### Default reranker -- `LinearCombinationReranker(weight=0.7,
fill=1.0)`

```
table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas()
```
### Available rerankers
LinearCombinationReranker
```
from lancedb.rerankers import LinearCombinationReranker

# Same as default 
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=LinearCombinationReranker()
                                     ).to_pandas()

# with custom params
reranker = LinearCombinationReranker(weight=0.3, fill=1.0)
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=reranker
                                     ).to_pandas()
```

Cohere Reranker
```
from lancedb.rerankers import CohereReranker

# default model.. English and multi-lingual supported. See docstring for available custom params
table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank",  # score or rank
                                      reranker=CohereReranker()
                                     ).to_pandas()

```

CrossEncoderReranker

```
from lancedb.rerankers import CrossEncoderReranker

table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank", 
                                      reranker=CrossEncoderReranker()
                                     ).to_pandas()

```

## Using custom Reranker
```
from lancedb.reranker import Reranker

class CustomReranker(Reranker):
    def rerank_hybrid(self, vector_result, fts_result):
           combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic
           # Custom rerank logic here
           
           return combined_res
```

- [x] Expand testing
- [x] Make sure usage makes sense
- [x] Run simple benchmarks for correctness (Seeing weird result from
cohere reranker in the toy example)
- Support diverse rerankers by default:
- [x] Cross encoding
- [x] Cohere
- [x] Reciprocal Rank Fusion

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2024-01-30 19:10:33 +05:30
committed by Weston Pace
parent ecbbe185c7
commit a41f7be88d
16 changed files with 1136 additions and 41 deletions

View File

@@ -682,3 +682,57 @@ def test_count_rows(db):
assert len(table) == 2
assert table.count_rows() == 2
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db):
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
# Probably not being parsed right by the fts
db = MockDB("~/lancedb_")
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Create a list of 10 unique english phrases
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(MyTable)
)
result2 = ( # noqa
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(MyTable)
)
result3 = table.search(
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3