follow up on the fix of multiply with overflow

This commit is contained in:
Paul Masurel
2025-02-24 11:49:28 +09:00
parent 371dba9414
commit df2d52a84e

View File

@@ -65,7 +65,7 @@ impl BitPacker {
#[derive(Clone, Debug, Default, Copy)]
pub struct BitUnpacker {
num_bits: u32,
num_bits: usize,
mask: u64,
}
@@ -83,7 +83,7 @@ impl BitUnpacker {
(1u64 << num_bits) - 1u64
};
BitUnpacker {
num_bits: u32::from(num_bits),
num_bits: usize::from(num_bits),
mask,
}
}
@@ -94,7 +94,7 @@ impl BitUnpacker {
#[inline]
pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
let addr_in_bits = idx as usize * self.num_bits as usize;
let addr_in_bits = idx as usize * self.num_bits;
let addr = addr_in_bits >> 3;
if addr + 8 > data.len() {
if self.num_bits == 0 {
@@ -129,24 +129,25 @@ impl BitUnpacker {
//
// This methods panics if `num_bits` is > 32.
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
let start_idx = start_idx as usize;
assert!(
self.bit_width() <= 32,
"Bitwidth must be <= 32 to use this method."
);
let end_idx = start_idx + output.len() as u32;
let end_idx = start_idx + output.len();
let end_bit_read = end_idx * self.num_bits;
let end_byte_read = (end_bit_read + 7) / 8;
assert!(
end_byte_read as usize <= data.len(),
end_byte_read <= data.len(),
"Requested index is out of bounds."
);
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
let get_batch_ramp = |start_idx: usize, output: &mut [u32]| {
for (out, idx) in output.iter_mut().zip(start_idx..) {
*out = self.get(idx, data) as u32;
*out = self.get(idx as u32, data) as u32;
}
};
@@ -161,23 +162,23 @@ impl BitUnpacker {
// so highway start is the closest multiple of 8 that is >= start_idx.
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
let highway_start: u32 = start_idx + entrance_ramp_len;
let highway_start: usize = start_idx + entrance_ramp_len;
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
if highway_start + BitPacker1x::BLOCK_LEN > end_idx {
// We don't have enough values to have even a single block of highway.
// Let's just supply the values the simple way.
get_batch_ramp(start_idx, output);
return;
}
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
let num_blocks: usize = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN;
// Entrance ramp
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len]);
// Highway
let mut offset = (highway_start * self.num_bits) as usize / 8;
let mut output_cursor = (highway_start - start_idx) as usize;
let mut offset = (highway_start * self.num_bits) / 8;
let mut output_cursor = highway_start - start_idx;
for _ in 0..num_blocks {
offset += BitPacker1x.decompress(
&data[offset..],
@@ -188,7 +189,7 @@ impl BitUnpacker {
}
// Exit ramp
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN;
get_batch_ramp(highway_end, &mut output[output_cursor..]);
}