use std::{
mem::{self, MaybeUninit},
slice,
};
use crate::{bits_to_bytes, packet_priority::OrderingChannel, PacketReliability, Result, ID};
pub type BitSize = usize;
pub unsafe trait ReadSafe {}
unsafe impl ReadSafe for u64 {}
unsafe impl ReadSafe for u32 {}
unsafe impl ReadSafe for u16 {}
unsafe impl ReadSafe for u8 {}
pub unsafe trait WriteSafe {}
unsafe impl WriteSafe for u64 {}
unsafe impl WriteSafe for u32 {}
unsafe impl WriteSafe for u16 {}
unsafe impl WriteSafe for u8 {}
unsafe impl WriteSafe for ID {}
pub unsafe trait Bits {
const NUM: usize;
}
unsafe impl Bits for OrderingChannel {
const NUM: usize = 5;
}
unsafe impl Bits for PacketReliability {
const NUM: usize = 3;
}
fn bytes_from_maybe_uninit<T>(u: &mut MaybeUninit<T>) -> &mut [MaybeUninit<u8>] {
unsafe {
slice::from_raw_parts_mut(u.as_mut_ptr() as *mut MaybeUninit<u8>, mem::size_of::<T>())
}
}
pub struct BitStreamRead<'a> {
number_of_bits_used: BitSize,
read_offset: BitSize,
data: &'a [u8],
}
impl<'a> BitStreamRead<'a> {
pub fn with_size(data: &'a [u8], number_of_bits_used: BitSize) -> Self {
assert!(data.len() << 3 >= number_of_bits_used);
Self {
number_of_bits_used,
read_offset: 0,
data,
}
}
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
number_of_bits_used: data.len() << 3,
read_offset: 0,
}
}
pub fn ignore_bits(&mut self, number_of_bits: BitSize) -> Result<()> {
if self.read_offset + number_of_bits > self.number_of_bits_used {
Err(crate::Error::BitStreamEos)
} else {
self.read_offset += number_of_bits;
Ok(())
}
}
pub const fn get_number_of_unread_bits(&self) -> BitSize {
self.number_of_bits_used - self.read_offset
}
pub fn read_bit(&mut self) -> bool {
let result = (self.data[self.read_offset >> 3] & (0x80 >> (self.read_offset & 7))) != 0;
self.read_offset += 1;
result
}
pub fn read<T: ReadSafe>(&mut self) -> Result<T> {
let mut u = MaybeUninit::<T>::uninit();
let bytes = bytes_from_maybe_uninit(&mut u);
self._read_bits(bytes, mem::size_of::<T>() * 8, true)?;
Ok(unsafe { u.assume_init() })
}
pub fn read_bool(&mut self) -> Result<bool> {
if self.read_offset + 1 > self.number_of_bits_used {
return Err(crate::Error::BitStreamEos);
}
let var = (self.data[self.read_offset >> 3] & (0x80 >> (self.read_offset & 7))) != 0;
self.read_offset += 1;
Ok(var)
}
pub fn read_bits<T: Bits>(&mut self) -> Result<T> {
assert!(T::NUM <= mem::size_of::<T>() << 3);
let mut buffer = MaybeUninit::uninit();
self._read_bits(bytes_from_maybe_uninit(&mut buffer), T::NUM, true)?;
Ok(unsafe { buffer.assume_init() })
}
fn _read_bits(
&mut self,
output: &mut [MaybeUninit<u8>],
mut number_of_bits_to_read: usize,
align_bits_to_right: bool,
) -> Result<()> {
if self.read_offset + number_of_bits_to_read > self.number_of_bits_used {
return Err(crate::Error::BitStreamEos);
}
let read_offset_mod_8: BitSize = self.read_offset & 7;
let mut offset: BitSize = 0;
output.fill(MaybeUninit::new(0));
while number_of_bits_to_read > 0 {
unsafe {
*output[offset].as_mut_ptr() |=
self.data[self.read_offset >> 3] << read_offset_mod_8
}; if read_offset_mod_8 > 0 && number_of_bits_to_read > 8 - read_offset_mod_8 {
unsafe {
*output[offset].as_mut_ptr() |=
self.data[(self.read_offset >> 3) + 1] >> (8 - read_offset_mod_8)
}; }
if number_of_bits_to_read >= 8 {
number_of_bits_to_read -= 8;
self.read_offset += 8;
offset += 1;
} else {
let neg = number_of_bits_to_read as isize - 8;
if neg < 0
{
if align_bits_to_right {
unsafe { *output[offset].as_mut_ptr() >>= -neg };
}
self.read_offset += (8 + neg) as usize;
} else {
self.read_offset += 8;
}
offset += 1;
number_of_bits_to_read = 0;
}
}
Ok(())
}
pub fn read_compressed<T: ReadSafe>(&mut self) -> Result<T> {
let mut buf = MaybeUninit::uninit();
self._read_compressed(
bytes_from_maybe_uninit(&mut buf),
mem::size_of::<T>() << 3,
true,
)?;
Ok(unsafe { buf.assume_init() })
}
fn _align_read_to_byte_boundary(&mut self) {
if self.read_offset > 0 {
self.read_offset += 8 - (((self.read_offset - 1) & 7) + 1);
}
}
pub fn read_aligned_bytes(&mut self, number_of_bytes_to_read: usize) -> Result<&'a [u8]> {
if number_of_bytes_to_read == 0 {
return Err(crate::Error::BitStreamEos);
}
self._align_read_to_byte_boundary();
if self.read_offset + (number_of_bytes_to_read << 3) > self.number_of_bits_used {
return Err(crate::Error::BitStreamEos);
}
let slice = &self.data[self.read_offset >> 3..][..number_of_bytes_to_read];
self.read_offset += number_of_bytes_to_read << 3;
Ok(slice)
}
fn _read_compressed(
&mut self,
output: &mut [MaybeUninit<u8>],
size: usize,
unsigned_data: bool,
) -> Result<()> {
let mut current_byte = (size >> 3) - 1;
let (byte_match, half_byte_match): (u8, u8) = match unsigned_data {
true => (0, 0),
false => (0xFF, 0xF0),
};
while current_byte > 0 {
if self.read_bool()?
{
unsafe { *output[current_byte].as_mut_ptr() = byte_match };
current_byte -= 1;
} else {
self._read_bits(output, (current_byte + 1) << 3, true)?;
return Ok(());
}
}
if self.read_offset + 1 > self.number_of_bits_used {
return Err(crate::Error::BitStreamEos);
}
let b = self.read_bool()?;
if b
{
self._read_bits(&mut output[current_byte..], 4, true)?;
unsafe { *output[current_byte].as_mut_ptr() |= half_byte_match }; } else {
self._read_bits(&mut output[current_byte..], 8, true)?;
}
Ok(())
}
}
#[derive(Default, Debug, Clone)]
pub struct BitStreamWrite {
number_of_bits_used: usize,
data: Vec<u8>,
}
macro_rules! assert_bs_invariant {
($self:ident) => {
debug_assert_eq!(bits_to_bytes!($self.number_of_bits_used), $self.data.len());
};
}
impl BitStreamWrite {
pub fn new() -> Self {
Self {
number_of_bits_used: 0,
data: Vec::new(),
}
}
pub fn with_capacity(bits: usize) -> Self {
Self {
data: Vec::with_capacity(bits_to_bytes!(bits)),
number_of_bits_used: 0,
}
}
pub fn num_bits(&self) -> usize {
self.number_of_bits_used
}
pub fn data(&self) -> &[u8] {
&self.data[..bits_to_bytes!(self.number_of_bits_used)]
}
fn add_bits_and_reallocate(&mut self, number_of_bits_to_write: usize) {
if number_of_bits_to_write == 0 {
return;
}
let new_number_of_bits_allocated: BitSize =
number_of_bits_to_write + self.number_of_bits_used;
if new_number_of_bits_allocated > 0
&& self.data.len() < bits_to_bytes!(new_number_of_bits_allocated)
{
let amount_to_reserve = bits_to_bytes!(
new_number_of_bits_allocated + new_number_of_bits_allocated.max(0xFFFFF)
);
self.data.reserve(amount_to_reserve - self.data.len());
let amount_to_allocate: BitSize = bits_to_bytes!(new_number_of_bits_allocated);
self.data.resize(amount_to_allocate, 0);
}
}
pub fn write_0(&mut self) {
assert_bs_invariant!(self);
self.add_bits_and_reallocate(1);
if (self.number_of_bits_used & 7) == 0 {
self.data[self.number_of_bits_used >> 3] = 0;
}
self.number_of_bits_used += 1;
}
pub fn write_1(&mut self) {
assert_bs_invariant!(self);
self.add_bits_and_reallocate(1);
let number_of_bits_mod_8: BitSize = self.number_of_bits_used & 7;
if number_of_bits_mod_8 == 0 {
self.data[self.number_of_bits_used >> 3] = 0x80;
} else {
self.data[self.number_of_bits_used >> 3] |= 0x80 >> (number_of_bits_mod_8);
}
self.number_of_bits_used += 1;
}
pub fn write_bool(&mut self, value: bool) {
match value {
true => self.write_1(),
false => self.write_0(),
}
}
pub fn write<T: WriteSafe>(&mut self, data: T) {
assert_bs_invariant!(self);
let input =
unsafe { slice::from_raw_parts((&data) as *const T as *const u8, mem::size_of::<T>()) };
let number_of_bits_to_write = mem::size_of::<T>() << 3;
self._write_bits(input, number_of_bits_to_write, true);
}
pub fn write_bits<T: Bits>(&mut self, value: T) {
assert_bs_invariant!(self);
assert!(T::NUM <= mem::size_of::<T>() << 3);
let input = unsafe {
slice::from_raw_parts((&value) as *const T as *const u8, mem::size_of::<T>())
};
self._write_bits(input, T::NUM, true);
}
pub fn write_compressed<T: WriteSafe>(&mut self, data: T) {
assert_bs_invariant!(self);
let input =
unsafe { slice::from_raw_parts((&data) as *const T as *const u8, mem::size_of::<T>()) };
self._write_compressed(input, true)
}
pub fn write_bytes(&mut self, input: &[u8], number_of_bytes: usize) {
assert_bs_invariant!(self);
if number_of_bytes == 0 {
return;
}
if (self.number_of_bits_used & 7) == 0 {
self.data.extend_from_slice(input);
self.number_of_bits_used += number_of_bytes << 3;
} else {
self._write_bits(input, number_of_bytes << 3, true);
}
}
fn align_write_to_byte_boundary(&mut self) {
assert_bs_invariant!(self);
let offset = self.number_of_bits_used % 8;
if offset > 0 {
self.number_of_bits_used += 8 - offset;
}
}
pub fn write_aligned_bytes(&mut self, bytes: &[u8]) {
assert_bs_invariant!(self);
self.align_write_to_byte_boundary();
self.data.extend_from_slice(bytes);
self.number_of_bits_used += bytes.len() << 3;
}
fn _write_compressed(&mut self, input: &[u8], unsigned_data: bool) {
debug_assert_eq!(bits_to_bytes!(self.number_of_bits_used), self.data.len());
let mut current_byte: BitSize = input.len() - 1; let byte_match = match unsigned_data {
true => 0,
false => 0xFF,
};
while current_byte > 0 {
if input[current_byte] == byte_match {
self.write_1();
} else {
self.write_0();
self._write_bits(input, (current_byte + 1) << 3, true);
return;
}
current_byte -= 1;
}
let half_match = match unsigned_data {
true => 0x00,
false => 0xF0,
};
if (input[current_byte] & 0xF0) == half_match {
self.write_1();
self._write_bits(&input[current_byte..], 4, true);
} else {
self.write_0();
self._write_bits(&input[current_byte..], 8, true);
}
}
fn _write_bits(
&mut self,
input: &[u8],
mut number_of_bits_to_write: usize,
right_aligned_bits: bool,
) {
if number_of_bits_to_write == 0 {
return;
}
self.add_bits_and_reallocate(number_of_bits_to_write);
let mut offset: usize = 0;
let mut data_byte: u8;
let number_of_bits_used_mod_8: BitSize = self.number_of_bits_used & 7;
while number_of_bits_to_write > 0 {
data_byte = input[offset];
if number_of_bits_to_write < 8 && right_aligned_bits {
data_byte <<= 8 - number_of_bits_to_write; }
if number_of_bits_used_mod_8 == 0 {
self.data[self.number_of_bits_used >> 3] = data_byte;
} else {
self.data[self.number_of_bits_used >> 3] |=
data_byte >> (number_of_bits_used_mod_8); if 8 - (number_of_bits_used_mod_8) < 8
&& 8 - (number_of_bits_used_mod_8) < number_of_bits_to_write
{
self.data[(self.number_of_bits_used >> 3) + 1] =
data_byte << (8 - number_of_bits_used_mod_8); }
}
if number_of_bits_to_write >= 8 {
self.number_of_bits_used += 8;
number_of_bits_to_write -= 8;
} else {
self.number_of_bits_used += number_of_bits_to_write;
number_of_bits_to_write = 0;
}
offset += 1;
}
}
}
impl std::io::Write for BitStreamWrite {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let number_of_bytes = buf.len();
self.write_bytes(buf, number_of_bytes);
Ok(number_of_bytes)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::BitStreamWrite;
#[test]
fn test_read() {
let mut bit_stream = super::BitStreamRead::new(&[0b11100011, 0b00000100]);
assert!(bit_stream.read_bit());
assert!(bit_stream.read_bit());
assert!(bit_stream.read_bit());
assert!(!bit_stream.read_bit());
assert!(!bit_stream.read_bit());
assert!(!bit_stream.read_bit());
assert_eq!(bit_stream.read::<u8>(), Ok(0b11000001));
}
#[test]
fn test_write() {
let mut bs = BitStreamWrite::new();
bs.write(10u8);
assert_eq!(bs.data, &[10]);
}
#[test]
fn test_aligned() {
let mut bs = BitStreamWrite::new();
bs.write_1();
bs.write_1();
bs.write_aligned_bytes(&[20, 19, 18]);
assert_eq!(bs.data(), &[0b11000000, 20, 19, 18]);
bs = BitStreamWrite::new();
bs.write_aligned_bytes(&[20, 19, 18]);
assert_eq!(bs.data(), &[20, 19, 18]);
}
#[test]
#[allow(clippy::unusual_byte_groupings)]
fn test_compressed() {
let mut bs = BitStreamWrite::new();
bs.write_compressed(0u32);
assert_eq!(bs.data(), &[0b1111_0000]);
assert_eq!(bs.num_bits(), 8);
bs = BitStreamWrite::new();
bs.write_compressed(1u32);
assert_eq!(bs.data(), &[0b1111_0001]);
assert_eq!(bs.num_bits(), 8);
bs = BitStreamWrite::new();
bs.write_compressed(15u32);
assert_eq!(bs.data(), &[0b1111_1111]);
assert_eq!(bs.num_bits(), 8);
bs = BitStreamWrite::new();
bs.write_compressed(1u16);
assert_eq!(bs.data(), &[0b11_0001_00]);
assert_eq!(bs.num_bits(), 6);
bs = BitStreamWrite::new();
bs.write_compressed(4u16);
assert_eq!(bs.data(), &[0b11_0100_00]);
assert_eq!(bs.num_bits(), 6);
}
}