MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/APInt.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/DebugLog.h"
29 #include <cstdint>
30 
31 using namespace mlir;
32 
33 #define DEBUG_TYPE "scf-utils"
34 
36  RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
37  ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
38  bool replaceIterOperandsUsesInLoop) {
39  if (loopNest.empty())
40  return {};
41  // This method is recursive (to make it more readable). Adding an
42  // assertion here to limit the recursion. (See
43  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
44  assert(loopNest.size() <= 10 &&
45  "exceeded recursion limit when yielding value from loop nest");
46 
47  // To yield a value from a perfectly nested loop nest, the following
48  // pattern needs to be created, i.e. starting with
49  //
50  // ```mlir
51  // scf.for .. {
52  // scf.for .. {
53  // scf.for .. {
54  // %value = ...
55  // }
56  // }
57  // }
58  // ```
59  //
60  // needs to be modified to
61  //
62  // ```mlir
63  // %0 = scf.for .. iter_args(%arg0 = %init) {
64  // %1 = scf.for .. iter_args(%arg1 = %arg0) {
65  // %2 = scf.for .. iter_args(%arg2 = %arg1) {
66  // %value = ...
67  // scf.yield %value
68  // }
69  // scf.yield %2
70  // }
71  // scf.yield %1
72  // }
73  // ```
74  //
75  // The inner most loop is handled using the `replaceWithAdditionalYields`
76  // that works on a single loop.
77  if (loopNest.size() == 1) {
78  auto innerMostLoop =
79  cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
80  rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
81  newYieldValuesFn));
82  return {innerMostLoop};
83  }
84  // The outer loops are modified by calling this method recursively
85  // - The return value of the inner loop is the value yielded by this loop.
86  // - The region iter args of this loop are the init_args for the inner loop.
87  SmallVector<scf::ForOp> newLoopNest;
88  NewYieldValuesFn fn =
89  [&](OpBuilder &innerBuilder, Location loc,
90  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
91  newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
92  innerNewBBArgs, newYieldValuesFn,
93  replaceIterOperandsUsesInLoop);
94  return llvm::to_vector(llvm::map_range(
95  newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
96  [](OpResult r) -> Value { return r; }));
97  };
98  scf::ForOp outerMostLoop =
99  cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
100  rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
101  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
102  return newLoopNest;
103 }
104 
105 /// Outline a region with a single block into a new FuncOp.
106 /// Assumes the FuncOp result types is the type of the yielded operands of the
107 /// single block. This constraint makes it easy to determine the result.
108 /// This method also clones the `arith::ConstantIndexOp` at the start of
109 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
110 /// provided, it will be set to point to the operation that calls the outlined
111 /// function.
112 // TODO: support more than single-block regions.
113 // TODO: more flexible constant handling.
114 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
115  Location loc,
116  Region &region,
117  StringRef funcName,
118  func::CallOp *callOp) {
119  assert(!funcName.empty() && "funcName cannot be empty");
120  if (!region.hasOneBlock())
121  return failure();
122 
123  Block *originalBlock = &region.front();
124  Operation *originalTerminator = originalBlock->getTerminator();
125 
126  // Outline before current function.
127  OpBuilder::InsertionGuard g(rewriter);
128  rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
129 
130  SetVector<Value> captures;
131  getUsedValuesDefinedAbove(region, captures);
132 
133  ValueRange outlinedValues(captures.getArrayRef());
134  SmallVector<Type> outlinedFuncArgTypes;
135  SmallVector<Location> outlinedFuncArgLocs;
136  // Region's arguments are exactly the first block's arguments as per
137  // Region::getArguments().
138  // Func's arguments are cat(regions's arguments, captures arguments).
139  for (BlockArgument arg : region.getArguments()) {
140  outlinedFuncArgTypes.push_back(arg.getType());
141  outlinedFuncArgLocs.push_back(arg.getLoc());
142  }
143  for (Value value : outlinedValues) {
144  outlinedFuncArgTypes.push_back(value.getType());
145  outlinedFuncArgLocs.push_back(value.getLoc());
146  }
147  FunctionType outlinedFuncType =
148  FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
149  originalTerminator->getOperandTypes());
150  auto outlinedFunc =
151  func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
152  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
153 
154  // Merge blocks while replacing the original block operands.
155  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
156  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
157  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
158  {
159  OpBuilder::InsertionGuard g(rewriter);
160  rewriter.setInsertionPointToEnd(outlinedFuncBody);
161  rewriter.mergeBlocks(
162  originalBlock, outlinedFuncBody,
163  outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
164  // Explicitly set up a new ReturnOp terminator.
165  rewriter.setInsertionPointToEnd(outlinedFuncBody);
166  func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
167  originalTerminator->getOperands());
168  }
169 
170  // Reconstruct the block that was deleted and add a
171  // terminator(call_results).
172  Block *newBlock = rewriter.createBlock(
173  &region, region.begin(),
174  TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
175  ArrayRef<Location>(outlinedFuncArgLocs)
176  .take_front(numOriginalBlockArguments));
177  {
178  OpBuilder::InsertionGuard g(rewriter);
179  rewriter.setInsertionPointToEnd(newBlock);
180  SmallVector<Value> callValues;
181  llvm::append_range(callValues, newBlock->getArguments());
182  llvm::append_range(callValues, outlinedValues);
183  auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
184  if (callOp)
185  *callOp = call;
186 
187  // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
188  // Clone `originalTerminator` to take the callOp results then erase it from
189  // `outlinedFuncBody`.
190  IRMapping bvm;
191  bvm.map(originalTerminator->getOperands(), call->getResults());
192  rewriter.clone(*originalTerminator, bvm);
193  rewriter.eraseOp(originalTerminator);
194  }
195 
196  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
197  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
198  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
199  outlinedValues.size()))) {
200  Value orig = std::get<0>(it);
201  Value repl = std::get<1>(it);
202  {
203  OpBuilder::InsertionGuard g(rewriter);
204  rewriter.setInsertionPointToStart(outlinedFuncBody);
205  if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
206  repl = rewriter.clone(*cst)->getResult(0);
207  }
208  }
209  orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
210  return outlinedFunc->isProperAncestor(opOperand.getOwner());
211  });
212  }
213 
214  return outlinedFunc;
215 }
216 
217 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
218  func::FuncOp *thenFn, StringRef thenFnName,
219  func::FuncOp *elseFn, StringRef elseFnName) {
220  IRRewriter rewriter(b);
221  Location loc = ifOp.getLoc();
222  FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
223  if (thenFn && !ifOp.getThenRegion().empty()) {
224  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
225  rewriter, loc, ifOp.getThenRegion(), thenFnName);
226  if (failed(outlinedFuncOpOrFailure))
227  return failure();
228  *thenFn = *outlinedFuncOpOrFailure;
229  }
230  if (elseFn && !ifOp.getElseRegion().empty()) {
231  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
232  rewriter, loc, ifOp.getElseRegion(), elseFnName);
233  if (failed(outlinedFuncOpOrFailure))
234  return failure();
235  *elseFn = *outlinedFuncOpOrFailure;
236  }
237  return success();
238 }
239 
242  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
243  bool rootEnclosesPloops = false;
244  for (Region &region : rootOp->getRegions()) {
245  for (Block &block : region.getBlocks()) {
246  for (Operation &op : block) {
247  bool enclosesPloops = getInnermostParallelLoops(&op, result);
248  rootEnclosesPloops |= enclosesPloops;
249  if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
250  rootEnclosesPloops = true;
251 
252  // Collect parallel loop if it is an innermost one.
253  if (!enclosesPloops)
254  result.push_back(ploop);
255  }
256  }
257  }
258  }
259  return rootEnclosesPloops;
260 }
261 
262 // Build the IR that performs ceil division of a positive value by a constant:
263 // ceildiv(a, B) = divis(a + (B-1), B)
264 // where divis is rounding-to-zero division.
265 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
266  int64_t divisor) {
267  assert(divisor > 0 && "expected positive divisor");
268  assert(dividend.getType().isIntOrIndex() &&
269  "expected integer or index-typed value");
270 
271  Value divisorMinusOneCst = arith::ConstantOp::create(
272  builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
273  Value divisorCst = arith::ConstantOp::create(
274  builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
275  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
276  return arith::DivUIOp::create(builder, loc, sum, divisorCst);
277 }
278 
279 // Build the IR that performs ceil division of a positive value by another
280 // positive value:
281 // ceildiv(a, b) = divis(a + (b - 1), b)
282 // where divis is rounding-to-zero division.
283 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
284  Value divisor) {
285  assert(dividend.getType().isIntOrIndex() &&
286  "expected integer or index-typed value");
287  Value cstOne = arith::ConstantOp::create(
288  builder, loc, builder.getOneAttr(dividend.getType()));
289  Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
290  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
291  return arith::DivUIOp::create(builder, loc, sum, divisor);
292 }
293 
295  Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
296  function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
297  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
298  ValueRange iterArgs, ValueRange yieldedValues,
299  IRMapping *clonedToSrcOpsMap) {
300 
301  // Check if the op was cloned from another source op, and return it if found
302  // (or the same op if not found)
303  auto findOriginalSrcOp =
304  [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
305  Operation *srcOp = op;
306  // If the source op derives from another op: traverse the chain to find the
307  // original source op
308  while (srcOp && clonedToSrcOpsMap.contains(srcOp))
309  srcOp = clonedToSrcOpsMap.lookup(srcOp);
310  return srcOp;
311  };
312 
313  // Builder to insert unrolled bodies just before the terminator of the body of
314  // the loop.
315  auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
316 
317  static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
318  if (!annotateFn)
319  annotateFn = noopAnnotateFn;
320 
321  // Keep a pointer to the last non-terminator operation in the original block
322  // so that we know what to clone (since we are doing this in-place).
323  Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
324 
325  // Unroll the contents of the loop body (append unrollFactor - 1 additional
326  // copies).
327  SmallVector<Value, 4> lastYielded(yieldedValues);
328 
329  for (unsigned i = 1; i < unrollFactor; i++) {
330  // Prepare operand map.
331  IRMapping operandMap;
332  operandMap.map(iterArgs, lastYielded);
333 
334  // If the induction variable is used, create a remapping to the value for
335  // this unrolled instance.
336  if (!iv.use_empty()) {
337  Value ivUnroll = ivRemapFn(i, iv, builder);
338  operandMap.map(iv, ivUnroll);
339  }
340 
341  // Clone the original body of 'forOp'.
342  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
343  Operation *srcOp = &(*it);
344  Operation *clonedOp = builder.clone(*srcOp, operandMap);
345  annotateFn(i, clonedOp, builder);
346  if (clonedToSrcOpsMap)
347  clonedToSrcOpsMap->map(clonedOp,
348  findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
349  }
350 
351  // Update yielded values.
352  for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
353  lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
354  }
355 
356  // Make sure we annotate the Ops in the original body. We do this last so that
357  // any annotations are not copied into the cloned Ops above.
358  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
359  annotateFn(0, &*it, builder);
360 
361  // Update operands of the yield statement.
362  loopBodyBlock->getTerminator()->setOperands(lastYielded);
363 }
364 
365 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
366 /// epilogue loop, if the loop is unrolled.
367 FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
368  scf::ForOp forOp, uint64_t unrollFactor,
369  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
370  assert(unrollFactor > 0 && "expected positive unroll factor");
371 
372  // Return if the loop body is empty.
373  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
374  return UnrolledLoopInfo{forOp, std::nullopt};
375 
376  // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
377  // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
378  OpBuilder boundsBuilder(forOp);
379  IRRewriter rewriter(forOp.getContext());
380  auto loc = forOp.getLoc();
381  Value step = forOp.getStep();
382  Value upperBoundUnrolled;
383  Value stepUnrolled;
384  bool generateEpilogueLoop = true;
385 
386  std::optional<APInt> constTripCount = forOp.getStaticTripCount();
387  if (constTripCount) {
388  // Constant loop bounds computation.
389  int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
390  int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
391  int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
392  if (unrollFactor == 1) {
393  if (*constTripCount == 1 &&
394  failed(forOp.promoteIfSingleIteration(rewriter)))
395  return failure();
396  return UnrolledLoopInfo{forOp, std::nullopt};
397  }
398 
399  int64_t tripCountEvenMultiple =
400  constTripCount->getSExtValue() -
401  (constTripCount->getSExtValue() % unrollFactor);
402  int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
403  int64_t stepUnrolledCst = stepCst * unrollFactor;
404 
405  // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
406  generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
407  if (generateEpilogueLoop)
408  upperBoundUnrolled = arith::ConstantOp::create(
409  boundsBuilder, loc,
410  boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
411  upperBoundUnrolledCst));
412  else
413  upperBoundUnrolled = forOp.getUpperBound();
414 
415  // Create constant for 'stepUnrolled'.
416  stepUnrolled =
417  stepCst == stepUnrolledCst
418  ? step
419  : arith::ConstantOp::create(boundsBuilder, loc,
420  boundsBuilder.getIntegerAttr(
421  step.getType(), stepUnrolledCst));
422  } else {
423  // Dynamic loop bounds computation.
424  // TODO: Add dynamic asserts for negative lb/ub/step, or
425  // consider using ceilDiv from AffineApplyExpander.
426  auto lowerBound = forOp.getLowerBound();
427  auto upperBound = forOp.getUpperBound();
428  Value diff =
429  arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
430  Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
431  Value unrollFactorCst = arith::ConstantOp::create(
432  boundsBuilder, loc,
433  boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
434  Value tripCountRem =
435  arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
436  // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
437  Value tripCountEvenMultiple =
438  arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
439  // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
440  upperBoundUnrolled = arith::AddIOp::create(
441  boundsBuilder, loc, lowerBound,
442  arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
443  // Scale 'step' by 'unrollFactor'.
444  stepUnrolled =
445  arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
446  }
447 
448  UnrolledLoopInfo resultLoops;
449 
450  // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
451  if (generateEpilogueLoop) {
452  OpBuilder epilogueBuilder(forOp->getContext());
453  epilogueBuilder.setInsertionPointAfter(forOp);
454  auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
455  epilogueForOp.setLowerBound(upperBoundUnrolled);
456 
457  // Update uses of loop results.
458  auto results = forOp.getResults();
459  auto epilogueResults = epilogueForOp.getResults();
460 
461  for (auto e : llvm::zip(results, epilogueResults)) {
462  std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
463  }
464  epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
465  epilogueForOp.getInitArgs().size(), results);
466  if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
467  resultLoops.epilogueLoopOp = epilogueForOp;
468  }
469 
470  // Create unrolled loop.
471  forOp.setUpperBound(upperBoundUnrolled);
472  forOp.setStep(stepUnrolled);
473 
474  auto iterArgs = ValueRange(forOp.getRegionIterArgs());
475  auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
476 
478  forOp.getBody(), forOp.getInductionVar(), unrollFactor,
479  [&](unsigned i, Value iv, OpBuilder b) {
480  // iv' = iv + step * i;
481  auto stride = arith::MulIOp::create(
482  b, loc, step,
483  arith::ConstantOp::create(b, loc,
484  b.getIntegerAttr(iv.getType(), i)));
485  return arith::AddIOp::create(b, loc, iv, stride);
486  },
487  annotateFn, iterArgs, yieldedValues);
488  // Promote the loop body up if this has turned into a single iteration loop.
489  if (forOp.promoteIfSingleIteration(rewriter).failed())
490  resultLoops.mainLoopOp = forOp;
491  return resultLoops;
492 }
493 
494 /// Unrolls this loop completely.
495 LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
496  IRRewriter rewriter(forOp.getContext());
497  std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
498  if (!mayBeConstantTripCount.has_value())
499  return failure();
500  const APInt &tripCount = *mayBeConstantTripCount;
501  if (tripCount.isZero())
502  return success();
503  if (tripCount.getSExtValue() == 1)
504  return forOp.promoteIfSingleIteration(rewriter);
505  return loopUnrollByFactor(forOp, tripCount.getSExtValue());
506 }
507 
508 /// Check if bounds of all inner loops are defined outside of `forOp`
509 /// and return false if not.
510 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
511  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
512  if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
513  !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
514  !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
515  return WalkResult::interrupt();
516 
517  return WalkResult::advance();
518  });
519  return !walkResult.wasInterrupted();
520 }
521 
522 /// Unrolls and jams this loop by the specified factor.
523 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
524  uint64_t unrollJamFactor) {
525  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
526 
527  if (unrollJamFactor == 1)
528  return success();
529 
530  // If any control operand of any inner loop of `forOp` is defined within
531  // `forOp`, no unroll jam.
532  if (!areInnerBoundsInvariant(forOp)) {
533  LDBG() << "failed to unroll and jam: inner bounds are not invariant";
534  return failure();
535  }
536 
537  // Currently, for operations with results are not supported.
538  if (forOp->getNumResults() > 0) {
539  LDBG() << "failed to unroll and jam: unsupported loop with results";
540  return failure();
541  }
542 
543  // Currently, only constant trip count that divided by the unroll factor is
544  // supported.
545  std::optional<APInt> tripCount = forOp.getStaticTripCount();
546  if (!tripCount.has_value()) {
547  // If the trip count is dynamic, do not unroll & jam.
548  LDBG() << "failed to unroll and jam: trip count could not be determined";
549  return failure();
550  }
551  if (unrollJamFactor > tripCount->getZExtValue()) {
552  LDBG() << "unroll and jam factor is greater than trip count, set factor to "
553  "trip "
554  "count";
555  unrollJamFactor = tripCount->getZExtValue();
556  } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
557  LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
558  "multiple of unroll jam factor";
559  return failure();
560  }
561 
562  // Nothing in the loop body other than the terminator.
563  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
564  return success();
565 
566  // Gather all sub-blocks to jam upon the loop being unrolled.
568  jbg.walk(forOp);
569  auto &subBlocks = jbg.subBlocks;
570 
571  // Collect inner loops.
572  SmallVector<scf::ForOp> innerLoops;
573  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
574 
575  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
576  // iteration. There are (`unrollJamFactor` - 1) iterations.
577  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
578 
579  // For any loop with iter_args, replace it with a new loop that has
580  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
581  // operands.
582  SmallVector<scf::ForOp> newInnerLoops;
583  IRRewriter rewriter(forOp.getContext());
584  for (scf::ForOp oldForOp : innerLoops) {
585  SmallVector<Value> dupIterOperands, dupYieldOperands;
586  ValueRange oldIterOperands = oldForOp.getInits();
587  ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
588  ValueRange oldYieldOperands =
589  cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
590  // Get additional iterOperands, iterArgs, and yield operands. We will
591  // fix iterOperands and yield operands after cloning of sub-blocks.
592  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
593  dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
594  dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
595  }
596  // Create a new loop with additional iterOperands, iter_args and yield
597  // operands. This new loop will take the loop body of the original loop.
598  bool forOpReplaced = oldForOp == forOp;
599  scf::ForOp newForOp =
600  cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
601  rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
602  [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
603  return dupYieldOperands;
604  }));
605  newInnerLoops.push_back(newForOp);
606  // `forOp` has been replaced with a new loop.
607  if (forOpReplaced)
608  forOp = newForOp;
609  // Update `operandMaps` for `newForOp` iterArgs and results.
610  ValueRange newIterArgs = newForOp.getRegionIterArgs();
611  unsigned oldNumIterArgs = oldIterArgs.size();
612  ValueRange newResults = newForOp.getResults();
613  unsigned oldNumResults = newResults.size() / unrollJamFactor;
614  assert(oldNumIterArgs == oldNumResults &&
615  "oldNumIterArgs must be the same as oldNumResults");
616  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
617  for (unsigned j = 0; j < oldNumIterArgs; ++j) {
618  // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
619  // results. Update `operandMaps[i - 1]` to map old iterArgs and results
620  // to those in the `i`th new set.
621  operandMaps[i - 1].map(newIterArgs[j],
622  newIterArgs[i * oldNumIterArgs + j]);
623  operandMaps[i - 1].map(newResults[j],
624  newResults[i * oldNumResults + j]);
625  }
626  }
627  }
628 
629  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
630  rewriter.setInsertionPoint(forOp);
631  int64_t step = forOp.getConstantStep()->getSExtValue();
632  auto newStep = rewriter.createOrFold<arith::MulIOp>(
633  forOp.getLoc(), forOp.getStep(),
634  rewriter.createOrFold<arith::ConstantOp>(
635  forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
636  forOp.setStep(newStep);
637  auto forOpIV = forOp.getInductionVar();
638 
639  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
640  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
641  for (auto &subBlock : subBlocks) {
642  // Builder to insert unroll-jammed bodies. Insert right at the end of
643  // sub-block.
644  OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
645 
646  // If the induction variable is used, create a remapping to the value for
647  // this unrolled instance.
648  if (!forOpIV.use_empty()) {
649  // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
650  auto ivTag = builder.createOrFold<arith::ConstantOp>(
651  forOp.getLoc(), builder.getIndexAttr(step * i));
652  auto ivUnroll =
653  builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
654  operandMaps[i - 1].map(forOpIV, ivUnroll);
655  }
656  // Clone the sub-block being unroll-jammed.
657  for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
658  builder.clone(*it, operandMaps[i - 1]);
659  }
660  // Fix iterOperands and yield op operands of newly created loops.
661  for (auto newForOp : newInnerLoops) {
662  unsigned oldNumIterOperands =
663  newForOp.getNumRegionIterArgs() / unrollJamFactor;
664  unsigned numControlOperands = newForOp.getNumControlOperands();
665  auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
666  unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
667  assert(oldNumIterOperands == oldNumYieldOperands &&
668  "oldNumIterOperands must be the same as oldNumYieldOperands");
669  for (unsigned j = 0; j < oldNumIterOperands; ++j) {
670  // The `i`th duplication of an old iterOperand or yield op operand
671  // needs to be replaced with a mapped value from `operandMaps[i - 1]`
672  // if such mapped value exists.
673  newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
674  operandMaps[i - 1].lookupOrDefault(
675  newForOp.getOperand(numControlOperands + j)));
676  yieldOp.setOperand(
677  i * oldNumYieldOperands + j,
678  operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
679  }
680  }
681  }
682 
683  // Promote the loop body up if this has turned into a single iteration loop.
684  (void)forOp.promoteIfSingleIteration(rewriter);
685  return success();
686 }
687 
689  Location loc, OpFoldResult lb,
690  OpFoldResult ub,
691  OpFoldResult step) {
692  Range normalizedLoopBounds;
693  normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
694  normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
695  AffineExpr s0, s1, s2;
696  bindSymbols(rewriter.getContext(), s0, s1, s2);
697  AffineExpr e = (s1 - s0).ceilDiv(s2);
698  normalizedLoopBounds.size =
699  affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
700  return normalizedLoopBounds;
701 }
702 
704  OpFoldResult lb, OpFoldResult ub,
705  OpFoldResult step) {
706  if (getType(lb).isIndex()) {
707  return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
708  }
709  // For non-index types, generate `arith` instructions
710  // Check if the loop is already known to have a constant zero lower bound or
711  // a constant one step.
712  bool isZeroBased = false;
713  if (auto lbCst = getConstantIntValue(lb))
714  isZeroBased = lbCst.value() == 0;
715 
716  bool isStepOne = false;
717  if (auto stepCst = getConstantIntValue(step))
718  isStepOne = stepCst.value() == 1;
719 
720  Type rangeType = getType(lb);
721  assert(rangeType == getType(ub) && rangeType == getType(step) &&
722  "expected matching types");
723 
724  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
725  // assuming the step is strictly positive. Update the bounds and the step
726  // of the loop to go from 0 to the number of iterations, if necessary.
727  if (isZeroBased && isStepOne)
728  return {lb, ub, step};
729 
730  OpFoldResult diff = ub;
731  if (!isZeroBased) {
732  diff = rewriter.createOrFold<arith::SubIOp>(
733  loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
734  getValueOrCreateConstantIntOp(rewriter, loc, lb));
735  }
736  OpFoldResult newUpperBound = diff;
737  if (!isStepOne) {
738  newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
739  loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
740  getValueOrCreateConstantIntOp(rewriter, loc, step));
741  }
742 
743  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
744  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
745 
746  return {newLowerBound, newUpperBound, newStep};
747 }
748 
750  Location loc,
751  Value normalizedIv,
752  OpFoldResult origLb,
753  OpFoldResult origStep) {
754  AffineExpr d0, s0, s1;
755  bindSymbols(rewriter.getContext(), s0, s1);
756  bindDims(rewriter.getContext(), d0);
757  AffineExpr e = d0 * s1 + s0;
759  rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
760  Value denormalizedIvVal =
761  getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
762  SmallPtrSet<Operation *, 1> preservedUses;
763  // If an `affine.apply` operation is generated for denormalization, the use
764  // of `origLb` in those ops must not be replaced. These arent not generated
765  // when `origLb == 0` and `origStep == 1`.
766  if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
767  if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
768  preservedUses.insert(preservedUse);
769  }
770  }
771  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
772 }
773 
775  Value normalizedIv, OpFoldResult origLb,
776  OpFoldResult origStep) {
777  if (getType(origLb).isIndex()) {
778  return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
779  origLb, origStep);
780  }
781  Value denormalizedIv;
783  bool isStepOne = isOneInteger(origStep);
784  bool isZeroBased = isZeroInteger(origLb);
785 
786  Value scaled = normalizedIv;
787  if (!isStepOne) {
788  Value origStepValue =
789  getValueOrCreateConstantIntOp(rewriter, loc, origStep);
790  scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
791  preserve.insert(scaled.getDefiningOp());
792  }
793  denormalizedIv = scaled;
794  if (!isZeroBased) {
795  Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
796  denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
797  preserve.insert(denormalizedIv.getDefiningOp());
798  }
799 
800  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
801 }
802 
804  ArrayRef<OpFoldResult> values) {
805  assert(!values.empty() && "unexecpted empty array");
806  AffineExpr s0, s1;
807  bindSymbols(rewriter.getContext(), s0, s1);
808  AffineExpr mul = s0 * s1;
809  OpFoldResult products = rewriter.getIndexAttr(1);
810  for (auto v : values) {
812  rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
813  }
814  return products;
815 }
816 
817 /// Helper function to multiply a sequence of values.
819  ArrayRef<Value> values) {
820  assert(!values.empty() && "unexpected empty list");
821  if (getType(values.front()).isIndex()) {
823  OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
824  return getValueOrCreateConstantIndexOp(rewriter, loc, product);
825  }
826  std::optional<Value> productOf;
827  for (auto v : values) {
828  auto vOne = getConstantIntValue(v);
829  if (vOne && vOne.value() == 1)
830  continue;
831  if (productOf)
832  productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
833  .getResult();
834  else
835  productOf = v;
836  }
837  if (!productOf) {
838  productOf = arith::ConstantOp::create(
839  rewriter, loc, rewriter.getOneAttr(getType(values.front())))
840  .getResult();
841  }
842  return productOf.value();
843 }
844 
845 /// For each original loop, the value of the
846 /// induction variable can be obtained by dividing the induction variable of
847 /// the linearized loop by the total number of iterations of the loops nested
848 /// in it modulo the number of iterations in this loop (remove the values
849 /// related to the outer loops):
850 /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
851 /// Compute these iteratively from the innermost loop by creating a "running
852 /// quotient" of division by the range.
853 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
855  Value linearizedIv, ArrayRef<Value> ubs) {
856 
857  if (linearizedIv.getType().isIndex()) {
858  Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
859  rewriter, loc, linearizedIv, ubs);
860  auto resultVals = llvm::map_to_vector(
861  delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
862  return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
863  }
864 
865  SmallVector<Value> delinearizedIvs(ubs.size());
866  SmallPtrSet<Operation *, 2> preservedUsers;
867 
868  llvm::BitVector isUbOne(ubs.size());
869  for (auto [index, ub] : llvm::enumerate(ubs)) {
870  auto ubCst = getConstantIntValue(ub);
871  if (ubCst && ubCst.value() == 1)
872  isUbOne.set(index);
873  }
874 
875  // Prune the lead ubs that are all ones.
876  unsigned numLeadingOneUbs = 0;
877  for (auto [index, ub] : llvm::enumerate(ubs)) {
878  if (!isUbOne.test(index)) {
879  break;
880  }
881  delinearizedIvs[index] = arith::ConstantOp::create(
882  rewriter, loc, rewriter.getZeroAttr(ub.getType()));
883  numLeadingOneUbs++;
884  }
885 
886  Value previous = linearizedIv;
887  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
888  unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
889  if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
890  previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
891  preservedUsers.insert(previous.getDefiningOp());
892  }
893  Value iv = previous;
894  if (i != e - 1) {
895  if (!isUbOne.test(idx)) {
896  iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
897  preservedUsers.insert(iv.getDefiningOp());
898  } else {
899  iv = arith::ConstantOp::create(
900  rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
901  }
902  }
903  delinearizedIvs[idx] = iv;
904  }
905  return {delinearizedIvs, preservedUsers};
906 }
907 
908 LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
910  if (loops.size() < 2)
911  return failure();
912 
913  scf::ForOp innermost = loops.back();
914  scf::ForOp outermost = loops.front();
915 
916  // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
917  // allows the following code to assume upperBound is the number of iterations.
918  for (auto loop : loops) {
919  OpBuilder::InsertionGuard g(rewriter);
920  rewriter.setInsertionPoint(outermost);
921  Value lb = loop.getLowerBound();
922  Value ub = loop.getUpperBound();
923  Value step = loop.getStep();
924  auto newLoopRange =
925  emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
926 
927  rewriter.modifyOpInPlace(loop, [&]() {
928  loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
929  newLoopRange.offset));
930  loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
931  newLoopRange.size));
932  loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
933  newLoopRange.stride));
934  });
935  rewriter.setInsertionPointToStart(innermost.getBody());
936  denormalizeInductionVariable(rewriter, loop.getLoc(),
937  loop.getInductionVar(), lb, step);
938  }
939 
940  // 2. Emit code computing the upper bound of the coalesced loop as product
941  // of the number of iterations of all loops.
942  OpBuilder::InsertionGuard g(rewriter);
943  rewriter.setInsertionPoint(outermost);
944  Location loc = outermost.getLoc();
945  SmallVector<Value> upperBounds = llvm::map_to_vector(
946  loops, [](auto loop) { return loop.getUpperBound(); });
947  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
948  outermost.setUpperBound(upperBound);
949 
950  rewriter.setInsertionPointToStart(innermost.getBody());
951  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
952  rewriter, loc, outermost.getInductionVar(), upperBounds);
953  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
954  preservedUsers);
955 
956  for (int i = loops.size() - 1; i > 0; --i) {
957  auto outerLoop = loops[i - 1];
958  auto innerLoop = loops[i];
959 
960  Operation *innerTerminator = innerLoop.getBody()->getTerminator();
961  auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
962  assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
963  for (Value &yieldedVal : yieldedVals) {
964  // The yielded value may be an iteration argument of the inner loop
965  // which is about to be inlined.
966  auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
967  if (iter != innerLoop.getRegionIterArgs().end()) {
968  unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
969  // `outerLoop` iter args identical to the `innerLoop` init args.
970  assert(iterArgIndex < innerLoop.getInitArgs().size());
971  yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
972  }
973  }
974  rewriter.eraseOp(innerTerminator);
975 
976  SmallVector<Value> innerBlockArgs;
977  innerBlockArgs.push_back(delinearizeIvs[i]);
978  llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
979  rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
980  Block::iterator(innerLoop), innerBlockArgs);
981  rewriter.replaceOp(innerLoop, yieldedVals);
982  }
983  return success();
984 }
985 
987  if (loops.empty()) {
988  return failure();
989  }
990  IRRewriter rewriter(loops.front().getContext());
991  return coalesceLoops(rewriter, loops);
992 }
993 
994 LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
995  LogicalResult result(failure());
997  getPerfectlyNestedLoops(loops, op);
998 
999  // Look for a band of loops that can be coalesced, i.e. perfectly nested
1000  // loops with bounds defined above some loop.
1001 
1002  // 1. For each loop, find above which parent loop its bounds operands are
1003  // defined.
1004  SmallVector<unsigned> operandsDefinedAbove(loops.size());
1005  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1006  operandsDefinedAbove[i] = i;
1007  for (unsigned j = 0; j < i; ++j) {
1008  SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1009  loops[i].getUpperBound(),
1010  loops[i].getStep()};
1011  if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1012  operandsDefinedAbove[i] = j;
1013  break;
1014  }
1015  }
1016  }
1017 
1018  // 2. For each inner loop check that the iter_args for the immediately outer
1019  // loop are the init for the immediately inner loop and that the yields of the
1020  // return of the inner loop is the yield for the immediately outer loop. Keep
1021  // track of where the chain starts from for each loop.
1022  SmallVector<unsigned> iterArgChainStart(loops.size());
1023  iterArgChainStart[0] = 0;
1024  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1025  // By default set the start of the chain to itself.
1026  iterArgChainStart[i] = i;
1027  auto outerloop = loops[i - 1];
1028  auto innerLoop = loops[i];
1029  if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1030  continue;
1031  }
1032  if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1033  continue;
1034  }
1035  auto outerloopTerminator = outerloop.getBody()->getTerminator();
1036  if (!llvm::equal(outerloopTerminator->getOperands(),
1037  innerLoop.getResults())) {
1038  continue;
1039  }
1040  iterArgChainStart[i] = iterArgChainStart[i - 1];
1041  }
1042 
1043  // 3. Identify bands of loops such that the operands of all of them are
1044  // defined above the first loop in the band. Traverse the nest bottom-up
1045  // so that modifications don't invalidate the inner loops.
1046  for (unsigned end = loops.size(); end > 0; --end) {
1047  unsigned start = 0;
1048  for (; start < end - 1; ++start) {
1049  auto maxPos =
1050  *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1051  std::next(operandsDefinedAbove.begin(), end));
1052  if (maxPos > start)
1053  continue;
1054  if (iterArgChainStart[end - 1] > start)
1055  continue;
1056  auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1057  if (succeeded(coalesceLoops(band)))
1058  result = success();
1059  break;
1060  }
1061  // If a band was found and transformed, keep looking at the loops above
1062  // the outermost transformed loop.
1063  if (start != end - 1)
1064  end = start + 1;
1065  }
1066  return result;
1067 }
1068 
1070  RewriterBase &rewriter, scf::ParallelOp loops,
1071  ArrayRef<std::vector<unsigned>> combinedDimensions) {
1072  OpBuilder::InsertionGuard g(rewriter);
1073  rewriter.setInsertionPoint(loops);
1074  Location loc = loops.getLoc();
1075 
1076  // Presort combined dimensions.
1077  auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1078  for (auto &dims : sortedDimensions)
1079  llvm::sort(dims);
1080 
1081  // Normalize ParallelOp's iteration pattern.
1082  SmallVector<Value, 3> normalizedUpperBounds;
1083  for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1084  OpBuilder::InsertionGuard g2(rewriter);
1085  rewriter.setInsertionPoint(loops);
1086  Value lb = loops.getLowerBound()[i];
1087  Value ub = loops.getUpperBound()[i];
1088  Value step = loops.getStep()[i];
1089  auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1090  normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1091  rewriter, loops.getLoc(), newLoopRange.size));
1092 
1093  rewriter.setInsertionPointToStart(loops.getBody());
1094  denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1095  step);
1096  }
1097 
1098  // Combine iteration spaces.
1099  SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1100  auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1101  auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1102  for (auto &sortedDimension : sortedDimensions) {
1103  Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1104  for (auto idx : sortedDimension) {
1105  newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1106  normalizedUpperBounds[idx]);
1107  }
1108  lowerBounds.push_back(cst0);
1109  steps.push_back(cst1);
1110  upperBounds.push_back(newUpperBound);
1111  }
1112 
1113  // Create new ParallelLoop with conversions to the original induction values.
1114  // The loop below uses divisions to get the relevant range of values in the
1115  // new induction value that represent each range of the original induction
1116  // value. The remainders then determine based on that range, which iteration
1117  // of the original induction value this represents. This is a normalized value
1118  // that is un-normalized already by the previous logic.
1119  auto newPloop = scf::ParallelOp::create(
1120  rewriter, loc, lowerBounds, upperBounds, steps,
1121  [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1122  for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1123  Value previous = ploopIVs[i];
1124  unsigned numberCombinedDimensions = combinedDimensions[i].size();
1125  // Iterate over all except the last induction value.
1126  for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1127  unsigned idx = combinedDimensions[i][j];
1128 
1129  // Determine the current induction value's current loop iteration
1130  Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1131  normalizedUpperBounds[idx]);
1132  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1133  loops.getRegion());
1134 
1135  // Remove the effect of the current induction value to prepare for
1136  // the next value.
1137  previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1138  normalizedUpperBounds[idx]);
1139  }
1140 
1141  // The final induction value is just the remaining value.
1142  unsigned idx = combinedDimensions[i][0];
1143  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1144  previous, loops.getRegion());
1145  }
1146  });
1147 
1148  // Replace the old loop with the new loop.
1149  loops.getBody()->back().erase();
1150  newPloop.getBody()->getOperations().splice(
1151  Block::iterator(newPloop.getBody()->back()),
1152  loops.getBody()->getOperations());
1153  loops.erase();
1154 }
1155 
1156 // Hoist the ops within `outer` that appear before `inner`.
1157 // Such ops include the ops that have been introduced by parametric tiling.
1158 // Ops that come from triangular loops (i.e. that belong to the program slice
1159 // rooted at `outer`) and ops that have side effects cannot be hoisted.
1160 // Return failure when any op fails to hoist.
1161 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1162  SetVector<Operation *> forwardSlice;
1164  options.filter = [&inner](Operation *op) {
1165  return op != inner.getOperation();
1166  };
1167  getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1168  LogicalResult status = success();
1170  for (auto &op : outer.getBody()->without_terminator()) {
1171  // Stop when encountering the inner loop.
1172  if (&op == inner.getOperation())
1173  break;
1174  // Skip over non-hoistable ops.
1175  if (forwardSlice.count(&op) > 0) {
1176  status = failure();
1177  continue;
1178  }
1179  // Skip intermediate scf::ForOp, these are not considered a failure.
1180  if (isa<scf::ForOp>(op))
1181  continue;
1182  // Skip other ops with regions.
1183  if (op.getNumRegions() > 0) {
1184  status = failure();
1185  continue;
1186  }
1187  // Skip if op has side effects.
1188  // TODO: loads to immutable memory regions are ok.
1189  if (!isMemoryEffectFree(&op)) {
1190  status = failure();
1191  continue;
1192  }
1193  toHoist.push_back(&op);
1194  }
1195  auto *outerForOp = outer.getOperation();
1196  for (auto *op : toHoist)
1197  op->moveBefore(outerForOp);
1198  return status;
1199 }
1200 
1201 // Traverse the interTile and intraTile loops and try to hoist ops such that
1202 // bands of perfectly nested loops are isolated.
1203 // Return failure if either perfect interTile or perfect intraTile bands cannot
1204 // be formed.
1205 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1206  LogicalResult status = success();
1207  const Loops &interTile = tileLoops.first;
1208  const Loops &intraTile = tileLoops.second;
1209  auto size = interTile.size();
1210  assert(size == intraTile.size());
1211  if (size <= 1)
1212  return success();
1213  for (unsigned s = 1; s < size; ++s)
1214  status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1215  : failure();
1216  for (unsigned s = 1; s < size; ++s)
1217  status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1218  : failure();
1219  return status;
1220 }
1221 
1222 /// Collect perfectly nested loops starting from `rootForOps`. Loops are
1223 /// perfectly nested if each loop is the first and only non-terminator operation
1224 /// in the parent loop. Collect at most `maxLoops` loops and append them to
1225 /// `forOps`.
1226 template <typename T>
1228  SmallVectorImpl<T> &forOps, T rootForOp,
1229  unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1230  for (unsigned i = 0; i < maxLoops; ++i) {
1231  forOps.push_back(rootForOp);
1232  Block &body = rootForOp.getRegion().front();
1233  if (body.begin() != std::prev(body.end(), 2))
1234  return;
1235 
1236  rootForOp = dyn_cast<T>(&body.front());
1237  if (!rootForOp)
1238  return;
1239  }
1240 }
1241 
1242 static Loops stripmineSink(scf::ForOp forOp, Value factor,
1243  ArrayRef<scf::ForOp> targets) {
1244  assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
1245  auto originalStep = forOp.getStep();
1246  auto iv = forOp.getInductionVar();
1247 
1248  OpBuilder b(forOp);
1249  forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1250 
1251  Loops innerLoops;
1252  for (auto t : targets) {
1253  assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
1254 
1255  // Save information for splicing ops out of t when done
1256  auto begin = t.getBody()->begin();
1257  auto nOps = t.getBody()->getOperations().size();
1258 
1259  // Insert newForOp before the terminator of `t`.
1260  auto b = OpBuilder::atBlockTerminator((t.getBody()));
1261  Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1262  Value ub =
1263  arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1264 
1265  // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1266  auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1267  newForOp.getBody()->getOperations().splice(
1268  newForOp.getBody()->getOperations().begin(),
1269  t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1270  replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1271  newForOp.getRegion());
1272 
1273  innerLoops.push_back(newForOp);
1274  }
1275 
1276  return innerLoops;
1277 }
1278 
1279 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1280 // Returns the new for operation, nested immediately under `target`.
1281 template <typename SizeType>
1282 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1283  scf::ForOp target) {
1284  // TODO: Use cheap structural assertions that targets are nested under
1285  // forOp and that targets are not nested under each other when DominanceInfo
1286  // exposes the capability. It seems overkill to construct a whole function
1287  // dominance tree at this point.
1288  auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1289  assert(res.size() == 1 && "Expected 1 inner forOp");
1290  return res[0];
1291 }
1292 
1294  ArrayRef<Value> sizes,
1295  ArrayRef<scf::ForOp> targets) {
1297  SmallVector<scf::ForOp, 8> currentTargets(targets);
1298  for (auto it : llvm::zip(forOps, sizes)) {
1299  auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1300  res.push_back(step);
1301  currentTargets = step;
1302  }
1303  return res;
1304 }
1305 
1307  scf::ForOp target) {
1309  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1310  res.push_back(llvm::getSingleElement(loops));
1311  return res;
1312 }
1313 
1314 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1315  // Collect perfectly nested loops. If more size values provided than nested
1316  // loops available, truncate `sizes`.
1318  forOps.reserve(sizes.size());
1319  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1320  if (forOps.size() < sizes.size())
1321  sizes = sizes.take_front(forOps.size());
1322 
1323  return ::tile(forOps, sizes, forOps.back());
1324 }
1325 
1327  scf::ForOp root) {
1328  getPerfectlyNestedLoopsImpl(nestedLoops, root);
1329 }
1330 
1332  ArrayRef<int64_t> sizes) {
1333  // Collect perfectly nested loops. If more size values provided than nested
1334  // loops available, truncate `sizes`.
1336  forOps.reserve(sizes.size());
1337  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1338  if (forOps.size() < sizes.size())
1339  sizes = sizes.take_front(forOps.size());
1340 
1341  // Compute the tile sizes such that i-th outer loop executes size[i]
1342  // iterations. Given that the loop current executes
1343  // numIterations = ceildiv((upperBound - lowerBound), step)
1344  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1345  SmallVector<Value, 4> tileSizes;
1346  tileSizes.reserve(sizes.size());
1347  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1348  assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1349 
1350  auto forOp = forOps[i];
1351  OpBuilder builder(forOp);
1352  auto loc = forOp.getLoc();
1353  Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1354  forOp.getLowerBound());
1355  Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1356  Value iterationsPerBlock =
1357  ceilDivPositive(builder, loc, numIterations, sizes[i]);
1358  tileSizes.push_back(iterationsPerBlock);
1359  }
1360 
1361  // Call parametric tiling with the given sizes.
1362  auto intraTile = tile(forOps, tileSizes, forOps.back());
1363  TileLoops tileLoops = std::make_pair(forOps, intraTile);
1364 
1365  // TODO: for now we just ignore the result of band isolation.
1366  // In the future, mapping decisions may be impacted by the ability to
1367  // isolate perfectly nested bands.
1368  (void)tryIsolateBands(tileLoops);
1369 
1370  return tileLoops;
1371 }
1372 
1373 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1374  scf::ForallOp source,
1375  RewriterBase &rewriter) {
1376  unsigned numTargetOuts = target.getNumResults();
1377  unsigned numSourceOuts = source.getNumResults();
1378 
1379  // Create fused shared_outs.
1380  SmallVector<Value> fusedOuts;
1381  llvm::append_range(fusedOuts, target.getOutputs());
1382  llvm::append_range(fusedOuts, source.getOutputs());
1383 
1384  // Create a new scf.forall op after the source loop.
1385  rewriter.setInsertionPointAfter(source);
1386  scf::ForallOp fusedLoop = scf::ForallOp::create(
1387  rewriter, source.getLoc(), source.getMixedLowerBound(),
1388  source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1389  source.getMapping());
1390 
1391  // Map control operands.
1392  IRMapping mapping;
1393  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1394  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1395 
1396  // Map shared outs.
1397  mapping.map(target.getRegionIterArgs(),
1398  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1399  mapping.map(source.getRegionIterArgs(),
1400  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1401 
1402  // Append everything except the terminator into the fused operation.
1403  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1404  for (Operation &op : target.getBody()->without_terminator())
1405  rewriter.clone(op, mapping);
1406  for (Operation &op : source.getBody()->without_terminator())
1407  rewriter.clone(op, mapping);
1408 
1409  // Fuse the old terminator in_parallel ops into the new one.
1410  scf::InParallelOp targetTerm = target.getTerminator();
1411  scf::InParallelOp sourceTerm = source.getTerminator();
1412  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1413  rewriter.setInsertionPointToStart(fusedTerm.getBody());
1414  for (Operation &op : targetTerm.getYieldingOps())
1415  rewriter.clone(op, mapping);
1416  for (Operation &op : sourceTerm.getYieldingOps())
1417  rewriter.clone(op, mapping);
1418 
1419  // Replace old loops by substituting their uses by results of the fused loop.
1420  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1421  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1422 
1423  return fusedLoop;
1424 }
1425 
1426 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1427  scf::ForOp source,
1428  RewriterBase &rewriter) {
1429  assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1430  "incompatible signedness");
1431  unsigned numTargetOuts = target.getNumResults();
1432  unsigned numSourceOuts = source.getNumResults();
1433 
1434  // Create fused init_args, with target's init_args before source's init_args.
1435  SmallVector<Value> fusedInitArgs;
1436  llvm::append_range(fusedInitArgs, target.getInitArgs());
1437  llvm::append_range(fusedInitArgs, source.getInitArgs());
1438 
1439  // Create a new scf.for op after the source loop (with scf.yield terminator
1440  // (without arguments) only in case its init_args is empty).
1441  rewriter.setInsertionPointAfter(source);
1442  scf::ForOp fusedLoop = scf::ForOp::create(
1443  rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1444  source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
1445  source.getUnsignedCmp());
1446 
1447  // Map original induction variables and operands to those of the fused loop.
1448  IRMapping mapping;
1449  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1450  mapping.map(target.getRegionIterArgs(),
1451  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1452  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1453  mapping.map(source.getRegionIterArgs(),
1454  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1455 
1456  // Merge target's body into the new (fused) for loop and then source's body.
1457  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1458  for (Operation &op : target.getBody()->without_terminator())
1459  rewriter.clone(op, mapping);
1460  for (Operation &op : source.getBody()->without_terminator())
1461  rewriter.clone(op, mapping);
1462 
1463  // Build fused yield results by appropriately mapping original yield operands.
1464  SmallVector<Value> yieldResults;
1465  for (Value operand : target.getBody()->getTerminator()->getOperands())
1466  yieldResults.push_back(mapping.lookupOrDefault(operand));
1467  for (Value operand : source.getBody()->getTerminator()->getOperands())
1468  yieldResults.push_back(mapping.lookupOrDefault(operand));
1469  if (!yieldResults.empty())
1470  scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1471 
1472  // Replace old loops by substituting their uses by results of the fused loop.
1473  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1474  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1475 
1476  return fusedLoop;
1477 }
1478 
1479 FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1480  scf::ForallOp forallOp) {
1481  SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1482  SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1483  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1484 
1485  if (forallOp.isNormalized())
1486  return forallOp;
1487 
1488  OpBuilder::InsertionGuard g(rewriter);
1489  auto loc = forallOp.getLoc();
1490  rewriter.setInsertionPoint(forallOp);
1492  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1493  Range normalizedLoopParams =
1494  emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1495  newUbs.push_back(normalizedLoopParams.size);
1496  }
1497  (void)foldDynamicIndexList(newUbs);
1498 
1499  // Use the normalized builder since the lower bounds are always 0 and the
1500  // steps are always 1.
1501  auto normalizedForallOp = scf::ForallOp::create(
1502  rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1503  [](OpBuilder &, Location, ValueRange) {});
1504 
1505  rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1506  normalizedForallOp.getBodyRegion(),
1507  normalizedForallOp.getBodyRegion().begin());
1508  // Remove the original empty block in the new loop.
1509  rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1510 
1511  rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1512  // Update the users of the original loop variables.
1513  for (auto [idx, iv] :
1514  llvm::enumerate(normalizedForallOp.getInductionVars())) {
1515  auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1516  auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1517  denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1518  }
1519 
1520  rewriter.replaceOp(forallOp, normalizedForallOp);
1521  return normalizedForallOp;
1522 }
1523 
1526  assert(!loops.empty() && "unexpected empty loop nest");
1527  if (loops.size() == 1)
1528  return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1529  for (auto [outerLoop, innerLoop] :
1530  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1531  auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1532  auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1533  if (!outerFor || !innerFor)
1534  return false;
1535  auto outerBBArgs = outerFor.getRegionIterArgs();
1536  auto innerIterArgs = innerFor.getInitArgs();
1537  if (outerBBArgs.size() != innerIterArgs.size())
1538  return false;
1539 
1540  for (auto [outerBBArg, innerIterArg] :
1541  llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1542  if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1543  innerIterArg != outerBBArg)
1544  return false;
1545  }
1546 
1547  ValueRange outerYields =
1548  cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1549  ValueRange innerResults = innerFor.getResults();
1550  if (outerYields.size() != innerResults.size())
1551  return false;
1552  for (auto [outerYield, innerResult] :
1553  llvm::zip_equal(outerYields, innerResults)) {
1554  if (!llvm::hasSingleElement(innerResult.getUses()) ||
1555  outerYield != innerResult)
1556  return false;
1557  }
1558  }
1559  return true;
1560 }
1561 
1563 mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
1564  std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
1565  std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
1566  std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
1567  if (!loBnds || !upBnds || !steps)
1568  return {};
1569  llvm::SmallVector<int64_t> tripCounts;
1570  for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
1571  std::optional<llvm::APInt> numIter = constantTripCount(
1572  lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
1573  if (!numIter)
1574  return {};
1575  tripCounts.push_back(numIter->getSExtValue());
1576  }
1577  return tripCounts;
1578 }
1579 
1580 FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
1581  scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
1582  RewriterBase &rewriter,
1583  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1584  IRMapping *clonedToSrcOpsMap) {
1585  const unsigned numLoops = op.getNumLoops();
1586  assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
1587  "Expected positive unroll factors");
1588  assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
1589  "Expected non-empty unroll factors of size <= to the number of loops");
1590 
1591  // Bail out if no valid unroll factors were provided
1592  if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
1593  return rewriter.notifyMatchFailure(
1594  op, "Unrolling not applied if all factors are 1");
1595 
1596  // Return if the loop body is empty.
1597  if (llvm::hasSingleElement(op.getBody()->getOperations()))
1598  return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
1599 
1600  // If the provided unroll factors do not cover all the loop dims, they are
1601  // applied to the inner loop dimensions.
1602  const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
1603 
1604  // Make sure that the unroll factors divide the iteration space evenly
1605  // TODO: Support unrolling loops with dynamic iteration spaces.
1606  const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
1607  if (tripCounts.empty())
1608  return rewriter.notifyMatchFailure(
1609  op, "Failed to compute constant trip counts for the loop. Note that "
1610  "dynamic loop sizes are not supported.");
1611 
1612  for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1613  const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1614  if (tripCounts[dimIdx] % unrollFactor)
1615  return rewriter.notifyMatchFailure(
1616  op, "Unroll factors don't divide the iteration space evenly");
1617  }
1618 
1619  std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
1620  if (!maybeFoldSteps)
1621  return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
1622  llvm::SmallVector<size_t> steps{};
1623  for (auto step : *maybeFoldSteps)
1624  steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
1625 
1626  for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
1627  const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
1628  if (unrollFactor == 1)
1629  continue;
1630  const size_t origStep = steps[dimIdx];
1631  const int64_t newStep = origStep * unrollFactor;
1632  IRMapping clonedToSrcOpsMap;
1633 
1634  ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
1635  auto yieldedValues = op.getBody()->getTerminator()->getOperands();
1636 
1638  op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
1639  [&](unsigned i, Value iv, OpBuilder b) {
1640  // iv' = iv + step * i;
1641  const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
1642  const auto map =
1643  b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
1644  return affine::AffineApplyOp::create(b, iv.getLoc(), map,
1645  ValueRange{iv});
1646  },
1647  /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
1648 
1649  // Update loop step
1650  auto prevInsertPoint = rewriter.saveInsertionPoint();
1651  rewriter.setInsertionPoint(op);
1652  op.getStepMutable()[dimIdx].assign(
1653  arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
1654  rewriter.restoreInsertionPoint(prevInsertPoint);
1655  }
1656  return op;
1657 }
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition: Utils.cpp:803
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition: Utils.cpp:1205
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition: Utils.cpp:1227
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition: Utils.cpp:1161
static Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition: Utils.cpp:688
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition: Utils.cpp:1242
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition: Utils.cpp:854
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition: Utils.cpp:265
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition: Utils.cpp:818
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition: Utils.cpp:749
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition: Utils.cpp:510
static int64_t product(ArrayRef< int64_t > vals)
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of AffineForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor'...
Definition: LoopUtils.cpp:899
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
#define mul(a, b)
Base type for affine expression.
Definition: AffineExpr.h:68
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:342
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
Definition: Builders.h:385
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition: Builders.h:252
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
Definition: Builders.h:390
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:91
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1469
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1584
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition: SCF.cpp:114
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition: Utils.cpp:1326
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition: Utils.cpp:1524
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition: Utils.cpp:217
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:35
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:35
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition: Utils.cpp:994
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
void generateUnrolledLoop(Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues, IRMapping *clonedToSrcOpsMap=nullptr)
Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs' and 'yieldedValues' as the...
Definition: Utils.cpp:294
std::pair< Loops, Loops > TileLoops
Definition: Utils.h:155
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:102
llvm::SmallVector< int64_t > getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp)
Get constant trip counts for each of the induction variables of the given loop operation.
Definition: Utils.cpp:1563
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition: Utils.cpp:495
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition: Utils.cpp:1069
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition: Utils.cpp:1314
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition: Utils.cpp:367
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition: Utils.cpp:523
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:240
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
FailureOr< scf::ParallelOp > parallelLoopUnrollByFactors(scf::ParallelOp op, ArrayRef< uint64_t > unrollFactors, RewriterBase &rewriter, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, IRMapping *clonedToSrcOpsMap=nullptr)
Unroll this scf::Parallel loop by the specified unroll factors.
Definition: Utils.cpp:1580
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1293
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:114
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition: RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition: Utils.cpp:774
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:1373
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition: Utils.cpp:986
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition: Utils.cpp:1426
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition: Utils.cpp:1331
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition: Utils.cpp:703
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition: Utils.cpp:1479
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition: Utils.h:116
std::optional< scf::ForOp > mainLoopOp
Definition: Utils.h:115
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.