MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/DebugLog.h"
28 #include <cstdint>
29 
30 using namespace mlir;
31 
32 #define DEBUG_TYPE "scf-utils"
33 
35  RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
36  ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
37  bool replaceIterOperandsUsesInLoop) {
38  if (loopNest.empty())
39  return {};
40  // This method is recursive (to make it more readable). Adding an
41  // assertion here to limit the recursion. (See
42  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
43  assert(loopNest.size() <= 10 &&
44  "exceeded recursion limit when yielding value from loop nest");
45 
46  // To yield a value from a perfectly nested loop nest, the following
47  // pattern needs to be created, i.e. starting with
48  //
49  // ```mlir
50  // scf.for .. {
51  // scf.for .. {
52  // scf.for .. {
53  // %value = ...
54  // }
55  // }
56  // }
57  // ```
58  //
59  // needs to be modified to
60  //
61  // ```mlir
62  // %0 = scf.for .. iter_args(%arg0 = %init) {
63  // %1 = scf.for .. iter_args(%arg1 = %arg0) {
64  // %2 = scf.for .. iter_args(%arg2 = %arg1) {
65  // %value = ...
66  // scf.yield %value
67  // }
68  // scf.yield %2
69  // }
70  // scf.yield %1
71  // }
72  // ```
73  //
74  // The inner most loop is handled using the `replaceWithAdditionalYields`
75  // that works on a single loop.
76  if (loopNest.size() == 1) {
77  auto innerMostLoop =
78  cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
79  rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
80  newYieldValuesFn));
81  return {innerMostLoop};
82  }
83  // The outer loops are modified by calling this method recursively
84  // - The return value of the inner loop is the value yielded by this loop.
85  // - The region iter args of this loop are the init_args for the inner loop.
86  SmallVector<scf::ForOp> newLoopNest;
87  NewYieldValuesFn fn =
88  [&](OpBuilder &innerBuilder, Location loc,
89  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
90  newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
91  innerNewBBArgs, newYieldValuesFn,
92  replaceIterOperandsUsesInLoop);
93  return llvm::to_vector(llvm::map_range(
94  newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
95  [](OpResult r) -> Value { return r; }));
96  };
97  scf::ForOp outerMostLoop =
98  cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
99  rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
100  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
101  return newLoopNest;
102 }
103 
104 /// Outline a region with a single block into a new FuncOp.
105 /// Assumes the FuncOp result types is the type of the yielded operands of the
106 /// single block. This constraint makes it easy to determine the result.
107 /// This method also clones the `arith::ConstantIndexOp` at the start of
108 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
109 /// provided, it will be set to point to the operation that calls the outlined
110 /// function.
111 // TODO: support more than single-block regions.
112 // TODO: more flexible constant handling.
113 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
114  Location loc,
115  Region &region,
116  StringRef funcName,
117  func::CallOp *callOp) {
118  assert(!funcName.empty() && "funcName cannot be empty");
119  if (!region.hasOneBlock())
120  return failure();
121 
122  Block *originalBlock = &region.front();
123  Operation *originalTerminator = originalBlock->getTerminator();
124 
125  // Outline before current function.
126  OpBuilder::InsertionGuard g(rewriter);
127  rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
128 
129  SetVector<Value> captures;
130  getUsedValuesDefinedAbove(region, captures);
131 
132  ValueRange outlinedValues(captures.getArrayRef());
133  SmallVector<Type> outlinedFuncArgTypes;
134  SmallVector<Location> outlinedFuncArgLocs;
135  // Region's arguments are exactly the first block's arguments as per
136  // Region::getArguments().
137  // Func's arguments are cat(regions's arguments, captures arguments).
138  for (BlockArgument arg : region.getArguments()) {
139  outlinedFuncArgTypes.push_back(arg.getType());
140  outlinedFuncArgLocs.push_back(arg.getLoc());
141  }
142  for (Value value : outlinedValues) {
143  outlinedFuncArgTypes.push_back(value.getType());
144  outlinedFuncArgLocs.push_back(value.getLoc());
145  }
146  FunctionType outlinedFuncType =
147  FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
148  originalTerminator->getOperandTypes());
149  auto outlinedFunc =
150  func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
151  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
152 
153  // Merge blocks while replacing the original block operands.
154  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
155  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
156  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
157  {
158  OpBuilder::InsertionGuard g(rewriter);
159  rewriter.setInsertionPointToEnd(outlinedFuncBody);
160  rewriter.mergeBlocks(
161  originalBlock, outlinedFuncBody,
162  outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
163  // Explicitly set up a new ReturnOp terminator.
164  rewriter.setInsertionPointToEnd(outlinedFuncBody);
165  func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
166  originalTerminator->getOperands());
167  }
168 
169  // Reconstruct the block that was deleted and add a
170  // terminator(call_results).
171  Block *newBlock = rewriter.createBlock(
172  &region, region.begin(),
173  TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
174  ArrayRef<Location>(outlinedFuncArgLocs)
175  .take_front(numOriginalBlockArguments));
176  {
177  OpBuilder::InsertionGuard g(rewriter);
178  rewriter.setInsertionPointToEnd(newBlock);
179  SmallVector<Value> callValues;
180  llvm::append_range(callValues, newBlock->getArguments());
181  llvm::append_range(callValues, outlinedValues);
182  auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
183  if (callOp)
184  *callOp = call;
185 
186  // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
187  // Clone `originalTerminator` to take the callOp results then erase it from
188  // `outlinedFuncBody`.
189  IRMapping bvm;
190  bvm.map(originalTerminator->getOperands(), call->getResults());
191  rewriter.clone(*originalTerminator, bvm);
192  rewriter.eraseOp(originalTerminator);
193  }
194 
195  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
196  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
197  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
198  outlinedValues.size()))) {
199  Value orig = std::get<0>(it);
200  Value repl = std::get<1>(it);
201  {
202  OpBuilder::InsertionGuard g(rewriter);
203  rewriter.setInsertionPointToStart(outlinedFuncBody);
204  if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
205  IRMapping bvm;
206  repl = rewriter.clone(*cst, bvm)->getResult(0);
207  }
208  }
209  orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
210  return outlinedFunc->isProperAncestor(opOperand.getOwner());
211  });
212  }
213 
214  return outlinedFunc;
215 }
216 
217 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
218  func::FuncOp *thenFn, StringRef thenFnName,
219  func::FuncOp *elseFn, StringRef elseFnName) {
220  IRRewriter rewriter(b);
221  Location loc = ifOp.getLoc();
222  FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
223  if (thenFn && !ifOp.getThenRegion().empty()) {
224  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
225  rewriter, loc, ifOp.getThenRegion(), thenFnName);
226  if (failed(outlinedFuncOpOrFailure))
227  return failure();
228  *thenFn = *outlinedFuncOpOrFailure;
229  }
230  if (elseFn && !ifOp.getElseRegion().empty()) {
231  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
232  rewriter, loc, ifOp.getElseRegion(), elseFnName);
233  if (failed(outlinedFuncOpOrFailure))
234  return failure();
235  *elseFn = *outlinedFuncOpOrFailure;
236  }
237  return success();
238 }
239 
242  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
243  bool rootEnclosesPloops = false;
244  for (Region &region : rootOp->getRegions()) {
245  for (Block &block : region.getBlocks()) {
246  for (Operation &op : block) {
247  bool enclosesPloops = getInnermostParallelLoops(&op, result);
248  rootEnclosesPloops |= enclosesPloops;
249  if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
250  rootEnclosesPloops = true;
251 
252  // Collect parallel loop if it is an innermost one.
253  if (!enclosesPloops)
254  result.push_back(ploop);
255  }
256  }
257  }
258  }
259  return rootEnclosesPloops;
260 }
261 
262 // Build the IR that performs ceil division of a positive value by a constant:
263 // ceildiv(a, B) = divis(a + (B-1), B)
264 // where divis is rounding-to-zero division.
265 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
266  int64_t divisor) {
267  assert(divisor > 0 && "expected positive divisor");
268  assert(dividend.getType().isIntOrIndex() &&
269  "expected integer or index-typed value");
270 
271  Value divisorMinusOneCst = arith::ConstantOp::create(
272  builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
273  Value divisorCst = arith::ConstantOp::create(
274  builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
275  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
276  return arith::DivUIOp::create(builder, loc, sum, divisorCst);
277 }
278 
279 // Build the IR that performs ceil division of a positive value by another
280 // positive value:
281 // ceildiv(a, b) = divis(a + (b - 1), b)
282 // where divis is rounding-to-zero division.
283 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
284  Value divisor) {
285  assert(dividend.getType().isIntOrIndex() &&
286  "expected integer or index-typed value");
287  Value cstOne = arith::ConstantOp::create(
288  builder, loc, builder.getOneAttr(dividend.getType()));
289  Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
290  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
291  return arith::DivUIOp::create(builder, loc, sum, divisor);
292 }
293 
294 /// Returns the trip count of `forOp` if its' low bound, high bound and step are
295 /// constants, or optional otherwise. Trip count is computed as
296 /// ceilDiv(highBound - lowBound, step).
297 static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
298  return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
299  forOp.getStep());
300 }
301 
302 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
303 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
304 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
305 /// unrolled iteration using annotateFn.
307  Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
308  function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
309  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
310  ValueRange iterArgs, ValueRange yieldedValues) {
311  // Builder to insert unrolled bodies just before the terminator of the body of
312  // 'forOp'.
313  auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
314 
315  constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
316  if (!annotateFn)
317  annotateFn = defaultAnnotateFn;
318 
319  // Keep a pointer to the last non-terminator operation in the original block
320  // so that we know what to clone (since we are doing this in-place).
321  Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
322 
323  // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
324  SmallVector<Value, 4> lastYielded(yieldedValues);
325 
326  for (unsigned i = 1; i < unrollFactor; i++) {
327  IRMapping operandMap;
328 
329  // Prepare operand map.
330  operandMap.map(iterArgs, lastYielded);
331 
332  // If the induction variable is used, create a remapping to the value for
333  // this unrolled instance.
334  if (!forOpIV.use_empty()) {
335  Value ivUnroll = ivRemapFn(i, forOpIV, builder);
336  operandMap.map(forOpIV, ivUnroll);
337  }
338 
339  // Clone the original body of 'forOp'.
340  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
341  Operation *clonedOp = builder.clone(*it, operandMap);
342  annotateFn(i, clonedOp, builder);
343  }
344 
345  // Update yielded values.
346  for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
347  lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
348  }
349 
350  // Make sure we annotate the Ops in the original body. We do this last so that
351  // any annotations are not copied into the cloned Ops above.
352  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
353  annotateFn(0, &*it, builder);
354 
355  // Update operands of the yield statement.
356  loopBodyBlock->getTerminator()->setOperands(lastYielded);
357 }
358 
359 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
360 /// epilogue loop, if the loop is unrolled.
361 FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
362  scf::ForOp forOp, uint64_t unrollFactor,
363  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
364  assert(unrollFactor > 0 && "expected positive unroll factor");
365 
366  // Return if the loop body is empty.
367  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
368  return UnrolledLoopInfo{forOp, std::nullopt};
369 
370  // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
371  // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
372  OpBuilder boundsBuilder(forOp);
373  IRRewriter rewriter(forOp.getContext());
374  auto loc = forOp.getLoc();
375  Value step = forOp.getStep();
376  Value upperBoundUnrolled;
377  Value stepUnrolled;
378  bool generateEpilogueLoop = true;
379 
380  std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
381  if (constTripCount) {
382  // Constant loop bounds computation.
383  int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
384  int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
385  int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
386  if (unrollFactor == 1) {
387  if (*constTripCount == 1 &&
388  failed(forOp.promoteIfSingleIteration(rewriter)))
389  return failure();
390  return UnrolledLoopInfo{forOp, std::nullopt};
391  }
392 
393  int64_t tripCountEvenMultiple =
394  *constTripCount - (*constTripCount % unrollFactor);
395  int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
396  int64_t stepUnrolledCst = stepCst * unrollFactor;
397 
398  // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
399  generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
400  if (generateEpilogueLoop)
401  upperBoundUnrolled = arith::ConstantOp::create(
402  boundsBuilder, loc,
403  boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
404  upperBoundUnrolledCst));
405  else
406  upperBoundUnrolled = forOp.getUpperBound();
407 
408  // Create constant for 'stepUnrolled'.
409  stepUnrolled =
410  stepCst == stepUnrolledCst
411  ? step
412  : arith::ConstantOp::create(boundsBuilder, loc,
413  boundsBuilder.getIntegerAttr(
414  step.getType(), stepUnrolledCst));
415  } else {
416  // Dynamic loop bounds computation.
417  // TODO: Add dynamic asserts for negative lb/ub/step, or
418  // consider using ceilDiv from AffineApplyExpander.
419  auto lowerBound = forOp.getLowerBound();
420  auto upperBound = forOp.getUpperBound();
421  Value diff =
422  arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
423  Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
424  Value unrollFactorCst = arith::ConstantOp::create(
425  boundsBuilder, loc,
426  boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
427  Value tripCountRem =
428  arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
429  // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
430  Value tripCountEvenMultiple =
431  arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
432  // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
433  upperBoundUnrolled = arith::AddIOp::create(
434  boundsBuilder, loc, lowerBound,
435  arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
436  // Scale 'step' by 'unrollFactor'.
437  stepUnrolled =
438  arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
439  }
440 
441  UnrolledLoopInfo resultLoops;
442 
443  // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
444  if (generateEpilogueLoop) {
445  OpBuilder epilogueBuilder(forOp->getContext());
446  epilogueBuilder.setInsertionPointAfter(forOp);
447  auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
448  epilogueForOp.setLowerBound(upperBoundUnrolled);
449 
450  // Update uses of loop results.
451  auto results = forOp.getResults();
452  auto epilogueResults = epilogueForOp.getResults();
453 
454  for (auto e : llvm::zip(results, epilogueResults)) {
455  std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
456  }
457  epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
458  epilogueForOp.getInitArgs().size(), results);
459  if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
460  resultLoops.epilogueLoopOp = epilogueForOp;
461  }
462 
463  // Create unrolled loop.
464  forOp.setUpperBound(upperBoundUnrolled);
465  forOp.setStep(stepUnrolled);
466 
467  auto iterArgs = ValueRange(forOp.getRegionIterArgs());
468  auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
469 
471  forOp.getBody(), forOp.getInductionVar(), unrollFactor,
472  [&](unsigned i, Value iv, OpBuilder b) {
473  // iv' = iv + step * i;
474  auto stride = arith::MulIOp::create(
475  b, loc, step,
476  arith::ConstantOp::create(b, loc,
477  b.getIntegerAttr(iv.getType(), i)));
478  return arith::AddIOp::create(b, loc, iv, stride);
479  },
480  annotateFn, iterArgs, yieldedValues);
481  // Promote the loop body up if this has turned into a single iteration loop.
482  if (forOp.promoteIfSingleIteration(rewriter).failed())
483  resultLoops.mainLoopOp = forOp;
484  return resultLoops;
485 }
486 
487 /// Unrolls this loop completely.
488 LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
489  IRRewriter rewriter(forOp.getContext());
490  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
491  if (!mayBeConstantTripCount.has_value())
492  return failure();
493  uint64_t tripCount = *mayBeConstantTripCount;
494  if (tripCount == 0)
495  return success();
496  if (tripCount == 1)
497  return forOp.promoteIfSingleIteration(rewriter);
498  return loopUnrollByFactor(forOp, tripCount);
499 }
500 
501 /// Check if bounds of all inner loops are defined outside of `forOp`
502 /// and return false if not.
503 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
504  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
505  if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
506  !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
507  !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
508  return WalkResult::interrupt();
509 
510  return WalkResult::advance();
511  });
512  return !walkResult.wasInterrupted();
513 }
514 
515 /// Unrolls and jams this loop by the specified factor.
516 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
517  uint64_t unrollJamFactor) {
518  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
519 
520  if (unrollJamFactor == 1)
521  return success();
522 
523  // If any control operand of any inner loop of `forOp` is defined within
524  // `forOp`, no unroll jam.
525  if (!areInnerBoundsInvariant(forOp)) {
526  LDBG() << "failed to unroll and jam: inner bounds are not invariant";
527  return failure();
528  }
529 
530  // Currently, for operations with results are not supported.
531  if (forOp->getNumResults() > 0) {
532  LDBG() << "failed to unroll and jam: unsupported loop with results";
533  return failure();
534  }
535 
536  // Currently, only constant trip count that divided by the unroll factor is
537  // supported.
538  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
539  if (!tripCount.has_value()) {
540  // If the trip count is dynamic, do not unroll & jam.
541  LDBG() << "failed to unroll and jam: trip count could not be determined";
542  return failure();
543  }
544  if (unrollJamFactor > *tripCount) {
545  LDBG() << "unroll and jam factor is greater than trip count, set factor to "
546  "trip "
547  "count";
548  unrollJamFactor = *tripCount;
549  } else if (*tripCount % unrollJamFactor != 0) {
550  LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
551  "multiple of unroll jam factor";
552  return failure();
553  }
554 
555  // Nothing in the loop body other than the terminator.
556  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
557  return success();
558 
559  // Gather all sub-blocks to jam upon the loop being unrolled.
561  jbg.walk(forOp);
562  auto &subBlocks = jbg.subBlocks;
563 
564  // Collect inner loops.
565  SmallVector<scf::ForOp> innerLoops;
566  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
567 
568  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
569  // iteration. There are (`unrollJamFactor` - 1) iterations.
570  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
571 
572  // For any loop with iter_args, replace it with a new loop that has
573  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
574  // operands.
575  SmallVector<scf::ForOp> newInnerLoops;
576  IRRewriter rewriter(forOp.getContext());
577  for (scf::ForOp oldForOp : innerLoops) {
578  SmallVector<Value> dupIterOperands, dupYieldOperands;
579  ValueRange oldIterOperands = oldForOp.getInits();
580  ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
581  ValueRange oldYieldOperands =
582  cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
583  // Get additional iterOperands, iterArgs, and yield operands. We will
584  // fix iterOperands and yield operands after cloning of sub-blocks.
585  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
586  dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
587  dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
588  }
589  // Create a new loop with additional iterOperands, iter_args and yield
590  // operands. This new loop will take the loop body of the original loop.
591  bool forOpReplaced = oldForOp == forOp;
592  scf::ForOp newForOp =
593  cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
594  rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
595  [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
596  return dupYieldOperands;
597  }));
598  newInnerLoops.push_back(newForOp);
599  // `forOp` has been replaced with a new loop.
600  if (forOpReplaced)
601  forOp = newForOp;
602  // Update `operandMaps` for `newForOp` iterArgs and results.
603  ValueRange newIterArgs = newForOp.getRegionIterArgs();
604  unsigned oldNumIterArgs = oldIterArgs.size();
605  ValueRange newResults = newForOp.getResults();
606  unsigned oldNumResults = newResults.size() / unrollJamFactor;
607  assert(oldNumIterArgs == oldNumResults &&
608  "oldNumIterArgs must be the same as oldNumResults");
609  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
610  for (unsigned j = 0; j < oldNumIterArgs; ++j) {
611  // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
612  // results. Update `operandMaps[i - 1]` to map old iterArgs and results
613  // to those in the `i`th new set.
614  operandMaps[i - 1].map(newIterArgs[j],
615  newIterArgs[i * oldNumIterArgs + j]);
616  operandMaps[i - 1].map(newResults[j],
617  newResults[i * oldNumResults + j]);
618  }
619  }
620  }
621 
622  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
623  rewriter.setInsertionPoint(forOp);
624  int64_t step = forOp.getConstantStep()->getSExtValue();
625  auto newStep = rewriter.createOrFold<arith::MulIOp>(
626  forOp.getLoc(), forOp.getStep(),
627  rewriter.createOrFold<arith::ConstantOp>(
628  forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
629  forOp.setStep(newStep);
630  auto forOpIV = forOp.getInductionVar();
631 
632  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
633  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
634  for (auto &subBlock : subBlocks) {
635  // Builder to insert unroll-jammed bodies. Insert right at the end of
636  // sub-block.
637  OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
638 
639  // If the induction variable is used, create a remapping to the value for
640  // this unrolled instance.
641  if (!forOpIV.use_empty()) {
642  // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
643  auto ivTag = builder.createOrFold<arith::ConstantOp>(
644  forOp.getLoc(), builder.getIndexAttr(step * i));
645  auto ivUnroll =
646  builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
647  operandMaps[i - 1].map(forOpIV, ivUnroll);
648  }
649  // Clone the sub-block being unroll-jammed.
650  for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
651  builder.clone(*it, operandMaps[i - 1]);
652  }
653  // Fix iterOperands and yield op operands of newly created loops.
654  for (auto newForOp : newInnerLoops) {
655  unsigned oldNumIterOperands =
656  newForOp.getNumRegionIterArgs() / unrollJamFactor;
657  unsigned numControlOperands = newForOp.getNumControlOperands();
658  auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
659  unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
660  assert(oldNumIterOperands == oldNumYieldOperands &&
661  "oldNumIterOperands must be the same as oldNumYieldOperands");
662  for (unsigned j = 0; j < oldNumIterOperands; ++j) {
663  // The `i`th duplication of an old iterOperand or yield op operand
664  // needs to be replaced with a mapped value from `operandMaps[i - 1]`
665  // if such mapped value exists.
666  newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
667  operandMaps[i - 1].lookupOrDefault(
668  newForOp.getOperand(numControlOperands + j)));
669  yieldOp.setOperand(
670  i * oldNumYieldOperands + j,
671  operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
672  }
673  }
674  }
675 
676  // Promote the loop body up if this has turned into a single iteration loop.
677  (void)forOp.promoteIfSingleIteration(rewriter);
678  return success();
679 }
680 
682  OpFoldResult lb, OpFoldResult ub,
683  OpFoldResult step) {
684  Range normalizedLoopBounds;
685  normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
686  normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
687  AffineExpr s0, s1, s2;
688  bindSymbols(rewriter.getContext(), s0, s1, s2);
689  AffineExpr e = (s1 - s0).ceilDiv(s2);
690  normalizedLoopBounds.size =
691  affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
692  return normalizedLoopBounds;
693 }
694 
696  OpFoldResult lb, OpFoldResult ub,
697  OpFoldResult step) {
698  if (getType(lb).isIndex()) {
699  return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
700  }
701  // For non-index types, generate `arith` instructions
702  // Check if the loop is already known to have a constant zero lower bound or
703  // a constant one step.
704  bool isZeroBased = false;
705  if (auto lbCst = getConstantIntValue(lb))
706  isZeroBased = lbCst.value() == 0;
707 
708  bool isStepOne = false;
709  if (auto stepCst = getConstantIntValue(step))
710  isStepOne = stepCst.value() == 1;
711 
712  Type rangeType = getType(lb);
713  assert(rangeType == getType(ub) && rangeType == getType(step) &&
714  "expected matching types");
715 
716  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
717  // assuming the step is strictly positive. Update the bounds and the step
718  // of the loop to go from 0 to the number of iterations, if necessary.
719  if (isZeroBased && isStepOne)
720  return {lb, ub, step};
721 
722  OpFoldResult diff = ub;
723  if (!isZeroBased) {
724  diff = rewriter.createOrFold<arith::SubIOp>(
725  loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
726  getValueOrCreateConstantIntOp(rewriter, loc, lb));
727  }
728  OpFoldResult newUpperBound = diff;
729  if (!isStepOne) {
730  newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
731  loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
732  getValueOrCreateConstantIntOp(rewriter, loc, step));
733  }
734 
735  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
736  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
737 
738  return {newLowerBound, newUpperBound, newStep};
739 }
740 
742  Location loc,
743  Value normalizedIv,
744  OpFoldResult origLb,
745  OpFoldResult origStep) {
746  AffineExpr d0, s0, s1;
747  bindSymbols(rewriter.getContext(), s0, s1);
748  bindDims(rewriter.getContext(), d0);
749  AffineExpr e = d0 * s1 + s0;
751  rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
752  Value denormalizedIvVal =
753  getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
754  SmallPtrSet<Operation *, 1> preservedUses;
755  // If an `affine.apply` operation is generated for denormalization, the use
756  // of `origLb` in those ops must not be replaced. These arent not generated
757  // when `origLb == 0` and `origStep == 1`.
758  if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
759  if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
760  preservedUses.insert(preservedUse);
761  }
762  }
763  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
764 }
765 
767  Value normalizedIv, OpFoldResult origLb,
768  OpFoldResult origStep) {
769  if (getType(origLb).isIndex()) {
770  return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
771  origLb, origStep);
772  }
773  Value denormalizedIv;
775  bool isStepOne = isOneInteger(origStep);
776  bool isZeroBased = isZeroInteger(origLb);
777 
778  Value scaled = normalizedIv;
779  if (!isStepOne) {
780  Value origStepValue =
781  getValueOrCreateConstantIntOp(rewriter, loc, origStep);
782  scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
783  preserve.insert(scaled.getDefiningOp());
784  }
785  denormalizedIv = scaled;
786  if (!isZeroBased) {
787  Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
788  denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
789  preserve.insert(denormalizedIv.getDefiningOp());
790  }
791 
792  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
793 }
794 
796  ArrayRef<OpFoldResult> values) {
797  assert(!values.empty() && "unexecpted empty array");
798  AffineExpr s0, s1;
799  bindSymbols(rewriter.getContext(), s0, s1);
800  AffineExpr mul = s0 * s1;
801  OpFoldResult products = rewriter.getIndexAttr(1);
802  for (auto v : values) {
804  rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
805  }
806  return products;
807 }
808 
809 /// Helper function to multiply a sequence of values.
811  ArrayRef<Value> values) {
812  assert(!values.empty() && "unexpected empty list");
813  if (getType(values.front()).isIndex()) {
815  OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
816  return getValueOrCreateConstantIndexOp(rewriter, loc, product);
817  }
818  std::optional<Value> productOf;
819  for (auto v : values) {
820  auto vOne = getConstantIntValue(v);
821  if (vOne && vOne.value() == 1)
822  continue;
823  if (productOf)
824  productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
825  .getResult();
826  else
827  productOf = v;
828  }
829  if (!productOf) {
830  productOf = arith::ConstantOp::create(
831  rewriter, loc, rewriter.getOneAttr(getType(values.front())))
832  .getResult();
833  }
834  return productOf.value();
835 }
836 
837 /// For each original loop, the value of the
838 /// induction variable can be obtained by dividing the induction variable of
839 /// the linearized loop by the total number of iterations of the loops nested
840 /// in it modulo the number of iterations in this loop (remove the values
841 /// related to the outer loops):
842 /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
843 /// Compute these iteratively from the innermost loop by creating a "running
844 /// quotient" of division by the range.
845 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
847  Value linearizedIv, ArrayRef<Value> ubs) {
848 
849  if (linearizedIv.getType().isIndex()) {
850  Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
851  rewriter, loc, linearizedIv, ubs);
852  auto resultVals = llvm::map_to_vector(
853  delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
854  return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
855  }
856 
857  SmallVector<Value> delinearizedIvs(ubs.size());
858  SmallPtrSet<Operation *, 2> preservedUsers;
859 
860  llvm::BitVector isUbOne(ubs.size());
861  for (auto [index, ub] : llvm::enumerate(ubs)) {
862  auto ubCst = getConstantIntValue(ub);
863  if (ubCst && ubCst.value() == 1)
864  isUbOne.set(index);
865  }
866 
867  // Prune the lead ubs that are all ones.
868  unsigned numLeadingOneUbs = 0;
869  for (auto [index, ub] : llvm::enumerate(ubs)) {
870  if (!isUbOne.test(index)) {
871  break;
872  }
873  delinearizedIvs[index] = arith::ConstantOp::create(
874  rewriter, loc, rewriter.getZeroAttr(ub.getType()));
875  numLeadingOneUbs++;
876  }
877 
878  Value previous = linearizedIv;
879  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
880  unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
881  if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
882  previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
883  preservedUsers.insert(previous.getDefiningOp());
884  }
885  Value iv = previous;
886  if (i != e - 1) {
887  if (!isUbOne.test(idx)) {
888  iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
889  preservedUsers.insert(iv.getDefiningOp());
890  } else {
891  iv = arith::ConstantOp::create(
892  rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
893  }
894  }
895  delinearizedIvs[idx] = iv;
896  }
897  return {delinearizedIvs, preservedUsers};
898 }
899 
900 LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
902  if (loops.size() < 2)
903  return failure();
904 
905  scf::ForOp innermost = loops.back();
906  scf::ForOp outermost = loops.front();
907 
908  // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
909  // allows the following code to assume upperBound is the number of iterations.
910  for (auto loop : loops) {
911  OpBuilder::InsertionGuard g(rewriter);
912  rewriter.setInsertionPoint(outermost);
913  Value lb = loop.getLowerBound();
914  Value ub = loop.getUpperBound();
915  Value step = loop.getStep();
916  auto newLoopRange =
917  emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
918 
919  rewriter.modifyOpInPlace(loop, [&]() {
920  loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
921  newLoopRange.offset));
922  loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
923  newLoopRange.size));
924  loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
925  newLoopRange.stride));
926  });
927  rewriter.setInsertionPointToStart(innermost.getBody());
928  denormalizeInductionVariable(rewriter, loop.getLoc(),
929  loop.getInductionVar(), lb, step);
930  }
931 
932  // 2. Emit code computing the upper bound of the coalesced loop as product
933  // of the number of iterations of all loops.
934  OpBuilder::InsertionGuard g(rewriter);
935  rewriter.setInsertionPoint(outermost);
936  Location loc = outermost.getLoc();
937  SmallVector<Value> upperBounds = llvm::map_to_vector(
938  loops, [](auto loop) { return loop.getUpperBound(); });
939  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
940  outermost.setUpperBound(upperBound);
941 
942  rewriter.setInsertionPointToStart(innermost.getBody());
943  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
944  rewriter, loc, outermost.getInductionVar(), upperBounds);
945  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
946  preservedUsers);
947 
948  for (int i = loops.size() - 1; i > 0; --i) {
949  auto outerLoop = loops[i - 1];
950  auto innerLoop = loops[i];
951 
952  Operation *innerTerminator = innerLoop.getBody()->getTerminator();
953  auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
954  assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
955  for (Value &yieldedVal : yieldedVals) {
956  // The yielded value may be an iteration argument of the inner loop
957  // which is about to be inlined.
958  auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
959  if (iter != innerLoop.getRegionIterArgs().end()) {
960  unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
961  // `outerLoop` iter args identical to the `innerLoop` init args.
962  assert(iterArgIndex < innerLoop.getInitArgs().size());
963  yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
964  }
965  }
966  rewriter.eraseOp(innerTerminator);
967 
968  SmallVector<Value> innerBlockArgs;
969  innerBlockArgs.push_back(delinearizeIvs[i]);
970  llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
971  rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
972  Block::iterator(innerLoop), innerBlockArgs);
973  rewriter.replaceOp(innerLoop, yieldedVals);
974  }
975  return success();
976 }
977 
979  if (loops.empty()) {
980  return failure();
981  }
982  IRRewriter rewriter(loops.front().getContext());
983  return coalesceLoops(rewriter, loops);
984 }
985 
986 LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
987  LogicalResult result(failure());
989  getPerfectlyNestedLoops(loops, op);
990 
991  // Look for a band of loops that can be coalesced, i.e. perfectly nested
992  // loops with bounds defined above some loop.
993 
994  // 1. For each loop, find above which parent loop its bounds operands are
995  // defined.
996  SmallVector<unsigned> operandsDefinedAbove(loops.size());
997  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
998  operandsDefinedAbove[i] = i;
999  for (unsigned j = 0; j < i; ++j) {
1000  SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1001  loops[i].getUpperBound(),
1002  loops[i].getStep()};
1003  if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1004  operandsDefinedAbove[i] = j;
1005  break;
1006  }
1007  }
1008  }
1009 
1010  // 2. For each inner loop check that the iter_args for the immediately outer
1011  // loop are the init for the immediately inner loop and that the yields of the
1012  // return of the inner loop is the yield for the immediately outer loop. Keep
1013  // track of where the chain starts from for each loop.
1014  SmallVector<unsigned> iterArgChainStart(loops.size());
1015  iterArgChainStart[0] = 0;
1016  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1017  // By default set the start of the chain to itself.
1018  iterArgChainStart[i] = i;
1019  auto outerloop = loops[i - 1];
1020  auto innerLoop = loops[i];
1021  if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1022  continue;
1023  }
1024  if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1025  continue;
1026  }
1027  auto outerloopTerminator = outerloop.getBody()->getTerminator();
1028  if (!llvm::equal(outerloopTerminator->getOperands(),
1029  innerLoop.getResults())) {
1030  continue;
1031  }
1032  iterArgChainStart[i] = iterArgChainStart[i - 1];
1033  }
1034 
1035  // 3. Identify bands of loops such that the operands of all of them are
1036  // defined above the first loop in the band. Traverse the nest bottom-up
1037  // so that modifications don't invalidate the inner loops.
1038  for (unsigned end = loops.size(); end > 0; --end) {
1039  unsigned start = 0;
1040  for (; start < end - 1; ++start) {
1041  auto maxPos =
1042  *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1043  std::next(operandsDefinedAbove.begin(), end));
1044  if (maxPos > start)
1045  continue;
1046  if (iterArgChainStart[end - 1] > start)
1047  continue;
1048  auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1049  if (succeeded(coalesceLoops(band)))
1050  result = success();
1051  break;
1052  }
1053  // If a band was found and transformed, keep looking at the loops above
1054  // the outermost transformed loop.
1055  if (start != end - 1)
1056  end = start + 1;
1057  }
1058  return result;
1059 }
1060 
1062  RewriterBase &rewriter, scf::ParallelOp loops,
1063  ArrayRef<std::vector<unsigned>> combinedDimensions) {
1064  OpBuilder::InsertionGuard g(rewriter);
1065  rewriter.setInsertionPoint(loops);
1066  Location loc = loops.getLoc();
1067 
1068  // Presort combined dimensions.
1069  auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1070  for (auto &dims : sortedDimensions)
1071  llvm::sort(dims);
1072 
1073  // Normalize ParallelOp's iteration pattern.
1074  SmallVector<Value, 3> normalizedUpperBounds;
1075  for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1076  OpBuilder::InsertionGuard g2(rewriter);
1077  rewriter.setInsertionPoint(loops);
1078  Value lb = loops.getLowerBound()[i];
1079  Value ub = loops.getUpperBound()[i];
1080  Value step = loops.getStep()[i];
1081  auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1082  normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1083  rewriter, loops.getLoc(), newLoopRange.size));
1084 
1085  rewriter.setInsertionPointToStart(loops.getBody());
1086  denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1087  step);
1088  }
1089 
1090  // Combine iteration spaces.
1091  SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1092  auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1093  auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
1094  for (auto &sortedDimension : sortedDimensions) {
1095  Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
1096  for (auto idx : sortedDimension) {
1097  newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1098  normalizedUpperBounds[idx]);
1099  }
1100  lowerBounds.push_back(cst0);
1101  steps.push_back(cst1);
1102  upperBounds.push_back(newUpperBound);
1103  }
1104 
1105  // Create new ParallelLoop with conversions to the original induction values.
1106  // The loop below uses divisions to get the relevant range of values in the
1107  // new induction value that represent each range of the original induction
1108  // value. The remainders then determine based on that range, which iteration
1109  // of the original induction value this represents. This is a normalized value
1110  // that is un-normalized already by the previous logic.
1111  auto newPloop = scf::ParallelOp::create(
1112  rewriter, loc, lowerBounds, upperBounds, steps,
1113  [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1114  for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1115  Value previous = ploopIVs[i];
1116  unsigned numberCombinedDimensions = combinedDimensions[i].size();
1117  // Iterate over all except the last induction value.
1118  for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1119  unsigned idx = combinedDimensions[i][j];
1120 
1121  // Determine the current induction value's current loop iteration
1122  Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1123  normalizedUpperBounds[idx]);
1124  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1125  loops.getRegion());
1126 
1127  // Remove the effect of the current induction value to prepare for
1128  // the next value.
1129  previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1130  normalizedUpperBounds[idx]);
1131  }
1132 
1133  // The final induction value is just the remaining value.
1134  unsigned idx = combinedDimensions[i][0];
1135  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1136  previous, loops.getRegion());
1137  }
1138  });
1139 
1140  // Replace the old loop with the new loop.
1141  loops.getBody()->back().erase();
1142  newPloop.getBody()->getOperations().splice(
1143  Block::iterator(newPloop.getBody()->back()),
1144  loops.getBody()->getOperations());
1145  loops.erase();
1146 }
1147 
1148 // Hoist the ops within `outer` that appear before `inner`.
1149 // Such ops include the ops that have been introduced by parametric tiling.
1150 // Ops that come from triangular loops (i.e. that belong to the program slice
1151 // rooted at `outer`) and ops that have side effects cannot be hoisted.
1152 // Return failure when any op fails to hoist.
1153 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1154  SetVector<Operation *> forwardSlice;
1156  options.filter = [&inner](Operation *op) {
1157  return op != inner.getOperation();
1158  };
1159  getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1160  LogicalResult status = success();
1162  for (auto &op : outer.getBody()->without_terminator()) {
1163  // Stop when encountering the inner loop.
1164  if (&op == inner.getOperation())
1165  break;
1166  // Skip over non-hoistable ops.
1167  if (forwardSlice.count(&op) > 0) {
1168  status = failure();
1169  continue;
1170  }
1171  // Skip intermediate scf::ForOp, these are not considered a failure.
1172  if (isa<scf::ForOp>(op))
1173  continue;
1174  // Skip other ops with regions.
1175  if (op.getNumRegions() > 0) {
1176  status = failure();
1177  continue;
1178  }
1179  // Skip if op has side effects.
1180  // TODO: loads to immutable memory regions are ok.
1181  if (!isMemoryEffectFree(&op)) {
1182  status = failure();
1183  continue;
1184  }
1185  toHoist.push_back(&op);
1186  }
1187  auto *outerForOp = outer.getOperation();
1188  for (auto *op : toHoist)
1189  op->moveBefore(outerForOp);
1190  return status;
1191 }
1192 
1193 // Traverse the interTile and intraTile loops and try to hoist ops such that
1194 // bands of perfectly nested loops are isolated.
1195 // Return failure if either perfect interTile or perfect intraTile bands cannot
1196 // be formed.
1197 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1198  LogicalResult status = success();
1199  const Loops &interTile = tileLoops.first;
1200  const Loops &intraTile = tileLoops.second;
1201  auto size = interTile.size();
1202  assert(size == intraTile.size());
1203  if (size <= 1)
1204  return success();
1205  for (unsigned s = 1; s < size; ++s)
1206  status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1207  : failure();
1208  for (unsigned s = 1; s < size; ++s)
1209  status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1210  : failure();
1211  return status;
1212 }
1213 
1214 /// Collect perfectly nested loops starting from `rootForOps`. Loops are
1215 /// perfectly nested if each loop is the first and only non-terminator operation
1216 /// in the parent loop. Collect at most `maxLoops` loops and append them to
1217 /// `forOps`.
1218 template <typename T>
1220  SmallVectorImpl<T> &forOps, T rootForOp,
1221  unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1222  for (unsigned i = 0; i < maxLoops; ++i) {
1223  forOps.push_back(rootForOp);
1224  Block &body = rootForOp.getRegion().front();
1225  if (body.begin() != std::prev(body.end(), 2))
1226  return;
1227 
1228  rootForOp = dyn_cast<T>(&body.front());
1229  if (!rootForOp)
1230  return;
1231  }
1232 }
1233 
1234 static Loops stripmineSink(scf::ForOp forOp, Value factor,
1235  ArrayRef<scf::ForOp> targets) {
1236  auto originalStep = forOp.getStep();
1237  auto iv = forOp.getInductionVar();
1238 
1239  OpBuilder b(forOp);
1240  forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1241 
1242  Loops innerLoops;
1243  for (auto t : targets) {
1244  // Save information for splicing ops out of t when done
1245  auto begin = t.getBody()->begin();
1246  auto nOps = t.getBody()->getOperations().size();
1247 
1248  // Insert newForOp before the terminator of `t`.
1249  auto b = OpBuilder::atBlockTerminator((t.getBody()));
1250  Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1251  Value ub =
1252  arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1253 
1254  // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1255  auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1256  newForOp.getBody()->getOperations().splice(
1257  newForOp.getBody()->getOperations().begin(),
1258  t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1259  replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1260  newForOp.getRegion());
1261 
1262  innerLoops.push_back(newForOp);
1263  }
1264 
1265  return innerLoops;
1266 }
1267 
1268 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1269 // Returns the new for operation, nested immediately under `target`.
1270 template <typename SizeType>
1271 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1272  scf::ForOp target) {
1273  // TODO: Use cheap structural assertions that targets are nested under
1274  // forOp and that targets are not nested under each other when DominanceInfo
1275  // exposes the capability. It seems overkill to construct a whole function
1276  // dominance tree at this point.
1277  auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1278  assert(res.size() == 1 && "Expected 1 inner forOp");
1279  return res[0];
1280 }
1281 
1283  ArrayRef<Value> sizes,
1284  ArrayRef<scf::ForOp> targets) {
1286  SmallVector<scf::ForOp, 8> currentTargets(targets);
1287  for (auto it : llvm::zip(forOps, sizes)) {
1288  auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1289  res.push_back(step);
1290  currentTargets = step;
1291  }
1292  return res;
1293 }
1294 
1296  scf::ForOp target) {
1298  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
1299  res.push_back(llvm::getSingleElement(loops));
1300  return res;
1301 }
1302 
1303 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1304  // Collect perfectly nested loops. If more size values provided than nested
1305  // loops available, truncate `sizes`.
1307  forOps.reserve(sizes.size());
1308  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1309  if (forOps.size() < sizes.size())
1310  sizes = sizes.take_front(forOps.size());
1311 
1312  return ::tile(forOps, sizes, forOps.back());
1313 }
1314 
1316  scf::ForOp root) {
1317  getPerfectlyNestedLoopsImpl(nestedLoops, root);
1318 }
1319 
1321  ArrayRef<int64_t> sizes) {
1322  // Collect perfectly nested loops. If more size values provided than nested
1323  // loops available, truncate `sizes`.
1325  forOps.reserve(sizes.size());
1326  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1327  if (forOps.size() < sizes.size())
1328  sizes = sizes.take_front(forOps.size());
1329 
1330  // Compute the tile sizes such that i-th outer loop executes size[i]
1331  // iterations. Given that the loop current executes
1332  // numIterations = ceildiv((upperBound - lowerBound), step)
1333  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1334  SmallVector<Value, 4> tileSizes;
1335  tileSizes.reserve(sizes.size());
1336  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1337  assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1338 
1339  auto forOp = forOps[i];
1340  OpBuilder builder(forOp);
1341  auto loc = forOp.getLoc();
1342  Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1343  forOp.getLowerBound());
1344  Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1345  Value iterationsPerBlock =
1346  ceilDivPositive(builder, loc, numIterations, sizes[i]);
1347  tileSizes.push_back(iterationsPerBlock);
1348  }
1349 
1350  // Call parametric tiling with the given sizes.
1351  auto intraTile = tile(forOps, tileSizes, forOps.back());
1352  TileLoops tileLoops = std::make_pair(forOps, intraTile);
1353 
1354  // TODO: for now we just ignore the result of band isolation.
1355  // In the future, mapping decisions may be impacted by the ability to
1356  // isolate perfectly nested bands.
1357  (void)tryIsolateBands(tileLoops);
1358 
1359  return tileLoops;
1360 }
1361 
1362 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1363  scf::ForallOp source,
1364  RewriterBase &rewriter) {
1365  unsigned numTargetOuts = target.getNumResults();
1366  unsigned numSourceOuts = source.getNumResults();
1367 
1368  // Create fused shared_outs.
1369  SmallVector<Value> fusedOuts;
1370  llvm::append_range(fusedOuts, target.getOutputs());
1371  llvm::append_range(fusedOuts, source.getOutputs());
1372 
1373  // Create a new scf.forall op after the source loop.
1374  rewriter.setInsertionPointAfter(source);
1375  scf::ForallOp fusedLoop = scf::ForallOp::create(
1376  rewriter, source.getLoc(), source.getMixedLowerBound(),
1377  source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1378  source.getMapping());
1379 
1380  // Map control operands.
1381  IRMapping mapping;
1382  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1383  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1384 
1385  // Map shared outs.
1386  mapping.map(target.getRegionIterArgs(),
1387  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1388  mapping.map(source.getRegionIterArgs(),
1389  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1390 
1391  // Append everything except the terminator into the fused operation.
1392  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1393  for (Operation &op : target.getBody()->without_terminator())
1394  rewriter.clone(op, mapping);
1395  for (Operation &op : source.getBody()->without_terminator())
1396  rewriter.clone(op, mapping);
1397 
1398  // Fuse the old terminator in_parallel ops into the new one.
1399  scf::InParallelOp targetTerm = target.getTerminator();
1400  scf::InParallelOp sourceTerm = source.getTerminator();
1401  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1402  rewriter.setInsertionPointToStart(fusedTerm.getBody());
1403  for (Operation &op : targetTerm.getYieldingOps())
1404  rewriter.clone(op, mapping);
1405  for (Operation &op : sourceTerm.getYieldingOps())
1406  rewriter.clone(op, mapping);
1407 
1408  // Replace old loops by substituting their uses by results of the fused loop.
1409  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1410  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1411 
1412  return fusedLoop;
1413 }
1414 
1415 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1416  scf::ForOp source,
1417  RewriterBase &rewriter) {
1418  unsigned numTargetOuts = target.getNumResults();
1419  unsigned numSourceOuts = source.getNumResults();
1420 
1421  // Create fused init_args, with target's init_args before source's init_args.
1422  SmallVector<Value> fusedInitArgs;
1423  llvm::append_range(fusedInitArgs, target.getInitArgs());
1424  llvm::append_range(fusedInitArgs, source.getInitArgs());
1425 
1426  // Create a new scf.for op after the source loop (with scf.yield terminator
1427  // (without arguments) only in case its init_args is empty).
1428  rewriter.setInsertionPointAfter(source);
1429  scf::ForOp fusedLoop = scf::ForOp::create(
1430  rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1431  source.getStep(), fusedInitArgs);
1432 
1433  // Map original induction variables and operands to those of the fused loop.
1434  IRMapping mapping;
1435  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1436  mapping.map(target.getRegionIterArgs(),
1437  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1438  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1439  mapping.map(source.getRegionIterArgs(),
1440  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1441 
1442  // Merge target's body into the new (fused) for loop and then source's body.
1443  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1444  for (Operation &op : target.getBody()->without_terminator())
1445  rewriter.clone(op, mapping);
1446  for (Operation &op : source.getBody()->without_terminator())
1447  rewriter.clone(op, mapping);
1448 
1449  // Build fused yield results by appropriately mapping original yield operands.
1450  SmallVector<Value> yieldResults;
1451  for (Value operand : target.getBody()->getTerminator()->getOperands())
1452  yieldResults.push_back(mapping.lookupOrDefault(operand));
1453  for (Value operand : source.getBody()->getTerminator()->getOperands())
1454  yieldResults.push_back(mapping.lookupOrDefault(operand));
1455  if (!yieldResults.empty())
1456  scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1457 
1458  // Replace old loops by substituting their uses by results of the fused loop.
1459  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1460  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1461 
1462  return fusedLoop;
1463 }
1464 
1465 FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1466  scf::ForallOp forallOp) {
1467  SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1468  SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1469  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1470 
1471  if (forallOp.isNormalized())
1472  return forallOp;
1473 
1474  OpBuilder::InsertionGuard g(rewriter);
1475  auto loc = forallOp.getLoc();
1476  rewriter.setInsertionPoint(forallOp);
1478  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1479  Range normalizedLoopParams =
1480  emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1481  newUbs.push_back(normalizedLoopParams.size);
1482  }
1483  (void)foldDynamicIndexList(newUbs);
1484 
1485  // Use the normalized builder since the lower bounds are always 0 and the
1486  // steps are always 1.
1487  auto normalizedForallOp = scf::ForallOp::create(
1488  rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1489  [](OpBuilder &, Location, ValueRange) {});
1490 
1491  rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1492  normalizedForallOp.getBodyRegion(),
1493  normalizedForallOp.getBodyRegion().begin());
1494  // Remove the original empty block in the new loop.
1495  rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1496 
1497  rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1498  // Update the users of the original loop variables.
1499  for (auto [idx, iv] :
1500  llvm::enumerate(normalizedForallOp.getInductionVars())) {
1501  auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1502  auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1503  denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1504  }
1505 
1506  rewriter.replaceOp(forallOp, normalizedForallOp);
1507  return normalizedForallOp;
1508 }
static std::optional< int64_t > getConstantTripCount(scf::ForOp forOp)
Returns the trip count of forOp if its' low bound, high bound and step are constants,...
Definition: Utils.cpp:297
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition: Utils.cpp:795
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition: Utils.cpp:1197
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition: Utils.cpp:1219
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition: Utils.cpp:1153
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
Definition: Utils.cpp:306
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition: Utils.cpp:1234
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition: Utils.cpp:846
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition: Utils.cpp:265
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition: Utils.cpp:810
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition: Utils.cpp:741
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition: Utils.cpp:681
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition: Utils.cpp:503
static int64_t product(ArrayRef< int64_t > vals)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:337
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:750
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition: Builders.h:250
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
result_range getResults()
Definition: Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:686
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:91
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1575
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition: Utils.cpp:1315
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition: Utils.cpp:217
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:32
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:34
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition: Utils.cpp:986
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
std::pair< Loops, Loops > TileLoops
Definition: Utils.h:155
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:102
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition: Utils.cpp:488
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition: Utils.cpp:1061
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition: Utils.cpp:1303
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition: Utils.cpp:361
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition: Utils.cpp:516
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:240
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1282
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:113
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition: RegionUtils.h:26
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition: Utils.cpp:766
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:1362
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition: Utils.cpp:978
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition: Utils.cpp:1415
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition: Utils.cpp:1320
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition: Utils.cpp:695
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition: Utils.cpp:1465
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition: Utils.h:116
std::optional< scf::ForOp > mainLoopOp
Definition: Utils.h:115
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.