1pub(crate) fn decode_secret<'a>(input: &[u8], output: &'a mut [u8]) -> Result<&'a [u8], Error> {
17 decode(input, output, CodePoint::decode_secret)
18}
19
20pub(crate) fn decode_public<'a>(input: &[u8], output: &'a mut [u8]) -> Result<&'a [u8], Error> {
27 decode(input, output, CodePoint::decode_public)
28}
29
30pub(crate) const fn decoded_length(base64_len: usize) -> usize {
33 ((base64_len + 3) / 4) * 3
34}
35
36fn decode<'a>(
37 input: &[u8],
38 output: &'a mut [u8],
39 decode_byte: impl Fn(u8) -> CodePoint,
40) -> Result<&'a [u8], Error> {
41 let mut buffer = 0u64;
42 let mut used = 0;
43 let mut shift = SHIFT_INITIAL;
44 let mut pad_mask = 0;
45
46 let mut output_offset = 0;
47
48 const SHIFT_INITIAL: i32 = (8 - 1) * 6;
49
50 for byte in input.iter().copied() {
51 let (item, pad) = match decode_byte(byte) {
52 CodePoint::WHITESPACE => continue,
53 CodePoint::INVALID => return Err(Error::InvalidCharacter(byte)),
54 CodePoint::PAD => (0, 1),
55 CodePoint(n) => (n, 0),
56 };
57
58 if used == 8 {
61 if pad_mask != 0b0000_0000 {
62 return Err(Error::PrematurePadding);
63 }
64
65 let chunk = output
66 .get_mut(output_offset..output_offset + 6)
67 .ok_or(Error::InsufficientOutputSpace)?;
68
69 chunk[0] = (buffer >> 40) as u8;
70 chunk[1] = (buffer >> 32) as u8;
71 chunk[2] = (buffer >> 24) as u8;
72 chunk[3] = (buffer >> 16) as u8;
73 chunk[4] = (buffer >> 8) as u8;
74 chunk[5] = buffer as u8;
75
76 output_offset += 6;
77 buffer = 0;
78 used = 0;
79 pad_mask = 0;
80 shift = SHIFT_INITIAL;
81 }
82
83 buffer |= (item as u64) << shift;
84 shift -= 6;
85 pad_mask |= pad << used;
86 used += 1;
87 }
88
89 if used > 4 {
91 if pad_mask & 0b0000_1111 != 0 {
92 return Err(Error::PrematurePadding);
93 }
94 let chunk = output
95 .get_mut(output_offset..output_offset + 3)
96 .ok_or(Error::InsufficientOutputSpace)?;
97 chunk[0] = (buffer >> 40) as u8;
98 chunk[1] = (buffer >> 32) as u8;
99 chunk[2] = (buffer >> 24) as u8;
100
101 buffer <<= 24;
102 pad_mask >>= 4;
103 used -= 4;
104 output_offset += 3;
105 }
106
107 match (used, pad_mask) {
108 (0, 0b0000) => {}
110
111 (4, 0b0000) => {
113 let chunk = output
114 .get_mut(output_offset..output_offset + 3)
115 .ok_or(Error::InsufficientOutputSpace)?;
116 chunk[0] = (buffer >> 40) as u8;
117 chunk[1] = (buffer >> 32) as u8;
118 chunk[2] = (buffer >> 24) as u8;
119 output_offset += 3;
120 }
121
122 (4, 0b1000) | (3, 0b0000) => {
124 let chunk = output
125 .get_mut(output_offset..output_offset + 2)
126 .ok_or(Error::InsufficientOutputSpace)?;
127
128 chunk[0] = (buffer >> 40) as u8;
129 chunk[1] = (buffer >> 32) as u8;
130 output_offset += 2;
131 }
132
133 (4, 0b1100) | (2, 0b0000) => {
135 let chunk = output
136 .get_mut(output_offset..output_offset + 1)
137 .ok_or(Error::InsufficientOutputSpace)?;
138 chunk[0] = (buffer >> 40) as u8;
139 output_offset += 1;
140 }
141
142 _ => return Err(Error::InvalidTrailingPadding),
144 }
145
146 Ok(&output[..output_offset])
147}
148
149#[derive(Debug, PartialEq)]
150pub(crate) enum Error {
151 InvalidCharacter(u8),
153
154 PrematurePadding,
157
158 InvalidTrailingPadding,
160
161 InsufficientOutputSpace,
165}
166
167#[derive(Copy, Clone, Debug, Eq, PartialEq)]
168struct CodePoint(u8);
169
170impl CodePoint {
171 const WHITESPACE: Self = Self(0xf0);
172 const PAD: Self = Self(0xf1);
173 const INVALID: Self = Self(0xf2);
174}
175
176impl CodePoint {
177 fn decode_secret(b: u8) -> Self {
184 let is_upper = u8_in_range(b, b'A', b'Z');
185 let is_lower = u8_in_range(b, b'a', b'z');
186 let is_digit = u8_in_range(b, b'0', b'9');
187 let is_plus = u8_equals(b, b'+');
188 let is_slash = u8_equals(b, b'/');
189 let is_pad = u8_equals(b, b'=');
190 let is_space = u8_in_range(b, b'\t', b'\r') | u8_equals(b, b' ');
191
192 let is_invalid = !(is_lower | is_upper | is_digit | is_plus | is_slash | is_pad | is_space);
193
194 Self(
195 (is_upper & b.wrapping_sub(b'A'))
196 | (is_lower & (b.wrapping_sub(b'a').wrapping_add(26)))
197 | (is_digit & (b.wrapping_sub(b'0').wrapping_add(52)))
198 | (is_plus & 62)
199 | (is_slash & 63)
200 | (is_space & Self::WHITESPACE.0)
201 | (is_pad & Self::PAD.0)
202 | (is_invalid & Self::INVALID.0),
203 )
204 }
205
206 const fn decode_public(a: u8) -> Self {
207 const TABLE: [CodePoint; 256] = [
208 CodePoint::INVALID,
210 CodePoint::INVALID,
211 CodePoint::INVALID,
212 CodePoint::INVALID,
213 CodePoint::INVALID,
214 CodePoint::INVALID,
215 CodePoint::INVALID,
216 CodePoint::INVALID,
217 CodePoint::INVALID,
218 CodePoint::WHITESPACE,
219 CodePoint::WHITESPACE,
220 CodePoint::WHITESPACE,
221 CodePoint::WHITESPACE,
222 CodePoint::WHITESPACE,
223 CodePoint::INVALID,
224 CodePoint::INVALID,
225 CodePoint::INVALID,
227 CodePoint::INVALID,
228 CodePoint::INVALID,
229 CodePoint::INVALID,
230 CodePoint::INVALID,
231 CodePoint::INVALID,
232 CodePoint::INVALID,
233 CodePoint::INVALID,
234 CodePoint::INVALID,
235 CodePoint::INVALID,
236 CodePoint::INVALID,
237 CodePoint::INVALID,
238 CodePoint::INVALID,
239 CodePoint::INVALID,
240 CodePoint::INVALID,
241 CodePoint::INVALID,
242 CodePoint::WHITESPACE,
244 CodePoint::INVALID,
245 CodePoint::INVALID,
246 CodePoint::INVALID,
247 CodePoint::INVALID,
248 CodePoint::INVALID,
249 CodePoint::INVALID,
250 CodePoint::INVALID,
251 CodePoint::INVALID,
252 CodePoint::INVALID,
253 CodePoint::INVALID,
254 CodePoint(62),
255 CodePoint::INVALID,
256 CodePoint::INVALID,
257 CodePoint::INVALID,
258 CodePoint(63),
259 CodePoint(52),
261 CodePoint(53),
262 CodePoint(54),
263 CodePoint(55),
264 CodePoint(56),
265 CodePoint(57),
266 CodePoint(58),
267 CodePoint(59),
268 CodePoint(60),
269 CodePoint(61),
270 CodePoint::INVALID,
271 CodePoint::INVALID,
272 CodePoint::INVALID,
273 CodePoint::PAD,
274 CodePoint::INVALID,
275 CodePoint::INVALID,
276 CodePoint::INVALID,
278 CodePoint(0),
279 CodePoint(1),
280 CodePoint(2),
281 CodePoint(3),
282 CodePoint(4),
283 CodePoint(5),
284 CodePoint(6),
285 CodePoint(7),
286 CodePoint(8),
287 CodePoint(9),
288 CodePoint(10),
289 CodePoint(11),
290 CodePoint(12),
291 CodePoint(13),
292 CodePoint(14),
293 CodePoint(15),
295 CodePoint(16),
296 CodePoint(17),
297 CodePoint(18),
298 CodePoint(19),
299 CodePoint(20),
300 CodePoint(21),
301 CodePoint(22),
302 CodePoint(23),
303 CodePoint(24),
304 CodePoint(25),
305 CodePoint::INVALID,
306 CodePoint::INVALID,
307 CodePoint::INVALID,
308 CodePoint::INVALID,
309 CodePoint::INVALID,
310 CodePoint::INVALID,
312 CodePoint(26),
313 CodePoint(27),
314 CodePoint(28),
315 CodePoint(29),
316 CodePoint(30),
317 CodePoint(31),
318 CodePoint(32),
319 CodePoint(33),
320 CodePoint(34),
321 CodePoint(35),
322 CodePoint(36),
323 CodePoint(37),
324 CodePoint(38),
325 CodePoint(39),
326 CodePoint(40),
327 CodePoint(41),
329 CodePoint(42),
330 CodePoint(43),
331 CodePoint(44),
332 CodePoint(45),
333 CodePoint(46),
334 CodePoint(47),
335 CodePoint(48),
336 CodePoint(49),
337 CodePoint(50),
338 CodePoint(51),
339 CodePoint::INVALID,
340 CodePoint::INVALID,
341 CodePoint::INVALID,
342 CodePoint::INVALID,
343 CodePoint::INVALID,
344 CodePoint::INVALID,
346 CodePoint::INVALID,
347 CodePoint::INVALID,
348 CodePoint::INVALID,
349 CodePoint::INVALID,
350 CodePoint::INVALID,
351 CodePoint::INVALID,
352 CodePoint::INVALID,
353 CodePoint::INVALID,
354 CodePoint::INVALID,
355 CodePoint::INVALID,
356 CodePoint::INVALID,
357 CodePoint::INVALID,
358 CodePoint::INVALID,
359 CodePoint::INVALID,
360 CodePoint::INVALID,
361 CodePoint::INVALID,
363 CodePoint::INVALID,
364 CodePoint::INVALID,
365 CodePoint::INVALID,
366 CodePoint::INVALID,
367 CodePoint::INVALID,
368 CodePoint::INVALID,
369 CodePoint::INVALID,
370 CodePoint::INVALID,
371 CodePoint::INVALID,
372 CodePoint::INVALID,
373 CodePoint::INVALID,
374 CodePoint::INVALID,
375 CodePoint::INVALID,
376 CodePoint::INVALID,
377 CodePoint::INVALID,
378 CodePoint::INVALID,
380 CodePoint::INVALID,
381 CodePoint::INVALID,
382 CodePoint::INVALID,
383 CodePoint::INVALID,
384 CodePoint::INVALID,
385 CodePoint::INVALID,
386 CodePoint::INVALID,
387 CodePoint::INVALID,
388 CodePoint::INVALID,
389 CodePoint::INVALID,
390 CodePoint::INVALID,
391 CodePoint::INVALID,
392 CodePoint::INVALID,
393 CodePoint::INVALID,
394 CodePoint::INVALID,
395 CodePoint::INVALID,
397 CodePoint::INVALID,
398 CodePoint::INVALID,
399 CodePoint::INVALID,
400 CodePoint::INVALID,
401 CodePoint::INVALID,
402 CodePoint::INVALID,
403 CodePoint::INVALID,
404 CodePoint::INVALID,
405 CodePoint::INVALID,
406 CodePoint::INVALID,
407 CodePoint::INVALID,
408 CodePoint::INVALID,
409 CodePoint::INVALID,
410 CodePoint::INVALID,
411 CodePoint::INVALID,
412 CodePoint::INVALID,
414 CodePoint::INVALID,
415 CodePoint::INVALID,
416 CodePoint::INVALID,
417 CodePoint::INVALID,
418 CodePoint::INVALID,
419 CodePoint::INVALID,
420 CodePoint::INVALID,
421 CodePoint::INVALID,
422 CodePoint::INVALID,
423 CodePoint::INVALID,
424 CodePoint::INVALID,
425 CodePoint::INVALID,
426 CodePoint::INVALID,
427 CodePoint::INVALID,
428 CodePoint::INVALID,
429 CodePoint::INVALID,
431 CodePoint::INVALID,
432 CodePoint::INVALID,
433 CodePoint::INVALID,
434 CodePoint::INVALID,
435 CodePoint::INVALID,
436 CodePoint::INVALID,
437 CodePoint::INVALID,
438 CodePoint::INVALID,
439 CodePoint::INVALID,
440 CodePoint::INVALID,
441 CodePoint::INVALID,
442 CodePoint::INVALID,
443 CodePoint::INVALID,
444 CodePoint::INVALID,
445 CodePoint::INVALID,
446 CodePoint::INVALID,
448 CodePoint::INVALID,
449 CodePoint::INVALID,
450 CodePoint::INVALID,
451 CodePoint::INVALID,
452 CodePoint::INVALID,
453 CodePoint::INVALID,
454 CodePoint::INVALID,
455 CodePoint::INVALID,
456 CodePoint::INVALID,
457 CodePoint::INVALID,
458 CodePoint::INVALID,
459 CodePoint::INVALID,
460 CodePoint::INVALID,
461 CodePoint::INVALID,
462 CodePoint::INVALID,
463 CodePoint::INVALID,
465 CodePoint::INVALID,
466 CodePoint::INVALID,
467 CodePoint::INVALID,
468 CodePoint::INVALID,
469 CodePoint::INVALID,
470 CodePoint::INVALID,
471 CodePoint::INVALID,
472 CodePoint::INVALID,
473 CodePoint::INVALID,
474 CodePoint::INVALID,
475 CodePoint::INVALID,
476 CodePoint::INVALID,
477 CodePoint::INVALID,
478 CodePoint::INVALID,
479 CodePoint::INVALID,
480 ];
481
482 TABLE[a as usize]
483 }
484}
485
486fn u8_in_range(a: u8, lo: u8, hi: u8) -> u8 {
491 debug_assert!(lo <= hi);
492 debug_assert!(hi - lo != 255);
493 let a = a.wrapping_sub(lo);
494 u8_less_than(a, (hi - lo).wrapping_add(1))
495}
496
497fn u8_less_than(a: u8, b: u8) -> u8 {
499 let a = u16::from(a);
500 let b = u16::from(b);
501 u8_broadcast16(a.wrapping_sub(b))
502}
503
504const fn u8_equals(a: u8, b: u8) -> u8 {
506 let diff = a ^ b;
507 u8_nonzero(diff)
508}
509
510const fn u8_nonzero(x: u8) -> u8 {
512 u8_broadcast8(!x & x.wrapping_sub(1))
513}
514
515const fn u8_broadcast8(x: u8) -> u8 {
520 let msb = x >> 7;
521 0u8.wrapping_sub(msb)
522}
523
524const fn u8_broadcast16(x: u16) -> u8 {
529 let msb = x >> 15;
530 0u8.wrapping_sub(msb as u8)
531}
532
533#[cfg(all(test, feature = "alloc"))]
534mod tests {
535 use alloc::vec::Vec;
536
537 use super::*;
538
539 #[test]
540 fn decode_test() {
541 assert_eq!(
542 decode(b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"),
543 b"\x00\x10\x83\x10\x51\x87\x20\x92\x8b\x30\xd3\x8f\x41\x14\x93\x51\x55\x97\
544 \x61\x96\x9b\x71\xd7\x9f\x82\x18\xa3\x92\x59\xa7\xa2\x9a\xab\xb2\xdb\xaf\
545 \xc3\x1c\xb3\xd3\x5d\xb7\xe3\x9e\xbb\xf3\xdf\xbf"
546 );
547 assert_eq!(decode(b"aGVsbG8="), b"hello");
548 assert_eq!(decode(b"aGVsbG8gd29ybGQ="), b"hello world");
549 assert_eq!(decode(b"aGVsbG8gd29ybGQh"), b"hello world!");
550 assert_eq!(decode(b"////"), b"\xff\xff\xff");
551 assert_eq!(decode(b"++++"), b"\xfb\xef\xbe");
552 assert_eq!(decode(b"AAAA"), b"\x00\x00\x00");
553 assert_eq!(decode(b"AAA="), b"\x00\x00");
554 assert_eq!(decode(b"AA=="), b"\x00");
555
556 assert_eq!(decode(b"AAA"), b"\x00\x00");
559 assert_eq!(decode(b"AA"), b"\x00");
560
561 assert_eq!(decode(b""), b"");
562 }
563
564 #[test]
565 fn decode_errors() {
566 let mut buf = [0u8; 6];
567
568 assert_eq!(
570 decode_both(b"A===", &mut buf),
571 Err(Error::InvalidTrailingPadding)
572 );
573 assert_eq!(
574 decode_both(b"====", &mut buf),
575 Err(Error::InvalidTrailingPadding)
576 );
577 assert_eq!(
578 decode_both(b"A==", &mut buf),
579 Err(Error::InvalidTrailingPadding)
580 );
581 assert_eq!(
582 decode_both(b"AA=", &mut buf),
583 Err(Error::InvalidTrailingPadding)
584 );
585 assert_eq!(
586 decode_both(b"A", &mut buf),
587 Err(Error::InvalidTrailingPadding)
588 );
589
590 assert_eq!(
592 decode_both(b"=AAAAA==", &mut buf),
593 Err(Error::PrematurePadding)
594 );
595 assert_eq!(
596 decode_both(b"A=AAAA==", &mut buf),
597 Err(Error::PrematurePadding)
598 );
599 assert_eq!(
600 decode_both(b"AA=AAA==", &mut buf),
601 Err(Error::PrematurePadding)
602 );
603 assert_eq!(
604 decode_both(b"AAA=AA==", &mut buf),
605 Err(Error::PrematurePadding)
606 );
607
608 assert_eq!(
610 decode_both(b"%AAA", &mut buf),
611 Err(Error::InvalidCharacter(b'%'))
612 );
613 assert_eq!(
614 decode_both(b"A%AA", &mut buf),
615 Err(Error::InvalidCharacter(b'%'))
616 );
617 assert_eq!(
618 decode_both(b"AA%A", &mut buf),
619 Err(Error::InvalidCharacter(b'%'))
620 );
621 assert_eq!(
622 decode_both(b"AAA%", &mut buf),
623 Err(Error::InvalidCharacter(b'%'))
624 );
625
626 assert_eq!(decode_both(b"am9lIGJw", &mut [0u8; 7]), Ok(&b"joe bp"[..]));
628 assert_eq!(decode_both(b"am9lIGJw", &mut [0u8; 6]), Ok(&b"joe bp"[..]));
629 assert_eq!(
630 decode_both(b"am9lIGJw", &mut [0u8; 5]),
631 Err(Error::InsufficientOutputSpace)
632 );
633 assert_eq!(
634 decode_both(b"am9lIGJw", &mut [0u8; 4]),
635 Err(Error::InsufficientOutputSpace)
636 );
637 assert_eq!(
638 decode_both(b"am9lIGJw", &mut [0u8; 3]),
639 Err(Error::InsufficientOutputSpace)
640 );
641
642 assert_eq!(decode_both(b"am9=", &mut [0u8; 2]), Ok(&b"jo"[..]));
644 assert_eq!(decode_both(b"am==", &mut [0u8; 1]), Ok(&b"j"[..]));
645 assert_eq!(decode_both(b"am9", &mut [0u8; 2]), Ok(&b"jo"[..]));
646 assert_eq!(decode_both(b"am", &mut [0u8; 1]), Ok(&b"j"[..]));
647 }
648
649 #[test]
650 fn check_models() {
651 fn u8_broadcast8_model(x: u8) -> u8 {
652 match x & 0x80 {
653 0x80 => 0xff,
654 _ => 0x00,
655 }
656 }
657
658 fn u8_broadcast16_model(x: u16) -> u8 {
659 match x & 0x8000 {
660 0x8000 => 0xff,
661 _ => 0x00,
662 }
663 }
664
665 fn u8_nonzero_model(x: u8) -> u8 {
666 match x {
667 0 => 0xff,
668 _ => 0x00,
669 }
670 }
671
672 fn u8_equals_model(x: u8, y: u8) -> u8 {
673 match x == y {
674 true => 0xff,
675 false => 0x00,
676 }
677 }
678
679 fn u8_in_range_model(x: u8, y: u8, z: u8) -> u8 {
680 match (y..=z).contains(&x) {
681 true => 0xff,
682 false => 0x00,
683 }
684 }
685
686 for x in u8::MIN..=u8::MAX {
687 assert_eq!(u8_broadcast8(x), u8_broadcast8_model(x));
688 assert_eq!(u8_nonzero(x), u8_nonzero_model(x));
689 assert_eq!(CodePoint::decode_secret(x), CodePoint::decode_public(x));
690
691 for y in u8::MIN..=u8::MAX {
692 assert_eq!(u8_equals(x, y), u8_equals_model(x, y));
693
694 let v = (x as u16) | ((y as u16) << 8);
695 assert_eq!(u8_broadcast16(v), u8_broadcast16_model(v));
696
697 for z in y..=u8::MAX {
698 if z - y == 255 {
699 continue;
700 }
701 assert_eq!(u8_in_range(x, y, z), u8_in_range_model(x, y, z));
702 }
703 }
704 }
705 }
706
707 #[cfg(all(feature = "std", target_os = "linux", target_arch = "x86_64"))]
708 #[test]
709 fn codepoint_decode_secret_does_not_branch_or_index_on_secret_input() {
710 use crabgrind as cg;
712
713 if matches!(cg::run_mode(), cg::RunMode::Native) {
714 std::println!("SKIPPED: must be run under valgrind");
715 return;
716 }
717
718 let input = [b'a'];
719 cg::monitor_command(format!(
720 "make_memory undefined {:p} {}",
721 input.as_ptr(),
722 input.len()
723 ))
724 .unwrap();
725
726 core::hint::black_box(CodePoint::decode_secret(input[0]));
727 }
728
729 #[track_caller]
730 fn decode(input: &[u8]) -> Vec<u8> {
731 let length = decoded_length(input.len());
732
733 let mut v = alloc::vec![0u8; length];
734 let used = decode_both(input, &mut v).unwrap().len();
735 v.truncate(used);
736
737 v
738 }
739
740 fn decode_both<'a>(input: &'_ [u8], output: &'a mut [u8]) -> Result<&'a [u8], Error> {
741 let mut output_copy = output.to_vec();
742 let r_pub = decode_public(input, &mut output_copy);
743
744 let r_sec = decode_secret(input, output);
745
746 assert_eq!(r_pub, r_sec);
747
748 r_sec
749 }
750}