MLIR  19.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 
30 using namespace mlir;
31 
32 namespace {
33 // This structure is to pass and return sets of loop parameters without
34 // confusing the order.
35 struct LoopParams {
36  Value lowerBound;
37  Value upperBound;
38  Value step;
39 };
40 } // namespace
41 
43  RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
44  ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
45  bool replaceIterOperandsUsesInLoop) {
46  if (loopNest.empty())
47  return {};
48  // This method is recursive (to make it more readable). Adding an
49  // assertion here to limit the recursion. (See
50  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
51  assert(loopNest.size() <= 10 &&
52  "exceeded recursion limit when yielding value from loop nest");
53 
54  // To yield a value from a perfectly nested loop nest, the following
55  // pattern needs to be created, i.e. starting with
56  //
57  // ```mlir
58  // scf.for .. {
59  // scf.for .. {
60  // scf.for .. {
61  // %value = ...
62  // }
63  // }
64  // }
65  // ```
66  //
67  // needs to be modified to
68  //
69  // ```mlir
70  // %0 = scf.for .. iter_args(%arg0 = %init) {
71  // %1 = scf.for .. iter_args(%arg1 = %arg0) {
72  // %2 = scf.for .. iter_args(%arg2 = %arg1) {
73  // %value = ...
74  // scf.yield %value
75  // }
76  // scf.yield %2
77  // }
78  // scf.yield %1
79  // }
80  // ```
81  //
82  // The inner most loop is handled using the `replaceWithAdditionalYields`
83  // that works on a single loop.
84  if (loopNest.size() == 1) {
85  auto innerMostLoop =
86  cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
87  rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
88  newYieldValuesFn));
89  return {innerMostLoop};
90  }
91  // The outer loops are modified by calling this method recursively
92  // - The return value of the inner loop is the value yielded by this loop.
93  // - The region iter args of this loop are the init_args for the inner loop.
94  SmallVector<scf::ForOp> newLoopNest;
95  NewYieldValuesFn fn =
96  [&](OpBuilder &innerBuilder, Location loc,
97  ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
98  newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
99  innerNewBBArgs, newYieldValuesFn,
100  replaceIterOperandsUsesInLoop);
101  return llvm::to_vector(llvm::map_range(
102  newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
103  [](OpResult r) -> Value { return r; }));
104  };
105  scf::ForOp outerMostLoop =
106  cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
107  rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
108  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
109  return newLoopNest;
110 }
111 
112 /// Outline a region with a single block into a new FuncOp.
113 /// Assumes the FuncOp result types is the type of the yielded operands of the
114 /// single block. This constraint makes it easy to determine the result.
115 /// This method also clones the `arith::ConstantIndexOp` at the start of
116 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
117 /// provided, it will be set to point to the operation that calls the outlined
118 /// function.
119 // TODO: support more than single-block regions.
120 // TODO: more flexible constant handling.
122  Location loc,
123  Region &region,
124  StringRef funcName,
125  func::CallOp *callOp) {
126  assert(!funcName.empty() && "funcName cannot be empty");
127  if (!region.hasOneBlock())
128  return failure();
129 
130  Block *originalBlock = &region.front();
131  Operation *originalTerminator = originalBlock->getTerminator();
132 
133  // Outline before current function.
134  OpBuilder::InsertionGuard g(rewriter);
135  rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>());
136 
137  SetVector<Value> captures;
138  getUsedValuesDefinedAbove(region, captures);
139 
140  ValueRange outlinedValues(captures.getArrayRef());
141  SmallVector<Type> outlinedFuncArgTypes;
142  SmallVector<Location> outlinedFuncArgLocs;
143  // Region's arguments are exactly the first block's arguments as per
144  // Region::getArguments().
145  // Func's arguments are cat(regions's arguments, captures arguments).
146  for (BlockArgument arg : region.getArguments()) {
147  outlinedFuncArgTypes.push_back(arg.getType());
148  outlinedFuncArgLocs.push_back(arg.getLoc());
149  }
150  for (Value value : outlinedValues) {
151  outlinedFuncArgTypes.push_back(value.getType());
152  outlinedFuncArgLocs.push_back(value.getLoc());
153  }
154  FunctionType outlinedFuncType =
155  FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
156  originalTerminator->getOperandTypes());
157  auto outlinedFunc =
158  rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
159  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
160 
161  // Merge blocks while replacing the original block operands.
162  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
163  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
164  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
165  {
166  OpBuilder::InsertionGuard g(rewriter);
167  rewriter.setInsertionPointToEnd(outlinedFuncBody);
168  rewriter.mergeBlocks(
169  originalBlock, outlinedFuncBody,
170  outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
171  // Explicitly set up a new ReturnOp terminator.
172  rewriter.setInsertionPointToEnd(outlinedFuncBody);
173  rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
174  originalTerminator->getOperands());
175  }
176 
177  // Reconstruct the block that was deleted and add a
178  // terminator(call_results).
179  Block *newBlock = rewriter.createBlock(
180  &region, region.begin(),
181  TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
182  ArrayRef<Location>(outlinedFuncArgLocs)
183  .take_front(numOriginalBlockArguments));
184  {
185  OpBuilder::InsertionGuard g(rewriter);
186  rewriter.setInsertionPointToEnd(newBlock);
187  SmallVector<Value> callValues;
188  llvm::append_range(callValues, newBlock->getArguments());
189  llvm::append_range(callValues, outlinedValues);
190  auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
191  if (callOp)
192  *callOp = call;
193 
194  // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
195  // Clone `originalTerminator` to take the callOp results then erase it from
196  // `outlinedFuncBody`.
197  IRMapping bvm;
198  bvm.map(originalTerminator->getOperands(), call->getResults());
199  rewriter.clone(*originalTerminator, bvm);
200  rewriter.eraseOp(originalTerminator);
201  }
202 
203  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
204  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
205  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
206  outlinedValues.size()))) {
207  Value orig = std::get<0>(it);
208  Value repl = std::get<1>(it);
209  {
210  OpBuilder::InsertionGuard g(rewriter);
211  rewriter.setInsertionPointToStart(outlinedFuncBody);
212  if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
213  IRMapping bvm;
214  repl = rewriter.clone(*cst, bvm)->getResult(0);
215  }
216  }
217  orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
218  return outlinedFunc->isProperAncestor(opOperand.getOwner());
219  });
220  }
221 
222  return outlinedFunc;
223 }
224 
226  func::FuncOp *thenFn, StringRef thenFnName,
227  func::FuncOp *elseFn, StringRef elseFnName) {
228  IRRewriter rewriter(b);
229  Location loc = ifOp.getLoc();
230  FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
231  if (thenFn && !ifOp.getThenRegion().empty()) {
232  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
233  rewriter, loc, ifOp.getThenRegion(), thenFnName);
234  if (failed(outlinedFuncOpOrFailure))
235  return failure();
236  *thenFn = *outlinedFuncOpOrFailure;
237  }
238  if (elseFn && !ifOp.getElseRegion().empty()) {
239  outlinedFuncOpOrFailure = outlineSingleBlockRegion(
240  rewriter, loc, ifOp.getElseRegion(), elseFnName);
241  if (failed(outlinedFuncOpOrFailure))
242  return failure();
243  *elseFn = *outlinedFuncOpOrFailure;
244  }
245  return success();
246 }
247 
250  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
251  bool rootEnclosesPloops = false;
252  for (Region &region : rootOp->getRegions()) {
253  for (Block &block : region.getBlocks()) {
254  for (Operation &op : block) {
255  bool enclosesPloops = getInnermostParallelLoops(&op, result);
256  rootEnclosesPloops |= enclosesPloops;
257  if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
258  rootEnclosesPloops = true;
259 
260  // Collect parallel loop if it is an innermost one.
261  if (!enclosesPloops)
262  result.push_back(ploop);
263  }
264  }
265  }
266  }
267  return rootEnclosesPloops;
268 }
269 
270 // Build the IR that performs ceil division of a positive value by a constant:
271 // ceildiv(a, B) = divis(a + (B-1), B)
272 // where divis is rounding-to-zero division.
273 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
274  int64_t divisor) {
275  assert(divisor > 0 && "expected positive divisor");
276  assert(dividend.getType().isIndex() && "expected index-typed value");
277 
278  Value divisorMinusOneCst =
279  builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
280  Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
281  Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282  return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
283 }
284 
285 // Build the IR that performs ceil division of a positive value by another
286 // positive value:
287 // ceildiv(a, b) = divis(a + (b - 1), b)
288 // where divis is rounding-to-zero division.
289 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
290  Value divisor) {
291  assert(dividend.getType().isIndex() && "expected index-typed value");
292 
293  Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
294  Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
295  Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
296  return builder.create<arith::DivUIOp>(loc, sum, divisor);
297 }
298 
299 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
300 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
301 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
302 /// unrolled iteration using annotateFn.
304  Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
305  function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
306  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
307  ValueRange iterArgs, ValueRange yieldedValues) {
308  // Builder to insert unrolled bodies just before the terminator of the body of
309  // 'forOp'.
310  auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
311 
312  if (!annotateFn)
313  annotateFn = [](unsigned, Operation *, OpBuilder) {};
314 
315  // Keep a pointer to the last non-terminator operation in the original block
316  // so that we know what to clone (since we are doing this in-place).
317  Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
318 
319  // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
320  SmallVector<Value, 4> lastYielded(yieldedValues);
321 
322  for (unsigned i = 1; i < unrollFactor; i++) {
323  IRMapping operandMap;
324 
325  // Prepare operand map.
326  operandMap.map(iterArgs, lastYielded);
327 
328  // If the induction variable is used, create a remapping to the value for
329  // this unrolled instance.
330  if (!forOpIV.use_empty()) {
331  Value ivUnroll = ivRemapFn(i, forOpIV, builder);
332  operandMap.map(forOpIV, ivUnroll);
333  }
334 
335  // Clone the original body of 'forOp'.
336  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
337  Operation *clonedOp = builder.clone(*it, operandMap);
338  annotateFn(i, clonedOp, builder);
339  }
340 
341  // Update yielded values.
342  for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
343  lastYielded[i] = operandMap.lookup(yieldedValues[i]);
344  }
345 
346  // Make sure we annotate the Ops in the original body. We do this last so that
347  // any annotations are not copied into the cloned Ops above.
348  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
349  annotateFn(0, &*it, builder);
350 
351  // Update operands of the yield statement.
352  loopBodyBlock->getTerminator()->setOperands(lastYielded);
353 }
354 
355 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
357  scf::ForOp forOp, uint64_t unrollFactor,
358  function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
359  assert(unrollFactor > 0 && "expected positive unroll factor");
360 
361  // Return if the loop body is empty.
362  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
363  return success();
364 
365  // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
366  // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
367  OpBuilder boundsBuilder(forOp);
368  IRRewriter rewriter(forOp.getContext());
369  auto loc = forOp.getLoc();
370  Value step = forOp.getStep();
371  Value upperBoundUnrolled;
372  Value stepUnrolled;
373  bool generateEpilogueLoop = true;
374 
375  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
376  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
377  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
378  if (lbCstOp && ubCstOp && stepCstOp) {
379  // Constant loop bounds computation.
380  int64_t lbCst = lbCstOp.value();
381  int64_t ubCst = ubCstOp.value();
382  int64_t stepCst = stepCstOp.value();
383  assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
384  "expected positive loop bounds and step");
385  int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
386 
387  if (unrollFactor == 1) {
388  if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
389  return failure();
390  return success();
391  }
392 
393  int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
394  int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
395  int64_t stepUnrolledCst = stepCst * unrollFactor;
396 
397  // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
398  generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
399  if (generateEpilogueLoop)
400  upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
401  loc, upperBoundUnrolledCst);
402  else
403  upperBoundUnrolled = forOp.getUpperBound();
404 
405  // Create constant for 'stepUnrolled'.
406  stepUnrolled = stepCst == stepUnrolledCst
407  ? step
408  : boundsBuilder.create<arith::ConstantIndexOp>(
409  loc, stepUnrolledCst);
410  } else {
411  // Dynamic loop bounds computation.
412  // TODO: Add dynamic asserts for negative lb/ub/step, or
413  // consider using ceilDiv from AffineApplyExpander.
414  auto lowerBound = forOp.getLowerBound();
415  auto upperBound = forOp.getUpperBound();
416  Value diff =
417  boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
418  Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
419  Value unrollFactorCst =
420  boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
421  Value tripCountRem =
422  boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
423  // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
424  Value tripCountEvenMultiple =
425  boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
426  // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
427  upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
428  loc, lowerBound,
429  boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
430  // Scale 'step' by 'unrollFactor'.
431  stepUnrolled =
432  boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
433  }
434 
435  // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
436  if (generateEpilogueLoop) {
437  OpBuilder epilogueBuilder(forOp->getContext());
438  epilogueBuilder.setInsertionPoint(forOp->getBlock(),
439  std::next(Block::iterator(forOp)));
440  auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
441  epilogueForOp.setLowerBound(upperBoundUnrolled);
442 
443  // Update uses of loop results.
444  auto results = forOp.getResults();
445  auto epilogueResults = epilogueForOp.getResults();
446 
447  for (auto e : llvm::zip(results, epilogueResults)) {
448  std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
449  }
450  epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
451  epilogueForOp.getInitArgs().size(), results);
452  (void)epilogueForOp.promoteIfSingleIteration(rewriter);
453  }
454 
455  // Create unrolled loop.
456  forOp.setUpperBound(upperBoundUnrolled);
457  forOp.setStep(stepUnrolled);
458 
459  auto iterArgs = ValueRange(forOp.getRegionIterArgs());
460  auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
461 
463  forOp.getBody(), forOp.getInductionVar(), unrollFactor,
464  [&](unsigned i, Value iv, OpBuilder b) {
465  // iv' = iv + step * i;
466  auto stride = b.create<arith::MulIOp>(
467  loc, step, b.create<arith::ConstantIndexOp>(loc, i));
468  return b.create<arith::AddIOp>(loc, iv, stride);
469  },
470  annotateFn, iterArgs, yieldedValues);
471  // Promote the loop body up if this has turned into a single iteration loop.
472  (void)forOp.promoteIfSingleIteration(rewriter);
473  return success();
474 }
475 
476 /// Transform a loop with a strictly positive step
477 /// for %i = %lb to %ub step %s
478 /// into a 0-based loop with step 1
479 /// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480 /// %i = %ii * %s + %lb
481 /// Insert the induction variable remapping in the body of `inner`, which is
482 /// expected to be either `loop` or another loop perfectly nested under `loop`.
483 /// Insert the definition of new bounds immediate before `outer`, which is
484 /// expected to be either `loop` or its parent in the loop nest.
485 static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
486  Value lb, Value ub, Value step) {
487  // For non-index types, generate `arith` instructions
488  // Check if the loop is already known to have a constant zero lower bound or
489  // a constant one step.
490  bool isZeroBased = false;
491  if (auto lbCst = getConstantIntValue(lb))
492  isZeroBased = lbCst.value() == 0;
493 
494  bool isStepOne = false;
495  if (auto stepCst = getConstantIntValue(step))
496  isStepOne = stepCst.value() == 1;
497 
498  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
499  // assuming the step is strictly positive. Update the bounds and the step
500  // of the loop to go from 0 to the number of iterations, if necessary.
501  if (isZeroBased && isStepOne)
502  return {lb, ub, step};
503 
504  Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
505  Value newUpperBound =
506  isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
507 
508  Value newLowerBound = isZeroBased
509  ? lb
510  : rewriter.create<arith::ConstantOp>(
511  loc, rewriter.getZeroAttr(lb.getType()));
512  Value newStep = isStepOne
513  ? step
514  : rewriter.create<arith::ConstantOp>(
515  loc, rewriter.getIntegerAttr(step.getType(), 1));
516 
517  return {newLowerBound, newUpperBound, newStep};
518 }
519 
520 /// Get back the original induction variable values after loop normalization
522  Value normalizedIv, Value origLb,
523  Value origStep) {
524  Value denormalizedIv;
526  bool isStepOne = isConstantIntValue(origStep, 1);
527  bool isZeroBased = isConstantIntValue(origLb, 0);
528 
529  Value scaled = normalizedIv;
530  if (!isStepOne) {
531  scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
532  preserve.insert(scaled.getDefiningOp());
533  }
534  denormalizedIv = scaled;
535  if (!isZeroBased) {
536  denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
537  preserve.insert(denormalizedIv.getDefiningOp());
538  }
539 
540  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
541 }
542 
543 /// Helper function to multiply a sequence of values.
545  ArrayRef<Value> values) {
546  assert(!values.empty() && "unexpected empty list");
547  std::optional<Value> productOf;
548  for (auto v : values) {
549  auto vOne = getConstantIntValue(v);
550  if (vOne && vOne.value() == 1)
551  continue;
552  if (productOf)
553  productOf =
554  rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
555  else
556  productOf = v;
557  }
558  if (!productOf) {
559  productOf = rewriter
560  .create<arith::ConstantOp>(
561  loc, rewriter.getOneAttr(values.front().getType()))
562  .getResult();
563  }
564  return productOf.value();
565 }
566 
567 /// For each original loop, the value of the
568 /// induction variable can be obtained by dividing the induction variable of
569 /// the linearized loop by the total number of iterations of the loops nested
570 /// in it modulo the number of iterations in this loop (remove the values
571 /// related to the outer loops):
572 /// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
573 /// Compute these iteratively from the innermost loop by creating a "running
574 /// quotient" of division by the range.
575 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
577  Value linearizedIv, ArrayRef<Value> ubs) {
578  SmallVector<Value> delinearizedIvs(ubs.size());
579  SmallPtrSet<Operation *, 2> preservedUsers;
580 
581  llvm::BitVector isUbOne(ubs.size());
582  for (auto [index, ub] : llvm::enumerate(ubs)) {
583  auto ubCst = getConstantIntValue(ub);
584  if (ubCst && ubCst.value() == 1)
585  isUbOne.set(index);
586  }
587 
588  // Prune the lead ubs that are all ones.
589  unsigned numLeadingOneUbs = 0;
590  for (auto [index, ub] : llvm::enumerate(ubs)) {
591  if (!isUbOne.test(index)) {
592  break;
593  }
594  delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
595  loc, rewriter.getZeroAttr(ub.getType()));
596  numLeadingOneUbs++;
597  }
598 
599  Value previous = linearizedIv;
600  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
601  unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
602  if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
603  previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
604  preservedUsers.insert(previous.getDefiningOp());
605  }
606  Value iv = previous;
607  if (i != e - 1) {
608  if (!isUbOne.test(idx)) {
609  iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
610  preservedUsers.insert(iv.getDefiningOp());
611  } else {
612  iv = rewriter.create<arith::ConstantOp>(
613  loc, rewriter.getZeroAttr(ubs[idx].getType()));
614  }
615  }
616  delinearizedIvs[idx] = iv;
617  }
618  return {delinearizedIvs, preservedUsers};
619 }
620 
623  if (loops.size() < 2)
624  return failure();
625 
626  scf::ForOp innermost = loops.back();
627  scf::ForOp outermost = loops.front();
628 
629  // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
630  // allows the following code to assume upperBound is the number of iterations.
631  for (auto loop : loops) {
632  OpBuilder::InsertionGuard g(rewriter);
633  rewriter.setInsertionPoint(outermost);
634  Value lb = loop.getLowerBound();
635  Value ub = loop.getUpperBound();
636  Value step = loop.getStep();
637  auto newLoopParams =
638  emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
639 
640  rewriter.modifyOpInPlace(loop, [&]() {
641  loop.setLowerBound(newLoopParams.lowerBound);
642  loop.setUpperBound(newLoopParams.upperBound);
643  loop.setStep(newLoopParams.step);
644  });
645 
646  rewriter.setInsertionPointToStart(innermost.getBody());
647  denormalizeInductionVariable(rewriter, loop.getLoc(),
648  loop.getInductionVar(), lb, step);
649  }
650 
651  // 2. Emit code computing the upper bound of the coalesced loop as product
652  // of the number of iterations of all loops.
653  OpBuilder::InsertionGuard g(rewriter);
654  rewriter.setInsertionPoint(outermost);
655  Location loc = outermost.getLoc();
656  SmallVector<Value> upperBounds = llvm::map_to_vector(
657  loops, [](auto loop) { return loop.getUpperBound(); });
658  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
659  outermost.setUpperBound(upperBound);
660 
661  rewriter.setInsertionPointToStart(innermost.getBody());
662  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
663  rewriter, loc, outermost.getInductionVar(), upperBounds);
664  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
665  preservedUsers);
666 
667  for (int i = loops.size() - 1; i > 0; --i) {
668  auto outerLoop = loops[i - 1];
669  auto innerLoop = loops[i];
670 
671  Operation *innerTerminator = innerLoop.getBody()->getTerminator();
672  auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
673  rewriter.eraseOp(innerTerminator);
674 
675  SmallVector<Value> innerBlockArgs;
676  innerBlockArgs.push_back(delinearizeIvs[i]);
677  llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
678  rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
679  Block::iterator(innerLoop), innerBlockArgs);
680  rewriter.replaceOp(innerLoop, yieldedVals);
681  }
682  return success();
683 }
684 
686  if (loops.empty()) {
687  return failure();
688  }
689  IRRewriter rewriter(loops.front().getContext());
690  return coalesceLoops(rewriter, loops);
691 }
692 
694  LogicalResult result(failure());
696  getPerfectlyNestedLoops(loops, op);
697 
698  // Look for a band of loops that can be coalesced, i.e. perfectly nested
699  // loops with bounds defined above some loop.
700 
701  // 1. For each loop, find above which parent loop its bounds operands are
702  // defined.
703  SmallVector<unsigned> operandsDefinedAbove(loops.size());
704  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
705  operandsDefinedAbove[i] = i;
706  for (unsigned j = 0; j < i; ++j) {
707  SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
708  loops[i].getUpperBound(),
709  loops[i].getStep()};
710  if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
711  operandsDefinedAbove[i] = j;
712  break;
713  }
714  }
715  }
716 
717  // 2. For each inner loop check that the iter_args for the immediately outer
718  // loop are the init for the immediately inner loop and that the yields of the
719  // return of the inner loop is the yield for the immediately outer loop. Keep
720  // track of where the chain starts from for each loop.
721  SmallVector<unsigned> iterArgChainStart(loops.size());
722  iterArgChainStart[0] = 0;
723  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
724  // By default set the start of the chain to itself.
725  iterArgChainStart[i] = i;
726  auto outerloop = loops[i - 1];
727  auto innerLoop = loops[i];
728  if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
729  continue;
730  }
731  if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
732  continue;
733  }
734  auto outerloopTerminator = outerloop.getBody()->getTerminator();
735  if (!llvm::equal(outerloopTerminator->getOperands(),
736  innerLoop.getResults())) {
737  continue;
738  }
739  iterArgChainStart[i] = iterArgChainStart[i - 1];
740  }
741 
742  // 3. Identify bands of loops such that the operands of all of them are
743  // defined above the first loop in the band. Traverse the nest bottom-up
744  // so that modifications don't invalidate the inner loops.
745  for (unsigned end = loops.size(); end > 0; --end) {
746  unsigned start = 0;
747  for (; start < end - 1; ++start) {
748  auto maxPos =
749  *std::max_element(std::next(operandsDefinedAbove.begin(), start),
750  std::next(operandsDefinedAbove.begin(), end));
751  if (maxPos > start)
752  continue;
753  if (iterArgChainStart[end - 1] > start)
754  continue;
755  auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
756  if (succeeded(coalesceLoops(band)))
757  result = success();
758  break;
759  }
760  // If a band was found and transformed, keep looking at the loops above
761  // the outermost transformed loop.
762  if (start != end - 1)
763  end = start + 1;
764  }
765  return result;
766 }
767 
769  RewriterBase &rewriter, scf::ParallelOp loops,
770  ArrayRef<std::vector<unsigned>> combinedDimensions) {
771  OpBuilder::InsertionGuard g(rewriter);
772  rewriter.setInsertionPoint(loops);
773  Location loc = loops.getLoc();
774 
775  // Presort combined dimensions.
776  auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
777  for (auto &dims : sortedDimensions)
778  llvm::sort(dims);
779 
780  // Normalize ParallelOp's iteration pattern.
781  SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
782  normalizedUpperBounds;
783  for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
784  OpBuilder::InsertionGuard g2(rewriter);
785  rewriter.setInsertionPoint(loops);
786  Value lb = loops.getLowerBound()[i];
787  Value ub = loops.getUpperBound()[i];
788  Value step = loops.getStep()[i];
789  auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
790  normalizedLowerBounds.push_back(newLoopParams.lowerBound);
791  normalizedUpperBounds.push_back(newLoopParams.upperBound);
792  normalizedSteps.push_back(newLoopParams.step);
793 
794  rewriter.setInsertionPointToStart(loops.getBody());
795  denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
796  step);
797  }
798 
799  // Combine iteration spaces.
800  SmallVector<Value, 3> lowerBounds, upperBounds, steps;
801  auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
802  auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
803  for (auto &sortedDimension : sortedDimensions) {
804  Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
805  for (auto idx : sortedDimension) {
806  newUpperBound = rewriter.create<arith::MulIOp>(
807  loc, newUpperBound, normalizedUpperBounds[idx]);
808  }
809  lowerBounds.push_back(cst0);
810  steps.push_back(cst1);
811  upperBounds.push_back(newUpperBound);
812  }
813 
814  // Create new ParallelLoop with conversions to the original induction values.
815  // The loop below uses divisions to get the relevant range of values in the
816  // new induction value that represent each range of the original induction
817  // value. The remainders then determine based on that range, which iteration
818  // of the original induction value this represents. This is a normalized value
819  // that is un-normalized already by the previous logic.
820  auto newPloop = rewriter.create<scf::ParallelOp>(
821  loc, lowerBounds, upperBounds, steps,
822  [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
823  for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
824  Value previous = ploopIVs[i];
825  unsigned numberCombinedDimensions = combinedDimensions[i].size();
826  // Iterate over all except the last induction value.
827  for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
828  unsigned idx = combinedDimensions[i][j];
829 
830  // Determine the current induction value's current loop iteration
831  Value iv = insideBuilder.create<arith::RemSIOp>(
832  loc, previous, normalizedUpperBounds[idx]);
833  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
834  loops.getRegion());
835 
836  // Remove the effect of the current induction value to prepare for
837  // the next value.
838  previous = insideBuilder.create<arith::DivSIOp>(
839  loc, previous, normalizedUpperBounds[idx]);
840  }
841 
842  // The final induction value is just the remaining value.
843  unsigned idx = combinedDimensions[i][0];
844  replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
845  previous, loops.getRegion());
846  }
847  });
848 
849  // Replace the old loop with the new loop.
850  loops.getBody()->back().erase();
851  newPloop.getBody()->getOperations().splice(
852  Block::iterator(newPloop.getBody()->back()),
853  loops.getBody()->getOperations());
854  loops.erase();
855 }
856 
857 // Hoist the ops within `outer` that appear before `inner`.
858 // Such ops include the ops that have been introduced by parametric tiling.
859 // Ops that come from triangular loops (i.e. that belong to the program slice
860 // rooted at `outer`) and ops that have side effects cannot be hoisted.
861 // Return failure when any op fails to hoist.
862 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
863  SetVector<Operation *> forwardSlice;
865  options.filter = [&inner](Operation *op) {
866  return op != inner.getOperation();
867  };
868  getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
869  LogicalResult status = success();
871  for (auto &op : outer.getBody()->without_terminator()) {
872  // Stop when encountering the inner loop.
873  if (&op == inner.getOperation())
874  break;
875  // Skip over non-hoistable ops.
876  if (forwardSlice.count(&op) > 0) {
877  status = failure();
878  continue;
879  }
880  // Skip intermediate scf::ForOp, these are not considered a failure.
881  if (isa<scf::ForOp>(op))
882  continue;
883  // Skip other ops with regions.
884  if (op.getNumRegions() > 0) {
885  status = failure();
886  continue;
887  }
888  // Skip if op has side effects.
889  // TODO: loads to immutable memory regions are ok.
890  if (!isMemoryEffectFree(&op)) {
891  status = failure();
892  continue;
893  }
894  toHoist.push_back(&op);
895  }
896  auto *outerForOp = outer.getOperation();
897  for (auto *op : toHoist)
898  op->moveBefore(outerForOp);
899  return status;
900 }
901 
902 // Traverse the interTile and intraTile loops and try to hoist ops such that
903 // bands of perfectly nested loops are isolated.
904 // Return failure if either perfect interTile or perfect intraTile bands cannot
905 // be formed.
906 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
907  LogicalResult status = success();
908  const Loops &interTile = tileLoops.first;
909  const Loops &intraTile = tileLoops.second;
910  auto size = interTile.size();
911  assert(size == intraTile.size());
912  if (size <= 1)
913  return success();
914  for (unsigned s = 1; s < size; ++s)
915  status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
916  : failure();
917  for (unsigned s = 1; s < size; ++s)
918  status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
919  : failure();
920  return status;
921 }
922 
923 /// Collect perfectly nested loops starting from `rootForOps`. Loops are
924 /// perfectly nested if each loop is the first and only non-terminator operation
925 /// in the parent loop. Collect at most `maxLoops` loops and append them to
926 /// `forOps`.
927 template <typename T>
929  SmallVectorImpl<T> &forOps, T rootForOp,
930  unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
931  for (unsigned i = 0; i < maxLoops; ++i) {
932  forOps.push_back(rootForOp);
933  Block &body = rootForOp.getRegion().front();
934  if (body.begin() != std::prev(body.end(), 2))
935  return;
936 
937  rootForOp = dyn_cast<T>(&body.front());
938  if (!rootForOp)
939  return;
940  }
941 }
942 
943 static Loops stripmineSink(scf::ForOp forOp, Value factor,
944  ArrayRef<scf::ForOp> targets) {
945  auto originalStep = forOp.getStep();
946  auto iv = forOp.getInductionVar();
947 
948  OpBuilder b(forOp);
949  forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
950 
951  Loops innerLoops;
952  for (auto t : targets) {
953  // Save information for splicing ops out of t when done
954  auto begin = t.getBody()->begin();
955  auto nOps = t.getBody()->getOperations().size();
956 
957  // Insert newForOp before the terminator of `t`.
958  auto b = OpBuilder::atBlockTerminator((t.getBody()));
959  Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
960  Value ub =
961  b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
962 
963  // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
964  auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
965  newForOp.getBody()->getOperations().splice(
966  newForOp.getBody()->getOperations().begin(),
967  t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
968  replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
969  newForOp.getRegion());
970 
971  innerLoops.push_back(newForOp);
972  }
973 
974  return innerLoops;
975 }
976 
977 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
978 // Returns the new for operation, nested immediately under `target`.
979 template <typename SizeType>
980 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
981  scf::ForOp target) {
982  // TODO: Use cheap structural assertions that targets are nested under
983  // forOp and that targets are not nested under each other when DominanceInfo
984  // exposes the capability. It seems overkill to construct a whole function
985  // dominance tree at this point.
986  auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
987  assert(res.size() == 1 && "Expected 1 inner forOp");
988  return res[0];
989 }
990 
992  ArrayRef<Value> sizes,
993  ArrayRef<scf::ForOp> targets) {
995  SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end());
996  for (auto it : llvm::zip(forOps, sizes)) {
997  auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
998  res.push_back(step);
999  currentTargets = step;
1000  }
1001  return res;
1002 }
1003 
1005  scf::ForOp target) {
1007  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
1008  assert(loops.size() == 1);
1009  res.push_back(loops[0]);
1010  }
1011  return res;
1012 }
1013 
1014 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
1015  // Collect perfectly nested loops. If more size values provided than nested
1016  // loops available, truncate `sizes`.
1018  forOps.reserve(sizes.size());
1019  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1020  if (forOps.size() < sizes.size())
1021  sizes = sizes.take_front(forOps.size());
1022 
1023  return ::tile(forOps, sizes, forOps.back());
1024 }
1025 
1027  scf::ForOp root) {
1028  getPerfectlyNestedLoopsImpl(nestedLoops, root);
1029 }
1030 
1032  ArrayRef<int64_t> sizes) {
1033  // Collect perfectly nested loops. If more size values provided than nested
1034  // loops available, truncate `sizes`.
1036  forOps.reserve(sizes.size());
1037  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
1038  if (forOps.size() < sizes.size())
1039  sizes = sizes.take_front(forOps.size());
1040 
1041  // Compute the tile sizes such that i-th outer loop executes size[i]
1042  // iterations. Given that the loop current executes
1043  // numIterations = ceildiv((upperBound - lowerBound), step)
1044  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
1045  SmallVector<Value, 4> tileSizes;
1046  tileSizes.reserve(sizes.size());
1047  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
1048  assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
1049 
1050  auto forOp = forOps[i];
1051  OpBuilder builder(forOp);
1052  auto loc = forOp.getLoc();
1053  Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
1054  forOp.getLowerBound());
1055  Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
1056  Value iterationsPerBlock =
1057  ceilDivPositive(builder, loc, numIterations, sizes[i]);
1058  tileSizes.push_back(iterationsPerBlock);
1059  }
1060 
1061  // Call parametric tiling with the given sizes.
1062  auto intraTile = tile(forOps, tileSizes, forOps.back());
1063  TileLoops tileLoops = std::make_pair(forOps, intraTile);
1064 
1065  // TODO: for now we just ignore the result of band isolation.
1066  // In the future, mapping decisions may be impacted by the ability to
1067  // isolate perfectly nested bands.
1068  (void)tryIsolateBands(tileLoops);
1069 
1070  return tileLoops;
1071 }
1072 
1073 scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
1074  scf::ForallOp source,
1075  RewriterBase &rewriter) {
1076  unsigned numTargetOuts = target.getNumResults();
1077  unsigned numSourceOuts = source.getNumResults();
1078 
1079  // Create fused shared_outs.
1080  SmallVector<Value> fusedOuts;
1081  llvm::append_range(fusedOuts, target.getOutputs());
1082  llvm::append_range(fusedOuts, source.getOutputs());
1083 
1084  // Create a new scf.forall op after the source loop.
1085  rewriter.setInsertionPointAfter(source);
1086  scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
1087  source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1088  source.getMixedStep(), fusedOuts, source.getMapping());
1089 
1090  // Map control operands.
1091  IRMapping mapping;
1092  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
1093  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1094 
1095  // Map shared outs.
1096  mapping.map(target.getRegionIterArgs(),
1097  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1098  mapping.map(source.getRegionIterArgs(),
1099  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1100 
1101  // Append everything except the terminator into the fused operation.
1102  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1103  for (Operation &op : target.getBody()->without_terminator())
1104  rewriter.clone(op, mapping);
1105  for (Operation &op : source.getBody()->without_terminator())
1106  rewriter.clone(op, mapping);
1107 
1108  // Fuse the old terminator in_parallel ops into the new one.
1109  scf::InParallelOp targetTerm = target.getTerminator();
1110  scf::InParallelOp sourceTerm = source.getTerminator();
1111  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1112  rewriter.setInsertionPointToStart(fusedTerm.getBody());
1113  for (Operation &op : targetTerm.getYieldingOps())
1114  rewriter.clone(op, mapping);
1115  for (Operation &op : sourceTerm.getYieldingOps())
1116  rewriter.clone(op, mapping);
1117 
1118  // Replace old loops by substituting their uses by results of the fused loop.
1119  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1120  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1121 
1122  return fusedLoop;
1123 }
1124 
1125 scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
1126  scf::ForOp source,
1127  RewriterBase &rewriter) {
1128  unsigned numTargetOuts = target.getNumResults();
1129  unsigned numSourceOuts = source.getNumResults();
1130 
1131  // Create fused init_args, with target's init_args before source's init_args.
1132  SmallVector<Value> fusedInitArgs;
1133  llvm::append_range(fusedInitArgs, target.getInitArgs());
1134  llvm::append_range(fusedInitArgs, source.getInitArgs());
1135 
1136  // Create a new scf.for op after the source loop (with scf.yield terminator
1137  // (without arguments) only in case its init_args is empty).
1138  rewriter.setInsertionPointAfter(source);
1139  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
1140  source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1141  source.getStep(), fusedInitArgs);
1142 
1143  // Map original induction variables and operands to those of the fused loop.
1144  IRMapping mapping;
1145  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
1146  mapping.map(target.getRegionIterArgs(),
1147  fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1148  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1149  mapping.map(source.getRegionIterArgs(),
1150  fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1151 
1152  // Merge target's body into the new (fused) for loop and then source's body.
1153  rewriter.setInsertionPointToStart(fusedLoop.getBody());
1154  for (Operation &op : target.getBody()->without_terminator())
1155  rewriter.clone(op, mapping);
1156  for (Operation &op : source.getBody()->without_terminator())
1157  rewriter.clone(op, mapping);
1158 
1159  // Build fused yield results by appropriately mapping original yield operands.
1160  SmallVector<Value> yieldResults;
1161  for (Value operand : target.getBody()->getTerminator()->getOperands())
1162  yieldResults.push_back(mapping.lookupOrDefault(operand));
1163  for (Value operand : source.getBody()->getTerminator()->getOperands())
1164  yieldResults.push_back(mapping.lookupOrDefault(operand));
1165  if (!yieldResults.empty())
1166  rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1167 
1168  // Replace old loops by substituting their uses by results of the fused loop.
1169  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1170  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1171 
1172  return fusedLoop;
1173 }
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
Definition: Utils.cpp:906
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
Definition: Utils.cpp:928
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
Definition: Utils.cpp:862
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
Definition: Utils.cpp:303
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
Definition: Utils.cpp:943
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
Definition: Utils.cpp:576
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
Definition: Utils.cpp:273
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
Definition: Utils.cpp:544
static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, Value lb, Value ub, Value step)
Transform a loop with a strictly positive step for i = lb to ub step s into a 0-based loop with step ...
Definition: Utils.cpp:485
static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, Value origLb, Value origStep)
Get back the original induction variable values after loop normalization.
Definition: Utils.cpp:521
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:349
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
Definition: Builders.h:254
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
Definition: Region.h:205
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Definition: Value.cpp:81
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: LoopUtils.cpp:1624
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
Definition: Utils.cpp:1026
LogicalResult loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition: Utils.cpp:356
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
Definition: Utils.cpp:225
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
Definition: RegionUtils.cpp:28
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:42
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition: Utils.cpp:693
std::pair< Loops, Loops > TileLoops
Definition: Utils.h:127
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
Definition: Utils.cpp:768
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
Definition: Utils.cpp:1014
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
Definition: Utils.cpp:248
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Definition: RegionUtils.cpp:63
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:991
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:121
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition: RegionUtils.h:24
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:1073
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
Definition: Utils.cpp:685
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition: Utils.cpp:1125
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Definition: Utils.cpp:1031
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.