fix: use import stubs to prevent MLX doctest collection failures (#2536)

## Summary
- Add `create_import_stub()` helper to `embeddings/utils.py` for
handling optional dependencies
- Fix MLX doctest collection failures by using import stubs in
`gte_mlx_model.py`
- Module now imports successfully for doctest collection even when MLX
is not installed

## Changes
- **New utility function**: `create_import_stub()` creates placeholder
objects that allow class inheritance but raise helpful errors when used
- **Updated MLX model**: Uses import stubs instead of direct imports
that fail immediately
- **Graceful degradation**: Clear error messages when MLX functionality
is accessed without MLX installed

## Test Results
-  `pytest --doctest-modules python/lancedb` now passes (with and
without MLX installed)
-  All existing tests continue to pass
-  MLX functionality works normally when MLX is installed
-  Helpful error messages when MLX functionality is used without MLX
installed

Fixes #2538

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Tristan Zajonc
2025-07-23 16:25:33 -07:00
committed by GitHub
parent 94fb9f364a
commit d2c6759e7f
5 changed files with 60 additions and 26 deletions

View File

@@ -9,11 +9,14 @@ from huggingface_hub import snapshot_download
from pydantic import BaseModel
from transformers import BertTokenizer
from .utils import create_import_stub
try:
import mlx.core as mx
import mlx.nn as nn
except ImportError:
raise ImportError("You need to install MLX to use this model use - pip install mlx")
mx = create_import_stub("mlx.core", "mlx")
nn = create_import_stub("mlx.nn", "mlx")
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
@@ -72,7 +75,7 @@ class TransformerEncoder(nn.Module):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
for _ in range(num_layers)
]
def __call__(self, x, mask):

View File

@@ -21,6 +21,36 @@ from ..dependencies import pandas as pd
from ..util import attempt_import_or_raise
def create_import_stub(module_name: str, package_name: str = None):
"""
Create a stub module that allows class definition but fails when used.
This allows modules to be imported for doctest collection even when
optional dependencies are not available.
Parameters
----------
module_name : str
The name of the module to create a stub for
package_name : str, optional
The package name to suggest in the error message
Returns
-------
object
A stub object that can be used in place of the module
"""
class _ImportStub:
def __getattr__(self, name):
return _ImportStub # Return stub for chained access like nn.Module
def __call__(self, *args, **kwargs):
pkg = package_name or module_name
raise ImportError(f"You need to install {pkg} to use this functionality")
return _ImportStub()
# ruff: noqa: PERF203
def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
def wrapper(fn):

View File

@@ -906,11 +906,11 @@ class LanceQueryBuilder(ABC):
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceRead: uri=..., projection=[vector], ...
Parameters
----------
@@ -940,19 +940,19 @@ class LanceQueryBuilder(ABC):
>>> plan = table.search(query).analyze_plan()
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
AnalyzeExec verbose=true, metrics=[]
ProjectionExec: expr=[...], metrics=[...]
GlobalLimitExec: skip=0, fetch=10, metrics=[...]
FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=...]
SortExec: TopK(fetch=10), expr=[...],
preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...]
KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., output_batches=...]
LanceScan: uri=..., projection=[vector], row_id=true,
row_addr=false, ordered=false,
metrics=[output_rows=..., elapsed_compute=...,
bytes_read=..., iops=..., requests=...]
TracedExec, metrics=[]
ProjectionExec: expr=[...], metrics=[...]
GlobalLimitExec: skip=0, fetch=10, metrics=[...]
FilterExec: _distance@2 IS NOT NULL,
metrics=[output_rows=..., elapsed_compute=...]
SortExec: TopK(fetch=10), expr=[...],
preserve_partitioning=[...],
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...]
KNNVectorDistance: metric=l2,
metrics=[output_rows=..., elapsed_compute=..., output_batches=...]
LanceRead: uri=..., projection=[vector], ...
metrics=[output_rows=..., elapsed_compute=...,
bytes_read=..., iops=..., requests=...]
Returns
-------
@@ -2043,7 +2043,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
LanceRead: uri=..., projection=[vector], ...
Parameters
----------
@@ -2429,7 +2429,7 @@ class AsyncQueryBase(object):
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
LanceRead: uri=..., projection=[vector], ...
Parameters
----------
@@ -3054,7 +3054,7 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
LanceRead: uri=..., projection=[vector], ...
<BLANKLINE>
FTS Search Plan:
ProjectionExec: expr=[vector@2 as vector, text@3 as text, _score@1 as _score]

View File

@@ -166,7 +166,7 @@ async def test_explain_plan(table: AsyncTable):
assert "Vector Search Plan" in plan
assert "KNNVectorDistance" in plan
assert "FTS Search Plan" in plan
assert "LanceScan" in plan
assert "LanceRead" in plan
@pytest.mark.asyncio

View File

@@ -839,7 +839,7 @@ async def test_explain_plan_with_filters(table_async: AsyncTable):
table_async.query().nearest_to(pa.array([1, 2])).where("id = 1").explain_plan()
)
assert "KNN" in plan_with_filter
assert "FilterExec" in plan_with_filter
assert "LanceRead" in plan_with_filter
# Test FTS query with filter
from lancedb.index import FTS
@@ -850,7 +850,8 @@ async def test_explain_plan_with_filters(table_async: AsyncTable):
)
plan_fts_filter = await query_fts_filter.where("id = 1").explain_plan()
assert "MatchQuery: query=dog" in plan_fts_filter
assert "FilterExec: id@" in plan_fts_filter # Should show filter details
assert "LanceRead" in plan_fts_filter
assert "full_filter=id = Int64(1)" in plan_fts_filter # Should show filter details
@pytest.mark.asyncio