Skip to content

Commit cf3276a

Browse files
committed
Fix error with short writes
1 parent 41cadaa commit cf3276a

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

src/enc/test.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ fn test_roundtrip_empty() {
553553
#[cfg(feature="std")]
554554
#[test]
555555
fn test_compress_into_short_buffer() {
556-
use std::io::{Cursor, Write};
556+
use std::io::{Cursor, Write, ErrorKind};
557557

558558
// this plaintext should compress to 11 bytes
559559
let plaintext = [0u8; 2048];
@@ -564,7 +564,8 @@ fn test_compress_into_short_buffer() {
564564

565565
let mut w = crate::CompressorWriter::new(&mut output_cursor,
566566
4096, 4, 22);
567-
w.write(&plaintext).unwrap_err();
567+
assert_eq!(w.write(&plaintext).unwrap(), 2048);
568+
assert_eq!(w.flush().unwrap_err().kind(), ErrorKind::WriteZero);
568569
w.into_inner();
569570

570571
println!("{output_buffer:?}");

src/enc/writer.rs

+35-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ impl<W: Write, BufferType: SliceWrapperMut<u8>, Alloc: BrotliAlloc>
4141
buffer,
4242
alloc,
4343
Error::new(ErrorKind::InvalidData, "Invalid Data"),
44+
Error::new(ErrorKind::WriteZero, "No room in output."),
4445
q,
4546
lgwin,
4647
))
@@ -127,14 +128,24 @@ pub struct CompressorWriterCustomIo<
127128
output: Option<W>,
128129
error_if_invalid_data: Option<ErrType>,
129130
state: BrotliEncoderStateStruct<Alloc>,
131+
error_if_zero_bytes_written: Option<ErrType>,
130132
}
131-
pub fn write_all<ErrType, W: CustomWrite<ErrType>>(
133+
pub fn write_all<ErrType, W: CustomWrite<ErrType>, ErrMaker: FnMut() -> Option<ErrType>>(
132134
writer: &mut W,
133135
mut buf: &[u8],
136+
mut error_to_return_if_zero_bytes_written: ErrMaker,
134137
) -> Result<(), ErrType> {
135138
while !buf.is_empty() {
136139
match writer.write(buf) {
137-
Ok(bytes_written) => buf = &buf[bytes_written..],
140+
Ok(bytes_written) => if bytes_written != 0 {
141+
buf = &buf[bytes_written..]
142+
} else {
143+
if let Some(err) = error_to_return_if_zero_bytes_written() {
144+
return Err(err);
145+
} else {
146+
return Ok(());
147+
}
148+
},
138149
Err(e) => return Err(e),
139150
}
140151
}
@@ -148,6 +159,7 @@ impl<ErrType, W: CustomWrite<ErrType>, BufferType: SliceWrapperMut<u8>, Alloc: B
148159
buffer: BufferType,
149160
alloc: Alloc,
150161
invalid_data_error_type: ErrType,
162+
error_if_zero_bytes_written: ErrType,
151163
q: u32,
152164
lgwin: u32,
153165
) -> Self {
@@ -157,6 +169,7 @@ impl<ErrType, W: CustomWrite<ErrType>, BufferType: SliceWrapperMut<u8>, Alloc: B
157169
output: Some(w),
158170
state: BrotliEncoderStateStruct::new(alloc),
159171
error_if_invalid_data: Some(invalid_data_error_type),
172+
error_if_zero_bytes_written: Some(error_if_zero_bytes_written),
160173
};
161174
ret.state
162175
.set_parameter(BrotliEncoderParameter::BROTLI_PARAM_QUALITY, q);
@@ -189,9 +202,17 @@ impl<ErrType, W: CustomWrite<ErrType>, BufferType: SliceWrapperMut<u8>, Alloc: B
189202
&mut nop_callback,
190203
);
191204
if output_offset > 0 {
205+
let zero_err = &mut self.error_if_zero_bytes_written;
206+
let fallback = &mut self.error_if_invalid_data;
192207
match write_all(
193208
self.output.as_mut().unwrap(),
194209
&self.output_buffer.slice_mut()[..output_offset],
210+
|| {
211+
if let Some(err) = zero_err.take() {
212+
return Some(err);
213+
}
214+
fallback.take()
215+
},
195216
) {
196217
Ok(_) => {}
197218
Err(e) => return Err(e),
@@ -266,12 +287,23 @@ impl<ErrType, W: CustomWrite<ErrType>, BufferType: SliceWrapperMut<u8>, Alloc: B
266287
&mut nop_callback,
267288
);
268289
if output_offset > 0 {
290+
let zero_err = &mut self.error_if_zero_bytes_written;
291+
let fallback = &mut self.error_if_invalid_data;
269292
match write_all(
270293
self.output.as_mut().unwrap(),
271294
&self.output_buffer.slice_mut()[..output_offset],
295+
|| {
296+
if let Some(err) = zero_err.take() {
297+
return Some(err);
298+
}
299+
fallback.take()
300+
},
301+
272302
) {
273303
Ok(_) => {}
274-
Err(e) => return Err(e),
304+
Err(e) => {
305+
return Err(e)
306+
},
275307
}
276308
}
277309
if !ret {

0 commit comments

Comments
 (0)