MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/SmallVectorExtras.h"
29#include "llvm/Support/DebugLog.h"
30#include <cstdint>
31
32using namespace mlir;
33
34#define DEBUG_TYPE "scf-utils"
35
37 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
38 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
39 bool replaceIterOperandsUsesInLoop) {
40 if (loopNest.empty())
41 return {};
42 // This method is recursive (to make it more readable). Adding an
43 // assertion here to limit the recursion. (See
44 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
45 assert(loopNest.size() <= 10 &&
46 "exceeded recursion limit when yielding value from loop nest");
47
48 // To yield a value from a perfectly nested loop nest, the following
49 // pattern needs to be created, i.e. starting with
50 //
51 // ```mlir
52 // scf.for .. {
53 // scf.for .. {
54 // scf.for .. {
55 // %value = ...
56 // }
57 // }
58 // }
59 // ```
60 //
61 // needs to be modified to
62 //
63 // ```mlir
64 // %0 = scf.for .. iter_args(%arg0 = %init) {
65 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
66 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
67 // %value = ...
68 // scf.yield %value
69 // }
70 // scf.yield %2
71 // }
72 // scf.yield %1
73 // }
74 // ```
75 //
76 // The inner most loop is handled using the `replaceWithAdditionalYields`
77 // that works on a single loop.
78 if (loopNest.size() == 1) {
79 auto innerMostLoop =
80 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
81 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
82 newYieldValuesFn));
83 return {innerMostLoop};
84 }
85 // The outer loops are modified by calling this method recursively
86 // - The return value of the inner loop is the value yielded by this loop.
87 // - The region iter args of this loop are the init_args for the inner loop.
88 SmallVector<scf::ForOp> newLoopNest;
90 [&](OpBuilder &innerBuilder, Location loc,
92 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
93 innerNewBBArgs, newYieldValuesFn,
94 replaceIterOperandsUsesInLoop);
95 return llvm::map_to_vector(
96 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
97 [](OpResult r) -> Value { return r; });
98 };
99 scf::ForOp outerMostLoop =
100 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
101 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
102 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
103 return newLoopNest;
104}
105
106/// Outline a region with a single block into a new FuncOp.
107/// Assumes the FuncOp result types is the type of the yielded operands of the
108/// single block. This constraint makes it easy to determine the result.
109/// This method also clones the `arith::ConstantIndexOp` at the start of
110/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
111/// provided, it will be set to point to the operation that calls the outlined
112/// function.
113// TODO: support more than single-block regions.
114// TODO: more flexible constant handling.
115FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
116 Location loc,
117 Region &region,
118 StringRef funcName,
119 func::CallOp *callOp) {
120 assert(!funcName.empty() && "funcName cannot be empty");
121 if (!region.hasOneBlock())
122 return failure();
123
124 Block *originalBlock = &region.front();
125 Operation *originalTerminator = originalBlock->getTerminator();
126
127 // Outline before current function.
128 OpBuilder::InsertionGuard g(rewriter);
129 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
130
131 SetVector<Value> captures;
132 getUsedValuesDefinedAbove(region, captures);
133
134 ValueRange outlinedValues(captures.getArrayRef());
135 SmallVector<Type> outlinedFuncArgTypes;
136 SmallVector<Location> outlinedFuncArgLocs;
137 // Region's arguments are exactly the first block's arguments as per
138 // Region::getArguments().
139 // Func's arguments are cat(regions's arguments, captures arguments).
140 for (BlockArgument arg : region.getArguments()) {
141 outlinedFuncArgTypes.push_back(arg.getType());
142 outlinedFuncArgLocs.push_back(arg.getLoc());
143 }
144 for (Value value : outlinedValues) {
145 outlinedFuncArgTypes.push_back(value.getType());
146 outlinedFuncArgLocs.push_back(value.getLoc());
147 }
148 FunctionType outlinedFuncType =
149 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
150 originalTerminator->getOperandTypes());
151 auto outlinedFunc =
152 func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
153 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
154
155 // Merge blocks while replacing the original block operands.
156 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
157 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
158 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
159 {
160 OpBuilder::InsertionGuard g(rewriter);
161 rewriter.setInsertionPointToEnd(outlinedFuncBody);
162 rewriter.mergeBlocks(
163 originalBlock, outlinedFuncBody,
164 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
165 // Explicitly set up a new ReturnOp terminator.
166 rewriter.setInsertionPointToEnd(outlinedFuncBody);
167 func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
168 originalTerminator->getOperands());
169 }
170
171 // Reconstruct the block that was deleted and add a
172 // terminator(call_results).
173 Block *newBlock = rewriter.createBlock(
174 &region, region.begin(),
175 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
176 ArrayRef<Location>(outlinedFuncArgLocs)
177 .take_front(numOriginalBlockArguments));
178 {
179 OpBuilder::InsertionGuard g(rewriter);
180 rewriter.setInsertionPointToEnd(newBlock);
181 SmallVector<Value> callValues;
182 llvm::append_range(callValues, newBlock->getArguments());
183 llvm::append_range(callValues, outlinedValues);
184 auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
185 if (callOp)
186 *callOp = call;
187
188 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
189 // Clone `originalTerminator` to take the callOp results then erase it from
190 // `outlinedFuncBody`.
191 IRMapping bvm;
192 bvm.map(originalTerminator->getOperands(), call->getResults());
193 rewriter.clone(*originalTerminator, bvm);
194 rewriter.eraseOp(originalTerminator);
195 }
196
197 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
198 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
199 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
200 outlinedValues.size()))) {
201 Value orig = std::get<0>(it);
202 Value repl = std::get<1>(it);
203 {
204 OpBuilder::InsertionGuard g(rewriter);
205 rewriter.setInsertionPointToStart(outlinedFuncBody);
207 repl = rewriter.clone(*cst)->getResult(0);
208 }
209 }
210 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
211 return outlinedFunc->isProperAncestor(opOperand.getOwner());
212 });
213 }
214
215 return outlinedFunc;
216}
217
218LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
219 func::FuncOp *thenFn, StringRef thenFnName,
220 func::FuncOp *elseFn, StringRef elseFnName) {
221 IRRewriter rewriter(b);
222 Location loc = ifOp.getLoc();
223 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
224 if (thenFn && !ifOp.getThenRegion().empty()) {
225 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
226 rewriter, loc, ifOp.getThenRegion(), thenFnName);
227 if (failed(outlinedFuncOpOrFailure))
228 return failure();
229 *thenFn = *outlinedFuncOpOrFailure;
230 }
231 if (elseFn && !ifOp.getElseRegion().empty()) {
232 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
233 rewriter, loc, ifOp.getElseRegion(), elseFnName);
234 if (failed(outlinedFuncOpOrFailure))
235 return failure();
236 *elseFn = *outlinedFuncOpOrFailure;
237 }
238 return success();
239}
240
243 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
244 bool rootEnclosesPloops = false;
245 for (Region &region : rootOp->getRegions()) {
246 for (Block &block : region.getBlocks()) {
247 for (Operation &op : block) {
248 bool enclosesPloops = getInnermostParallelLoops(&op, result);
249 rootEnclosesPloops |= enclosesPloops;
250 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
251 rootEnclosesPloops = true;
252
253 // Collect parallel loop if it is an innermost one.
254 if (!enclosesPloops)
255 result.push_back(ploop);
256 }
257 }
258 }
259 }
260 return rootEnclosesPloops;
261}
262
263// Build the IR that performs ceil division of a positive value by a constant:
264// ceildiv(a, B) = divis(a + (B-1), B)
265// where divis is rounding-to-zero division.
266static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
267 int64_t divisor) {
268 assert(divisor > 0 && "expected positive divisor");
269 assert(dividend.getType().isIntOrIndex() &&
270 "expected integer or index-typed value");
271
272 Value divisorMinusOneCst = arith::ConstantOp::create(
273 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
274 Value divisorCst = arith::ConstantOp::create(
275 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
276 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
277 return arith::DivUIOp::create(builder, loc, sum, divisorCst);
278}
279
280// Build the IR that performs ceil division of a positive value by another
281// positive value:
282// ceildiv(a, b) = divis(a + (b - 1), b)
283// where divis is rounding-to-zero division.
284static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
285 Value divisor) {
286 assert(dividend.getType().isIntOrIndex() &&
287 "expected integer or index-typed value");
288 Value cstOne = arith::ConstantOp::create(
289 builder, loc, builder.getOneAttr(dividend.getType()));
290 Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
291 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
292 return arith::DivUIOp::create(builder, loc, sum, divisor);
293}
294
296 Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
297 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
298 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
299 ValueRange iterArgs, ValueRange yieldedValues,
300 IRMapping *clonedToSrcOpsMap) {
301
302 // Check if the op was cloned from another source op, and return it if found
303 // (or the same op if not found)
304 auto findOriginalSrcOp =
305 [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
306 Operation *srcOp = op;
307 // If the source op derives from another op: traverse the chain to find the
308 // original source op
309 while (srcOp && clonedToSrcOpsMap.contains(srcOp))
310 srcOp = clonedToSrcOpsMap.lookup(srcOp);
311 return srcOp;
312 };
313
314 // Builder to insert unrolled bodies just before the terminator of the body of
315 // the loop.
316 auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
317
318 static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
319 if (!annotateFn)
320 annotateFn = noopAnnotateFn;
321
322 // Keep a pointer to the last non-terminator operation in the original block
323 // so that we know what to clone (since we are doing this in-place).
324 Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
325
326 // Unroll the contents of the loop body (append unrollFactor - 1 additional
327 // copies).
328 SmallVector<Value, 4> lastYielded(yieldedValues);
329
330 for (unsigned i = 1; i < unrollFactor; i++) {
331 // Prepare operand map.
332 IRMapping operandMap;
333 operandMap.map(iterArgs, lastYielded);
334
335 // If the induction variable is used, create a remapping to the value for
336 // this unrolled instance.
337 if (!iv.use_empty()) {
338 Value ivUnroll = ivRemapFn(i, iv, builder);
339 operandMap.map(iv, ivUnroll);
340 }
341
342 // Clone the original body of 'forOp'.
343 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
344 Operation *srcOp = &(*it);
345 Operation *clonedOp = builder.clone(*srcOp, operandMap);
346 annotateFn(i, clonedOp, builder);
347 if (clonedToSrcOpsMap)
348 clonedToSrcOpsMap->map(clonedOp,
349 findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
350 }
351
352 // Update yielded values.
353 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
354 lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
355 }
356
357 // Make sure we annotate the Ops in the original body. We do this last so that
358 // any annotations are not copied into the cloned Ops above.
359 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
360 annotateFn(0, &*it, builder);
361
362 // Update operands of the yield statement.
363 loopBodyBlock->getTerminator()->setOperands(lastYielded);
364}
365
366/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
367/// epilogue loop, if the loop is unrolled.
368FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
369 scf::ForOp forOp, uint64_t unrollFactor,
370 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
371 assert(unrollFactor > 0 && "expected positive unroll factor");
372
373 // Return if the loop body is empty.
374 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
375 return UnrolledLoopInfo{forOp, std::nullopt};
376
377 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
378 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
379 OpBuilder boundsBuilder(forOp);
380 IRRewriter rewriter(forOp.getContext());
381 auto loc = forOp.getLoc();
382 Value step = forOp.getStep();
383 Value upperBoundUnrolled;
384 Value stepUnrolled;
385 bool generateEpilogueLoop = true;
386
387 std::optional<APInt> constTripCount = forOp.getStaticTripCount();
388 if (constTripCount) {
389 // Constant loop bounds computation.
390 int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
391 int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
392 int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
393 if (unrollFactor == 1) {
394 if (constTripCount->isOne() &&
395 failed(forOp.promoteIfSingleIteration(rewriter)))
396 return failure();
397 return UnrolledLoopInfo{forOp, std::nullopt};
398 }
399
400 uint64_t tripCount = constTripCount->getZExtValue();
401 uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor;
402 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
403 int64_t stepUnrolledCst = stepCst * unrollFactor;
404
405 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
406 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
407 if (generateEpilogueLoop)
408 upperBoundUnrolled = arith::ConstantOp::create(
409 boundsBuilder, loc,
410 boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
411 upperBoundUnrolledCst));
412 else
413 upperBoundUnrolled = forOp.getUpperBound();
414
415 // Create constant for 'stepUnrolled'.
416 stepUnrolled =
417 stepCst == stepUnrolledCst
418 ? step
419 : arith::ConstantOp::create(boundsBuilder, loc,
420 boundsBuilder.getIntegerAttr(
421 step.getType(), stepUnrolledCst));
422 } else {
423 // Dynamic loop bounds computation.
424 // TODO: Add dynamic asserts for negative lb/ub/step, or
425 // consider using ceilDiv from AffineApplyExpander.
426 auto lowerBound = forOp.getLowerBound();
427 auto upperBound = forOp.getUpperBound();
428 Value diff =
429 arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
430 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
431 Value unrollFactorCst = arith::ConstantOp::create(
432 boundsBuilder, loc,
433 boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
434 Value tripCountRem =
435 arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
436 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
437 Value tripCountEvenMultiple =
438 arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
439 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
440 upperBoundUnrolled = arith::AddIOp::create(
441 boundsBuilder, loc, lowerBound,
442 arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
443 // Scale 'step' by 'unrollFactor'.
444 stepUnrolled =
445 arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
446 }
447
448 UnrolledLoopInfo resultLoops;
449
450 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
451 if (generateEpilogueLoop) {
452 OpBuilder epilogueBuilder(forOp->getContext());
453 epilogueBuilder.setInsertionPointAfter(forOp);
454 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
455 epilogueForOp.setLowerBound(upperBoundUnrolled);
456
457 // Update uses of loop results.
458 auto results = forOp.getResults();
459 auto epilogueResults = epilogueForOp.getResults();
460
461 for (auto e : llvm::zip(results, epilogueResults)) {
462 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
463 }
464 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
465 epilogueForOp.getInitArgs().size(), results);
466 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
467 resultLoops.epilogueLoopOp = epilogueForOp;
468 }
469
470 // Create unrolled loop.
471 forOp.setUpperBound(upperBoundUnrolled);
472 forOp.setStep(stepUnrolled);
473
474 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
475 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
476
478 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
479 [&](unsigned i, Value iv, OpBuilder b) {
480 // iv' = iv + step * i;
481 auto stride = arith::MulIOp::create(
482 b, loc, step,
483 arith::ConstantOp::create(b, loc,
484 b.getIntegerAttr(iv.getType(), i)));
485 return arith::AddIOp::create(b, loc, iv, stride);
486 },
487 annotateFn, iterArgs, yieldedValues);
488 // Promote the loop body up if this has turned into a single iteration loop.
489 if (forOp.promoteIfSingleIteration(rewriter).failed())
490 resultLoops.mainLoopOp = forOp;
491 return resultLoops;
492}
493
494/// Unrolls this loop completely.
495LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
496 IRRewriter rewriter(forOp.getContext());
497 std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
498 if (!mayBeConstantTripCount.has_value())
499 return failure();
500 const APInt &tripCount = *mayBeConstantTripCount;
501 if (tripCount.isZero())
502 return success();
503 if (tripCount.isOne())
504 return forOp.promoteIfSingleIteration(rewriter);
505 return loopUnrollByFactor(forOp, tripCount.getZExtValue());
506}
507
508/// Check if bounds of all inner loops are defined outside of `forOp`
509/// and return false if not.
510static bool areInnerBoundsInvariant(scf::ForOp forOp) {
511 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
512 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
513 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
514 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
515 return WalkResult::interrupt();
516
517 return WalkResult::advance();
518 });
519 return !walkResult.wasInterrupted();
520}
521
522/// Unrolls and jams this loop by the specified factor.
523LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
524 uint64_t unrollJamFactor) {
525 assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
526
527 if (unrollJamFactor == 1)
528 return success();
529
530 // If any control operand of any inner loop of `forOp` is defined within
531 // `forOp`, no unroll jam.
532 if (!areInnerBoundsInvariant(forOp)) {
533 LDBG() << "failed to unroll and jam: inner bounds are not invariant";
534 return failure();
535 }
536
537 // Currently, for operations with results are not supported.
538 if (forOp->getNumResults() > 0) {
539 LDBG() << "failed to unroll and jam: unsupported loop with results";
540 return failure();
541 }
542
543 // Currently, only constant trip count that divided by the unroll factor is
544 // supported.
545 std::optional<APInt> tripCount = forOp.getStaticTripCount();
546 if (!tripCount.has_value()) {
547 // If the trip count is dynamic, do not unroll & jam.
548 LDBG() << "failed to unroll and jam: trip count could not be determined";
549 return failure();
550 }
551 uint64_t tripCountValue = tripCount->getZExtValue();
552 if (unrollJamFactor > tripCountValue) {
553 LDBG() << "unroll and jam factor is greater than trip count, set factor to "
554 "trip "
555 "count";
556 unrollJamFactor = tripCountValue;
557 } else if (tripCountValue % unrollJamFactor != 0) {
558 LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
559 "multiple of unroll jam factor";
560 return failure();
561 }
562
563 // Nothing in the loop body other than the terminator.
564 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
565 return success();
566
567 // Gather all sub-blocks to jam upon the loop being unrolled.
569 jbg.walk(forOp);
570 auto &subBlocks = jbg.subBlocks;
571
572 // Collect inner loops.
573 SmallVector<scf::ForOp> innerLoops;
574 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
575
576 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
577 // iteration. There are (`unrollJamFactor` - 1) iterations.
578 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
579
580 // For any loop with iter_args, replace it with a new loop that has
581 // `unrollJamFactor` copies of its iterOperands, iter_args and yield
582 // operands.
583 SmallVector<scf::ForOp> newInnerLoops;
584 IRRewriter rewriter(forOp.getContext());
585 for (scf::ForOp oldForOp : innerLoops) {
586 SmallVector<Value> dupIterOperands, dupYieldOperands;
587 ValueRange oldIterOperands = oldForOp.getInits();
588 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
589 ValueRange oldYieldOperands =
590 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
591 // Get additional iterOperands, iterArgs, and yield operands. We will
592 // fix iterOperands and yield operands after cloning of sub-blocks.
593 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
594 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
595 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
596 }
597 // Create a new loop with additional iterOperands, iter_args and yield
598 // operands. This new loop will take the loop body of the original loop.
599 bool forOpReplaced = oldForOp == forOp;
600 scf::ForOp newForOp =
601 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
602 rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
603 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
604 return dupYieldOperands;
605 }));
606 newInnerLoops.push_back(newForOp);
607 // `forOp` has been replaced with a new loop.
608 if (forOpReplaced)
609 forOp = newForOp;
610 // Update `operandMaps` for `newForOp` iterArgs and results.
611 ValueRange newIterArgs = newForOp.getRegionIterArgs();
612 unsigned oldNumIterArgs = oldIterArgs.size();
613 ValueRange newResults = newForOp.getResults();
614 unsigned oldNumResults = newResults.size() / unrollJamFactor;
615 assert(oldNumIterArgs == oldNumResults &&
616 "oldNumIterArgs must be the same as oldNumResults");
617 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
618 for (unsigned j = 0; j < oldNumIterArgs; ++j) {
619 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
620 // results. Update `operandMaps[i - 1]` to map old iterArgs and results
621 // to those in the `i`th new set.
622 operandMaps[i - 1].map(newIterArgs[j],
623 newIterArgs[i * oldNumIterArgs + j]);
624 operandMaps[i - 1].map(newResults[j],
625 newResults[i * oldNumResults + j]);
626 }
627 }
628 }
629
630 // Scale the step of loop being unroll-jammed by the unroll-jam factor.
631 rewriter.setInsertionPoint(forOp);
632 int64_t step = forOp.getConstantStep()->getSExtValue();
633 auto newStep = rewriter.createOrFold<arith::MulIOp>(
634 forOp.getLoc(), forOp.getStep(),
635 rewriter.createOrFold<arith::ConstantOp>(
636 forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
637 forOp.setStep(newStep);
638 auto forOpIV = forOp.getInductionVar();
639
640 // Unroll and jam (appends unrollJamFactor - 1 additional copies).
641 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
642 for (auto &subBlock : subBlocks) {
643 // Builder to insert unroll-jammed bodies. Insert right at the end of
644 // sub-block.
645 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
646
647 // If the induction variable is used, create a remapping to the value for
648 // this unrolled instance.
649 if (!forOpIV.use_empty()) {
650 // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
651 auto ivTag = builder.createOrFold<arith::ConstantOp>(
652 forOp.getLoc(), builder.getIndexAttr(step * i));
653 auto ivUnroll =
654 builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
655 operandMaps[i - 1].map(forOpIV, ivUnroll);
656 }
657 // Clone the sub-block being unroll-jammed.
658 for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
659 builder.clone(*it, operandMaps[i - 1]);
660 }
661 // Fix iterOperands and yield op operands of newly created loops.
662 for (auto newForOp : newInnerLoops) {
663 unsigned oldNumIterOperands =
664 newForOp.getNumRegionIterArgs() / unrollJamFactor;
665 unsigned numControlOperands = newForOp.getNumControlOperands();
666 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
667 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
668 assert(oldNumIterOperands == oldNumYieldOperands &&
669 "oldNumIterOperands must be the same as oldNumYieldOperands");
670 for (unsigned j = 0; j < oldNumIterOperands; ++j) {
671 // The `i`th duplication of an old iterOperand or yield op operand
672 // needs to be replaced with a mapped value from `operandMaps[i - 1]`
673 // if such mapped value exists.
674 newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
675 operandMaps[i - 1].lookupOrDefault(
676 newForOp.getOperand(numControlOperands + j)));
677 yieldOp.setOperand(
678 i * oldNumYieldOperands + j,
679 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
680 }
681 }
682 }
683
684 // Promote the loop body up if this has turned into a single iteration loop.
685 (void)forOp.promoteIfSingleIteration(rewriter);
686 return success();
687}
688
690 Location loc, OpFoldResult lb,
692 OpFoldResult step) {
693 Range normalizedLoopBounds;
694 normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
695 normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
696 AffineExpr s0, s1, s2;
697 bindSymbols(rewriter.getContext(), s0, s1, s2);
698 AffineExpr e = (s1 - s0).ceilDiv(s2);
699 normalizedLoopBounds.size =
700 affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
701 return normalizedLoopBounds;
702}
703
706 OpFoldResult step) {
707 if (getType(lb).isIndex()) {
708 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
709 }
710 // For non-index types, generate `arith` instructions
711 // Check if the loop is already known to have a constant zero lower bound or
712 // a constant one step.
713 bool isZeroBased = false;
714 if (auto lbCst = getConstantIntValue(lb))
715 isZeroBased = lbCst.value() == 0;
716
717 bool isStepOne = false;
718 if (auto stepCst = getConstantIntValue(step))
719 isStepOne = stepCst.value() == 1;
720
721 Type rangeType = getType(lb);
722 assert(rangeType == getType(ub) && rangeType == getType(step) &&
723 "expected matching types");
724
725 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
726 // assuming the step is strictly positive. Update the bounds and the step
727 // of the loop to go from 0 to the number of iterations, if necessary.
728 if (isZeroBased && isStepOne)
729 return {lb, ub, step};
730
731 OpFoldResult diff = ub;
732 if (!isZeroBased) {
733 diff = rewriter.createOrFold<arith::SubIOp>(
734 loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
735 getValueOrCreateConstantIntOp(rewriter, loc, lb));
736 }
737 OpFoldResult newUpperBound = diff;
738 if (!isStepOne) {
739 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
740 loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
741 getValueOrCreateConstantIntOp(rewriter, loc, step));
742 }
743
744 OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
745 OpFoldResult newStep = rewriter.getOneAttr(rangeType);
746
747 return {newLowerBound, newUpperBound, newStep};
748}
749
751 Location loc,
752 Value normalizedIv,
753 OpFoldResult origLb,
754 OpFoldResult origStep) {
755 AffineExpr d0, s0, s1;
756 bindSymbols(rewriter.getContext(), s0, s1);
757 bindDims(rewriter.getContext(), d0);
758 AffineExpr e = d0 * s1 + s0;
760 rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
761 Value denormalizedIvVal =
762 getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
763 SmallPtrSet<Operation *, 1> preservedUses;
764 // If an `affine.apply` operation is generated for denormalization, the use
765 // of `origLb` in those ops must not be replaced. These arent not generated
766 // when `origLb == 0` and `origStep == 1`.
767 if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
768 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
769 preservedUses.insert(preservedUse);
770 }
771 }
772 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
773}
774
776 Value normalizedIv, OpFoldResult origLb,
777 OpFoldResult origStep) {
778 if (getType(origLb).isIndex()) {
779 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
780 origLb, origStep);
781 }
782 Value denormalizedIv;
784 bool isStepOne = isOneInteger(origStep);
785 bool isZeroBased = isZeroInteger(origLb);
786
787 Value scaled = normalizedIv;
788 if (!isStepOne) {
789 Value origStepValue =
790 getValueOrCreateConstantIntOp(rewriter, loc, origStep);
791 scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
792 preserve.insert(scaled.getDefiningOp());
793 }
794 denormalizedIv = scaled;
795 if (!isZeroBased) {
796 Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
797 denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
798 preserve.insert(denormalizedIv.getDefiningOp());
799 }
800
801 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
802}
803
805 ArrayRef<OpFoldResult> values) {
806 assert(!values.empty() && "unexecpted empty array");
807 AffineExpr s0, s1;
808 bindSymbols(rewriter.getContext(), s0, s1);
809 AffineExpr mul = s0 * s1;
810 OpFoldResult products = rewriter.getIndexAttr(1);
811 for (auto v : values) {
813 rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
814 }
815 return products;
816}
817
818/// Helper function to multiply a sequence of values.
820 ArrayRef<Value> values) {
821 assert(!values.empty() && "unexpected empty list");
822 if (getType(values.front()).isIndex()) {
824 OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
825 return getValueOrCreateConstantIndexOp(rewriter, loc, product);
826 }
827 std::optional<Value> productOf;
828 for (auto v : values) {
829 auto vOne = getConstantIntValue(v);
830 if (vOne && vOne.value() == 1)
831 continue;
832 if (productOf)
833 productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
834 .getResult();
835 else
836 productOf = v;
837 }
838 if (!productOf) {
839 productOf = arith::ConstantOp::create(
840 rewriter, loc, rewriter.getOneAttr(getType(values.front())))
841 .getResult();
842 }
843 return productOf.value();
844}
845
846/// For each original loop, the value of the
847/// induction variable can be obtained by dividing the induction variable of
848/// the linearized loop by the total number of iterations of the loops nested
849/// in it modulo the number of iterations in this loop (remove the values
850/// related to the outer loops):
851/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
852/// Compute these iteratively from the innermost loop by creating a "running
853/// quotient" of division by the range.
854static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
856 Value linearizedIv, ArrayRef<Value> ubs) {
857
858 if (linearizedIv.getType().isIndex()) {
859 Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
860 rewriter, loc, linearizedIv, ubs);
861 auto resultVals = llvm::map_to_vector(
862 delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
863 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
864 }
865
866 SmallVector<Value> delinearizedIvs(ubs.size());
867 SmallPtrSet<Operation *, 2> preservedUsers;
868
869 llvm::BitVector isUbOne(ubs.size());
870 for (auto [index, ub] : llvm::enumerate(ubs)) {
871 auto ubCst = getConstantIntValue(ub);
872 if (ubCst && ubCst.value() == 1)
873 isUbOne.set(index);
874 }
875
876 // Prune the lead ubs that are all ones.
877 unsigned numLeadingOneUbs = 0;
878 for (auto [index, ub] : llvm::enumerate(ubs)) {
879 if (!isUbOne.test(index)) {
880 break;
881 }
882 delinearizedIvs[index] = arith::ConstantOp::create(
883 rewriter, loc, rewriter.getZeroAttr(ub.getType()));
884 numLeadingOneUbs++;
885 }
886
887 Value previous = linearizedIv;
888 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
889 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
890 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
891 previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
892 preservedUsers.insert(previous.getDefiningOp());
893 }
894 Value iv = previous;
895 if (i != e - 1) {
896 if (!isUbOne.test(idx)) {
897 iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
898 preservedUsers.insert(iv.getDefiningOp());
899 } else {
900 iv = arith::ConstantOp::create(
901 rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
902 }
903 }
904 delinearizedIvs[idx] = iv;
905 }
906 return {delinearizedIvs, preservedUsers};
907}
908
909LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
911 if (loops.size() < 2)
912 return failure();
913
914 scf::ForOp innermost = loops.back();
915 scf::ForOp outermost = loops.front();
916
917 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
918 // allows the following code to assume upperBound is the number of iterations.
919 for (auto loop : loops) {
920 OpBuilder::InsertionGuard g(rewriter);
921 rewriter.setInsertionPoint(outermost);
922 Value lb = loop.getLowerBound();
923 Value ub = loop.getUpperBound();
924 Value step = loop.getStep();
925 auto newLoopRange =
926 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
927
928 rewriter.modifyOpInPlace(loop, [&]() {
929 loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
930 newLoopRange.offset));
931 loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
932 newLoopRange.size));
933 loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
934 newLoopRange.stride));
935 });
936 rewriter.setInsertionPointToStart(innermost.getBody());
937 denormalizeInductionVariable(rewriter, loop.getLoc(),
938 loop.getInductionVar(), lb, step);
939 }
940
941 // 2. Emit code computing the upper bound of the coalesced loop as product
942 // of the number of iterations of all loops.
943 OpBuilder::InsertionGuard g(rewriter);
944 rewriter.setInsertionPoint(outermost);
945 Location loc = outermost.getLoc();
946 SmallVector<Value> upperBounds = llvm::map_to_vector(
947 loops, [](auto loop) { return loop.getUpperBound(); });
948 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
949 outermost.setUpperBound(upperBound);
950
951 rewriter.setInsertionPointToStart(innermost.getBody());
952 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
953 rewriter, loc, outermost.getInductionVar(), upperBounds);
954 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
955 preservedUsers);
956
957 for (int i = loops.size() - 1; i > 0; --i) {
958 auto outerLoop = loops[i - 1];
959 auto innerLoop = loops[i];
960
961 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
962 auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
963 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
964 for (Value &yieldedVal : yieldedVals) {
965 // The yielded value may be an iteration argument of the inner loop
966 // which is about to be inlined.
967 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
968 if (iter != innerLoop.getRegionIterArgs().end()) {
969 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
970 // `outerLoop` iter args identical to the `innerLoop` init args.
971 assert(iterArgIndex < innerLoop.getInitArgs().size());
972 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
973 }
974 }
975 rewriter.eraseOp(innerTerminator);
976
977 SmallVector<Value> innerBlockArgs;
978 innerBlockArgs.push_back(delinearizeIvs[i]);
979 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
980 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
981 Block::iterator(innerLoop), innerBlockArgs);
982 rewriter.replaceOp(innerLoop, yieldedVals);
983 }
984 return success();
985}
986
988 if (loops.empty()) {
989 return failure();
990 }
991 IRRewriter rewriter(loops.front().getContext());
992 return coalesceLoops(rewriter, loops);
993}
994
995LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
996 LogicalResult result(failure());
998 getPerfectlyNestedLoops(loops, op);
999
1000 // Look for a band of loops that can be coalesced, i.e. perfectly nested
1001 // loops with bounds defined above some loop.
1002
1003 // 1. For each loop, find above which parent loop its bounds operands are
1004 // defined.
1005 SmallVector<unsigned> operandsDefinedAbove(loops.size());
1006 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1007 operandsDefinedAbove[i] = i;
1008 for (unsigned j = 0; j < i; ++j) {
1009 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1010 loops[i].getUpperBound(),
1011 loops[i].getStep()};
1012 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1013 operandsDefinedAbove[i] = j;
1014 break;
1015 }
1016 }
1017 }
1018
1019 // 2. For each inner loop check that the iter_args for the immediately outer
1020 // loop are the init for the immediately inner loop and that the yields of the
1021 // return of the inner loop is the yield for the immediately outer loop. Keep
1022 // track of where the chain starts from for each loop.
1023 SmallVector<unsigned> iterArgChainStart(loops.size());
1024 iterArgChainStart[0] = 0;
1025 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1026 // By default set the start of the chain to itself.
1027 iterArgChainStart[i] = i;
1028 auto outerloop = loops[i - 1];
1029 auto innerLoop = loops[i];
1030 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1031 continue;
1032 }
1033 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1034 continue;
1035 }
1036 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1037 if (!llvm::equal(outerloopTerminator->getOperands(),
1038 innerLoop.getResults())) {
1039 continue;
1040 }
1041 iterArgChainStart[i] = iterArgChainStart[i - 1];
1042 }
1043
1044 // 3. Identify bands of loops such that the operands of all of them are
1045 // defined above the first loop in the band. Traverse the nest bottom-up
1046 // so that modifications don't invalidate the inner loops.
1047 for (unsigned end = loops.size(); end > 0; --end) {
1048 unsigned start = 0;
1049 for (; start < end - 1; ++start) {
1050 auto maxPos =
1051 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1052 std::next(operandsDefinedAbove.begin(), end));
1053 if (maxPos > start)
1054 continue;
1055 if (iterArgChainStart[end - 1] > start)
1056 continue;
1057 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1058 if (succeeded(coalesceLoops(band)))
1059 result = success();
1060 break;
1061 }
1062 // If a band was found and transformed, keep looking at the loops above
1063 // the outermost transformed loop.
1064 if (start != end - 1)
1065 end = start + 1;
1066 }
1067 return result;
1068}
1069
1071 RewriterBase &rewriter, scf::ParallelOp loops,
1072 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1073 OpBuilder::InsertionGuard g(rewriter);
1074 rewriter.setInsertionPoint(loops);
1075 Location loc = loops.getLoc();
1076
1077 // Presort combined dimensions.
1078 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1079 for (auto &dims : sortedDimensions)
1080 llvm::sort(dims);
1081
1082 // Normalize ParallelOp's iteration pattern.
1083 SmallVector<Value, 3> normalizedUpperBounds;
1084 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1085 OpBuilder::InsertionGuard g2(rewriter);
1086 rewriter.setInsertionPoint(loops);
1087 Value lb = loops.getLowerBound()[i];
1088 Value ub = loops.getUpperBound()[i];
1089 Value step = loops.getStep()[i];
1090 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1091 normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1092 rewriter, loops.getLoc(), newLoopRange.size));
1093
1094 rewriter.setInsertionPointToStart(loops.getBody());
1095 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1096 step);
1097 }
1098
1099 // Combine iteration spaces.
1100 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1101 auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1102 auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1103 for (auto &sortedDimension : sortedDimensions) {
1104 Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1105 for (auto idx : sortedDimension) {
1106 newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1107 normalizedUpperBounds[idx]);
1108 }
1109 lowerBounds.push_back(cst0);
1110 steps.push_back(cst1);
1111 upperBounds.push_back(newUpperBound);
1112 }
1113
1114 // Create new ParallelLoop with conversions to the original induction values.
1115 // The loop below uses divisions to get the relevant range of values in the
1116 // new induction value that represent each range of the original induction
1117 // value. The remainders then determine based on that range, which iteration
1118 // of the original induction value this represents. This is a normalized value
1119 // that is un-normalized already by the previous logic.
1120 auto newPloop = scf::ParallelOp::create(
1121 rewriter, loc, lowerBounds, upperBounds, steps,
1122 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1123 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1124 Value previous = ploopIVs[i];
1125 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1126 // Iterate over all except the last induction value.
1127 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1128 unsigned idx = combinedDimensions[i][j];
1129
1130 // Determine the current induction value's current loop iteration
1131 Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1132 normalizedUpperBounds[idx]);
1133 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1134 loops.getRegion());
1135
1136 // Remove the effect of the current induction value to prepare for
1137 // the next value.
1138 previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1139 normalizedUpperBounds[idx]);
1140 }
1141
1142 // The final induction value is just the remaining value.
1143 unsigned idx = combinedDimensions[i][0];
1144 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1145 previous, loops.getRegion());
1146 }
1147 });
1148
1149 // Replace the old loop with the new loop.
1150 loops.getBody()->back().erase();
1151 newPloop.getBody()->getOperations().splice(
1152 Block::iterator(newPloop.getBody()->back()),
1153 loops.getBody()->getOperations());
1154 loops.erase();
1155}
1156
1157// Hoist the ops within `outer` that appear before `inner`.
1158// Such ops include the ops that have been introduced by parametric tiling.
1159// Ops that come from triangular loops (i.e. that belong to the program slice
1160// rooted at `outer`) and ops that have side effects cannot be hoisted.
1161// Return failure when any op fails to hoist.
1162static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1163 SetVector<Operation *> forwardSlice;
1165 options.filter = [&inner](Operation *op) {
1166 return op != inner.getOperation();
1167 };
1168 getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1169 LogicalResult status = success();
1171 for (auto &op : outer.getBody()->without_terminator()) {
1172 // Stop when encountering the inner loop.
1173 if (&op == inner.getOperation())
1174 break;
1175 // Skip over non-hoistable ops.
1176 if (forwardSlice.count(&op) > 0) {
1177 status = failure();
1178 continue;
1179 }
1180 // Skip intermediate scf::ForOp, these are not considered a failure.
1181 if (isa<scf::ForOp>(op))
1182 continue;
1183 // Skip other ops with regions.
1184 if (op.getNumRegions() > 0) {
1185 status = failure();
1186 continue;
1187 }
1188 // Skip if op has side effects.
1189 // TODO: loads to immutable memory regions are ok.
1190 if (!isMemoryEffectFree(&op)) {
1191 status = failure();
1192 continue;
1193 }
1194 toHoist.push_back(&op);
1195 }
1196 auto *outerForOp = outer.getOperation();
1197 for (auto *op : toHoist)
1198 op->moveBefore(outerForOp);
1199 return status;
1200}
1201
1202// Traverse the interTile and intraTile loops and try to hoist ops such that
1203// bands of perfectly nested loops are isolated.
1204// Return failure if either perfect interTile or perfect intraTile bands cannot
1205// be formed.
1206static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1207 LogicalResult status = success();
1208 const Loops &interTile = tileLoops.first;
1209 const Loops &intraTile = tileLoops.second;
1210 auto size = interTile.size();
1211 assert(size == intraTile.size());
1212 if (size <= 1)
1213 return success();
1214 for (unsigned s = 1; s < size; ++s)
1215 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1216 : failure();
1217 for (unsigned s = 1; s < size; ++s)
1218 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1219 : failure();
1220 return status;
1221}
1222
1223/// Collect perfectly nested loops starting from `rootForOps`. Loops are
1224/// perfectly nested if each loop is the first and only non-terminator operation
1225/// in the parent loop. Collect at most `maxLoops` loops and append them to
1226/// `forOps`.
1227template <typename T>
1229 SmallVectorImpl<T> &forOps, T rootForOp,
1230 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1231 for (unsigned i = 0; i < maxLoops; ++i) {
1232 forOps.push_back(rootForOp);
1233 Block &body = rootForOp.getRegion().front();
1234 if (body.begin() != std::prev(body.end(), 2))
1235 return;
1236
1237 rootForOp = dyn_cast<T>(&body.front());
1238 if (!rootForOp)
1239 return;
1240 }
1241}
1242
1243static Loops stripmineSink(scf::ForOp forOp, Value factor,
1244 ArrayRef<scf::ForOp> targets) {
1245 assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
1246 auto originalStep = forOp.getStep();
1247 auto iv = forOp.getInductionVar();
1248
1249 OpBuilder b(forOp);
1250 forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1251
1252 Loops innerLoops;
1253 for (auto t : targets) {
1254 assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
1255
1256 // Save information for splicing ops out of t when done
1257 auto begin = t.getBody()->begin();
1258 auto nOps = t.getBody()->getOperations().size();
1259
1260 // Insert newForOp before the terminator of `t`.
1261 auto b = OpBuilder::atBlockTerminator((t.getBody()));
1262 Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1263 Value ub =
1264 arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1265
1266 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1267 auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1268 newForOp.getBody()->getOperations().splice(
1269 newForOp.getBody()->getOperations().begin(),
1270 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1271 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1272 newForOp.getRegion());
1273
1274 innerLoops.push_back(newForOp);
1275 }
1276
1277 return innerLoops;
1278}
1279
1280// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1281// Returns the new for operation, nested immediately under `target`.
1282template <typename SizeType>
1283static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1284 scf::ForOp target) {
1285 // TODO: Use cheap structural assertions that targets are nested under
1286 // forOp and that targets are not nested under each other when DominanceInfo
1287 // exposes the capability. It seems overkill to construct a whole function
1288 // dominance tree at this point.
1289 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1290 assert(res.size() == 1 && "Expected 1 inner forOp");
1291 return res[0];
1292}
1293
1295 ArrayRef<Value> sizes,
1296 ArrayRef<scf::ForOp> targets) {
1298 SmallVector<scf::ForOp, 8> currentTargets(targets);
1299 for (auto it : llvm::zip(forOps, sizes)) {
1300 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1301 res.push_back(step);
1302 currentTargets = step;
1303 }
1304 return res;
1305}
1306
1308 scf::ForOp target) {
1310 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1311 res.push_back(llvm::getSingleElement(loops));
1312 return res;
1313}
1314
1316 // Collect perfectly nested loops. If more size values provided than nested
1317 // loops available, truncate `sizes`.
1319 forOps.reserve(sizes.size());
1320 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1321 if (forOps.size() < sizes.size())
1322 sizes = sizes.take_front(forOps.size());
1323
1324 return ::tile(forOps, sizes, forOps.back());
1325}
1326
1328 scf::ForOp root) {
1329 getPerfectlyNestedLoopsImpl(nestedLoops, root);
1330}
1331
1333 ArrayRef<int64_t> sizes) {
1334 // Collect perfectly nested loops. If more size values provided than nested
1335 // loops available, truncate `sizes`.
1337 forOps.reserve(sizes.size());
1338 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1339 if (forOps.size() < sizes.size())
1340 sizes = sizes.take_front(forOps.size());
1341
1342 // Compute the tile sizes such that i-th outer loop executes size[i]
1343 // iterations. Given that the loop current executes
1344 // numIterations = ceildiv((upperBound - lowerBound), step)
1345 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1346 SmallVector<Value, 4> tileSizes;
1347 tileSizes.reserve(sizes.size());
1348 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1349 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1350
1351 auto forOp = forOps[i];
1352 OpBuilder builder(forOp);
1353 auto loc = forOp.getLoc();
1354 Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1355 forOp.getLowerBound());
1356 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1357 Value iterationsPerBlock =
1358 ceilDivPositive(builder, loc, numIterations, sizes[i]);
1359 tileSizes.push_back(iterationsPerBlock);
1360 }
1361
1362 // Call parametric tiling with the given sizes.
1363 auto intraTile = tile(forOps, tileSizes, forOps.back());
1364 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1365
1366 // TODO: for now we just ignore the result of band isolation.
1367 // In the future, mapping decisions may be impacted by the ability to
1368 // isolate perfectly nested bands.
1369 (void)tryIsolateBands(tileLoops);
1370
1371 return tileLoops;
1372}
1373
1375 scf::ForallOp source,
1376 RewriterBase &rewriter) {
1377 unsigned numTargetOuts = target.getNumResults();
1378 unsigned numSourceOuts = source.getNumResults();
1379
1380 // Create fused shared_outs.
1381 SmallVector<Value> fusedOuts;
1382 llvm::append_range(fusedOuts, target.getOutputs());
1383 llvm::append_range(fusedOuts, source.getOutputs());
1384
1385 // Create a new scf.forall op after the source loop.
1386 rewriter.setInsertionPointAfter(source);
1387 scf::ForallOp fusedLoop = scf::ForallOp::create(
1388 rewriter, source.getLoc(), source.getMixedLowerBound(),
1389 source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1390 source.getMapping());
1391
1392 // Map control operands.
1393 IRMapping mapping;
1394 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1395 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1396
1397 // Map shared outs.
1398 mapping.map(target.getRegionIterArgs(),
1399 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1400 mapping.map(source.getRegionIterArgs(),
1401 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1402
1403 // Append everything except the terminator into the fused operation.
1404 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1405 for (Operation &op : target.getBody()->without_terminator())
1406 rewriter.clone(op, mapping);
1407 for (Operation &op : source.getBody()->without_terminator())
1408 rewriter.clone(op, mapping);
1409
1410 // Fuse the old terminator in_parallel ops into the new one.
1411 scf::InParallelOp targetTerm = target.getTerminator();
1412 scf::InParallelOp sourceTerm = source.getTerminator();
1413 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1414 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1415 for (Operation &op : targetTerm.getYieldingOps())
1416 rewriter.clone(op, mapping);
1417 for (Operation &op : sourceTerm.getYieldingOps())
1418 rewriter.clone(op, mapping);
1419
1420 // Replace old loops by substituting their uses by results of the fused loop.
1421 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1422 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1423
1424 return fusedLoop;
1425}
1426
1428 scf::ForOp source,
1429 RewriterBase &rewriter) {
1430 assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1431 "incompatible signedness");
1432 unsigned numTargetOuts = target.getNumResults();
1433 unsigned numSourceOuts = source.getNumResults();
1434
1435 // Create fused init_args, with target's init_args before source's init_args.
1436 SmallVector<Value> fusedInitArgs;
1437 llvm::append_range(fusedInitArgs, target.getInitArgs());
1438 llvm::append_range(fusedInitArgs, source.getInitArgs());
1439
1440 // Create a new scf.for op after the source loop (with scf.yield terminator
1441 // (without arguments) only in case its init_args is empty).
1442 rewriter.setInsertionPointAfter(source);
1443 scf::ForOp fusedLoop = scf::ForOp::create(
1444 rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1445 source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
1446 source.getUnsignedCmp());
1447
1448 // Map original induction variables and operands to those of the fused loop.
1449 IRMapping mapping;
1450 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1451 mapping.map(target.getRegionIterArgs(),
1452 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1453 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1454 mapping.map(source.getRegionIterArgs(),
1455 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1456
1457 // Merge target's body into the new (fused) for loop and then source's body.
1458 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1459 for (Operation &op : target.getBody()->without_terminator())
1460 rewriter.clone(op, mapping);
1461 for (Operation &op : source.getBody()->without_terminator())
1462 rewriter.clone(op, mapping);
1463
1464 // Build fused yield results by appropriately mapping original yield operands.
1465 SmallVector<Value> yieldResults;
1466 for (Value operand : target.getBody()->getTerminator()->getOperands())
1467 yieldResults.push_back(mapping.lookupOrDefault(operand));
1468 for (Value operand : source.getBody()->getTerminator()->getOperands())
1469 yieldResults.push_back(mapping.lookupOrDefault(operand));
1470 if (!yieldResults.empty())
1471 scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1472
1473 // Replace old loops by substituting their uses by results of the fused loop.
1474 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1475 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1476
1477 return fusedLoop;
1478}
1479
1480FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1481 scf::ForallOp forallOp) {
1482 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1483 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1484 SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1485
1486 if (forallOp.isNormalized())
1487 return forallOp;
1488
1489 OpBuilder::InsertionGuard g(rewriter);
1490 auto loc = forallOp.getLoc();
1491 rewriter.setInsertionPoint(forallOp);
1493 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1494 Range normalizedLoopParams =
1495 emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1496 newUbs.push_back(normalizedLoopParams.size);
1497 }
1498 (void)foldDynamicIndexList(newUbs);
1499
1500 // Use the normalized builder since the lower bounds are always 0 and the
1501 // steps are always 1.
1502 auto normalizedForallOp = scf::ForallOp::create(
1503 rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1504 [](OpBuilder &, Location, ValueRange) {});
1505
1506 rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1507 normalizedForallOp.getBodyRegion(),
1508 normalizedForallOp.getBodyRegion().begin());
1509 // Remove the original empty block in the new loop.
1510 rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1511
1512 rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1513 // Update the users of the original loop variables.
1514 for (auto [idx, iv] :
1515 llvm::enumerate(normalizedForallOp.getInductionVars())) {
1516 auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1517 auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1518 denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1519 }
1520
1521 rewriter.replaceOp(forallOp, normalizedForallOp);
1522 return normalizedForallOp;
1523}
1524
1527 assert(!loops.empty() && "unexpected empty loop nest");
1528 if (loops.size() == 1)
1529 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1530 for (auto [outerLoop, innerLoop] :
1531 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1532 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1533 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1534 if (!outerFor || !innerFor)
1535 return false;
1536 auto outerBBArgs = outerFor.getRegionIterArgs();
1537 auto innerIterArgs = innerFor.getInitArgs();
1538 if (outerBBArgs.size() != innerIterArgs.size())
1539 return false;
1540
1541 for (auto [outerBBArg, innerIterArg] :
1542 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1543 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1544 innerIterArg != outerBBArg)
1545 return false;
1546 }
1547
1548 ValueRange outerYields =
1549 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1550 ValueRange innerResults = innerFor.getResults();
1551 if (outerYields.size() != innerResults.size())
1552 return false;
1553 for (auto [outerYield, innerResult] :
1554 llvm::zip_equal(outerYields, innerResults)) {
1555 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1556 outerYield != innerResult)
1557 return false;
1558 }
1559 }
1560 return true;
1561}
1562
1564mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
1565 std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1566 std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1567 std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1568 if (!loBnds || !upBnds || !steps)
1569 return {};
1571 for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1572 // TODO(#178506): Signedness is not handled correctly here.
1573 std::optional<llvm::APInt> numIter = constantTripCount(
1574 lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
1575 if (!numIter)
1576 return {};
1577 tripCounts.push_back(*numIter);
1578 }
1579 return tripCounts;
1580}
1581
1582FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
1583 scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
1584 RewriterBase &rewriter,
1585 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1586 IRMapping *clonedToSrcOpsMap) {
1587 const unsigned numLoops = op.getNumLoops();
1588 assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
1589 "Expected positive unroll factors");
1590 assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
1591 "Expected non-empty unroll factors of size <= to the number of loops");
1592
1593 // Bail out if no valid unroll factors were provided
1594 if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
1595 return rewriter.notifyMatchFailure(
1596 op, "Unrolling not applied if all factors are 1");
1597
1598 // Return if the loop body is empty.
1599 if (llvm::hasSingleElement(op.getBody()->getOperations()))
1600 return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
1601
1602 // If the provided unroll factors do not cover all the loop dims, they are
1603 // applied to the inner loop dimensions.
1604 const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
1605
1606 // Make sure that the unroll factors divide the iteration space evenly
1607 // TODO: Support unrolling loops with dynamic iteration spaces.
1609 if (tripCounts.empty())
1610 return rewriter.notifyMatchFailure(
1611 op, "Failed to compute constant trip counts for the loop. Note that "
1612 "dynamic loop sizes are not supported.");
1613
1614 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1615 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1616 if (tripCounts[dimIdx].urem(unrollFactor) != 0)
1617 return rewriter.notifyMatchFailure(
1618 op, "Unroll factors don't divide the iteration space evenly");
1619 }
1620
1621 std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
1622 if (!maybeFoldSteps)
1623 return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
1625 for (auto step : *maybeFoldSteps)
1626 steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
1627
1628 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1629 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1630 if (unrollFactor == 1)
1631 continue;
1632 const size_t origStep = steps[dimIdx];
1633 const int64_t newStep = origStep * unrollFactor;
1634 IRMapping clonedToSrcOpsMap;
1635
1636 ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
1637 auto yieldedValues = op.getBody()->getTerminator()->getOperands();
1638
1640 op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
1641 [&](unsigned i, Value iv, OpBuilder b) {
1642 // iv' = iv + step * i;
1643 const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
1644 const auto map =
1645 b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
1646 return affine::AffineApplyOp::create(b, iv.getLoc(), map,
1647 ValueRange{iv});
1648 },
1649 /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
1650
1651 // Update loop step
1652 auto prevInsertPoint = rewriter.saveInsertionPoint();
1653 rewriter.setInsertionPoint(op);
1654 op.getStepMutable()[dimIdx].assign(
1655 arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
1656 rewriter.restoreInsertionPoint(prevInsertPoint);
1657 }
1658 return op;
1659}
return success()
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition Utils.cpp:804
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition Utils.cpp:1206
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition Utils.cpp:1228
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition Utils.cpp:1162
static Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition Utils.cpp:689
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition Utils.cpp:1243
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition Utils.cpp:266
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition Utils.cpp:819
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition Utils.cpp:855
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition Utils.cpp:750
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition Utils.cpp:510
static int64_t product(ArrayRef< int64_t > vals)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition Builders.h:387
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition Builders.h:254
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition Builders.h:392
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
Definition Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition Utils.cpp:1327
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition Utils.cpp:1525
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition Utils.cpp:218
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition Utils.cpp:36
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition Utils.cpp:995
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
void generateUnrolledLoop(Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues, IRMapping *clonedToSrcOpsMap=nullptr)
Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs' and 'yieldedValues' as the...
Definition Utils.cpp:295
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:103
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:495
llvm::SmallVector< llvm::APInt > getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp)
Get constant trip counts for each of the induction variables of the given loop operation.
Definition Utils.cpp:1564
std::pair< Loops, Loops > TileLoops
Definition Utils.h:155
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned > > combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition Utils.cpp:1070
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SliceOptions ForwardSliceOptions
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition Utils.cpp:1315
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition Utils.cpp:368
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition Utils.cpp:523
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition Utils.cpp:241
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
FailureOr< scf::ParallelOp > parallelLoopUnrollByFactors(scf::ParallelOp op, ArrayRef< uint64_t > unrollFactors, RewriterBase &rewriter, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, IRMapping *clonedToSrcOpsMap=nullptr)
Unroll this scf::Parallel loop by the specified unroll factors.
Definition Utils.cpp:1582
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1294
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition Utils.cpp:115
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition Utils.cpp:775
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition Utils.cpp:1374
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition Utils.cpp:987
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition Utils.cpp:1427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition Utils.cpp:1332
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition Utils.cpp:704
SmallVector< scf::ForOp, 8 > Loops
Tile a nest of standard for loops rooted at rootForOp by finding such parametric tile sizes that the ...
Definition Utils.h:154
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition Utils.cpp:1480
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition Utils.h:116
std::optional< scf::ForOp > mainLoopOp
Definition Utils.h:115
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.