ripgrep-all/src/caching_writer.rs

78 lines
2.7 KiB
Rust
Raw Normal View History

use anyhow::Result;
2019-06-07 21:17:33 +00:00
use log::*;
2020-09-10 15:18:11 +00:00
use std::io::{Read, Write};
2019-06-05 14:43:40 +00:00
/**
* wrap a writer so that it is passthrough,
2019-06-07 17:00:24 +00:00
* but also the written data is compressed and written into a buffer,
* unless more than max_cache_size bytes is written, then the cache is dropped and it is pure passthrough.
2019-06-05 14:43:40 +00:00
*/
2020-06-17 09:43:47 +00:00
pub struct CachingReader<R: Read> {
2019-06-05 14:43:40 +00:00
max_cache_size: usize,
zstd_writer: Option<zstd::stream::write::Encoder<Vec<u8>>>,
2020-06-17 09:43:47 +00:00
inp: R,
2020-06-09 10:47:34 +00:00
bytes_written: u64,
2020-06-17 09:43:47 +00:00
on_finish: Box<dyn FnOnce((u64, Option<Vec<u8>>)) -> Result<()> + Send>,
2019-06-05 14:43:40 +00:00
}
2020-06-17 09:43:47 +00:00
impl<R: Read> CachingReader<R> {
pub fn new(
inp: R,
max_cache_size: usize,
compression_level: i32,
on_finish: Box<dyn FnOnce((u64, Option<Vec<u8>>)) -> Result<()> + Send>,
) -> Result<CachingReader<R>> {
Ok(CachingReader {
inp,
2019-06-05 14:43:40 +00:00
max_cache_size,
zstd_writer: Some(zstd::stream::write::Encoder::new(
Vec::new(),
compression_level,
)?),
2020-06-09 10:47:34 +00:00
bytes_written: 0,
2020-06-17 09:43:47 +00:00
on_finish,
2019-06-05 14:43:40 +00:00
})
}
2020-06-17 09:43:47 +00:00
pub fn finish(&mut self) -> std::io::Result<(u64, Option<Vec<u8>>)> {
if let Some(writer) = self.zstd_writer.take() {
2019-06-05 14:43:40 +00:00
let res = writer.finish()?;
if res.len() <= self.max_cache_size {
2020-06-09 10:47:34 +00:00
return Ok((self.bytes_written, Some(res)));
2019-06-05 14:43:40 +00:00
}
}
2020-06-09 10:47:34 +00:00
Ok((self.bytes_written, None))
2019-06-05 14:43:40 +00:00
}
2020-06-17 09:43:47 +00:00
fn write_to_compressed(&mut self, buf: &[u8]) -> std::io::Result<()> {
if let Some(writer) = self.zstd_writer.as_mut() {
let wrote = writer.write(buf)?;
let compressed_len = writer.get_ref().len();
trace!("wrote {} to zstd, len now {}", wrote, compressed_len);
if compressed_len > self.max_cache_size {
debug!("cache longer than max, dropping");
//writer.finish();
self.zstd_writer.take().unwrap().finish()?;
2019-06-05 14:43:40 +00:00
}
2020-06-17 09:43:47 +00:00
}
Ok(())
2019-06-05 14:43:40 +00:00
}
2020-06-17 09:43:47 +00:00
}
impl<R: Read> Read for CachingReader<R> {
fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
match self.inp.read(&mut buf) {
Ok(0) => {
// move out of box, replace with noop lambda
let on_finish = std::mem::replace(&mut self.on_finish, Box::new(|_| Ok(())));
// EOF, finish!
(on_finish)(self.finish()?)
.map(|()| 0)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
}
Ok(read_bytes) => {
self.write_to_compressed(&buf[0..read_bytes])?;
self.bytes_written += read_bytes as u64;
Ok(read_bytes)
}
Err(e) => Err(e),
2019-06-05 14:43:40 +00:00
}
}
}