Flutter Engine
The Flutter Engine
Loading...
Searching...
No Matches
SkSLGetLoopUnrollInfo.cpp
Go to the documentation of this file.
1/*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
27
28#include <cmath>
29#include <memory>
30
31namespace SkSL {
32
33class Context;
34
35// Loops that run for 100000+ iterations will exceed our program size limit.
36static constexpr int kLoopTerminationLimit = 100000;
37
38static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) {
39 if ((forwards && start > end) || (!forwards && start < end)) {
40 // The loop starts in a completed state (the start has already advanced past the end).
41 return 0;
42 }
43 if ((delta == 0.0) || forwards != (delta > 0.0)) {
44 // The loop does not progress toward a completed state, and will never terminate.
46 }
47 double iterations = sk_ieee_double_divide(end - start, delta);
48 double count = std::ceil(iterations);
49 if (inclusive && (count == iterations)) {
50 count += 1.0;
51 }
52 if (count > kLoopTerminationLimit || !std::isfinite(count)) {
53 // The loop runs for more iterations than we can safely unroll.
55 }
56 return (int)count;
57}
58
59std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(const Context& context,
60 Position loopPos,
61 const ForLoopPositions& positions,
62 const Statement* loopInitializer,
63 std::unique_ptr<Expression>* loopTest,
64 const Expression* loopNext,
65 const Statement* loopStatement,
66 ErrorReporter* errorPtr) {
68 ErrorReporter& errors = errorPtr ? *errorPtr : unused;
69
70 auto loopInfo = std::make_unique<LoopUnrollInfo>();
71
72 //
73 // init_declaration has the form: type_specifier identifier = constant_expression
74 //
75 if (!loopInitializer) {
76 Position pos = positions.initPosition.valid() ? positions.initPosition : loopPos;
77 errors.error(pos, "missing init declaration");
78 return nullptr;
79 }
80 if (!loopInitializer->is<VarDeclaration>()) {
81 errors.error(loopInitializer->fPosition, "invalid init declaration");
82 return nullptr;
83 }
84 const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>();
85 if (!initDecl.baseType().isNumber()) {
86 errors.error(loopInitializer->fPosition, "invalid type for loop index");
87 return nullptr;
88 }
89 if (initDecl.arraySize() != 0) {
90 errors.error(loopInitializer->fPosition, "invalid type for loop index");
91 return nullptr;
92 }
93 if (!initDecl.value()) {
94 errors.error(loopInitializer->fPosition, "missing loop index initializer");
95 return nullptr;
96 }
97 if (!ConstantFolder::GetConstantValue(*initDecl.value(), &loopInfo->fStart)) {
98 errors.error(loopInitializer->fPosition,
99 "loop index initializer must be a constant expression");
100 return nullptr;
101 }
102
103 loopInfo->fIndex = initDecl.var();
104
105 auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) {
106 return expr->is<VariableReference>() &&
107 expr->as<VariableReference>().variable() == loopInfo->fIndex;
108 };
109
110 //
111 // condition has the form: loop_index relational_operator constant_expression
112 //
113 if (!loopTest || !*loopTest) {
114 Position pos = positions.conditionPosition.valid() ? positions.conditionPosition : loopPos;
115 errors.error(pos, "missing condition");
116 return nullptr;
117 }
118 if (!loopTest->get()->is<BinaryExpression>()) {
119 errors.error(loopTest->get()->fPosition, "invalid condition");
120 return nullptr;
121 }
122 const BinaryExpression* cond = &loopTest->get()->as<BinaryExpression>();
123 if (!is_loop_index(cond->left())) {
124 errors.error(cond->fPosition, "expected loop index on left hand side of condition");
125 return nullptr;
126 }
127 // relational_operator is one of: > >= < <= == or !=
128 switch (cond->getOperator().kind()) {
129 case Operator::Kind::GT:
130 case Operator::Kind::GTEQ:
131 case Operator::Kind::LT:
132 case Operator::Kind::LTEQ:
133 case Operator::Kind::EQEQ:
134 case Operator::Kind::NEQ:
135 break;
136 default:
137 errors.error(cond->fPosition, "invalid relational operator");
138 return nullptr;
139 }
140 double loopEnd = 0;
141 if (!ConstantFolder::GetConstantValue(*cond->right(), &loopEnd)) {
142 errors.error(cond->fPosition, "loop index must be compared with a constant expression");
143 return nullptr;
144 }
145
146 //
147 // expression has one of the following forms:
148 // loop_index++
149 // loop_index--
150 // loop_index += constant_expression
151 // loop_index -= constant_expression
152 // The spec doesn't mention prefix increment and decrement, but there is some consensus that
153 // it's an oversight, so we allow those as well.
154 //
155 if (!loopNext) {
156 Position pos = positions.nextPosition.valid() ? positions.nextPosition : loopPos;
157 errors.error(pos, "missing loop expression");
158 return nullptr;
159 }
160 switch (loopNext->kind()) {
161 case Expression::Kind::kBinary: {
162 const BinaryExpression& next = loopNext->as<BinaryExpression>();
163 if (!is_loop_index(next.left())) {
164 errors.error(loopNext->fPosition, "expected loop index in loop expression");
165 return nullptr;
166 }
167 if (!ConstantFolder::GetConstantValue(*next.right(), &loopInfo->fDelta)) {
168 errors.error(loopNext->fPosition,
169 "loop index must be modified by a constant expression");
170 return nullptr;
171 }
172 switch (next.getOperator().kind()) {
173 case Operator::Kind::PLUSEQ: break;
174 case Operator::Kind::MINUSEQ: loopInfo->fDelta = -loopInfo->fDelta; break;
175 default:
176 errors.error(loopNext->fPosition, "invalid operator in loop expression");
177 return nullptr;
178 }
179 break;
180 }
181 case Expression::Kind::kPrefix: {
182 const PrefixExpression& next = loopNext->as<PrefixExpression>();
183 if (!is_loop_index(next.operand())) {
184 errors.error(loopNext->fPosition, "expected loop index in loop expression");
185 return nullptr;
186 }
187 switch (next.getOperator().kind()) {
188 case Operator::Kind::PLUSPLUS: loopInfo->fDelta = 1; break;
189 case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
190 default:
191 errors.error(loopNext->fPosition, "invalid operator in loop expression");
192 return nullptr;
193 }
194 break;
195 }
196 case Expression::Kind::kPostfix: {
197 const PostfixExpression& next = loopNext->as<PostfixExpression>();
198 if (!is_loop_index(next.operand())) {
199 errors.error(loopNext->fPosition, "expected loop index in loop expression");
200 return nullptr;
201 }
202 switch (next.getOperator().kind()) {
203 case Operator::Kind::PLUSPLUS: loopInfo->fDelta = 1; break;
204 case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
205 default:
206 errors.error(loopNext->fPosition, "invalid operator in loop expression");
207 return nullptr;
208 }
209 break;
210 }
211 default:
212 errors.error(loopNext->fPosition, "invalid loop expression");
213 return nullptr;
214 }
215
216 //
217 // Within the body of the loop, the loop index is not statically assigned to, nor is it used as
218 // argument to a function 'out' or 'inout' parameter.
219 //
220 if (Analysis::StatementWritesToVariable(*loopStatement, *initDecl.var())) {
221 errors.error(loopStatement->fPosition,
222 "loop index must not be modified within body of the loop");
223 return nullptr;
224 }
225
226 // Finally, compute the iteration count, based on the bounds, and the termination operator.
227 loopInfo->fCount = 0;
228
229 switch (cond->getOperator().kind()) {
230 case Operator::Kind::LT:
231 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
232 /*forwards=*/true, /*inclusive=*/false);
233 break;
234
235 case Operator::Kind::GT:
236 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
237 /*forwards=*/false, /*inclusive=*/false);
238 break;
239
240 case Operator::Kind::LTEQ:
241 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
242 /*forwards=*/true, /*inclusive=*/true);
243 break;
244
245 case Operator::Kind::GTEQ:
246 loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
247 /*forwards=*/false, /*inclusive=*/true);
248 break;
249
250 case Operator::Kind::NEQ: {
251 float iterations = sk_ieee_double_divide(loopEnd - loopInfo->fStart, loopInfo->fDelta);
252 loopInfo->fCount = std::ceil(iterations);
253 if (loopInfo->fCount < 0 || loopInfo->fCount != iterations ||
254 !std::isfinite(iterations)) {
255 // The loop doesn't reach the exact endpoint and so will never terminate.
256 loopInfo->fCount = kLoopTerminationLimit;
257 }
258 if (loopInfo->fIndex->type().componentType().isFloat()) {
259 // Rewrite `x != n` tests as `x < n` or `x > n` depending on the loop direction.
260 // Less-than and greater-than tests avoid infinite loops caused by rounding error.
261 Operator::Kind op = (loopInfo->fDelta > 0) ? Operator::Kind::LT
262 : Operator::Kind::GT;
263 *loopTest = BinaryExpression::Make(context,
264 cond->fPosition,
265 cond->left()->clone(),
266 op,
267 cond->right()->clone());
268 cond = &loopTest->get()->as<BinaryExpression>();
269 }
270 break;
271 }
272 case Operator::Kind::EQEQ: {
273 if (loopInfo->fStart == loopEnd) {
274 // Start and end begin in the same place, so we can run one iteration...
275 if (loopInfo->fDelta) {
276 // ... and then they diverge, so the loop terminates.
277 loopInfo->fCount = 1;
278 } else {
279 // ... but they never diverge, so the loop runs forever.
280 loopInfo->fCount = kLoopTerminationLimit;
281 }
282 } else {
283 // Start never equals end, so the loop will not run a single iteration.
284 loopInfo->fCount = 0;
285 }
286 break;
287 }
288 default: SkUNREACHABLE;
289 }
290
291 SkASSERT(loopInfo->fCount >= 0);
292 if (loopInfo->fCount >= kLoopTerminationLimit) {
293 errors.error(loopPos, "loop must guarantee termination in fewer iterations");
294 return nullptr;
295 }
296
297 return loopInfo;
298}
299
300} // namespace SkSL
static bool unused
int count
SkPoint pos
static float next(float f)
#define SkUNREACHABLE
Definition SkAssert.h:135
#define SkASSERT(cond)
Definition SkAssert.h:116
static constexpr double sk_ieee_double_divide(double numer, double denom)
std::unique_ptr< Expression > & left()
std::unique_ptr< Expression > & right()
static std::unique_ptr< Expression > Make(const Context &context, Position pos, std::unique_ptr< Expression > left, Operator op, std::unique_ptr< Expression > right)
static bool GetConstantValue(const Expression &value, double *out)
Kind kind() const
bool is() const
Definition SkSLIRNode.h:124
const T & as() const
Definition SkSLIRNode.h:133
Position fPosition
Definition SkSLIRNode.h:109
Kind kind() const
bool valid() const
bool isNumber() const
Definition SkSLType.h:304
const Type & baseType() const
std::unique_ptr< Expression > & value()
Variable * var() const
const Variable * variable() const
glong glong end
std::unique_ptr< LoopUnrollInfo > GetLoopUnrollInfo(const Context &context, Position pos, const ForLoopPositions &positions, const Statement *loopInitializer, std::unique_ptr< Expression > *loopTestPtr, const Expression *loopNext, const Statement *loopStatement, ErrorReporter *errors)
bool StatementWritesToVariable(const Statement &stmt, const Variable &var)
static constexpr int kLoopTerminationLimit
static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive)