mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-15 02:50:44 +00:00
I'm working on a lancedb version of pytorch data loading (and hopefully addressing https://github.com/lancedb/lance/issues/3727). However, rather than rely on pytorch for everything I'm moving some of the things that pytorch does into rust. This gives us more control over data loading (e.g. using shards or a hash-based split) and it allows permutations to be persistent. In particular I hope to be able to: * Create a persistent permutation * This permutation can handle splits, filtering, shuffling, and sharding * Create a rust data loader that can read a permutation (one or more splits), or a subset of a permutation (for DDP) * Create a python data loader that delegates to the rust data loader Eventually create integrations for other data loading libraries, including rust & node
61 lines
1.8 KiB
Rust
61 lines
1.8 KiB
Rust
// SPDX-License-Identifier: Apache-2.0
|
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
use arrow::RecordBatchStream;
|
|
use connection::{connect, Connection};
|
|
use env_logger::Env;
|
|
use index::IndexConfig;
|
|
use permutation::PyAsyncPermutationBuilder;
|
|
use pyo3::{
|
|
pymodule,
|
|
types::{PyModule, PyModuleMethods},
|
|
wrap_pyfunction, Bound, PyResult, Python,
|
|
};
|
|
use query::{FTSQuery, HybridQuery, Query, VectorQuery};
|
|
use session::Session;
|
|
use table::{
|
|
AddColumnsResult, AddResult, AlterColumnsResult, DeleteResult, DropColumnsResult, MergeResult,
|
|
Table, UpdateResult,
|
|
};
|
|
|
|
pub mod arrow;
|
|
pub mod connection;
|
|
pub mod error;
|
|
pub mod header;
|
|
pub mod index;
|
|
pub mod permutation;
|
|
pub mod query;
|
|
pub mod session;
|
|
pub mod table;
|
|
pub mod util;
|
|
|
|
#[pymodule]
|
|
pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
let env = Env::new()
|
|
.filter_or("LANCEDB_LOG", "warn")
|
|
.write_style("LANCEDB_LOG_STYLE");
|
|
env_logger::init_from_env(env);
|
|
m.add_class::<Connection>()?;
|
|
m.add_class::<Session>()?;
|
|
m.add_class::<Table>()?;
|
|
m.add_class::<IndexConfig>()?;
|
|
m.add_class::<Query>()?;
|
|
m.add_class::<FTSQuery>()?;
|
|
m.add_class::<HybridQuery>()?;
|
|
m.add_class::<VectorQuery>()?;
|
|
m.add_class::<RecordBatchStream>()?;
|
|
m.add_class::<AddColumnsResult>()?;
|
|
m.add_class::<AlterColumnsResult>()?;
|
|
m.add_class::<AddResult>()?;
|
|
m.add_class::<MergeResult>()?;
|
|
m.add_class::<DeleteResult>()?;
|
|
m.add_class::<DropColumnsResult>()?;
|
|
m.add_class::<UpdateResult>()?;
|
|
m.add_class::<PyAsyncPermutationBuilder>()?;
|
|
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
|
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
|
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
|
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
|
Ok(())
|
|
}
|