// Copyright 2024 Lance Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::sync::Mutex; use lancedb::index::scalar::FtsIndexBuilder; use lancedb::{ index::{ scalar::BTreeIndexBuilder, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, Index as LanceDbIndex, }, DistanceType, }; use pyo3::{ exceptions::{PyKeyError, PyRuntimeError, PyValueError}, pyclass, pymethods, IntoPy, PyObject, PyResult, Python, }; use crate::util::parse_distance_type; #[pyclass] pub struct Index { inner: Mutex>, } impl Index { pub fn consume(&self) -> PyResult { self.inner .lock() .unwrap() .take() .ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once")) } } #[pymethods] impl Index { #[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None))] #[staticmethod] pub fn ivf_pq( distance_type: Option, num_partitions: Option, num_sub_vectors: Option, num_bits: Option, max_iterations: Option, sample_rate: Option, ) -> PyResult { let mut ivf_pq_builder = IvfPqIndexBuilder::default(); if let Some(distance_type) = distance_type { let distance_type = match distance_type.as_str() { "l2" => Ok(DistanceType::L2), "cosine" => Ok(DistanceType::Cosine), "dot" => Ok(DistanceType::Dot), _ => Err(PyValueError::new_err(format!( "Invalid distance type '{}'. Must be one of l2, cosine, or dot", distance_type ))), }?; ivf_pq_builder = ivf_pq_builder.distance_type(distance_type); } if let Some(num_partitions) = num_partitions { ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions); } if let Some(num_sub_vectors) = num_sub_vectors { ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors); } if let Some(num_bits) = num_bits { ivf_pq_builder = ivf_pq_builder.num_bits(num_bits); } if let Some(max_iterations) = max_iterations { ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations); } if let Some(sample_rate) = sample_rate { ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate); } Ok(Self { inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))), }) } #[staticmethod] pub fn btree() -> PyResult { Ok(Self { inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))), }) } #[staticmethod] pub fn bitmap() -> PyResult { Ok(Self { inner: Mutex::new(Some(LanceDbIndex::Bitmap(Default::default()))), }) } #[staticmethod] pub fn label_list() -> PyResult { Ok(Self { inner: Mutex::new(Some(LanceDbIndex::LabelList(Default::default()))), }) } #[pyo3(signature = (with_position=None, base_tokenizer=None, language=None, max_token_length=None, lower_case=None, stem=None, remove_stop_words=None, ascii_folding=None))] #[allow(clippy::too_many_arguments)] #[staticmethod] pub fn fts( with_position: Option, base_tokenizer: Option, language: Option, max_token_length: Option, lower_case: Option, stem: Option, remove_stop_words: Option, ascii_folding: Option, ) -> Self { let mut opts = FtsIndexBuilder::default(); if let Some(with_position) = with_position { opts = opts.with_position(with_position); } if let Some(base_tokenizer) = base_tokenizer { opts.tokenizer_configs = opts.tokenizer_configs.base_tokenizer(base_tokenizer); } if let Some(language) = language { opts.tokenizer_configs = opts.tokenizer_configs.language(&language).unwrap(); } opts.tokenizer_configs = opts.tokenizer_configs.max_token_length(max_token_length); if let Some(lower_case) = lower_case { opts.tokenizer_configs = opts.tokenizer_configs.lower_case(lower_case); } if let Some(stem) = stem { opts.tokenizer_configs = opts.tokenizer_configs.stem(stem); } if let Some(remove_stop_words) = remove_stop_words { opts.tokenizer_configs = opts.tokenizer_configs.remove_stop_words(remove_stop_words); } if let Some(ascii_folding) = ascii_folding { opts.tokenizer_configs = opts.tokenizer_configs.ascii_folding(ascii_folding); } Self { inner: Mutex::new(Some(LanceDbIndex::FTS(opts))), } } #[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))] #[staticmethod] #[allow(clippy::too_many_arguments)] pub fn hnsw_pq( distance_type: Option, num_partitions: Option, num_sub_vectors: Option, num_bits: Option, max_iterations: Option, sample_rate: Option, m: Option, ef_construction: Option, ) -> PyResult { let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default(); if let Some(distance_type) = distance_type { let distance_type = parse_distance_type(distance_type)?; hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type); } if let Some(num_partitions) = num_partitions { hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions); } if let Some(num_sub_vectors) = num_sub_vectors { hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors); } if let Some(num_bits) = num_bits { hnsw_pq_builder = hnsw_pq_builder.num_bits(num_bits); } if let Some(max_iterations) = max_iterations { hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations); } if let Some(sample_rate) = sample_rate { hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate); } if let Some(m) = m { hnsw_pq_builder = hnsw_pq_builder.num_edges(m); } if let Some(ef_construction) = ef_construction { hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction); } Ok(Self { inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))), }) } #[pyo3(signature = (distance_type=None, num_partitions=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))] #[staticmethod] pub fn hnsw_sq( distance_type: Option, num_partitions: Option, max_iterations: Option, sample_rate: Option, m: Option, ef_construction: Option, ) -> PyResult { let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default(); if let Some(distance_type) = distance_type { let distance_type = parse_distance_type(distance_type)?; hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type); } if let Some(num_partitions) = num_partitions { hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions); } if let Some(max_iterations) = max_iterations { hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations); } if let Some(sample_rate) = sample_rate { hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate); } if let Some(m) = m { hnsw_sq_builder = hnsw_sq_builder.num_edges(m); } if let Some(ef_construction) = ef_construction { hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction); } Ok(Self { inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))), }) } } #[pyclass(get_all)] /// A description of an index currently configured on a column pub struct IndexConfig { /// The type of the index pub index_type: String, /// The columns in the index /// /// Currently this is always a list of size 1. In the future there may /// be more columns to represent composite indices. pub columns: Vec, /// Name of the index. pub name: String, } #[pymethods] impl IndexConfig { pub fn __repr__(&self) -> String { format!( "Index({}, columns={:?}, name=\"{}\")", self.index_type, self.columns, self.name ) } // For backwards-compatibility with the old sync SDK, we also support getting // attributes via __getitem__. pub fn __getitem__(&self, key: String, py: Python<'_>) -> PyResult { match key.as_str() { "index_type" => Ok(self.index_type.clone().into_py(py)), "columns" => Ok(self.columns.clone().into_py(py)), "name" | "index_name" => Ok(self.name.clone().into_py(py)), _ => Err(PyKeyError::new_err(format!("Invalid key: {}", key))), } } } impl From for IndexConfig { fn from(value: lancedb::index::IndexConfig) -> Self { let index_type = format!("{:?}", value.index_type); Self { index_type, columns: value.columns, name: value.name, } } }