MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/SmallVectorExtras.h"
29#include "llvm/Support/DebugLog.h"
30#include <cstdint>
31
32using namespace mlir;
33
34#define DEBUG_TYPE "scf-utils"
35
37 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
38 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
39 bool replaceIterOperandsUsesInLoop) {
40 if (loopNest.empty())
41 return {};
42 // This method is recursive (to make it more readable). Adding an
43 // assertion here to limit the recursion. (See
44 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
45 assert(loopNest.size() <= 10 &&
46 "exceeded recursion limit when yielding value from loop nest");
47
48 // To yield a value from a perfectly nested loop nest, the following
49 // pattern needs to be created, i.e. starting with
50 //
51 // ```mlir
52 // scf.for .. {
53 // scf.for .. {
54 // scf.for .. {
55 // %value = ...
56 // }
57 // }
58 // }
59 // ```
60 //
61 // needs to be modified to
62 //
63 // ```mlir
64 // %0 = scf.for .. iter_args(%arg0 = %init) {
65 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
66 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
67 // %value = ...
68 // scf.yield %value
69 // }
70 // scf.yield %2
71 // }
72 // scf.yield %1
73 // }
74 // ```
75 //
76 // The inner most loop is handled using the `replaceWithAdditionalYields`
77 // that works on a single loop.
78 if (loopNest.size() == 1) {
79 auto innerMostLoop =
80 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
81 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
82 newYieldValuesFn));
83 return {innerMostLoop};
84 }
85 // The outer loops are modified by calling this method recursively
86 // - The return value of the inner loop is the value yielded by this loop.
87 // - The region iter args of this loop are the init_args for the inner loop.
88 SmallVector<scf::ForOp> newLoopNest;
90 [&](OpBuilder &innerBuilder, Location loc,
92 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
93 innerNewBBArgs, newYieldValuesFn,
94 replaceIterOperandsUsesInLoop);
95 return llvm::map_to_vector(
96 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
97 [](OpResult r) -> Value { return r; });
98 };
99 scf::ForOp outerMostLoop =
100 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
101 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
102 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
103 return newLoopNest;
104}
105
106/// Outline a region with a single block into a new FuncOp.
107/// Assumes the FuncOp result types is the type of the yielded operands of the
108/// single block. This constraint makes it easy to determine the result.
109/// This method also clones the `arith::ConstantIndexOp` at the start of
110/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
111/// provided, it will be set to point to the operation that calls the outlined
112/// function.
113// TODO: support more than single-block regions.
114// TODO: more flexible constant handling.
115FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
116 Location loc,
117 Region &region,
118 StringRef funcName,
119 func::CallOp *callOp) {
120 assert(!funcName.empty() && "funcName cannot be empty");
121 if (!region.hasOneBlock())
122 return failure();
123
124 Block *originalBlock = &region.front();
125 Operation *originalTerminator = originalBlock->getTerminator();
126
127 // Outline before current function.
128 OpBuilder::InsertionGuard g(rewriter);
129 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
130
131 SetVector<Value> captures;
132 getUsedValuesDefinedAbove(region, captures);
133
134 ValueRange outlinedValues(captures.getArrayRef());
135 SmallVector<Type> outlinedFuncArgTypes;
136 SmallVector<Location> outlinedFuncArgLocs;
137 // Region's arguments are exactly the first block's arguments as per
138 // Region::getArguments().
139 // Func's arguments are cat(regions's arguments, captures arguments).
140 for (BlockArgument arg : region.getArguments()) {
141 outlinedFuncArgTypes.push_back(arg.getType());
142 outlinedFuncArgLocs.push_back(arg.getLoc());
143 }
144 for (Value value : outlinedValues) {
145 outlinedFuncArgTypes.push_back(value.getType());
146 outlinedFuncArgLocs.push_back(value.getLoc());
147 }
148 FunctionType outlinedFuncType =
149 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
150 originalTerminator->getOperandTypes());
151 auto outlinedFunc =
152 func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
153 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
154
155 // Merge blocks while replacing the original block operands.
156 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
157 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
158 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
159 {
160 OpBuilder::InsertionGuard g(rewriter);
161 rewriter.setInsertionPointToEnd(outlinedFuncBody);
162 rewriter.mergeBlocks(
163 originalBlock, outlinedFuncBody,
164 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
165 // Explicitly set up a new ReturnOp terminator.
166 rewriter.setInsertionPointToEnd(outlinedFuncBody);
167 func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
168 originalTerminator->getOperands());
169 }
170
171 // Reconstruct the block that was deleted and add a
172 // terminator(call_results).
173 Block *newBlock = rewriter.createBlock(
174 &region, region.begin(),
175 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
176 ArrayRef<Location>(outlinedFuncArgLocs)
177 .take_front(numOriginalBlockArguments));
178 {
179 OpBuilder::InsertionGuard g(rewriter);
180 rewriter.setInsertionPointToEnd(newBlock);
181 SmallVector<Value> callValues;
182 llvm::append_range(callValues, newBlock->getArguments());
183 llvm::append_range(callValues, outlinedValues);
184 auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
185 if (callOp)
186 *callOp = call;
187
188 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
189 // Clone `originalTerminator` to take the callOp results then erase it from
190 // `outlinedFuncBody`.
191 IRMapping bvm;
192 bvm.map(originalTerminator->getOperands(), call->getResults());
193 rewriter.clone(*originalTerminator, bvm);
194 rewriter.eraseOp(originalTerminator);
195 }
196
197 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
198 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
199 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
200 outlinedValues.size()))) {
201 Value orig = std::get<0>(it);
202 Value repl = std::get<1>(it);
203 {
204 OpBuilder::InsertionGuard g(rewriter);
205 rewriter.setInsertionPointToStart(outlinedFuncBody);
207 repl = rewriter.clone(*cst)->getResult(0);
208 }
209 }
210 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
211 return outlinedFunc->isProperAncestor(opOperand.getOwner());
212 });
213 }
214
215 return outlinedFunc;
216}
217
218LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
219 func::FuncOp *thenFn, StringRef thenFnName,
220 func::FuncOp *elseFn, StringRef elseFnName) {
221 IRRewriter rewriter(b);
222 Location loc = ifOp.getLoc();
223 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
224 if (thenFn && !ifOp.getThenRegion().empty()) {
225 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
226 rewriter, loc, ifOp.getThenRegion(), thenFnName);
227 if (failed(outlinedFuncOpOrFailure))
228 return failure();
229 *thenFn = *outlinedFuncOpOrFailure;
230 }
231 if (elseFn && !ifOp.getElseRegion().empty()) {
232 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
233 rewriter, loc, ifOp.getElseRegion(), elseFnName);
234 if (failed(outlinedFuncOpOrFailure))
235 return failure();
236 *elseFn = *outlinedFuncOpOrFailure;
237 }
238 return success();
239}
240
243 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
244 bool rootEnclosesPloops = false;
245 for (Region &region : rootOp->getRegions()) {
246 for (Block &block : region.getBlocks()) {
247 for (Operation &op : block) {
248 bool enclosesPloops = getInnermostParallelLoops(&op, result);
249 rootEnclosesPloops |= enclosesPloops;
250 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
251 rootEnclosesPloops = true;
252
253 // Collect parallel loop if it is an innermost one.
254 if (!enclosesPloops)
255 result.push_back(ploop);
256 }
257 }
258 }
259 }
260 return rootEnclosesPloops;
261}
262
263// Build the IR that performs ceil division of a positive value by a constant:
264// ceildiv(a, B) = divis(a + (B-1), B)
265// where divis is rounding-to-zero division.
266static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
267 int64_t divisor) {
268 assert(divisor > 0 && "expected positive divisor");
269 assert(dividend.getType().isIntOrIndex() &&
270 "expected integer or index-typed value");
271
272 Value divisorMinusOneCst = arith::ConstantOp::create(
273 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
274 Value divisorCst = arith::ConstantOp::create(
275 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
276 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
277 return arith::DivUIOp::create(builder, loc, sum, divisorCst);
278}
279
280// Build the IR that performs ceil division of a positive value by another
281// positive value:
282// ceildiv(a, b) = divis(a + (b - 1), b)
283// where divis is rounding-to-zero division.
284static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
285 Value divisor) {
286 assert(dividend.getType().isIntOrIndex() &&
287 "expected integer or index-typed value");
288 Value cstOne = arith::ConstantOp::create(
289 builder, loc, builder.getOneAttr(dividend.getType()));
290 Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
291 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
292 return arith::DivUIOp::create(builder, loc, sum, divisor);
293}
294
296 Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
297 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
298 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
299 ValueRange iterArgs, ValueRange yieldedValues,
300 IRMapping *clonedToSrcOpsMap) {
301
302 // Check if the op was cloned from another source op, and return it if found
303 // (or the same op if not found)
304 auto findOriginalSrcOp =
305 [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
306 Operation *srcOp = op;
307 // If the source op derives from another op: traverse the chain to find the
308 // original source op
309 while (srcOp && clonedToSrcOpsMap.contains(srcOp))
310 srcOp = clonedToSrcOpsMap.lookup(srcOp);
311 return srcOp;
312 };
313
314 // Builder to insert unrolled bodies just before the terminator of the body of
315 // the loop.
316 auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
317
318 static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
319 if (!annotateFn)
320 annotateFn = noopAnnotateFn;
321
322 // Keep a pointer to the last non-terminator operation in the original block
323 // so that we know what to clone (since we are doing this in-place).
324 Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
325
326 // Unroll the contents of the loop body (append unrollFactor - 1 additional
327 // copies).
328 SmallVector<Value, 4> lastYielded(yieldedValues);
329
330 for (unsigned i = 1; i < unrollFactor; i++) {
331 // Prepare operand map.
332 IRMapping operandMap;
333 operandMap.map(iterArgs, lastYielded);
334
335 // If the induction variable is used, create a remapping to the value for
336 // this unrolled instance.
337 if (!iv.use_empty()) {
338 Value ivUnroll = ivRemapFn(i, iv, builder);
339 operandMap.map(iv, ivUnroll);
340 }
341
342 // Clone the original body of 'forOp'.
343 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
344 Operation *srcOp = &(*it);
345 Operation *clonedOp = builder.clone(*srcOp, operandMap);
346 annotateFn(i, clonedOp, builder);
347 if (clonedToSrcOpsMap)
348 clonedToSrcOpsMap->map(clonedOp,
349 findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
350 }
351
352 // Update yielded values.
353 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
354 lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
355 }
356
357 // Make sure we annotate the Ops in the original body. We do this last so that
358 // any annotations are not copied into the cloned Ops above.
359 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
360 annotateFn(0, &*it, builder);
361
362 // Update operands of the yield statement.
363 loopBodyBlock->getTerminator()->setOperands(lastYielded);
364}
365
366/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
367/// epilogue loop, if the loop is unrolled.
368FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
369 scf::ForOp forOp, uint64_t unrollFactor,
370 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
371 assert(unrollFactor > 0 && "expected positive unroll factor");
372
373 // Return if the loop body is empty.
374 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
375 return UnrolledLoopInfo{forOp, std::nullopt};
376
377 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
378 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
379 OpBuilder boundsBuilder(forOp);
380 IRRewriter rewriter(forOp.getContext());
381 auto loc = forOp.getLoc();
382 Value step = forOp.getStep();
383 Value upperBoundUnrolled;
384 Value stepUnrolled;
385 bool generateEpilogueLoop = true;
386
387 std::optional<APInt> constTripCount = forOp.getStaticTripCount();
388 if (constTripCount) {
389 // Constant loop bounds computation.
390 int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
391 int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
392 int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
393 if (unrollFactor == 1) {
394 if (constTripCount->isOne() &&
395 failed(forOp.promoteIfSingleIteration(rewriter)))
396 return failure();
397 return UnrolledLoopInfo{forOp, std::nullopt};
398 }
399
400 uint64_t tripCount = constTripCount->getZExtValue();
401 uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor;
402 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
403 int64_t stepUnrolledCst = stepCst * unrollFactor;
404
405 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
406 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
407 if (generateEpilogueLoop)
408 upperBoundUnrolled = arith::ConstantOp::create(
409 boundsBuilder, loc,
410 boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
411 upperBoundUnrolledCst));
412 else
413 upperBoundUnrolled = forOp.getUpperBound();
414
415 // Create constant for 'stepUnrolled'.
416 stepUnrolled =
417 stepCst == stepUnrolledCst
418 ? step
419 : arith::ConstantOp::create(boundsBuilder, loc,
420 boundsBuilder.getIntegerAttr(
421 step.getType(), stepUnrolledCst));
422 } else {
423 // Dynamic loop bounds computation.
424 // TODO: Add dynamic asserts for negative lb/ub/step, or
425 // consider using ceilDiv from AffineApplyExpander.
426 auto lowerBound = forOp.getLowerBound();
427 auto upperBound = forOp.getUpperBound();
428 Value diff =
429 arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
430 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
431 Value unrollFactorCst = arith::ConstantOp::create(
432 boundsBuilder, loc,
433 boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
434 Value tripCountRem =
435 arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
436 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
437 Value tripCountEvenMultiple =
438 arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
439 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
440 upperBoundUnrolled = arith::AddIOp::create(
441 boundsBuilder, loc, lowerBound,
442 arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
443 // Scale 'step' by 'unrollFactor'.
444 stepUnrolled =
445 arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
446 }
447
448 UnrolledLoopInfo resultLoops;
449
450 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
451 if (generateEpilogueLoop) {
452 OpBuilder epilogueBuilder(forOp->getContext());
453 epilogueBuilder.setInsertionPointAfter(forOp);
454 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
455 epilogueForOp.setLowerBound(upperBoundUnrolled);
456
457 // Update uses of loop results.
458 auto results = forOp.getResults();
459 auto epilogueResults = epilogueForOp.getResults();
460
461 for (auto e : llvm::zip(results, epilogueResults)) {
462 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
463 }
464 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
465 epilogueForOp.getInitArgs().size(), results);
466 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
467 resultLoops.epilogueLoopOp = epilogueForOp;
468 }
469
470 // Create unrolled loop.
471 forOp.setUpperBound(upperBoundUnrolled);
472 forOp.setStep(stepUnrolled);
473
474 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
475 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
476
478 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
479 [&](unsigned i, Value iv, OpBuilder b) {
480 // iv' = iv + step * i;
481 auto stride = arith::MulIOp::create(
482 b, loc, step,
483 arith::ConstantOp::create(b, loc,
484 b.getIntegerAttr(iv.getType(), i)));
485 return arith::AddIOp::create(b, loc, iv, stride);
486 },
487 annotateFn, iterArgs, yieldedValues);
488 // Promote the loop body up if this has turned into a single iteration loop.
489 if (forOp.promoteIfSingleIteration(rewriter).failed())
490 resultLoops.mainLoopOp = forOp;
491 return resultLoops;
492}
493
494/// Unrolls this loop completely.
495LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
496 IRRewriter rewriter(forOp.getContext());
497 std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
498 if (!mayBeConstantTripCount.has_value())
499 return failure();
500 const APInt &tripCount = *mayBeConstantTripCount;
501 if (tripCount.isZero())
502 return success();
503 if (tripCount.isOne())
504 return forOp.promoteIfSingleIteration(rewriter);
505 return loopUnrollByFactor(forOp, tripCount.getZExtValue());
506}
507
508/// Check if bounds of all inner loops are defined outside of `forOp`
509/// and return false if not.
510static bool areInnerBoundsInvariant(scf::ForOp forOp) {
511 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
512 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
513 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
514 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
515 return WalkResult::interrupt();
516
517 return WalkResult::advance();
518 });
519 return !walkResult.wasInterrupted();
520}
521
522/// Unrolls and jams this loop by the specified factor.
523LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
524 uint64_t unrollJamFactor) {
525 assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
526
527 if (unrollJamFactor == 1)
528 return success();
529
530 // If any control operand of any inner loop of `forOp` is defined within
531 // `forOp`, no unroll jam.
532 if (!areInnerBoundsInvariant(forOp)) {
533 LDBG() << "failed to unroll and jam: inner bounds are not invariant";
534 return failure();
535 }
536
537 // Currently, for operations with results are not supported.
538 if (forOp->getNumResults() > 0) {
539 LDBG() << "failed to unroll and jam: unsupported loop with results";
540 return failure();
541 }
542
543 // Currently, only constant trip count that divided by the unroll factor is
544 // supported.
545 std::optional<APInt> tripCount = forOp.getStaticTripCount();
546 if (!tripCount.has_value()) {
547 // If the trip count is dynamic, do not unroll & jam.
548 LDBG() << "failed to unroll and jam: trip count could not be determined";
549 return failure();
550 }
551 uint64_t tripCountValue = tripCount->getZExtValue();
552 if (tripCountValue == 0)
553 return success();
554 if (unrollJamFactor > tripCountValue) {
555 LDBG() << "unroll and jam factor is greater than trip count, set factor to "
556 "trip "
557 "count";
558 unrollJamFactor = tripCountValue;
559 } else if (tripCountValue % unrollJamFactor != 0) {
560 LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
561 "multiple of unroll jam factor";
562 return failure();
563 }
564
565 // Nothing in the loop body other than the terminator.
566 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
567 return success();
568
569 // Gather all sub-blocks to jam upon the loop being unrolled.
571 jbg.walk(forOp);
572 auto &subBlocks = jbg.subBlocks;
573
574 // Collect inner loops.
575 SmallVector<scf::ForOp> innerLoops;
576 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
577
578 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
579 // iteration. There are (`unrollJamFactor` - 1) iterations.
580 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
581
582 // For any loop with iter_args, replace it with a new loop that has
583 // `unrollJamFactor` copies of its iterOperands, iter_args and yield
584 // operands.
585 SmallVector<scf::ForOp> newInnerLoops;
586 IRRewriter rewriter(forOp.getContext());
587 for (scf::ForOp oldForOp : innerLoops) {
588 SmallVector<Value> dupIterOperands, dupYieldOperands;
589 ValueRange oldIterOperands = oldForOp.getInits();
590 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
591 ValueRange oldYieldOperands =
592 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
593 // Get additional iterOperands, iterArgs, and yield operands. We will
594 // fix iterOperands and yield operands after cloning of sub-blocks.
595 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
596 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
597 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
598 }
599 // Create a new loop with additional iterOperands, iter_args and yield
600 // operands. This new loop will take the loop body of the original loop.
601 bool forOpReplaced = oldForOp == forOp;
602 scf::ForOp newForOp =
603 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
604 rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
605 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
606 return dupYieldOperands;
607 }));
608 newInnerLoops.push_back(newForOp);
609 // `forOp` has been replaced with a new loop.
610 if (forOpReplaced)
611 forOp = newForOp;
612 // Update `operandMaps` for `newForOp` iterArgs and results.
613 ValueRange newIterArgs = newForOp.getRegionIterArgs();
614 unsigned oldNumIterArgs = oldIterArgs.size();
615 ValueRange newResults = newForOp.getResults();
616 unsigned oldNumResults = newResults.size() / unrollJamFactor;
617 assert(oldNumIterArgs == oldNumResults &&
618 "oldNumIterArgs must be the same as oldNumResults");
619 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
620 for (unsigned j = 0; j < oldNumIterArgs; ++j) {
621 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
622 // results. Update `operandMaps[i - 1]` to map old iterArgs and results
623 // to those in the `i`th new set.
624 operandMaps[i - 1].map(newIterArgs[j],
625 newIterArgs[i * oldNumIterArgs + j]);
626 operandMaps[i - 1].map(newResults[j],
627 newResults[i * oldNumResults + j]);
628 }
629 }
630 }
631
632 // Scale the step of loop being unroll-jammed by the unroll-jam factor.
633 rewriter.setInsertionPoint(forOp);
634 int64_t step = forOp.getConstantStep()->getSExtValue();
635 auto newStep = rewriter.createOrFold<arith::MulIOp>(
636 forOp.getLoc(), forOp.getStep(),
637 rewriter.createOrFold<arith::ConstantOp>(
638 forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
639 forOp.setStep(newStep);
640 auto forOpIV = forOp.getInductionVar();
641
642 // Unroll and jam (appends unrollJamFactor - 1 additional copies).
643 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
644 for (auto &subBlock : subBlocks) {
645 // Builder to insert unroll-jammed bodies. Insert right at the end of
646 // sub-block.
647 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
648
649 // If the induction variable is used, create a remapping to the value for
650 // this unrolled instance.
651 if (!forOpIV.use_empty()) {
652 // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
653 auto ivTag = builder.createOrFold<arith::ConstantOp>(
654 forOp.getLoc(), builder.getIndexAttr(step * i));
655 auto ivUnroll =
656 builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
657 operandMaps[i - 1].map(forOpIV, ivUnroll);
658 }
659 // Clone the sub-block being unroll-jammed.
660 for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
661 builder.clone(*it, operandMaps[i - 1]);
662 }
663 // Fix iterOperands and yield op operands of newly created loops.
664 for (auto newForOp : newInnerLoops) {
665 unsigned oldNumIterOperands =
666 newForOp.getNumRegionIterArgs() / unrollJamFactor;
667 unsigned numControlOperands = newForOp.getNumControlOperands();
668 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
669 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
670 assert(oldNumIterOperands == oldNumYieldOperands &&
671 "oldNumIterOperands must be the same as oldNumYieldOperands");
672 for (unsigned j = 0; j < oldNumIterOperands; ++j) {
673 // The `i`th duplication of an old iterOperand or yield op operand
674 // needs to be replaced with a mapped value from `operandMaps[i - 1]`
675 // if such mapped value exists.
676 newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
677 operandMaps[i - 1].lookupOrDefault(
678 newForOp.getOperand(numControlOperands + j)));
679 yieldOp.setOperand(
680 i * oldNumYieldOperands + j,
681 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
682 }
683 }
684 }
685
686 // Promote the loop body up if this has turned into a single iteration loop.
687 (void)forOp.promoteIfSingleIteration(rewriter);
688 return success();
689}
690
692 Location loc, OpFoldResult lb,
694 OpFoldResult step) {
695 Range normalizedLoopBounds;
696 normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
697 normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
698 AffineExpr s0, s1, s2;
699 bindSymbols(rewriter.getContext(), s0, s1, s2);
700 AffineExpr e = (s1 - s0).ceilDiv(s2);
701 normalizedLoopBounds.size =
702 affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
703 return normalizedLoopBounds;
704}
705
708 OpFoldResult step) {
709 if (getType(lb).isIndex()) {
710 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
711 }
712 // For non-index types, generate `arith` instructions
713 // Check if the loop is already known to have a constant zero lower bound or
714 // a constant one step.
715 bool isZeroBased = false;
716 if (auto lbCst = getConstantIntValue(lb))
717 isZeroBased = lbCst.value() == 0;
718
719 bool isStepOne = false;
720 if (auto stepCst = getConstantIntValue(step))
721 isStepOne = stepCst.value() == 1;
722
723 Type rangeType = getType(lb);
724 assert(rangeType == getType(ub) && rangeType == getType(step) &&
725 "expected matching types");
726
727 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
728 // assuming the step is strictly positive. Update the bounds and the step
729 // of the loop to go from 0 to the number of iterations, if necessary.
730 if (isZeroBased && isStepOne)
731 return {lb, ub, step};
732
733 OpFoldResult diff = ub;
734 if (!isZeroBased) {
735 diff = rewriter.createOrFold<arith::SubIOp>(
736 loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
737 getValueOrCreateConstantIntOp(rewriter, loc, lb));
738 }
739 OpFoldResult newUpperBound = diff;
740 if (!isStepOne) {
741 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
742 loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
743 getValueOrCreateConstantIntOp(rewriter, loc, step));
744 }
745
746 OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
747 OpFoldResult newStep = rewriter.getOneAttr(rangeType);
748
749 return {newLowerBound, newUpperBound, newStep};
750}
751
753 Location loc,
754 Value normalizedIv,
755 OpFoldResult origLb,
756 OpFoldResult origStep) {
757 AffineExpr d0, s0, s1;
758 bindSymbols(rewriter.getContext(), s0, s1);
759 bindDims(rewriter.getContext(), d0);
760 AffineExpr e = d0 * s1 + s0;
762 rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
763 Value denormalizedIvVal =
764 getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
765 SmallPtrSet<Operation *, 1> preservedUses;
766 // If an `affine.apply` operation is generated for denormalization, the use
767 // of `origLb` in those ops must not be replaced. These arent not generated
768 // when `origLb == 0` and `origStep == 1`.
769 if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
770 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
771 preservedUses.insert(preservedUse);
772 }
773 }
774 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
775}
776
778 Value normalizedIv, OpFoldResult origLb,
779 OpFoldResult origStep) {
780 if (getType(origLb).isIndex()) {
781 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
782 origLb, origStep);
783 }
784 Value denormalizedIv;
786 bool isStepOne = isOneInteger(origStep);
787 bool isZeroBased = isZeroInteger(origLb);
788
789 Value scaled = normalizedIv;
790 if (!isStepOne) {
791 Value origStepValue =
792 getValueOrCreateConstantIntOp(rewriter, loc, origStep);
793 scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
794 preserve.insert(scaled.getDefiningOp());
795 }
796 denormalizedIv = scaled;
797 if (!isZeroBased) {
798 Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
799 denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
800 preserve.insert(denormalizedIv.getDefiningOp());
801 }
802
803 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
804}
805
807 ArrayRef<OpFoldResult> values) {
808 assert(!values.empty() && "unexecpted empty array");
809 AffineExpr s0, s1;
810 bindSymbols(rewriter.getContext(), s0, s1);
811 AffineExpr mul = s0 * s1;
812 OpFoldResult products = rewriter.getIndexAttr(1);
813 for (auto v : values) {
815 rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
816 }
817 return products;
818}
819
820/// Helper function to multiply a sequence of values.
822 ArrayRef<Value> values) {
823 assert(!values.empty() && "unexpected empty list");
824 if (getType(values.front()).isIndex()) {
826 OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
827 return getValueOrCreateConstantIndexOp(rewriter, loc, product);
828 }
829 std::optional<Value> productOf;
830 for (auto v : values) {
831 auto vOne = getConstantIntValue(v);
832 if (vOne && vOne.value() == 1)
833 continue;
834 if (productOf)
835 productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
836 .getResult();
837 else
838 productOf = v;
839 }
840 if (!productOf) {
841 productOf = arith::ConstantOp::create(
842 rewriter, loc, rewriter.getOneAttr(getType(values.front())))
843 .getResult();
844 }
845 return productOf.value();
846}
847
848/// For each original loop, the value of the
849/// induction variable can be obtained by dividing the induction variable of
850/// the linearized loop by the total number of iterations of the loops nested
851/// in it modulo the number of iterations in this loop (remove the values
852/// related to the outer loops):
853/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
854/// Compute these iteratively from the innermost loop by creating a "running
855/// quotient" of division by the range.
856static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
858 Value linearizedIv, ArrayRef<Value> ubs) {
859
860 if (linearizedIv.getType().isIndex()) {
861 Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
862 rewriter, loc, linearizedIv, ubs);
863 auto resultVals = llvm::map_to_vector(
864 delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
865 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
866 }
867
868 SmallVector<Value> delinearizedIvs(ubs.size());
869 SmallPtrSet<Operation *, 2> preservedUsers;
870
871 llvm::BitVector isUbOne(ubs.size());
872 for (auto [index, ub] : llvm::enumerate(ubs)) {
873 auto ubCst = getConstantIntValue(ub);
874 if (ubCst && ubCst.value() == 1)
875 isUbOne.set(index);
876 }
877
878 // Prune the lead ubs that are all ones.
879 unsigned numLeadingOneUbs = 0;
880 for (auto [index, ub] : llvm::enumerate(ubs)) {
881 if (!isUbOne.test(index)) {
882 break;
883 }
884 delinearizedIvs[index] = arith::ConstantOp::create(
885 rewriter, loc, rewriter.getZeroAttr(ub.getType()));
886 numLeadingOneUbs++;
887 }
888
889 Value previous = linearizedIv;
890 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
891 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
892 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
893 previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
894 preservedUsers.insert(previous.getDefiningOp());
895 }
896 Value iv = previous;
897 if (i != e - 1) {
898 if (!isUbOne.test(idx)) {
899 iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
900 preservedUsers.insert(iv.getDefiningOp());
901 } else {
902 iv = arith::ConstantOp::create(
903 rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
904 }
905 }
906 delinearizedIvs[idx] = iv;
907 }
908 return {delinearizedIvs, preservedUsers};
909}
910
911LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
913 if (loops.size() < 2)
914 return failure();
915
916 scf::ForOp innermost = loops.back();
917 scf::ForOp outermost = loops.front();
918
919 // Bail out if any loop has a known zero step, as normalization
920 // would result in a division by zero.
921 for (auto loop : loops) {
922 if (auto step = getConstantIntValue(loop.getStep())) {
923 if (step.value() == 0) {
924 return failure();
925 }
926 }
927 }
928 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
929 // allows the following code to assume upperBound is the number of iterations.
930 for (auto loop : loops) {
931 OpBuilder::InsertionGuard g(rewriter);
932 rewriter.setInsertionPoint(outermost);
933 Value lb = loop.getLowerBound();
934 Value ub = loop.getUpperBound();
935 Value step = loop.getStep();
936 auto newLoopRange =
937 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
938
939 rewriter.modifyOpInPlace(loop, [&]() {
940 loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
941 newLoopRange.offset));
942 loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
943 newLoopRange.size));
944 loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
945 newLoopRange.stride));
946 });
947 rewriter.setInsertionPointToStart(innermost.getBody());
948 denormalizeInductionVariable(rewriter, loop.getLoc(),
949 loop.getInductionVar(), lb, step);
950 }
951
952 // 2. Emit code computing the upper bound of the coalesced loop as product
953 // of the number of iterations of all loops.
954 OpBuilder::InsertionGuard g(rewriter);
955 rewriter.setInsertionPoint(outermost);
956 Location loc = outermost.getLoc();
957 SmallVector<Value> upperBounds = llvm::map_to_vector(
958 loops, [](auto loop) { return loop.getUpperBound(); });
959 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
960 outermost.setUpperBound(upperBound);
961
962 // Insert delinearization at the start of the outermost loop body.
963 rewriter.setInsertionPointToStart(outermost.getBody());
964 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
965 rewriter, loc, outermost.getInductionVar(), upperBounds);
966 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
967 preservedUsers);
968
969 for (int i = loops.size() - 1; i > 0; --i) {
970 auto outerLoop = loops[i - 1];
971 auto innerLoop = loops[i];
972
973 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
974 auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
975 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
976 for (Value &yieldedVal : yieldedVals) {
977 // The yielded value may be an iteration argument of the inner loop
978 // which is about to be inlined.
979 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
980 if (iter != innerLoop.getRegionIterArgs().end()) {
981 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
982 // `outerLoop` iter args identical to the `innerLoop` init args.
983 assert(iterArgIndex < innerLoop.getInitArgs().size());
984 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
985 }
986 }
987 rewriter.eraseOp(innerTerminator);
988
989 SmallVector<Value> innerBlockArgs;
990 innerBlockArgs.push_back(delinearizeIvs[i]);
991 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
992 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
993 Block::iterator(innerLoop), innerBlockArgs);
994 rewriter.replaceOp(innerLoop, yieldedVals);
995 }
996 return success();
997}
998
1000 if (loops.empty()) {
1001 return failure();
1002 }
1003 IRRewriter rewriter(loops.front().getContext());
1004 return coalesceLoops(rewriter, loops);
1005}
1006
1007LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
1008 LogicalResult result(failure());
1010 getPerfectlyNestedLoops(loops, op);
1011
1012 // Look for a band of loops that can be coalesced, i.e. perfectly nested
1013 // loops with bounds defined above some loop.
1014
1015 // 1. For each loop, find above which parent loop its bounds operands are
1016 // defined.
1017 SmallVector<unsigned> operandsDefinedAbove(loops.size());
1018 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1019 operandsDefinedAbove[i] = i;
1020 for (unsigned j = 0; j < i; ++j) {
1021 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1022 loops[i].getUpperBound(),
1023 loops[i].getStep()};
1024 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1025 operandsDefinedAbove[i] = j;
1026 break;
1027 }
1028 }
1029 }
1030
1031 // 2. For each inner loop check that the iter_args for the immediately outer
1032 // loop are the init for the immediately inner loop and that the yields of the
1033 // return of the inner loop is the yield for the immediately outer loop. Keep
1034 // track of where the chain starts from for each loop.
1035 SmallVector<unsigned> iterArgChainStart(loops.size());
1036 iterArgChainStart[0] = 0;
1037 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1038 // By default set the start of the chain to itself.
1039 iterArgChainStart[i] = i;
1040 auto outerloop = loops[i - 1];
1041 auto innerLoop = loops[i];
1042 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1043 continue;
1044 }
1045 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1046 continue;
1047 }
1048 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1049 if (!llvm::equal(outerloopTerminator->getOperands(),
1050 innerLoop.getResults())) {
1051 continue;
1052 }
1053 iterArgChainStart[i] = iterArgChainStart[i - 1];
1054 }
1055
1056 // 3. Identify bands of loops such that the operands of all of them are
1057 // defined above the first loop in the band. Traverse the nest bottom-up
1058 // so that modifications don't invalidate the inner loops.
1059 for (unsigned end = loops.size(); end > 0; --end) {
1060 unsigned start = 0;
1061 for (; start < end - 1; ++start) {
1062 auto maxPos =
1063 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1064 std::next(operandsDefinedAbove.begin(), end));
1065 if (maxPos > start)
1066 continue;
1067 if (iterArgChainStart[end - 1] > start)
1068 continue;
1069 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1070 if (succeeded(coalesceLoops(band)))
1071 result = success();
1072 break;
1073 }
1074 // If a band was found and transformed, keep looking at the loops above
1075 // the outermost transformed loop.
1076 if (start != end - 1)
1077 end = start + 1;
1078 }
1079 return result;
1080}
1081
1083 RewriterBase &rewriter, scf::ParallelOp loops,
1084 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1085 OpBuilder::InsertionGuard g(rewriter);
1086 rewriter.setInsertionPoint(loops);
1087 Location loc = loops.getLoc();
1088
1089 // Presort combined dimensions.
1090 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1091 for (auto &dims : sortedDimensions)
1092 llvm::sort(dims);
1093
1094 // Normalize ParallelOp's iteration pattern.
1095 SmallVector<Value, 3> normalizedUpperBounds;
1096 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1097 OpBuilder::InsertionGuard g2(rewriter);
1098 rewriter.setInsertionPoint(loops);
1099 Value lb = loops.getLowerBound()[i];
1100 Value ub = loops.getUpperBound()[i];
1101 Value step = loops.getStep()[i];
1102 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1103 normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1104 rewriter, loops.getLoc(), newLoopRange.size));
1105
1106 rewriter.setInsertionPointToStart(loops.getBody());
1107 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1108 step);
1109 }
1110
1111 // Combine iteration spaces.
1112 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1113 auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1114 auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1115 for (auto &sortedDimension : sortedDimensions) {
1116 Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1117 for (auto idx : sortedDimension) {
1118 newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1119 normalizedUpperBounds[idx]);
1120 }
1121 lowerBounds.push_back(cst0);
1122 steps.push_back(cst1);
1123 upperBounds.push_back(newUpperBound);
1124 }
1125
1126 // Create new ParallelLoop with conversions to the original induction values.
1127 // The loop below uses divisions to get the relevant range of values in the
1128 // new induction value that represent each range of the original induction
1129 // value. The remainders then determine based on that range, which iteration
1130 // of the original induction value this represents. This is a normalized value
1131 // that is un-normalized already by the previous logic.
1132 auto newPloop = scf::ParallelOp::create(
1133 rewriter, loc, lowerBounds, upperBounds, steps,
1134 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1135 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1136 Value previous = ploopIVs[i];
1137 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1138 // Iterate over all except the last induction value.
1139 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1140 unsigned idx = combinedDimensions[i][j];
1141
1142 // Determine the current induction value's current loop iteration
1143 Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1144 normalizedUpperBounds[idx]);
1145 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1146 loops.getRegion());
1147
1148 // Remove the effect of the current induction value to prepare for
1149 // the next value.
1150 previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1151 normalizedUpperBounds[idx]);
1152 }
1153
1154 // The final induction value is just the remaining value.
1155 unsigned idx = combinedDimensions[i][0];
1156 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1157 previous, loops.getRegion());
1158 }
1159 });
1160
1161 // Replace the old loop with the new loop.
1162 loops.getBody()->back().erase();
1163 newPloop.getBody()->getOperations().splice(
1164 Block::iterator(newPloop.getBody()->back()),
1165 loops.getBody()->getOperations());
1166 loops.erase();
1167}
1168
1169// Hoist the ops within `outer` that appear before `inner`.
1170// Such ops include the ops that have been introduced by parametric tiling.
1171// Ops that come from triangular loops (i.e. that belong to the program slice
1172// rooted at `outer`) and ops that have side effects cannot be hoisted.
1173// Return failure when any op fails to hoist.
1174static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1175 SetVector<Operation *> forwardSlice;
1177 options.filter = [&inner](Operation *op) {
1178 return op != inner.getOperation();
1179 };
1180 getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1181 LogicalResult status = success();
1183 for (auto &op : outer.getBody()->without_terminator()) {
1184 // Stop when encountering the inner loop.
1185 if (&op == inner.getOperation())
1186 break;
1187 // Skip over non-hoistable ops.
1188 if (forwardSlice.count(&op) > 0) {
1189 status = failure();
1190 continue;
1191 }
1192 // Skip intermediate scf::ForOp, these are not considered a failure.
1193 if (isa<scf::ForOp>(op))
1194 continue;
1195 // Skip other ops with regions.
1196 if (op.getNumRegions() > 0) {
1197 status = failure();
1198 continue;
1199 }
1200 // Skip if op has side effects.
1201 // TODO: loads to immutable memory regions are ok.
1202 if (!isMemoryEffectFree(&op)) {
1203 status = failure();
1204 continue;
1205 }
1206 toHoist.push_back(&op);
1207 }
1208 auto *outerForOp = outer.getOperation();
1209 for (auto *op : toHoist)
1210 op->moveBefore(outerForOp);
1211 return status;
1212}
1213
1214// Traverse the interTile and intraTile loops and try to hoist ops such that
1215// bands of perfectly nested loops are isolated.
1216// Return failure if either perfect interTile or perfect intraTile bands cannot
1217// be formed.
1218static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1219 LogicalResult status = success();
1220 const Loops &interTile = tileLoops.first;
1221 const Loops &intraTile = tileLoops.second;
1222 auto size = interTile.size();
1223 assert(size == intraTile.size());
1224 if (size <= 1)
1225 return success();
1226 for (unsigned s = 1; s < size; ++s)
1227 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1228 : failure();
1229 for (unsigned s = 1; s < size; ++s)
1230 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1231 : failure();
1232 return status;
1233}
1234
1235/// Collect perfectly nested loops starting from `rootForOps`. Loops are
1236/// perfectly nested if each loop is the first and only non-terminator operation
1237/// in the parent loop. Collect at most `maxLoops` loops and append them to
1238/// `forOps`.
1239template <typename T>
1241 SmallVectorImpl<T> &forOps, T rootForOp,
1242 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1243 for (unsigned i = 0; i < maxLoops; ++i) {
1244 forOps.push_back(rootForOp);
1245 Block &body = rootForOp.getRegion().front();
1246 if (body.begin() != std::prev(body.end(), 2))
1247 return;
1248
1249 rootForOp = dyn_cast<T>(&body.front());
1250 if (!rootForOp)
1251 return;
1252 }
1253}
1254
1255static Loops stripmineSink(scf::ForOp forOp, Value factor,
1256 ArrayRef<scf::ForOp> targets) {
1257 assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
1258 auto originalStep = forOp.getStep();
1259 auto iv = forOp.getInductionVar();
1260
1261 OpBuilder b(forOp);
1262 forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1263
1264 Loops innerLoops;
1265 for (auto t : targets) {
1266 assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
1267
1268 // Save information for splicing ops out of t when done
1269 auto begin = t.getBody()->begin();
1270 auto nOps = t.getBody()->getOperations().size();
1271
1272 // Insert newForOp before the terminator of `t`.
1273 auto b = OpBuilder::atBlockTerminator((t.getBody()));
1274 Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1275 Value ub =
1276 arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1277
1278 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1279 auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1280 newForOp.getBody()->getOperations().splice(
1281 newForOp.getBody()->getOperations().begin(),
1282 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1283 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1284 newForOp.getRegion());
1285
1286 innerLoops.push_back(newForOp);
1287 }
1288
1289 return innerLoops;
1290}
1291
1292// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1293// Returns the new for operation, nested immediately under `target`.
1294template <typename SizeType>
1295static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1296 scf::ForOp target) {
1297 // TODO: Use cheap structural assertions that targets are nested under
1298 // forOp and that targets are not nested under each other when DominanceInfo
1299 // exposes the capability. It seems overkill to construct a whole function
1300 // dominance tree at this point.
1301 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1302 assert(res.size() == 1 && "Expected 1 inner forOp");
1303 return res[0];
1304}
1305
1307 ArrayRef<Value> sizes,
1308 ArrayRef<scf::ForOp> targets) {
1310 SmallVector<scf::ForOp, 8> currentTargets(targets);
1311 for (auto it : llvm::zip(forOps, sizes)) {
1312 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1313 res.push_back(step);
1314 currentTargets = step;
1315 }
1316 return res;
1317}
1318
1320 scf::ForOp target) {
1322 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1323 res.push_back(llvm::getSingleElement(loops));
1324 return res;
1325}
1326
1328 // Collect perfectly nested loops. If more size values provided than nested
1329 // loops available, truncate `sizes`.
1331 forOps.reserve(sizes.size());
1332 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1333 if (forOps.size() < sizes.size())
1334 sizes = sizes.take_front(forOps.size());
1335
1336 return ::tile(forOps, sizes, forOps.back());
1337}
1338
1340 scf::ForOp root) {
1341 getPerfectlyNestedLoopsImpl(nestedLoops, root);
1342}
1343
1345 ArrayRef<int64_t> sizes) {
1346 // Collect perfectly nested loops. If more size values provided than nested
1347 // loops available, truncate `sizes`.
1349 forOps.reserve(sizes.size());
1350 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1351 if (forOps.size() < sizes.size())
1352 sizes = sizes.take_front(forOps.size());
1353
1354 // Compute the tile sizes such that i-th outer loop executes size[i]
1355 // iterations. Given that the loop current executes
1356 // numIterations = ceildiv((upperBound - lowerBound), step)
1357 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1358 SmallVector<Value, 4> tileSizes;
1359 tileSizes.reserve(sizes.size());
1360 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1361 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1362
1363 auto forOp = forOps[i];
1364 OpBuilder builder(forOp);
1365 auto loc = forOp.getLoc();
1366 Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1367 forOp.getLowerBound());
1368 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1369 Value iterationsPerBlock =
1370 ceilDivPositive(builder, loc, numIterations, sizes[i]);
1371 tileSizes.push_back(iterationsPerBlock);
1372 }
1373
1374 // Call parametric tiling with the given sizes.
1375 auto intraTile = tile(forOps, tileSizes, forOps.back());
1376 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1377
1378 // TODO: for now we just ignore the result of band isolation.
1379 // In the future, mapping decisions may be impacted by the ability to
1380 // isolate perfectly nested bands.
1381 (void)tryIsolateBands(tileLoops);
1382
1383 return tileLoops;
1384}
1385
1387 scf::ForallOp source,
1388 RewriterBase &rewriter) {
1389 unsigned numTargetOuts = target.getNumResults();
1390 unsigned numSourceOuts = source.getNumResults();
1391
1392 // Create fused shared_outs.
1393 SmallVector<Value> fusedOuts;
1394 llvm::append_range(fusedOuts, target.getOutputs());
1395 llvm::append_range(fusedOuts, source.getOutputs());
1396
1397 // Create a new scf.forall op after the source loop.
1398 rewriter.setInsertionPointAfter(source);
1399 scf::ForallOp fusedLoop = scf::ForallOp::create(
1400 rewriter, source.getLoc(), source.getMixedLowerBound(),
1401 source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1402 source.getMapping());
1403
1404 // Map control operands.
1405 IRMapping mapping;
1406 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1407 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1408
1409 // Map shared outs.
1410 mapping.map(target.getRegionIterArgs(),
1411 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1412 mapping.map(source.getRegionIterArgs(),
1413 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1414
1415 // Append everything except the terminator into the fused operation.
1416 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1417 for (Operation &op : target.getBody()->without_terminator())
1418 rewriter.clone(op, mapping);
1419 for (Operation &op : source.getBody()->without_terminator())
1420 rewriter.clone(op, mapping);
1421
1422 // Fuse the old terminator in_parallel ops into the new one.
1423 scf::InParallelOp targetTerm = target.getTerminator();
1424 scf::InParallelOp sourceTerm = source.getTerminator();
1425 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1426 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1427 for (Operation &op : targetTerm.getYieldingOps())
1428 rewriter.clone(op, mapping);
1429 for (Operation &op : sourceTerm.getYieldingOps())
1430 rewriter.clone(op, mapping);
1431
1432 // Replace old loops by substituting their uses by results of the fused loop.
1433 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1434 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1435
1436 return fusedLoop;
1437}
1438
1440 scf::ForOp source,
1441 RewriterBase &rewriter) {
1442 assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1443 "incompatible signedness");
1444 unsigned numTargetOuts = target.getNumResults();
1445 unsigned numSourceOuts = source.getNumResults();
1446
1447 // Create fused init_args, with target's init_args before source's init_args.
1448 SmallVector<Value> fusedInitArgs;
1449 llvm::append_range(fusedInitArgs, target.getInitArgs());
1450 llvm::append_range(fusedInitArgs, source.getInitArgs());
1451
1452 // Create a new scf.for op after the source loop (with scf.yield terminator
1453 // (without arguments) only in case its init_args is empty).
1454 rewriter.setInsertionPointAfter(source);
1455 scf::ForOp fusedLoop = scf::ForOp::create(
1456 rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1457 source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
1458 source.getUnsignedCmp());
1459
1460 // Map original induction variables and operands to those of the fused loop.
1461 IRMapping mapping;
1462 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1463 mapping.map(target.getRegionIterArgs(),
1464 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1465 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1466 mapping.map(source.getRegionIterArgs(),
1467 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1468
1469 // Merge target's body into the new (fused) for loop and then source's body.
1470 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1471 for (Operation &op : target.getBody()->without_terminator())
1472 rewriter.clone(op, mapping);
1473 for (Operation &op : source.getBody()->without_terminator())
1474 rewriter.clone(op, mapping);
1475
1476 // Build fused yield results by appropriately mapping original yield operands.
1477 SmallVector<Value> yieldResults;
1478 for (Value operand : target.getBody()->getTerminator()->getOperands())
1479 yieldResults.push_back(mapping.lookupOrDefault(operand));
1480 for (Value operand : source.getBody()->getTerminator()->getOperands())
1481 yieldResults.push_back(mapping.lookupOrDefault(operand));
1482 if (!yieldResults.empty())
1483 scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1484
1485 // Replace old loops by substituting their uses by results of the fused loop.
1486 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1487 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1488
1489 return fusedLoop;
1490}
1491
1492FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1493 scf::ForallOp forallOp) {
1494 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1495 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1496 SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1497
1498 if (forallOp.isNormalized())
1499 return forallOp;
1500
1501 OpBuilder::InsertionGuard g(rewriter);
1502 auto loc = forallOp.getLoc();
1503 rewriter.setInsertionPoint(forallOp);
1505 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1506 Range normalizedLoopParams =
1507 emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1508 newUbs.push_back(normalizedLoopParams.size);
1509 }
1510 (void)foldDynamicIndexList(newUbs);
1511
1512 // Use the normalized builder since the lower bounds are always 0 and the
1513 // steps are always 1.
1514 auto normalizedForallOp = scf::ForallOp::create(
1515 rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1516 [](OpBuilder &, Location, ValueRange) {});
1517
1518 rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1519 normalizedForallOp.getBodyRegion(),
1520 normalizedForallOp.getBodyRegion().begin());
1521 // Remove the original empty block in the new loop.
1522 rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1523
1524 rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1525 // Update the users of the original loop variables.
1526 for (auto [idx, iv] :
1527 llvm::enumerate(normalizedForallOp.getInductionVars())) {
1528 auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1529 auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1530 denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1531 }
1532
1533 rewriter.replaceOp(forallOp, normalizedForallOp);
1534 return normalizedForallOp;
1535}
1536
1539 assert(!loops.empty() && "unexpected empty loop nest");
1540 if (loops.size() == 1)
1541 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1542 for (auto [outerLoop, innerLoop] :
1543 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1544 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1545 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1546 if (!outerFor || !innerFor)
1547 return false;
1548 auto outerBBArgs = outerFor.getRegionIterArgs();
1549 auto innerIterArgs = innerFor.getInitArgs();
1550 if (outerBBArgs.size() != innerIterArgs.size())
1551 return false;
1552
1553 for (auto [outerBBArg, innerIterArg] :
1554 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1555 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1556 innerIterArg != outerBBArg)
1557 return false;
1558 }
1559
1560 ValueRange outerYields =
1561 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1562 ValueRange innerResults = innerFor.getResults();
1563 if (outerYields.size() != innerResults.size())
1564 return false;
1565 for (auto [outerYield, innerResult] :
1566 llvm::zip_equal(outerYields, innerResults)) {
1567 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1568 outerYield != innerResult)
1569 return false;
1570 }
1571 }
1572 return true;
1573}
1574
1576mlir::getConstLoopBounds(mlir::LoopLikeOpInterface loopOp) {
1577 std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1578 std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1579 std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1580 if (!loBnds || !upBnds || !steps)
1581 return {};
1583 for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1584 auto lbCst = getConstantIntValue(lb);
1585 auto ubCst = getConstantIntValue(ub);
1586 auto stepCst = getConstantIntValue(step);
1587 if (!lbCst || !ubCst || !stepCst)
1588 return {};
1589 loopRanges.emplace_back(*lbCst, *ubCst, *stepCst);
1590 }
1591 return loopRanges;
1592}
1593
1595mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
1596 std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1597 std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1598 std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1599 if (!loBnds || !upBnds || !steps)
1600 return {};
1602 for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1603 // TODO(#178506): Signedness is not handled correctly here.
1604 std::optional<llvm::APInt> numIter = constantTripCount(
1605 lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
1606 if (!numIter)
1607 return {};
1608 tripCounts.push_back(*numIter);
1609 }
1610 return tripCounts;
1611}
1612
1613FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
1614 scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
1615 RewriterBase &rewriter,
1616 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1617 IRMapping *clonedToSrcOpsMap) {
1618 const unsigned numLoops = op.getNumLoops();
1619 assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
1620 "Expected positive unroll factors");
1621 assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
1622 "Expected non-empty unroll factors of size <= to the number of loops");
1623
1624 // Bail out if no valid unroll factors were provided
1625 if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
1626 return rewriter.notifyMatchFailure(
1627 op, "Unrolling not applied if all factors are 1");
1628
1629 // Return if the loop body is empty.
1630 if (llvm::hasSingleElement(op.getBody()->getOperations()))
1631 return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
1632
1633 // If the provided unroll factors do not cover all the loop dims, they are
1634 // applied to the inner loop dimensions.
1635 const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
1636
1637 // Make sure that the unroll factors divide the iteration space evenly
1638 // TODO: Support unrolling loops with dynamic iteration spaces.
1640 if (tripCounts.empty())
1641 return rewriter.notifyMatchFailure(
1642 op, "Failed to compute constant trip counts for the loop. Note that "
1643 "dynamic loop sizes are not supported.");
1644
1645 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1646 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1647 if (tripCounts[dimIdx].urem(unrollFactor) != 0)
1648 return rewriter.notifyMatchFailure(
1649 op, "Unroll factors don't divide the iteration space evenly");
1650 }
1651
1652 std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
1653 if (!maybeFoldSteps)
1654 return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
1656 for (auto step : *maybeFoldSteps)
1657 steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
1658
1659 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1660 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1661 if (unrollFactor == 1)
1662 continue;
1663 const size_t origStep = steps[dimIdx];
1664 const int64_t newStep = origStep * unrollFactor;
1665 IRMapping clonedToSrcOpsMap;
1666
1667 ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
1668 auto yieldedValues = op.getBody()->getTerminator()->getOperands();
1669
1671 op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
1672 [&](unsigned i, Value iv, OpBuilder b) {
1673 // iv' = iv + step * i;
1674 const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
1675 const auto map =
1676 b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
1677 return affine::AffineApplyOp::create(b, iv.getLoc(), map,
1678 ValueRange{iv});
1679 },
1680 /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
1681
1682 // Update loop step
1683 auto prevInsertPoint = rewriter.saveInsertionPoint();
1684 rewriter.setInsertionPoint(op);
1685 op.getStepMutable()[dimIdx].assign(
1686 arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
1687 rewriter.restoreInsertionPoint(prevInsertPoint);
1688 }
1689 return op;
1690}
return success()
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition Utils.cpp:806
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition Utils.cpp:1218
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition Utils.cpp:1240
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition Utils.cpp:1174
static Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition Utils.cpp:691
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition Utils.cpp:1255
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition Utils.cpp:266
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition Utils.cpp:821
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition Utils.cpp:857
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition Utils.cpp:752
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition Utils.cpp:510
static int64_t product(ArrayRef< int64_t > vals)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition Builders.h:387
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition Builders.h:254
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition Builders.h:392
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
operand_type_range getOperandTypes()
Definition Operation.h:426
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
Definition Operation.h:444
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition Utils.cpp:1339
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition Utils.cpp:1537
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition Utils.cpp:218
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition Utils.cpp:36
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition Utils.cpp:1007
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
void generateUnrolledLoop(Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues, IRMapping *clonedToSrcOpsMap=nullptr)
Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs' and 'yieldedValues' as the...
Definition Utils.cpp:295
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:103
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:495
llvm::SmallVector< llvm::APInt > getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp)
Get constant trip counts for each of the induction variables of the given loop operation.
Definition Utils.cpp:1595
std::pair< Loops, Loops > TileLoops
Definition Utils.h:156
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SmallVector< std::tuple< int64_t, int64_t, int64_t > > getConstLoopBounds(mlir::LoopLikeOpInterface loopOp)
Get constant loop bounds and steps for each of the induction variables of the given loop operation,...
Definition Utils.cpp:1576
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned > > combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition Utils.cpp:1082
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
SliceOptions ForwardSliceOptions
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition Utils.cpp:1327
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition Utils.cpp:368
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition Utils.cpp:523
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition Utils.cpp:241
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
FailureOr< scf::ParallelOp > parallelLoopUnrollByFactors(scf::ParallelOp op, ArrayRef< uint64_t > unrollFactors, RewriterBase &rewriter, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, IRMapping *clonedToSrcOpsMap=nullptr)
Unroll this scf::Parallel loop by the specified unroll factors.
Definition Utils.cpp:1613
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1306
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition Utils.cpp:115
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition Utils.cpp:777
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition Utils.cpp:1386
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition Utils.cpp:999
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition Utils.cpp:1439
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition Utils.cpp:1344
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition Utils.cpp:706
SmallVector< scf::ForOp, 8 > Loops
Tile a nest of standard for loops rooted at rootForOp by finding such parametric tile sizes that the ...
Definition Utils.h:155
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition Utils.cpp:1492
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition Utils.h:117
std::optional< scf::ForOp > mainLoopOp
Definition Utils.h:116
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.