mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 10:30:40 +00:00
feat(python): support model-backed native FTS tokenizers
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Literal, Optional
|
||||
from ._lancedb import (
|
||||
IndexConfig,
|
||||
)
|
||||
from .types import BaseTokenizerType
|
||||
|
||||
lang_mapping = {
|
||||
"ar": "Arabic",
|
||||
@@ -111,8 +112,12 @@ class FTS:
|
||||
- "simple": Splits text by whitespace and punctuation.
|
||||
- "whitespace": Split text by whitespace, but not punctuation.
|
||||
- "raw": No tokenization. The entire text is treated as a single token.
|
||||
- "ngram": N-gram tokenizer for substring-style matching.
|
||||
- "jieba/*": Jieba tokenizer loaded from ``LANCE_LANGUAGE_MODEL_HOME``.
|
||||
- "lindera/*": Lindera tokenizer loaded from ``LANCE_LANGUAGE_MODEL_HOME``.
|
||||
language : str, default "English"
|
||||
The language to use for tokenization.
|
||||
The language to use for stemming and stop-word removal. This is not the
|
||||
primary way to enable CJK tokenization.
|
||||
max_token_length : int, default 40
|
||||
The maximum token length to index. Tokens longer than this length will be
|
||||
ignored.
|
||||
@@ -127,10 +132,15 @@ class FTS:
|
||||
ascii_folding : bool, default True
|
||||
Whether to fold ASCII characters. This converts accented characters to
|
||||
their ASCII equivalent. For example, "café" would be converted to "cafe".
|
||||
|
||||
Notes
|
||||
-----
|
||||
Model-backed tokenizers such as ``jieba/default`` and ``lindera/ipadic``
|
||||
require tokenizer models under ``LANCE_LANGUAGE_MODEL_HOME``.
|
||||
"""
|
||||
|
||||
with_position: bool = False
|
||||
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple"
|
||||
base_tokenizer: BaseTokenizerType = "simple"
|
||||
language: str = "English"
|
||||
max_token_length: Optional[int] = 40
|
||||
lower_case: bool = True
|
||||
|
||||
@@ -39,6 +39,7 @@ from lancedb.table import _normalize_progress
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
||||
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
|
||||
from ..types import BaseTokenizerType
|
||||
|
||||
|
||||
class RemoteTable(Table):
|
||||
@@ -167,7 +168,7 @@ class RemoteTable(Table):
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
with_position: bool = False,
|
||||
# tokenizer configs:
|
||||
base_tokenizer: str = "simple",
|
||||
base_tokenizer: BaseTokenizerType = "simple",
|
||||
language: str = "English",
|
||||
max_token_length: Optional[int] = 40,
|
||||
lower_case: bool = True,
|
||||
|
||||
@@ -86,6 +86,52 @@ from .util import (
|
||||
)
|
||||
from .index import lang_mapping
|
||||
|
||||
_MODEL_BACKED_TOKENIZER_PREFIXES = ("jieba", "lindera")
|
||||
_MODEL_BACKED_TOKENIZER_ERRORS = (
|
||||
"unknown base tokenizer",
|
||||
"Invalid directory path:",
|
||||
"Failed to load Jieba",
|
||||
"Failed to load tokenizer config",
|
||||
"Failed to initialize default tokenizer",
|
||||
)
|
||||
|
||||
|
||||
def _add_unique_note(exception: BaseException, note: str) -> None:
|
||||
existing_notes = getattr(exception, "__notes__", ()) or ()
|
||||
if note not in existing_notes:
|
||||
add_note(exception, note)
|
||||
|
||||
|
||||
def _is_model_backed_tokenizer(base_tokenizer: str) -> bool:
|
||||
return any(
|
||||
base_tokenizer == prefix or base_tokenizer.startswith(f"{prefix}/")
|
||||
for prefix in _MODEL_BACKED_TOKENIZER_PREFIXES
|
||||
)
|
||||
|
||||
|
||||
def _maybe_add_fts_error_note(
|
||||
exception: BaseException, *, base_tokenizer: str, language: Optional[str] = None
|
||||
) -> None:
|
||||
message = str(exception)
|
||||
if language is not None and "not support the requested language" in message:
|
||||
supported_langs = ", ".join(lang_mapping.values())
|
||||
_add_unique_note(exception, f"Supported languages: {supported_langs}")
|
||||
return
|
||||
|
||||
if not _is_model_backed_tokenizer(base_tokenizer):
|
||||
return
|
||||
|
||||
if not any(marker in message for marker in _MODEL_BACKED_TOKENIZER_ERRORS):
|
||||
return
|
||||
|
||||
_add_unique_note(
|
||||
exception,
|
||||
"Model-backed tokenizers such as 'jieba/default' and 'lindera/ipadic' "
|
||||
"require tokenizer models under LANCE_LANGUAGE_MODEL_HOME. Expected "
|
||||
"layouts include '$LANCE_LANGUAGE_MODEL_HOME/jieba/default/...' and "
|
||||
"'$LANCE_LANGUAGE_MODEL_HOME/lindera/ipadic/...'.",
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .db import LanceDBConnection
|
||||
@@ -958,7 +1004,8 @@ class Table(ABC):
|
||||
tokenizer_name: str, default "default"
|
||||
A compatibility alias for native tokenizer configs. Can be "raw",
|
||||
"default" or the 2 letter language code followed by "_stem". So
|
||||
for english it would be "en_stem".
|
||||
for english it would be "en_stem". Prefer ``base_tokenizer`` for
|
||||
new code.
|
||||
use_tantivy: bool, default False
|
||||
Deprecated legacy Tantivy parameter. Setting this to True raises an
|
||||
error.
|
||||
@@ -972,8 +1019,11 @@ class Table(ABC):
|
||||
- "whitespace": Split text by whitespace, but not punctuation.
|
||||
- "raw": No tokenization. The entire text is treated as a single token.
|
||||
- "ngram": N-Gram tokenizer.
|
||||
- "jieba/*": Jieba tokenizer loaded from ``LANCE_LANGUAGE_MODEL_HOME``.
|
||||
- "lindera/*": Lindera tokenizer loaded from ``LANCE_LANGUAGE_MODEL_HOME``.
|
||||
language : str, default "English"
|
||||
The language to use for tokenization.
|
||||
The language to use for stemming and stop-word removal. This is not
|
||||
the primary way to enable CJK tokenization.
|
||||
max_token_length : int, default 40
|
||||
The maximum token length to index. Tokens longer than this length will be
|
||||
ignored.
|
||||
@@ -999,6 +1049,11 @@ class Table(ABC):
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Model-backed tokenizers such as ``jieba/default`` and ``lindera/ipadic``
|
||||
require tokenizer models under ``LANCE_LANGUAGE_MODEL_HOME``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2462,14 +2517,22 @@ class LanceTable(Table):
|
||||
**tokenizer_configs,
|
||||
)
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
field_names,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
try:
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
field_names,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
)
|
||||
except (ValueError, RuntimeError) as e:
|
||||
_maybe_add_fts_error_note(
|
||||
e,
|
||||
base_tokenizer=config.base_tokenizer,
|
||||
language=config.language,
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def infer_tokenizer_configs(tokenizer_name: str) -> dict:
|
||||
@@ -3865,11 +3928,13 @@ class AsyncTable:
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "not support the requested language" in str(e):
|
||||
supported_langs = ", ".join(lang_mapping.values())
|
||||
help_msg = f"Supported languages: {supported_langs}"
|
||||
add_note(e, help_msg)
|
||||
except (ValueError, RuntimeError) as e:
|
||||
if isinstance(config, FTS):
|
||||
_maybe_add_fts_error_note(
|
||||
e,
|
||||
base_tokenizer=config.base_tokenizer,
|
||||
language=config.language,
|
||||
)
|
||||
raise e
|
||||
|
||||
async def drop_index(self, name: str) -> None:
|
||||
|
||||
@@ -40,4 +40,5 @@ IndexType = Literal[
|
||||
]
|
||||
|
||||
# Tokenizer literals
|
||||
BaseTokenizerType = Literal["simple", "raw", "whitespace", "ngram"]
|
||||
BuiltinTokenizerType = Literal["simple", "raw", "whitespace", "ngram"]
|
||||
BaseTokenizerType = BuiltinTokenizerType | str
|
||||
|
||||
8
python/python/tests/models/jieba/default/dict.txt
Normal file
8
python/python/tests/models/jieba/default/dict.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
我们 98740 r
|
||||
都 202780 d
|
||||
有 423765 v
|
||||
光明 1219 n
|
||||
的 318825 uj
|
||||
前途 1263 n
|
||||
前 62779 f
|
||||
途 857 n
|
||||
4
python/python/tests/models/lindera/ipadic/config.yml
Normal file
4
python/python/tests/models/lindera/ipadic/config.yml
Normal file
@@ -0,0 +1,4 @@
|
||||
segmenter:
|
||||
mode: "normal"
|
||||
dictionary:
|
||||
path: "./python/tests/models/lindera/ipadic/main"
|
||||
BIN
python/python/tests/models/lindera/ipadic/main.zip
Normal file
BIN
python/python/tests/models/lindera/ipadic/main.zip
Normal file
Binary file not shown.
@@ -15,7 +15,10 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from unittest import mock
|
||||
from pathlib import Path
|
||||
import zipfile
|
||||
|
||||
import lancedb as ldb
|
||||
from lancedb.db import DBConnection
|
||||
@@ -36,6 +39,8 @@ import pytest
|
||||
import pytest_asyncio
|
||||
from utils import exception_output
|
||||
|
||||
TEST_LANGUAGE_MODEL_HOME = Path(__file__).parent / "models"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> ldb.table.LanceTable:
|
||||
@@ -89,6 +94,30 @@ def table(tmp_path) -> ldb.table.LanceTable:
|
||||
return table
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def language_model_home(monkeypatch):
|
||||
monkeypatch.setenv("LANCE_LANGUAGE_MODEL_HOME", str(TEST_LANGUAGE_MODEL_HOME))
|
||||
return TEST_LANGUAGE_MODEL_HOME
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lindera_ipadic(language_model_home):
|
||||
model_path = language_model_home / "lindera" / "ipadic"
|
||||
extracted_model = model_path / "main"
|
||||
|
||||
if extracted_model.exists():
|
||||
shutil.rmtree(extracted_model)
|
||||
|
||||
with zipfile.ZipFile(model_path / "main.zip", "r") as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if extracted_model.exists():
|
||||
shutil.rmtree(extracted_model)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_table(tmp_path) -> ldb.table.AsyncTable:
|
||||
# Use local random state to avoid affecting other tests
|
||||
@@ -684,6 +713,90 @@ def test_fts_ngram(mem_db: DBConnection):
|
||||
assert set(r["text"] for r in results) == {"lance database", "lance is cool"}
|
||||
|
||||
|
||||
def test_fts_jieba_tokenizer(mem_db: DBConnection, language_model_home):
|
||||
data = pa.table({"text": ["我们都有光明的前途", "光明的前途"]})
|
||||
table = mem_db.create_table("test_jieba", data=data)
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
base_tokenizer="jieba/default",
|
||||
stem=False,
|
||||
remove_stop_words=False,
|
||||
ascii_folding=False,
|
||||
)
|
||||
|
||||
results = table.search("我们", query_type="fts").limit(10).to_list()
|
||||
assert [row["text"] for row in results] == ["我们都有光明的前途"]
|
||||
|
||||
|
||||
def test_fts_jieba_missing_language_model_note(
|
||||
mem_db: DBConnection, monkeypatch, tmp_path
|
||||
):
|
||||
missing_root = tmp_path / "missing-language-models"
|
||||
monkeypatch.setenv("LANCE_LANGUAGE_MODEL_HOME", str(missing_root))
|
||||
table = mem_db.create_table(
|
||||
"test_missing_jieba_model",
|
||||
data=pa.table({"text": ["我们都有光明的前途"]}),
|
||||
)
|
||||
|
||||
with pytest.raises((ValueError, RuntimeError)) as e:
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
base_tokenizer="jieba/default",
|
||||
stem=False,
|
||||
remove_stop_words=False,
|
||||
ascii_folding=False,
|
||||
)
|
||||
|
||||
output = exception_output(e)
|
||||
assert "Invalid directory path:" in output
|
||||
assert "LANCE_LANGUAGE_MODEL_HOME" in output
|
||||
assert "jieba/default" in output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fts_jieba_missing_language_model_note_async(monkeypatch, tmp_path):
|
||||
missing_root = tmp_path / "missing-language-models"
|
||||
monkeypatch.setenv("LANCE_LANGUAGE_MODEL_HOME", str(missing_root))
|
||||
db = await ldb.connect_async(tmp_path / "async-db")
|
||||
table = await db.create_table(
|
||||
"test_missing_jieba_model_async",
|
||||
data=pa.table({"text": ["我们都有光明的前途"]}),
|
||||
)
|
||||
|
||||
with pytest.raises((ValueError, RuntimeError)) as e:
|
||||
await table.create_index(
|
||||
"text",
|
||||
config=FTS(
|
||||
base_tokenizer="jieba/default",
|
||||
stem=False,
|
||||
remove_stop_words=False,
|
||||
ascii_folding=False,
|
||||
),
|
||||
)
|
||||
|
||||
output = exception_output(e)
|
||||
assert "Invalid directory path:" in output
|
||||
assert "LANCE_LANGUAGE_MODEL_HOME" in output
|
||||
assert "jieba/default" in output
|
||||
|
||||
|
||||
def test_fts_lindera_tokenizer(
|
||||
mem_db: DBConnection, language_model_home, lindera_ipadic
|
||||
):
|
||||
data = pa.table({"text": ["成田国際空港", "東京国際空港", "羽田空港"]})
|
||||
table = mem_db.create_table("test_lindera", data=data)
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
base_tokenizer="lindera/ipadic",
|
||||
stem=False,
|
||||
remove_stop_words=False,
|
||||
ascii_folding=False,
|
||||
)
|
||||
|
||||
results = table.search("成田", query_type="fts").limit(10).to_list()
|
||||
assert [row["text"] for row in results] == ["成田国際空港"]
|
||||
|
||||
|
||||
def test_fts_query_to_json():
|
||||
"""Test that FTS query to_json() produces valid JSON strings with exact format."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user