1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12 arena::HandleVec,
13 proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
14 valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
15 Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
16 Span, Statement, TypeInner, WithSpan,
17};
18
19#[derive(Error, Debug, Clone)]
20#[cfg_attr(test, derive(PartialEq))]
21pub enum PipelineConstantError {
22 #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
23 MissingValue(String),
24 #[error(
25 "Source f64 value needs to be finite ({}) for number destinations",
26 "NaNs and Inifinites are not allowed"
27 )]
28 SrcNeedsToBeFinite,
29 #[error("Source f64 value doesn't fit in destination")]
30 DstRangeTooSmall,
31 #[error(transparent)]
32 ConstantEvaluatorError(#[from] ConstantEvaluatorError),
33 #[error(transparent)]
34 ValidationError(#[from] WithSpan<ValidationError>),
35 #[error("workgroup_size override isn't strictly positive")]
36 NegativeWorkgroupSize,
37}
38
39pub fn process_overrides<'a>(
53 module: &'a Module,
54 module_info: &'a ModuleInfo,
55 pipeline_constants: &PipelineConstants,
56) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
57 if module.overrides.is_empty() {
58 return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
59 }
60
61 let mut module = module.clone();
62
63 let mut override_map = HandleVec::with_capacity(module.overrides.len());
66
67 let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
70
71 let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
82
83 let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
84 let mut layouter = crate::proc::Layouter::default();
85
86 let mut overrides = mem::take(&mut module.overrides);
89 let mut override_iter = overrides.iter_mut_span();
90
91 for (old_h, expr, span) in module.global_expressions.drain() {
121 let mut expr = match expr {
122 Expression::Override(h) => {
123 let c_h = if let Some(new_h) = override_map.get(h) {
124 *new_h
125 } else {
126 let mut new_h = None;
127 for entry in override_iter.by_ref() {
128 let stop = entry.0 == h;
129 new_h = Some(process_override(
130 entry,
131 pipeline_constants,
132 &mut module,
133 &mut override_map,
134 &adjusted_global_expressions,
135 &mut adjusted_constant_initializers,
136 &mut global_expression_kind_tracker,
137 )?);
138 if stop {
139 break;
140 }
141 }
142 new_h.unwrap()
143 };
144 Expression::Constant(c_h)
145 }
146 Expression::Constant(c_h) => {
147 if adjusted_constant_initializers.insert(c_h) {
148 let init = &mut module.constants[c_h].init;
149 *init = adjusted_global_expressions[*init];
150 }
151 expr
152 }
153 expr => expr,
154 };
155 let mut evaluator = ConstantEvaluator::for_wgsl_module(
156 &mut module,
157 &mut global_expression_kind_tracker,
158 &mut layouter,
159 false,
160 );
161 adjust_expr(&adjusted_global_expressions, &mut expr);
162 let h = evaluator.try_eval_and_append(expr, span)?;
163 adjusted_global_expressions.insert(old_h, h);
164 }
165
166 for entry in override_iter {
168 match *entry.1 {
169 Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
170 process_override(
171 entry,
172 pipeline_constants,
173 &mut module,
174 &mut override_map,
175 &adjusted_global_expressions,
176 &mut adjusted_constant_initializers,
177 &mut global_expression_kind_tracker,
178 )?;
179 }
180 Override {
181 init: Some(ref mut init),
182 ..
183 } => {
184 *init = adjusted_global_expressions[*init];
185 }
186 _ => {}
187 }
188 }
189
190 for (_, c) in module
194 .constants
195 .iter_mut()
196 .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
197 {
198 c.init = adjusted_global_expressions[c.init];
199 }
200
201 for (_, v) in module.global_variables.iter_mut() {
202 if let Some(ref mut init) = v.init {
203 *init = adjusted_global_expressions[*init];
204 }
205 }
206
207 let mut functions = mem::take(&mut module.functions);
208 for (_, function) in functions.iter_mut() {
209 process_function(&mut module, &override_map, &mut layouter, function)?;
210 }
211 module.functions = functions;
212
213 let mut entry_points = mem::take(&mut module.entry_points);
214 for ep in entry_points.iter_mut() {
215 process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
216 process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
217 }
218 module.entry_points = entry_points;
219 module.overrides = overrides;
220
221 let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
225 let module_info = validator.validate_resolved_overrides(&module)?;
226
227 Ok((Cow::Owned(module), Cow::Owned(module_info)))
228}
229
230fn process_workgroup_size_override(
231 module: &mut Module,
232 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
233 ep: &mut crate::EntryPoint,
234) -> Result<(), PipelineConstantError> {
235 match ep.workgroup_size_overrides {
236 None => {}
237 Some(overrides) => {
238 overrides.iter().enumerate().try_for_each(
239 |(i, overridden)| -> Result<(), PipelineConstantError> {
240 match *overridden {
241 None => Ok(()),
242 Some(h) => {
243 ep.workgroup_size[i] = module
244 .to_ctx()
245 .eval_expr_to_u32(adjusted_global_expressions[h])
246 .map(|n| {
247 if n == 0 {
248 Err(PipelineConstantError::NegativeWorkgroupSize)
249 } else {
250 Ok(n)
251 }
252 })
253 .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
254 Ok(())
255 }
256 }
257 },
258 )?;
259 ep.workgroup_size_overrides = None;
260 }
261 }
262 Ok(())
263}
264
265fn process_override(
269 (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
270 pipeline_constants: &PipelineConstants,
271 module: &mut Module,
272 override_map: &mut HandleVec<Override, Handle<Constant>>,
273 adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
274 adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
275 global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
276) -> Result<Handle<Constant>, PipelineConstantError> {
277 let key = if let Some(id) = r#override.id {
279 Cow::Owned(id.to_string())
280 } else if let Some(ref name) = r#override.name {
281 Cow::Borrowed(name)
282 } else {
283 unreachable!();
284 };
285
286 let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
290 let literal = match module.types[r#override.ty].inner {
291 TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
292 _ => unreachable!(),
293 };
294 let expr = module
295 .global_expressions
296 .append(Expression::Literal(literal), Span::UNDEFINED);
297 global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
298 expr
299 } else if let Some(init) = r#override.init {
300 adjusted_global_expressions[init]
301 } else {
302 return Err(PipelineConstantError::MissingValue(key.to_string()));
303 };
304
305 let constant = Constant {
307 name: r#override.name.clone(),
308 ty: r#override.ty,
309 init,
310 };
311 let h = module.constants.append(constant, *span);
312 override_map.insert(old_h, h);
313 adjusted_constant_initializers.insert(h);
314 r#override.init = Some(init);
315 Ok(h)
316}
317
318fn process_function(
328 module: &mut Module,
329 override_map: &HandleVec<Override, Handle<Constant>>,
330 layouter: &mut crate::proc::Layouter,
331 function: &mut Function,
332) -> Result<(), ConstantEvaluatorError> {
333 let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
336
337 let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
338
339 let mut expressions = mem::take(&mut function.expressions);
340
341 let mut emitter = Emitter::default();
349 let mut block = Block::new();
350
351 let mut evaluator = ConstantEvaluator::for_wgsl_function(
352 module,
353 &mut function.expressions,
354 &mut local_expression_kind_tracker,
355 layouter,
356 &mut emitter,
357 &mut block,
358 false,
359 );
360
361 for (old_h, mut expr, span) in expressions.drain() {
362 if let Expression::Override(h) = expr {
363 expr = Expression::Constant(override_map[h]);
364 }
365 adjust_expr(&adjusted_local_expressions, &mut expr);
366 let h = evaluator.try_eval_and_append(expr, span)?;
367 adjusted_local_expressions.insert(old_h, h);
368 }
369
370 adjust_block(&adjusted_local_expressions, &mut function.body);
371
372 filter_emits_in_block(&mut function.body, &function.expressions);
373
374 for (_, local) in function.local_variables.iter_mut() {
376 if let &mut Some(ref mut init) = &mut local.init {
377 *init = adjusted_local_expressions[*init];
378 }
379 }
380
381 let named_expressions = mem::take(&mut function.named_expressions);
384 for (expr_h, name) in named_expressions {
385 function
386 .named_expressions
387 .insert(adjusted_local_expressions[expr_h], name);
388 }
389
390 Ok(())
391}
392
393fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
396 let adjust = |expr: &mut Handle<Expression>| {
397 *expr = new_pos[*expr];
398 };
399 match *expr {
400 Expression::Compose {
401 ref mut components,
402 ty: _,
403 } => {
404 for c in components.iter_mut() {
405 adjust(c);
406 }
407 }
408 Expression::Access {
409 ref mut base,
410 ref mut index,
411 } => {
412 adjust(base);
413 adjust(index);
414 }
415 Expression::AccessIndex {
416 ref mut base,
417 index: _,
418 } => {
419 adjust(base);
420 }
421 Expression::Splat {
422 ref mut value,
423 size: _,
424 } => {
425 adjust(value);
426 }
427 Expression::Swizzle {
428 ref mut vector,
429 size: _,
430 pattern: _,
431 } => {
432 adjust(vector);
433 }
434 Expression::Load { ref mut pointer } => {
435 adjust(pointer);
436 }
437 Expression::ImageSample {
438 ref mut image,
439 ref mut sampler,
440 ref mut coordinate,
441 ref mut array_index,
442 ref mut offset,
443 ref mut level,
444 ref mut depth_ref,
445 gather: _,
446 } => {
447 adjust(image);
448 adjust(sampler);
449 adjust(coordinate);
450 if let Some(e) = array_index.as_mut() {
451 adjust(e);
452 }
453 if let Some(e) = offset.as_mut() {
454 adjust(e);
455 }
456 match *level {
457 crate::SampleLevel::Exact(ref mut expr)
458 | crate::SampleLevel::Bias(ref mut expr) => {
459 adjust(expr);
460 }
461 crate::SampleLevel::Gradient {
462 ref mut x,
463 ref mut y,
464 } => {
465 adjust(x);
466 adjust(y);
467 }
468 _ => {}
469 }
470 if let Some(e) = depth_ref.as_mut() {
471 adjust(e);
472 }
473 }
474 Expression::ImageLoad {
475 ref mut image,
476 ref mut coordinate,
477 ref mut array_index,
478 ref mut sample,
479 ref mut level,
480 } => {
481 adjust(image);
482 adjust(coordinate);
483 if let Some(e) = array_index.as_mut() {
484 adjust(e);
485 }
486 if let Some(e) = sample.as_mut() {
487 adjust(e);
488 }
489 if let Some(e) = level.as_mut() {
490 adjust(e);
491 }
492 }
493 Expression::ImageQuery {
494 ref mut image,
495 ref mut query,
496 } => {
497 adjust(image);
498 match *query {
499 crate::ImageQuery::Size { ref mut level } => {
500 if let Some(e) = level.as_mut() {
501 adjust(e);
502 }
503 }
504 crate::ImageQuery::NumLevels
505 | crate::ImageQuery::NumLayers
506 | crate::ImageQuery::NumSamples => {}
507 }
508 }
509 Expression::Unary {
510 ref mut expr,
511 op: _,
512 } => {
513 adjust(expr);
514 }
515 Expression::Binary {
516 ref mut left,
517 ref mut right,
518 op: _,
519 } => {
520 adjust(left);
521 adjust(right);
522 }
523 Expression::Select {
524 ref mut condition,
525 ref mut accept,
526 ref mut reject,
527 } => {
528 adjust(condition);
529 adjust(accept);
530 adjust(reject);
531 }
532 Expression::Derivative {
533 ref mut expr,
534 axis: _,
535 ctrl: _,
536 } => {
537 adjust(expr);
538 }
539 Expression::Relational {
540 ref mut argument,
541 fun: _,
542 } => {
543 adjust(argument);
544 }
545 Expression::Math {
546 ref mut arg,
547 ref mut arg1,
548 ref mut arg2,
549 ref mut arg3,
550 fun: _,
551 } => {
552 adjust(arg);
553 if let Some(e) = arg1.as_mut() {
554 adjust(e);
555 }
556 if let Some(e) = arg2.as_mut() {
557 adjust(e);
558 }
559 if let Some(e) = arg3.as_mut() {
560 adjust(e);
561 }
562 }
563 Expression::As {
564 ref mut expr,
565 kind: _,
566 convert: _,
567 } => {
568 adjust(expr);
569 }
570 Expression::ArrayLength(ref mut expr) => {
571 adjust(expr);
572 }
573 Expression::RayQueryGetIntersection {
574 ref mut query,
575 committed: _,
576 } => {
577 adjust(query);
578 }
579 Expression::Literal(_)
580 | Expression::FunctionArgument(_)
581 | Expression::GlobalVariable(_)
582 | Expression::LocalVariable(_)
583 | Expression::CallResult(_)
584 | Expression::RayQueryProceedResult
585 | Expression::Constant(_)
586 | Expression::Override(_)
587 | Expression::ZeroValue(_)
588 | Expression::AtomicResult {
589 ty: _,
590 comparison: _,
591 }
592 | Expression::WorkGroupUniformLoadResult { ty: _ }
593 | Expression::SubgroupBallotResult
594 | Expression::SubgroupOperationResult { .. } => {}
595 Expression::RayQueryVertexPositions {
596 ref mut query,
597 committed: _,
598 } => {
599 adjust(query);
600 }
601 }
602}
603
604fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
607 for stmt in block.iter_mut() {
608 adjust_stmt(new_pos, stmt);
609 }
610}
611
612fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
615 let adjust = |expr: &mut Handle<Expression>| {
616 *expr = new_pos[*expr];
617 };
618 match *stmt {
619 Statement::Emit(ref mut range) => {
620 if let Some((mut first, mut last)) = range.first_and_last() {
621 adjust(&mut first);
622 adjust(&mut last);
623 *range = Range::new_from_bounds(first, last);
624 }
625 }
626 Statement::Block(ref mut block) => {
627 adjust_block(new_pos, block);
628 }
629 Statement::If {
630 ref mut condition,
631 ref mut accept,
632 ref mut reject,
633 } => {
634 adjust(condition);
635 adjust_block(new_pos, accept);
636 adjust_block(new_pos, reject);
637 }
638 Statement::Switch {
639 ref mut selector,
640 ref mut cases,
641 } => {
642 adjust(selector);
643 for case in cases.iter_mut() {
644 adjust_block(new_pos, &mut case.body);
645 }
646 }
647 Statement::Loop {
648 ref mut body,
649 ref mut continuing,
650 ref mut break_if,
651 } => {
652 adjust_block(new_pos, body);
653 adjust_block(new_pos, continuing);
654 if let Some(e) = break_if.as_mut() {
655 adjust(e);
656 }
657 }
658 Statement::Return { ref mut value } => {
659 if let Some(e) = value.as_mut() {
660 adjust(e);
661 }
662 }
663 Statement::Store {
664 ref mut pointer,
665 ref mut value,
666 } => {
667 adjust(pointer);
668 adjust(value);
669 }
670 Statement::ImageStore {
671 ref mut image,
672 ref mut coordinate,
673 ref mut array_index,
674 ref mut value,
675 } => {
676 adjust(image);
677 adjust(coordinate);
678 if let Some(e) = array_index.as_mut() {
679 adjust(e);
680 }
681 adjust(value);
682 }
683 Statement::Atomic {
684 ref mut pointer,
685 ref mut value,
686 ref mut result,
687 ref mut fun,
688 } => {
689 adjust(pointer);
690 adjust(value);
691 if let Some(ref mut result) = *result {
692 adjust(result);
693 }
694 match *fun {
695 crate::AtomicFunction::Exchange {
696 compare: Some(ref mut compare),
697 } => {
698 adjust(compare);
699 }
700 crate::AtomicFunction::Add
701 | crate::AtomicFunction::Subtract
702 | crate::AtomicFunction::And
703 | crate::AtomicFunction::ExclusiveOr
704 | crate::AtomicFunction::InclusiveOr
705 | crate::AtomicFunction::Min
706 | crate::AtomicFunction::Max
707 | crate::AtomicFunction::Exchange { compare: None } => {}
708 }
709 }
710 Statement::ImageAtomic {
711 ref mut image,
712 ref mut coordinate,
713 ref mut array_index,
714 fun: _,
715 ref mut value,
716 } => {
717 adjust(image);
718 adjust(coordinate);
719 if let Some(ref mut array_index) = *array_index {
720 adjust(array_index);
721 }
722 adjust(value);
723 }
724 Statement::WorkGroupUniformLoad {
725 ref mut pointer,
726 ref mut result,
727 } => {
728 adjust(pointer);
729 adjust(result);
730 }
731 Statement::SubgroupBallot {
732 ref mut result,
733 ref mut predicate,
734 } => {
735 if let Some(ref mut predicate) = *predicate {
736 adjust(predicate);
737 }
738 adjust(result);
739 }
740 Statement::SubgroupCollectiveOperation {
741 ref mut argument,
742 ref mut result,
743 ..
744 } => {
745 adjust(argument);
746 adjust(result);
747 }
748 Statement::SubgroupGather {
749 ref mut mode,
750 ref mut argument,
751 ref mut result,
752 } => {
753 match *mode {
754 crate::GatherMode::BroadcastFirst => {}
755 crate::GatherMode::Broadcast(ref mut index)
756 | crate::GatherMode::Shuffle(ref mut index)
757 | crate::GatherMode::ShuffleDown(ref mut index)
758 | crate::GatherMode::ShuffleUp(ref mut index)
759 | crate::GatherMode::ShuffleXor(ref mut index) => {
760 adjust(index);
761 }
762 }
763 adjust(argument);
764 adjust(result)
765 }
766 Statement::Call {
767 ref mut arguments,
768 ref mut result,
769 function: _,
770 } => {
771 for argument in arguments.iter_mut() {
772 adjust(argument);
773 }
774 if let Some(e) = result.as_mut() {
775 adjust(e);
776 }
777 }
778 Statement::RayQuery {
779 ref mut query,
780 ref mut fun,
781 } => {
782 adjust(query);
783 match *fun {
784 crate::RayQueryFunction::Initialize {
785 ref mut acceleration_structure,
786 ref mut descriptor,
787 } => {
788 adjust(acceleration_structure);
789 adjust(descriptor);
790 }
791 crate::RayQueryFunction::Proceed { ref mut result } => {
792 adjust(result);
793 }
794 crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
795 adjust(hit_t);
796 }
797 crate::RayQueryFunction::ConfirmIntersection => {}
798 crate::RayQueryFunction::Terminate => {}
799 }
800 }
801 Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
802 }
803}
804
805fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
822 let original = mem::replace(block, Block::with_capacity(block.len()));
823 for (stmt, span) in original.span_into_iter() {
824 match stmt {
825 Statement::Emit(range) => {
826 let mut current = None;
827 for expr_h in range {
828 if expressions[expr_h].needs_pre_emit() {
829 if let Some((first, last)) = current {
830 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
831 }
832
833 current = None;
834 } else if let Some((_, ref mut last)) = current {
835 *last = expr_h;
836 } else {
837 current = Some((expr_h, expr_h));
838 }
839 }
840 if let Some((first, last)) = current {
841 block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
842 }
843 }
844 Statement::Block(mut child) => {
845 filter_emits_in_block(&mut child, expressions);
846 block.push(Statement::Block(child), span);
847 }
848 Statement::If {
849 condition,
850 mut accept,
851 mut reject,
852 } => {
853 filter_emits_in_block(&mut accept, expressions);
854 filter_emits_in_block(&mut reject, expressions);
855 block.push(
856 Statement::If {
857 condition,
858 accept,
859 reject,
860 },
861 span,
862 );
863 }
864 Statement::Switch {
865 selector,
866 mut cases,
867 } => {
868 for case in &mut cases {
869 filter_emits_in_block(&mut case.body, expressions);
870 }
871 block.push(Statement::Switch { selector, cases }, span);
872 }
873 Statement::Loop {
874 mut body,
875 mut continuing,
876 break_if,
877 } => {
878 filter_emits_in_block(&mut body, expressions);
879 filter_emits_in_block(&mut continuing, expressions);
880 block.push(
881 Statement::Loop {
882 body,
883 continuing,
884 break_if,
885 },
886 span,
887 );
888 }
889 stmt => block.push(stmt.clone(), span),
890 }
891 }
892}
893
894fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
895 match scalar {
897 Scalar::BOOL => {
898 let value = value != 0.0 && !value.is_nan();
900 Ok(Literal::Bool(value))
901 }
902 Scalar::I32 => {
903 if !value.is_finite() {
905 return Err(PipelineConstantError::SrcNeedsToBeFinite);
906 }
907
908 let value = value.trunc();
909 if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
910 return Err(PipelineConstantError::DstRangeTooSmall);
911 }
912
913 let value = value as i32;
914 Ok(Literal::I32(value))
915 }
916 Scalar::U32 => {
917 if !value.is_finite() {
919 return Err(PipelineConstantError::SrcNeedsToBeFinite);
920 }
921
922 let value = value.trunc();
923 if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
924 return Err(PipelineConstantError::DstRangeTooSmall);
925 }
926
927 let value = value as u32;
928 Ok(Literal::U32(value))
929 }
930 Scalar::F32 => {
931 if !value.is_finite() {
933 return Err(PipelineConstantError::SrcNeedsToBeFinite);
934 }
935
936 let value = value as f32;
937 if !value.is_finite() {
938 return Err(PipelineConstantError::DstRangeTooSmall);
939 }
940
941 Ok(Literal::F32(value))
942 }
943 Scalar::F64 => {
944 if !value.is_finite() {
946 return Err(PipelineConstantError::SrcNeedsToBeFinite);
947 }
948
949 Ok(Literal::F64(value))
950 }
951 _ => unreachable!(),
952 }
953}
954
955#[test]
956fn test_map_value_to_literal() {
957 let bool_test_cases = [
958 (0.0, false),
959 (-0.0, false),
960 (f64::NAN, false),
961 (1.0, true),
962 (f64::INFINITY, true),
963 (f64::NEG_INFINITY, true),
964 ];
965 for (value, out) in bool_test_cases {
966 let res = Ok(Literal::Bool(out));
967 assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
968 }
969
970 for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
971 for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
972 let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
973 assert_eq!(map_value_to_literal(value, scalar), res);
974 }
975 }
976
977 assert_eq!(
979 map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
980 Ok(Literal::I32(i32::MIN))
981 );
982 assert_eq!(
983 map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
984 Ok(Literal::I32(i32::MAX))
985 );
986 assert_eq!(
987 map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
988 Err(PipelineConstantError::DstRangeTooSmall)
989 );
990 assert_eq!(
991 map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
992 Err(PipelineConstantError::DstRangeTooSmall)
993 );
994
995 assert_eq!(
997 map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
998 Ok(Literal::U32(u32::MIN))
999 );
1000 assert_eq!(
1001 map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1002 Ok(Literal::U32(u32::MAX))
1003 );
1004 assert_eq!(
1005 map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1006 Err(PipelineConstantError::DstRangeTooSmall)
1007 );
1008 assert_eq!(
1009 map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1010 Err(PipelineConstantError::DstRangeTooSmall)
1011 );
1012
1013 assert_eq!(
1015 map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1016 Ok(Literal::F32(f32::MIN))
1017 );
1018 assert_eq!(
1019 map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1020 Ok(Literal::F32(f32::MAX))
1021 );
1022 assert_eq!(
1023 map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1024 Ok(Literal::F32(f32::MIN))
1025 );
1026 assert_eq!(
1027 map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1028 Ok(Literal::F32(f32::MAX))
1029 );
1030 assert_eq!(
1031 map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1032 Err(PipelineConstantError::DstRangeTooSmall)
1033 );
1034 assert_eq!(
1035 map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1036 Err(PipelineConstantError::DstRangeTooSmall)
1037 );
1038
1039 assert_eq!(
1041 map_value_to_literal(f64::MIN, Scalar::F64),
1042 Ok(Literal::F64(f64::MIN))
1043 );
1044 assert_eq!(
1045 map_value_to_literal(f64::MAX, Scalar::F64),
1046 Ok(Literal::F64(f64::MAX))
1047 );
1048}