1use std::collections::HashMap;
16use std::io;
17use std::str::FromStr;
18use std::sync::Arc;
19use std::task::Poll;
20
21use arrow::csv::reader::Format;
22use arrow::csv::{self, WriterBuilder};
23use arrow::error::ArrowError;
24use arrow::record_batch::RecordBatch;
25use arrow_schema::{Schema, SchemaRef};
26use async_trait::async_trait;
27use bytes::{Buf, Bytes};
28use common_runtime;
29use common_telemetry::warn;
30use datafusion::physical_plan::SendableRecordBatchStream;
31use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
32use futures::StreamExt;
33use futures::stream::BoxStream;
34use object_store::ObjectStore;
35use snafu::ResultExt;
36use tokio_util::compat::FuturesAsyncReadCompatExt;
37use tokio_util::io::SyncIoBridge;
38
39use crate::buffered_writer::DfRecordBatchEncoder;
40use crate::compression::CompressionType;
41use crate::error::{self, Result};
42use crate::file_format::{self, FileFormat, stream_to_file};
43use crate::share_buffer::SharedBuffer;
44use crate::util::normalize_infer_schema;
45
46const SKIP_BAD_RECORDS_BATCH_SIZE: usize = 1;
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct CsvFormat {
50 pub has_header: bool,
51 pub skip_bad_records: bool,
52 pub delimiter: u8,
53 pub schema_infer_max_record: Option<usize>,
54 pub compression_type: CompressionType,
55 pub timestamp_format: Option<String>,
56 pub time_format: Option<String>,
57 pub date_format: Option<String>,
58}
59
60impl TryFrom<&HashMap<String, String>> for CsvFormat {
61 type Error = error::Error;
62
63 fn try_from(value: &HashMap<String, String>) -> Result<Self> {
64 let mut format = CsvFormat::default();
65 if let Some(delimiter) = value.get(file_format::FORMAT_DELIMITER) {
66 format.delimiter = u8::from_str(delimiter).map_err(|_| {
68 error::ParseFormatSnafu {
69 key: file_format::FORMAT_DELIMITER,
70 value: delimiter,
71 }
72 .build()
73 })?;
74 };
75 if let Some(compression_type) = value.get(file_format::FORMAT_COMPRESSION_TYPE) {
76 format.compression_type = CompressionType::from_str(compression_type)?;
77 };
78 if let Some(schema_infer_max_record) =
79 value.get(file_format::FORMAT_SCHEMA_INFER_MAX_RECORD)
80 {
81 format.schema_infer_max_record =
82 Some(schema_infer_max_record.parse::<usize>().map_err(|_| {
83 error::ParseFormatSnafu {
84 key: file_format::FORMAT_SCHEMA_INFER_MAX_RECORD,
85 value: schema_infer_max_record,
86 }
87 .build()
88 })?);
89 };
90 if let Some(has_header) = value.get(file_format::FORMAT_HAS_HEADER) {
91 format.has_header = parse_bool(file_format::FORMAT_HAS_HEADER, has_header)?;
92 };
93 if let Some(skip_bad_records) = value.get(file_format::FORMAT_SKIP_BAD_RECORDS) {
94 format.skip_bad_records =
95 parse_bool(file_format::FORMAT_SKIP_BAD_RECORDS, skip_bad_records)?;
96 };
97 if let Some(timestamp_format) = value.get(file_format::TIMESTAMP_FORMAT) {
98 format.timestamp_format = Some(timestamp_format.clone());
99 }
100 if let Some(time_format) = value.get(file_format::TIME_FORMAT) {
101 format.time_format = Some(time_format.clone());
102 }
103 if let Some(date_format) = value.get(file_format::DATE_FORMAT) {
104 format.date_format = Some(date_format.clone());
105 }
106 Ok(format)
107 }
108}
109
110fn parse_bool(key: &'static str, value: &str) -> Result<bool> {
111 value
112 .parse()
113 .map_err(|_| error::ParseFormatSnafu { key, value }.build())
114}
115
116impl Default for CsvFormat {
117 fn default() -> Self {
118 Self {
119 has_header: true,
120 skip_bad_records: false,
121 delimiter: b',',
122 schema_infer_max_record: Some(file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD),
123 compression_type: CompressionType::Uncompressed,
124 timestamp_format: None,
125 time_format: None,
126 date_format: None,
127 }
128 }
129}
130
131#[async_trait]
132impl FileFormat for CsvFormat {
133 async fn infer_schema(&self, store: &ObjectStore, path: &str) -> Result<Schema> {
134 let meta = store
135 .stat(path)
136 .await
137 .context(error::ReadObjectSnafu { path })?;
138
139 let reader = store
140 .reader(path)
141 .await
142 .context(error::ReadObjectSnafu { path })?
143 .into_futures_async_read(0..meta.content_length())
144 .await
145 .context(error::ReadObjectSnafu { path })?
146 .compat();
147
148 let decoded = self.compression_type.convert_async_read(reader);
149
150 let delimiter = self.delimiter;
151 let schema_infer_max_record = self.schema_infer_max_record;
152 let has_header = self.has_header;
153
154 common_runtime::spawn_blocking_global(move || {
155 let reader = SyncIoBridge::new(decoded);
156
157 let format = Format::default()
158 .with_delimiter(delimiter)
159 .with_header(has_header);
160 let (schema, _records_read) = format
161 .infer_schema(reader, schema_infer_max_record)
162 .context(error::InferSchemaSnafu)?;
163
164 Ok(normalize_infer_schema(schema))
165 })
166 .await
167 .context(error::JoinHandleSnafu)?
168 }
169}
170
171pub async fn stream_to_csv(
172 stream: SendableRecordBatchStream,
173 store: ObjectStore,
174 path: &str,
175 threshold: usize,
176 concurrency: usize,
177 format: &CsvFormat,
178) -> Result<usize> {
179 stream_to_file(
180 stream,
181 store,
182 path,
183 threshold,
184 concurrency,
185 format.compression_type,
186 |buffer| {
187 let mut builder = WriterBuilder::new();
188 if let Some(timestamp_format) = &format.timestamp_format {
189 builder = builder.with_timestamp_format(timestamp_format.to_owned())
190 }
191 if let Some(date_format) = &format.date_format {
192 builder = builder.with_date_format(date_format.to_owned())
193 }
194 if let Some(time_format) = &format.time_format {
195 builder = builder.with_time_format(time_format.to_owned())
196 }
197 builder.build(buffer)
198 },
199 )
200 .await
201}
202
203impl DfRecordBatchEncoder for csv::Writer<SharedBuffer> {
204 fn write(&mut self, batch: &RecordBatch) -> Result<()> {
205 self.write(batch).context(error::WriteRecordBatchSnafu)
206 }
207}
208
209pub async fn tolerant_csv_stream(
217 store: &ObjectStore,
218 path: &str,
219 schema: SchemaRef,
220 projection: Vec<usize>,
221 format: &CsvFormat,
222) -> Result<SendableRecordBatchStream> {
223 let meta = store
224 .stat(path)
225 .await
226 .context(error::ReadObjectSnafu { path })?;
227
228 let reader = store
229 .reader(path)
230 .await
231 .context(error::ReadObjectSnafu { path })?
232 .into_bytes_stream(0..meta.content_length())
233 .await
234 .context(error::ReadObjectSnafu { path })?;
235
236 let reader = format.compression_type.convert_stream(reader).boxed();
237 tolerant_csv_stream_from_reader(
238 reader,
239 path,
240 schema,
241 projection,
242 format.has_header,
243 format.delimiter,
244 )
245}
246
247fn tolerant_csv_stream_from_reader(
248 reader: BoxStream<'static, io::Result<Bytes>>,
249 path: &str,
250 schema: SchemaRef,
251 projection: Vec<usize>,
252 has_header: bool,
253 delimiter: u8,
254) -> Result<SendableRecordBatchStream> {
255 let projected_schema = Arc::new(
256 schema
257 .project(&projection)
258 .context(error::InferSchemaSnafu)?,
259 );
260 let mut decoder = csv::ReaderBuilder::new(schema)
261 .with_header(has_header)
262 .with_delimiter(delimiter)
263 .with_batch_size(SKIP_BAD_RECORDS_BATCH_SIZE)
264 .with_projection(projection)
265 .build_decoder();
266
267 let path = path.to_string();
268 let mut upstream = reader.fuse();
269 let mut buffered = Bytes::new();
270 let mut input_finished = false;
271 let stream = futures::stream::poll_fn(move |cx| {
272 loop {
273 while !input_finished {
274 if buffered.is_empty() {
275 match futures::ready!(upstream.poll_next_unpin(cx)) {
276 Some(Ok(bytes)) if bytes.is_empty() => continue,
277 Some(Ok(bytes)) => buffered = bytes,
278 Some(Err(error)) => return Poll::Ready(Some(Err(error.into()))),
279 None => input_finished = true,
280 }
281 }
282
283 let decoded = decoder.decode(buffered.as_ref())?;
284 if decoded > 0 {
285 buffered.advance(decoded);
286 continue;
287 }
288
289 if decoder.capacity() == 0 || input_finished {
290 break;
291 }
292
293 if buffered.is_empty() {
294 continue;
295 }
296
297 return Poll::Ready(Some(Err(ArrowError::ParseError(
298 "CSV decoder made no progress while input bytes remain".to_string(),
299 ))));
300 }
301
302 match decoder.flush() {
303 Ok(Some(batch)) => return Poll::Ready(Some(Ok(batch))),
304 Ok(None) if input_finished => return Poll::Ready(None),
305 Ok(None) => continue,
306 Err(error) if is_skippable_arrow_error(&error) => {
307 warn!(
308 "Skipping bad CSV record while copying from {}: {}",
309 path, error
310 );
311 }
312 Err(error) => return Poll::Ready(Some(Err(error))),
313 }
314 }
315 })
316 .map(|result: std::result::Result<RecordBatch, ArrowError>| result.map_err(Into::into));
317
318 Ok(Box::pin(RecordBatchStreamAdapter::new(
319 projected_schema,
320 stream,
321 )))
322}
323
324pub fn is_skippable_arrow_error(error: &ArrowError) -> bool {
325 matches!(
326 error,
327 ArrowError::ParseError(_)
328 | ArrowError::CastError(_)
329 | ArrowError::ComputeError(_)
330 | ArrowError::InvalidArgumentError(_)
331 )
332}
333
334#[cfg(test)]
335mod tests {
336 use std::sync::Arc;
337
338 use arrow_schema::{DataType, Field};
339 use common_recordbatch::adapter::DfRecordBatchStreamAdapter;
340 use common_recordbatch::{RecordBatch, RecordBatches};
341 use common_test_util::find_workspace_path;
342 use datafusion::datasource::physical_plan::{CsvSource, FileSource};
343 use datatypes::prelude::ConcreteDataType;
344 use datatypes::schema::{ColumnSchema, Schema};
345 use datatypes::vectors::{Float64Vector, StringVector, UInt32Vector, VectorRef};
346 use futures::TryStreamExt;
347
348 use super::*;
349 use crate::file_format::{
350 FORMAT_COMPRESSION_TYPE, FORMAT_DELIMITER, FORMAT_HAS_HEADER,
351 FORMAT_SCHEMA_INFER_MAX_RECORD, FORMAT_SKIP_BAD_RECORDS, FileFormat, file_to_stream,
352 };
353 use crate::test_util::{format_schema, test_store};
354
355 fn test_data_root() -> String {
356 find_workspace_path("/src/common/datasource/tests/csv")
357 .display()
358 .to_string()
359 }
360
361 #[tokio::test]
362 async fn infer_schema_basic() {
363 let csv = CsvFormat::default();
364 let store = test_store(&test_data_root());
365 let schema = csv.infer_schema(&store, "simple.csv").await.unwrap();
366 let formatted: Vec<_> = format_schema(schema);
367
368 assert_eq!(
369 vec![
370 "c1: Utf8: NULL",
371 "c2: Int64: NULL",
372 "c3: Int64: NULL",
373 "c4: Int64: NULL",
374 "c5: Int64: NULL",
375 "c6: Int64: NULL",
376 "c7: Int64: NULL",
377 "c8: Int64: NULL",
378 "c9: Int64: NULL",
379 "c10: Utf8: NULL",
380 "c11: Float64: NULL",
381 "c12: Float64: NULL",
382 "c13: Utf8: NULL"
383 ],
384 formatted,
385 );
386 }
387
388 #[tokio::test]
389 async fn normalize_infer_schema() {
390 let csv = CsvFormat {
391 schema_infer_max_record: Some(3),
392 ..CsvFormat::default()
393 };
394 let store = test_store(&test_data_root());
395 let schema = csv.infer_schema(&store, "max_infer.csv").await.unwrap();
396 let formatted: Vec<_> = format_schema(schema);
397
398 assert_eq!(
399 vec![
400 "num: Int64: NULL",
401 "str: Utf8: NULL",
402 "ts: Utf8: NULL",
403 "t: Utf8: NULL",
404 "date: Date32: NULL"
405 ],
406 formatted,
407 );
408 }
409
410 #[tokio::test]
411 async fn infer_schema_with_limit() {
412 let csv = CsvFormat {
413 schema_infer_max_record: Some(3),
414 ..CsvFormat::default()
415 };
416 let store = test_store(&test_data_root());
417 let schema = csv
418 .infer_schema(&store, "schema_infer_limit.csv")
419 .await
420 .unwrap();
421 let formatted: Vec<_> = format_schema(schema);
422
423 assert_eq!(
424 vec![
425 "a: Int64: NULL",
426 "b: Float64: NULL",
427 "c: Int64: NULL",
428 "d: Int64: NULL"
429 ],
430 formatted
431 );
432
433 let csv = CsvFormat::default();
434 let store = test_store(&test_data_root());
435 let schema = csv
436 .infer_schema(&store, "schema_infer_limit.csv")
437 .await
438 .unwrap();
439 let formatted: Vec<_> = format_schema(schema);
440
441 assert_eq!(
442 vec![
443 "a: Int64: NULL",
444 "b: Float64: NULL",
445 "c: Int64: NULL",
446 "d: Utf8: NULL"
447 ],
448 formatted
449 );
450 }
451
452 #[test]
453 fn test_try_from() {
454 let map = HashMap::new();
455 let format: CsvFormat = CsvFormat::try_from(&map).unwrap();
456
457 assert_eq!(format, CsvFormat::default());
458
459 let map = HashMap::from([
460 (
461 FORMAT_SCHEMA_INFER_MAX_RECORD.to_string(),
462 "2000".to_string(),
463 ),
464 (FORMAT_COMPRESSION_TYPE.to_string(), "zstd".to_string()),
465 (FORMAT_DELIMITER.to_string(), b'\t'.to_string()),
466 (FORMAT_HAS_HEADER.to_string(), "false".to_string()),
467 ]);
468 let format = CsvFormat::try_from(&map).unwrap();
469
470 assert_eq!(
471 format,
472 CsvFormat {
473 compression_type: CompressionType::Zstd,
474 schema_infer_max_record: Some(2000),
475 delimiter: b'\t',
476 has_header: false,
477 skip_bad_records: false,
478 timestamp_format: None,
479 time_format: None,
480 date_format: None
481 }
482 );
483
484 let map = HashMap::from([(FORMAT_SKIP_BAD_RECORDS.to_string(), "true".to_string())]);
485 let format = CsvFormat::try_from(&map).unwrap();
486
487 assert_eq!(
488 format,
489 CsvFormat {
490 skip_bad_records: true,
491 ..CsvFormat::default()
492 }
493 );
494 }
495
496 #[test]
497 fn test_try_from_rejects_invalid_bool_options() {
498 let map = HashMap::from([(FORMAT_SKIP_BAD_RECORDS.to_string(), "yes".to_string())]);
499 assert!(CsvFormat::try_from(&map).is_err());
500 }
501
502 #[tokio::test]
503 async fn test_compressed_csv() {
504 let column_schemas = vec![
506 ColumnSchema::new("id", ConcreteDataType::uint32_datatype(), false),
507 ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
508 ColumnSchema::new("value", ConcreteDataType::float64_datatype(), false),
509 ];
510 let schema = Arc::new(Schema::new(column_schemas));
511
512 let batch1_columns: Vec<VectorRef> = vec![
514 Arc::new(UInt32Vector::from_slice(vec![1, 2, 3])),
515 Arc::new(StringVector::from(vec!["Alice", "Bob", "Charlie"])),
516 Arc::new(Float64Vector::from_slice(vec![10.5, 20.3, 30.7])),
517 ];
518 let batch1 = RecordBatch::new(schema.clone(), batch1_columns).unwrap();
519
520 let batch2_columns: Vec<VectorRef> = vec![
521 Arc::new(UInt32Vector::from_slice(vec![4, 5, 6])),
522 Arc::new(StringVector::from(vec!["David", "Eva", "Frank"])),
523 Arc::new(Float64Vector::from_slice(vec![40.1, 50.2, 60.3])),
524 ];
525 let batch2 = RecordBatch::new(schema.clone(), batch2_columns).unwrap();
526
527 let batch3_columns: Vec<VectorRef> = vec![
528 Arc::new(UInt32Vector::from_slice(vec![7, 8, 9])),
529 Arc::new(StringVector::from(vec!["Grace", "Henry", "Ivy"])),
530 Arc::new(Float64Vector::from_slice(vec![70.4, 80.5, 90.6])),
531 ];
532 let batch3 = RecordBatch::new(schema.clone(), batch3_columns).unwrap();
533
534 let recordbatches = RecordBatches::try_new(schema, vec![batch1, batch2, batch3]).unwrap();
536
537 let compression_types = vec![
539 CompressionType::Gzip,
540 CompressionType::Bzip2,
541 CompressionType::Xz,
542 CompressionType::Zstd,
543 ];
544
545 let temp_dir = common_test_util::temp_dir::create_temp_dir("test_compressed_csv");
547 for compression_type in compression_types {
548 let format = CsvFormat {
549 compression_type,
550 ..CsvFormat::default()
551 };
552
553 let compressed_file_name =
555 format!("test_compressed_csv.{}", compression_type.file_extension());
556 let compressed_file_path = temp_dir.path().join(&compressed_file_name);
557 let compressed_file_path_str = compressed_file_path.to_str().unwrap();
558
559 let store = test_store("/");
561
562 let rows = stream_to_csv(
564 Box::pin(DfRecordBatchStreamAdapter::new(recordbatches.as_stream())),
565 store,
566 compressed_file_path_str,
567 1024,
568 1,
569 &format,
570 )
571 .await
572 .unwrap();
573
574 assert_eq!(rows, 9);
575
576 assert!(compressed_file_path.exists());
578 let file_size = std::fs::metadata(&compressed_file_path).unwrap().len();
579 assert!(file_size > 0);
580
581 let file_content = std::fs::read(&compressed_file_path).unwrap();
583 match compression_type {
586 CompressionType::Gzip => {
587 assert_eq!(file_content[0], 0x1f, "Gzip file should start with 0x1f");
589 assert_eq!(
590 file_content[1], 0x8b,
591 "Gzip file should have 0x8b as second byte"
592 );
593 }
594 CompressionType::Bzip2 => {
595 assert_eq!(file_content[0], b'B', "Bzip2 file should start with 'B'");
597 assert_eq!(
598 file_content[1], b'Z',
599 "Bzip2 file should have 'Z' as second byte"
600 );
601 }
602 CompressionType::Xz => {
603 assert_eq!(file_content[0], 0xFD, "XZ file should start with 0xFD");
605 }
606 CompressionType::Zstd => {
607 assert_eq!(file_content[0], 0x28, "Zstd file should start with 0x28");
609 assert_eq!(
610 file_content[1], 0xB5,
611 "Zstd file should have 0xB5 as second byte"
612 );
613 }
614 _ => {}
615 }
616
617 let store = test_store("/");
619 let schema = Arc::new(
620 CsvFormat {
621 compression_type,
622 ..Default::default()
623 }
624 .infer_schema(&store, compressed_file_path_str)
625 .await
626 .unwrap(),
627 );
628 let csv_source = CsvSource::new(schema).with_batch_size(8192);
629
630 let stream = file_to_stream(
631 &store,
632 compressed_file_path_str,
633 csv_source.clone(),
634 None,
635 compression_type,
636 )
637 .await
638 .unwrap();
639
640 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
641 let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
642 .unwrap()
643 .to_string();
644 let expected = r#"+----+---------+-------+
645| id | name | value |
646+----+---------+-------+
647| 1 | Alice | 10.5 |
648| 2 | Bob | 20.3 |
649| 3 | Charlie | 30.7 |
650| 4 | David | 40.1 |
651| 5 | Eva | 50.2 |
652| 6 | Frank | 60.3 |
653| 7 | Grace | 70.4 |
654| 8 | Henry | 80.5 |
655| 9 | Ivy | 90.6 |
656+----+---------+-------+"#;
657 assert_eq!(expected, pretty_print);
658 }
659 }
660
661 #[tokio::test]
662 async fn test_tolerant_csv_stream_continues_after_parse_error() {
663 let temp_dir = common_test_util::temp_dir::create_temp_dir("test_tolerant_csv_stream");
664 let csv_file_path = temp_dir.path().join("input.csv");
665 std::fs::write(
666 &csv_file_path,
667 "id,name,value\n1,Alice,10.5\nbad,Bad,20.0\nworse,Bad,21.0\n2,Bob,30.5",
668 )
669 .unwrap();
670
671 let store = test_store("/");
672 let schema = Arc::new(arrow_schema::Schema::new(vec![
673 Field::new("id", DataType::UInt32, false),
674 Field::new("name", DataType::Utf8, false),
675 Field::new("value", DataType::Float64, false),
676 ]));
677 let path = csv_file_path.to_str().unwrap();
678
679 let stream =
680 tolerant_csv_stream(&store, path, schema, vec![0, 1, 2], &CsvFormat::default())
681 .await
682 .unwrap();
683 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
684 let pretty_print = arrow::util::pretty::pretty_format_batches(&batches)
685 .unwrap()
686 .to_string();
687 let expected = r#"+----+-------+-------+
688| id | name | value |
689+----+-------+-------+
690| 1 | Alice | 10.5 |
691| 2 | Bob | 30.5 |
692+----+-------+-------+"#;
693 assert_eq!(expected, pretty_print);
694 }
695
696 #[tokio::test]
697 async fn test_tolerant_csv_stream_fails_on_structural_csv_error() {
698 let temp_dir =
699 common_test_util::temp_dir::create_temp_dir("test_tolerant_csv_stream_csv_error");
700 let csv_file_path = temp_dir.path().join("input.csv");
701 std::fs::write(&csv_file_path, "id,name,value\n1,Alice,10.5\n2,Bob\n").unwrap();
702
703 let store = test_store("/");
704 let schema = Arc::new(arrow_schema::Schema::new(vec![
705 Field::new("id", DataType::UInt32, false),
706 Field::new("name", DataType::Utf8, false),
707 Field::new("value", DataType::Float64, false),
708 ]));
709 let path = csv_file_path.to_str().unwrap();
710
711 let stream =
712 tolerant_csv_stream(&store, path, schema, vec![0, 1, 2], &CsvFormat::default())
713 .await
714 .unwrap();
715 let error = stream.try_collect::<Vec<_>>().await.unwrap_err();
716
717 assert!(error.to_string().contains("incorrect number of fields"));
718 }
719}