use crate::prelude_std::*;
use super::{ ChunkedSlice, UnsafeBufWriteGuard };
pub const BINARY_FRAME_LEN: usize = 5;
pub const STRING_FRAME_LEN: usize = 8;
#[inline]
pub fn encode_base32(bytes: &[u8]) -> String {
_encode::<25, b'A', { b'2' - 26 }>(bytes)
}
#[inline]
pub fn encode_base32hex(bytes: &[u8]) -> String {
_encode::<9, b'0', { b'A' - 10 }>(bytes)
}
fn _encode<
const BREAKPOINT: u8,
const LOWER: u8,
const UPPER_ADJUSTED: u8
>(bytes: &[u8]) -> String {
let frames = bytes.len() / BINARY_FRAME_LEN;
let remainder = bytes.len() % BINARY_FRAME_LEN;
let capacity = if remainder == 0 {
frames * STRING_FRAME_LEN
} else {
(frames + 1) * STRING_FRAME_LEN
};
let mut frames_iter = ChunkedSlice::<BINARY_FRAME_LEN>::new(bytes);
let mut dest = UnsafeBufWriteGuard::with_capacity(capacity);
for _ in 0..frames {
let frame = unsafe { frames_iter.next_frame_unchecked() };
unsafe { encode_frame::<BREAKPOINT, LOWER, UPPER_ADJUSTED>(frame, &mut dest) }
}
if remainder > 0 {
let padding_amount = match remainder {
1 => { 6 }
2 => { 4 }
3 => { 3 }
4 => { 1 }
_ => {
unsafe { hint::unreachable_unchecked() }
}
};
let f = |frame: &[u8; 5]| {
static PADDING: &[u8; 6] = b"======";
unsafe { encode_frame::<BREAKPOINT, LOWER, UPPER_ADJUSTED>(frame, &mut dest) }
let ptr = unsafe { dest.as_mut_ptr().sub(padding_amount) };
unsafe { ptr::copy_nonoverlapping(PADDING.as_ptr(), ptr, padding_amount) }
};
unsafe { frames_iter.with_remainder_unchecked(f) }
}
let vec = unsafe { dest.into_full_vec() };
unsafe {
debug_assert!(str::from_utf8(&vec).is_ok(), "output bytes valid utf-8");
String::from_utf8_unchecked(vec)
}
}
unsafe fn encode_frame<
const BREAKPOINT: u8,
const LOWER: u8,
const UPPER_ADJUSTED: u8
>(frame: &[u8; BINARY_FRAME_LEN], dest: &mut UnsafeBufWriteGuard) {
let byte1 = frame[0] >> 3;
let byte2 = ((frame[0] << 2) & 0b11100) | (frame[1] >> 6);
let byte3 = (frame[1] >> 1) & 0b11111;
let byte4 = ((frame[1] << 4) & 0b10000) | (frame[2] >> 4);
let byte5 = ((frame[2] << 1) & 0b11110) | (frame[3] >> 7);
let byte6 = (frame[3] >> 2) & 0b11111;
let byte7 = ((frame[3] << 3) & 0b11000) | (frame[4] >> 5);
let byte8 = frame[4] & 0b11111;
let bytes = [
if byte1 > BREAKPOINT { byte1 + UPPER_ADJUSTED } else { byte1 + LOWER },
if byte2 > BREAKPOINT { byte2 + UPPER_ADJUSTED } else { byte2 + LOWER },
if byte3 > BREAKPOINT { byte3 + UPPER_ADJUSTED } else { byte3 + LOWER },
if byte4 > BREAKPOINT { byte4 + UPPER_ADJUSTED } else { byte4 + LOWER },
if byte5 > BREAKPOINT { byte5 + UPPER_ADJUSTED } else { byte5 + LOWER },
if byte6 > BREAKPOINT { byte6 + UPPER_ADJUSTED } else { byte6 + LOWER },
if byte7 > BREAKPOINT { byte7 + UPPER_ADJUSTED } else { byte7 + LOWER },
if byte8 > BREAKPOINT { byte8 + UPPER_ADJUSTED } else { byte8 + LOWER }
];
unsafe { dest.write_bytes_const::<8>(bytes.as_ptr()) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rfc_provided_examples() {
let examples = [
("", ""),
("f", "MY======"),
("fo", "MZXQ===="),
("foo", "MZXW6==="),
("foob", "MZXW6YQ="),
("fooba", "MZXW6YTB"),
("foobar", "MZXW6YTBOI======")
];
for (bytes, encoded) in examples {
assert_eq!(encoded, encode_base32(bytes.as_bytes()));
}
}
#[test]
fn rfc_provided_examples_base32hex() {
let examples = [
("", ""),
("f", "CO======"),
("fo", "CPNG===="),
("foo", "CPNMU==="),
("foob", "CPNMUOG="),
("fooba", "CPNMUOJ1"),
("foobar", "CPNMUOJ1E8======")
];
for (bytes, encoded) in examples {
assert_eq!(encoded, encode_base32hex(bytes.as_bytes()));
}
}
}