MLIR  21.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <cstdint>
33 
34 using namespace mlir;
35 
36 #define DEBUG_TYPE "scf-utils"
37 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
39 
41  RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
42  ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
43  bool replaceIterOperandsUsesInLoop) {
44  if (loopNest.empty())
45  return {};
46  // This method is recursive (to make it more readable). Adding an
47  // assertion here to limit the recursion. (See
48  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
49  assert(loopNest.size() <= 10 &&
50  "exceeded recursion limit when yielding value from loop nest");
51 
52  // To yield a value from a perfectly nested loop nest, the following
53  // pattern needs to be created, i.e. starting with
54  //
55  // ```mlir
56  // scf.for .. {
57  // scf.for .. {
58  // scf.for .. {
59  // %value = ...
60  // }
61  // }
62  // }
63  // ```
64  //
65  // needs to be modified to
66  //
67  // ```mlir
68  // %0 = scf.for .. iter_args(%arg0 = %init) {
69  // %1 = scf.for .. iter_args(%arg1 = %arg0) {
70  // %2 = scf.for .. iter_args(%arg2 = %arg1) {
71  // %value = ...
72  // scf.yield %value
73  // }
74  // scf.yield %2
75  // }
76  // scf.yield %1
77  // }
78  // ```
79  //
80  // The inner most loop is handled using the `replaceWithAdditionalYields`
81  // that works on a single loop.
82  if (loopNest.size() == 1) {
83  auto innerMostLoop =
84  cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
85  rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
86  newYieldValuesFn));
87  return {innerMostLoop};
88  }
89  // The outer loops are modified by calling this method recursively
90  // - The return value of the inner loop is the value yielded by this loop.
91  // - The region iter args of this loop are the init_args for the inner loop.
92  SmallVector<scf::ForOp> newLoopNest;
93  NewYieldValuesFn fn =
94  [&](OpBuilder &innerBuilder, Location loc,
95  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
96  newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
97  innerNewBBArgs, newYieldValuesFn,
98  replaceIterOperandsUsesInLoop);
99  return llvm::to_vector(llvm::map_range(
100  newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
101  [](OpResult r) -> Value { return r; }));
102  };
103  scf::ForOp outerMostLoop =
104  cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
105  rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
106  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
107  return newLoopNest;
108 }
109 
110 /// Outline a region with a single block into a new FuncOp.
111 /// Assumes the FuncOp result types is the type of the yielded operands of the
112 /// single block. This constraint makes it easy to determine the result.
113 /// This method also clones the `arith::ConstantIndexOp` at the start of
114 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
115 /// provided, it will be set to point to the operation that calls the outlined
116 /// function.
117 // TODO: support more than single-block regions.
118 // TODO: more flexible constant handling.
119 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
120  Location loc,
121  Region &region,
122  StringRef funcName,
123  func::CallOp *callOp) {
124  assert(!funcName.empty() && "funcName cannot be empty");
125  if (!region.hasOneBlock())
126  return failure();
127 
128  Block *originalBlock = &region.front();
129  Operation *originalTerminator = originalBlock->getTerminator();
130 
131  // Outline before current function.
132  OpBuilder::InsertionGuard g(rewriter);
133  rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
134 
135  SetVector<Value> captures;
136  getUsedValuesDefinedAbove(region, captures);
137 
138  ValueRange outlinedValues(captures.getArrayRef());
139  SmallVector<Type> outlinedFuncArgTypes;
140  SmallVector<Location> outlinedFuncArgLocs;
141  // Region's arguments are exactly the first block's arguments as per
142  // Region::getArguments().
143  // Func's arguments are cat(regions's arguments, captures arguments).
144  for (BlockArgument arg : region.getArguments()) {
145  outlinedFuncArgTypes.push_back(arg.getType());
146  outlinedFuncArgLocs.push_back(arg.getLoc());
147  }
148  for (Value value : outlinedValues) {
149  outlinedFuncArgTypes.push_back(value.getType());
150  outlinedFuncArgLocs.push_back(value.getLoc());
151  }
152  FunctionType outlinedFuncType =
153  FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
154  originalTerminator->getOperandTypes());
155  auto outlinedFunc =
156  rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
157  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
158 
159  // Merge blocks while replacing the original block operands.
160  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
161  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
162  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
163  {
164  OpBuilder::InsertionGuard g(rewriter);
165  rewriter.setInsertionPointToEnd(outlinedFuncBody);
166  rewriter.mergeBlocks(
167  originalBlock, outlinedFuncBody,
168  outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
169  // Explicitly set up a new ReturnOp terminator.
170  rewriter.setInsertionPointToEnd(outlinedFuncBody);
171  rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
172  originalTerminator->getOperands());
173  }
174 
175  // Reconstruct the block that was deleted and add a
176  // terminator(call_results).
177  Block *newBlock = rewriter.createBlock(
178  &region, region.begin(),
179  TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
180  ArrayRef<Location>(outlinedFuncArgLocs)
181  .take_front(numOriginalBlockArguments));
182  {
183  OpBuilder::InsertionGuard g(rewriter);
184  rewriter.setInsertionPointToEnd(newBlock);
185  SmallVector<Value> callValues;
186  llvm::append_range(callValues, newBlock->getArguments());
187  llvm::append_range(callValues, outlinedValues);
188  auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
189  if (callOp)
190  *callOp = call;
191 
192  // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
193  // Clone `originalTerminator` to take the callOp results then erase it from
194  // `outlinedFuncBody`.
195  IRMapping bvm;
196  bvm.map(originalTerminator->getOperands(), call->getResults());
197  rewriter.clone(*originalTerminator, bvm);
198  rewriter.eraseOp(originalTerminator);
199  }
200 
201  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
202  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
203  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
204  outlinedValues.size()))) {
205  Value orig = std::get<0>(it);
206  Value repl = std::get<1>(it);
207  {
208  OpBuilder::InsertionGuard g(rewriter);
209  rewriter.setInsertionPointToStart(outlinedFuncBody);
210  if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
211  IRMapping bvm;
212  repl = rewriter.clone(*cst, bvm)->getResult(0);
213  }
214  }
215  orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
216  return outlinedFunc->isProperAncestor(opOperand.getOwner());
217  });
218  }
219 
220  return outlinedFunc;
221 }
222 
223 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
224  func::FuncOp *thenFn, StringRef thenFnName,
225  func::FuncOp *elseFn, StringRef elseFnName) {
226  IRRewriter rewriter(b);
227  Location loc = ifOp.getLoc();
228  FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
229  if (thenFn && !ifOp.getThenRegion().empty()) {
230  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
231  rewriter, loc, ifOp.getThenRegion(), thenFnName);
232  if (failed(outlinedFuncOpOrFailure))
233  return failure();
234  *thenFn = *outlinedFuncOpOrFailure;
235  }
236  if (elseFn && !ifOp.getElseRegion().empty()) {
237  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
238  rewriter, loc, ifOp.getElseRegion(), elseFnName);
239  if (failed(outlinedFuncOpOrFailure))
240  return failure();
241  *elseFn = *outlinedFuncOpOrFailure;
242  }
243  return success();
244 }
245 
248  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
249  bool rootEnclosesPloops = false;
250  for (Region &region : rootOp->getRegions()) {
251  for (Block &block : region.getBlocks()) {
252  for (Operation &op : block) {
253  bool enclosesPloops = getInnermostParallelLoops(&op, result);
254  rootEnclosesPloops |= enclosesPloops;
255  if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
256  rootEnclosesPloops = true;
257 
258  // Collect parallel loop if it is an innermost one.
259  if (!enclosesPloops)
260  result.push_back(ploop);
261  }
262  }
263  }
264  }
265  return rootEnclosesPloops;
266 }
267 
268 // Build the IR that performs ceil division of a positive value by a constant:
269 // ceildiv(a, B) = divis(a + (B-1), B)
270 // where divis is rounding-to-zero division.
271 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
272  int64_t divisor) {
273  assert(divisor > 0 && "expected positive divisor");
274  assert(dividend.getType().isIntOrIndex() &&
275  "expected integer or index-typed value");
276 
277  Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
278  loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
279  Value divisorCst = builder.create<arith::ConstantOp>(
280  loc, builder.getIntegerAttr(dividend.getType(), divisor));
281  Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282  return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
283 }
284 
285 // Build the IR that performs ceil division of a positive value by another
286 // positive value:
287 // ceildiv(a, b) = divis(a + (b - 1), b)
288 // where divis is rounding-to-zero division.
289 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
290  Value divisor) {
291  assert(dividend.getType().isIntOrIndex() &&
292  "expected integer or index-typed value");
293  Value cstOne = builder.create<arith::ConstantOp>(
294  loc, builder.getOneAttr(dividend.getType()));
295  Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
296  Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
297  return builder.create<arith::DivUIOp>(loc, sum, divisor);
298 }
299 
300 /// Returns the trip count of `forOp` if its' low bound, high bound and step are
301 /// constants, or optional otherwise. Trip count is computed as
302 /// ceilDiv(highBound - lowBound, step).
303 static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
304  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
305  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
306  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
307  if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
308  return {};
309 
310  // Constant loop bounds computation.
311  int64_t lbCst = lbCstOp.value();
312  int64_t ubCst = ubCstOp.value();
313  int64_t stepCst = stepCstOp.value();
314  assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
315  "expected positive loop bounds and step");
316  return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
317 }
318 
319 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
320 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
321 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
322 /// unrolled iteration using annotateFn.
324  Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
325  function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
326  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
327  ValueRange iterArgs, ValueRange yieldedValues) {
328  // Builder to insert unrolled bodies just before the terminator of the body of
329  // 'forOp'.
330  auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
331 
332  constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
333  if (!annotateFn)
334  annotateFn = defaultAnnotateFn;
335 
336  // Keep a pointer to the last non-terminator operation in the original block
337  // so that we know what to clone (since we are doing this in-place).
338  Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
339 
340  // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
341  SmallVector<Value, 4> lastYielded(yieldedValues);
342 
343  for (unsigned i = 1; i < unrollFactor; i++) {
344  IRMapping operandMap;
345 
346  // Prepare operand map.
347  operandMap.map(iterArgs, lastYielded);
348 
349  // If the induction variable is used, create a remapping to the value for
350  // this unrolled instance.
351  if (!forOpIV.use_empty()) {
352  Value ivUnroll = ivRemapFn(i, forOpIV, builder);
353  operandMap.map(forOpIV, ivUnroll);
354  }
355 
356  // Clone the original body of 'forOp'.
357  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
358  Operation *clonedOp = builder.clone(*it, operandMap);
359  annotateFn(i, clonedOp, builder);
360  }
361 
362  // Update yielded values.
363  for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
364  lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
365  }
366 
367  // Make sure we annotate the Ops in the original body. We do this last so that
368  // any annotations are not copied into the cloned Ops above.
369  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
370  annotateFn(0, &*it, builder);
371 
372  // Update operands of the yield statement.
373  loopBodyBlock->getTerminator()->setOperands(lastYielded);
374 }
375 
376 /// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
377 /// eplilog loop, if the loop is unrolled.
378 FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
379  scf::ForOp forOp, uint64_t unrollFactor,
380  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
381  assert(unrollFactor > 0 && "expected positive unroll factor");
382 
383  // Return if the loop body is empty.
384  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
385  return UnrolledLoopInfo{forOp, std::nullopt};
386 
387  // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
388  // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
389  OpBuilder boundsBuilder(forOp);
390  IRRewriter rewriter(forOp.getContext());
391  auto loc = forOp.getLoc();
392  Value step = forOp.getStep();
393  Value upperBoundUnrolled;
394  Value stepUnrolled;
395  bool generateEpilogueLoop = true;
396 
397  std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
398  if (constTripCount) {
399  // Constant loop bounds computation.
400  int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
401  int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
402  int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
403  if (unrollFactor == 1) {
404  if (*constTripCount == 1 &&
405  failed(forOp.promoteIfSingleIteration(rewriter)))
406  return failure();
407  return UnrolledLoopInfo{forOp, std::nullopt};
408  }
409 
410  int64_t tripCountEvenMultiple =
411  *constTripCount - (*constTripCount % unrollFactor);
412  int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
413  int64_t stepUnrolledCst = stepCst * unrollFactor;
414 
415  // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
416  generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
417  if (generateEpilogueLoop)
418  upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
419  loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
420  upperBoundUnrolledCst));
421  else
422  upperBoundUnrolled = forOp.getUpperBound();
423 
424  // Create constant for 'stepUnrolled'.
425  stepUnrolled = stepCst == stepUnrolledCst
426  ? step
427  : boundsBuilder.create<arith::ConstantOp>(
428  loc, boundsBuilder.getIntegerAttr(
429  step.getType(), stepUnrolledCst));
430  } else {
431  // Dynamic loop bounds computation.
432  // TODO: Add dynamic asserts for negative lb/ub/step, or
433  // consider using ceilDiv from AffineApplyExpander.
434  auto lowerBound = forOp.getLowerBound();
435  auto upperBound = forOp.getUpperBound();
436  Value diff =
437  boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
438  Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
439  Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
440  loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
441  Value tripCountRem =
442  boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
443  // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
444  Value tripCountEvenMultiple =
445  boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
446  // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
447  upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
448  loc, lowerBound,
449  boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
450  // Scale 'step' by 'unrollFactor'.
451  stepUnrolled =
452  boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
453  }
454 
455  UnrolledLoopInfo resultLoops;
456 
457  // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
458  if (generateEpilogueLoop) {
459  OpBuilder epilogueBuilder(forOp->getContext());
460  epilogueBuilder.setInsertionPointAfter(forOp);
461  auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
462  epilogueForOp.setLowerBound(upperBoundUnrolled);
463 
464  // Update uses of loop results.
465  auto results = forOp.getResults();
466  auto epilogueResults = epilogueForOp.getResults();
467 
468  for (auto e : llvm::zip(results, epilogueResults)) {
469  std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
470  }
471  epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
472  epilogueForOp.getInitArgs().size(), results);
473  if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
474  resultLoops.epilogueLoopOp = epilogueForOp;
475  }
476 
477  // Create unrolled loop.
478  forOp.setUpperBound(upperBoundUnrolled);
479  forOp.setStep(stepUnrolled);
480 
481  auto iterArgs = ValueRange(forOp.getRegionIterArgs());
482  auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
483 
485  forOp.getBody(), forOp.getInductionVar(), unrollFactor,
486  [&](unsigned i, Value iv, OpBuilder b) {
487  // iv' = iv + step * i;
488  auto stride = b.create<arith::MulIOp>(
489  loc, step,
490  b.create<arith::ConstantOp>(loc,
491  b.getIntegerAttr(iv.getType(), i)));
492  return b.create<arith::AddIOp>(loc, iv, stride);
493  },
494  annotateFn, iterArgs, yieldedValues);
495  // Promote the loop body up if this has turned into a single iteration loop.
496  if (forOp.promoteIfSingleIteration(rewriter).failed())
497  resultLoops.mainLoopOp = forOp;
498  return resultLoops;
499 }
500 
501 /// Unrolls this loop completely.
502 LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
503  IRRewriter rewriter(forOp.getContext());
504  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
505  if (!mayBeConstantTripCount.has_value())
506  return failure();
507  uint64_t tripCount = *mayBeConstantTripCount;
508  if (tripCount == 0)
509  return success();
510  if (tripCount == 1)
511  return forOp.promoteIfSingleIteration(rewriter);
512  return loopUnrollByFactor(forOp, tripCount);
513 }
514 
515 /// Check if bounds of all inner loops are defined outside of `forOp`
516 /// and return false if not.
517 static bool areInnerBoundsInvariant(scf::ForOp forOp) {
518  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
519  if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
520  !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
521  !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
522  return WalkResult::interrupt();
523 
524  return WalkResult::advance();
525  });
526  return !walkResult.wasInterrupted();
527 }
528 
529 /// Unrolls and jams this loop by the specified factor.
530 LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
531  uint64_t unrollJamFactor) {
532  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
533 
534  if (unrollJamFactor == 1)
535  return success();
536 
537  // If any control operand of any inner loop of `forOp` is defined within
538  // `forOp`, no unroll jam.
539  if (!areInnerBoundsInvariant(forOp)) {
540  LDBG("failed to unroll and jam: inner bounds are not invariant");
541  return failure();
542  }
543 
544  // Currently, for operations with results are not supported.
545  if (forOp->getNumResults() > 0) {
546  LDBG("failed to unroll and jam: unsupported loop with results");
547  return failure();
548  }
549 
550  // Currently, only constant trip count that divided by the unroll factor is
551  // supported.
552  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
553  if (!tripCount.has_value()) {
554  // If the trip count is dynamic, do not unroll & jam.
555  LDBG("failed to unroll and jam: trip count could not be determined");
556  return failure();
557  }
558  if (unrollJamFactor > *tripCount) {
559  LDBG("unroll and jam factor is greater than trip count, set factor to trip "
560  "count");
561  unrollJamFactor = *tripCount;
562  } else if (*tripCount % unrollJamFactor != 0) {
563  LDBG("failed to unroll and jam: unsupported trip count that is not a "
564  "multiple of unroll jam factor");
565  return failure();
566  }
567 
568  // Nothing in the loop body other than the terminator.
569  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
570  return success();
571 
572  // Gather all sub-blocks to jam upon the loop being unrolled.
574  jbg.walk(forOp);
575  auto &subBlocks = jbg.subBlocks;
576 
577  // Collect inner loops.
578  SmallVector<scf::ForOp> innerLoops;
579  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
580 
581  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
582  // iteration. There are (`unrollJamFactor` - 1) iterations.
583  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
584 
585  // For any loop with iter_args, replace it with a new loop that has
586  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
587  // operands.
588  SmallVector<scf::ForOp> newInnerLoops;
589  IRRewriter rewriter(forOp.getContext());
590  for (scf::ForOp oldForOp : innerLoops) {
591  SmallVector<Value> dupIterOperands, dupYieldOperands;
592  ValueRange oldIterOperands = oldForOp.getInits();
593  ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
594  ValueRange oldYieldOperands =
595  cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
596  // Get additional iterOperands, iterArgs, and yield operands. We will
597  // fix iterOperands and yield operands after cloning of sub-blocks.
598  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
599  dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
600  dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
601  }
602  // Create a new loop with additional iterOperands, iter_args and yield
603  // operands. This new loop will take the loop body of the original loop.
604  bool forOpReplaced = oldForOp == forOp;
605  scf::ForOp newForOp =
606  cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
607  rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
608  [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
609  return dupYieldOperands;
610  }));
611  newInnerLoops.push_back(newForOp);
612  // `forOp` has been replaced with a new loop.
613  if (forOpReplaced)
614  forOp = newForOp;
615  // Update `operandMaps` for `newForOp` iterArgs and results.
616  ValueRange newIterArgs = newForOp.getRegionIterArgs();
617  unsigned oldNumIterArgs = oldIterArgs.size();
618  ValueRange newResults = newForOp.getResults();
619  unsigned oldNumResults = newResults.size() / unrollJamFactor;
620  assert(oldNumIterArgs == oldNumResults &&
621  "oldNumIterArgs must be the same as oldNumResults");
622  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
623  for (unsigned j = 0; j < oldNumIterArgs; ++j) {
624  // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
625  // results. Update `operandMaps[i - 1]` to map old iterArgs and results
626  // to those in the `i`th new set.
627  operandMaps[i - 1].map(newIterArgs[j],
628  newIterArgs[i * oldNumIterArgs + j]);
629  operandMaps[i - 1].map(newResults[j],
630  newResults[i * oldNumResults + j]);
631  }
632  }
633  }
634 
635  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
636  rewriter.setInsertionPoint(forOp);
637  int64_t step = forOp.getConstantStep()->getSExtValue();
638  auto newStep = rewriter.createOrFold<arith::MulIOp>(
639  forOp.getLoc(), forOp.getStep(),
640  rewriter.createOrFold<arith::ConstantOp>(
641  forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
642  forOp.setStep(newStep);
643  auto forOpIV = forOp.getInductionVar();
644 
645  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
646  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
647  for (auto &subBlock : subBlocks) {
648  // Builder to insert unroll-jammed bodies. Insert right at the end of
649  // sub-block.
650  OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
651 
652  // If the induction variable is used, create a remapping to the value for
653  // this unrolled instance.
654  if (!forOpIV.use_empty()) {
655  // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
656  auto ivTag = builder.createOrFold<arith::ConstantOp>(
657  forOp.getLoc(), builder.getIndexAttr(step * i));
658  auto ivUnroll =
659  builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
660  operandMaps[i - 1].map(forOpIV, ivUnroll);
661  }
662  // Clone the sub-block being unroll-jammed.
663  for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
664  builder.clone(*it, operandMaps[i - 1]);
665  }
666  // Fix iterOperands and yield op operands of newly created loops.
667  for (auto newForOp : newInnerLoops) {
668  unsigned oldNumIterOperands =
669  newForOp.getNumRegionIterArgs() / unrollJamFactor;
670  unsigned numControlOperands = newForOp.getNumControlOperands();
671  auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
672  unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
673  assert(oldNumIterOperands == oldNumYieldOperands &&
674  "oldNumIterOperands must be the same as oldNumYieldOperands");
675  for (unsigned j = 0; j < oldNumIterOperands; ++j) {
676  // The `i`th duplication of an old iterOperand or yield op operand
677  // needs to be replaced with a mapped value from `operandMaps[i - 1]`
678  // if such mapped value exists.
679  newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
680  operandMaps[i - 1].lookupOrDefault(
681  newForOp.getOperand(numControlOperands + j)));
682  yieldOp.setOperand(
683  i * oldNumYieldOperands + j,
684  operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
685  }
686  }
687  }
688 
689  // Promote the loop body up if this has turned into a single iteration loop.
690  (void)forOp.promoteIfSingleIteration(rewriter);
691  return success();
692 }
693 
695  OpFoldResult lb, OpFoldResult ub,
696  OpFoldResult step) {
697  Range normalizedLoopBounds;
698  normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
699  normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
700  AffineExpr s0, s1, s2;
701  bindSymbols(rewriter.getContext(), s0, s1, s2);
702  AffineExpr e = (s1 - s0).ceilDiv(s2);
703  normalizedLoopBounds.size =
704  affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
705  return normalizedLoopBounds;
706 }
707 
709  OpFoldResult lb, OpFoldResult ub,
710  OpFoldResult step) {
711  if (getType(lb).isIndex()) {
712  return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
713  }
714  // For non-index types, generate `arith` instructions
715  // Check if the loop is already known to have a constant zero lower bound or
716  // a constant one step.
717  bool isZeroBased = false;
718  if (auto lbCst = getConstantIntValue(lb))
719  isZeroBased = lbCst.value() == 0;
720 
721  bool isStepOne = false;
722  if (auto stepCst = getConstantIntValue(step))
723  isStepOne = stepCst.value() == 1;
724 
725  Type rangeType = getType(lb);
726  assert(rangeType == getType(ub) && rangeType == getType(step) &&
727  "expected matching types");
728 
729  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
730  // assuming the step is strictly positive. Update the bounds and the step
731  // of the loop to go from 0 to the number of iterations, if necessary.
732  if (isZeroBased && isStepOne)
733  return {lb, ub, step};
734 
735  OpFoldResult diff = ub;
736  if (!isZeroBased) {
737  diff = rewriter.createOrFold<arith::SubIOp>(
738  loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
739  getValueOrCreateConstantIntOp(rewriter, loc, lb));
740  }
741  OpFoldResult newUpperBound = diff;
742  if (!isStepOne) {
743  newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
744  loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
745  getValueOrCreateConstantIntOp(rewriter, loc, step));
746  }
747 
748  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
749  OpFoldResult newStep = rewriter.getOneAttr(rangeType);
750 
751  return {newLowerBound, newUpperBound, newStep};
752 }
753 
755  Location loc,
756  Value normalizedIv,
757  OpFoldResult origLb,
758  OpFoldResult origStep) {
759  AffineExpr d0, s0, s1;
760  bindSymbols(rewriter.getContext(), s0, s1);
761  bindDims(rewriter.getContext(), d0);
762  AffineExpr e = d0 * s1 + s0;
764  rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
765  Value denormalizedIvVal =
766  getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
767  SmallPtrSet<Operation *, 1> preservedUses;
768  // If an `affine.apply` operation is generated for denormalization, the use
769  // of `origLb` in those ops must not be replaced. These arent not generated
770  // when `origLb == 0` and `origStep == 1`.
771  if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
772  if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
773  preservedUses.insert(preservedUse);
774  }
775  }
776  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
777 }
778 
780  Value normalizedIv, OpFoldResult origLb,
781  OpFoldResult origStep) {
782  if (getType(origLb).isIndex()) {
783  return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
784  origLb, origStep);
785  }
786  Value denormalizedIv;
788  bool isStepOne = isConstantIntValue(origStep, 1);
789  bool isZeroBased = isConstantIntValue(origLb, 0);
790 
791  Value scaled = normalizedIv;
792  if (!isStepOne) {
793  Value origStepValue =
794  getValueOrCreateConstantIntOp(rewriter, loc, origStep);
795  scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
796  preserve.insert(scaled.getDefiningOp());
797  }
798  denormalizedIv = scaled;
799  if (!isZeroBased) {
800  Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
801  denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
802  preserve.insert(denormalizedIv.getDefiningOp());
803  }
804 
805  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
806 }
807 
809  ArrayRef<OpFoldResult> values) {
810  assert(!values.empty() && "unexecpted empty array");
811  AffineExpr s0, s1;
812  bindSymbols(rewriter.getContext(), s0, s1);
813  AffineExpr mul = s0 * s1;
814  OpFoldResult products = rewriter.getIndexAttr(1);
815  for (auto v : values) {
817  rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
818  }
819  return products;
820 }
821 
822 /// Helper function to multiply a sequence of values.
824  ArrayRef<Value> values) {
825  assert(!values.empty() && "unexpected empty list");
826  if (getType(values.front()).isIndex()) {
828  OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
829  return getValueOrCreateConstantIndexOp(rewriter, loc, product);
830  }
831  std::optional<Value> productOf;
832  for (auto v : values) {
833  auto vOne = getConstantIntValue(v);
834  if (vOne && vOne.value() == 1)
835  continue;
836  if (productOf)
837  productOf =
838  rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
839  else
840  productOf = v;
841  }
842  if (!productOf) {
843  productOf = rewriter
844  .create<arith::ConstantOp>(
845  loc, rewriter.getOneAttr(getType(values.front())))
846  .getResult();
847  }
848  return productOf.value();
849 }
850 
851 /// For each original loop, the value of the
852 /// induction variable can be obtained by dividing the induction variable of
853 /// the linearized loop by the total number of iterations of the loops nested
854 /// in it modulo the number of iterations in this loop (remove the values
855 /// related to the outer loops):
856 /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
857 /// Compute these iteratively from the innermost loop by creating a "running
858 /// quotient" of division by the range.
859 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
861  Value linearizedIv, ArrayRef<Value> ubs) {
862 
863  if (linearizedIv.getType().isIndex()) {
864  Operation *delinearizedOp =
865  rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
866  ubs);
867  auto resultVals = llvm::map_to_vector(
868  delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
869  return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
870  }
871 
872  SmallVector<Value> delinearizedIvs(ubs.size());
873  SmallPtrSet<Operation *, 2> preservedUsers;
874 
875  llvm::BitVector isUbOne(ubs.size());
876  for (auto [index, ub] : llvm::enumerate(ubs)) {
877  auto ubCst = getConstantIntValue(ub);
878  if (ubCst && ubCst.value() == 1)
879  isUbOne.set(index);
880  }
881 
882  // Prune the lead ubs that are all ones.
883  unsigned numLeadingOneUbs = 0;
884  for (auto [index, ub] : llvm::enumerate(ubs)) {
885  if (!isUbOne.test(index)) {
886  break;
887  }
888  delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
889  loc, rewriter.getZeroAttr(ub.getType()));
890  numLeadingOneUbs++;
891  }
892 
893  Value previous = linearizedIv;
894  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
895  unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
896  if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
897  previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
898  preservedUsers.insert(previous.getDefiningOp());
899  }
900  Value iv = previous;
901  if (i != e - 1) {
902  if (!isUbOne.test(idx)) {
903  iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
904  preservedUsers.insert(iv.getDefiningOp());
905  } else {
906  iv = rewriter.create<arith::ConstantOp>(
907  loc, rewriter.getZeroAttr(ubs[idx].getType()));
908  }
909  }
910  delinearizedIvs[idx] = iv;
911  }
912  return {delinearizedIvs, preservedUsers};
913 }
914 
915 LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
917  if (loops.size() < 2)
918  return failure();
919 
920  scf::ForOp innermost = loops.back();
921  scf::ForOp outermost = loops.front();
922 
923  // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
924  // allows the following code to assume upperBound is the number of iterations.
925  for (auto loop : loops) {
926  OpBuilder::InsertionGuard g(rewriter);
927  rewriter.setInsertionPoint(outermost);
928  Value lb = loop.getLowerBound();
929  Value ub = loop.getUpperBound();
930  Value step = loop.getStep();
931  auto newLoopRange =
932  emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
933 
934  rewriter.modifyOpInPlace(loop, [&]() {
935  loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
936  newLoopRange.offset));
937  loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
938  newLoopRange.size));
939  loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
940  newLoopRange.stride));
941  });
942  rewriter.setInsertionPointToStart(innermost.getBody());
943  denormalizeInductionVariable(rewriter, loop.getLoc(),
944  loop.getInductionVar(), lb, step);
945  }
946 
947  // 2. Emit code computing the upper bound of the coalesced loop as product
948  // of the number of iterations of all loops.
949  OpBuilder::InsertionGuard g(rewriter);
950  rewriter.setInsertionPoint(outermost);
951  Location loc = outermost.getLoc();
952  SmallVector<Value> upperBounds = llvm::map_to_vector(
953  loops, [](auto loop) { return loop.getUpperBound(); });
954  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
955  outermost.setUpperBound(upperBound);
956 
957  rewriter.setInsertionPointToStart(innermost.getBody());
958  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
959  rewriter, loc, outermost.getInductionVar(), upperBounds);
960  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
961  preservedUsers);
962 
963  for (int i = loops.size() - 1; i > 0; --i) {
964  auto outerLoop = loops[i - 1];
965  auto innerLoop = loops[i];
966 
967  Operation *innerTerminator = innerLoop.getBody()->getTerminator();
968  auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
969  assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
970  for (Value &yieldedVal : yieldedVals) {
971  // The yielded value may be an iteration argument of the inner loop
972  // which is about to be inlined.
973  auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
974  if (iter != innerLoop.getRegionIterArgs().end()) {
975  unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
976  // `outerLoop` iter args identical to the `innerLoop` init args.
977  assert(iterArgIndex < innerLoop.getInitArgs().size());
978  yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
979  }
980  }
981  rewriter.eraseOp(innerTerminator);
982 
983  SmallVector<Value> innerBlockArgs;
984  innerBlockArgs.push_back(delinearizeIvs[i]);
985  llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
986  rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
987  Block::iterator(innerLoop), innerBlockArgs);
988  rewriter.replaceOp(innerLoop, yieldedVals);
989  }
990  return success();
991 }
992 
994  if (loops.empty()) {
995  return failure();
996  }
997  IRRewriter rewriter(loops.front().getContext());
998  return coalesceLoops(rewriter, loops);
999 }
1000 
1001 LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
1002  LogicalResult result(failure());
1004  getPerfectlyNestedLoops(loops, op);
1005 
1006  // Look for a band of loops that can be coalesced, i.e. perfectly nested
1007  // loops with bounds defined above some loop.
1008 
1009  // 1. For each loop, find above which parent loop its bounds operands are
1010  // defined.
1011  SmallVector<unsigned> operandsDefinedAbove(loops.size());
1012  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1013  operandsDefinedAbove[i] = i;
1014  for (unsigned j = 0; j < i; ++j) {
1015  SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
1016  loops[i].getUpperBound(),
1017  loops[i].getStep()};
1018  if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
1019  operandsDefinedAbove[i] = j;
1020  break;
1021  }
1022  }
1023  }
1024 
1025  // 2. For each inner loop check that the iter_args for the immediately outer
1026  // loop are the init for the immediately inner loop and that the yields of the
1027  // return of the inner loop is the yield for the immediately outer loop. Keep
1028  // track of where the chain starts from for each loop.
1029  SmallVector<unsigned> iterArgChainStart(loops.size());
1030  iterArgChainStart[0] = 0;
1031  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
1032  // By default set the start of the chain to itself.
1033  iterArgChainStart[i] = i;
1034  auto outerloop = loops[i - 1];
1035  auto innerLoop = loops[i];
1036  if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1037  continue;
1038  }
1039  if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1040  continue;
1041  }
1042  auto outerloopTerminator = outerloop.getBody()->getTerminator();
1043  if (!llvm::equal(outerloopTerminator->getOperands(),
1044  innerLoop.getResults())) {
1045  continue;
1046  }
1047  iterArgChainStart[i] = iterArgChainStart[i - 1];
1048  }
1049 
1050  // 3. Identify bands of loops such that the operands of all of them are
1051  // defined above the first loop in the band. Traverse the nest bottom-up
1052  // so that modifications don't invalidate the inner loops.
1053  for (unsigned end = loops.size(); end > 0; --end) {
1054  unsigned start = 0;
1055  for (; start < end - 1; ++start) {
1056  auto maxPos =
1057  *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1058  std::next(operandsDefinedAbove.begin(), end));
1059  if (maxPos > start)
1060  continue;
1061  if (iterArgChainStart[end - 1] > start)
1062  continue;
1063  auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
1064  if (succeeded(coalesceLoops(band)))
1065  result = success();
1066  break;
1067  }
1068  // If a band was found and transformed, keep looking at the loops above
1069  // the outermost transformed loop.
1070  if (start != end - 1)
1071  end = start + 1;
1072  }
1073  return result;
1074 }
1075 
1077  RewriterBase &rewriter, scf::ParallelOp loops,
1078  ArrayRef<std::vector<unsigned>> combinedDimensions) {
1079  OpBuilder::InsertionGuard g(rewriter);
1080  rewriter.setInsertionPoint(loops);
1081  Location loc = loops.getLoc();
1082 
1083  // Presort combined dimensions.
1084  auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1085  for (auto &dims : sortedDimensions)
1086  llvm::sort(dims);
1087 
1088  // Normalize ParallelOp's iteration pattern.
1089  SmallVector<Value, 3> normalizedUpperBounds;
1090  for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1091  OpBuilder::InsertionGuard g2(rewriter);
1092  rewriter.setInsertionPoint(loops);
1093  Value lb = loops.getLowerBound()[i];
1094  Value ub = loops.getUpperBound()[i];
1095  Value step = loops.getStep()[i];
1096  auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
1097  normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
1098  rewriter, loops.getLoc(), newLoopRange.size));
1099 
1100  rewriter.setInsertionPointToStart(loops.getBody());
1101  denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
1102  step);
1103  }
1104 
1105  // Combine iteration spaces.
1106  SmallVector<Value, 3> lowerBounds, upperBounds, steps;
1107  auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1108  auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1109  for (auto &sortedDimension : sortedDimensions) {
1110  Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1111  for (auto idx : sortedDimension) {
1112  newUpperBound = rewriter.create<arith::MulIOp>(
1113  loc, newUpperBound, normalizedUpperBounds[idx]);
1114  }
1115  lowerBounds.push_back(cst0);
1116  steps.push_back(cst1);
1117  upperBounds.push_back(newUpperBound);
1118  }
1119 
1120  // Create new ParallelLoop with conversions to the original induction values.
1121  // The loop below uses divisions to get the relevant range of values in the
1122  // new induction value that represent each range of the original induction
1123  // value. The remainders then determine based on that range, which iteration
1124  // of the original induction value this represents. This is a normalized value
1125  // that is un-normalized already by the previous logic.
1126  auto newPloop = rewriter.create<scf::ParallelOp>(
1127  loc, lowerBounds, upperBounds, steps,
1128  [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
1129  for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1130  Value previous = ploopIVs[i];
1131  unsigned numberCombinedDimensions = combinedDimensions[i].size();
1132  // Iterate over all except the last induction value.
1133  for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1134  unsigned idx = combinedDimensions[i][j];
1135 
1136  // Determine the current induction value's current loop iteration
1137  Value iv = insideBuilder.create<arith::RemSIOp>(
1138  loc, previous, normalizedUpperBounds[idx]);
1139  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1140  loops.getRegion());
1141 
1142  // Remove the effect of the current induction value to prepare for
1143  // the next value.
1144  previous = insideBuilder.create<arith::DivSIOp>(
1145  loc, previous, normalizedUpperBounds[idx]);
1146  }
1147 
1148  // The final induction value is just the remaining value.
1149  unsigned idx = combinedDimensions[i][0];
1150  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
1151  previous, loops.getRegion());
1152  }
1153  });
1154 
1155  // Replace the old loop with the new loop.
1156  loops.getBody()->back().erase();
1157  newPloop.getBody()->getOperations().splice(
1158  Block::iterator(newPloop.getBody()->back()),
1159  loops.getBody()->getOperations());
1160  loops.erase();
1161 }
1162 
1163 // Hoist the ops within `outer` that appear before `inner`.
1164 // Such ops include the ops that have been introduced by parametric tiling.
1165 // Ops that come from triangular loops (i.e. that belong to the program slice
1166 // rooted at `outer`) and ops that have side effects cannot be hoisted.
1167 // Return failure when any op fails to hoist.
1168 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
1169  SetVector<Operation *> forwardSlice;
1171  options.filter = [&inner](Operation *op) {
1172  return op != inner.getOperation();
1173  };
1174  getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
1175  LogicalResult status = success();
1177  for (auto &op : outer.getBody()->without_terminator()) {
1178  // Stop when encountering the inner loop.
1179  if (&op == inner.getOperation())
1180  break;
1181  // Skip over non-hoistable ops.
1182  if (forwardSlice.count(&op) > 0) {
1183  status = failure();
1184  continue;
1185  }
1186  // Skip intermediate scf::ForOp, these are not considered a failure.
1187  if (isa<scf::ForOp>(op))
1188  continue;
1189  // Skip other ops with regions.
1190  if (op.getNumRegions() > 0) {
1191  status = failure();
1192  continue;
1193  }
1194  // Skip if op has side effects.
1195  // TODO: loads to immutable memory regions are ok.
1196  if (!isMemoryEffectFree(&op)) {
1197  status = failure();
1198  continue;
1199  }
1200  toHoist.push_back(&op);
1201  }
1202  auto *outerForOp = outer.getOperation();
1203  for (auto *op : toHoist)
1204  op->moveBefore(outerForOp);
1205  return status;
1206 }
1207 
1208 // Traverse the interTile and intraTile loops and try to hoist ops such that
1209 // bands of perfectly nested loops are isolated.
1210 // Return failure if either perfect interTile or perfect intraTile bands cannot
1211 // be formed.
1212 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
1213  LogicalResult status = success();
1214  const Loops &interTile = tileLoops.first;
1215  const Loops &intraTile = tileLoops.second;
1216  auto size = interTile.size();
1217  assert(size == intraTile.size());
1218  if (size <= 1)
1219  return success();
1220  for (unsigned s = 1; s < size; ++s)
1221  status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
1222  : failure();
1223  for (unsigned s = 1; s < size; ++s)
1224  status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
1225  : failure();
1226  return status;
1227 }
1228 
1229 /// Collect perfectly nested loops starting from `rootForOps`. Loops are
1230 /// perfectly nested if each loop is the first and only non-terminator operation
1231 /// in the parent loop. Collect at most `maxLoops` loops and append them to
1232 /// `forOps`.
1233 template <typename T>
1235  SmallVectorImpl<T> &forOps, T rootForOp,
1236  unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
1237  for (unsigned i = 0; i < maxLoops; ++i) {
1238  forOps.push_back(rootForOp);
1239  Block &body = rootForOp.getRegion().front();
1240  if (body.begin() != std::prev(body.end(), 2))
1241  return;
1242 
1243  rootForOp = dyn_cast<T>(&body.front());
1244  if (!rootForOp)
1245  return;
1246  }
1247 }
1248 
1249 static Loops stripmineSink(scf::ForOp forOp, Value factor,
1250  ArrayRef<scf::ForOp> targets) {
1251  auto originalStep = forOp.getStep();
1252  auto iv = forOp.getInductionVar();
1253 
1254  OpBuilder b(forOp);
1255  forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
1256 
1257  Loops innerLoops;
1258  for (auto t : targets) {
1259  // Save information for splicing ops out of t when done
1260  auto begin = t.getBody()->begin();
1261  auto nOps = t.getBody()->getOperations().size();
1262 
1263  // Insert newForOp before the terminator of `t`.
1264  auto b = OpBuilder::atBlockTerminator((t.getBody()));
1265  Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
1266  Value ub =
1267  b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
1268 
1269  // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
1270  auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1271  newForOp.getBody()->getOperations().splice(
1272  newForOp.getBody()->getOperations().begin(),
1273  t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1274  replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1275  newForOp.getRegion());
1276 
1277  innerLoops.push_back(newForOp);
1278  }
1279 
1280  return innerLoops;
1281 }
1282 
1283 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1284 // Returns the new for operation, nested immediately under `target`.
1285 template <typename SizeType>
1286 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
1287  scf::ForOp target) {
1288  // TODO: Use cheap structural assertions that targets are nested under
1289  // forOp and that targets are not nested under each other when DominanceInfo
1290  // exposes the capability. It seems overkill to construct a whole function
1291  // dominance tree at this point.
1292  auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
1293  assert(res.size() == 1 && "Expected 1 inner forOp");
1294  return res[0];
1295 }
1296 
1298  ArrayRef<Value> sizes,
1299  ArrayRef<scf::ForOp> targets) {
1301  SmallVector<scf::ForOp, 8> currentTargets(targets);
1302  for (auto it : llvm::zip(forOps, sizes)) {
1303  auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1304  res.push_back(step);
1305  currentTargets = step;
1306  }
1307  return res;
1308 }
1309 
1311  scf::ForOp target) {
1313  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
1314  assert(loops.size() == 1);
1315  res.push_back(loops[0]);
1316  }
1317  return res;
1318 }
1319 
1320 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1321  // Collect perfectly nested loops. If more size values provided than nested
1322  // loops available, truncate `sizes`.
1324  forOps.reserve(sizes.size());
1325  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1326  if (forOps.size() < sizes.size())
1327  sizes = sizes.take_front(forOps.size());
1328 
1329  return ::tile(forOps, sizes, forOps.back());
1330 }
1331 
1333  scf::ForOp root) {
1334  getPerfectlyNestedLoopsImpl(nestedLoops, root);
1335 }
1336 
1338  ArrayRef<int64_t> sizes) {
1339  // Collect perfectly nested loops. If more size values provided than nested
1340  // loops available, truncate `sizes`.
1342  forOps.reserve(sizes.size());
1343  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1344  if (forOps.size() < sizes.size())
1345  sizes = sizes.take_front(forOps.size());
1346 
1347  // Compute the tile sizes such that i-th outer loop executes size[i]
1348  // iterations. Given that the loop current executes
1349  // numIterations = ceildiv((upperBound - lowerBound), step)
1350  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1351  SmallVector<Value, 4> tileSizes;
1352  tileSizes.reserve(sizes.size());
1353  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1354  assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1355 
1356  auto forOp = forOps[i];
1357  OpBuilder builder(forOp);
1358  auto loc = forOp.getLoc();
1359  Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
1360  forOp.getLowerBound());
1361  Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1362  Value iterationsPerBlock =
1363  ceilDivPositive(builder, loc, numIterations, sizes[i]);
1364  tileSizes.push_back(iterationsPerBlock);
1365  }
1366 
1367  // Call parametric tiling with the given sizes.
1368  auto intraTile = tile(forOps, tileSizes, forOps.back());
1369  TileLoops tileLoops = std::make_pair(forOps, intraTile);
1370 
1371  // TODO: for now we just ignore the result of band isolation.
1372  // In the future, mapping decisions may be impacted by the ability to
1373  // isolate perfectly nested bands.
1374  (void)tryIsolateBands(tileLoops);
1375 
1376  return tileLoops;
1377 }
1378 
1379 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1380  scf::ForallOp source,
1381  RewriterBase &rewriter) {
1382  unsigned numTargetOuts = target.getNumResults();
1383  unsigned numSourceOuts = source.getNumResults();
1384 
1385  // Create fused shared_outs.
1386  SmallVector<Value> fusedOuts;
1387  llvm::append_range(fusedOuts, target.getOutputs());
1388  llvm::append_range(fusedOuts, source.getOutputs());
1389 
1390  // Create a new scf.forall op after the source loop.
1391  rewriter.setInsertionPointAfter(source);
1392  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
1393  source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1394  source.getMixedStep(), fusedOuts, source.getMapping());
1395 
1396  // Map control operands.
1397  IRMapping mapping;
1398  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1399  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1400 
1401  // Map shared outs.
1402  mapping.map(target.getRegionIterArgs(),
1403  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1404  mapping.map(source.getRegionIterArgs(),
1405  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1406 
1407  // Append everything except the terminator into the fused operation.
1408  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1409  for (Operation &op : target.getBody()->without_terminator())
1410  rewriter.clone(op, mapping);
1411  for (Operation &op : source.getBody()->without_terminator())
1412  rewriter.clone(op, mapping);
1413 
1414  // Fuse the old terminator in_parallel ops into the new one.
1415  scf::InParallelOp targetTerm = target.getTerminator();
1416  scf::InParallelOp sourceTerm = source.getTerminator();
1417  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1418  rewriter.setInsertionPointToStart(fusedTerm.getBody());
1419  for (Operation &op : targetTerm.getYieldingOps())
1420  rewriter.clone(op, mapping);
1421  for (Operation &op : sourceTerm.getYieldingOps())
1422  rewriter.clone(op, mapping);
1423 
1424  // Replace old loops by substituting their uses by results of the fused loop.
1425  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1426  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1427 
1428  return fusedLoop;
1429 }
1430 
1431 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1432  scf::ForOp source,
1433  RewriterBase &rewriter) {
1434  unsigned numTargetOuts = target.getNumResults();
1435  unsigned numSourceOuts = source.getNumResults();
1436 
1437  // Create fused init_args, with target's init_args before source's init_args.
1438  SmallVector<Value> fusedInitArgs;
1439  llvm::append_range(fusedInitArgs, target.getInitArgs());
1440  llvm::append_range(fusedInitArgs, source.getInitArgs());
1441 
1442  // Create a new scf.for op after the source loop (with scf.yield terminator
1443  // (without arguments) only in case its init_args is empty).
1444  rewriter.setInsertionPointAfter(source);
1445  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
1446  source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1447  source.getStep(), fusedInitArgs);
1448 
1449  // Map original induction variables and operands to those of the fused loop.
1450  IRMapping mapping;
1451  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1452  mapping.map(target.getRegionIterArgs(),
1453  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1454  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1455  mapping.map(source.getRegionIterArgs(),
1456  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1457 
1458  // Merge target's body into the new (fused) for loop and then source's body.
1459  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1460  for (Operation &op : target.getBody()->without_terminator())
1461  rewriter.clone(op, mapping);
1462  for (Operation &op : source.getBody()->without_terminator())
1463  rewriter.clone(op, mapping);
1464 
1465  // Build fused yield results by appropriately mapping original yield operands.
1466  SmallVector<Value> yieldResults;
1467  for (Value operand : target.getBody()->getTerminator()->getOperands())
1468  yieldResults.push_back(mapping.lookupOrDefault(operand));
1469  for (Value operand : source.getBody()->getTerminator()->getOperands())
1470  yieldResults.push_back(mapping.lookupOrDefault(operand));
1471  if (!yieldResults.empty())
1472  rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1473 
1474  // Replace old loops by substituting their uses by results of the fused loop.
1475  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1476  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1477 
1478  return fusedLoop;
1479 }
1480 
1481 FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
1482  scf::ForallOp forallOp) {
1483  SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
1484  SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
1485  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
1486 
1487  if (llvm::all_of(
1488  lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
1489  llvm::all_of(
1490  steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
1491  return forallOp;
1492  }
1493 
1494  SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
1495  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1496  Range normalizedLoopParams =
1497  emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
1498  newLbs.push_back(normalizedLoopParams.offset);
1499  newUbs.push_back(normalizedLoopParams.size);
1500  newSteps.push_back(normalizedLoopParams.stride);
1501  }
1502 
1503  auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1504  forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1505  forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
1506 
1507  rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
1508  normalizedForallOp.getBodyRegion(),
1509  normalizedForallOp.getBodyRegion().begin());
1510 
1511  rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
1512  return success();
1513 }
static std::optional< int64_t > getConstantTripCount(scf::ForOp forOp)
Returns the trip count of forOp if its' low bound, high bound and step are constants,...
Definition: Utils.cpp:303
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
Definition: Utils.cpp:808
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition: Utils.cpp:1212
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition: Utils.cpp:1234
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition: Utils.cpp:1168
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
Definition: Utils.cpp:323
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition: Utils.cpp:1249
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition: Utils.cpp:860
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition: Utils.cpp:271
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition: Utils.cpp:823
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Definition: Utils.cpp:754
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Definition: Utils.cpp:694
#define LDBG(X)
Definition: Utils.cpp:38
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
Definition: Utils.cpp:517
static int64_t product(ArrayRef< int64_t > vals)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
Definition: AffineExpr.h:68
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType::iterator iterator
Definition: Block.h:140
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:338
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition: Builders.h:250
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:415
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:45
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:81
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1198
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1589
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition: Utils.cpp:1332
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition: Utils.cpp:223
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:32
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:40
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition: Utils.cpp:1001
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
std::pair< Loops, Loops > TileLoops
Definition: Utils.h:155
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
Definition: Utils.cpp:502
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition: Utils.cpp:1076
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition: Utils.cpp:1320
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition: Utils.cpp:378
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
Definition: Utils.cpp:530
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:246
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:67
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:119
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition: RegionUtils.h:24
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
Definition: Utils.cpp:779
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:1379
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition: Utils.cpp:993
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition: Utils.cpp:1431
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition: Utils.cpp:1337
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
Definition: Utils.cpp:708
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
Definition: Utils.cpp:1481
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
void walk(Operation *op)
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
std::optional< scf::ForOp > epilogueLoopOp
Definition: Utils.h:116
std::optional< scf::ForOp > mainLoopOp
Definition: Utils.h:115
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.