mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
feat: add create_index to the async python API (#1052)
This also refactors the rust lancedb index builder API (and, correspondingly, the nodejs API)
This commit is contained in:
@@ -23,8 +23,9 @@ from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector # noqa: F401
|
||||
from .utils import sentry_log # noqa: F401
|
||||
from .schema import vector
|
||||
from .table import AsyncTable
|
||||
from .utils import sentry_log
|
||||
|
||||
|
||||
def connect(
|
||||
@@ -188,3 +189,19 @@ async def connect_async(
|
||||
read_consistency_interval_secs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"connect",
|
||||
"connect_async",
|
||||
"AsyncConnection",
|
||||
"AsyncTable",
|
||||
"URI",
|
||||
"sanitize_uri",
|
||||
"sentry_log",
|
||||
"vector",
|
||||
"DBConnection",
|
||||
"LanceDBConnection",
|
||||
"RemoteDBConnection",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,18 @@ from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
class Index:
|
||||
@staticmethod
|
||||
def ivf_pq(
|
||||
distance_type: Optional[str],
|
||||
num_partitions: Optional[int],
|
||||
num_sub_vectors: Optional[int],
|
||||
max_iterations: Optional[int],
|
||||
sample_rate: Optional[int],
|
||||
) -> Index: ...
|
||||
@staticmethod
|
||||
def btree() -> Index: ...
|
||||
|
||||
class Connection(object):
|
||||
async def table_names(
|
||||
self, start_after: Optional[str], limit: Optional[int]
|
||||
@@ -13,10 +25,15 @@ class Connection(object):
|
||||
self, name: str, mode: str, schema: pa.Schema
|
||||
) -> Table: ...
|
||||
|
||||
class Table(object):
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(
|
||||
self, column: str, config: Optional[Index], replace: Optional[bool]
|
||||
): ...
|
||||
|
||||
async def connect(
|
||||
uri: str,
|
||||
|
||||
157
python/python/lancedb/index.py
Normal file
157
python/python/lancedb/index.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._lancedb import (
|
||||
Index as LanceDbIndex,
|
||||
)
|
||||
|
||||
|
||||
class BTree(object):
|
||||
"""Describes a btree index configuration
|
||||
|
||||
A btree index is an index on scalar columns. The index stores a copy of the
|
||||
column in sorted order. A header entry is created for each block of rows
|
||||
(currently the block size is fixed at 4096). These header entries are stored
|
||||
in a separate cacheable structure (a btree). To search for data the header is
|
||||
used to determine which blocks need to be read from disk.
|
||||
|
||||
For example, a btree index in a table with 1Bi rows requires
|
||||
sizeof(Scalar) * 256Ki bytes of memory and will generally need to read
|
||||
sizeof(Scalar) * 4096 bytes to find the correct row ids.
|
||||
|
||||
This index is good for scalar columns with mostly distinct values and does best
|
||||
when the query is highly selective.
|
||||
|
||||
The btree index does not currently have any parameters though parameters such as
|
||||
the block size may be added in the future.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inner = LanceDbIndex.btree()
|
||||
|
||||
|
||||
class IvfPq(object):
|
||||
"""Describes an IVF PQ Index
|
||||
|
||||
This index stores a compressed (quantized) copy of every vector. These vectors
|
||||
are grouped into partitions of similar vectors. Each partition keeps track of
|
||||
a centroid which is the average value of all vectors in the group.
|
||||
|
||||
During a query the centroids are compared with the query vector to find the
|
||||
closest partitions. The compressed vectors in these partitions are then
|
||||
searched to find the closest vectors.
|
||||
|
||||
The compression scheme is called product quantization. Each vector is divide
|
||||
into subvectors and then each subvector is quantized into a small number of
|
||||
bits. the parameters `num_bits` and `num_subvectors` control this process,
|
||||
providing a tradeoff between index size (and thus search speed) and index
|
||||
accuracy.
|
||||
|
||||
The partitioning process is called IVF and the `num_partitions` parameter
|
||||
controls how many groups to create.
|
||||
|
||||
Note that training an IVF PQ index on a large dataset is a slow operation and
|
||||
currently is also a memory intensive operation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create an IVF PQ index config
|
||||
|
||||
Parameters
|
||||
----------
|
||||
distance_type: str, default "L2"
|
||||
The distance metric used to train the index
|
||||
|
||||
This is used when training the index to calculate the IVF partitions
|
||||
(vectors are grouped in partitions with similar vectors according to this
|
||||
distance type) and to calculate a subvector's code during quantization.
|
||||
|
||||
The distance type used to train an index MUST match the distance type used
|
||||
to search the index. Failure to do so will yield inaccurate results.
|
||||
|
||||
The following distance types are available:
|
||||
|
||||
"l2" - Euclidean distance. This is a very common distance metric that
|
||||
accounts for both magnitude and direction when determining the distance
|
||||
between vectors. L2 distance has a range of [0, ∞).
|
||||
|
||||
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||
calculated from the cosine similarity between two vectors. Cosine
|
||||
similarity is a measure of similarity between two non-zero vectors of an
|
||||
inner product space. It is defined to equal the cosine of the angle
|
||||
between them. Unlike L2, the cosine distance is not affected by the
|
||||
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||
|
||||
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||
are all zeros (there is no direction). These vectors are invalid and may
|
||||
never be returned from a vector search.
|
||||
|
||||
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||
num_partitions: int, default sqrt(num_rows)
|
||||
The number of IVF partitions to create.
|
||||
|
||||
This value should generally scale with the number of rows in the dataset.
|
||||
By default the number of partitions is the square root of the number of
|
||||
rows.
|
||||
|
||||
If this value is too large then the first part of the search (picking the
|
||||
right partition) will be slow. If this value is too small then the second
|
||||
part of the search (searching within a partition) will be slow.
|
||||
num_sub_vectors: int, default is vector dimension / 16
|
||||
Number of sub-vectors of PQ.
|
||||
|
||||
This value controls how much the vector is compressed during the
|
||||
quantization step. The more sub vectors there are the less the vector is
|
||||
compressed. The default is the dimension of the vector divided by 16. If
|
||||
the dimension is not evenly divisible by 16 we use the dimension divded by
|
||||
8.
|
||||
|
||||
The above two cases are highly preferred. Having 8 or 16 values per
|
||||
subvector allows us to use efficient SIMD instructions.
|
||||
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||
This parameter controls how many iterations of kmeans to run.
|
||||
|
||||
Increasing this might improve the quality of the index but in most cases
|
||||
these extra iterations have diminishing returns.
|
||||
|
||||
The default value is 50.
|
||||
sample_rate: int, default 256
|
||||
The rate used to calculate the number of training vectors for kmeans.
|
||||
|
||||
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||
are groups of vectors that are similar to each other. To do this we use an
|
||||
algorithm called kmeans.
|
||||
|
||||
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||
kmeans on a random sample of the data. This parameter controls the size of
|
||||
the sample. The total number of vectors used to train the index is
|
||||
`sample_rate * num_partitions`.
|
||||
|
||||
Increasing this value might improve the quality of the index but in most
|
||||
cases the default should be sufficient.
|
||||
|
||||
The default value is 256.
|
||||
"""
|
||||
self._inner = LanceDbIndex.ivf_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
@@ -60,6 +60,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from ._lancedb import Table as LanceDBTable
|
||||
from .db import LanceDBConnection
|
||||
from .index import BTree, IvfPq
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
@@ -1917,112 +1918,48 @@ class AsyncTable:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: str, default "L2"
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "L2", "cosine", or "dot".
|
||||
L2 is euclidean distance.
|
||||
num_partitions: int, default 256
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int, default 96
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
- If True, replace the existing index if it exists.
|
||||
|
||||
- If False, raise an error if duplicate index exists.
|
||||
accelerator: str, default None
|
||||
If set, use the given accelerator to create the index.
|
||||
Only support "cuda" for now.
|
||||
index_cache_size : int, optional
|
||||
The size of the index cache in number of entries. Default value is 256.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = True,
|
||||
replace: Optional[bool] = None,
|
||||
config: Optional[Union[IvfPq, BTree]] = None,
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
"""Create an index to speed up queries
|
||||
|
||||
Scalar indices, like vector indices, can be used to speed up scans. A scalar
|
||||
index can speed up scans that contain filter expressions on the indexed column.
|
||||
For example, the following scan will be faster if the column ``my_col`` has
|
||||
a scalar index:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect("/data/lance")
|
||||
img_table = db.open_table("images")
|
||||
my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas()
|
||||
|
||||
Scalar indices can also speed up scans containing a vector search and a
|
||||
prefilter:
|
||||
|
||||
.. code-block::python
|
||||
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect("/data/lance")
|
||||
img_table = db.open_table("images")
|
||||
img_table.search([1, 2, 3, 4], vector_column_name="vector")
|
||||
.where("my_col != 7", prefilter=True)
|
||||
.to_pandas()
|
||||
|
||||
Scalar indices can only speed up scans for basic filters using
|
||||
equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set
|
||||
membership (e.g. `my_col IN (0, 1, 2)`)
|
||||
|
||||
Scalar indices can be used if the filter contains multiple indexed columns and
|
||||
the filter criteria are AND'd or OR'd together
|
||||
(e.g. ``my_col < 0 AND other_col> 100``)
|
||||
|
||||
Scalar indices may be used if the filter contains non-indexed columns but,
|
||||
depending on the structure of the filter, they may not be usable. For example,
|
||||
if the column ``not_indexed`` does not have a scalar index then the filter
|
||||
``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on
|
||||
``my_col``.
|
||||
|
||||
**Experimental API**
|
||||
Indices can be created on vector columns or scalar columns.
|
||||
Indices on vector columns will speed up vector searches.
|
||||
Indices on scalar columns will speed up filtering (in both
|
||||
vector and non-vector searches)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
column : str
|
||||
The column to be indexed. Must be a boolean, integer, float,
|
||||
or string column.
|
||||
replace : bool, default True
|
||||
Replace the existing index if it exists.
|
||||
index: Index
|
||||
The index to create.
|
||||
|
||||
Examples
|
||||
--------
|
||||
LanceDb supports multiple types of indices. See the static methods on
|
||||
the Index class for more details.
|
||||
column: str, default None
|
||||
The column to index.
|
||||
|
||||
.. code-block:: python
|
||||
When building a scalar index this must be set.
|
||||
|
||||
import lance
|
||||
When building a vector index, this is optional. The default will look
|
||||
for any columns of type fixed-size-list with floating point values. If
|
||||
there is only one column of this type then it will be used. Otherwise
|
||||
an error will be returned.
|
||||
replace: bool, default True
|
||||
Whether to replace the existing index
|
||||
|
||||
dataset = lance.dataset("./images.lance")
|
||||
dataset.create_scalar_index("category")
|
||||
If this is false, and another index already exists on the same columns
|
||||
and the same name, then an error will be returned. This is true even if
|
||||
that index is out of date.
|
||||
|
||||
The default is True
|
||||
"""
|
||||
raise NotImplementedError
|
||||
index = None
|
||||
if config is not None:
|
||||
index = config._inner
|
||||
await self._inner.create_index(column, index=index, replace=replace)
|
||||
|
||||
async def add(
|
||||
self,
|
||||
@@ -2066,6 +2003,8 @@ class AsyncTable:
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
await self._inner.add(data, mode)
|
||||
register_event("add")
|
||||
|
||||
|
||||
61
python/python/tests/test_index.py
Normal file
61
python/python/tests/test_index.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from datetime import timedelta
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||
from lancedb.index import BTree, IvfPq
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_async(tmp_path) -> AsyncConnection:
|
||||
return await connect_async(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
|
||||
|
||||
def sample_fixed_size_list_array(nrows, dim):
|
||||
vector_data = pa.array([float(i) for i in range(dim * nrows)], pa.float32())
|
||||
return pa.FixedSizeListArray.from_arrays(vector_data, dim)
|
||||
|
||||
|
||||
DIM = 8
|
||||
NROWS = 256
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def some_table(db_async):
|
||||
data = pa.Table.from_pydict(
|
||||
{
|
||||
"id": list(range(256)),
|
||||
"vector": sample_fixed_size_list_array(NROWS, DIM),
|
||||
}
|
||||
)
|
||||
return await db_async.create_table(
|
||||
"some_table",
|
||||
data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_scalar_index(some_table: AsyncTable):
|
||||
# Can create
|
||||
await some_table.create_index("id")
|
||||
# Can recreate if replace=True
|
||||
await some_table.create_index("id", replace=True)
|
||||
# Can't recreate if replace=False
|
||||
with pytest.raises(RuntimeError, match="already exists"):
|
||||
await some_table.create_index("id", replace=False)
|
||||
# can also specify index type
|
||||
await some_table.create_index("id", config=BTree())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_vector_index(some_table: AsyncTable):
|
||||
# Can create
|
||||
await some_table.create_index("vector")
|
||||
# Can recreate if replace=True
|
||||
await some_table.create_index("vector", replace=True)
|
||||
# Can't recreate if replace=False
|
||||
with pytest.raises(RuntimeError, match="already exists"):
|
||||
await some_table.create_index("vector", replace=False)
|
||||
# Can also specify index type
|
||||
await some_table.create_index("vector", config=IvfPq(num_partitions=100))
|
||||
87
python/src/index.rs
Normal file
87
python/src/index.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::{
|
||||
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
|
||||
DistanceType,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, PyResult,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Index {
|
||||
inner: Mutex<Option<LanceDbIndex>>,
|
||||
}
|
||||
|
||||
impl Index {
|
||||
pub fn consume(&self) -> PyResult<LanceDbIndex> {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Index {
|
||||
#[staticmethod]
|
||||
pub fn ivf_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = match distance_type.as_str() {
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
distance_type
|
||||
))),
|
||||
}?;
|
||||
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn btree() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,11 +14,15 @@
|
||||
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::Index;
|
||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
use table::Table;
|
||||
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
#[pymodule]
|
||||
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
@@ -27,6 +31,8 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<Index>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
@@ -9,7 +9,7 @@ use pyo3::{
|
||||
};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::{error::PythonErrorExt, index::Index};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Table {
|
||||
@@ -81,6 +81,28 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
column: String,
|
||||
index: Option<&Index>,
|
||||
replace: Option<bool>,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
let index = if let Some(index) = index {
|
||||
index.consume()?
|
||||
} else {
|
||||
lancedb::index::Index::Auto
|
||||
};
|
||||
let mut op = self_.inner_ref()?.create_index(&[column], index);
|
||||
if let Some(replace) = replace {
|
||||
op = op.replace(replace);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
op.execute().await.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn __repr__(&self) -> String {
|
||||
match &self.inner {
|
||||
None => format!("ClosedTable({})", self.name),
|
||||
|
||||
35
python/src/util.rs
Normal file
35
python/src/util.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use std::sync::Mutex;
|
||||
|
||||
use pyo3::{exceptions::PyRuntimeError, PyResult};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
///
|
||||
/// Rust builders are often implemented so that the builder methods
|
||||
/// consume the builder and return a new one. This is not compatible
|
||||
/// with the pyo3, which, being garbage collected, cannot easily obtain
|
||||
/// ownership of an object.
|
||||
///
|
||||
/// This wrapper converts the compile-time safety of rust into runtime
|
||||
/// errors if any attempt to use the builder happens after it is consumed.
|
||||
pub struct BuilderWrapper<T> {
|
||||
name: String,
|
||||
inner: Mutex<Option<T>>,
|
||||
}
|
||||
|
||||
impl<T> BuilderWrapper<T> {
|
||||
pub fn new(name: impl AsRef<str>, inner: T) -> Self {
|
||||
Self {
|
||||
name: name.as_ref().to_string(),
|
||||
inner: Mutex::new(Some(inner)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn consume<O>(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult<O> {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let inner_builder = inner.take().ok_or_else(|| {
|
||||
PyRuntimeError::new_err(format!("{} has already been consumed", self.name))
|
||||
})?;
|
||||
let result = mod_fn(inner_builder);
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user