Flutter Engine
The Flutter Engine
SkSLConstantFolder.cpp
Go to the documentation of this file.
1/*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
9
29
30#include <cstdint>
31#include <float.h>
32#include <limits>
33#include <optional>
34#include <string>
35#include <utility>
36
37using namespace skia_private;
38
39namespace SkSL {
40
41static bool is_vec_or_mat(const Type& type) {
42 switch (type.typeKind()) {
45 return true;
46
47 default:
48 return false;
49 }
50}
51
52static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
53 const Expression& left,
54 Operator op,
55 const Expression& right) {
56 bool rightVal = right.as<Literal>().boolValue();
57
58 // Detect no-op Boolean expressions and optimize them away.
59 if ((op.kind() == Operator::Kind::LOGICALAND && rightVal) || // (expr && true) -> (expr)
60 (op.kind() == Operator::Kind::LOGICALOR && !rightVal) || // (expr || false) -> (expr)
61 (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr)
62 (op.kind() == Operator::Kind::EQEQ && rightVal) || // (expr == true) -> (expr)
63 (op.kind() == Operator::Kind::NEQ && !rightVal)) { // (expr != false) -> (expr)
64
65 return left.clone(pos);
66 }
67
68 return nullptr;
69}
70
71static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
72 const Expression& left,
73 Operator op,
74 const Expression& right) {
75 bool leftVal = left.as<Literal>().boolValue();
76
77 // When the literal is on the left, we can sometimes eliminate the other expression entirely.
78 if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) || // (false && expr) -> (false)
79 (op.kind() == Operator::Kind::LOGICALOR && leftVal)) { // (true || expr) -> (true)
80
81 return left.clone(pos);
82 }
83
84 // We can't eliminate the right-side expression via short-circuit, but we might still be able to
85 // simplify away a no-op expression.
86 return eliminate_no_op_boolean(pos, right, op, left);
87}
88
89static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
91 const Expression& left,
92 Operator op,
93 const Expression& right) {
94 if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
95 bool equality = (op.kind() == Operator::Kind::EQEQ);
96
97 switch (left.compareConstant(right)) {
99 equality = !equality;
100 [[fallthrough]];
101
103 return Literal::MakeBool(context, pos, equality);
104
106 break;
107 }
108 }
109 return nullptr;
110}
111
112static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
114 const Expression& left,
115 const Expression& right,
116 int leftColumns,
117 int leftRows,
118 int rightColumns,
119 int rightRows) {
120 const Type& componentType = left.type().componentType();
121 SkASSERT(componentType.matches(right.type().componentType()));
122
123 // Fetch the left matrix.
124 double leftVals[4][4];
125 for (int c = 0; c < leftColumns; ++c) {
126 for (int r = 0; r < leftRows; ++r) {
127 leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
128 }
129 }
130 // Fetch the right matrix.
131 double rightVals[4][4];
132 for (int c = 0; c < rightColumns; ++c) {
133 for (int r = 0; r < rightRows; ++r) {
134 rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
135 }
136 }
137
138 SkASSERT(leftColumns == rightRows);
139 int outColumns = rightColumns,
140 outRows = leftRows;
141
142 double args[16];
143 int argIndex = 0;
144 for (int c = 0; c < outColumns; ++c) {
145 for (int r = 0; r < outRows; ++r) {
146 // Compute a dot product for this position.
147 double val = 0;
148 for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
149 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
150 }
151
152 if (val >= -FLT_MAX && val <= FLT_MAX) {
153 args[argIndex++] = val;
154 } else {
155 // The value is outside the 32-bit float range, or is NaN; do not optimize.
156 return nullptr;
157 }
158 }
159 }
160
161 if (outColumns == 1) {
162 // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
163 std::swap(outColumns, outRows);
164 }
165
166 const Type& resultType = componentType.toCompound(context, outColumns, outRows);
167 return ConstructorCompound::MakeFromConstants(context, pos, resultType, args);
168}
169
170static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
172 const Expression& left,
173 const Expression& right) {
174 const Type& leftType = left.type();
175 const Type& rightType = right.type();
176
177 SkASSERT(leftType.isMatrix());
178 SkASSERT(rightType.isMatrix());
179
180 return simplify_matrix_multiplication(context, pos, left, right,
181 leftType.columns(), leftType.rows(),
182 rightType.columns(), rightType.rows());
183}
184
185static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
187 const Expression& left,
188 const Expression& right) {
189 const Type& leftType = left.type();
190 const Type& rightType = right.type();
191
192 SkASSERT(leftType.isVector());
193 SkASSERT(rightType.isMatrix());
194
195 return simplify_matrix_multiplication(context, pos, left, right,
196 /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
197 rightType.columns(), rightType.rows());
198}
199
200static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
202 const Expression& left,
203 const Expression& right) {
204 const Type& leftType = left.type();
205 const Type& rightType = right.type();
206
207 SkASSERT(leftType.isMatrix());
208 SkASSERT(rightType.isVector());
209
210 return simplify_matrix_multiplication(context, pos, left, right,
211 leftType.columns(), leftType.rows(),
212 /*rightColumns=*/1, /*rightRows=*/rightType.columns());
213}
214
215static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
217 const Expression& left,
218 Operator op,
219 const Expression& right) {
220 SkASSERT(is_vec_or_mat(left.type()));
221 SkASSERT(left.type().matches(right.type()));
222 const Type& type = left.type();
223
224 // Handle equality operations: == !=
225 if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226 right)) {
227 return result;
228 }
229
230 // Handle floating-point arithmetic: + - * /
231 using FoldFn = double (*)(double, double);
232 FoldFn foldFn;
233 switch (op.kind()) {
234 case Operator::Kind::PLUS: foldFn = +[](double a, double b) { return a + b; }; break;
235 case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236 case Operator::Kind::STAR: foldFn = +[](double a, double b) { return a * b; }; break;
237 case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
238 default:
239 return nullptr;
240 }
241
242 const Type& componentType = type.componentType();
243 SkASSERT(componentType.isNumber());
244
245 double minimumValue = componentType.minimumValue();
246 double maximumValue = componentType.maximumValue();
247
248 double args[16];
249 int numSlots = type.slotCount();
250 for (int i = 0; i < numSlots; i++) {
251 double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
252 if (value < minimumValue || value > maximumValue) {
253 return nullptr;
254 }
255 args[i] = value;
256 }
258}
259
260static std::unique_ptr<Expression> splat_scalar(const Context& context,
261 const Expression& scalar,
262 const Type& type) {
263 if (type.isVector()) {
264 return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
265 }
266 if (type.isMatrix()) {
267 int numSlots = type.slotCount();
268 ExpressionArray splatMatrix;
269 splatMatrix.reserve_exact(numSlots);
270 for (int index = 0; index < numSlots; ++index) {
271 splatMatrix.push_back(scalar.clone());
272 }
273 return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
274 }
275 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
276 return nullptr;
277}
278
279static std::unique_ptr<Expression> cast_expression(const Context& context,
281 const Expression& expr,
282 const Type& type) {
283 SkASSERT(type.componentType().matches(expr.type().componentType()));
284 if (expr.type().isScalar()) {
285 if (type.isMatrix()) {
286 return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
287 }
288 if (type.isVector()) {
289 return ConstructorSplat::Make(context, pos, type, expr.clone());
290 }
291 }
292 if (type.matches(expr.type())) {
293 return expr.clone(pos);
294 }
295 // We can't cast matrices into vectors or vice-versa.
296 return nullptr;
297}
298
299static std::unique_ptr<Expression> zero_expression(const Context& context,
301 const Type& type) {
302 std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
303 if (type.isScalar()) {
304 return zero;
305 }
306 if (type.isVector()) {
307 return ConstructorSplat::Make(context, pos, type, std::move(zero));
308 }
309 if (type.isMatrix()) {
310 return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
311 }
312 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313 return nullptr;
314}
315
316static std::unique_ptr<Expression> negate_expression(const Context& context,
318 const Expression& expr,
319 const Type& type) {
320 std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
321 return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
322 : nullptr;
323}
324
327 if (!expr->isIntLiteral()) {
328 return false;
329 }
330 *out = expr->as<Literal>().intValue();
331 return true;
332}
333
336 if (!expr->is<Literal>()) {
337 return false;
338 }
339 *out = expr->as<Literal>().value();
340 return true;
341}
342
343static bool contains_constant_zero(const Expression& expr) {
344 int numSlots = expr.type().slotCount();
345 for (int index = 0; index < numSlots; ++index) {
346 std::optional<double> slotVal = expr.getConstantValue(index);
347 if (slotVal.has_value() && *slotVal == 0.0) {
348 return true;
349 }
350 }
351 return false;
352}
353
355 int numSlots = expr.type().slotCount();
356 for (int index = 0; index < numSlots; ++index) {
357 std::optional<double> slotVal = expr.getConstantValue(index);
358 if (!slotVal.has_value() || *slotVal != value) {
359 return false;
360 }
361 }
362 return true;
363}
364
365// Returns true if the expression is a square diagonal matrix containing `value`.
366static bool is_constant_diagonal(const Expression& expr, double value) {
367 SkASSERT(expr.type().isMatrix());
368 int columns = expr.type().columns();
369 int rows = expr.type().rows();
370 if (columns != rows) {
371 return false;
372 }
373 int slotIdx = 0;
374 for (int c = 0; c < columns; ++c) {
375 for (int r = 0; r < rows; ++r) {
376 double expectation = (c == r) ? value : 0;
377 std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
378 if (!slotVal.has_value() || *slotVal != expectation) {
379 return false;
380 }
381 }
382 }
383 return true;
384}
385
386// Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
387static bool is_constant_value(const Expression& expr, double value) {
388 return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
390}
391
392// The expression represents the right-hand side of a division op. If the division can be
393// strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
394// Note that this only supports literal values with safe-to-use reciprocals, and returns null if
395// Expression contains anything else.
396static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
397 const Expression& right) {
398 if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
399 return nullptr;
400 }
401 // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
402 double values[4];
403 int nslots = right.type().slotCount();
404 for (int index = 0; index < nslots; ++index) {
405 std::optional<double> value = right.getConstantValue(index);
406 if (!value) {
407 return nullptr;
408 }
410 if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
411 // The reciprocal can be represented safely as a finite 32-bit float.
412 values[index] = *value;
413 } else {
414 // The value is outside the 32-bit float range, or is NaN; do not optimize.
415 return nullptr;
416 }
417 }
418 // Turn the expression array into a compound constructor. (If this is a single-slot expression,
419 // this will return the literal as-is.)
420 return ConstructorCompound::MakeFromConstants(context, right.fPosition, right.type(), values);
421}
422
423static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
424 const Expression& right) {
425 switch (op.kind()) {
426 case Operator::Kind::SLASH:
427 case Operator::Kind::SLASHEQ:
428 case Operator::Kind::PERCENT:
429 case Operator::Kind::PERCENTEQ:
430 if (contains_constant_zero(right)) {
431 context.fErrors->error(pos, "division by zero");
432 return true;
433 }
434 return false;
435 default:
436 return false;
437 }
438}
439
441 const Expression* expr = &inExpr;
442 while (expr->is<VariableReference>()) {
443 const VariableReference& varRef = expr->as<VariableReference>();
444 if (varRef.refKind() != VariableRefKind::kRead) {
445 return nullptr;
446 }
447 const Variable& var = *varRef.variable();
448 if (!var.modifierFlags().isConst()) {
449 return nullptr;
450 }
451 expr = var.initialValue();
452 if (!expr) {
453 // Generally, const variables must have initial values. However, function parameters are
454 // an exception; they can be const but won't have an initial value.
455 return nullptr;
456 }
457 }
458 return Analysis::IsCompileTimeConstant(*expr) ? expr : nullptr;
459}
460
462 const Expression* expr = GetConstantValueOrNull(inExpr);
463 return expr ? expr : &inExpr;
464}
465
467 Position pos, std::unique_ptr<Expression> inExpr) {
468 const Expression* expr = GetConstantValueOrNull(*inExpr);
469 return expr ? expr->clone(pos) : std::move(inExpr);
470}
471
472static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
473 return left.type().isScalar() && right.type().isMatrix();
474}
475
476static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
477 return is_scalar_op_matrix(right, left);
478}
479
480static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
482 const Expression& left,
483 Operator op,
484 const Expression& right,
485 const Type& resultType) {
486 switch (op.kind()) {
488 if (!is_scalar_op_matrix(left, right) &&
489 ConstantFolder::IsConstantSplat(right, 0.0)) { // x + 0
490 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
491 resultType)) {
492 return expr;
493 }
494 }
495 if (!is_matrix_op_scalar(left, right) &&
496 ConstantFolder::IsConstantSplat(left, 0.0)) { // 0 + x
497 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
498 resultType)) {
499 return expr;
500 }
501 }
502 break;
503
504 case Operator::Kind::STAR:
505 if (is_constant_value(right, 1.0)) { // x * 1
506 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
507 resultType)) {
508 return expr;
509 }
510 }
511 if (is_constant_value(left, 1.0)) { // 1 * x
512 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
513 resultType)) {
514 return expr;
515 }
516 }
517 if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) { // x * 0
518 return zero_expression(context, pos, resultType);
519 }
520 if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) { // 0 * x
521 return zero_expression(context, pos, resultType);
522 }
523 if (is_constant_value(right, -1.0)) { // x * -1 (to `-x`)
524 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
525 resultType)) {
526 return expr;
527 }
528 }
529 if (is_constant_value(left, -1.0)) { // -1 * x (to `-x`)
530 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
531 resultType)) {
532 return expr;
533 }
534 }
535 break;
536
537 case Operator::Kind::MINUS:
538 if (!is_scalar_op_matrix(left, right) &&
539 ConstantFolder::IsConstantSplat(right, 0.0)) { // x - 0
540 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
541 resultType)) {
542 return expr;
543 }
544 }
545 if (!is_matrix_op_scalar(left, right) &&
546 ConstantFolder::IsConstantSplat(left, 0.0)) { // 0 - x
547 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
548 resultType)) {
549 return expr;
550 }
551 }
552 break;
553
554 case Operator::Kind::SLASH:
555 if (!is_scalar_op_matrix(left, right) &&
556 ConstantFolder::IsConstantSplat(right, 1.0)) { // x / 1
557 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
558 resultType)) {
559 return expr;
560 }
561 }
562 if (!left.type().isMatrix()) { // convert `x / 2` into `x * 0.5`
563 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
564 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
565 std::move(expr));
566 }
567 }
568 break;
569
570 case Operator::Kind::PLUSEQ:
571 case Operator::Kind::MINUSEQ:
572 if (ConstantFolder::IsConstantSplat(right, 0.0)) { // x += 0, x -= 0
573 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
574 resultType)) {
576 return var;
577 }
578 }
579 break;
580
581 case Operator::Kind::STAREQ:
582 if (is_constant_value(right, 1.0)) { // x *= 1
583 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
584 resultType)) {
586 return var;
587 }
588 }
589 break;
590
591 case Operator::Kind::SLASHEQ:
592 if (ConstantFolder::IsConstantSplat(right, 1.0)) { // x /= 1
593 if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
594 resultType)) {
596 return var;
597 }
598 }
599 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
600 return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
601 std::move(expr));
602 }
603 break;
604
605 default:
606 break;
607 }
608
609 return nullptr;
610}
611
612// The expression must be scalar, and represents the right-hand side of a division op. It can
613// contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
614// expression might be further simplified by the constant folding, if possible.
615static std::unique_ptr<Expression> one_over_scalar(const Context& context,
616 const Expression& right) {
617 SkASSERT(right.type().isScalar());
618 Position pos = right.fPosition;
619 return BinaryExpression::Make(context, pos,
620 Literal::Make(pos, 1.0, &right.type()),
621 Operator::Kind::SLASH,
622 right.clone());
623}
624
625static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
627 const Expression& left,
628 Operator op,
629 const Expression& right,
630 const Type& resultType) {
631 // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
632 // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
633 switch (op.kind()) {
636 if (left.type().isMatrix() && right.type().isScalar()) {
637 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
639 return BinaryExpression::Make(context, pos,
640 left.clone(),
641 multiplyOp,
642 one_over_scalar(context, right));
643 }
644 break;
645
646 default:
647 break;
648 }
649
650 return nullptr;
651}
652
653static std::unique_ptr<Expression> fold_expression(Position pos,
654 double result,
655 const Type* resultType) {
656 if (resultType->isNumber()) {
657 if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
658 // This result will fit inside its type.
659 } else {
660 // The value is outside the range or is NaN (all if-checks fail); do not optimize.
661 return nullptr;
662 }
663 }
664
665 return Literal::Make(pos, result, resultType);
666}
667
668static std::unique_ptr<Expression> fold_two_constants(const Context& context,
670 const Expression* left,
671 Operator op,
672 const Expression* right,
673 const Type& resultType) {
676 const Type& leftType = left->type();
677 const Type& rightType = right->type();
678
679 // Handle pairs of integer literals.
680 if (left->isIntLiteral() && right->isIntLiteral()) {
681 using SKSL_UINT = uint64_t;
682 SKSL_INT leftVal = left->as<Literal>().intValue();
683 SKSL_INT rightVal = right->as<Literal>().intValue();
684
685 // Note that fold_expression returns null if the result would overflow its type.
686 #define RESULT(Op) fold_expression(pos, (SKSL_INT)(leftVal) Op \
687 (SKSL_INT)(rightVal), &resultType)
688 #define URESULT(Op) fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
689 (SKSL_UINT)(rightVal)), &resultType)
690 switch (op.kind()) {
691 case Operator::Kind::PLUS: return URESULT(+);
692 case Operator::Kind::MINUS: return URESULT(-);
693 case Operator::Kind::STAR: return URESULT(*);
694 case Operator::Kind::SLASH:
695 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
696 context.fErrors->error(pos, "arithmetic overflow");
697 return nullptr;
698 }
699 return RESULT(/);
700
701 case Operator::Kind::PERCENT:
702 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
703 context.fErrors->error(pos, "arithmetic overflow");
704 return nullptr;
705 }
706 return RESULT(%);
707
708 case Operator::Kind::BITWISEAND: return RESULT(&);
709 case Operator::Kind::BITWISEOR: return RESULT(|);
710 case Operator::Kind::BITWISEXOR: return RESULT(^);
711 case Operator::Kind::EQEQ: return RESULT(==);
712 case Operator::Kind::NEQ: return RESULT(!=);
713 case Operator::Kind::GT: return RESULT(>);
714 case Operator::Kind::GTEQ: return RESULT(>=);
715 case Operator::Kind::LT: return RESULT(<);
716 case Operator::Kind::LTEQ: return RESULT(<=);
717 case Operator::Kind::SHL:
718 if (rightVal >= 0 && rightVal <= 31) {
719 // Left-shifting a negative (or really, any signed) value is undefined behavior
720 // in C++, but not in GLSL. Do the shift on unsigned values to avoid triggering
721 // an UBSAN error.
722 return URESULT(<<);
723 }
724 context.fErrors->error(pos, "shift value out of range");
725 return nullptr;
726
727 case Operator::Kind::SHR:
728 if (rightVal >= 0 && rightVal <= 31) {
729 return RESULT(>>);
730 }
731 context.fErrors->error(pos, "shift value out of range");
732 return nullptr;
733
734 default:
735 break;
736 }
737 #undef RESULT
738 #undef URESULT
739
740 return nullptr;
741 }
742
743 // Handle pairs of floating-point literals.
744 if (left->isFloatLiteral() && right->isFloatLiteral()) {
745 SKSL_FLOAT leftVal = left->as<Literal>().floatValue();
746 SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
747
748 #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
749 switch (op.kind()) {
750 case Operator::Kind::PLUS: return RESULT(+);
751 case Operator::Kind::MINUS: return RESULT(-);
752 case Operator::Kind::STAR: return RESULT(*);
753 case Operator::Kind::SLASH: return RESULT(/);
754 case Operator::Kind::EQEQ: return RESULT(==);
755 case Operator::Kind::NEQ: return RESULT(!=);
756 case Operator::Kind::GT: return RESULT(>);
757 case Operator::Kind::GTEQ: return RESULT(>=);
758 case Operator::Kind::LT: return RESULT(<);
759 case Operator::Kind::LTEQ: return RESULT(<=);
760 default: break;
761 }
762 #undef RESULT
763
764 return nullptr;
765 }
766
767 // Perform matrix multiplication.
768 if (op.kind() == Operator::Kind::STAR) {
769 if (leftType.isMatrix() && rightType.isMatrix()) {
770 return simplify_matrix_times_matrix(context, pos, *left, *right);
771 }
772 if (leftType.isVector() && rightType.isMatrix()) {
773 return simplify_vector_times_matrix(context, pos, *left, *right);
774 }
775 if (leftType.isMatrix() && rightType.isVector()) {
776 return simplify_matrix_times_vector(context, pos, *left, *right);
777 }
778 }
779
780 // Perform constant folding on pairs of vectors/matrices.
781 if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
782 return simplify_componentwise(context, pos, *left, op, *right);
783 }
784
785 // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
786 if (rightType.isScalar() && is_vec_or_mat(leftType) &&
787 leftType.componentType().matches(rightType)) {
788 return simplify_componentwise(context, pos,
789 *left, op, *splat_scalar(context, *right, left->type()));
790 }
791
792 // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
793 if (leftType.isScalar() && is_vec_or_mat(rightType) &&
794 rightType.componentType().matches(leftType)) {
795 return simplify_componentwise(context, pos,
796 *splat_scalar(context, *left, right->type()), op, *right);
797 }
798
799 // Perform constant folding on pairs of matrices, arrays or structs.
800 if ((leftType.isMatrix() && rightType.isMatrix()) ||
801 (leftType.isArray() && rightType.isArray()) ||
802 (leftType.isStruct() && rightType.isStruct())) {
803 return simplify_constant_equality(context, pos, *left, op, *right);
804 }
805
806 // We aren't able to constant-fold these expressions.
807 return nullptr;
808}
809
810std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
812 const Expression& leftExpr,
813 Operator op,
814 const Expression& rightExpr,
815 const Type& resultType) {
816 // Replace constant variables with their literal values.
817 const Expression* left = GetConstantValueForVariable(leftExpr);
818 const Expression* right = GetConstantValueForVariable(rightExpr);
819
820 // If this is the assignment operator, and both sides are the same trivial expression, this is
821 // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
822 // This can happen when other parts of the assignment are optimized away.
823 if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
824 return right->clone(pos);
825 }
826
827 // Simplify the expression when both sides are constant Boolean literals.
828 if (left->isBoolLiteral() && right->isBoolLiteral()) {
829 bool leftVal = left->as<Literal>().boolValue();
830 bool rightVal = right->as<Literal>().boolValue();
831 bool result;
832 switch (op.kind()) {
833 case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
834 case Operator::Kind::LOGICALOR: result = leftVal || rightVal; break;
835 case Operator::Kind::LOGICALXOR: result = leftVal ^ rightVal; break;
836 case Operator::Kind::EQEQ: result = leftVal == rightVal; break;
837 case Operator::Kind::NEQ: result = leftVal != rightVal; break;
838 default: return nullptr;
839 }
840 return Literal::MakeBool(context, pos, result);
841 }
842
843 // If the left side is a Boolean literal, apply short-circuit optimizations.
844 if (left->isBoolLiteral()) {
845 return short_circuit_boolean(pos, *left, op, *right);
846 }
847
848 // If the right side is a Boolean literal...
849 if (right->isBoolLiteral()) {
850 // ... and the left side has no side effects...
851 if (!Analysis::HasSideEffects(*left)) {
852 // We can reverse the expressions and short-circuit optimizations are still valid.
853 return short_circuit_boolean(pos, *right, op, *left);
854 }
855
856 // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
857 return eliminate_no_op_boolean(pos, *left, op, *right);
858 }
859
860 if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
861 // With == comparison, if both sides are the same trivial expression, this is self-
862 // comparison and is always true. (We are not concerned with NaN.)
863 return Literal::MakeBool(context, pos, /*value=*/true);
864 }
865
866 if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
867 // With != comparison, if both sides are the same trivial expression, this is self-
868 // comparison and is always false. (We are not concerned with NaN.)
869 return Literal::MakeBool(context, pos, /*value=*/false);
870 }
871
872 if (error_on_divide_by_zero(context, pos, op, *right)) {
873 return nullptr;
874 }
875
876 // Perform full constant folding when both sides are compile-time constants.
877 bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
878 bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
879 if (leftSideIsConstant && rightSideIsConstant) {
880 return fold_two_constants(context, pos, left, op, right, resultType);
881 }
882
883 if (context.fConfig->fSettings.fOptimize) {
884 // If just one side is constant, we might still be able to simplify arithmetic expressions
885 // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
886 if (leftSideIsConstant || rightSideIsConstant) {
887 if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
888 *right, resultType)) {
889 return expr;
890 }
891 }
892
893 // We can simplify some forms of matrix division even when neither side is constant.
894 if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
895 *right, resultType)) {
896 return expr;
897 }
898 }
899
900 // We aren't able to constant-fold.
901 return nullptr;
902}
903
904} // namespace SkSL
SkPoint pos
#define SkDEBUGFAILF(fmt,...)
Definition: SkAssert.h:119
#define SkASSERT(cond)
Definition: SkAssert.h:116
static constexpr double sk_ieee_double_divide(double numer, double denom)
void swap(sk_sp< T > &a, sk_sp< T > &b)
Definition: SkRefCnt.h:341
#define RESULT(Op)
#define URESULT(Op)
int64_t SKSL_INT
Definition: SkSLDefines.h:16
float SKSL_FLOAT
Definition: SkSLDefines.h:17
GLenum type
static std::unique_ptr< Expression > Make(const Context &context, Position pos, std::unique_ptr< Expression > left, Operator op, std::unique_ptr< Expression > right)
static bool GetConstantValue(const Expression &value, double *out)
static bool GetConstantInt(const Expression &value, SKSL_INT *out)
static std::unique_ptr< Expression > Simplify(const Context &context, Position pos, const Expression &left, Operator op, const Expression &right, const Type &resultType)
static bool IsConstantSplat(const Expression &expr, double value)
static const Expression * GetConstantValueOrNull(const Expression &value)
static const Expression * GetConstantValueForVariable(const Expression &value)
static std::unique_ptr< Expression > MakeConstantValueForVariable(Position pos, std::unique_ptr< Expression > expr)
static std::unique_ptr< Expression > MakeFromConstants(const Context &context, Position pos, const Type &type, const double values[])
static std::unique_ptr< Expression > Make(const Context &context, Position pos, const Type &type, ExpressionArray args)
static std::unique_ptr< Expression > Make(const Context &context, Position pos, const Type &type, std::unique_ptr< Expression > arg)
static std::unique_ptr< Expression > Make(const Context &context, Position pos, const Type &type, std::unique_ptr< Expression > arg)
ErrorReporter * fErrors
Definition: SkSLContext.h:36
ProgramConfig * fConfig
Definition: SkSLContext.h:33
void error(Position position, std::string_view msg)
bool isIntLiteral() const
virtual std::unique_ptr< Expression > clone(Position pos) const =0
const Type & type() const
virtual std::optional< double > getConstantValue(int n) const
bool is() const
Definition: SkSLIRNode.h:124
const T & as() const
Definition: SkSLIRNode.h:133
Position fPosition
Definition: SkSLIRNode.h:109
static std::unique_ptr< Literal > MakeBool(const Context &context, Position pos, bool value)
Definition: SkSLLiteral.h:69
static std::unique_ptr< Literal > Make(Position pos, double value, const Type *type)
Definition: SkSLLiteral.h:81
Kind kind() const
Definition: SkSLOperator.h:85
bool isAssignment() const
static std::unique_ptr< Expression > Make(const Context &context, Position pos, Operator op, std::unique_ptr< Expression > base)
virtual bool isArray() const
Definition: SkSLType.h:532
virtual bool isVector() const
Definition: SkSLType.h:524
virtual int rows() const
Definition: SkSLType.h:438
virtual const Type & componentType() const
Definition: SkSLType.h:404
bool isNumber() const
Definition: SkSLType.h:304
bool matches(const Type &other) const
Definition: SkSLType.h:269
virtual bool isMatrix() const
Definition: SkSLType.h:528
virtual int columns() const
Definition: SkSLType.h:429
virtual size_t slotCount() const
Definition: SkSLType.h:457
virtual bool isScalar() const
Definition: SkSLType.h:512
virtual double maximumValue() const
Definition: SkSLType.h:449
const Type & toCompound(const Context &context, int columns, int rows) const
virtual bool isStruct() const
Definition: SkSLType.h:540
virtual double minimumValue() const
Definition: SkSLType.h:444
const Variable * variable() const
const Expression * initialValue() const
ModifierFlags modifierFlags() const
Definition: SkSLVariable.h:89
void reserve_exact(int n)
Definition: SkTArray.h:181
static bool b
struct MyStruct a[10]
G_BEGIN_DECLS G_MODULE_EXPORT FlValue * args
uint8_t value
GAsyncResult * result
static float min(float r, float g, float b)
Definition: hsl.cpp:48
bool IsCompileTimeConstant(const Expression &expr)
bool IsSameExpressionTree(const Expression &left, const Expression &right)
bool HasSideEffects(const Expression &expr)
bool UpdateVariableRefKind(Expression *expr, VariableRefKind kind, ErrorReporter *errors=nullptr)
static std::unique_ptr< Expression > negate_expression(const Context &context, Position pos, const Expression &expr, const Type &type)
static std::unique_ptr< Expression > simplify_matrix_division(const Context &context, Position pos, const Expression &left, Operator op, const Expression &right, const Type &resultType)
static std::unique_ptr< Expression > simplify_matrix_multiplication(const Context &context, Position pos, const Expression &left, const Expression &right, int leftColumns, int leftRows, int rightColumns, int rightRows)
static bool error_on_divide_by_zero(const Context &context, Position pos, Operator op, const Expression &right)
static std::unique_ptr< Expression > eliminate_no_op_boolean(Position pos, const Expression &left, Operator op, const Expression &right)
static bool contains_constant_zero(const Expression &expr)
static std::unique_ptr< Expression > simplify_matrix_times_matrix(const Context &context, Position pos, const Expression &left, const Expression &right)
static bool is_vec_or_mat(const Type &type)
static bool is_matrix_op_scalar(const Expression &left, const Expression &right)
static std::unique_ptr< Expression > cast_expression(const Context &context, Position pos, const Expression &expr, const Type &type)
static bool is_constant_diagonal(const Expression &expr, double value)
static bool is_constant_value(const Expression &expr, double value)
static std::unique_ptr< Expression > simplify_matrix_times_vector(const Context &context, Position pos, const Expression &left, const Expression &right)
static std::unique_ptr< Expression > make_reciprocal_expression(const Context &context, const Expression &right)
static std::unique_ptr< Expression > short_circuit_boolean(Position pos, const Expression &left, Operator op, const Expression &right)
static bool is_scalar_op_matrix(const Expression &left, const Expression &right)
static std::unique_ptr< Expression > one_over_scalar(const Context &context, const Expression &right)
static std::unique_ptr< Expression > simplify_constant_equality(const Context &context, Position pos, const Expression &left, Operator op, const Expression &right)
static std::unique_ptr< Expression > fold_two_constants(const Context &context, Position pos, const Expression *left, Operator op, const Expression *right, const Type &resultType)
static std::unique_ptr< Expression > simplify_vector_times_matrix(const Context &context, Position pos, const Expression &left, const Expression &right)
static std::unique_ptr< Expression > simplify_componentwise(const Context &context, Position pos, const Expression &left, Operator op, const Expression &right)
static std::unique_ptr< Expression > splat_scalar(const Context &context, const Expression &scalar, const Type &type)
static std::unique_ptr< Expression > zero_expression(const Context &context, Position pos, const Type &type)
static std::unique_ptr< Expression > simplify_arithmetic(const Context &context, Position pos, const Expression &left, Operator op, const Expression &right, const Type &resultType)
static std::unique_ptr< Expression > fold_expression(Position pos, double result, const Type *resultType)
ProgramSettings fSettings