MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/DebugLog.h"
29#include <cstdint>
30
31using namespace mlir;
32
33#define DEBUG_TYPE "scf-utils"
34
36 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
37 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
38 bool replaceIterOperandsUsesInLoop) {
39 if (loopNest.empty())
40 return {};
41 // This method is recursive (to make it more readable). Adding an
42 // assertion here to limit the recursion. (See
43 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
44 assert(loopNest.size() <= 10 &&
45 "exceeded recursion limit when yielding value from loop nest");
46
47 // To yield a value from a perfectly nested loop nest, the following
48 // pattern needs to be created, i.e. starting with
49 //
50 // ```mlir
51 // scf.for .. {
52 // scf.for .. {
53 // scf.for .. {
54 // %value = ...
55 // }
56 // }
57 // }
58 // ```
59 //
60 // needs to be modified to
61 //
62 // ```mlir
63 // %0 = scf.for .. iter_args(%arg0 = %init) {
64 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
65 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
66 // %value = ...
67 // scf.yield %value
68 // }
69 // scf.yield %2
70 // }
71 // scf.yield %1
72 // }
73 // ```
74 //
75 // The inner most loop is handled using the `replaceWithAdditionalYields`
76 // that works on a single loop.
77 if (loopNest.size() == 1) {
78 auto innerMostLoop =
79 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
80 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
81 newYieldValuesFn));
82 return {innerMostLoop};
83 }
84 // The outer loops are modified by calling this method recursively
85 // - The return value of the inner loop is the value yielded by this loop.
86 // - The region iter args of this loop are the init_args for the inner loop.
87 SmallVector<scf::ForOp> newLoopNest;
89 [&](OpBuilder &innerBuilder, Location loc,
91 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
92 innerNewBBArgs, newYieldValuesFn,
93 replaceIterOperandsUsesInLoop);
94 return llvm::to_vector(llvm::map_range(
95 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
96 [](OpResult r) -> Value { return r; }));
97 };
98 scf::ForOp outerMostLoop =
99 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
100 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
101 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
102 return newLoopNest;
103}
104
105/// Outline a region with a single block into a new FuncOp.
106/// Assumes the FuncOp result types is the type of the yielded operands of the
107/// single block. This constraint makes it easy to determine the result.
108/// This method also clones the `arith::ConstantIndexOp` at the start of
109/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
110/// provided, it will be set to point to the operation that calls the outlined
111/// function.
112// TODO: support more than single-block regions.
113// TODO: more flexible constant handling.
114FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
115 Location loc,
116 Region &region,
117 StringRef funcName,
118 func::CallOp *callOp) {
119 assert(!funcName.empty() && "funcName cannot be empty");
120 if (!region.hasOneBlock())
121 return failure();
122
123 Block *originalBlock = &region.front();
124 Operation *originalTerminator = originalBlock->getTerminator();
125
126 // Outline before current function.
127 OpBuilder::InsertionGuard g(rewriter);
128 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
129
130 SetVector<Value> captures;
131 getUsedValuesDefinedAbove(region, captures);
132
133 ValueRange outlinedValues(captures.getArrayRef());
134 SmallVector<Type> outlinedFuncArgTypes;
135 SmallVector<Location> outlinedFuncArgLocs;
136 // Region's arguments are exactly the first block's arguments as per
137 // Region::getArguments().
138 // Func's arguments are cat(regions's arguments, captures arguments).
139 for (BlockArgument arg : region.getArguments()) {
140 outlinedFuncArgTypes.push_back(arg.getType());
141 outlinedFuncArgLocs.push_back(arg.getLoc());
142 }
143 for (Value value : outlinedValues) {
144 outlinedFuncArgTypes.push_back(value.getType());
145 outlinedFuncArgLocs.push_back(value.getLoc());
146 }
147 FunctionType outlinedFuncType =
148 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
149 originalTerminator->getOperandTypes());
150 auto outlinedFunc =
151 func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
152 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
153
154 // Merge blocks while replacing the original block operands.
155 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
156 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
157 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
158 {
159 OpBuilder::InsertionGuard g(rewriter);
160 rewriter.setInsertionPointToEnd(outlinedFuncBody);
161 rewriter.mergeBlocks(
162 originalBlock, outlinedFuncBody,
163 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
164 // Explicitly set up a new ReturnOp terminator.
165 rewriter.setInsertionPointToEnd(outlinedFuncBody);
166 func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
167 originalTerminator->getOperands());
168 }
169
170 // Reconstruct the block that was deleted and add a
171 // terminator(call_results).
172 Block *newBlock = rewriter.createBlock(
173 &region, region.begin(),
174 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
175 ArrayRef<Location>(outlinedFuncArgLocs)
176 .take_front(numOriginalBlockArguments));
177 {
178 OpBuilder::InsertionGuard g(rewriter);
179 rewriter.setInsertionPointToEnd(newBlock);
180 SmallVector<Value> callValues;
181 llvm::append_range(callValues, newBlock->getArguments());
182 llvm::append_range(callValues, outlinedValues);
183 auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
184 if (callOp)
185 *callOp = call;
186
187 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
188 // Clone `originalTerminator` to take the callOp results then erase it from
189 // `outlinedFuncBody`.
190 IRMapping bvm;
191 bvm.map(originalTerminator->getOperands(), call->getResults());
192 rewriter.clone(*originalTerminator, bvm);
193 rewriter.eraseOp(originalTerminator);
194 }
195
196 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
197 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
198 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
199 outlinedValues.size()))) {
200 Value orig = std::get<0>(it);
201 Value repl = std::get<1>(it);
202 {
203 OpBuilder::InsertionGuard g(rewriter);
204 rewriter.setInsertionPointToStart(outlinedFuncBody);
206 repl = rewriter.clone(*cst)->getResult(0);
207 }
208 }
209 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
210 return outlinedFunc->isProperAncestor(opOperand.getOwner());
211 });
212 }
213
214 return outlinedFunc;
215}
216
217LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
218 func::FuncOp *thenFn, StringRef thenFnName,
219 func::FuncOp *elseFn, StringRef elseFnName) {
220 IRRewriter rewriter(b);
221 Location loc = ifOp.getLoc();
222 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
223 if (thenFn && !ifOp.getThenRegion().empty()) {
224 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
225 rewriter, loc, ifOp.getThenRegion(), thenFnName);
226 if (failed(outlinedFuncOpOrFailure))
227 return failure();
228 *thenFn = *outlinedFuncOpOrFailure;
229 }
230 if (elseFn && !ifOp.getElseRegion().empty()) {
231 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
232 rewriter, loc, ifOp.getElseRegion(), elseFnName);
233 if (failed(outlinedFuncOpOrFailure))
234 return failure();
235 *elseFn = *outlinedFuncOpOrFailure;
236 }
237 return success();
238}
239
242 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
243 bool rootEnclosesPloops = false;
244 for (Region &region : rootOp->getRegions()) {
245 for (Block &block : region.getBlocks()) {
246 for (Operation &op : block) {
247 bool enclosesPloops = getInnermostParallelLoops(&op, result);
248 rootEnclosesPloops |= enclosesPloops;
249 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
250 rootEnclosesPloops = true;
251
252 // Collect parallel loop if it is an innermost one.
253 if (!enclosesPloops)
254 result.push_back(ploop);
255 }
256 }
257 }
258 }
259 return rootEnclosesPloops;
260}
261
262// Build the IR that performs ceil division of a positive value by a constant:
263// ceildiv(a, B) = divis(a + (B-1), B)
264// where divis is rounding-to-zero division.
265static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
266 int64_t divisor) {
267 assert(divisor > 0 && "expected positive divisor");
268 assert(dividend.getType().isIntOrIndex() &&
269 "expected integer or index-typed value");
270
271 Value divisorMinusOneCst = arith::ConstantOp::create(
272 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
273 Value divisorCst = arith::ConstantOp::create(
274 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
275 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
276 return arith::DivUIOp::create(builder, loc, sum, divisorCst);
277}
278
279// Build the IR that performs ceil division of a positive value by another
280// positive value:
281// ceildiv(a, b) = divis(a + (b - 1), b)
282// where divis is rounding-to-zero division.
283static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
284 Value divisor) {
285 assert(dividend.getType().isIntOrIndex() &&
286 "expected integer or index-typed value");
287 Value cstOne = arith::ConstantOp::create(
288 builder, loc, builder.getOneAttr(dividend.getType()));
289 Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
290 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
291 return arith::DivUIOp::create(builder, loc, sum, divisor);
292}
293
295 Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
296 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
297 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
298 ValueRange iterArgs, ValueRange yieldedValues,
299 IRMapping *clonedToSrcOpsMap) {
300
301 // Check if the op was cloned from another source op, and return it if found
302 // (or the same op if not found)
303 auto findOriginalSrcOp =
304 [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
305 Operation *srcOp = op;
306 // If the source op derives from another op: traverse the chain to find the
307 // original source op
308 while (srcOp && clonedToSrcOpsMap.contains(srcOp))
309 srcOp = clonedToSrcOpsMap.lookup(srcOp);
310 return srcOp;
311 };
312
313 // Builder to insert unrolled bodies just before the terminator of the body of
314 // the loop.
315 auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
316
317 static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
318 if (!annotateFn)
319 annotateFn = noopAnnotateFn;
320
321 // Keep a pointer to the last non-terminator operation in the original block
322 // so that we know what to clone (since we are doing this in-place).
323 Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
324
325 // Unroll the contents of the loop body (append unrollFactor - 1 additional
326 // copies).
327 SmallVector<Value, 4> lastYielded(yieldedValues);
328
329 for (unsigned i = 1; i < unrollFactor; i++) {
330 // Prepare operand map.
331 IRMapping operandMap;
332 operandMap.map(iterArgs, lastYielded);
333
334 // If the induction variable is used, create a remapping to the value for
335 // this unrolled instance.
336 if (!iv.use_empty()) {
337 Value ivUnroll = ivRemapFn(i, iv, builder);
338 operandMap.map(iv, ivUnroll);
339 }
340
341 // Clone the original body of 'forOp'.
342 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
343 Operation *srcOp = &(*it);
344 Operation *clonedOp = builder.clone(*srcOp, operandMap);
345 annotateFn(i, clonedOp, builder);
346 if (clonedToSrcOpsMap)
347 clonedToSrcOpsMap->map(clonedOp,
348 findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
349 }
350
351 // Update yielded values.
352 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
353 lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
354 }
355
356 // Make sure we annotate the Ops in the original body. We do this last so that
357 // any annotations are not copied into the cloned Ops above.
358 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
359 annotateFn(0, &*it, builder);
360
361 // Update operands of the yield statement.
362 loopBodyBlock->getTerminator()->setOperands(lastYielded);
363}
364
365/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
366/// epilogue loop, if the loop is unrolled.
367FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
368 scf::ForOp forOp, uint64_t unrollFactor,
369 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
370 assert(unrollFactor > 0 && "expected positive unroll factor");
371
372 // Return if the loop body is empty.
373 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
374 return UnrolledLoopInfo{forOp, std::nullopt};
375
376 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
377 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
378 OpBuilder boundsBuilder(forOp);
379 IRRewriter rewriter(forOp.getContext());
380 auto loc = forOp.getLoc();
381 Value step = forOp.getStep();
382 Value upperBoundUnrolled;
383 Value stepUnrolled;
384 bool generateEpilogueLoop = true;
385
386 std::optional<APInt> constTripCount = forOp.getStaticTripCount();
387 if (constTripCount) {
388 // Constant loop bounds computation.
389 int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
390 int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
391 int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
392 if (unrollFactor == 1) {
393 if (*constTripCount == 1 &&
394 failed(forOp.promoteIfSingleIteration(rewriter)))
395 return failure();
396 return UnrolledLoopInfo{forOp, std::nullopt};
397 }
398
399 int64_t tripCountEvenMultiple =
400 constTripCount->getSExtValue() -
401 (constTripCount->getSExtValue() % unrollFactor);
402 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
403 int64_t stepUnrolledCst = stepCst * unrollFactor;
404
405 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
406 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
407 if (generateEpilogueLoop)
408 upperBoundUnrolled = arith::ConstantOp::create(
409 boundsBuilder, loc,
410 boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
411 upperBoundUnrolledCst));
412 else
413 upperBoundUnrolled = forOp.getUpperBound();
414
415 // Create constant for 'stepUnrolled'.
416 stepUnrolled =
417 stepCst == stepUnrolledCst
418 ? step
419 : arith::ConstantOp::create(boundsBuilder, loc,
420 boundsBuilder.getIntegerAttr(
421 step.getType(), stepUnrolledCst));
422 } else {
423 // Dynamic loop bounds computation.
424 // TODO: Add dynamic asserts for negative lb/ub/step, or
425 // consider using ceilDiv from AffineApplyExpander.
426 auto lowerBound = forOp.getLowerBound();
427 auto upperBound = forOp.getUpperBound();
428 Value diff =
429 arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
430 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
431 Value unrollFactorCst = arith::ConstantOp::create(
432 boundsBuilder, loc,
433 boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
434 Value tripCountRem =
435 arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
436 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
437 Value tripCountEvenMultiple =
438 arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
439 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
440 upperBoundUnrolled = arith::AddIOp::create(
441 boundsBuilder, loc, lowerBound,
442 arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
443 // Scale 'step' by 'unrollFactor'.
444 stepUnrolled =
445 arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
446 }
447
448 UnrolledLoopInfo resultLoops;
449
450 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
451 if (generateEpilogueLoop) {
452 OpBuilder epilogueBuilder(forOp->getContext());
453 epilogueBuilder.setInsertionPointAfter(forOp);
454 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
455 epilogueForOp.setLowerBound(upperBoundUnrolled);
456
457 // Update uses of loop results.
458 auto results = forOp.getResults();
459 auto epilogueResults = epilogueForOp.getResults();
460
461 for (auto e : llvm::zip(results, epilogueResults)) {
462 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
463 }
464 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
465 epilogueForOp.getInitArgs().size(), results);
466 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
467 resultLoops.epilogueLoopOp = epilogueForOp;
468 }
469
470 // Create unrolled loop.
471 forOp.setUpperBound(upperBoundUnrolled);
472 forOp.setStep(stepUnrolled);
473
474 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
475 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
476
478 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
479 [&](unsigned i, Value iv, OpBuilder b) {
480 // iv' = iv + step * i;
481 auto stride = arith::MulIOp::create(
482 b, loc, step,
483 arith::ConstantOp::create(b, loc,
484 b.getIntegerAttr(iv.getType(), i)));
485 return arith::AddIOp::create(b, loc, iv, stride);
486 },
487 annotateFn, iterArgs, yieldedValues);
488 // Promote the loop body up if this has turned into a single iteration loop.
489 if (forOp.promoteIfSingleIteration(rewriter).failed())
490 resultLoops.mainLoopOp = forOp;
491 return resultLoops;
492}
493
494/// Unrolls this loop completely.
495LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
496 IRRewriter rewriter(forOp.getContext());
497 std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
498 if (!mayBeConstantTripCount.has_value())
499 return failure();
500 const APInt &tripCount = *mayBeConstantTripCount;
501 if (tripCount.isZero())
502 return success();
503 if (tripCount.getSExtValue() == 1)
504 return forOp.promoteIfSingleIteration(rewriter);
505 return loopUnrollByFactor(forOp, tripCount.getSExtValue());
506}
507
508/// Check if bounds of all inner loops are defined outside of `forOp`
509/// and return false if not.
510static bool areInnerBoundsInvariant(scf::ForOp forOp) {
511 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
512 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
513 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
514 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
515 return WalkResult::interrupt();
516
517 return WalkResult::advance();
518 });
519 return !walkResult.wasInterrupted();
520}
521
522/// Unrolls and jams this loop by the specified factor.
523LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
524 uint64_t unrollJamFactor) {
525 assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
526
527 if (unrollJamFactor == 1)
528 return success();
529
530 // If any control operand of any inner loop of `forOp` is defined within
531 // `forOp`, no unroll jam.
532 if (!areInnerBoundsInvariant(forOp)) {
533 LDBG() << "failed to unroll and jam: inner bounds are not invariant";
534 return failure();
535 }
536
537 // Currently, for operations with results are not supported.
538 if (forOp->getNumResults() > 0) {
539 LDBG() << "failed to unroll and jam: unsupported loop with results";
540 return failure();
541 }
542
543 // Currently, only constant trip count that divided by the unroll factor is
544 // supported.
545 std::optional<APInt> tripCount = forOp.getStaticTripCount();
546 if (!tripCount.has_value()) {
547 // If the trip count is dynamic, do not unroll & jam.
548 LDBG() << "failed to unroll and jam: trip count could not be determined";
549 return failure();
550 }
551 if (unrollJamFactor > tripCount->getZExtValue()) {
552 LDBG() << "unroll and jam factor is greater than trip count, set factor to "
553 "trip "
554 "count";
555 unrollJamFactor = tripCount->getZExtValue();
556 } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
557 LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
558 "multiple of unroll jam factor";
559 return failure();
560 }
561
562 // Nothing in the loop body other than the terminator.
563 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
564 return success();
565
566 // Gather all sub-blocks to jam upon the loop being unrolled.
568 jbg.walk(forOp);
569 auto &subBlocks = jbg.subBlocks;
570
571 // Collect inner loops.
572 SmallVector<scf::ForOp> innerLoops;
573 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
574
575 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
576 // iteration. There are (`unrollJamFactor` - 1) iterations.
577 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
578
579 // For any loop with iter_args, replace it with a new loop that has
580 // `unrollJamFactor` copies of its iterOperands, iter_args and yield
581 // operands.
582 SmallVector<scf::ForOp> newInnerLoops;
583 IRRewriter rewriter(forOp.getContext());
584 for (scf::ForOp oldForOp : innerLoops) {
585 SmallVector<Value> dupIterOperands, dupYieldOperands;
586 ValueRange oldIterOperands = oldForOp.getInits();
587 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
588 ValueRange oldYieldOperands =
589 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
590 // Get additional iterOperands, iterArgs, and yield operands. We will
591 // fix iterOperands and yield operands after cloning of sub-blocks.
592 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
593 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
594 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
595 }
596 // Create a new loop with additional iterOperands, iter_args and yield
597 // operands. This new loop will take the loop body of the original loop.
598 bool forOpReplaced = oldForOp == forOp;
599 scf::ForOp newForOp =
600 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
601 rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
602 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
603 return dupYieldOperands;
604 }));
605 newInnerLoops.push_back(newForOp);
606 // `forOp` has been replaced with a new loop.
607 if (forOpReplaced)
608 forOp = newForOp;
609 // Update `operandMaps` for `newForOp` iterArgs and results.
610 ValueRange newIterArgs = newForOp.getRegionIterArgs();
611 unsigned oldNumIterArgs = oldIterArgs.size();
612 ValueRange newResults = newForOp.getResults();
613 unsigned oldNumResults = newResults.size() / unrollJamFactor;
614 assert(oldNumIterArgs == oldNumResults &&
615 "oldNumIterArgs must be the same as oldNumResults");
616 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
617 for (unsigned j = 0; j < oldNumIterArgs; ++j) {
618 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
619 // results. Update `operandMaps[i - 1]` to map old iterArgs and results
620 // to those in the `i`th new set.
621 operandMaps[i - 1].map(newIterArgs[j],
622 newIterArgs[i * oldNumIterArgs + j]);
623 operandMaps[i - 1].map(newResults[j],
624 newResults[i * oldNumResults + j]);
625 }
626 }
627 }
628
629 // Scale the step of loop being unroll-jammed by the unroll-jam factor.
630 rewriter.setInsertionPoint(forOp);
631 int64_t step = forOp.getConstantStep()->getSExtValue();
632 auto newStep = rewriter.createOrFold<arith::MulIOp>(
633 forOp.getLoc(), forOp.getStep(),
634 rewriter.createOrFold<arith::ConstantOp>(
635 forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
636 forOp.setStep(newStep);
637 auto forOpIV = forOp.getInductionVar();
638
639 // Unroll and jam (appends unrollJamFactor - 1 additional copies).
640 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
641 for (auto &subBlock : subBlocks) {
642 // Builder to insert unroll-jammed bodies. Insert right at the end of
643 // sub-block.
644 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
645
646 // If the induction variable is used, create a remapping to the value for
647 // this unrolled instance.
648 if (!forOpIV.use_empty()) {
649 // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
650 auto ivTag = builder.createOrFold<arith::ConstantOp>(
651 forOp.getLoc(), builder.getIndexAttr(step * i));
652 auto ivUnroll =
653 builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
654 operandMaps[i - 1].map(forOpIV, ivUnroll);
655 }
656 // Clone the sub-block being unroll-jammed.
657 for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
658 builder.clone(*it, operandMaps[i - 1]);
659 }
660 // Fix iterOperands and yield op operands of newly created loops.
661 for (auto newForOp : newInnerLoops) {
662 unsigned oldNumIterOperands =
663 newForOp.getNumRegionIterArgs() / unrollJamFactor;
664 unsigned numControlOperands = newForOp.getNumControlOperands();
665 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
666 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
667 assert(oldNumIterOperands == oldNumYieldOperands &&
668 "oldNumIterOperands must be the same as oldNumYieldOperands");
669 for (unsigned j = 0; j < oldNumIterOperands; ++j) {
670 // The `i`th duplication of an old iterOperand or yield op operand
671 // needs to be replaced with a mapped value from `operandMaps[i - 1]`
672 // if such mapped value exists.
673 newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
674 operandMaps[i - 1].lookupOrDefault(
675 newForOp.getOperand(numControlOperands + j)));
676 yieldOp.setOperand(
677 i * oldNumYieldOperands + j,
678 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
679 }
680 }
681 }
682
683 // Promote the loop body up if this has turned into a single iteration loop.
684 (void)forOp.promoteIfSingleIteration(rewriter);
685 return success();
686}
687
689 Location loc, OpFoldResult lb,
691 OpFoldResult step) {
692 Range normalizedLoopBounds;
693 normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
694 normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
695 AffineExpr s0, s1, s2;
696 bindSymbols(rewriter.getContext(), s0, s1, s2);
697 AffineExpr e = (s1 - s0).ceilDiv(s2);
698 normalizedLoopBounds.size =
699 affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
700 return normalizedLoopBounds;
701}
702
705 OpFoldResult step) {
706 if (getType(lb).isIndex()) {
707 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
708 }
709 // For non-index types, generate `arith` instructions
710 // Check if the loop is already known to have a constant zero lower bound or
711 // a constant one step.
712 bool isZeroBased = false;
713 if (auto lbCst = getConstantIntValue(lb))
714 isZeroBased = lbCst.value() == 0;
715
716 bool isStepOne = false;
717 if (auto stepCst = getConstantIntValue(step))
718 isStepOne = stepCst.value() == 1;
719
720 Type rangeType = getType(lb);
721 assert(rangeType == getType(ub) && rangeType == getType(step) &&
722 "expected matching types");
723
724 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
725 // assuming the step is strictly positive. Update the bounds and the step
726 // of the loop to go from 0 to the number of iterations, if necessary.
727 if (isZeroBased && isStepOne)
728 return {lb, ub, step};
729
730 OpFoldResult diff = ub;
731 if (!isZeroBased) {
732 diff = rewriter.createOrFold<arith::SubIOp>(
733 loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
734 getValueOrCreateConstantIntOp(rewriter, loc, lb));
735 }
736 OpFoldResult newUpperBound = diff;
737 if (!isStepOne) {
738 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
739 loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
740 getValueOrCreateConstantIntOp(rewriter, loc, step));
741 }
742
743 OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
744 OpFoldResult newStep = rewriter.getOneAttr(rangeType);
745
746 return {newLowerBound, newUpperBound, newStep};
747}
748
750 Location loc,
751 Value normalizedIv,
752 OpFoldResult origLb,
753 OpFoldResult origStep) {
754 AffineExpr d0, s0, s1;
755 bindSymbols(rewriter.getContext(), s0, s1);
756 bindDims(rewriter.getContext(), d0);
757 AffineExpr e = d0 * s1 + s0;
759 rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
760 Value denormalizedIvVal =
761 getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
762 SmallPtrSet<Operation *, 1> preservedUses;
763 // If an `affine.apply` operation is generated for denormalization, the use
764 // of `origLb` in those ops must not be replaced. These arent not generated
765 // when `origLb == 0` and `origStep == 1`.
766 if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
767 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
768 preservedUses.insert(preservedUse);
769 }
770 }
771 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
772}
773
775 Value normalizedIv, OpFoldResult origLb,
776 OpFoldResult origStep) {
777 if (getType(origLb).isIndex()) {
778 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
779 origLb, origStep);
780 }
781 Value denormalizedIv;
783 bool isStepOne = isOneInteger(origStep);
784 bool isZeroBased = isZeroInteger(origLb);
785
786 Value scaled = normalizedIv;
787 if (!isStepOne) {
788 Value origStepValue =
789 getValueOrCreateConstantIntOp(rewriter, loc, origStep);
790 scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
791 preserve.insert(scaled.getDefiningOp());
792 }
793 denormalizedIv = scaled;
794 if (!isZeroBased) {
795 Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
796 denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
797 preserve.insert(denormalizedIv.getDefiningOp());
798 }
799
800 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
801}
802
804 ArrayRef<OpFoldResult> values) {
805 assert(!values.empty() && "unexecpted empty array");
806 AffineExpr s0, s1;
807 bindSymbols(rewriter.getContext(), s0, s1);
808 AffineExpr mul = s0 * s1;
809 OpFoldResult products = rewriter.getIndexAttr(1);
810 for (auto v : values) {
812 rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
813 }
814 return products;
815}
816
817/// Helper function to multiply a sequence of values.
819 ArrayRef<Value> values) {
820 assert(!values.empty() && "unexpected empty list");
821 if (getType(values.front()).isIndex()) {
823 OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
824 return getValueOrCreateConstantIndexOp(rewriter, loc, product);
825 }
826 std::optional<Value> productOf;
827 for (auto v : values) {
828 auto vOne = getConstantIntValue(v);
829 if (vOne && vOne.value() == 1)
830 continue;
831 if (productOf)
832 productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
833 .getResult();
834 else
835 productOf = v;
836 }
837 if (!productOf) {
838 productOf = arith::ConstantOp::create(
839 rewriter, loc, rewriter.getOneAttr(getType(values.front())))
840 .getResult();
841 }
842 return productOf.value();
843}
844
845/// For each original loop, the value of the
846/// induction variable can be obtained by dividing the induction variable of
847/// the linearized loop by the total number of iterations of the loops nested
848/// in it modulo the number of iterations in this loop (remove the values
849/// related to the outer loops):
850/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
851/// Compute these iteratively from the innermost loop by creating a "running
852/// quotient" of division by the range.
853static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
855 Value linearizedIv, ArrayRef<Value> ubs) {
856
857 if (linearizedIv.getType().isIndex()) {
858 Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
859 rewriter, loc, linearizedIv, ubs);
860 auto resultVals = llvm::map_to_vector(
861 delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
862 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
863 }
864
865 SmallVector<Value> delinearizedIvs(ubs.size());
866 SmallPtrSet<Operation *, 2> preservedUsers;
867
868 llvm::BitVector isUbOne(ubs.size());
869 for (auto [index, ub] : llvm::enumerate(ubs)) {
870 auto ubCst = getConstantIntValue(ub);
871 if (ubCst && ubCst.value() == 1)
872 isUbOne.set(index);
873 }
874
875 // Prune the lead ubs that are all ones.
876 unsigned numLeadingOneUbs = 0;
877 for (auto [index, ub] : llvm::enumerate(ubs)) {
878 if (!isUbOne.test(index)) {
879 break;
880 }
881 delinearizedIvs[index] = arith::ConstantOp::create(
882 rewriter, loc, rewriter.getZeroAttr(ub.getType()));
883 numLeadingOneUbs++;
884 }
885
886 Value previous = linearizedIv;
887 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
888 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
889 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
890 previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
891 preservedUsers.insert(previous.getDefiningOp());
892 }
893 Value iv = previous;
894 if (i != e - 1) {
895 if (!isUbOne.test(idx)) {
896 iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
897 preservedUsers.insert(iv.getDefiningOp());
898 } else {
899 iv = arith::ConstantOp::create(
900 rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
901 }
902 }
903 delinearizedIvs[idx] = iv;
904 }
905 return {delinearizedIvs, preservedUsers};
906}
907
908LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
910 if (loops.size() < 2)
911 return failure();
912
913 scf::ForOp innermost = loops.back();
914 scf::ForOp outermost = loops.front();
915
916 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
917 // allows the following code to assume upperBound is the number of iterations.
918 for (auto loop : loops) {
919 OpBuilder::InsertionGuard g(rewriter);
920 rewriter.setInsertionPoint(outermost);
921 Value lb = loop.getLowerBound();
922 Value ub = loop.getUpperBound();
923 Value step = loop.getStep();
924 auto newLoopRange =
925 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
926
927 rewriter.modifyOpInPlace(loop, [&]() {
928 loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
929 newLoopRange.offset));
930 loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
931 newLoopRange.size));
932 loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
933 newLoopRange.stride));
934 });
935 rewriter.setInsertionPointToStart(innermost.getBody());
936 denormalizeInductionVariable(rewriter, loop.getLoc(),
937 loop.getInductionVar(), lb, step);
938 }
939
940 // 2. Emit code computing the upper bound of the coalesced loop as product
941 // of the number of iterations of all loops.
942 OpBuilder::InsertionGuard g(rewriter);
943 rewriter.setInsertionPoint(outermost);
944 Location loc = outermost.getLoc();
945 SmallVector<Value> upperBounds = llvm::map_to_vector(
946 loops, [](auto loop) { return loop.getUpperBound(); });
947 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
948 outermost.setUpperBound(upperBound);
949
950 rewriter.setInsertionPointToStart(innermost.getBody());
951 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
952 rewriter, loc, outermost.getInductionVar(), upperBounds);
953 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
954 preservedUsers);
955
956 for (int i = loops.size() - 1; i > 0; --i) {
957 auto outerLoop = loops[i - 1];
958 auto innerLoop = loops[i];
959
960 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
961 auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
962 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
963 for (Value &yieldedVal : yieldedVals) {
964 // The yielded value may be an iteration argument of the inner loop
965 // which is about to be inlined.
966 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
967 if (iter != innerLoop.getRegionIterArgs().end()) {
968 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
969 // `outerLoop` iter args identical to the `innerLoop` init args.
970 assert(iterArgIndex < innerLoop.getInitArgs().size());
971 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
972 }
973 }
974 rewriter.eraseOp(innerTerminator);
975
976 SmallVector<Value> innerBlockArgs;
977 innerBlockArgs.push_back(delinearizeIvs[i]);
978 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
979 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
980 Block::iterator(innerLoop), innerBlockArgs);
981 rewriter.replaceOp(innerLoop, yieldedVals);
982 }
983 return success();
984}
985
987 if (loops.empty()) {
988 return failure();
989 }
990 IRRewriter rewriter(loops.front().getContext());
991 return coalesceLoops(rewriter, loops);
992}
993
994LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
995 LogicalResult result(failure());
997 getPerfectlyNestedLoops(loops, op);
998
999 // Look for a band of loops that can be coalesced, i.e. perfectly nested
1000 // loops with bounds defined above some loop.
1001
1002 // 1. For each loop, find above which parent loop its bounds operands are
1003 // defined.
1004 SmallVector<unsigned> operandsDefinedAbove(loops.size());
1005 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1006 operandsDefinedAbove[i] = i;
1007 for (unsigned j = 0; j < i; ++j) {
1008 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1009 loops[i].getUpperBound(),
1010 loops[i].getStep()};
1011 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1012 operandsDefinedAbove[i] = j;
1013 break;
1014 }
1015 }
1016 }
1017
1018 // 2. For each inner loop check that the iter_args for the immediately outer
1019 // loop are the init for the immediately inner loop and that the yields of the
1020 // return of the inner loop is the yield for the immediately outer loop. Keep
1021 // track of where the chain starts from for each loop.
1022 SmallVector<unsigned> iterArgChainStart(loops.size());
1023 iterArgChainStart[0] = 0;
1024 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1025 // By default set the start of the chain to itself.
1026 iterArgChainStart[i] = i;
1027 auto outerloop = loops[i - 1];
1028 auto innerLoop = loops[i];
1029 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1030 continue;
1031 }
1032 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1033 continue;
1034 }
1035 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1036 if (!llvm::equal(outerloopTerminator->getOperands(),
1037 innerLoop.getResults())) {
1038 continue;
1039 }
1040 iterArgChainStart[i] = iterArgChainStart[i - 1];
1041 }
1042
1043 // 3. Identify bands of loops such that the operands of all of them are
1044 // defined above the first loop in the band. Traverse the nest bottom-up
1045 // so that modifications don't invalidate the inner loops.
1046 for (unsigned end = loops.size(); end > 0; --end) {
1047 unsigned start = 0;
1048 for (; start < end - 1; ++start) {
1049 auto maxPos =
1050 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1051 std::next(operandsDefinedAbove.begin(), end));
1052 if (maxPos > start)
1053 continue;
1054 if (iterArgChainStart[end - 1] > start)
1055 continue;
1056 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1057 if (succeeded(coalesceLoops(band)))
1058 result = success();
1059 break;
1060 }
1061 // If a band was found and transformed, keep looking at the loops above
1062 // the outermost transformed loop.
1063 if (start != end - 1)
1064 end = start + 1;
1065 }
1066 return result;
1067}
1068
1070 RewriterBase &rewriter, scf::ParallelOp loops,
1071 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1072 OpBuilder::InsertionGuard g(rewriter);
1073 rewriter.setInsertionPoint(loops);
1074 Location loc = loops.getLoc();
1075
1076 // Presort combined dimensions.
1077 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1078 for (auto &dims : sortedDimensions)
1079 llvm::sort(dims);
1080
1081 // Normalize ParallelOp's iteration pattern.
1082 SmallVector<Value, 3> normalizedUpperBounds;
1083 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1084 OpBuilder::InsertionGuard g2(rewriter);
1085 rewriter.setInsertionPoint(loops);
1086 Value lb = loops.getLowerBound()[i];
1087 Value ub = loops.getUpperBound()[i];
1088 Value step = loops.getStep()[i];
1089 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1090 normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1091 rewriter, loops.getLoc(), newLoopRange.size));
1092
1093 rewriter.setInsertionPointToStart(loops.getBody());
1094 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1095 step);
1096 }
1097
1098 // Combine iteration spaces.
1099 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1100 auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1101 auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1102 for (auto &sortedDimension : sortedDimensions) {
1103 Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1104 for (auto idx : sortedDimension) {
1105 newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1106 normalizedUpperBounds[idx]);
1107 }
1108 lowerBounds.push_back(cst0);
1109 steps.push_back(cst1);
1110 upperBounds.push_back(newUpperBound);
1111 }
1112
1113 // Create new ParallelLoop with conversions to the original induction values.
1114 // The loop below uses divisions to get the relevant range of values in the
1115 // new induction value that represent each range of the original induction
1116 // value. The remainders then determine based on that range, which iteration
1117 // of the original induction value this represents. This is a normalized value
1118 // that is un-normalized already by the previous logic.
1119 auto newPloop = scf::ParallelOp::create(
1120 rewriter, loc, lowerBounds, upperBounds, steps,
1121 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1122 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1123 Value previous = ploopIVs[i];
1124 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1125 // Iterate over all except the last induction value.
1126 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1127 unsigned idx = combinedDimensions[i][j];
1128
1129 // Determine the current induction value's current loop iteration
1130 Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1131 normalizedUpperBounds[idx]);
1132 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1133 loops.getRegion());
1134
1135 // Remove the effect of the current induction value to prepare for
1136 // the next value.
1137 previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1138 normalizedUpperBounds[idx]);
1139 }
1140
1141 // The final induction value is just the remaining value.
1142 unsigned idx = combinedDimensions[i][0];
1143 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1144 previous, loops.getRegion());
1145 }
1146 });
1147
1148 // Replace the old loop with the new loop.
1149 loops.getBody()->back().erase();
1150 newPloop.getBody()->getOperations().splice(
1151 Block::iterator(newPloop.getBody()->back()),
1152 loops.getBody()->getOperations());
1153 loops.erase();
1154}
1155
1156// Hoist the ops within `outer` that appear before `inner`.
1157// Such ops include the ops that have been introduced by parametric tiling.
1158// Ops that come from triangular loops (i.e. that belong to the program slice
1159// rooted at `outer`) and ops that have side effects cannot be hoisted.
1160// Return failure when any op fails to hoist.
1161static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1162 SetVector<Operation *> forwardSlice;
1164 options.filter = [&inner](Operation *op) {
1165 return op != inner.getOperation();
1166 };
1167 getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1168 LogicalResult status = success();
1170 for (auto &op : outer.getBody()->without_terminator()) {
1171 // Stop when encountering the inner loop.
1172 if (&op == inner.getOperation())
1173 break;
1174 // Skip over non-hoistable ops.
1175 if (forwardSlice.count(&op) > 0) {
1176 status = failure();
1177 continue;
1178 }
1179 // Skip intermediate scf::ForOp, these are not considered a failure.
1180 if (isa<scf::ForOp>(op))
1181 continue;
1182 // Skip other ops with regions.
1183 if (op.getNumRegions() > 0) {
1184 status = failure();
1185 continue;
1186 }
1187 // Skip if op has side effects.
1188 // TODO: loads to immutable memory regions are ok.
1189 if (!isMemoryEffectFree(&op)) {
1190 status = failure();
1191 continue;
1192 }
1193 toHoist.push_back(&op);
1194 }
1195 auto *outerForOp = outer.getOperation();
1196 for (auto *op : toHoist)
1197 op->moveBefore(outerForOp);
1198 return status;
1199}
1200
1201// Traverse the interTile and intraTile loops and try to hoist ops such that
1202// bands of perfectly nested loops are isolated.
1203// Return failure if either perfect interTile or perfect intraTile bands cannot
1204// be formed.
1205static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1206 LogicalResult status = success();
1207 const Loops &interTile = tileLoops.first;
1208 const Loops &intraTile = tileLoops.second;
1209 auto size = interTile.size();
1210 assert(size == intraTile.size());
1211 if (size <= 1)
1212 return success();
1213 for (unsigned s = 1; s < size; ++s)
1214 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1215 : failure();
1216 for (unsigned s = 1; s < size; ++s)
1217 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1218 : failure();
1219 return status;
1220}
1221
1222/// Collect perfectly nested loops starting from `rootForOps`. Loops are
1223/// perfectly nested if each loop is the first and only non-terminator operation
1224/// in the parent loop. Collect at most `maxLoops` loops and append them to
1225/// `forOps`.
1226template <typename T>
1228 SmallVectorImpl<T> &forOps, T rootForOp,
1229 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1230 for (unsigned i = 0; i < maxLoops; ++i) {
1231 forOps.push_back(rootForOp);
1232 Block &body = rootForOp.getRegion().front();
1233 if (body.begin() != std::prev(body.end(), 2))
1234 return;
1235
1236 rootForOp = dyn_cast<T>(&body.front());
1237 if (!rootForOp)
1238 return;
1239 }
1240}
1241
1242static Loops stripmineSink(scf::ForOp forOp, Value factor,
1243 ArrayRef<scf::ForOp> targets) {
1244 assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
1245 auto originalStep = forOp.getStep();
1246 auto iv = forOp.getInductionVar();
1247
1248 OpBuilder b(forOp);
1249 forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1250
1251 Loops innerLoops;
1252 for (auto t : targets) {
1253 assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
1254
1255 // Save information for splicing ops out of t when done
1256 auto begin = t.getBody()->begin();
1257 auto nOps = t.getBody()->getOperations().size();
1258
1259 // Insert newForOp before the terminator of `t`.
1260 auto b = OpBuilder::atBlockTerminator((t.getBody()));
1261 Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1262 Value ub =
1263 arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1264
1265 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1266 auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1267 newForOp.getBody()->getOperations().splice(
1268 newForOp.getBody()->getOperations().begin(),
1269 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1270 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1271 newForOp.getRegion());
1272
1273 innerLoops.push_back(newForOp);
1274 }
1275
1276 return innerLoops;
1277}
1278
1279// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1280// Returns the new for operation, nested immediately under `target`.
1281template <typename SizeType>
1282static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1283 scf::ForOp target) {
1284 // TODO: Use cheap structural assertions that targets are nested under
1285 // forOp and that targets are not nested under each other when DominanceInfo
1286 // exposes the capability. It seems overkill to construct a whole function
1287 // dominance tree at this point.
1288 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1289 assert(res.size() == 1 && "Expected 1 inner forOp");
1290 return res[0];
1291}
1292
1294 ArrayRef<Value> sizes,
1295 ArrayRef<scf::ForOp> targets) {
1297 SmallVector<scf::ForOp, 8> currentTargets(targets);
1298 for (auto it : llvm::zip(forOps, sizes)) {
1299 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1300 res.push_back(step);
1301 currentTargets = step;
1302 }
1303 return res;
1304}
1305
1307 scf::ForOp target) {
1309 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1310 res.push_back(llvm::getSingleElement(loops));
1311 return res;
1312}
1313
1315 // Collect perfectly nested loops. If more size values provided than nested
1316 // loops available, truncate `sizes`.
1318 forOps.reserve(sizes.size());
1319 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1320 if (forOps.size() < sizes.size())
1321 sizes = sizes.take_front(forOps.size());
1322
1323 return ::tile(forOps, sizes, forOps.back());
1324}
1325
1327 scf::ForOp root) {
1328 getPerfectlyNestedLoopsImpl(nestedLoops, root);
1329}
1330
1332 ArrayRef<int64_t> sizes) {
1333 // Collect perfectly nested loops. If more size values provided than nested
1334 // loops available, truncate `sizes`.
1336 forOps.reserve(sizes.size());
1337 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1338 if (forOps.size() < sizes.size())
1339 sizes = sizes.take_front(forOps.size());
1340
1341 // Compute the tile sizes such that i-th outer loop executes size[i]
1342 // iterations. Given that the loop current executes
1343 // numIterations = ceildiv((upperBound - lowerBound), step)
1344 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1345 SmallVector<Value, 4> tileSizes;
1346 tileSizes.reserve(sizes.size());
1347 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1348 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1349
1350 auto forOp = forOps[i];
1351 OpBuilder builder(forOp);
1352 auto loc = forOp.getLoc();
1353 Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1354 forOp.getLowerBound());
1355 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1356 Value iterationsPerBlock =
1357 ceilDivPositive(builder, loc, numIterations, sizes[i]);
1358 tileSizes.push_back(iterationsPerBlock);
1359 }
1360
1361 // Call parametric tiling with the given sizes.
1362 auto intraTile = tile(forOps, tileSizes, forOps.back());
1363 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1364
1365 // TODO: for now we just ignore the result of band isolation.
1366 // In the future, mapping decisions may be impacted by the ability to
1367 // isolate perfectly nested bands.
1368 (void)tryIsolateBands(tileLoops);
1369
1370 return tileLoops;
1371}
1372
1374 scf::ForallOp source,
1375 RewriterBase &rewriter) {
1376 unsigned numTargetOuts = target.getNumResults();
1377 unsigned numSourceOuts = source.getNumResults();
1378
1379 // Create fused shared_outs.
1380 SmallVector<Value> fusedOuts;
1381 llvm::append_range(fusedOuts, target.getOutputs());
1382 llvm::append_range(fusedOuts, source.getOutputs());
1383
1384 // Create a new scf.forall op after the source loop.
1385 rewriter.setInsertionPointAfter(source);
1386 scf::ForallOp fusedLoop = scf::ForallOp::create(
1387 rewriter, source.getLoc(), source.getMixedLowerBound(),
1388 source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1389 source.getMapping());
1390
1391 // Map control operands.
1392 IRMapping mapping;
1393 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1394 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1395
1396 // Map shared outs.
1397 mapping.map(target.getRegionIterArgs(),
1398 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1399 mapping.map(source.getRegionIterArgs(),
1400 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1401
1402 // Append everything except the terminator into the fused operation.
1403 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1404 for (Operation &op : target.getBody()->without_terminator())
1405 rewriter.clone(op, mapping);
1406 for (Operation &op : source.getBody()->without_terminator())
1407 rewriter.clone(op, mapping);
1408
1409 // Fuse the old terminator in_parallel ops into the new one.
1410 scf::InParallelOp targetTerm = target.getTerminator();
1411 scf::InParallelOp sourceTerm = source.getTerminator();
1412 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1413 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1414 for (Operation &op : targetTerm.getYieldingOps())
1415 rewriter.clone(op, mapping);
1416 for (Operation &op : sourceTerm.getYieldingOps())
1417 rewriter.clone(op, mapping);
1418
1419 // Replace old loops by substituting their uses by results of the fused loop.
1420 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1421 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1422
1423 return fusedLoop;
1424}
1425
1427 scf::ForOp source,
1428 RewriterBase &rewriter) {
1429 assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1430 "incompatible signedness");
1431 unsigned numTargetOuts = target.getNumResults();
1432 unsigned numSourceOuts = source.getNumResults();
1433
1434 // Create fused init_args, with target's init_args before source's init_args.
1435 SmallVector<Value> fusedInitArgs;
1436 llvm::append_range(fusedInitArgs, target.getInitArgs());
1437 llvm::append_range(fusedInitArgs, source.getInitArgs());
1438
1439 // Create a new scf.for op after the source loop (with scf.yield terminator
1440 // (without arguments) only in case its init_args is empty).
1441 rewriter.setInsertionPointAfter(source);
1442 scf::ForOp fusedLoop = scf::ForOp::create(
1443 rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1444 source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
1445 source.getUnsignedCmp());
1446
1447 // Map original induction variables and operands to those of the fused loop.
1448 IRMapping mapping;
1449 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1450 mapping.map(target.getRegionIterArgs(),
1451 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1452 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1453 mapping.map(source.getRegionIterArgs(),
1454 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1455
1456 // Merge target's body into the new (fused) for loop and then source's body.
1457 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1458 for (Operation &op : target.getBody()->without_terminator())
1459 rewriter.clone(op, mapping);
1460 for (Operation &op : source.getBody()->without_terminator())
1461 rewriter.clone(op, mapping);
1462
1463 // Build fused yield results by appropriately mapping original yield operands.
1464 SmallVector<Value> yieldResults;
1465 for (Value operand : target.getBody()->getTerminator()->getOperands())
1466 yieldResults.push_back(mapping.lookupOrDefault(operand));
1467 for (Value operand : source.getBody()->getTerminator()->getOperands())
1468 yieldResults.push_back(mapping.lookupOrDefault(operand));
1469 if (!yieldResults.empty())
1470 scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1471
1472 // Replace old loops by substituting their uses by results of the fused loop.
1473 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1474 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1475
1476 return fusedLoop;
1477}
1478
1479FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1480 scf::ForallOp forallOp) {
1481 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1482 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1483 SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1484
1485 if (forallOp.isNormalized())
1486 return forallOp;
1487
1488 OpBuilder::InsertionGuard g(rewriter);
1489 auto loc = forallOp.getLoc();
1490 rewriter.setInsertionPoint(forallOp);
1492 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1493 Range normalizedLoopParams =
1494 emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1495 newUbs.push_back(normalizedLoopParams.size);
1496 }
1497 (void)foldDynamicIndexList(newUbs);
1498
1499 // Use the normalized builder since the lower bounds are always 0 and the
1500 // steps are always 1.
1501 auto normalizedForallOp = scf::ForallOp::create(
1502 rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1503 [](OpBuilder &, Location, ValueRange) {});
1504
1505 rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1506 normalizedForallOp.getBodyRegion(),
1507 normalizedForallOp.getBodyRegion().begin());
1508 // Remove the original empty block in the new loop.
1509 rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1510
1511 rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1512 // Update the users of the original loop variables.
1513 for (auto [idx, iv] :
1514 llvm::enumerate(normalizedForallOp.getInductionVars())) {
1515 auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1516 auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1517 denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1518 }
1519
1520 rewriter.replaceOp(forallOp, normalizedForallOp);
1521 return normalizedForallOp;
1522}
1523
1526 assert(!loops.empty() && "unexpected empty loop nest");
1527 if (loops.size() == 1)
1528 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1529 for (auto [outerLoop, innerLoop] :
1530 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1531 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1532 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1533 if (!outerFor || !innerFor)
1534 return false;
1535 auto outerBBArgs = outerFor.getRegionIterArgs();
1536 auto innerIterArgs = innerFor.getInitArgs();
1537 if (outerBBArgs.size() != innerIterArgs.size())
1538 return false;
1539
1540 for (auto [outerBBArg, innerIterArg] :
1541 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1542 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1543 innerIterArg != outerBBArg)
1544 return false;
1545 }
1546
1547 ValueRange outerYields =
1548 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1549 ValueRange innerResults = innerFor.getResults();
1550 if (outerYields.size() != innerResults.size())
1551 return false;
1552 for (auto [outerYield, innerResult] :
1553 llvm::zip_equal(outerYields, innerResults)) {
1554 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1555 outerYield != innerResult)
1556 return false;
1557 }
1558 }
1559 return true;
1560}
1561
1563mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
1564 std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1565 std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1566 std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1567 if (!loBnds || !upBnds || !steps)
1568 return {};
1569 llvm::SmallVector<int64_t> tripCounts;
1570 for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1571 std::optional<llvm::APInt> numIter = constantTripCount(
1572 lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
1573 if (!numIter)
1574 return {};
1575 tripCounts.push_back(numIter->getSExtValue());
1576 }
1577 return tripCounts;
1578}
1579
1580FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
1581 scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
1582 RewriterBase &rewriter,
1583 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1584 IRMapping *clonedToSrcOpsMap) {
1585 const unsigned numLoops = op.getNumLoops();
1586 assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
1587 "Expected positive unroll factors");
1588 assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
1589 "Expected non-empty unroll factors of size <= to the number of loops");
1590
1591 // Bail out if no valid unroll factors were provided
1592 if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
1593 return rewriter.notifyMatchFailure(
1594 op, "Unrolling not applied if all factors are 1");
1595
1596 // Return if the loop body is empty.
1597 if (llvm::hasSingleElement(op.getBody()->getOperations()))
1598 return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
1599
1600 // If the provided unroll factors do not cover all the loop dims, they are
1601 // applied to the inner loop dimensions.
1602 const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
1603
1604 // Make sure that the unroll factors divide the iteration space evenly
1605 // TODO: Support unrolling loops with dynamic iteration spaces.
1607 if (tripCounts.empty())
1608 return rewriter.notifyMatchFailure(
1609 op, "Failed to compute constant trip counts for the loop. Note that "
1610 "dynamic loop sizes are not supported.");
1611
1612 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1613 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1614 if (tripCounts[dimIdx] % unrollFactor)
1615 return rewriter.notifyMatchFailure(
1616 op, "Unroll factors don't divide the iteration space evenly");
1617 }
1618
1619 std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
1620 if (!maybeFoldSteps)
1621 return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
1623 for (auto step : *maybeFoldSteps)
1624 steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
1625
1626 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1627 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1628 if (unrollFactor == 1)
1629 continue;
1630 const size_t origStep = steps[dimIdx];
1631 const int64_t newStep = origStep * unrollFactor;
1632 IRMapping clonedToSrcOpsMap;
1633
1634 ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
1635 auto yieldedValues = op.getBody()->getTerminator()->getOperands();
1636
1638 op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
1639 [&](unsigned i, Value iv, OpBuilder b) {
1640 // iv' = iv + step * i;
1641 const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
1642 const auto map =
1643 b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
1644 return affine::AffineApplyOp::create(b, iv.getLoc(), map,
1645 ValueRange{iv});
1646 },
1647 /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
1648
1649 // Update loop step
1650 auto prevInsertPoint = rewriter.saveInsertionPoint();
1651 rewriter.setInsertionPoint(op);
1652 op.getStepMutable()[dimIdx].assign(
1653 arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
1654 rewriter.restoreInsertionPoint(prevInsertPoint);
1655 }
1656 return op;
1657}
return success()
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition Utils.cpp:803
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition Utils.cpp:1205
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition Utils.cpp:1227
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition Utils.cpp:1161
static Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition Utils.cpp:688
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition Utils.cpp:1242
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition Utils.cpp:265
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition Utils.cpp:818
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition Utils.cpp:854
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition Utils.cpp:749
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition Utils.cpp:510
static int64_t product(ArrayRef< int64_t > vals)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:140
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
iterator end()
Definition Block.h:144
iterator begin()
Definition Block.h:143
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:342
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition Builders.h:385
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition Builders.h:252
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition Builders.h:390
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
Definition Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:114
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition Utils.cpp:1326
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition Utils.cpp:1524
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition Utils.cpp:217
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition Utils.cpp:35
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition Utils.cpp:994
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
void generateUnrolledLoop(Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues, IRMapping *clonedToSrcOpsMap=nullptr)
Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs' and 'yieldedValues' as the...
Definition Utils.cpp:294
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:102
llvm::SmallVector< int64_t > getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp)
Get constant trip counts for each of the induction variables of the given loop operation.
Definition Utils.cpp:1563
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:495
std::pair< Loops, Loops > TileLoops
Definition Utils.h:155
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned > > combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition Utils.cpp:1069
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
SliceOptions ForwardSliceOptions
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition Utils.cpp:1314
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition Utils.cpp:367
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition Utils.cpp:523
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition Utils.cpp:240
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
FailureOr< scf::ParallelOp > parallelLoopUnrollByFactors(scf::ParallelOp op, ArrayRef< uint64_t > unrollFactors, RewriterBase &rewriter, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, IRMapping *clonedToSrcOpsMap=nullptr)
Unroll this scf::Parallel loop by the specified unroll factors.
Definition Utils.cpp:1580
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition Utils.cpp:114
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition Utils.cpp:774
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition Utils.cpp:1373
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition Utils.cpp:986
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition Utils.cpp:1426
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition Utils.cpp:1331
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition Utils.cpp:703
SmallVector< scf::ForOp, 8 > Loops
Tile a nest of standard for loops rooted at rootForOp by finding such parametric tile sizes that the ...
Definition Utils.h:154
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition Utils.cpp:1479
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition Utils.h:116
std::optional< scf::ForOp > mainLoopOp
Definition Utils.h:115
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.