feat(python): Hybrid search & Reranker API (#824)

based on https://github.com/lancedb/lancedb/pull/713
- The Reranker api can be plugged into vector only or fts only search
but this PR doesn't do that (see example -
https://txt.cohere.com/rerank/)


### Default reranker -- `LinearCombinationReranker(weight=0.7,
fill=1.0)`

```
table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas()
```
### Available rerankers
LinearCombinationReranker
```
from lancedb.rerankers import LinearCombinationReranker

# Same as default 
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=LinearCombinationReranker()
                                     ).to_pandas()

# with custom params
reranker = LinearCombinationReranker(weight=0.3, fill=1.0)
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=reranker
                                     ).to_pandas()
```

Cohere Reranker
```
from lancedb.rerankers import CohereReranker

# default model.. English and multi-lingual supported. See docstring for available custom params
table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank",  # score or rank
                                      reranker=CohereReranker()
                                     ).to_pandas()

```

CrossEncoderReranker

```
from lancedb.rerankers import CrossEncoderReranker

table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank", 
                                      reranker=CrossEncoderReranker()
                                     ).to_pandas()

```

## Using custom Reranker
```
from lancedb.reranker import Reranker

class CustomReranker(Reranker):
    def rerank_hybrid(self, vector_result, fts_result):
           combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic
           # Custom rerank logic here
           
           return combined_res
```

- [x] Expand testing
- [x] Make sure usage makes sense
- [x] Run simple benchmarks for correctness (Seeing weird result from
cohere reranker in the toy example)
- Support diverse rerankers by default:
- [x] Cross encoding
- [x] Cohere
- [x] Reciprocal Rank Fusion

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2024-01-30 19:10:33 +05:30
committed by GitHub
parent f150768739
commit 3ffed89793
16 changed files with 1136 additions and 41 deletions

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
import inspect
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
import lance
import numpy as np
@@ -33,8 +33,7 @@ from .query import LanceQueryBuilder, Query
from .util import (
fs_from_uri,
join_uri,
safe_import_pandas,
safe_import_polars,
safe_import,
value_to_sql,
)
from .utils.events import register_event
@@ -48,8 +47,8 @@ if TYPE_CHECKING:
from .db import LanceDBConnection
pd = safe_import_pandas()
pl = safe_import_polars()
pd = safe_import("pandas")
pl = safe_import("polars")
def _sanitize_data(
@@ -338,7 +337,7 @@ class Table(ABC):
@abstractmethod
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -924,7 +923,7 @@ class LanceTable(Table):
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -1194,6 +1193,7 @@ class LanceTable(Table):
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
with_row_id=query.with_row_id,
)
def cleanup_old_versions(