// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::sync::Arc; use arrow::{ datatypes::SchemaRef, pyarrow::{IntoPyArrow, ToPyArrow}, }; use futures::stream::StreamExt; use lancedb::arrow::SendableRecordBatchStream; use pyo3::{ exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, Py, PyAny, PyRef, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; use crate::error::PythonErrorExt; #[pyclass] pub struct RecordBatchStream { schema: SchemaRef, inner: Arc>, } impl RecordBatchStream { pub fn new(inner: SendableRecordBatchStream) -> Self { let schema = inner.schema().clone(); Self { schema, inner: Arc::new(tokio::sync::Mutex::new(inner)), } } } #[pymethods] impl RecordBatchStream { #[getter] pub fn schema(&self, py: Python) -> PyResult> { (*self.schema) .clone() .into_pyarrow(py) .map(|obj| obj.unbind()) } pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> { self_ } pub fn __anext__(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner.clone(); future_into_py(self_.py(), async move { let inner_next = inner .lock() .await .next() .await .ok_or_else(|| PyStopAsyncIteration::new_err(""))?; Python::attach(|py| { inner_next .infer_error()? .to_pyarrow(py) .map(|obj| obj.unbind()) }) }) } }