wiwi/encoding/
base32.rs

1use crate::prelude::*;
2use super::{ ChunkedSlice, UnsafeBufWriteGuard };
3
4// // table unused, for ref only, cause it can be calculated
5// pub const TABLE_ENCODER_LEN: usize = 32;
6// pub static TABLE_ENCODER: &[u8; TABLE_ENCODER_LEN] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
7// pub static TABLE_ENCODER_BASE32HEX: &[u8; TABLE_ENCODER_LEN] = b"0123456789ABCDEFGHIJKLMNOPQRSTUV";
8
9pub const BINARY_FRAME_LEN: usize = 5;
10pub const STRING_FRAME_LEN: usize = 8;
11
12/// Encodes the given bytes into a base32 [`String`], as specified in
13/// [RFC 4648].
14///
15/// [RFC 4648]: https://datatracker.ietf.org/doc/html/rfc4648#section-6
16#[inline]
17pub fn encode_base32(bytes: &[u8]) -> String {
18	_encode::<25, b'A', { b'2' - 26 }>(bytes)
19}
20
21/// Encodes the given bytes into a base32 [`String`], using
22/// the [hex encoding alphabet variant as defined in RFC 4648].
23///
24/// [hex encoding alphabet variant as defined in RFC 4648]: https://datatracker.ietf.org/doc/html/rfc4648#section-7
25#[inline]
26pub fn encode_base32hex(bytes: &[u8]) -> String {
27	_encode::<9, b'0', { b'A' - 10 }>(bytes)
28}
29
30/// - `BREAKPOINT`: the gt comparison against this number to determin when to use
31///   LOWER or UPPER_ADJUSTED
32/// - `LOWER`: the amount to add to a section when it is lt than `BREAKPOINT`.
33///   (ie. the lowest in the range)
34/// - `UPPER_ADJUSTED`: the amount to add to a section when it is gte `BREAKPOINT`.
35///   "Adjusted" means that `BREAKPOINT` should be subtracted from the UPPER
36///   char value, so no subtraction needs to be done in runtime.
37fn _encode<
38	const BREAKPOINT: u8,
39	const LOWER: u8,
40	const UPPER_ADJUSTED: u8
41>(bytes: &[u8]) -> String {
42	// 5 bytes per group of 8 output chars
43	let frames = bytes.len() / BINARY_FRAME_LEN;
44	let remainder = bytes.len() % BINARY_FRAME_LEN;
45
46	let capacity = if remainder == 0 {
47		frames * STRING_FRAME_LEN
48	} else {
49		(frames + 1) * STRING_FRAME_LEN
50	};
51
52	let mut frames_iter = ChunkedSlice::<BINARY_FRAME_LEN>::new(bytes);
53	let mut dest = UnsafeBufWriteGuard::with_capacity(capacity);
54
55	for _ in 0..frames {
56		// SAFETY: calculated
57		let frame = unsafe { frames_iter.next_frame_unchecked() };
58
59		// SAFETY: calculated
60		unsafe { encode_frame::<BREAKPOINT, LOWER, UPPER_ADJUSTED>(frame, &mut dest) }
61	}
62
63	if remainder > 0 {
64		// determine padding amount
65		let padding_amount = match remainder {
66			1 => { 6 }
67			2 => { 4 }
68			3 => { 3 }
69			4 => { 1 }
70			_ => {
71				// SAFETY: `remainder` is calculated by mod 5, so it cannot be 5 or
72				// more. and we just checked in an if statement that `remainder` is
73				// greater than 0. therefore, `remainder` can only be 1, 2, 3, or 4,
74				// all of which are covered by match branches.
75				unsafe { hint::unreachable_unchecked() }
76			}
77		};
78
79		let f = |frame: &[u8; 5]| {
80			static PADDING: &[u8; 6] = b"======";
81
82			// SAFETY: calculated
83			unsafe { encode_frame::<BREAKPOINT, LOWER, UPPER_ADJUSTED>(frame, &mut dest) }
84
85			// SAFETY: calculated
86			let ptr = unsafe { dest.as_mut_ptr().sub(padding_amount) };
87
88			// SAFETY: calculated
89			unsafe { ptr::copy_nonoverlapping(PADDING.as_ptr(), ptr, padding_amount) }
90		};
91
92		// SAFETY: calculated
93		unsafe { frames_iter.with_remainder_unchecked(f) }
94	}
95
96	// SAFETY: calculated
97	let vec = unsafe { dest.into_full_vec() };
98
99	// SAFETY: we only write valid ASCII as part of encoding process
100	unsafe {
101		debug_assert!(str::from_utf8(&vec).is_ok(), "output bytes valid utf-8");
102		String::from_utf8_unchecked(vec)
103	}
104}
105
106unsafe fn encode_frame<
107	const BREAKPOINT: u8,
108	const LOWER: u8,
109	const UPPER_ADJUSTED: u8
110>(frame: &[u8; BINARY_FRAME_LEN], dest: &mut UnsafeBufWriteGuard) {
111	// keep first 5 bits from byte 0, leaving 3 bits left
112	let byte1 = frame[0] >> 3;
113
114	// take remaining 3 from byte 0, then 2 from byte 1, leaving 6 bits left
115	let byte2 = ((frame[0] << 2) & 0b11100) | (frame[1] >> 6);
116
117	// take 5 in middle of byte 1, leaving 1 bit left
118	let byte3 = (frame[1] >> 1) & 0b11111;
119
120	// take last bit from byte 1, then 4 from byte 2, leaving 4 bits left
121	let byte4 = ((frame[1] << 4) & 0b10000) | (frame[2] >> 4);
122
123	// take last 4 bits from byte 2, then 1 from byte 3, leaving 7 bits left
124	let byte5 = ((frame[2] << 1) & 0b11110) | (frame[3] >> 7);
125
126	// take 5 from byte 3, leaving 2 bits left
127	let byte6 = (frame[3] >> 2) & 0b11111;
128
129	// take remaining 2 bits from byte 3, then 3 bits from byte 4, leaving 5 bits left
130	let byte7 = ((frame[3] << 3) & 0b11000) | (frame[4] >> 5);
131
132	// take remaining 5 bits
133	let byte8 = frame[4] & 0b11111;
134
135	let bytes = [
136		// multi cursor editing is great
137		if byte1 > BREAKPOINT { byte1 + UPPER_ADJUSTED } else { byte1 + LOWER },
138		if byte2 > BREAKPOINT { byte2 + UPPER_ADJUSTED } else { byte2 + LOWER },
139		if byte3 > BREAKPOINT { byte3 + UPPER_ADJUSTED } else { byte3 + LOWER },
140		if byte4 > BREAKPOINT { byte4 + UPPER_ADJUSTED } else { byte4 + LOWER },
141		if byte5 > BREAKPOINT { byte5 + UPPER_ADJUSTED } else { byte5 + LOWER },
142		if byte6 > BREAKPOINT { byte6 + UPPER_ADJUSTED } else { byte6 + LOWER },
143		if byte7 > BREAKPOINT { byte7 + UPPER_ADJUSTED } else { byte7 + LOWER },
144		if byte8 > BREAKPOINT { byte8 + UPPER_ADJUSTED } else { byte8 + LOWER }
145	];
146
147	// SAFETY: caller promises we can call this once
148	unsafe { dest.write_bytes_const::<8>(bytes.as_ptr()) }
149}
150
151#[cfg(test)]
152mod tests {
153	use super::*;
154
155	#[test]
156	fn rfc_provided_examples() {
157		let examples = [
158			("", ""),
159			("f", "MY======"),
160			("fo", "MZXQ===="),
161			("foo", "MZXW6==="),
162			("foob", "MZXW6YQ="),
163			("fooba", "MZXW6YTB"),
164			("foobar", "MZXW6YTBOI======")
165		];
166
167		for (bytes, encoded) in examples {
168			assert_eq!(encoded, encode_base32(bytes.as_bytes()));
169		}
170	}
171
172	#[test]
173	fn rfc_provided_examples_base32hex() {
174		let examples = [
175			("", ""),
176			("f", "CO======"),
177			("fo", "CPNG===="),
178			("foo", "CPNMU==="),
179			("foob", "CPNMUOG="),
180			("fooba", "CPNMUOJ1"),
181			("foobar", "CPNMUOJ1E8======")
182		];
183
184		for (bytes, encoded) in examples {
185			assert_eq!(encoded, encode_base32hex(bytes.as_bytes()));
186		}
187	}
188}