MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous loop transformation routines.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/IRMapping.h"
25#include "llvm/ADT/APInt.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/ADT/SmallVectorExtras.h"
29#include "llvm/Support/DebugLog.h"
30#include <cstdint>
31
32using namespace mlir;
33
34#define DEBUG_TYPE "scf-utils"
35
37 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
38 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
39 bool replaceIterOperandsUsesInLoop) {
40 if (loopNest.empty())
41 return {};
42 // This method is recursive (to make it more readable). Adding an
43 // assertion here to limit the recursion. (See
44 // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
45 assert(loopNest.size() <= 10 &&
46 "exceeded recursion limit when yielding value from loop nest");
47
48 // To yield a value from a perfectly nested loop nest, the following
49 // pattern needs to be created, i.e. starting with
50 //
51 // ```mlir
52 // scf.for .. {
53 // scf.for .. {
54 // scf.for .. {
55 // %value = ...
56 // }
57 // }
58 // }
59 // ```
60 //
61 // needs to be modified to
62 //
63 // ```mlir
64 // %0 = scf.for .. iter_args(%arg0 = %init) {
65 // %1 = scf.for .. iter_args(%arg1 = %arg0) {
66 // %2 = scf.for .. iter_args(%arg2 = %arg1) {
67 // %value = ...
68 // scf.yield %value
69 // }
70 // scf.yield %2
71 // }
72 // scf.yield %1
73 // }
74 // ```
75 //
76 // The inner most loop is handled using the `replaceWithAdditionalYields`
77 // that works on a single loop.
78 if (loopNest.size() == 1) {
79 auto innerMostLoop =
80 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
81 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
82 newYieldValuesFn));
83 return {innerMostLoop};
84 }
85 // The outer loops are modified by calling this method recursively
86 // - The return value of the inner loop is the value yielded by this loop.
87 // - The region iter args of this loop are the init_args for the inner loop.
88 SmallVector<scf::ForOp> newLoopNest;
90 [&](OpBuilder &innerBuilder, Location loc,
92 newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
93 innerNewBBArgs, newYieldValuesFn,
94 replaceIterOperandsUsesInLoop);
95 return llvm::map_to_vector(
96 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
97 [](OpResult r) -> Value { return r; });
98 };
99 scf::ForOp outerMostLoop =
100 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
101 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
102 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
103 return newLoopNest;
104}
105
106/// Outline a region with a single block into a new FuncOp.
107/// Assumes the FuncOp result types is the type of the yielded operands of the
108/// single block. This constraint makes it easy to determine the result.
109/// This method also clones the `arith::ConstantIndexOp` at the start of
110/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
111/// provided, it will be set to point to the operation that calls the outlined
112/// function.
113// TODO: support more than single-block regions.
114// TODO: more flexible constant handling.
115FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
116 Location loc,
117 Region &region,
118 StringRef funcName,
119 func::CallOp *callOp) {
120 assert(!funcName.empty() && "funcName cannot be empty");
121 if (!region.hasOneBlock())
122 return failure();
123
124 Block *originalBlock = &region.front();
125 Operation *originalTerminator = originalBlock->getTerminator();
126
127 // Outline before current function.
128 OpBuilder::InsertionGuard g(rewriter);
129 rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
130
131 SetVector<Value> captures;
132 getUsedValuesDefinedAbove(region, captures);
133
134 ValueRange outlinedValues(captures.getArrayRef());
135 SmallVector<Type> outlinedFuncArgTypes;
136 SmallVector<Location> outlinedFuncArgLocs;
137 // Region's arguments are exactly the first block's arguments as per
138 // Region::getArguments().
139 // Func's arguments are cat(regions's arguments, captures arguments).
140 for (BlockArgument arg : region.getArguments()) {
141 outlinedFuncArgTypes.push_back(arg.getType());
142 outlinedFuncArgLocs.push_back(arg.getLoc());
143 }
144 for (Value value : outlinedValues) {
145 outlinedFuncArgTypes.push_back(value.getType());
146 outlinedFuncArgLocs.push_back(value.getLoc());
147 }
148 FunctionType outlinedFuncType =
149 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
150 originalTerminator->getOperandTypes());
151 auto outlinedFunc =
152 func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
153 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
154
155 // Merge blocks while replacing the original block operands.
156 // Warning: `mergeBlocks` erases the original block, reconstruct it later.
157 int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
158 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
159 {
160 OpBuilder::InsertionGuard g(rewriter);
161 rewriter.setInsertionPointToEnd(outlinedFuncBody);
162 rewriter.mergeBlocks(
163 originalBlock, outlinedFuncBody,
164 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
165 // Explicitly set up a new ReturnOp terminator.
166 rewriter.setInsertionPointToEnd(outlinedFuncBody);
167 func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
168 originalTerminator->getOperands());
169 }
170
171 // Reconstruct the block that was deleted and add a
172 // terminator(call_results).
173 Block *newBlock = rewriter.createBlock(
174 &region, region.begin(),
175 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
176 ArrayRef<Location>(outlinedFuncArgLocs)
177 .take_front(numOriginalBlockArguments));
178 {
179 OpBuilder::InsertionGuard g(rewriter);
180 rewriter.setInsertionPointToEnd(newBlock);
181 SmallVector<Value> callValues;
182 llvm::append_range(callValues, newBlock->getArguments());
183 llvm::append_range(callValues, outlinedValues);
184 auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
185 if (callOp)
186 *callOp = call;
187
188 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
189 // Clone `originalTerminator` to take the callOp results then erase it from
190 // `outlinedFuncBody`.
191 IRMapping bvm;
192 bvm.map(originalTerminator->getOperands(), call->getResults());
193 rewriter.clone(*originalTerminator, bvm);
194 rewriter.eraseOp(originalTerminator);
195 }
196
197 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
198 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
199 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
200 outlinedValues.size()))) {
201 Value orig = std::get<0>(it);
202 Value repl = std::get<1>(it);
203 {
204 OpBuilder::InsertionGuard g(rewriter);
205 rewriter.setInsertionPointToStart(outlinedFuncBody);
207 repl = rewriter.clone(*cst)->getResult(0);
208 }
209 }
210 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
211 return outlinedFunc->isProperAncestor(opOperand.getOwner());
212 });
213 }
214
215 return outlinedFunc;
216}
217
218LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
219 func::FuncOp *thenFn, StringRef thenFnName,
220 func::FuncOp *elseFn, StringRef elseFnName) {
221 IRRewriter rewriter(b);
222 Location loc = ifOp.getLoc();
223 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
224 if (thenFn && !ifOp.getThenRegion().empty()) {
225 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
226 rewriter, loc, ifOp.getThenRegion(), thenFnName);
227 if (failed(outlinedFuncOpOrFailure))
228 return failure();
229 *thenFn = *outlinedFuncOpOrFailure;
230 }
231 if (elseFn && !ifOp.getElseRegion().empty()) {
232 outlinedFuncOpOrFailure = outlineSingleBlockRegion(
233 rewriter, loc, ifOp.getElseRegion(), elseFnName);
234 if (failed(outlinedFuncOpOrFailure))
235 return failure();
236 *elseFn = *outlinedFuncOpOrFailure;
237 }
238 return success();
239}
240
243 assert(rootOp != nullptr && "Root operation must not be a nullptr.");
244 bool rootEnclosesPloops = false;
245 for (Region &region : rootOp->getRegions()) {
246 for (Block &block : region.getBlocks()) {
247 for (Operation &op : block) {
248 bool enclosesPloops = getInnermostParallelLoops(&op, result);
249 rootEnclosesPloops |= enclosesPloops;
250 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
251 rootEnclosesPloops = true;
252
253 // Collect parallel loop if it is an innermost one.
254 if (!enclosesPloops)
255 result.push_back(ploop);
256 }
257 }
258 }
259 }
260 return rootEnclosesPloops;
261}
262
263// Build the IR that performs ceil division of a positive value by a constant:
264// ceildiv(a, B) = divis(a + (B-1), B)
265// where divis is rounding-to-zero division.
266static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
267 int64_t divisor) {
268 assert(divisor > 0 && "expected positive divisor");
269 assert(dividend.getType().isIntOrIndex() &&
270 "expected integer or index-typed value");
271
272 Value divisorMinusOneCst = arith::ConstantOp::create(
273 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
274 Value divisorCst = arith::ConstantOp::create(
275 builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
276 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
277 return arith::DivUIOp::create(builder, loc, sum, divisorCst);
278}
279
280// Build the IR that performs ceil division of a positive value by another
281// positive value:
282// ceildiv(a, b) = divis(a + (b - 1), b)
283// where divis is rounding-to-zero division.
284static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
285 Value divisor) {
286 assert(dividend.getType().isIntOrIndex() &&
287 "expected integer or index-typed value");
288 Value cstOne = arith::ConstantOp::create(
289 builder, loc, builder.getOneAttr(dividend.getType()));
290 Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
291 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
292 return arith::DivUIOp::create(builder, loc, sum, divisor);
293}
294
296 Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
297 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
298 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
299 ValueRange iterArgs, ValueRange yieldedValues,
300 IRMapping *clonedToSrcOpsMap) {
301
302 // Check if the op was cloned from another source op, and return it if found
303 // (or the same op if not found)
304 auto findOriginalSrcOp =
305 [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
306 Operation *srcOp = op;
307 // If the source op derives from another op: traverse the chain to find the
308 // original source op
309 while (srcOp && clonedToSrcOpsMap.contains(srcOp))
310 srcOp = clonedToSrcOpsMap.lookup(srcOp);
311 return srcOp;
312 };
313
314 // Builder to insert unrolled bodies just before the terminator of the body of
315 // the loop.
316 auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
317
318 static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
319 if (!annotateFn)
320 annotateFn = noopAnnotateFn;
321
322 // Keep a pointer to the last non-terminator operation in the original block
323 // so that we know what to clone (since we are doing this in-place).
324 Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
325
326 // Unroll the contents of the loop body (append unrollFactor - 1 additional
327 // copies).
328 SmallVector<Value, 4> lastYielded(yieldedValues);
329
330 for (unsigned i = 1; i < unrollFactor; i++) {
331 // Prepare operand map.
332 IRMapping operandMap;
333 operandMap.map(iterArgs, lastYielded);
334
335 // If the induction variable is used, create a remapping to the value for
336 // this unrolled instance.
337 if (!iv.use_empty()) {
338 Value ivUnroll = ivRemapFn(i, iv, builder);
339 operandMap.map(iv, ivUnroll);
340 }
341
342 // Clone the original body of 'forOp'.
343 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
344 Operation *srcOp = &(*it);
345 Operation *clonedOp = builder.clone(*srcOp, operandMap);
346 annotateFn(i, clonedOp, builder);
347 if (clonedToSrcOpsMap)
348 clonedToSrcOpsMap->map(clonedOp,
349 findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
350 }
351
352 // Update yielded values.
353 for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
354 lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
355 }
356
357 // Make sure we annotate the Ops in the original body. We do this last so that
358 // any annotations are not copied into the cloned Ops above.
359 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
360 annotateFn(0, &*it, builder);
361
362 // Update operands of the yield statement.
363 loopBodyBlock->getTerminator()->setOperands(lastYielded);
364}
365
366/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
367/// epilogue loop, if the loop is unrolled.
368FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
369 scf::ForOp forOp, uint64_t unrollFactor,
370 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
371 assert(unrollFactor > 0 && "expected positive unroll factor");
372
373 // Return if the loop body is empty.
374 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
375 return UnrolledLoopInfo{forOp, std::nullopt};
376
377 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
378 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
379 OpBuilder boundsBuilder(forOp);
380 IRRewriter rewriter(forOp.getContext());
381 auto loc = forOp.getLoc();
382 Value step = forOp.getStep();
383 Value upperBoundUnrolled;
384 Value stepUnrolled;
385 bool generateEpilogueLoop = true;
386
387 std::optional<APInt> constTripCount = forOp.getStaticTripCount();
388 if (constTripCount) {
389 // Constant loop bounds computation.
390 bool isUnsignedLoop = forOp.getUnsignedCmp();
391 // For unsigned loops, bounds must be zero-extended: narrow integer types
392 // (e.g. i1, i2, i3) may have bit patterns that are negative in a signed
393 // context (e.g., i1 value 1 has getSExtValue() == -1, getZExtValue() == 1).
394 // Zero-extension is only safe when the unsigned value fits in int64_t, i.e.
395 // the type's bitwidth is < 64. Bail out for 64-bit unsigned loops.
396 if (isUnsignedLoop) {
397 if (auto intTy = dyn_cast<IntegerType>(forOp.getUpperBound().getType()))
398 if (intTy.getWidth() >= 64)
399 return failure();
400 }
401 auto getLoopBound = [&](Value v) -> int64_t {
402 auto apInt = getConstantAPIntValue(v);
403 assert(apInt && "expected constant loop bound");
404 return isUnsignedLoop ? static_cast<int64_t>(apInt->first.getZExtValue())
405 : apInt->first.getSExtValue();
406 };
407 int64_t lbCst = getLoopBound(forOp.getLowerBound());
408 int64_t ubCst = getLoopBound(forOp.getUpperBound());
409 int64_t stepCst = getLoopBound(step);
410 if (unrollFactor == 1) {
411 if (constTripCount->isOne() &&
412 failed(forOp.promoteIfSingleIteration(rewriter)))
413 return failure();
414 return UnrolledLoopInfo{forOp, std::nullopt};
415 }
416
417 uint64_t tripCount = constTripCount->getZExtValue();
418 uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor;
419 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
420 int64_t stepUnrolledCst = stepCst * unrollFactor;
421
422 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
423 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
424 if (generateEpilogueLoop)
425 upperBoundUnrolled = arith::ConstantOp::create(
426 boundsBuilder, loc,
427 boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
428 upperBoundUnrolledCst));
429 else
430 upperBoundUnrolled = forOp.getUpperBound();
431
432 // Create constant for 'stepUnrolled'. When the main loop has zero
433 // iterations (tripCountEvenMultiple == 0), keep the original step.
434 // stepCst * unrollFactor may produce a value that, when truncated to the
435 // bound type's bitwidth during IntegerAttr construction, wraps to zero; a
436 // zero step causes constantTripCount to return nullopt instead of 0, which
437 // prevents the zero-trip main loop from being elided.
438 bool mainLoopHasNoIter = (tripCountEvenMultiple == 0);
439 bool stepUnchanged = (stepCst == stepUnrolledCst);
440 stepUnrolled =
441 (mainLoopHasNoIter || stepUnchanged)
442 ? step
443 : arith::ConstantOp::create(boundsBuilder, loc,
444 boundsBuilder.getIntegerAttr(
445 step.getType(), stepUnrolledCst));
446 } else {
447 // Dynamic loop bounds computation.
448 // TODO: Add dynamic asserts for negative lb/ub/step, or
449 // consider using ceilDiv from AffineApplyExpander.
450 auto lowerBound = forOp.getLowerBound();
451 auto upperBound = forOp.getUpperBound();
452 Value diff =
453 arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
454 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
455 Value unrollFactorCst = arith::ConstantOp::create(
456 boundsBuilder, loc,
457 boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
458 Value tripCountRem =
459 arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
460 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
461 Value tripCountEvenMultiple =
462 arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
463 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
464 upperBoundUnrolled = arith::AddIOp::create(
465 boundsBuilder, loc, lowerBound,
466 arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
467 // Scale 'step' by 'unrollFactor'.
468 stepUnrolled =
469 arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
470 }
471
472 UnrolledLoopInfo resultLoops;
473
474 // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
475 if (generateEpilogueLoop) {
476 OpBuilder epilogueBuilder(forOp->getContext());
477 epilogueBuilder.setInsertionPointAfter(forOp);
478 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
479 epilogueForOp.setLowerBound(upperBoundUnrolled);
480
481 // Update uses of loop results.
482 auto results = forOp.getResults();
483 auto epilogueResults = epilogueForOp.getResults();
484
485 for (auto e : llvm::zip(results, epilogueResults)) {
486 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
487 }
488 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
489 epilogueForOp.getInitArgs().size(), results);
490 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
491 resultLoops.epilogueLoopOp = epilogueForOp;
492 }
493
494 // Create unrolled loop.
495 forOp.setUpperBound(upperBoundUnrolled);
496 forOp.setStep(stepUnrolled);
497
498 auto iterArgs = ValueRange(forOp.getRegionIterArgs());
499 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
500
502 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
503 [&](unsigned i, Value iv, OpBuilder b) {
504 // iv' = iv + step * i;
505 auto stride = arith::MulIOp::create(
506 b, loc, step,
507 arith::ConstantOp::create(b, loc,
508 b.getIntegerAttr(iv.getType(), i)));
509 return arith::AddIOp::create(b, loc, iv, stride);
510 },
511 annotateFn, iterArgs, yieldedValues);
512 // Promote the loop body up if this has turned into a single iteration loop.
513 if (forOp.promoteIfSingleIteration(rewriter).failed())
514 resultLoops.mainLoopOp = forOp;
515 return resultLoops;
516}
517
518/// Unrolls this loop completely.
519LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
520 IRRewriter rewriter(forOp.getContext());
521 std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
522 if (!mayBeConstantTripCount.has_value())
523 return failure();
524 const APInt &tripCount = *mayBeConstantTripCount;
525 if (tripCount.isZero())
526 return success();
527 if (tripCount.isOne())
528 return forOp.promoteIfSingleIteration(rewriter);
529 return loopUnrollByFactor(forOp, tripCount.getZExtValue());
530}
531
532/// Check if bounds of all inner loops are defined outside of `forOp`
533/// and return false if not.
534static bool areInnerBoundsInvariant(scf::ForOp forOp) {
535 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
536 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
537 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
538 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
539 return WalkResult::interrupt();
540
541 return WalkResult::advance();
542 });
543 return !walkResult.wasInterrupted();
544}
545
546/// Unrolls and jams this loop by the specified factor.
547LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
548 uint64_t unrollJamFactor) {
549 assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
550
551 if (unrollJamFactor == 1)
552 return success();
553
554 // If any control operand of any inner loop of `forOp` is defined within
555 // `forOp`, no unroll jam.
556 if (!areInnerBoundsInvariant(forOp)) {
557 LDBG() << "failed to unroll and jam: inner bounds are not invariant";
558 return failure();
559 }
560
561 // Currently, for operations with results are not supported.
562 if (forOp->getNumResults() > 0) {
563 LDBG() << "failed to unroll and jam: unsupported loop with results";
564 return failure();
565 }
566
567 // Currently, only constant trip count that divided by the unroll factor is
568 // supported.
569 std::optional<APInt> tripCount = forOp.getStaticTripCount();
570 if (!tripCount.has_value()) {
571 // If the trip count is dynamic, do not unroll & jam.
572 LDBG() << "failed to unroll and jam: trip count could not be determined";
573 return failure();
574 }
575 uint64_t tripCountValue = tripCount->getZExtValue();
576 if (tripCountValue == 0)
577 return success();
578 if (unrollJamFactor > tripCountValue) {
579 LDBG() << "unroll and jam factor is greater than trip count, set factor to "
580 "trip "
581 "count";
582 unrollJamFactor = tripCountValue;
583 } else if (tripCountValue % unrollJamFactor != 0) {
584 LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
585 "multiple of unroll jam factor";
586 return failure();
587 }
588
589 // Nothing in the loop body other than the terminator.
590 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
591 return success();
592
593 // Gather all sub-blocks to jam upon the loop being unrolled.
595 jbg.walk(forOp);
596 auto &subBlocks = jbg.subBlocks;
597
598 // Collect inner loops.
599 SmallVector<scf::ForOp> innerLoops;
600 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
601
602 // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
603 // iteration. There are (`unrollJamFactor` - 1) iterations.
604 SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
605
606 // For any loop with iter_args, replace it with a new loop that has
607 // `unrollJamFactor` copies of its iterOperands, iter_args and yield
608 // operands.
609 SmallVector<scf::ForOp> newInnerLoops;
610 IRRewriter rewriter(forOp.getContext());
611 for (scf::ForOp oldForOp : innerLoops) {
612 SmallVector<Value> dupIterOperands, dupYieldOperands;
613 ValueRange oldIterOperands = oldForOp.getInits();
614 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
615 ValueRange oldYieldOperands =
616 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
617 // Get additional iterOperands, iterArgs, and yield operands. We will
618 // fix iterOperands and yield operands after cloning of sub-blocks.
619 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
620 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
621 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
622 }
623 // Create a new loop with additional iterOperands, iter_args and yield
624 // operands. This new loop will take the loop body of the original loop.
625 bool forOpReplaced = oldForOp == forOp;
626 scf::ForOp newForOp =
627 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
628 rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
629 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
630 return dupYieldOperands;
631 }));
632 newInnerLoops.push_back(newForOp);
633 // `forOp` has been replaced with a new loop.
634 if (forOpReplaced)
635 forOp = newForOp;
636 // Update `operandMaps` for `newForOp` iterArgs and results.
637 ValueRange newIterArgs = newForOp.getRegionIterArgs();
638 unsigned oldNumIterArgs = oldIterArgs.size();
639 ValueRange newResults = newForOp.getResults();
640 unsigned oldNumResults = newResults.size() / unrollJamFactor;
641 assert(oldNumIterArgs == oldNumResults &&
642 "oldNumIterArgs must be the same as oldNumResults");
643 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
644 for (unsigned j = 0; j < oldNumIterArgs; ++j) {
645 // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
646 // results. Update `operandMaps[i - 1]` to map old iterArgs and results
647 // to those in the `i`th new set.
648 operandMaps[i - 1].map(newIterArgs[j],
649 newIterArgs[i * oldNumIterArgs + j]);
650 operandMaps[i - 1].map(newResults[j],
651 newResults[i * oldNumResults + j]);
652 }
653 }
654 }
655
656 // Scale the step of loop being unroll-jammed by the unroll-jam factor.
657 rewriter.setInsertionPoint(forOp);
658 int64_t step = forOp.getConstantStep()->getSExtValue();
659 auto newStep = rewriter.createOrFold<arith::MulIOp>(
660 forOp.getLoc(), forOp.getStep(),
661 rewriter.createOrFold<arith::ConstantOp>(
662 forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
663 forOp.setStep(newStep);
664 auto forOpIV = forOp.getInductionVar();
665
666 // Unroll and jam (appends unrollJamFactor - 1 additional copies).
667 for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
668 for (auto &subBlock : subBlocks) {
669 // Builder to insert unroll-jammed bodies. Insert right at the end of
670 // sub-block.
671 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
672
673 // If the induction variable is used, create a remapping to the value for
674 // this unrolled instance.
675 if (!forOpIV.use_empty()) {
676 // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
677 auto ivTag = builder.createOrFold<arith::ConstantOp>(
678 forOp.getLoc(), builder.getIndexAttr(step * i));
679 auto ivUnroll =
680 builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
681 operandMaps[i - 1].map(forOpIV, ivUnroll);
682 }
683 // Clone the sub-block being unroll-jammed.
684 for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
685 builder.clone(*it, operandMaps[i - 1]);
686 }
687 // Fix iterOperands and yield op operands of newly created loops.
688 for (auto newForOp : newInnerLoops) {
689 unsigned oldNumIterOperands =
690 newForOp.getNumRegionIterArgs() / unrollJamFactor;
691 unsigned numControlOperands = newForOp.getNumControlOperands();
692 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
693 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
694 assert(oldNumIterOperands == oldNumYieldOperands &&
695 "oldNumIterOperands must be the same as oldNumYieldOperands");
696 for (unsigned j = 0; j < oldNumIterOperands; ++j) {
697 // The `i`th duplication of an old iterOperand or yield op operand
698 // needs to be replaced with a mapped value from `operandMaps[i - 1]`
699 // if such mapped value exists.
700 newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
701 operandMaps[i - 1].lookupOrDefault(
702 newForOp.getOperand(numControlOperands + j)));
703 yieldOp.setOperand(
704 i * oldNumYieldOperands + j,
705 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
706 }
707 }
708 }
709
710 // Promote the loop body up if this has turned into a single iteration loop.
711 (void)forOp.promoteIfSingleIteration(rewriter);
712 return success();
713}
714
716 Location loc, OpFoldResult lb,
718 OpFoldResult step) {
719 Range normalizedLoopBounds;
720 normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
721 normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
722 AffineExpr s0, s1, s2;
723 bindSymbols(rewriter.getContext(), s0, s1, s2);
724 AffineExpr e = (s1 - s0).ceilDiv(s2);
725 normalizedLoopBounds.size =
726 affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
727 return normalizedLoopBounds;
728}
729
732 OpFoldResult step) {
733 if (getType(lb).isIndex()) {
734 return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
735 }
736 // For non-index types, generate `arith` instructions
737 // Check if the loop is already known to have a constant zero lower bound or
738 // a constant one step.
739 bool isZeroBased = false;
740 if (auto lbCst = getConstantIntValue(lb))
741 isZeroBased = lbCst.value() == 0;
742
743 bool isStepOne = false;
744 if (auto stepCst = getConstantIntValue(step))
745 isStepOne = stepCst.value() == 1;
746
747 Type rangeType = getType(lb);
748 assert(rangeType == getType(ub) && rangeType == getType(step) &&
749 "expected matching types");
750
751 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
752 // assuming the step is strictly positive. Update the bounds and the step
753 // of the loop to go from 0 to the number of iterations, if necessary.
754 if (isZeroBased && isStepOne)
755 return {lb, ub, step};
756
757 OpFoldResult diff = ub;
758 if (!isZeroBased) {
759 diff = rewriter.createOrFold<arith::SubIOp>(
760 loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
761 getValueOrCreateConstantIntOp(rewriter, loc, lb));
762 }
763 OpFoldResult newUpperBound = diff;
764 if (!isStepOne) {
765 newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
766 loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
767 getValueOrCreateConstantIntOp(rewriter, loc, step));
768 }
769
770 OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
771 OpFoldResult newStep = rewriter.getOneAttr(rangeType);
772
773 return {newLowerBound, newUpperBound, newStep};
774}
775
777 Location loc,
778 Value normalizedIv,
779 OpFoldResult origLb,
780 OpFoldResult origStep) {
781 AffineExpr d0, s0, s1;
782 bindSymbols(rewriter.getContext(), s0, s1);
783 bindDims(rewriter.getContext(), d0);
784 AffineExpr e = d0 * s1 + s0;
786 rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
787 Value denormalizedIvVal =
788 getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
789 SmallPtrSet<Operation *, 1> preservedUses;
790 // If an `affine.apply` operation is generated for denormalization, the use
791 // of `origLb` in those ops must not be replaced. These arent not generated
792 // when `origLb == 0` and `origStep == 1`.
793 if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
794 if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
795 preservedUses.insert(preservedUse);
796 }
797 }
798 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
799}
800
802 Value normalizedIv, OpFoldResult origLb,
803 OpFoldResult origStep) {
804 if (getType(origLb).isIndex()) {
805 return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
806 origLb, origStep);
807 }
808 Value denormalizedIv;
810 bool isStepOne = isOneInteger(origStep);
811 bool isZeroBased = isZeroInteger(origLb);
812
813 Value scaled = normalizedIv;
814 if (!isStepOne) {
815 Value origStepValue =
816 getValueOrCreateConstantIntOp(rewriter, loc, origStep);
817 scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
818 preserve.insert(scaled.getDefiningOp());
819 }
820 denormalizedIv = scaled;
821 if (!isZeroBased) {
822 Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
823 denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
824 preserve.insert(denormalizedIv.getDefiningOp());
825 }
826
827 rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
828}
829
831 ArrayRef<OpFoldResult> values) {
832 assert(!values.empty() && "unexecpted empty array");
833 AffineExpr s0, s1;
834 bindSymbols(rewriter.getContext(), s0, s1);
835 AffineExpr mul = s0 * s1;
836 OpFoldResult products = rewriter.getIndexAttr(1);
837 for (auto v : values) {
839 rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
840 }
841 return products;
842}
843
844/// Helper function to multiply a sequence of values.
846 ArrayRef<Value> values) {
847 assert(!values.empty() && "unexpected empty list");
848 if (getType(values.front()).isIndex()) {
850 OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
851 return getValueOrCreateConstantIndexOp(rewriter, loc, product);
852 }
853 std::optional<Value> productOf;
854 for (auto v : values) {
855 auto vOne = getConstantIntValue(v);
856 if (vOne && vOne.value() == 1)
857 continue;
858 if (productOf)
859 productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
860 .getResult();
861 else
862 productOf = v;
863 }
864 if (!productOf) {
865 productOf = arith::ConstantOp::create(
866 rewriter, loc, rewriter.getOneAttr(getType(values.front())))
867 .getResult();
868 }
869 return productOf.value();
870}
871
872/// For each original loop, the value of the
873/// induction variable can be obtained by dividing the induction variable of
874/// the linearized loop by the total number of iterations of the loops nested
875/// in it modulo the number of iterations in this loop (remove the values
876/// related to the outer loops):
877/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
878/// Compute these iteratively from the innermost loop by creating a "running
879/// quotient" of division by the range.
880static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
882 Value linearizedIv, ArrayRef<Value> ubs) {
883
884 if (linearizedIv.getType().isIndex()) {
885 Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
886 rewriter, loc, linearizedIv, ubs);
887 auto resultVals = llvm::map_to_vector(
888 delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
889 return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
890 }
891
892 SmallVector<Value> delinearizedIvs(ubs.size());
893 SmallPtrSet<Operation *, 2> preservedUsers;
894
895 llvm::BitVector isUbOne(ubs.size());
896 for (auto [index, ub] : llvm::enumerate(ubs)) {
897 auto ubCst = getConstantIntValue(ub);
898 if (ubCst && ubCst.value() == 1)
899 isUbOne.set(index);
900 }
901
902 // Prune the lead ubs that are all ones.
903 unsigned numLeadingOneUbs = 0;
904 for (auto [index, ub] : llvm::enumerate(ubs)) {
905 if (!isUbOne.test(index)) {
906 break;
907 }
908 delinearizedIvs[index] = arith::ConstantOp::create(
909 rewriter, loc, rewriter.getZeroAttr(ub.getType()));
910 numLeadingOneUbs++;
911 }
912
913 Value previous = linearizedIv;
914 for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
915 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
916 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
917 previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
918 preservedUsers.insert(previous.getDefiningOp());
919 }
920 Value iv = previous;
921 if (i != e - 1) {
922 if (!isUbOne.test(idx)) {
923 iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
924 preservedUsers.insert(iv.getDefiningOp());
925 } else {
926 iv = arith::ConstantOp::create(
927 rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
928 }
929 }
930 delinearizedIvs[idx] = iv;
931 }
932 return {delinearizedIvs, preservedUsers};
933}
934
935LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
937 if (loops.size() < 2)
938 return failure();
939
940 scf::ForOp innermost = loops.back();
941 scf::ForOp outermost = loops.front();
942
943 // Bail out if any loop has a known zero step, as normalization
944 // would result in a division by zero.
945 for (auto loop : loops) {
946 if (auto step = getConstantIntValue(loop.getStep())) {
947 if (step.value() == 0) {
948 return failure();
949 }
950 }
951 }
952 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
953 // allows the following code to assume upperBound is the number of iterations.
954 for (auto loop : loops) {
955 OpBuilder::InsertionGuard g(rewriter);
956 rewriter.setInsertionPoint(outermost);
957 Value lb = loop.getLowerBound();
958 Value ub = loop.getUpperBound();
959 Value step = loop.getStep();
960 auto newLoopRange =
961 emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
962
963 rewriter.modifyOpInPlace(loop, [&]() {
964 loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
965 newLoopRange.offset));
966 loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
967 newLoopRange.size));
968 loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
969 newLoopRange.stride));
970 });
971 rewriter.setInsertionPointToStart(innermost.getBody());
972 denormalizeInductionVariable(rewriter, loop.getLoc(),
973 loop.getInductionVar(), lb, step);
974 }
975
976 // 2. Emit code computing the upper bound of the coalesced loop as product
977 // of the number of iterations of all loops.
978 OpBuilder::InsertionGuard g(rewriter);
979 rewriter.setInsertionPoint(outermost);
980 Location loc = outermost.getLoc();
981 SmallVector<Value> upperBounds = llvm::map_to_vector(
982 loops, [](auto loop) { return loop.getUpperBound(); });
983 Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
984 outermost.setUpperBound(upperBound);
985
986 // Insert delinearization at the start of the outermost loop body.
987 rewriter.setInsertionPointToStart(outermost.getBody());
988 auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
989 rewriter, loc, outermost.getInductionVar(), upperBounds);
990 rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
991 preservedUsers);
992
993 for (int i = loops.size() - 1; i > 0; --i) {
994 auto outerLoop = loops[i - 1];
995 auto innerLoop = loops[i];
996
997 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
998 auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
999 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
1000 for (Value &yieldedVal : yieldedVals) {
1001 // The yielded value may be an iteration argument of the inner loop
1002 // which is about to be inlined.
1003 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
1004 if (iter != innerLoop.getRegionIterArgs().end()) {
1005 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
1006 // `outerLoop` iter args identical to the `innerLoop` init args.
1007 assert(iterArgIndex < innerLoop.getInitArgs().size());
1008 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
1009 }
1010 }
1011 rewriter.eraseOp(innerTerminator);
1012
1013 SmallVector<Value> innerBlockArgs;
1014 innerBlockArgs.push_back(delinearizeIvs[i]);
1015 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
1016 rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
1017 Block::iterator(innerLoop), innerBlockArgs);
1018 rewriter.replaceOp(innerLoop, yieldedVals);
1019 }
1020 return success();
1021}
1022
1024 if (loops.empty()) {
1025 return failure();
1026 }
1027 IRRewriter rewriter(loops.front().getContext());
1028 return coalesceLoops(rewriter, loops);
1029}
1030
1031LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
1032 LogicalResult result(failure());
1034 getPerfectlyNestedLoops(loops, op);
1035
1036 // Look for a band of loops that can be coalesced, i.e. perfectly nested
1037 // loops with bounds defined above some loop.
1038
1039 // 1. For each loop, find above which parent loop its bounds operands are
1040 // defined.
1041 SmallVector<unsigned> operandsDefinedAbove(loops.size());
1042 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1043 operandsDefinedAbove[i] = i;
1044 for (unsigned j = 0; j < i; ++j) {
1045 SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1046 loops[i].getUpperBound(),
1047 loops[i].getStep()};
1048 if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1049 operandsDefinedAbove[i] = j;
1050 break;
1051 }
1052 }
1053 }
1054
1055 // 2. For each inner loop check that the iter_args for the immediately outer
1056 // loop are the init for the immediately inner loop and that the yields of the
1057 // return of the inner loop is the yield for the immediately outer loop. Keep
1058 // track of where the chain starts from for each loop.
1059 SmallVector<unsigned> iterArgChainStart(loops.size());
1060 iterArgChainStart[0] = 0;
1061 for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1062 // By default set the start of the chain to itself.
1063 iterArgChainStart[i] = i;
1064 auto outerloop = loops[i - 1];
1065 auto innerLoop = loops[i];
1066 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1067 continue;
1068 }
1069 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1070 continue;
1071 }
1072 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1073 if (!llvm::equal(outerloopTerminator->getOperands(),
1074 innerLoop.getResults())) {
1075 continue;
1076 }
1077 iterArgChainStart[i] = iterArgChainStart[i - 1];
1078 }
1079
1080 // 3. Identify bands of loops such that the operands of all of them are
1081 // defined above the first loop in the band. Traverse the nest bottom-up
1082 // so that modifications don't invalidate the inner loops.
1083 for (unsigned end = loops.size(); end > 0; --end) {
1084 unsigned start = 0;
1085 for (; start < end - 1; ++start) {
1086 auto maxPos =
1087 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1088 std::next(operandsDefinedAbove.begin(), end));
1089 if (maxPos > start)
1090 continue;
1091 if (iterArgChainStart[end - 1] > start)
1092 continue;
1093 auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1094 if (succeeded(coalesceLoops(band)))
1095 result = success();
1096 break;
1097 }
1098 // If a band was found and transformed, keep looking at the loops above
1099 // the outermost transformed loop.
1100 if (start != end - 1)
1101 end = start + 1;
1102 }
1103 return result;
1104}
1105
1107 RewriterBase &rewriter, scf::ParallelOp loops,
1108 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1109 OpBuilder::InsertionGuard g(rewriter);
1110 rewriter.setInsertionPoint(loops);
1111 Location loc = loops.getLoc();
1112
1113 // Presort combined dimensions.
1114 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1115 for (auto &dims : sortedDimensions)
1116 llvm::sort(dims);
1117
1118 // Normalize ParallelOp's iteration pattern.
1119 SmallVector<Value, 3> normalizedUpperBounds;
1120 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1121 OpBuilder::InsertionGuard g2(rewriter);
1122 rewriter.setInsertionPoint(loops);
1123 Value lb = loops.getLowerBound()[i];
1124 Value ub = loops.getUpperBound()[i];
1125 Value step = loops.getStep()[i];
1126 auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1127 normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1128 rewriter, loops.getLoc(), newLoopRange.size));
1129
1130 rewriter.setInsertionPointToStart(loops.getBody());
1131 denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1132 step);
1133 }
1134
1135 // Combine iteration spaces.
1136 SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1137 auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1138 auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1139 for (auto &sortedDimension : sortedDimensions) {
1140 Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1141 for (auto idx : sortedDimension) {
1142 newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1143 normalizedUpperBounds[idx]);
1144 }
1145 lowerBounds.push_back(cst0);
1146 steps.push_back(cst1);
1147 upperBounds.push_back(newUpperBound);
1148 }
1149
1150 // Create new ParallelLoop with conversions to the original induction values.
1151 // The loop below uses divisions to get the relevant range of values in the
1152 // new induction value that represent each range of the original induction
1153 // value. The remainders then determine based on that range, which iteration
1154 // of the original induction value this represents. This is a normalized value
1155 // that is un-normalized already by the previous logic.
1156 auto newPloop = scf::ParallelOp::create(
1157 rewriter, loc, lowerBounds, upperBounds, steps,
1158 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1159 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1160 Value previous = ploopIVs[i];
1161 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1162 // Iterate over all except the last induction value.
1163 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1164 unsigned idx = combinedDimensions[i][j];
1165
1166 // Determine the current induction value's current loop iteration
1167 Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1168 normalizedUpperBounds[idx]);
1169 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1170 loops.getRegion());
1171
1172 // Remove the effect of the current induction value to prepare for
1173 // the next value.
1174 previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1175 normalizedUpperBounds[idx]);
1176 }
1177
1178 // The final induction value is just the remaining value.
1179 unsigned idx = combinedDimensions[i][0];
1180 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1181 previous, loops.getRegion());
1182 }
1183 });
1184
1185 // Replace the old loop with the new loop.
1186 loops.getBody()->back().erase();
1187 newPloop.getBody()->getOperations().splice(
1188 Block::iterator(newPloop.getBody()->back()),
1189 loops.getBody()->getOperations());
1190 loops.erase();
1191}
1192
1193// Hoist the ops within `outer` that appear before `inner`.
1194// Such ops include the ops that have been introduced by parametric tiling.
1195// Ops that come from triangular loops (i.e. that belong to the program slice
1196// rooted at `outer`) and ops that have side effects cannot be hoisted.
1197// Return failure when any op fails to hoist.
1198static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1199 SetVector<Operation *> forwardSlice;
1201 options.filter = [&inner](Operation *op) {
1202 return op != inner.getOperation();
1203 };
1204 getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1205 LogicalResult status = success();
1207 for (auto &op : outer.getBody()->without_terminator()) {
1208 // Stop when encountering the inner loop.
1209 if (&op == inner.getOperation())
1210 break;
1211 // Skip over non-hoistable ops.
1212 if (forwardSlice.count(&op) > 0) {
1213 status = failure();
1214 continue;
1215 }
1216 // Skip intermediate scf::ForOp, these are not considered a failure.
1217 if (isa<scf::ForOp>(op))
1218 continue;
1219 // Skip other ops with regions.
1220 if (op.getNumRegions() > 0) {
1221 status = failure();
1222 continue;
1223 }
1224 // Skip if op has side effects.
1225 // TODO: loads to immutable memory regions are ok.
1226 if (!isMemoryEffectFree(&op)) {
1227 status = failure();
1228 continue;
1229 }
1230 toHoist.push_back(&op);
1231 }
1232 auto *outerForOp = outer.getOperation();
1233 for (auto *op : toHoist)
1234 op->moveBefore(outerForOp);
1235 return status;
1236}
1237
1238// Traverse the interTile and intraTile loops and try to hoist ops such that
1239// bands of perfectly nested loops are isolated.
1240// Return failure if either perfect interTile or perfect intraTile bands cannot
1241// be formed.
1242static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1243 LogicalResult status = success();
1244 const Loops &interTile = tileLoops.first;
1245 const Loops &intraTile = tileLoops.second;
1246 auto size = interTile.size();
1247 assert(size == intraTile.size());
1248 if (size <= 1)
1249 return success();
1250 for (unsigned s = 1; s < size; ++s)
1251 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1252 : failure();
1253 for (unsigned s = 1; s < size; ++s)
1254 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1255 : failure();
1256 return status;
1257}
1258
1259/// Collect perfectly nested loops starting from `rootForOps`. Loops are
1260/// perfectly nested if each loop is the first and only non-terminator operation
1261/// in the parent loop. Collect at most `maxLoops` loops and append them to
1262/// `forOps`.
1263template <typename T>
1265 SmallVectorImpl<T> &forOps, T rootForOp,
1266 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1267 for (unsigned i = 0; i < maxLoops; ++i) {
1268 forOps.push_back(rootForOp);
1269 Block &body = rootForOp.getRegion().front();
1270 if (body.begin() != std::prev(body.end(), 2))
1271 return;
1272
1273 rootForOp = dyn_cast<T>(&body.front());
1274 if (!rootForOp)
1275 return;
1276 }
1277}
1278
1279static Loops stripmineSink(scf::ForOp forOp, Value factor,
1280 ArrayRef<scf::ForOp> targets) {
1281 assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
1282 auto originalStep = forOp.getStep();
1283 auto iv = forOp.getInductionVar();
1284
1285 OpBuilder b(forOp);
1286 forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1287
1288 Loops innerLoops;
1289 for (auto t : targets) {
1290 assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
1291
1292 // Save information for splicing ops out of t when done
1293 auto begin = t.getBody()->begin();
1294 auto nOps = t.getBody()->getOperations().size();
1295
1296 // Insert newForOp before the terminator of `t`.
1297 auto b = OpBuilder::atBlockTerminator((t.getBody()));
1298 Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1299 Value ub =
1300 arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1301
1302 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1303 auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1304 newForOp.getBody()->getOperations().splice(
1305 newForOp.getBody()->getOperations().begin(),
1306 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1307 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1308 newForOp.getRegion());
1309
1310 innerLoops.push_back(newForOp);
1311 }
1312
1313 return innerLoops;
1314}
1315
1316// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1317// Returns the new for operation, nested immediately under `target`.
1318template <typename SizeType>
1319static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1320 scf::ForOp target) {
1321 // TODO: Use cheap structural assertions that targets are nested under
1322 // forOp and that targets are not nested under each other when DominanceInfo
1323 // exposes the capability. It seems overkill to construct a whole function
1324 // dominance tree at this point.
1325 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1326 assert(res.size() == 1 && "Expected 1 inner forOp");
1327 return res[0];
1328}
1329
1331 ArrayRef<Value> sizes,
1332 ArrayRef<scf::ForOp> targets) {
1334 SmallVector<scf::ForOp, 8> currentTargets(targets);
1335 for (auto it : llvm::zip(forOps, sizes)) {
1336 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1337 res.push_back(step);
1338 currentTargets = step;
1339 }
1340 return res;
1341}
1342
1344 scf::ForOp target) {
1346 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1347 res.push_back(llvm::getSingleElement(loops));
1348 return res;
1349}
1350
1352 // Collect perfectly nested loops. If more size values provided than nested
1353 // loops available, truncate `sizes`.
1355 forOps.reserve(sizes.size());
1356 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1357 if (forOps.size() < sizes.size())
1358 sizes = sizes.take_front(forOps.size());
1359
1360 return ::tile(forOps, sizes, forOps.back());
1361}
1362
1364 scf::ForOp root) {
1365 getPerfectlyNestedLoopsImpl(nestedLoops, root);
1366}
1367
1369 ArrayRef<int64_t> sizes) {
1370 // Collect perfectly nested loops. If more size values provided than nested
1371 // loops available, truncate `sizes`.
1373 forOps.reserve(sizes.size());
1374 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1375 if (forOps.size() < sizes.size())
1376 sizes = sizes.take_front(forOps.size());
1377
1378 // The strip-mining transformation splices loop bodies into a new inner loop
1379 // without threading iter_args. If any of the collected loops carries
1380 // iter_args, the splice would produce invalid IR (yielded values from the
1381 // inner scope used in the outer terminator). Skip the transformation in
1382 // that case.
1383 if (llvm::any_of(forOps,
1384 [](scf::ForOp op) { return !op.getInitArgs().empty(); }))
1385 return {};
1386
1387 // Compute the tile sizes such that i-th outer loop executes size[i]
1388 // iterations. Given that the loop current executes
1389 // numIterations = ceildiv((upperBound - lowerBound), step)
1390 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1391 SmallVector<Value, 4> tileSizes;
1392 tileSizes.reserve(sizes.size());
1393 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1394 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1395
1396 auto forOp = forOps[i];
1397 OpBuilder builder(forOp);
1398 auto loc = forOp.getLoc();
1399 Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1400 forOp.getLowerBound());
1401 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1402 Value iterationsPerBlock =
1403 ceilDivPositive(builder, loc, numIterations, sizes[i]);
1404 tileSizes.push_back(iterationsPerBlock);
1405 }
1406
1407 // Call parametric tiling with the given sizes.
1408 auto intraTile = tile(forOps, tileSizes, forOps.back());
1409 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1410
1411 // TODO: for now we just ignore the result of band isolation.
1412 // In the future, mapping decisions may be impacted by the ability to
1413 // isolate perfectly nested bands.
1414 (void)tryIsolateBands(tileLoops);
1415
1416 return tileLoops;
1417}
1418
1420 scf::ForallOp source,
1421 RewriterBase &rewriter) {
1422 unsigned numTargetOuts = target.getNumResults();
1423 unsigned numSourceOuts = source.getNumResults();
1424
1425 // Create fused shared_outs.
1426 SmallVector<Value> fusedOuts;
1427 llvm::append_range(fusedOuts, target.getOutputs());
1428 llvm::append_range(fusedOuts, source.getOutputs());
1429
1430 // Create a new scf.forall op after the source loop.
1431 rewriter.setInsertionPointAfter(source);
1432 scf::ForallOp fusedLoop = scf::ForallOp::create(
1433 rewriter, source.getLoc(), source.getMixedLowerBound(),
1434 source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1435 source.getMapping());
1436
1437 // Map control operands.
1438 IRMapping mapping;
1439 mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1440 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1441
1442 // Map shared outs.
1443 mapping.map(target.getRegionIterArgs(),
1444 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1445 mapping.map(source.getRegionIterArgs(),
1446 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1447
1448 // Append everything except the terminator into the fused operation.
1449 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1450 for (Operation &op : target.getBody()->without_terminator())
1451 rewriter.clone(op, mapping);
1452 for (Operation &op : source.getBody()->without_terminator())
1453 rewriter.clone(op, mapping);
1454
1455 // Fuse the old terminator in_parallel ops into the new one.
1456 scf::InParallelOp targetTerm = target.getTerminator();
1457 scf::InParallelOp sourceTerm = source.getTerminator();
1458 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1459 rewriter.setInsertionPointToStart(fusedTerm.getBody());
1460 for (Operation &op : targetTerm.getYieldingOps())
1461 rewriter.clone(op, mapping);
1462 for (Operation &op : sourceTerm.getYieldingOps())
1463 rewriter.clone(op, mapping);
1464
1465 // Replace old loops by substituting their uses by results of the fused loop.
1466 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1467 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1468
1469 return fusedLoop;
1470}
1471
1473 scf::ForOp source,
1474 RewriterBase &rewriter) {
1475 assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1476 "incompatible signedness");
1477 unsigned numTargetOuts = target.getNumResults();
1478 unsigned numSourceOuts = source.getNumResults();
1479
1480 // Create fused init_args, with target's init_args before source's init_args.
1481 SmallVector<Value> fusedInitArgs;
1482 llvm::append_range(fusedInitArgs, target.getInitArgs());
1483 llvm::append_range(fusedInitArgs, source.getInitArgs());
1484
1485 // Create a new scf.for op after the source loop (with scf.yield terminator
1486 // (without arguments) only in case its init_args is empty).
1487 rewriter.setInsertionPointAfter(source);
1488 scf::ForOp fusedLoop = scf::ForOp::create(
1489 rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1490 source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
1491 source.getUnsignedCmp());
1492
1493 // Map original induction variables and operands to those of the fused loop.
1494 IRMapping mapping;
1495 mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1496 mapping.map(target.getRegionIterArgs(),
1497 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1498 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1499 mapping.map(source.getRegionIterArgs(),
1500 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1501
1502 // Merge target's body into the new (fused) for loop and then source's body.
1503 rewriter.setInsertionPointToStart(fusedLoop.getBody());
1504 for (Operation &op : target.getBody()->without_terminator())
1505 rewriter.clone(op, mapping);
1506 for (Operation &op : source.getBody()->without_terminator())
1507 rewriter.clone(op, mapping);
1508
1509 // Build fused yield results by appropriately mapping original yield operands.
1510 SmallVector<Value> yieldResults;
1511 for (Value operand : target.getBody()->getTerminator()->getOperands())
1512 yieldResults.push_back(mapping.lookupOrDefault(operand));
1513 for (Value operand : source.getBody()->getTerminator()->getOperands())
1514 yieldResults.push_back(mapping.lookupOrDefault(operand));
1515 if (!yieldResults.empty())
1516 scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1517
1518 // Replace old loops by substituting their uses by results of the fused loop.
1519 rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1520 rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1521
1522 return fusedLoop;
1523}
1524
1525FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1526 scf::ForallOp forallOp) {
1527 SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1528 SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1529 SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1530
1531 if (forallOp.isNormalized())
1532 return forallOp;
1533
1534 OpBuilder::InsertionGuard g(rewriter);
1535 auto loc = forallOp.getLoc();
1536 rewriter.setInsertionPoint(forallOp);
1538 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1539 Range normalizedLoopParams =
1540 emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1541 newUbs.push_back(normalizedLoopParams.size);
1542 }
1543 (void)foldDynamicIndexList(newUbs);
1544
1545 // Use the normalized builder since the lower bounds are always 0 and the
1546 // steps are always 1.
1547 auto normalizedForallOp = scf::ForallOp::create(
1548 rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1549 [](OpBuilder &, Location, ValueRange) {});
1550
1551 rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1552 normalizedForallOp.getBodyRegion(),
1553 normalizedForallOp.getBodyRegion().begin());
1554 // Remove the original empty block in the new loop.
1555 rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1556
1557 rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1558 // Update the users of the original loop variables.
1559 for (auto [idx, iv] :
1560 llvm::enumerate(normalizedForallOp.getInductionVars())) {
1561 auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1562 auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1563 denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1564 }
1565
1566 rewriter.replaceOp(forallOp, normalizedForallOp);
1567 return normalizedForallOp;
1568}
1569
1572 assert(!loops.empty() && "unexpected empty loop nest");
1573 if (loops.size() == 1)
1574 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1575 for (auto [outerLoop, innerLoop] :
1576 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1577 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1578 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1579 if (!outerFor || !innerFor)
1580 return false;
1581 auto outerBBArgs = outerFor.getRegionIterArgs();
1582 auto innerIterArgs = innerFor.getInitArgs();
1583 if (outerBBArgs.size() != innerIterArgs.size())
1584 return false;
1585
1586 for (auto [outerBBArg, innerIterArg] :
1587 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1588 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1589 innerIterArg != outerBBArg)
1590 return false;
1591 }
1592
1593 ValueRange outerYields =
1594 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1595 ValueRange innerResults = innerFor.getResults();
1596 if (outerYields.size() != innerResults.size())
1597 return false;
1598 for (auto [outerYield, innerResult] :
1599 llvm::zip_equal(outerYields, innerResults)) {
1600 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1601 outerYield != innerResult)
1602 return false;
1603 }
1604 }
1605 return true;
1606}
1607
1609mlir::getConstLoopBounds(mlir::LoopLikeOpInterface loopOp) {
1610 std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1611 std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1612 std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1613 if (!loBnds || !upBnds || !steps)
1614 return {};
1616 for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1617 auto lbCst = getConstantIntValue(lb);
1618 auto ubCst = getConstantIntValue(ub);
1619 auto stepCst = getConstantIntValue(step);
1620 if (!lbCst || !ubCst || !stepCst)
1621 return {};
1622 loopRanges.emplace_back(*lbCst, *ubCst, *stepCst);
1623 }
1624 return loopRanges;
1625}
1626
1628mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
1629 std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1630 std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1631 std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1632 if (!loBnds || !upBnds || !steps)
1633 return {};
1635 for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1636 // TODO(#178506): Signedness is not handled correctly here.
1637 std::optional<llvm::APInt> numIter = constantTripCount(
1638 lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
1639 if (!numIter)
1640 return {};
1641 tripCounts.push_back(*numIter);
1642 }
1643 return tripCounts;
1644}
1645
1646FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
1647 scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
1648 RewriterBase &rewriter,
1649 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1650 IRMapping *clonedToSrcOpsMap) {
1651 const unsigned numLoops = op.getNumLoops();
1652 assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
1653 "Expected positive unroll factors");
1654 assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
1655 "Expected non-empty unroll factors of size <= to the number of loops");
1656
1657 // Bail out if no valid unroll factors were provided
1658 if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
1659 return rewriter.notifyMatchFailure(
1660 op, "Unrolling not applied if all factors are 1");
1661
1662 // Return if the loop body is empty.
1663 if (llvm::hasSingleElement(op.getBody()->getOperations()))
1664 return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
1665
1666 // If the provided unroll factors do not cover all the loop dims, they are
1667 // applied to the inner loop dimensions.
1668 const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
1669
1670 // Make sure that the unroll factors divide the iteration space evenly
1671 // TODO: Support unrolling loops with dynamic iteration spaces.
1673 if (tripCounts.empty())
1674 return rewriter.notifyMatchFailure(
1675 op, "Failed to compute constant trip counts for the loop. Note that "
1676 "dynamic loop sizes are not supported.");
1677
1678 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1679 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1680 if (tripCounts[dimIdx].urem(unrollFactor) != 0)
1681 return rewriter.notifyMatchFailure(
1682 op, "Unroll factors don't divide the iteration space evenly");
1683 }
1684
1685 std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
1686 if (!maybeFoldSteps)
1687 return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
1689 for (auto step : *maybeFoldSteps)
1690 steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
1691
1692 for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1693 const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1694 if (unrollFactor == 1)
1695 continue;
1696 const size_t origStep = steps[dimIdx];
1697 const int64_t newStep = origStep * unrollFactor;
1698 IRMapping clonedToSrcOpsMap;
1699
1700 ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
1701 auto yieldedValues = op.getBody()->getTerminator()->getOperands();
1702
1704 op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
1705 [&](unsigned i, Value iv, OpBuilder b) {
1706 // iv' = iv + step * i;
1707 const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
1708 const auto map =
1709 b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
1710 return affine::AffineApplyOp::create(b, iv.getLoc(), map,
1711 ValueRange{iv});
1712 },
1713 /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
1714
1715 // Update loop step
1716 auto prevInsertPoint = rewriter.saveInsertionPoint();
1717 rewriter.setInsertionPoint(op);
1718 op.getStepMutable()[dimIdx].assign(
1719 arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
1720 rewriter.restoreInsertionPoint(prevInsertPoint);
1721 }
1722 return op;
1723}
return success()
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition Utils.cpp:830
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition Utils.cpp:1242
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition Utils.cpp:1264
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition Utils.cpp:1198
static Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition Utils.cpp:715
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition Utils.cpp:1279
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition Utils.cpp:266
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition Utils.cpp:845
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition Utils.cpp:881
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition Utils.cpp:776
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition Utils.cpp:534
static int64_t product(ArrayRef< int64_t > vals)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType::iterator iterator
Definition Block.h:150
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator end()
Definition Block.h:154
iterator begin()
Definition Block.h:153
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition Builders.h:387
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition Builders.h:254
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition Builders.h:392
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
operand_type_range getOperandTypes()
Definition Operation.h:423
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
Definition Operation.h:441
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
iterator begin()
Definition Region.h:55
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition Value.cpp:91
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition SCF.cpp:115
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition Utils.cpp:1363
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition Utils.cpp:1570
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition Utils.cpp:218
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition Utils.cpp:36
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition Utils.cpp:1031
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
void generateUnrolledLoop(Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues, IRMapping *clonedToSrcOpsMap=nullptr)
Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs' and 'yieldedValues' as the...
Definition Utils.cpp:295
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition Utils.cpp:519
llvm::SmallVector< llvm::APInt > getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp)
Get constant trip counts for each of the induction variables of the given loop operation.
Definition Utils.cpp:1628
std::pair< Loops, Loops > TileLoops
Definition Utils.h:156
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SmallVector< std::tuple< int64_t, int64_t, int64_t > > getConstLoopBounds(mlir::LoopLikeOpInterface loopOp)
Get constant loop bounds and steps for each of the induction variables of the given loop operation,...
Definition Utils.cpp:1609
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned > > combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition Utils.cpp:1106
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
std::optional< std::pair< APInt, bool > > getConstantAPIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SliceOptions ForwardSliceOptions
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition Utils.cpp:1351
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition Utils.cpp:368
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition Utils.cpp:547
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition Utils.cpp:241
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
FailureOr< scf::ParallelOp > parallelLoopUnrollByFactors(scf::ParallelOp op, ArrayRef< uint64_t > unrollFactors, RewriterBase &rewriter, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, IRMapping *clonedToSrcOpsMap=nullptr)
Unroll this scf::Parallel loop by the specified unroll factors.
Definition Utils.cpp:1646
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition Utils.cpp:115
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition Utils.cpp:801
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition Utils.cpp:1419
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition Utils.cpp:1023
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition Utils.cpp:1472
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition Utils.cpp:1368
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition Utils.cpp:730
SmallVector< scf::ForOp, 8 > Loops
Tile a nest of standard for loops rooted at rootForOp by finding such parametric tile sizes that the ...
Definition Utils.h:155
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step,...
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition Utils.cpp:1525
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition Utils.h:117
std::optional< scf::ForOp > mainLoopOp
Definition Utils.h:116
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.