MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/APInt.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/DebugLog.h"
29 #include <cstdint>
30 
31 using namespace mlir;
32 
33 #define DEBUG_TYPE "scf-utils"
34 
36  RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
37  ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
38  bool replaceIterOperandsUsesInLoop) {
39  if (loopNest.empty())
40  return {};
41  // This method is recursive (to make it more readable). Adding an
42  // assertion here to limit the recursion. (See
43  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
44  assert(loopNest.size() <= 10 &&
45  "exceeded recursion limit when yielding value from loop nest");
46 
47  // To yield a value from a perfectly nested loop nest, the following
48  // pattern needs to be created, i.e. starting with
49  //
50  // ```mlir
51  // scf.for .. {
52  // scf.for .. {
53  // scf.for .. {
54  // %value = ...
55  // }
56  // }
57  // }
58  // ```
59  //
60  // needs to be modified to
61  //
62  // ```mlir
63  // %0 = scf.for .. iter_args(%arg0 = %init) {
64  // %1 = scf.for .. iter_args(%arg1 = %arg0) {
65  // %2 = scf.for .. iter_args(%arg2 = %arg1) {
66  // %value = ...
67  // scf.yield %value
68  // }
69  // scf.yield %2
70  // }
71  // scf.yield %1
72  // }
73  // ```
74  //
75  // The inner most loop is handled using the `replaceWithAdditionalYields`
76  // that works on a single loop.
77  if (loopNest.size() == 1) {
78  auto innerMostLoop =
79  cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
80  rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
81  newYieldValuesFn));
82  return {innerMostLoop};
83  }
84  // The outer loops are modified by calling this method recursively
85  // - The return value of the inner loop is the value yielded by this loop.
86  // - The region iter args of this loop are the init_args for the inner loop.
87  SmallVector<scf::ForOp> newLoopNest;
88  NewYieldValuesFn fn =
89  [&](OpBuilder &innerBuilder, Location loc,
90  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
91  newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
92  innerNewBBArgs, newYieldValuesFn,
93  replaceIterOperandsUsesInLoop);
94  return llvm::to_vector(llvm::map_range(
95  newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
96  [](OpResult r) -> Value { return r; }));
97  };
98  scf::ForOp outerMostLoop =
99  cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
100  rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
101  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
102  return newLoopNest;
103 }
104 
105 /// Outline a region with a single block into a new FuncOp.
106 /// Assumes the FuncOp result types is the type of the yielded operands of the
107 /// single block. This constraint makes it easy to determine the result.
108 /// This method also clones the `arith::ConstantIndexOp` at the start of
109 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
110 /// provided, it will be set to point to the operation that calls the outlined
111 /// function.
112 // TODO: support more than single-block regions.
113 // TODO: more flexible constant handling.
114 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
115  Location loc,
116  Region &region,
117  StringRef funcName,
118  func::CallOp *callOp) {
119  assert(!funcName.empty() && "funcName cannot be empty");
120  if (!region.hasOneBlock())
121  return failure();
122 
123  Block *originalBlock = &region.front();
124  Operation *originalTerminator = originalBlock->getTerminator();
125 
126  // Outline before current function.
127  OpBuilder::InsertionGuard g(rewriter);
128  rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
129 
130  SetVector<Value> captures;
131  getUsedValuesDefinedAbove(region, captures);
132 
133  ValueRange outlinedValues(captures.getArrayRef());
134  SmallVector<Type> outlinedFuncArgTypes;
135  SmallVector<Location> outlinedFuncArgLocs;
136  // Region's arguments are exactly the first block's arguments as per
137  // Region::getArguments().
138  // Func's arguments are cat(regions's arguments, captures arguments).
139  for (BlockArgument arg : region.getArguments()) {
140  outlinedFuncArgTypes.push_back(arg.getType());
141  outlinedFuncArgLocs.push_back(arg.getLoc());
142  }
143  for (Value value : outlinedValues) {
144  outlinedFuncArgTypes.push_back(value.getType());
145  outlinedFuncArgLocs.push_back(value.getLoc());
146  }
147  FunctionType outlinedFuncType =
148  FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
149  originalTerminator->getOperandTypes());
150  auto outlinedFunc =
151  func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
152  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
153 
154  // Merge blocks while replacing the original block operands.
155  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
156  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
157  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
158  {
159  OpBuilder::InsertionGuard g(rewriter);
160  rewriter.setInsertionPointToEnd(outlinedFuncBody);
161  rewriter.mergeBlocks(
162  originalBlock, outlinedFuncBody,
163  outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
164  // Explicitly set up a new ReturnOp terminator.
165  rewriter.setInsertionPointToEnd(outlinedFuncBody);
166  func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
167  originalTerminator->getOperands());
168  }
169 
170  // Reconstruct the block that was deleted and add a
171  // terminator(call_results).
172  Block *newBlock = rewriter.createBlock(
173  &region, region.begin(),
174  TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
175  ArrayRef<Location>(outlinedFuncArgLocs)
176  .take_front(numOriginalBlockArguments));
177  {
178  OpBuilder::InsertionGuard g(rewriter);
179  rewriter.setInsertionPointToEnd(newBlock);
180  SmallVector<Value> callValues;
181  llvm::append_range(callValues, newBlock->getArguments());
182  llvm::append_range(callValues, outlinedValues);
183  auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
184  if (callOp)
185  *callOp = call;
186 
187  // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
188  // Clone `originalTerminator` to take the callOp results then erase it from
189  // `outlinedFuncBody`.
190  IRMapping bvm;
191  bvm.map(originalTerminator->getOperands(), call->getResults());
192  rewriter.clone(*originalTerminator, bvm);
193  rewriter.eraseOp(originalTerminator);
194  }
195 
196  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
197  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
198  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
199  outlinedValues.size()))) {
200  Value orig = std::get<0>(it);
201  Value repl = std::get<1>(it);
202  {
203  OpBuilder::InsertionGuard g(rewriter);
204  rewriter.setInsertionPointToStart(outlinedFuncBody);
205  if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
206  repl = rewriter.clone(*cst)->getResult(0);
207  }
208  }
209  orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
210  return outlinedFunc->isProperAncestor(opOperand.getOwner());
211  });
212  }
213 
214  return outlinedFunc;
215 }
216 
217 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
218  func::FuncOp *thenFn, StringRef thenFnName,
219  func::FuncOp *elseFn, StringRef elseFnName) {
220  IRRewriter rewriter(b);
221  Location loc = ifOp.getLoc();
222  FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
223  if (thenFn && !ifOp.getThenRegion().empty()) {
224  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
225  rewriter, loc, ifOp.getThenRegion(), thenFnName);
226  if (failed(outlinedFuncOpOrFailure))
227  return failure();
228  *thenFn = *outlinedFuncOpOrFailure;
229  }
230  if (elseFn && !ifOp.getElseRegion().empty()) {
231  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
232  rewriter, loc, ifOp.getElseRegion(), elseFnName);
233  if (failed(outlinedFuncOpOrFailure))
234  return failure();
235  *elseFn = *outlinedFuncOpOrFailure;
236  }
237  return success();
238 }
239 
242  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
243  bool rootEnclosesPloops = false;
244  for (Region &region : rootOp->getRegions()) {
245  for (Block &block : region.getBlocks()) {
246  for (Operation &op : block) {
247  bool enclosesPloops = getInnermostParallelLoops(&op, result);
248  rootEnclosesPloops |= enclosesPloops;
249  if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
250  rootEnclosesPloops = true;
251 
252  // Collect parallel loop if it is an innermost one.
253  if (!enclosesPloops)
254  result.push_back(ploop);
255  }
256  }
257  }
258  }
259  return rootEnclosesPloops;
260 }
261 
262 // Build the IR that performs ceil division of a positive value by a constant:
263 // ceildiv(a, B) = divis(a + (B-1), B)
264 // where divis is rounding-to-zero division.
265 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
266  int64_t divisor) {
267  assert(divisor > 0 && "expected positive divisor");
268  assert(dividend.getType().isIntOrIndex() &&
269  "expected integer or index-typed value");
270 
271  Value divisorMinusOneCst = arith::ConstantOp::create(
272  builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
273  Value divisorCst = arith::ConstantOp::create(
274  builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
275  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
276  return arith::DivUIOp::create(builder, loc, sum, divisorCst);
277 }
278 
279 // Build the IR that performs ceil division of a positive value by another
280 // positive value:
281 // ceildiv(a, b) = divis(a + (b - 1), b)
282 // where divis is rounding-to-zero division.
283 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
284  Value divisor) {
285  assert(dividend.getType().isIntOrIndex() &&
286  "expected integer or index-typed value");
287  Value cstOne = arith::ConstantOp::create(
288  builder, loc, builder.getOneAttr(dividend.getType()));
289  Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
290  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
291  return arith::DivUIOp::create(builder, loc, sum, divisor);
292 }
293 
294 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
295 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
296 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
297 /// unrolled iteration using annotateFn.
299  Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
300  function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
301  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
302  ValueRange iterArgs, ValueRange yieldedValues) {
303  // Builder to insert unrolled bodies just before the terminator of the body of
304  // 'forOp'.
305  auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
306 
307  constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
308  if (!annotateFn)
309  annotateFn = defaultAnnotateFn;
310 
311  // Keep a pointer to the last non-terminator operation in the original block
312  // so that we know what to clone (since we are doing this in-place).
313  Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
314 
315  // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
316  SmallVector<Value, 4> lastYielded(yieldedValues);
317 
318  for (unsigned i = 1; i < unrollFactor; i++) {
319  IRMapping operandMap;
320 
321  // Prepare operand map.
322  operandMap.map(iterArgs, lastYielded);
323 
324  // If the induction variable is used, create a remapping to the value for
325  // this unrolled instance.
326  if (!forOpIV.use_empty()) {
327  Value ivUnroll = ivRemapFn(i, forOpIV, builder);
328  operandMap.map(forOpIV, ivUnroll);
329  }
330 
331  // Clone the original body of 'forOp'.
332  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
333  Operation *clonedOp = builder.clone(*it, operandMap);
334  annotateFn(i, clonedOp, builder);
335  }
336 
337  // Update yielded values.
338  for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
339  lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
340  }
341 
342  // Make sure we annotate the Ops in the original body. We do this last so that
343  // any annotations are not copied into the cloned Ops above.
344  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
345  annotateFn(0, &*it, builder);
346 
347  // Update operands of the yield statement.
348  loopBodyBlock->getTerminator()->setOperands(lastYielded);
349 }
350 
351 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
352 /// epilogue loop, if the loop is unrolled.
353 FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
354  scf::ForOp forOp, uint64_t unrollFactor,
355  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
356  assert(unrollFactor > 0 && "expected positive unroll factor");
357 
358  // Return if the loop body is empty.
359  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
360  return UnrolledLoopInfo{forOp, std::nullopt};
361 
362  // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
363  // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
364  OpBuilder boundsBuilder(forOp);
365  IRRewriter rewriter(forOp.getContext());
366  auto loc = forOp.getLoc();
367  Value step = forOp.getStep();
368  Value upperBoundUnrolled;
369  Value stepUnrolled;
370  bool generateEpilogueLoop = true;
371 
372  std::optional<APInt> constTripCount = forOp.getStaticTripCount();
373  if (constTripCount) {
374  // Constant loop bounds computation.
375  int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
376  int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
377  int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
378  if (unrollFactor == 1) {
379  if (*constTripCount == 1 &&
380  failed(forOp.promoteIfSingleIteration(rewriter)))
381  return failure();
382  return UnrolledLoopInfo{forOp, std::nullopt};
383  }
384 
385  int64_t tripCountEvenMultiple =
386  constTripCount->getSExtValue() -
387  (constTripCount->getSExtValue() % unrollFactor);
388  int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
389  int64_t stepUnrolledCst = stepCst * unrollFactor;
390 
391  // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
392  generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
393  if (generateEpilogueLoop)
394  upperBoundUnrolled = arith::ConstantOp::create(
395  boundsBuilder, loc,
396  boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
397  upperBoundUnrolledCst));
398  else
399  upperBoundUnrolled = forOp.getUpperBound();
400 
401  // Create constant for 'stepUnrolled'.
402  stepUnrolled =
403  stepCst == stepUnrolledCst
404  ? step
405  : arith::ConstantOp::create(boundsBuilder, loc,
406  boundsBuilder.getIntegerAttr(
407  step.getType(), stepUnrolledCst));
408  } else {
409  // Dynamic loop bounds computation.
410  // TODO: Add dynamic asserts for negative lb/ub/step, or
411  // consider using ceilDiv from AffineApplyExpander.
412  auto lowerBound = forOp.getLowerBound();
413  auto upperBound = forOp.getUpperBound();
414  Value diff =
415  arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
416  Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
417  Value unrollFactorCst = arith::ConstantOp::create(
418  boundsBuilder, loc,
419  boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
420  Value tripCountRem =
421  arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
422  // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
423  Value tripCountEvenMultiple =
424  arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
425  // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
426  upperBoundUnrolled = arith::AddIOp::create(
427  boundsBuilder, loc, lowerBound,
428  arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
429  // Scale 'step' by 'unrollFactor'.
430  stepUnrolled =
431  arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
432  }
433 
434  UnrolledLoopInfo resultLoops;
435 
436  // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
437  if (generateEpilogueLoop) {
438  OpBuilder epilogueBuilder(forOp->getContext());
439  epilogueBuilder.setInsertionPointAfter(forOp);
440  auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
441  epilogueForOp.setLowerBound(upperBoundUnrolled);
442 
443  // Update uses of loop results.
444  auto results = forOp.getResults();
445  auto epilogueResults = epilogueForOp.getResults();
446 
447  for (auto e : llvm::zip(results, epilogueResults)) {
448  std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
449  }
450  epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
451  epilogueForOp.getInitArgs().size(), results);
452  if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
453  resultLoops.epilogueLoopOp = epilogueForOp;
454  }
455 
456  // Create unrolled loop.
457  forOp.setUpperBound(upperBoundUnrolled);
458  forOp.setStep(stepUnrolled);
459 
460  auto iterArgs = ValueRange(forOp.getRegionIterArgs());
461  auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
462 
464  forOp.getBody(), forOp.getInductionVar(), unrollFactor,
465  [&](unsigned i, Value iv, OpBuilder b) {
466  // iv' = iv + step * i;
467  auto stride = arith::MulIOp::create(
468  b, loc, step,
469  arith::ConstantOp::create(b, loc,
470  b.getIntegerAttr(iv.getType(), i)));
471  return arith::AddIOp::create(b, loc, iv, stride);
472  },
473  annotateFn, iterArgs, yieldedValues);
474  // Promote the loop body up if this has turned into a single iteration loop.
475  if (forOp.promoteIfSingleIteration(rewriter).failed())
476  resultLoops.mainLoopOp = forOp;
477  return resultLoops;
478 }
479 
480 /// Unrolls this loop completely.
481 LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
482  IRRewriter rewriter(forOp.getContext());
483  std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
484  if (!mayBeConstantTripCount.has_value())
485  return failure();
486  const APInt &tripCount = *mayBeConstantTripCount;
487  if (tripCount.isZero())
488  return success();
489  if (tripCount.getSExtValue() == 1)
490  return forOp.promoteIfSingleIteration(rewriter);
491  return loopUnrollByFactor(forOp, tripCount.getSExtValue());
492 }
493 
494 /// Check if bounds of all inner loops are defined outside of `forOp`
495 /// and return false if not.
496 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
497  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
498  if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
499  !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
500  !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
501  return WalkResult::interrupt();
502 
503  return WalkResult::advance();
504  });
505  return !walkResult.wasInterrupted();
506 }
507 
508 /// Unrolls and jams this loop by the specified factor.
509 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
510  uint64_t unrollJamFactor) {
511  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
512 
513  if (unrollJamFactor == 1)
514  return success();
515 
516  // If any control operand of any inner loop of `forOp` is defined within
517  // `forOp`, no unroll jam.
518  if (!areInnerBoundsInvariant(forOp)) {
519  LDBG() << "failed to unroll and jam: inner bounds are not invariant";
520  return failure();
521  }
522 
523  // Currently, for operations with results are not supported.
524  if (forOp->getNumResults() > 0) {
525  LDBG() << "failed to unroll and jam: unsupported loop with results";
526  return failure();
527  }
528 
529  // Currently, only constant trip count that divided by the unroll factor is
530  // supported.
531  std::optional<APInt> tripCount = forOp.getStaticTripCount();
532  if (!tripCount.has_value()) {
533  // If the trip count is dynamic, do not unroll & jam.
534  LDBG() << "failed to unroll and jam: trip count could not be determined";
535  return failure();
536  }
537  if (unrollJamFactor > tripCount->getZExtValue()) {
538  LDBG() << "unroll and jam factor is greater than trip count, set factor to "
539  "trip "
540  "count";
541  unrollJamFactor = tripCount->getZExtValue();
542  } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
543  LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
544  "multiple of unroll jam factor";
545  return failure();
546  }
547 
548  // Nothing in the loop body other than the terminator.
549  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
550  return success();
551 
552  // Gather all sub-blocks to jam upon the loop being unrolled.
554  jbg.walk(forOp);
555  auto &subBlocks = jbg.subBlocks;
556 
557  // Collect inner loops.
558  SmallVector<scf::ForOp> innerLoops;
559  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
560 
561  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
562  // iteration. There are (`unrollJamFactor` - 1) iterations.
563  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
564 
565  // For any loop with iter_args, replace it with a new loop that has
566  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
567  // operands.
568  SmallVector<scf::ForOp> newInnerLoops;
569  IRRewriter rewriter(forOp.getContext());
570  for (scf::ForOp oldForOp : innerLoops) {
571  SmallVector<Value> dupIterOperands, dupYieldOperands;
572  ValueRange oldIterOperands = oldForOp.getInits();
573  ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
574  ValueRange oldYieldOperands =
575  cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
576  // Get additional iterOperands, iterArgs, and yield operands. We will
577  // fix iterOperands and yield operands after cloning of sub-blocks.
578  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
579  dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
580  dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
581  }
582  // Create a new loop with additional iterOperands, iter_args and yield
583  // operands. This new loop will take the loop body of the original loop.
584  bool forOpReplaced = oldForOp == forOp;
585  scf::ForOp newForOp =
586  cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
587  rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
588  [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
589  return dupYieldOperands;
590  }));
591  newInnerLoops.push_back(newForOp);
592  // `forOp` has been replaced with a new loop.
593  if (forOpReplaced)
594  forOp = newForOp;
595  // Update `operandMaps` for `newForOp` iterArgs and results.
596  ValueRange newIterArgs = newForOp.getRegionIterArgs();
597  unsigned oldNumIterArgs = oldIterArgs.size();
598  ValueRange newResults = newForOp.getResults();
599  unsigned oldNumResults = newResults.size() / unrollJamFactor;
600  assert(oldNumIterArgs == oldNumResults &&
601  "oldNumIterArgs must be the same as oldNumResults");
602  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
603  for (unsigned j = 0; j < oldNumIterArgs; ++j) {
604  // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
605  // results. Update `operandMaps[i - 1]` to map old iterArgs and results
606  // to those in the `i`th new set.
607  operandMaps[i - 1].map(newIterArgs[j],
608  newIterArgs[i * oldNumIterArgs + j]);
609  operandMaps[i - 1].map(newResults[j],
610  newResults[i * oldNumResults + j]);
611  }
612  }
613  }
614 
615  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
616  rewriter.setInsertionPoint(forOp);
617  int64_t step = forOp.getConstantStep()->getSExtValue();
618  auto newStep = rewriter.createOrFold<arith::MulIOp>(
619  forOp.getLoc(), forOp.getStep(),
620  rewriter.createOrFold<arith::ConstantOp>(
621  forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
622  forOp.setStep(newStep);
623  auto forOpIV = forOp.getInductionVar();
624 
625  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
626  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
627  for (auto &subBlock : subBlocks) {
628  // Builder to insert unroll-jammed bodies. Insert right at the end of
629  // sub-block.
630  OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
631 
632  // If the induction variable is used, create a remapping to the value for
633  // this unrolled instance.
634  if (!forOpIV.use_empty()) {
635  // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
636  auto ivTag = builder.createOrFold<arith::ConstantOp>(
637  forOp.getLoc(), builder.getIndexAttr(step * i));
638  auto ivUnroll =
639  builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
640  operandMaps[i - 1].map(forOpIV, ivUnroll);
641  }
642  // Clone the sub-block being unroll-jammed.
643  for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
644  builder.clone(*it, operandMaps[i - 1]);
645  }
646  // Fix iterOperands and yield op operands of newly created loops.
647  for (auto newForOp : newInnerLoops) {
648  unsigned oldNumIterOperands =
649  newForOp.getNumRegionIterArgs() / unrollJamFactor;
650  unsigned numControlOperands = newForOp.getNumControlOperands();
651  auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
652  unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
653  assert(oldNumIterOperands == oldNumYieldOperands &&
654  "oldNumIterOperands must be the same as oldNumYieldOperands");
655  for (unsigned j = 0; j < oldNumIterOperands; ++j) {
656  // The `i`th duplication of an old iterOperand or yield op operand
657  // needs to be replaced with a mapped value from `operandMaps[i - 1]`
658  // if such mapped value exists.
659  newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
660  operandMaps[i - 1].lookupOrDefault(
661  newForOp.getOperand(numControlOperands + j)));
662  yieldOp.setOperand(
663  i * oldNumYieldOperands + j,
664  operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
665  }
666  }
667  }
668 
669  // Promote the loop body up if this has turned into a single iteration loop.
670  (void)forOp.promoteIfSingleIteration(rewriter);
671  return success();
672 }
673 
675  OpFoldResult lb, OpFoldResult ub,
676  OpFoldResult step) {
677  Range normalizedLoopBounds;
678  normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
679  normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
680  AffineExpr s0, s1, s2;
681  bindSymbols(rewriter.getContext(), s0, s1, s2);
682  AffineExpr e = (s1 - s0).ceilDiv(s2);
683  normalizedLoopBounds.size =
684  affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
685  return normalizedLoopBounds;
686 }
687 
689  OpFoldResult lb, OpFoldResult ub,
690  OpFoldResult step) {
691  if (getType(lb).isIndex()) {
692  return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
693  }
694  // For non-index types, generate `arith` instructions
695  // Check if the loop is already known to have a constant zero lower bound or
696  // a constant one step.
697  bool isZeroBased = false;
698  if (auto lbCst = getConstantIntValue(lb))
699  isZeroBased = lbCst.value() == 0;
700 
701  bool isStepOne = false;
702  if (auto stepCst = getConstantIntValue(step))
703  isStepOne = stepCst.value() == 1;
704 
705  Type rangeType = getType(lb);
706  assert(rangeType == getType(ub) && rangeType == getType(step) &&
707  "expected matching types");
708 
709  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
710  // assuming the step is strictly positive. Update the bounds and the step
711  // of the loop to go from 0 to the number of iterations, if necessary.
712  if (isZeroBased && isStepOne)
713  return {lb, ub, step};
714 
715  OpFoldResult diff = ub;
716  if (!isZeroBased) {
717  diff = rewriter.createOrFold<arith::SubIOp>(
718  loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
719  getValueOrCreateConstantIntOp(rewriter, loc, lb));
720  }
721  OpFoldResult newUpperBound = diff;
722  if (!isStepOne) {
723  newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
724  loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
725  getValueOrCreateConstantIntOp(rewriter, loc, step));
726  }
727 
728  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
729  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
730 
731  return {newLowerBound, newUpperBound, newStep};
732 }
733 
735  Location loc,
736  Value normalizedIv,
737  OpFoldResult origLb,
738  OpFoldResult origStep) {
739  AffineExpr d0, s0, s1;
740  bindSymbols(rewriter.getContext(), s0, s1);
741  bindDims(rewriter.getContext(), d0);
742  AffineExpr e = d0 * s1 + s0;
744  rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
745  Value denormalizedIvVal =
746  getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
747  SmallPtrSet<Operation *, 1> preservedUses;
748  // If an `affine.apply` operation is generated for denormalization, the use
749  // of `origLb` in those ops must not be replaced. These arent not generated
750  // when `origLb == 0` and `origStep == 1`.
751  if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
752  if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
753  preservedUses.insert(preservedUse);
754  }
755  }
756  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
757 }
758 
760  Value normalizedIv, OpFoldResult origLb,
761  OpFoldResult origStep) {
762  if (getType(origLb).isIndex()) {
763  return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
764  origLb, origStep);
765  }
766  Value denormalizedIv;
768  bool isStepOne = isOneInteger(origStep);
769  bool isZeroBased = isZeroInteger(origLb);
770 
771  Value scaled = normalizedIv;
772  if (!isStepOne) {
773  Value origStepValue =
774  getValueOrCreateConstantIntOp(rewriter, loc, origStep);
775  scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
776  preserve.insert(scaled.getDefiningOp());
777  }
778  denormalizedIv = scaled;
779  if (!isZeroBased) {
780  Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
781  denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
782  preserve.insert(denormalizedIv.getDefiningOp());
783  }
784 
785  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
786 }
787 
789  ArrayRef<OpFoldResult> values) {
790  assert(!values.empty() && "unexecpted empty array");
791  AffineExpr s0, s1;
792  bindSymbols(rewriter.getContext(), s0, s1);
793  AffineExpr mul = s0 * s1;
794  OpFoldResult products = rewriter.getIndexAttr(1);
795  for (auto v : values) {
797  rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
798  }
799  return products;
800 }
801 
802 /// Helper function to multiply a sequence of values.
804  ArrayRef<Value> values) {
805  assert(!values.empty() && "unexpected empty list");
806  if (getType(values.front()).isIndex()) {
808  OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
809  return getValueOrCreateConstantIndexOp(rewriter, loc, product);
810  }
811  std::optional<Value> productOf;
812  for (auto v : values) {
813  auto vOne = getConstantIntValue(v);
814  if (vOne && vOne.value() == 1)
815  continue;
816  if (productOf)
817  productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
818  .getResult();
819  else
820  productOf = v;
821  }
822  if (!productOf) {
823  productOf = arith::ConstantOp::create(
824  rewriter, loc, rewriter.getOneAttr(getType(values.front())))
825  .getResult();
826  }
827  return productOf.value();
828 }
829 
830 /// For each original loop, the value of the
831 /// induction variable can be obtained by dividing the induction variable of
832 /// the linearized loop by the total number of iterations of the loops nested
833 /// in it modulo the number of iterations in this loop (remove the values
834 /// related to the outer loops):
835 /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
836 /// Compute these iteratively from the innermost loop by creating a "running
837 /// quotient" of division by the range.
838 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
840  Value linearizedIv, ArrayRef<Value> ubs) {
841 
842  if (linearizedIv.getType().isIndex()) {
843  Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
844  rewriter, loc, linearizedIv, ubs);
845  auto resultVals = llvm::map_to_vector(
846  delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
847  return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
848  }
849 
850  SmallVector<Value> delinearizedIvs(ubs.size());
851  SmallPtrSet<Operation *, 2> preservedUsers;
852 
853  llvm::BitVector isUbOne(ubs.size());
854  for (auto [index, ub] : llvm::enumerate(ubs)) {
855  auto ubCst = getConstantIntValue(ub);
856  if (ubCst && ubCst.value() == 1)
857  isUbOne.set(index);
858  }
859 
860  // Prune the lead ubs that are all ones.
861  unsigned numLeadingOneUbs = 0;
862  for (auto [index, ub] : llvm::enumerate(ubs)) {
863  if (!isUbOne.test(index)) {
864  break;
865  }
866  delinearizedIvs[index] = arith::ConstantOp::create(
867  rewriter, loc, rewriter.getZeroAttr(ub.getType()));
868  numLeadingOneUbs++;
869  }
870 
871  Value previous = linearizedIv;
872  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
873  unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
874  if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
875  previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
876  preservedUsers.insert(previous.getDefiningOp());
877  }
878  Value iv = previous;
879  if (i != e - 1) {
880  if (!isUbOne.test(idx)) {
881  iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
882  preservedUsers.insert(iv.getDefiningOp());
883  } else {
884  iv = arith::ConstantOp::create(
885  rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
886  }
887  }
888  delinearizedIvs[idx] = iv;
889  }
890  return {delinearizedIvs, preservedUsers};
891 }
892 
893 LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
895  if (loops.size() < 2)
896  return failure();
897 
898  scf::ForOp innermost = loops.back();
899  scf::ForOp outermost = loops.front();
900 
901  // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
902  // allows the following code to assume upperBound is the number of iterations.
903  for (auto loop : loops) {
904  OpBuilder::InsertionGuard g(rewriter);
905  rewriter.setInsertionPoint(outermost);
906  Value lb = loop.getLowerBound();
907  Value ub = loop.getUpperBound();
908  Value step = loop.getStep();
909  auto newLoopRange =
910  emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
911 
912  rewriter.modifyOpInPlace(loop, [&]() {
913  loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
914  newLoopRange.offset));
915  loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
916  newLoopRange.size));
917  loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
918  newLoopRange.stride));
919  });
920  rewriter.setInsertionPointToStart(innermost.getBody());
921  denormalizeInductionVariable(rewriter, loop.getLoc(),
922  loop.getInductionVar(), lb, step);
923  }
924 
925  // 2. Emit code computing the upper bound of the coalesced loop as product
926  // of the number of iterations of all loops.
927  OpBuilder::InsertionGuard g(rewriter);
928  rewriter.setInsertionPoint(outermost);
929  Location loc = outermost.getLoc();
930  SmallVector<Value> upperBounds = llvm::map_to_vector(
931  loops, [](auto loop) { return loop.getUpperBound(); });
932  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
933  outermost.setUpperBound(upperBound);
934 
935  rewriter.setInsertionPointToStart(innermost.getBody());
936  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
937  rewriter, loc, outermost.getInductionVar(), upperBounds);
938  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
939  preservedUsers);
940 
941  for (int i = loops.size() - 1; i > 0; --i) {
942  auto outerLoop = loops[i - 1];
943  auto innerLoop = loops[i];
944 
945  Operation *innerTerminator = innerLoop.getBody()->getTerminator();
946  auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
947  assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
948  for (Value &yieldedVal : yieldedVals) {
949  // The yielded value may be an iteration argument of the inner loop
950  // which is about to be inlined.
951  auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
952  if (iter != innerLoop.getRegionIterArgs().end()) {
953  unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
954  // `outerLoop` iter args identical to the `innerLoop` init args.
955  assert(iterArgIndex < innerLoop.getInitArgs().size());
956  yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
957  }
958  }
959  rewriter.eraseOp(innerTerminator);
960 
961  SmallVector<Value> innerBlockArgs;
962  innerBlockArgs.push_back(delinearizeIvs[i]);
963  llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
964  rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
965  Block::iterator(innerLoop), innerBlockArgs);
966  rewriter.replaceOp(innerLoop, yieldedVals);
967  }
968  return success();
969 }
970 
972  if (loops.empty()) {
973  return failure();
974  }
975  IRRewriter rewriter(loops.front().getContext());
976  return coalesceLoops(rewriter, loops);
977 }
978 
979 LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
980  LogicalResult result(failure());
982  getPerfectlyNestedLoops(loops, op);
983 
984  // Look for a band of loops that can be coalesced, i.e. perfectly nested
985  // loops with bounds defined above some loop.
986 
987  // 1. For each loop, find above which parent loop its bounds operands are
988  // defined.
989  SmallVector<unsigned> operandsDefinedAbove(loops.size());
990  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
991  operandsDefinedAbove[i] = i;
992  for (unsigned j = 0; j < i; ++j) {
993  SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
994  loops[i].getUpperBound(),
995  loops[i].getStep()};
996  if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
997  operandsDefinedAbove[i] = j;
998  break;
999  }
1000  }
1001  }
1002 
1003  // 2. For each inner loop check that the iter_args for the immediately outer
1004  // loop are the init for the immediately inner loop and that the yields of the
1005  // return of the inner loop is the yield for the immediately outer loop. Keep
1006  // track of where the chain starts from for each loop.
1007  SmallVector<unsigned> iterArgChainStart(loops.size());
1008  iterArgChainStart[0] = 0;
1009  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1010  // By default set the start of the chain to itself.
1011  iterArgChainStart[i] = i;
1012  auto outerloop = loops[i - 1];
1013  auto innerLoop = loops[i];
1014  if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1015  continue;
1016  }
1017  if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1018  continue;
1019  }
1020  auto outerloopTerminator = outerloop.getBody()->getTerminator();
1021  if (!llvm::equal(outerloopTerminator->getOperands(),
1022  innerLoop.getResults())) {
1023  continue;
1024  }
1025  iterArgChainStart[i] = iterArgChainStart[i - 1];
1026  }
1027 
1028  // 3. Identify bands of loops such that the operands of all of them are
1029  // defined above the first loop in the band. Traverse the nest bottom-up
1030  // so that modifications don't invalidate the inner loops.
1031  for (unsigned end = loops.size(); end > 0; --end) {
1032  unsigned start = 0;
1033  for (; start < end - 1; ++start) {
1034  auto maxPos =
1035  *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1036  std::next(operandsDefinedAbove.begin(), end));
1037  if (maxPos > start)
1038  continue;
1039  if (iterArgChainStart[end - 1] > start)
1040  continue;
1041  auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1042  if (succeeded(coalesceLoops(band)))
1043  result = success();
1044  break;
1045  }
1046  // If a band was found and transformed, keep looking at the loops above
1047  // the outermost transformed loop.
1048  if (start != end - 1)
1049  end = start + 1;
1050  }
1051  return result;
1052 }
1053 
1055  RewriterBase &rewriter, scf::ParallelOp loops,
1056  ArrayRef<std::vector<unsigned>> combinedDimensions) {
1057  OpBuilder::InsertionGuard g(rewriter);
1058  rewriter.setInsertionPoint(loops);
1059  Location loc = loops.getLoc();
1060 
1061  // Presort combined dimensions.
1062  auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1063  for (auto &dims : sortedDimensions)
1064  llvm::sort(dims);
1065 
1066  // Normalize ParallelOp's iteration pattern.
1067  SmallVector<Value, 3> normalizedUpperBounds;
1068  for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1069  OpBuilder::InsertionGuard g2(rewriter);
1070  rewriter.setInsertionPoint(loops);
1071  Value lb = loops.getLowerBound()[i];
1072  Value ub = loops.getUpperBound()[i];
1073  Value step = loops.getStep()[i];
1074  auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1075  normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1076  rewriter, loops.getLoc(), newLoopRange.size));
1077 
1078  rewriter.setInsertionPointToStart(loops.getBody());
1079  denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1080  step);
1081  }
1082 
1083  // Combine iteration spaces.
1084  SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1085  auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1086  auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1087  for (auto &sortedDimension : sortedDimensions) {
1088  Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1089  for (auto idx : sortedDimension) {
1090  newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1091  normalizedUpperBounds[idx]);
1092  }
1093  lowerBounds.push_back(cst0);
1094  steps.push_back(cst1);
1095  upperBounds.push_back(newUpperBound);
1096  }
1097 
1098  // Create new ParallelLoop with conversions to the original induction values.
1099  // The loop below uses divisions to get the relevant range of values in the
1100  // new induction value that represent each range of the original induction
1101  // value. The remainders then determine based on that range, which iteration
1102  // of the original induction value this represents. This is a normalized value
1103  // that is un-normalized already by the previous logic.
1104  auto newPloop = scf::ParallelOp::create(
1105  rewriter, loc, lowerBounds, upperBounds, steps,
1106  [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1107  for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1108  Value previous = ploopIVs[i];
1109  unsigned numberCombinedDimensions = combinedDimensions[i].size();
1110  // Iterate over all except the last induction value.
1111  for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1112  unsigned idx = combinedDimensions[i][j];
1113 
1114  // Determine the current induction value's current loop iteration
1115  Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1116  normalizedUpperBounds[idx]);
1117  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1118  loops.getRegion());
1119 
1120  // Remove the effect of the current induction value to prepare for
1121  // the next value.
1122  previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1123  normalizedUpperBounds[idx]);
1124  }
1125 
1126  // The final induction value is just the remaining value.
1127  unsigned idx = combinedDimensions[i][0];
1128  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1129  previous, loops.getRegion());
1130  }
1131  });
1132 
1133  // Replace the old loop with the new loop.
1134  loops.getBody()->back().erase();
1135  newPloop.getBody()->getOperations().splice(
1136  Block::iterator(newPloop.getBody()->back()),
1137  loops.getBody()->getOperations());
1138  loops.erase();
1139 }
1140 
1141 // Hoist the ops within `outer` that appear before `inner`.
1142 // Such ops include the ops that have been introduced by parametric tiling.
1143 // Ops that come from triangular loops (i.e. that belong to the program slice
1144 // rooted at `outer`) and ops that have side effects cannot be hoisted.
1145 // Return failure when any op fails to hoist.
1146 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1147  SetVector<Operation *> forwardSlice;
1149  options.filter = [&inner](Operation *op) {
1150  return op != inner.getOperation();
1151  };
1152  getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1153  LogicalResult status = success();
1155  for (auto &op : outer.getBody()->without_terminator()) {
1156  // Stop when encountering the inner loop.
1157  if (&op == inner.getOperation())
1158  break;
1159  // Skip over non-hoistable ops.
1160  if (forwardSlice.count(&op) > 0) {
1161  status = failure();
1162  continue;
1163  }
1164  // Skip intermediate scf::ForOp, these are not considered a failure.
1165  if (isa<scf::ForOp>(op))
1166  continue;
1167  // Skip other ops with regions.
1168  if (op.getNumRegions() > 0) {
1169  status = failure();
1170  continue;
1171  }
1172  // Skip if op has side effects.
1173  // TODO: loads to immutable memory regions are ok.
1174  if (!isMemoryEffectFree(&op)) {
1175  status = failure();
1176  continue;
1177  }
1178  toHoist.push_back(&op);
1179  }
1180  auto *outerForOp = outer.getOperation();
1181  for (auto *op : toHoist)
1182  op->moveBefore(outerForOp);
1183  return status;
1184 }
1185 
1186 // Traverse the interTile and intraTile loops and try to hoist ops such that
1187 // bands of perfectly nested loops are isolated.
1188 // Return failure if either perfect interTile or perfect intraTile bands cannot
1189 // be formed.
1190 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1191  LogicalResult status = success();
1192  const Loops &interTile = tileLoops.first;
1193  const Loops &intraTile = tileLoops.second;
1194  auto size = interTile.size();
1195  assert(size == intraTile.size());
1196  if (size <= 1)
1197  return success();
1198  for (unsigned s = 1; s < size; ++s)
1199  status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1200  : failure();
1201  for (unsigned s = 1; s < size; ++s)
1202  status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1203  : failure();
1204  return status;
1205 }
1206 
1207 /// Collect perfectly nested loops starting from `rootForOps`. Loops are
1208 /// perfectly nested if each loop is the first and only non-terminator operation
1209 /// in the parent loop. Collect at most `maxLoops` loops and append them to
1210 /// `forOps`.
1211 template <typename T>
1213  SmallVectorImpl<T> &forOps, T rootForOp,
1214  unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1215  for (unsigned i = 0; i < maxLoops; ++i) {
1216  forOps.push_back(rootForOp);
1217  Block &body = rootForOp.getRegion().front();
1218  if (body.begin() != std::prev(body.end(), 2))
1219  return;
1220 
1221  rootForOp = dyn_cast<T>(&body.front());
1222  if (!rootForOp)
1223  return;
1224  }
1225 }
1226 
1227 static Loops stripmineSink(scf::ForOp forOp, Value factor,
1228  ArrayRef<scf::ForOp> targets) {
1229  assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
1230  auto originalStep = forOp.getStep();
1231  auto iv = forOp.getInductionVar();
1232 
1233  OpBuilder b(forOp);
1234  forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1235 
1236  Loops innerLoops;
1237  for (auto t : targets) {
1238  assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
1239 
1240  // Save information for splicing ops out of t when done
1241  auto begin = t.getBody()->begin();
1242  auto nOps = t.getBody()->getOperations().size();
1243 
1244  // Insert newForOp before the terminator of `t`.
1245  auto b = OpBuilder::atBlockTerminator((t.getBody()));
1246  Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1247  Value ub =
1248  arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1249 
1250  // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1251  auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1252  newForOp.getBody()->getOperations().splice(
1253  newForOp.getBody()->getOperations().begin(),
1254  t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1255  replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1256  newForOp.getRegion());
1257 
1258  innerLoops.push_back(newForOp);
1259  }
1260 
1261  return innerLoops;
1262 }
1263 
1264 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1265 // Returns the new for operation, nested immediately under `target`.
1266 template <typename SizeType>
1267 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1268  scf::ForOp target) {
1269  // TODO: Use cheap structural assertions that targets are nested under
1270  // forOp and that targets are not nested under each other when DominanceInfo
1271  // exposes the capability. It seems overkill to construct a whole function
1272  // dominance tree at this point.
1273  auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1274  assert(res.size() == 1 && "Expected 1 inner forOp");
1275  return res[0];
1276 }
1277 
1279  ArrayRef<Value> sizes,
1280  ArrayRef<scf::ForOp> targets) {
1282  SmallVector<scf::ForOp, 8> currentTargets(targets);
1283  for (auto it : llvm::zip(forOps, sizes)) {
1284  auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1285  res.push_back(step);
1286  currentTargets = step;
1287  }
1288  return res;
1289 }
1290 
1292  scf::ForOp target) {
1294  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1295  res.push_back(llvm::getSingleElement(loops));
1296  return res;
1297 }
1298 
1299 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1300  // Collect perfectly nested loops. If more size values provided than nested
1301  // loops available, truncate `sizes`.
1303  forOps.reserve(sizes.size());
1304  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1305  if (forOps.size() < sizes.size())
1306  sizes = sizes.take_front(forOps.size());
1307 
1308  return ::tile(forOps, sizes, forOps.back());
1309 }
1310 
1312  scf::ForOp root) {
1313  getPerfectlyNestedLoopsImpl(nestedLoops, root);
1314 }
1315 
1317  ArrayRef<int64_t> sizes) {
1318  // Collect perfectly nested loops. If more size values provided than nested
1319  // loops available, truncate `sizes`.
1321  forOps.reserve(sizes.size());
1322  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1323  if (forOps.size() < sizes.size())
1324  sizes = sizes.take_front(forOps.size());
1325 
1326  // Compute the tile sizes such that i-th outer loop executes size[i]
1327  // iterations. Given that the loop current executes
1328  // numIterations = ceildiv((upperBound - lowerBound), step)
1329  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1330  SmallVector<Value, 4> tileSizes;
1331  tileSizes.reserve(sizes.size());
1332  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1333  assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1334 
1335  auto forOp = forOps[i];
1336  OpBuilder builder(forOp);
1337  auto loc = forOp.getLoc();
1338  Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1339  forOp.getLowerBound());
1340  Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1341  Value iterationsPerBlock =
1342  ceilDivPositive(builder, loc, numIterations, sizes[i]);
1343  tileSizes.push_back(iterationsPerBlock);
1344  }
1345 
1346  // Call parametric tiling with the given sizes.
1347  auto intraTile = tile(forOps, tileSizes, forOps.back());
1348  TileLoops tileLoops = std::make_pair(forOps, intraTile);
1349 
1350  // TODO: for now we just ignore the result of band isolation.
1351  // In the future, mapping decisions may be impacted by the ability to
1352  // isolate perfectly nested bands.
1353  (void)tryIsolateBands(tileLoops);
1354 
1355  return tileLoops;
1356 }
1357 
1358 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1359  scf::ForallOp source,
1360  RewriterBase &rewriter) {
1361  unsigned numTargetOuts = target.getNumResults();
1362  unsigned numSourceOuts = source.getNumResults();
1363 
1364  // Create fused shared_outs.
1365  SmallVector<Value> fusedOuts;
1366  llvm::append_range(fusedOuts, target.getOutputs());
1367  llvm::append_range(fusedOuts, source.getOutputs());
1368 
1369  // Create a new scf.forall op after the source loop.
1370  rewriter.setInsertionPointAfter(source);
1371  scf::ForallOp fusedLoop = scf::ForallOp::create(
1372  rewriter, source.getLoc(), source.getMixedLowerBound(),
1373  source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1374  source.getMapping());
1375 
1376  // Map control operands.
1377  IRMapping mapping;
1378  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1379  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1380 
1381  // Map shared outs.
1382  mapping.map(target.getRegionIterArgs(),
1383  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1384  mapping.map(source.getRegionIterArgs(),
1385  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1386 
1387  // Append everything except the terminator into the fused operation.
1388  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1389  for (Operation &op : target.getBody()->without_terminator())
1390  rewriter.clone(op, mapping);
1391  for (Operation &op : source.getBody()->without_terminator())
1392  rewriter.clone(op, mapping);
1393 
1394  // Fuse the old terminator in_parallel ops into the new one.
1395  scf::InParallelOp targetTerm = target.getTerminator();
1396  scf::InParallelOp sourceTerm = source.getTerminator();
1397  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1398  rewriter.setInsertionPointToStart(fusedTerm.getBody());
1399  for (Operation &op : targetTerm.getYieldingOps())
1400  rewriter.clone(op, mapping);
1401  for (Operation &op : sourceTerm.getYieldingOps())
1402  rewriter.clone(op, mapping);
1403 
1404  // Replace old loops by substituting their uses by results of the fused loop.
1405  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1406  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1407 
1408  return fusedLoop;
1409 }
1410 
1411 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1412  scf::ForOp source,
1413  RewriterBase &rewriter) {
1414  assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1415  "incompatible signedness");
1416  unsigned numTargetOuts = target.getNumResults();
1417  unsigned numSourceOuts = source.getNumResults();
1418 
1419  // Create fused init_args, with target's init_args before source's init_args.
1420  SmallVector<Value> fusedInitArgs;
1421  llvm::append_range(fusedInitArgs, target.getInitArgs());
1422  llvm::append_range(fusedInitArgs, source.getInitArgs());
1423 
1424  // Create a new scf.for op after the source loop (with scf.yield terminator
1425  // (without arguments) only in case its init_args is empty).
1426  rewriter.setInsertionPointAfter(source);
1427  scf::ForOp fusedLoop = scf::ForOp::create(
1428  rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1429  source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
1430  source.getUnsignedCmp());
1431 
1432  // Map original induction variables and operands to those of the fused loop.
1433  IRMapping mapping;
1434  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1435  mapping.map(target.getRegionIterArgs(),
1436  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1437  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1438  mapping.map(source.getRegionIterArgs(),
1439  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1440 
1441  // Merge target's body into the new (fused) for loop and then source's body.
1442  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1443  for (Operation &op : target.getBody()->without_terminator())
1444  rewriter.clone(op, mapping);
1445  for (Operation &op : source.getBody()->without_terminator())
1446  rewriter.clone(op, mapping);
1447 
1448  // Build fused yield results by appropriately mapping original yield operands.
1449  SmallVector<Value> yieldResults;
1450  for (Value operand : target.getBody()->getTerminator()->getOperands())
1451  yieldResults.push_back(mapping.lookupOrDefault(operand));
1452  for (Value operand : source.getBody()->getTerminator()->getOperands())
1453  yieldResults.push_back(mapping.lookupOrDefault(operand));
1454  if (!yieldResults.empty())
1455  scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1456 
1457  // Replace old loops by substituting their uses by results of the fused loop.
1458  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1459  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1460 
1461  return fusedLoop;
1462 }
1463 
1464 FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1465  scf::ForallOp forallOp) {
1466  SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1467  SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1468  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1469 
1470  if (forallOp.isNormalized())
1471  return forallOp;
1472 
1473  OpBuilder::InsertionGuard g(rewriter);
1474  auto loc = forallOp.getLoc();
1475  rewriter.setInsertionPoint(forallOp);
1477  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1478  Range normalizedLoopParams =
1479  emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1480  newUbs.push_back(normalizedLoopParams.size);
1481  }
1482  (void)foldDynamicIndexList(newUbs);
1483 
1484  // Use the normalized builder since the lower bounds are always 0 and the
1485  // steps are always 1.
1486  auto normalizedForallOp = scf::ForallOp::create(
1487  rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1488  [](OpBuilder &, Location, ValueRange) {});
1489 
1490  rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1491  normalizedForallOp.getBodyRegion(),
1492  normalizedForallOp.getBodyRegion().begin());
1493  // Remove the original empty block in the new loop.
1494  rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1495 
1496  rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1497  // Update the users of the original loop variables.
1498  for (auto [idx, iv] :
1499  llvm::enumerate(normalizedForallOp.getInductionVars())) {
1500  auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1501  auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1502  denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1503  }
1504 
1505  rewriter.replaceOp(forallOp, normalizedForallOp);
1506  return normalizedForallOp;
1507 }
1508 
1511  assert(!loops.empty() && "unexpected empty loop nest");
1512  if (loops.size() == 1)
1513  return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1514  for (auto [outerLoop, innerLoop] :
1515  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1516  auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1517  auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1518  if (!outerFor || !innerFor)
1519  return false;
1520  auto outerBBArgs = outerFor.getRegionIterArgs();
1521  auto innerIterArgs = innerFor.getInitArgs();
1522  if (outerBBArgs.size() != innerIterArgs.size())
1523  return false;
1524 
1525  for (auto [outerBBArg, innerIterArg] :
1526  llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1527  if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1528  innerIterArg != outerBBArg)
1529  return false;
1530  }
1531 
1532  ValueRange outerYields =
1533  cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1534  ValueRange innerResults = innerFor.getResults();
1535  if (outerYields.size() != innerResults.size())
1536  return false;
1537  for (auto [outerYield, innerResult] :
1538  llvm::zip_equal(outerYields, innerResults)) {
1539  if (!llvm::hasSingleElement(innerResult.getUses()) ||
1540  outerYield != innerResult)
1541  return false;
1542  }
1543  }
1544  return true;
1545 }
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition: Utils.cpp:788
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition: Utils.cpp:1190
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition: Utils.cpp:1212
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition: Utils.cpp:1146
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
Definition: Utils.cpp:298
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition: Utils.cpp:1227
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition: Utils.cpp:839
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition: Utils.cpp:265
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition: Utils.cpp:803
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition: Utils.cpp:734
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition: Utils.cpp:674
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition: Utils.cpp:496
static int64_t product(ArrayRef< int64_t > vals)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:341
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition: Builders.h:252
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
result_range getResults()
Definition: Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:91
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1584
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition: Utils.cpp:1311
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition: Utils.cpp:1509
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition: Utils.cpp:217
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:35
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:35
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition: Utils.cpp:979
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
std::pair< Loops, Loops > TileLoops
Definition: Utils.h:155
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:102
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition: Utils.cpp:481
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition: Utils.cpp:1054
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition: Utils.cpp:1299
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition: Utils.cpp:353
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition: Utils.cpp:509
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:240
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:70
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1278
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:114
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition: RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition: Utils.cpp:759
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:1358
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition: Utils.cpp:971
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition: Utils.cpp:1411
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition: Utils.cpp:1316
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition: Utils.cpp:688
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition: Utils.cpp:1464
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition: Utils.h:116
std::optional< scf::ForOp > mainLoopOp
Definition: Utils.h:115
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.