MLIR  22.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78  InParallelOp, ReduceReturnOp>();
79  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81  ForallOp, InParallelOp, WhileOp, YieldOp>();
82  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
83 }
84 
85 /// Default callback for IfOp builders. Inserts a yield without arguments.
87  scf::YieldOp::create(builder, loc);
88 }
89 
90 /// Verifies that the first block of the given `region` is terminated by a
91 /// TerminatorTy. Reports errors on the given operation if it is not the case.
92 template <typename TerminatorTy>
93 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
94  StringRef errorMessage) {
95  Operation *terminatorOperation = nullptr;
96  if (!region.empty() && !region.front().empty()) {
97  terminatorOperation = &region.front().back();
98  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
99  return yield;
100  }
101  auto diag = op->emitOpError(errorMessage);
102  if (terminatorOperation)
103  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
104  return nullptr;
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // ExecuteRegionOp
109 //===----------------------------------------------------------------------===//
110 
111 /// Replaces the given op with the contents of the given single-block region,
112 /// using the operands of the block terminator to replace operation results.
113 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
114  Region &region, ValueRange blockArgs = {}) {
115  assert(region.hasOneBlock() && "expected single-block region");
116  Block *block = &region.front();
117  Operation *terminator = block->getTerminator();
118  ValueRange results = terminator->getOperands();
119  rewriter.inlineBlockBefore(block, op, blockArgs);
120  rewriter.replaceOp(op, results);
121  rewriter.eraseOp(terminator);
122 }
123 
124 ///
125 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
126 /// block+
127 /// `}`
128 ///
129 /// Example:
130 /// scf.execute_region -> i32 {
131 /// %idx = load %rI[%i] : memref<128xi32>
132 /// return %idx : i32
133 /// }
134 ///
135 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
136  OperationState &result) {
137  if (parser.parseOptionalArrowTypeList(result.types))
138  return failure();
139 
140  if (succeeded(parser.parseOptionalKeyword("no_inline")))
141  result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
142 
143  // Introduce the body region and parse it.
144  Region *body = result.addRegion();
145  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
146  parser.parseOptionalAttrDict(result.attributes))
147  return failure();
148 
149  return success();
150 }
151 
153  p.printOptionalArrowTypeList(getResultTypes());
154  p << ' ';
155  if (getNoInline())
156  p << "no_inline ";
157  p.printRegion(getRegion(),
158  /*printEntryBlockArgs=*/false,
159  /*printBlockTerminators=*/true);
160  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
161 }
162 
163 LogicalResult ExecuteRegionOp::verify() {
164  if (getRegion().empty())
165  return emitOpError("region needs to have at least one block");
166  if (getRegion().front().getNumArguments() > 0)
167  return emitOpError("region cannot have any arguments");
168  return success();
169 }
170 
171 // Inline an ExecuteRegionOp if it only contains one block.
172 // "test.foo"() : () -> ()
173 // %v = scf.execute_region -> i64 {
174 // %x = "test.val"() : () -> i64
175 // scf.yield %x : i64
176 // }
177 // "test.bar"(%v) : (i64) -> ()
178 //
179 // becomes
180 //
181 // "test.foo"() : () -> ()
182 // %x = "test.val"() : () -> i64
183 // "test.bar"(%x) : (i64) -> ()
184 //
185 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
187 
188  LogicalResult matchAndRewrite(ExecuteRegionOp op,
189  PatternRewriter &rewriter) const override {
190  if (!op.getRegion().hasOneBlock() || op.getNoInline())
191  return failure();
192  replaceOpWithRegion(rewriter, op, op.getRegion());
193  return success();
194  }
195 };
196 
197 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
198 // TODO generalize the conditions for operations which can be inlined into.
199 // func @func_execute_region_elim() {
200 // "test.foo"() : () -> ()
201 // %v = scf.execute_region -> i64 {
202 // %c = "test.cmp"() : () -> i1
203 // cf.cond_br %c, ^bb2, ^bb3
204 // ^bb2:
205 // %x = "test.val1"() : () -> i64
206 // cf.br ^bb4(%x : i64)
207 // ^bb3:
208 // %y = "test.val2"() : () -> i64
209 // cf.br ^bb4(%y : i64)
210 // ^bb4(%z : i64):
211 // scf.yield %z : i64
212 // }
213 // "test.bar"(%v) : (i64) -> ()
214 // return
215 // }
216 //
217 // becomes
218 //
219 // func @func_execute_region_elim() {
220 // "test.foo"() : () -> ()
221 // %c = "test.cmp"() : () -> i1
222 // cf.cond_br %c, ^bb1, ^bb2
223 // ^bb1: // pred: ^bb0
224 // %x = "test.val1"() : () -> i64
225 // cf.br ^bb3(%x : i64)
226 // ^bb2: // pred: ^bb0
227 // %y = "test.val2"() : () -> i64
228 // cf.br ^bb3(%y : i64)
229 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
230 // "test.bar"(%z) : (i64) -> ()
231 // return
232 // }
233 //
234 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
236 
237  LogicalResult matchAndRewrite(ExecuteRegionOp op,
238  PatternRewriter &rewriter) const override {
239  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
240  return failure();
241 
242  Block *prevBlock = op->getBlock();
243  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
244  rewriter.setInsertionPointToEnd(prevBlock);
245 
246  cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
247 
248  for (Block &blk : op.getRegion()) {
249  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
250  rewriter.setInsertionPoint(yieldOp);
251  cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
252  yieldOp.getResults());
253  rewriter.eraseOp(yieldOp);
254  }
255  }
256 
257  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
258  SmallVector<Value> blockArgs;
259 
260  for (auto res : op.getResults())
261  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
262 
263  rewriter.replaceOp(op, blockArgs);
264  return success();
265  }
266 };
267 
268 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
269  MLIRContext *context) {
271 }
272 
273 void ExecuteRegionOp::getSuccessorRegions(
275  // If the predecessor is the ExecuteRegionOp, branch into the body.
276  if (point.isParent()) {
277  regions.push_back(RegionSuccessor(&getRegion()));
278  return;
279  }
280 
281  // Otherwise, the region branches back to the parent operation.
282  regions.push_back(RegionSuccessor(getResults()));
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // ConditionOp
287 //===----------------------------------------------------------------------===//
288 
291  assert((point.isParent() || point == getParentOp().getAfter()) &&
292  "condition op can only exit the loop or branch to the after"
293  "region");
294  // Pass all operands except the condition to the successor region.
295  return getArgsMutable();
296 }
297 
298 void ConditionOp::getSuccessorRegions(
300  FoldAdaptor adaptor(operands, *this);
301 
302  WhileOp whileOp = getParentOp();
303 
304  // Condition can either lead to the after region or back to the parent op
305  // depending on whether the condition is true or not.
306  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
307  if (!boolAttr || boolAttr.getValue())
308  regions.emplace_back(&whileOp.getAfter(),
309  whileOp.getAfter().getArguments());
310  if (!boolAttr || !boolAttr.getValue())
311  regions.emplace_back(whileOp.getResults());
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // ForOp
316 //===----------------------------------------------------------------------===//
317 
318 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
319  Value ub, Value step, ValueRange initArgs,
320  BodyBuilderFn bodyBuilder, bool unsignedCmp) {
321  OpBuilder::InsertionGuard guard(builder);
322 
323  if (unsignedCmp)
324  result.addAttribute(getUnsignedCmpAttrName(result.name),
325  builder.getUnitAttr());
326  result.addOperands({lb, ub, step});
327  result.addOperands(initArgs);
328  for (Value v : initArgs)
329  result.addTypes(v.getType());
330  Type t = lb.getType();
331  Region *bodyRegion = result.addRegion();
332  Block *bodyBlock = builder.createBlock(bodyRegion);
333  bodyBlock->addArgument(t, result.location);
334  for (Value v : initArgs)
335  bodyBlock->addArgument(v.getType(), v.getLoc());
336 
337  // Create the default terminator if the builder is not provided and if the
338  // iteration arguments are not provided. Otherwise, leave this to the caller
339  // because we don't know which values to return from the loop.
340  if (initArgs.empty() && !bodyBuilder) {
341  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
342  } else if (bodyBuilder) {
343  OpBuilder::InsertionGuard guard(builder);
344  builder.setInsertionPointToStart(bodyBlock);
345  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
346  bodyBlock->getArguments().drop_front());
347  }
348 }
349 
350 LogicalResult ForOp::verify() {
351  // Check that the number of init args and op results is the same.
352  if (getInitArgs().size() != getNumResults())
353  return emitOpError(
354  "mismatch in number of loop-carried values and defined values");
355 
356  return success();
357 }
358 
359 LogicalResult ForOp::verifyRegions() {
360  // Check that the body defines as single block argument for the induction
361  // variable.
362  if (getInductionVar().getType() != getLowerBound().getType())
363  return emitOpError(
364  "expected induction variable to be same type as bounds and step");
365 
366  if (getNumRegionIterArgs() != getNumResults())
367  return emitOpError(
368  "mismatch in number of basic block args and defined values");
369 
370  auto initArgs = getInitArgs();
371  auto iterArgs = getRegionIterArgs();
372  auto opResults = getResults();
373  unsigned i = 0;
374  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
375  if (std::get<0>(e).getType() != std::get<2>(e).getType())
376  return emitOpError() << "types mismatch between " << i
377  << "th iter operand and defined value";
378  if (std::get<1>(e).getType() != std::get<2>(e).getType())
379  return emitOpError() << "types mismatch between " << i
380  << "th iter region arg and defined value";
381 
382  ++i;
383  }
384  return success();
385 }
386 
387 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
388  return SmallVector<Value>{getInductionVar()};
389 }
390 
391 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
393 }
394 
395 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
396  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
397 }
398 
399 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
401 }
402 
403 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
404 
405 /// Promotes the loop body of a forOp to its containing block if the forOp
406 /// it can be determined that the loop has a single iteration.
407 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
408  std::optional<int64_t> tripCount =
410  if (!tripCount.has_value() || tripCount != 1)
411  return failure();
412 
413  // Replace all results with the yielded values.
414  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
415  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
416 
417  // Replace block arguments with lower bound (replacement for IV) and
418  // iter_args.
419  SmallVector<Value> bbArgReplacements;
420  bbArgReplacements.push_back(getLowerBound());
421  llvm::append_range(bbArgReplacements, getInitArgs());
422 
423  // Move the loop body operations to the loop's containing block.
424  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
425  getOperation()->getIterator(), bbArgReplacements);
426 
427  // Erase the old terminator and the loop.
428  rewriter.eraseOp(yieldOp);
429  rewriter.eraseOp(*this);
430 
431  return success();
432 }
433 
434 /// Prints the initialization list in the form of
435 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
436 /// where 'inner' values are assumed to be region arguments and 'outer' values
437 /// are regular SSA values.
439  Block::BlockArgListType blocksArgs,
440  ValueRange initializers,
441  StringRef prefix = "") {
442  assert(blocksArgs.size() == initializers.size() &&
443  "expected same length of arguments and initializers");
444  if (initializers.empty())
445  return;
446 
447  p << prefix << '(';
448  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
449  p << std::get<0>(it) << " = " << std::get<1>(it);
450  });
451  p << ")";
452 }
453 
454 void ForOp::print(OpAsmPrinter &p) {
455  if (getUnsignedCmp())
456  p << " unsigned";
457 
458  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
459  << getUpperBound() << " step " << getStep();
460 
461  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
462  if (!getInitArgs().empty())
463  p << " -> (" << getInitArgs().getTypes() << ')';
464  p << ' ';
465  if (Type t = getInductionVar().getType(); !t.isIndex())
466  p << " : " << t << ' ';
467  p.printRegion(getRegion(),
468  /*printEntryBlockArgs=*/false,
469  /*printBlockTerminators=*/!getInitArgs().empty());
470  p.printOptionalAttrDict((*this)->getAttrs(),
471  /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
472 }
473 
474 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
475  auto &builder = parser.getBuilder();
476  Type type;
477 
478  OpAsmParser::Argument inductionVariable;
479  OpAsmParser::UnresolvedOperand lb, ub, step;
480 
481  if (succeeded(parser.parseOptionalKeyword("unsigned")))
482  result.addAttribute(getUnsignedCmpAttrName(result.name),
483  builder.getUnitAttr());
484 
485  // Parse the induction variable followed by '='.
486  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
487  // Parse loop bounds.
488  parser.parseOperand(lb) || parser.parseKeyword("to") ||
489  parser.parseOperand(ub) || parser.parseKeyword("step") ||
490  parser.parseOperand(step))
491  return failure();
492 
493  // Parse the optional initial iteration arguments.
496  regionArgs.push_back(inductionVariable);
497 
498  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
499  if (hasIterArgs) {
500  // Parse assignment list and results type list.
501  if (parser.parseAssignmentList(regionArgs, operands) ||
502  parser.parseArrowTypeList(result.types))
503  return failure();
504  }
505 
506  if (regionArgs.size() != result.types.size() + 1)
507  return parser.emitError(
508  parser.getNameLoc(),
509  "mismatch in number of loop-carried values and defined values");
510 
511  // Parse optional type, else assume Index.
512  if (parser.parseOptionalColon())
513  type = builder.getIndexType();
514  else if (parser.parseType(type))
515  return failure();
516 
517  // Set block argument types, so that they are known when parsing the region.
518  regionArgs.front().type = type;
519  for (auto [iterArg, type] :
520  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
521  iterArg.type = type;
522 
523  // Parse the body region.
524  Region *body = result.addRegion();
525  if (parser.parseRegion(*body, regionArgs))
526  return failure();
527  ForOp::ensureTerminator(*body, builder, result.location);
528 
529  // Resolve input operands. This should be done after parsing the region to
530  // catch invalid IR where operands were defined inside of the region.
531  if (parser.resolveOperand(lb, type, result.operands) ||
532  parser.resolveOperand(ub, type, result.operands) ||
533  parser.resolveOperand(step, type, result.operands))
534  return failure();
535  if (hasIterArgs) {
536  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
537  operands, result.types)) {
538  Type type = std::get<2>(argOperandType);
539  std::get<0>(argOperandType).type = type;
540  if (parser.resolveOperand(std::get<1>(argOperandType), type,
541  result.operands))
542  return failure();
543  }
544  }
545 
546  // Parse the optional attribute list.
547  if (parser.parseOptionalAttrDict(result.attributes))
548  return failure();
549 
550  return success();
551 }
552 
553 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
554 
555 Block::BlockArgListType ForOp::getRegionIterArgs() {
556  return getBody()->getArguments().drop_front(getNumInductionVars());
557 }
558 
559 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
560  return getInitArgsMutable();
561 }
562 
563 FailureOr<LoopLikeOpInterface>
564 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
565  ValueRange newInitOperands,
566  bool replaceInitOperandUsesInLoop,
567  const NewYieldValuesFn &newYieldValuesFn) {
568  // Create a new loop before the existing one, with the extra operands.
569  OpBuilder::InsertionGuard g(rewriter);
570  rewriter.setInsertionPoint(getOperation());
571  auto inits = llvm::to_vector(getInitArgs());
572  inits.append(newInitOperands.begin(), newInitOperands.end());
573  scf::ForOp newLoop = scf::ForOp::create(
574  rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
575  [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
576  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
577 
578  // Generate the new yield values and append them to the scf.yield operation.
579  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
580  ArrayRef<BlockArgument> newIterArgs =
581  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
582  {
583  OpBuilder::InsertionGuard g(rewriter);
584  rewriter.setInsertionPoint(yieldOp);
585  SmallVector<Value> newYieldedValues =
586  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
587  assert(newInitOperands.size() == newYieldedValues.size() &&
588  "expected as many new yield values as new iter operands");
589  rewriter.modifyOpInPlace(yieldOp, [&]() {
590  yieldOp.getResultsMutable().append(newYieldedValues);
591  });
592  }
593 
594  // Move the loop body to the new op.
595  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
596  newLoop.getBody()->getArguments().take_front(
597  getBody()->getNumArguments()));
598 
599  if (replaceInitOperandUsesInLoop) {
600  // Replace all uses of `newInitOperands` with the corresponding basic block
601  // arguments.
602  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
603  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
604  [&](OpOperand &use) {
605  Operation *user = use.getOwner();
606  return newLoop->isProperAncestor(user);
607  });
608  }
609  }
610 
611  // Replace the old loop.
612  rewriter.replaceOp(getOperation(),
613  newLoop->getResults().take_front(getNumResults()));
614  return cast<LoopLikeOpInterface>(newLoop.getOperation());
615 }
616 
618  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
619  if (!ivArg)
620  return ForOp();
621  assert(ivArg.getOwner() && "unlinked block argument");
622  auto *containingOp = ivArg.getOwner()->getParentOp();
623  return dyn_cast_or_null<ForOp>(containingOp);
624 }
625 
626 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
627  return getInitArgs();
628 }
629 
630 void ForOp::getSuccessorRegions(RegionBranchPoint point,
632  // Both the operation itself and the region may be branching into the body or
633  // back into the operation itself. It is possible for loop not to enter the
634  // body.
635  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
636  regions.push_back(RegionSuccessor(getResults()));
637 }
638 
639 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
640 
641 /// Promotes the loop body of a forallOp to its containing block if it can be
642 /// determined that the loop has a single iteration.
643 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
644  for (auto [lb, ub, step] :
645  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
646  auto tripCount = constantTripCount(lb, ub, step);
647  if (!tripCount.has_value() || *tripCount != 1)
648  return failure();
649  }
650 
651  promote(rewriter, *this);
652  return success();
653 }
654 
655 Block::BlockArgListType ForallOp::getRegionIterArgs() {
656  return getBody()->getArguments().drop_front(getRank());
657 }
658 
659 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
660  return getOutputsMutable();
661 }
662 
663 /// Promotes the loop body of a scf::ForallOp to its containing block.
664 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
665  OpBuilder::InsertionGuard g(rewriter);
666  scf::InParallelOp terminator = forallOp.getTerminator();
667 
668  // Replace block arguments with lower bounds (replacements for IVs) and
669  // outputs.
670  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
671  bbArgReplacements.append(forallOp.getOutputs().begin(),
672  forallOp.getOutputs().end());
673 
674  // Move the loop body operations to the loop's containing block.
675  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
676  forallOp->getIterator(), bbArgReplacements);
677 
678  // Replace the terminator with tensor.insert_slice ops.
679  rewriter.setInsertionPointAfter(forallOp);
680  SmallVector<Value> results;
681  results.reserve(forallOp.getResults().size());
682  for (auto &yieldingOp : terminator.getYieldingOps()) {
683  auto parallelInsertSliceOp =
684  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
685 
686  Value dst = parallelInsertSliceOp.getDest();
687  Value src = parallelInsertSliceOp.getSource();
688  if (llvm::isa<TensorType>(src.getType())) {
689  results.push_back(tensor::InsertSliceOp::create(
690  rewriter, forallOp.getLoc(), dst.getType(), src, dst,
691  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
692  parallelInsertSliceOp.getStrides(),
693  parallelInsertSliceOp.getStaticOffsets(),
694  parallelInsertSliceOp.getStaticSizes(),
695  parallelInsertSliceOp.getStaticStrides()));
696  } else {
697  llvm_unreachable("unsupported terminator");
698  }
699  }
700  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
701 
702  // Erase the old terminator and the loop.
703  rewriter.eraseOp(terminator);
704  rewriter.eraseOp(forallOp);
705 }
706 
708  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
709  ValueRange steps, ValueRange iterArgs,
711  bodyBuilder) {
712  assert(lbs.size() == ubs.size() &&
713  "expected the same number of lower and upper bounds");
714  assert(lbs.size() == steps.size() &&
715  "expected the same number of lower bounds and steps");
716 
717  // If there are no bounds, call the body-building function and return early.
718  if (lbs.empty()) {
719  ValueVector results =
720  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
721  : ValueVector();
722  assert(results.size() == iterArgs.size() &&
723  "loop nest body must return as many values as loop has iteration "
724  "arguments");
725  return LoopNest{{}, std::move(results)};
726  }
727 
728  // First, create the loop structure iteratively using the body-builder
729  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
730  OpBuilder::InsertionGuard guard(builder);
733  loops.reserve(lbs.size());
734  ivs.reserve(lbs.size());
735  ValueRange currentIterArgs = iterArgs;
736  Location currentLoc = loc;
737  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
738  auto loop = scf::ForOp::create(
739  builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
740  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
741  ValueRange args) {
742  ivs.push_back(iv);
743  // It is safe to store ValueRange args because it points to block
744  // arguments of a loop operation that we also own.
745  currentIterArgs = args;
746  currentLoc = nestedLoc;
747  });
748  // Set the builder to point to the body of the newly created loop. We don't
749  // do this in the callback because the builder is reset when the callback
750  // returns.
751  builder.setInsertionPointToStart(loop.getBody());
752  loops.push_back(loop);
753  }
754 
755  // For all loops but the innermost, yield the results of the nested loop.
756  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
757  builder.setInsertionPointToEnd(loops[i].getBody());
758  scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
759  }
760 
761  // In the body of the innermost loop, call the body building function if any
762  // and yield its results.
763  builder.setInsertionPointToStart(loops.back().getBody());
764  ValueVector results = bodyBuilder
765  ? bodyBuilder(builder, currentLoc, ivs,
766  loops.back().getRegionIterArgs())
767  : ValueVector();
768  assert(results.size() == iterArgs.size() &&
769  "loop nest body must return as many values as loop has iteration "
770  "arguments");
771  builder.setInsertionPointToEnd(loops.back().getBody());
772  scf::YieldOp::create(builder, loc, results);
773 
774  // Return the loops.
775  ValueVector nestResults;
776  llvm::append_range(nestResults, loops.front().getResults());
777  return LoopNest{std::move(loops), std::move(nestResults)};
778 }
779 
781  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
782  ValueRange steps,
783  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
784  // Delegate to the main function by wrapping the body builder.
785  return buildLoopNest(builder, loc, lbs, ubs, steps, {},
786  [&bodyBuilder](OpBuilder &nestedBuilder,
787  Location nestedLoc, ValueRange ivs,
788  ValueRange) -> ValueVector {
789  if (bodyBuilder)
790  bodyBuilder(nestedBuilder, nestedLoc, ivs);
791  return {};
792  });
793 }
794 
797  OpOperand &operand, Value replacement,
798  const ValueTypeCastFnTy &castFn) {
799  assert(operand.getOwner() == forOp);
800  Type oldType = operand.get().getType(), newType = replacement.getType();
801 
802  // 1. Create new iter operands, exactly 1 is replaced.
803  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
804  "expected an iter OpOperand");
805  assert(operand.get().getType() != replacement.getType() &&
806  "Expected a different type");
807  SmallVector<Value> newIterOperands;
808  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
809  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
810  newIterOperands.push_back(replacement);
811  continue;
812  }
813  newIterOperands.push_back(opOperand.get());
814  }
815 
816  // 2. Create the new forOp shell.
817  scf::ForOp newForOp = scf::ForOp::create(
818  rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
819  forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
820  forOp.getUnsignedCmp());
821  newForOp->setAttrs(forOp->getAttrs());
822  Block &newBlock = newForOp.getRegion().front();
823  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
824  newBlock.getArguments().end());
825 
826  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
827  // corresponding to the `replacement` value.
828  OpBuilder::InsertionGuard g(rewriter);
829  rewriter.setInsertionPointToStart(&newBlock);
830  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
831  &newForOp->getOpOperand(operand.getOperandNumber()));
832  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
833  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
834 
835  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
836  Block &oldBlock = forOp.getRegion().front();
837  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
838 
839  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
840  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
841  rewriter.setInsertionPoint(clonedYieldOp);
842  unsigned yieldIdx =
843  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
844  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
845  clonedYieldOp.getOperand(yieldIdx));
846  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
847  newYieldOperands[yieldIdx] = castOut;
848  scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
849  rewriter.eraseOp(clonedYieldOp);
850 
851  // 6. Inject an outgoing cast op after the forOp.
852  rewriter.setInsertionPointAfter(newForOp);
853  SmallVector<Value> newResults = newForOp.getResults();
854  newResults[yieldIdx] =
855  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
856 
857  return newResults;
858 }
859 
860 namespace {
861 // Fold away ForOp iter arguments when:
862 // 1) The op yields the iter arguments.
863 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
864 // 3) The iter arguments have no use and the corresponding (operation) results
865 // have no use.
866 //
867 // These arguments must be defined outside of the ForOp region and can just be
868 // forwarded after simplifying the op inits, yields and returns.
869 //
870 // The implementation uses `inlineBlockBefore` to steal the content of the
871 // original ForOp and avoid cloning.
872 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
874 
875  LogicalResult matchAndRewrite(scf::ForOp forOp,
876  PatternRewriter &rewriter) const final {
877  bool canonicalize = false;
878 
879  // An internal flat vector of block transfer
880  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
881  // transformed block argument mappings. This plays the role of a
882  // IRMapping for the particular use case of calling into
883  // `inlineBlockBefore`.
884  int64_t numResults = forOp.getNumResults();
885  SmallVector<bool, 4> keepMask;
886  keepMask.reserve(numResults);
887  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
888  newResultValues;
889  newBlockTransferArgs.reserve(1 + numResults);
890  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
891  newIterArgs.reserve(forOp.getInitArgs().size());
892  newYieldValues.reserve(numResults);
893  newResultValues.reserve(numResults);
894  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
895  for (auto [init, arg, result, yielded] :
896  llvm::zip(forOp.getInitArgs(), // iter from outside
897  forOp.getRegionIterArgs(), // iter inside region
898  forOp.getResults(), // op results
899  forOp.getYieldedValues() // iter yield
900  )) {
901  // Forwarded is `true` when:
902  // 1) The region `iter` argument is yielded.
903  // 2) The region `iter` argument the corresponding input is yielded.
904  // 3) The region `iter` argument has no use, and the corresponding op
905  // result has no use.
906  bool forwarded = (arg == yielded) || (init == yielded) ||
907  (arg.use_empty() && result.use_empty());
908  if (forwarded) {
909  canonicalize = true;
910  keepMask.push_back(false);
911  newBlockTransferArgs.push_back(init);
912  newResultValues.push_back(init);
913  continue;
914  }
915 
916  // Check if a previous kept argument always has the same values for init
917  // and yielded values.
918  if (auto it = initYieldToArg.find({init, yielded});
919  it != initYieldToArg.end()) {
920  canonicalize = true;
921  keepMask.push_back(false);
922  auto [sameArg, sameResult] = it->second;
923  rewriter.replaceAllUsesWith(arg, sameArg);
924  rewriter.replaceAllUsesWith(result, sameResult);
925  // The replacement value doesn't matter because there are no uses.
926  newBlockTransferArgs.push_back(init);
927  newResultValues.push_back(init);
928  continue;
929  }
930 
931  // This value is kept.
932  initYieldToArg.insert({{init, yielded}, {arg, result}});
933  keepMask.push_back(true);
934  newIterArgs.push_back(init);
935  newYieldValues.push_back(yielded);
936  newBlockTransferArgs.push_back(Value()); // placeholder with null value
937  newResultValues.push_back(Value()); // placeholder with null value
938  }
939 
940  if (!canonicalize)
941  return failure();
942 
943  scf::ForOp newForOp =
944  scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
945  forOp.getUpperBound(), forOp.getStep(), newIterArgs,
946  /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
947  newForOp->setAttrs(forOp->getAttrs());
948  Block &newBlock = newForOp.getRegion().front();
949 
950  // Replace the null placeholders with newly constructed values.
951  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
952  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
953  idx != e; ++idx) {
954  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
955  Value &newResultVal = newResultValues[idx];
956  assert((blockTransferArg && newResultVal) ||
957  (!blockTransferArg && !newResultVal));
958  if (!blockTransferArg) {
959  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
960  newResultVal = newForOp.getResult(collapsedIdx++);
961  }
962  }
963 
964  Block &oldBlock = forOp.getRegion().front();
965  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
966  "unexpected argument size mismatch");
967 
968  // No results case: the scf::ForOp builder already created a zero
969  // result terminator. Merge before this terminator and just get rid of the
970  // original terminator that has been merged in.
971  if (newIterArgs.empty()) {
972  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
973  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
974  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
975  rewriter.replaceOp(forOp, newResultValues);
976  return success();
977  }
978 
979  // No terminator case: merge and rewrite the merged terminator.
980  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
981  OpBuilder::InsertionGuard g(rewriter);
982  rewriter.setInsertionPoint(mergedTerminator);
983  SmallVector<Value, 4> filteredOperands;
984  filteredOperands.reserve(newResultValues.size());
985  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
986  if (keepMask[idx])
987  filteredOperands.push_back(mergedTerminator.getOperand(idx));
988  scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
989  filteredOperands);
990  };
991 
992  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
993  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
994  cloneFilteredTerminator(mergedYieldOp);
995  rewriter.eraseOp(mergedYieldOp);
996  rewriter.replaceOp(forOp, newResultValues);
997  return success();
998  }
999 };
1000 
1001 /// Util function that tries to compute a constant diff between u and l.
1002 /// Returns std::nullopt when the difference between two AffineValueMap is
1003 /// dynamic.
1004 static std::optional<APInt> computeConstDiff(Value l, Value u) {
1005  IntegerAttr clb, cub;
1006  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
1007  llvm::APInt lbValue = clb.getValue();
1008  llvm::APInt ubValue = cub.getValue();
1009  return ubValue - lbValue;
1010  }
1011 
1012  // Else a simple pattern match for x + c or c + x
1013  llvm::APInt diff;
1014  if (matchPattern(
1015  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
1016  matchPattern(
1017  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1018  return diff;
1019  return std::nullopt;
1020 }
1021 
1022 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1023 /// single-iteration loops with their bodies, and removes empty loops that
1024 /// iterate at least once and only return values defined outside of the loop.
1025 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1027 
1028  LogicalResult matchAndRewrite(ForOp op,
1029  PatternRewriter &rewriter) const override {
1030  // If the upper bound is the same as the lower bound, the loop does not
1031  // iterate, just remove it.
1032  if (op.getLowerBound() == op.getUpperBound()) {
1033  rewriter.replaceOp(op, op.getInitArgs());
1034  return success();
1035  }
1036 
1037  std::optional<APInt> diff =
1038  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1039  if (!diff)
1040  return failure();
1041 
1042  // If the loop is known to have 0 iterations, remove it.
1043  bool zeroOrLessIterations =
1044  diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
1045  if (zeroOrLessIterations) {
1046  rewriter.replaceOp(op, op.getInitArgs());
1047  return success();
1048  }
1049 
1050  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1051  if (!maybeStepValue)
1052  return failure();
1053 
1054  // If the loop is known to have 1 iteration, inline its body and remove the
1055  // loop.
1056  llvm::APInt stepValue = *maybeStepValue;
1057  if (stepValue.sge(*diff)) {
1058  SmallVector<Value, 4> blockArgs;
1059  blockArgs.reserve(op.getInitArgs().size() + 1);
1060  blockArgs.push_back(op.getLowerBound());
1061  llvm::append_range(blockArgs, op.getInitArgs());
1062  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1063  return success();
1064  }
1065 
1066  // Now we are left with loops that have more than 1 iterations.
1067  Block &block = op.getRegion().front();
1068  if (!llvm::hasSingleElement(block))
1069  return failure();
1070  // If the loop is empty, iterates at least once, and only returns values
1071  // defined outside of the loop, remove it and replace it with yield values.
1072  if (llvm::any_of(op.getYieldedValues(),
1073  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1074  return failure();
1075  rewriter.replaceOp(op, op.getYieldedValues());
1076  return success();
1077  }
1078 };
1079 
1080 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1081 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1082 ///
1083 /// ```
1084 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1085 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1086 /// -> (tensor<?x?xf32>) {
1087 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1088 /// scf.yield %2 : tensor<?x?xf32>
1089 /// }
1090 /// use_of(%1)
1091 /// ```
1092 ///
1093 /// folds into:
1094 ///
1095 /// ```
1096 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1097 /// -> (tensor<32x1024xf32>) {
1098 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1099 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1100 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1101 /// scf.yield %4 : tensor<32x1024xf32>
1102 /// }
1103 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1104 /// use_of(%1)
1105 /// ```
1106 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1108 
1109  LogicalResult matchAndRewrite(ForOp op,
1110  PatternRewriter &rewriter) const override {
1111  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1112  OpOperand &iterOpOperand = std::get<0>(it);
1113  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1114  if (!incomingCast ||
1115  incomingCast.getSource().getType() == incomingCast.getType())
1116  continue;
1117  // If the dest type of the cast does not preserve static information in
1118  // the source type.
1120  incomingCast.getDest().getType(),
1121  incomingCast.getSource().getType()))
1122  continue;
1123  if (!std::get<1>(it).hasOneUse())
1124  continue;
1125 
1126  // Create a new ForOp with that iter operand replaced.
1127  rewriter.replaceOp(
1129  rewriter, op, iterOpOperand, incomingCast.getSource(),
1130  [](OpBuilder &b, Location loc, Type type, Value source) {
1131  return tensor::CastOp::create(b, loc, type, source);
1132  }));
1133  return success();
1134  }
1135  return failure();
1136  }
1137 };
1138 
1139 } // namespace
1140 
1141 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1142  MLIRContext *context) {
1143  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1144  context);
1145 }
1146 
1147 std::optional<APInt> ForOp::getConstantStep() {
1148  IntegerAttr step;
1149  if (matchPattern(getStep(), m_Constant(&step)))
1150  return step.getValue();
1151  return {};
1152 }
1153 
1154 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1155  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1156 }
1157 
1158 Speculation::Speculatability ForOp::getSpeculatability() {
1159  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1160  // and End.
1161  if (auto constantStep = getConstantStep())
1162  if (*constantStep == 1)
1164 
1165  // For Step != 1, the loop may not terminate. We can add more smarts here if
1166  // needed.
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // ForallOp
1172 //===----------------------------------------------------------------------===//
1173 
1174 LogicalResult ForallOp::verify() {
1175  unsigned numLoops = getRank();
1176  // Check number of outputs.
1177  if (getNumResults() != getOutputs().size())
1178  return emitOpError("produces ")
1179  << getNumResults() << " results, but has only "
1180  << getOutputs().size() << " outputs";
1181 
1182  // Check that the body defines block arguments for thread indices and outputs.
1183  auto *body = getBody();
1184  if (body->getNumArguments() != numLoops + getOutputs().size())
1185  return emitOpError("region expects ") << numLoops << " arguments";
1186  for (int64_t i = 0; i < numLoops; ++i)
1187  if (!body->getArgument(i).getType().isIndex())
1188  return emitOpError("expects ")
1189  << i << "-th block argument to be an index";
1190  for (unsigned i = 0; i < getOutputs().size(); ++i)
1191  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1192  return emitOpError("type mismatch between ")
1193  << i << "-th output and corresponding block argument";
1194  if (getMapping().has_value() && !getMapping()->empty()) {
1195  if (getDeviceMappingAttrs().size() != numLoops)
1196  return emitOpError() << "mapping attribute size must match op rank";
1197  if (failed(getDeviceMaskingAttr()))
1198  return emitOpError() << getMappingAttrName()
1199  << " supports at most one device masking attribute";
1200  }
1201 
1202  // Verify mixed static/dynamic control variables.
1203  Operation *op = getOperation();
1204  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1205  getStaticLowerBound(),
1206  getDynamicLowerBound())))
1207  return failure();
1208  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1209  getStaticUpperBound(),
1210  getDynamicUpperBound())))
1211  return failure();
1212  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1213  getStaticStep(), getDynamicStep())))
1214  return failure();
1215 
1216  return success();
1217 }
1218 
1219 void ForallOp::print(OpAsmPrinter &p) {
1220  Operation *op = getOperation();
1221  p << " (" << getInductionVars();
1222  if (isNormalized()) {
1223  p << ") in ";
1224  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1225  /*valueTypes=*/{}, /*scalables=*/{},
1227  } else {
1228  p << ") = ";
1229  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1230  /*valueTypes=*/{}, /*scalables=*/{},
1232  p << " to ";
1233  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1234  /*valueTypes=*/{}, /*scalables=*/{},
1236  p << " step ";
1237  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1238  /*valueTypes=*/{}, /*scalables=*/{},
1240  }
1241  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1242  p << " ";
1243  if (!getRegionOutArgs().empty())
1244  p << "-> (" << getResultTypes() << ") ";
1245  p.printRegion(getRegion(),
1246  /*printEntryBlockArgs=*/false,
1247  /*printBlockTerminators=*/getNumResults() > 0);
1248  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1249  getStaticLowerBoundAttrName(),
1250  getStaticUpperBoundAttrName(),
1251  getStaticStepAttrName()});
1252 }
1253 
1254 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1255  OpBuilder b(parser.getContext());
1256  auto indexType = b.getIndexType();
1257 
1258  // Parse an opening `(` followed by thread index variables followed by `)`
1259  // TODO: when we can refer to such "induction variable"-like handles from the
1260  // declarative assembly format, we can implement the parser as a custom hook.
1263  return failure();
1264 
1265  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1266  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1267  dynamicSteps;
1268  if (succeeded(parser.parseOptionalKeyword("in"))) {
1269  // Parse upper bounds.
1270  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1271  /*valueTypes=*/nullptr,
1273  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1274  return failure();
1275 
1276  unsigned numLoops = ivs.size();
1277  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1278  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1279  } else {
1280  // Parse lower bounds.
1281  if (parser.parseEqual() ||
1282  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1283  /*valueTypes=*/nullptr,
1285 
1286  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1287  return failure();
1288 
1289  // Parse upper bounds.
1290  if (parser.parseKeyword("to") ||
1291  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1292  /*valueTypes=*/nullptr,
1294  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1295  return failure();
1296 
1297  // Parse step values.
1298  if (parser.parseKeyword("step") ||
1299  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1300  /*valueTypes=*/nullptr,
1302  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1303  return failure();
1304  }
1305 
1306  // Parse out operands and results.
1309  SMLoc outOperandsLoc = parser.getCurrentLocation();
1310  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1311  if (outOperands.size() != result.types.size())
1312  return parser.emitError(outOperandsLoc,
1313  "mismatch between out operands and types");
1314  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1315  parser.parseOptionalArrowTypeList(result.types) ||
1316  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1317  result.operands))
1318  return failure();
1319  }
1320 
1321  // Parse region.
1323  std::unique_ptr<Region> region = std::make_unique<Region>();
1324  for (auto &iv : ivs) {
1325  iv.type = b.getIndexType();
1326  regionArgs.push_back(iv);
1327  }
1328  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1329  auto &out = it.value();
1330  out.type = result.types[it.index()];
1331  regionArgs.push_back(out);
1332  }
1333  if (parser.parseRegion(*region, regionArgs))
1334  return failure();
1335 
1336  // Ensure terminator and move region.
1337  ForallOp::ensureTerminator(*region, b, result.location);
1338  result.addRegion(std::move(region));
1339 
1340  // Parse the optional attribute list.
1341  if (parser.parseOptionalAttrDict(result.attributes))
1342  return failure();
1343 
1344  result.addAttribute("staticLowerBound", staticLbs);
1345  result.addAttribute("staticUpperBound", staticUbs);
1346  result.addAttribute("staticStep", staticSteps);
1347  result.addAttribute("operandSegmentSizes",
1349  {static_cast<int32_t>(dynamicLbs.size()),
1350  static_cast<int32_t>(dynamicUbs.size()),
1351  static_cast<int32_t>(dynamicSteps.size()),
1352  static_cast<int32_t>(outOperands.size())}));
1353  return success();
1354 }
1355 
1356 // Builder that takes loop bounds.
1357 void ForallOp::build(
1360  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1361  std::optional<ArrayAttr> mapping,
1362  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1363  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1364  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1365  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1366  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1367  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1368 
1369  result.addOperands(dynamicLbs);
1370  result.addOperands(dynamicUbs);
1371  result.addOperands(dynamicSteps);
1372  result.addOperands(outputs);
1373  result.addTypes(TypeRange(outputs));
1374 
1375  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1376  b.getDenseI64ArrayAttr(staticLbs));
1377  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1378  b.getDenseI64ArrayAttr(staticUbs));
1379  result.addAttribute(getStaticStepAttrName(result.name),
1380  b.getDenseI64ArrayAttr(staticSteps));
1381  result.addAttribute(
1382  "operandSegmentSizes",
1383  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1384  static_cast<int32_t>(dynamicUbs.size()),
1385  static_cast<int32_t>(dynamicSteps.size()),
1386  static_cast<int32_t>(outputs.size())}));
1387  if (mapping.has_value()) {
1389  mapping.value());
1390  }
1391 
1392  Region *bodyRegion = result.addRegion();
1394  b.createBlock(bodyRegion);
1395  Block &bodyBlock = bodyRegion->front();
1396 
1397  // Add block arguments for indices and outputs.
1398  bodyBlock.addArguments(
1399  SmallVector<Type>(lbs.size(), b.getIndexType()),
1400  SmallVector<Location>(staticLbs.size(), result.location));
1401  bodyBlock.addArguments(
1402  TypeRange(outputs),
1403  SmallVector<Location>(outputs.size(), result.location));
1404 
1405  b.setInsertionPointToStart(&bodyBlock);
1406  if (!bodyBuilderFn) {
1407  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1408  return;
1409  }
1410  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1411 }
1412 
1413 // Builder that takes loop bounds.
1414 void ForallOp::build(
1416  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1417  std::optional<ArrayAttr> mapping,
1418  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1419  unsigned numLoops = ubs.size();
1420  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1421  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1422  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1423 }
1424 
1425 // Checks if the lbs are zeros and steps are ones.
1426 bool ForallOp::isNormalized() {
1427  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1428  return llvm::all_of(results, [&](OpFoldResult ofr) {
1429  auto intValue = getConstantIntValue(ofr);
1430  return intValue.has_value() && intValue == val;
1431  });
1432  };
1433  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1434 }
1435 
1436 InParallelOp ForallOp::getTerminator() {
1437  return cast<InParallelOp>(getBody()->getTerminator());
1438 }
1439 
1440 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1441  SmallVector<Operation *> storeOps;
1442  InParallelOp inParallelOp = getTerminator();
1443  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1444  if (auto parallelInsertSliceOp =
1445  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1446  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1447  storeOps.push_back(parallelInsertSliceOp);
1448  }
1449  }
1450  return storeOps;
1451 }
1452 
1453 SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1455  if (!getMapping())
1456  return res;
1457  for (auto attr : getMapping()->getValue()) {
1458  auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1459  if (m)
1460  res.push_back(m);
1461  }
1462  return res;
1463 }
1464 
1465 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1466  DeviceMaskingAttrInterface res;
1467  if (!getMapping())
1468  return res;
1469  for (auto attr : getMapping()->getValue()) {
1470  auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1471  if (m && res)
1472  return failure();
1473  if (m)
1474  res = m;
1475  }
1476  return res;
1477 }
1478 
1479 bool ForallOp::usesLinearMapping() {
1480  SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1481  if (ifaces.empty())
1482  return false;
1483  return ifaces.front().isLinearMapping();
1484 }
1485 
1486 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1487  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1488 }
1489 
1490 // Get lower bounds as OpFoldResult.
1491 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1492  Builder b(getOperation()->getContext());
1493  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1494 }
1495 
1496 // Get upper bounds as OpFoldResult.
1497 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1498  Builder b(getOperation()->getContext());
1499  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1500 }
1501 
1502 // Get steps as OpFoldResult.
1503 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1504  Builder b(getOperation()->getContext());
1505  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1506 }
1507 
1509  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1510  if (!tidxArg)
1511  return ForallOp();
1512  assert(tidxArg.getOwner() && "unlinked block argument");
1513  auto *containingOp = tidxArg.getOwner()->getParentOp();
1514  return dyn_cast<ForallOp>(containingOp);
1515 }
1516 
1517 namespace {
1518 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1519 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1521 
1522  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1523  PatternRewriter &rewriter) const final {
1524  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1525  if (!forallOp)
1526  return failure();
1527  Value sharedOut =
1528  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1529  ->get();
1530  rewriter.modifyOpInPlace(
1531  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1532  return success();
1533  }
1534 };
1535 
1536 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1537 public:
1539 
1540  LogicalResult matchAndRewrite(ForallOp op,
1541  PatternRewriter &rewriter) const override {
1542  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1543  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1544  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1545  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1546  failed(foldDynamicIndexList(mixedUpperBound)) &&
1547  failed(foldDynamicIndexList(mixedStep)))
1548  return failure();
1549 
1550  rewriter.modifyOpInPlace(op, [&]() {
1551  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1552  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1553  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1554  staticLowerBound);
1555  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1556  op.setStaticLowerBound(staticLowerBound);
1557 
1558  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1559  staticUpperBound);
1560  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1561  op.setStaticUpperBound(staticUpperBound);
1562 
1563  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1564  op.getDynamicStepMutable().assign(dynamicStep);
1565  op.setStaticStep(staticStep);
1566 
1567  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1568  rewriter.getDenseI32ArrayAttr(
1569  {static_cast<int32_t>(dynamicLowerBound.size()),
1570  static_cast<int32_t>(dynamicUpperBound.size()),
1571  static_cast<int32_t>(dynamicStep.size()),
1572  static_cast<int32_t>(op.getNumResults())}));
1573  });
1574  return success();
1575  }
1576 };
1577 
1578 /// The following canonicalization pattern folds the iter arguments of
1579 /// scf.forall op if :-
1580 /// 1. The corresponding result has zero uses.
1581 /// 2. The iter argument is NOT being modified within the loop body.
1582 /// uses.
1583 ///
1584 /// Example of first case :-
1585 /// INPUT:
1586 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1587 /// {
1588 /// ...
1589 /// <SOME USE OF %arg0>
1590 /// <SOME USE OF %arg1>
1591 /// <SOME USE OF %arg2>
1592 /// ...
1593 /// scf.forall.in_parallel {
1594 /// <STORE OP WITH DESTINATION %arg1>
1595 /// <STORE OP WITH DESTINATION %arg0>
1596 /// <STORE OP WITH DESTINATION %arg2>
1597 /// }
1598 /// }
1599 /// return %res#1
1600 ///
1601 /// OUTPUT:
1602 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1603 /// {
1604 /// ...
1605 /// <SOME USE OF %a>
1606 /// <SOME USE OF %new_arg0>
1607 /// <SOME USE OF %c>
1608 /// ...
1609 /// scf.forall.in_parallel {
1610 /// <STORE OP WITH DESTINATION %new_arg0>
1611 /// }
1612 /// }
1613 /// return %res
1614 ///
1615 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1616 /// scf.forall is replaced by their corresponding operands.
1617 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1618 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1619 /// this canonicalization remains valid. For more details, please refer
1620 /// to :
1621 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1622 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1623 /// handles tensor.parallel_insert_slice ops only.
1624 ///
1625 /// Example of second case :-
1626 /// INPUT:
1627 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1628 /// {
1629 /// ...
1630 /// <SOME USE OF %arg0>
1631 /// <SOME USE OF %arg1>
1632 /// ...
1633 /// scf.forall.in_parallel {
1634 /// <STORE OP WITH DESTINATION %arg1>
1635 /// }
1636 /// }
1637 /// return %res#0, %res#1
1638 ///
1639 /// OUTPUT:
1640 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1641 /// {
1642 /// ...
1643 /// <SOME USE OF %a>
1644 /// <SOME USE OF %new_arg0>
1645 /// ...
1646 /// scf.forall.in_parallel {
1647 /// <STORE OP WITH DESTINATION %new_arg0>
1648 /// }
1649 /// }
1650 /// return %a, %res
1651 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1653 
1654  LogicalResult matchAndRewrite(ForallOp forallOp,
1655  PatternRewriter &rewriter) const final {
1656  // Step 1: For a given i-th result of scf.forall, check the following :-
1657  // a. If it has any use.
1658  // b. If the corresponding iter argument is being modified within
1659  // the loop, i.e. has at least one store op with the iter arg as
1660  // its destination operand. For this we use
1661  // ForallOp::getCombiningOps(iter_arg).
1662  //
1663  // Based on the check we maintain the following :-
1664  // a. `resultToDelete` - i-th result of scf.forall that'll be
1665  // deleted.
1666  // b. `resultToReplace` - i-th result of the old scf.forall
1667  // whose uses will be replaced by the new scf.forall.
1668  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1669  // corresponding to the i-th result with at least one use.
1670  SetVector<OpResult> resultToDelete;
1671  SmallVector<Value> resultToReplace;
1672  SmallVector<Value> newOuts;
1673  for (OpResult result : forallOp.getResults()) {
1674  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1675  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1676  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1677  resultToDelete.insert(result);
1678  } else {
1679  resultToReplace.push_back(result);
1680  newOuts.push_back(opOperand->get());
1681  }
1682  }
1683 
1684  // Return early if all results of scf.forall have at least one use and being
1685  // modified within the loop.
1686  if (resultToDelete.empty())
1687  return failure();
1688 
1689  // Step 2: For the the i-th result, do the following :-
1690  // a. Fetch the corresponding BlockArgument.
1691  // b. Look for store ops (currently tensor.parallel_insert_slice)
1692  // with the BlockArgument as its destination operand.
1693  // c. Remove the operations fetched in b.
1694  for (OpResult result : resultToDelete) {
1695  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1696  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1697  SmallVector<Operation *> combiningOps =
1698  forallOp.getCombiningOps(blockArg);
1699  for (Operation *combiningOp : combiningOps)
1700  rewriter.eraseOp(combiningOp);
1701  }
1702 
1703  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1704  // fetched earlier
1705  auto newForallOp = scf::ForallOp::create(
1706  rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1707  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1708  forallOp.getMapping(),
1709  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1710 
1711  // Step 4. Merge the block of the old scf.forall into the newly created
1712  // scf.forall using the new set of arguments.
1713  Block *loopBody = forallOp.getBody();
1714  Block *newLoopBody = newForallOp.getBody();
1715  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1716  // Form initial new bbArg list with just the control operands of the new
1717  // scf.forall op.
1718  SmallVector<Value> newBlockArgs =
1719  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1720  [](BlockArgument b) -> Value { return b; });
1721  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1722  unsigned index = 0;
1723  // Take the new corresponding bbArg if the old bbArg was used as a
1724  // destination in the in_parallel op. For all other bbArgs, use the
1725  // corresponding init_arg from the old scf.forall op.
1726  for (OpResult result : forallOp.getResults()) {
1727  if (resultToDelete.count(result)) {
1728  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1729  } else {
1730  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1731  }
1732  }
1733  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1734 
1735  // Step 5. Replace the uses of result of old scf.forall with that of the new
1736  // scf.forall.
1737  for (auto &&[oldResult, newResult] :
1738  llvm::zip(resultToReplace, newForallOp->getResults()))
1739  rewriter.replaceAllUsesWith(oldResult, newResult);
1740 
1741  // Step 6. Replace the uses of those values that either has no use or are
1742  // not being modified within the loop with the corresponding
1743  // OpOperand.
1744  for (OpResult oldResult : resultToDelete)
1745  rewriter.replaceAllUsesWith(oldResult,
1746  forallOp.getTiedOpOperand(oldResult)->get());
1747  return success();
1748  }
1749 };
1750 
1751 struct ForallOpSingleOrZeroIterationDimsFolder
1752  : public OpRewritePattern<ForallOp> {
1754 
1755  LogicalResult matchAndRewrite(ForallOp op,
1756  PatternRewriter &rewriter) const override {
1757  // Do not fold dimensions if they are mapped to processing units.
1758  if (op.getMapping().has_value() && !op.getMapping()->empty())
1759  return failure();
1760  Location loc = op.getLoc();
1761 
1762  // Compute new loop bounds that omit all single-iteration loop dimensions.
1763  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1764  newMixedSteps;
1765  IRMapping mapping;
1766  for (auto [lb, ub, step, iv] :
1767  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1768  op.getMixedStep(), op.getInductionVars())) {
1769  auto numIterations = constantTripCount(lb, ub, step);
1770  if (numIterations.has_value()) {
1771  // Remove the loop if it performs zero iterations.
1772  if (*numIterations == 0) {
1773  rewriter.replaceOp(op, op.getOutputs());
1774  return success();
1775  }
1776  // Replace the loop induction variable by the lower bound if the loop
1777  // performs a single iteration. Otherwise, copy the loop bounds.
1778  if (*numIterations == 1) {
1779  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1780  continue;
1781  }
1782  }
1783  newMixedLowerBounds.push_back(lb);
1784  newMixedUpperBounds.push_back(ub);
1785  newMixedSteps.push_back(step);
1786  }
1787 
1788  // All of the loop dimensions perform a single iteration. Inline loop body.
1789  if (newMixedLowerBounds.empty()) {
1790  promote(rewriter, op);
1791  return success();
1792  }
1793 
1794  // Exit if none of the loop dimensions perform a single iteration.
1795  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1796  return rewriter.notifyMatchFailure(
1797  op, "no dimensions have 0 or 1 iterations");
1798  }
1799 
1800  // Replace the loop by a lower-dimensional loop.
1801  ForallOp newOp;
1802  newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1803  newMixedUpperBounds, newMixedSteps,
1804  op.getOutputs(), std::nullopt, nullptr);
1805  newOp.getBodyRegion().getBlocks().clear();
1806  // The new loop needs to keep all attributes from the old one, except for
1807  // "operandSegmentSizes" and static loop bound attributes which capture
1808  // the outdated information of the old iteration domain.
1809  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1810  newOp.getStaticLowerBoundAttrName(),
1811  newOp.getStaticUpperBoundAttrName(),
1812  newOp.getStaticStepAttrName()};
1813  for (const auto &namedAttr : op->getAttrs()) {
1814  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1815  continue;
1816  rewriter.modifyOpInPlace(newOp, [&]() {
1817  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1818  });
1819  }
1820  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1821  newOp.getRegion().begin(), mapping);
1822  rewriter.replaceOp(op, newOp.getResults());
1823  return success();
1824  }
1825 };
1826 
1827 /// Replace all induction vars with a single trip count with their lower bound.
1828 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1830 
1831  LogicalResult matchAndRewrite(ForallOp op,
1832  PatternRewriter &rewriter) const override {
1833  Location loc = op.getLoc();
1834  bool changed = false;
1835  for (auto [lb, ub, step, iv] :
1836  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1837  op.getMixedStep(), op.getInductionVars())) {
1838  if (iv.hasNUses(0))
1839  continue;
1840  auto numIterations = constantTripCount(lb, ub, step);
1841  if (!numIterations.has_value() || numIterations.value() != 1) {
1842  continue;
1843  }
1844  rewriter.replaceAllUsesWith(
1845  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1846  changed = true;
1847  }
1848  return success(changed);
1849  }
1850 };
1851 
1852 struct FoldTensorCastOfOutputIntoForallOp
1853  : public OpRewritePattern<scf::ForallOp> {
1855 
1856  struct TypeCast {
1857  Type srcType;
1858  Type dstType;
1859  };
1860 
1861  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1862  PatternRewriter &rewriter) const final {
1863  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1864  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1865  for (auto en : llvm::enumerate(newOutputTensors)) {
1866  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1867  if (!castOp)
1868  continue;
1869 
1870  // Only casts that that preserve static information, i.e. will make the
1871  // loop result type "more" static than before, will be folded.
1872  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1873  castOp.getSource().getType())) {
1874  continue;
1875  }
1876 
1877  tensorCastProducers[en.index()] =
1878  TypeCast{castOp.getSource().getType(), castOp.getType()};
1879  newOutputTensors[en.index()] = castOp.getSource();
1880  }
1881 
1882  if (tensorCastProducers.empty())
1883  return failure();
1884 
1885  // Create new loop.
1886  Location loc = forallOp.getLoc();
1887  auto newForallOp = ForallOp::create(
1888  rewriter, loc, forallOp.getMixedLowerBound(),
1889  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1890  newOutputTensors, forallOp.getMapping(),
1891  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1892  auto castBlockArgs =
1893  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1894  for (auto [index, cast] : tensorCastProducers) {
1895  Value &oldTypeBBArg = castBlockArgs[index];
1896  oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1897  cast.dstType, oldTypeBBArg);
1898  }
1899 
1900  // Move old body into new parallel loop.
1901  SmallVector<Value> ivsBlockArgs =
1902  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1903  ivsBlockArgs.append(castBlockArgs);
1904  rewriter.mergeBlocks(forallOp.getBody(),
1905  bbArgs.front().getParentBlock(), ivsBlockArgs);
1906  });
1907 
1908  // After `mergeBlocks` happened, the destinations in the terminator were
1909  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1910  // destination have to be updated to point to the output bbArgs directly.
1911  auto terminator = newForallOp.getTerminator();
1912  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1913  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1914  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1915  insertSliceOp.getDestMutable().assign(outputBlockArg);
1916  }
1917 
1918  // Cast results back to the original types.
1919  rewriter.setInsertionPointAfter(newForallOp);
1920  SmallVector<Value> castResults = newForallOp.getResults();
1921  for (auto &item : tensorCastProducers) {
1922  Value &oldTypeResult = castResults[item.first];
1923  oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1924  oldTypeResult);
1925  }
1926  rewriter.replaceOp(forallOp, castResults);
1927  return success();
1928  }
1929 };
1930 
1931 } // namespace
1932 
1933 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1934  MLIRContext *context) {
1935  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1936  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1937  ForallOpSingleOrZeroIterationDimsFolder,
1938  ForallOpReplaceConstantInductionVar>(context);
1939 }
1940 
1941 /// Given the region at `index`, or the parent operation if `index` is None,
1942 /// return the successor regions. These are the regions that may be selected
1943 /// during the flow of control. `operands` is a set of optional attributes that
1944 /// correspond to a constant value for each operand, or null if that operand is
1945 /// not a constant.
1946 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1948  // In accordance with the semantics of forall, its body is executed in
1949  // parallel by multiple threads. We should not expect to branch back into
1950  // the forall body after the region's execution is complete.
1951  if (point.isParent())
1952  regions.push_back(RegionSuccessor(&getRegion()));
1953  else
1954  regions.push_back(RegionSuccessor());
1955 }
1956 
1957 //===----------------------------------------------------------------------===//
1958 // InParallelOp
1959 //===----------------------------------------------------------------------===//
1960 
1961 // Build a InParallelOp with mixed static and dynamic entries.
1962 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1964  Region *bodyRegion = result.addRegion();
1965  b.createBlock(bodyRegion);
1966 }
1967 
1968 LogicalResult InParallelOp::verify() {
1969  scf::ForallOp forallOp =
1970  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1971  if (!forallOp)
1972  return this->emitOpError("expected forall op parent");
1973 
1974  // TODO: InParallelOpInterface.
1975  for (Operation &op : getRegion().front().getOperations()) {
1976  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1977  return this->emitOpError("expected only ")
1978  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1979  }
1980 
1981  // Verify that inserts are into out block arguments.
1982  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1983  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1984  if (!llvm::is_contained(regionOutArgs, dest))
1985  return op.emitOpError("may only insert into an output block argument");
1986  }
1987  return success();
1988 }
1989 
1991  p << " ";
1992  p.printRegion(getRegion(),
1993  /*printEntryBlockArgs=*/false,
1994  /*printBlockTerminators=*/false);
1995  p.printOptionalAttrDict(getOperation()->getAttrs());
1996 }
1997 
1998 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1999  auto &builder = parser.getBuilder();
2000 
2002  std::unique_ptr<Region> region = std::make_unique<Region>();
2003  if (parser.parseRegion(*region, regionOperands))
2004  return failure();
2005 
2006  if (region->empty())
2007  OpBuilder(builder.getContext()).createBlock(region.get());
2008  result.addRegion(std::move(region));
2009 
2010  // Parse the optional attribute list.
2011  if (parser.parseOptionalAttrDict(result.attributes))
2012  return failure();
2013  return success();
2014 }
2015 
2016 OpResult InParallelOp::getParentResult(int64_t idx) {
2017  return getOperation()->getParentOp()->getResult(idx);
2018 }
2019 
2020 SmallVector<BlockArgument> InParallelOp::getDests() {
2021  return llvm::to_vector<4>(
2022  llvm::map_range(getYieldingOps(), [](Operation &op) {
2023  // Add new ops here as needed.
2024  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2025  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2026  }));
2027 }
2028 
2029 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2030  return getRegion().front().getOperations();
2031 }
2032 
2033 //===----------------------------------------------------------------------===//
2034 // IfOp
2035 //===----------------------------------------------------------------------===//
2036 
2038  assert(a && "expected non-empty operation");
2039  assert(b && "expected non-empty operation");
2040 
2041  IfOp ifOp = a->getParentOfType<IfOp>();
2042  while (ifOp) {
2043  // Check if b is inside ifOp. (We already know that a is.)
2044  if (ifOp->isProperAncestor(b))
2045  // b is contained in ifOp. a and b are in mutually exclusive branches if
2046  // they are in different blocks of ifOp.
2047  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2048  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2049  // Check next enclosing IfOp.
2050  ifOp = ifOp->getParentOfType<IfOp>();
2051  }
2052 
2053  // Could not find a common IfOp among a's and b's ancestors.
2054  return false;
2055 }
2056 
2057 LogicalResult
2058 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2059  IfOp::Adaptor adaptor,
2060  SmallVectorImpl<Type> &inferredReturnTypes) {
2061  if (adaptor.getRegions().empty())
2062  return failure();
2063  Region *r = &adaptor.getThenRegion();
2064  if (r->empty())
2065  return failure();
2066  Block &b = r->front();
2067  if (b.empty())
2068  return failure();
2069  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2070  if (!yieldOp)
2071  return failure();
2072  TypeRange types = yieldOp.getOperandTypes();
2073  llvm::append_range(inferredReturnTypes, types);
2074  return success();
2075 }
2076 
2077 void IfOp::build(OpBuilder &builder, OperationState &result,
2078  TypeRange resultTypes, Value cond) {
2079  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2080  /*addElseBlock=*/false);
2081 }
2082 
2083 void IfOp::build(OpBuilder &builder, OperationState &result,
2084  TypeRange resultTypes, Value cond, bool addThenBlock,
2085  bool addElseBlock) {
2086  assert((!addElseBlock || addThenBlock) &&
2087  "must not create else block w/o then block");
2088  result.addTypes(resultTypes);
2089  result.addOperands(cond);
2090 
2091  // Add regions and blocks.
2092  OpBuilder::InsertionGuard guard(builder);
2093  Region *thenRegion = result.addRegion();
2094  if (addThenBlock)
2095  builder.createBlock(thenRegion);
2096  Region *elseRegion = result.addRegion();
2097  if (addElseBlock)
2098  builder.createBlock(elseRegion);
2099 }
2100 
2101 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2102  bool withElseRegion) {
2103  build(builder, result, TypeRange{}, cond, withElseRegion);
2104 }
2105 
2106 void IfOp::build(OpBuilder &builder, OperationState &result,
2107  TypeRange resultTypes, Value cond, bool withElseRegion) {
2108  result.addTypes(resultTypes);
2109  result.addOperands(cond);
2110 
2111  // Build then region.
2112  OpBuilder::InsertionGuard guard(builder);
2113  Region *thenRegion = result.addRegion();
2114  builder.createBlock(thenRegion);
2115  if (resultTypes.empty())
2116  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2117 
2118  // Build else region.
2119  Region *elseRegion = result.addRegion();
2120  if (withElseRegion) {
2121  builder.createBlock(elseRegion);
2122  if (resultTypes.empty())
2123  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2124  }
2125 }
2126 
2127 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2128  function_ref<void(OpBuilder &, Location)> thenBuilder,
2129  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2130  assert(thenBuilder && "the builder callback for 'then' must be present");
2131  result.addOperands(cond);
2132 
2133  // Build then region.
2134  OpBuilder::InsertionGuard guard(builder);
2135  Region *thenRegion = result.addRegion();
2136  builder.createBlock(thenRegion);
2137  thenBuilder(builder, result.location);
2138 
2139  // Build else region.
2140  Region *elseRegion = result.addRegion();
2141  if (elseBuilder) {
2142  builder.createBlock(elseRegion);
2143  elseBuilder(builder, result.location);
2144  }
2145 
2146  // Infer result types.
2147  SmallVector<Type> inferredReturnTypes;
2148  MLIRContext *ctx = builder.getContext();
2149  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2150  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2151  /*properties=*/nullptr, result.regions,
2152  inferredReturnTypes))) {
2153  result.addTypes(inferredReturnTypes);
2154  }
2155 }
2156 
2157 LogicalResult IfOp::verify() {
2158  if (getNumResults() != 0 && getElseRegion().empty())
2159  return emitOpError("must have an else block if defining values");
2160  return success();
2161 }
2162 
2163 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2164  // Create the regions for 'then'.
2165  result.regions.reserve(2);
2166  Region *thenRegion = result.addRegion();
2167  Region *elseRegion = result.addRegion();
2168 
2169  auto &builder = parser.getBuilder();
2171  Type i1Type = builder.getIntegerType(1);
2172  if (parser.parseOperand(cond) ||
2173  parser.resolveOperand(cond, i1Type, result.operands))
2174  return failure();
2175  // Parse optional results type list.
2176  if (parser.parseOptionalArrowTypeList(result.types))
2177  return failure();
2178  // Parse the 'then' region.
2179  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2180  return failure();
2181  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2182 
2183  // If we find an 'else' keyword then parse the 'else' region.
2184  if (!parser.parseOptionalKeyword("else")) {
2185  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2186  return failure();
2187  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2188  }
2189 
2190  // Parse the optional attribute list.
2191  if (parser.parseOptionalAttrDict(result.attributes))
2192  return failure();
2193  return success();
2194 }
2195 
2196 void IfOp::print(OpAsmPrinter &p) {
2197  bool printBlockTerminators = false;
2198 
2199  p << " " << getCondition();
2200  if (!getResults().empty()) {
2201  p << " -> (" << getResultTypes() << ")";
2202  // Print yield explicitly if the op defines values.
2203  printBlockTerminators = true;
2204  }
2205  p << ' ';
2206  p.printRegion(getThenRegion(),
2207  /*printEntryBlockArgs=*/false,
2208  /*printBlockTerminators=*/printBlockTerminators);
2209 
2210  // Print the 'else' regions if it exists and has a block.
2211  auto &elseRegion = getElseRegion();
2212  if (!elseRegion.empty()) {
2213  p << " else ";
2214  p.printRegion(elseRegion,
2215  /*printEntryBlockArgs=*/false,
2216  /*printBlockTerminators=*/printBlockTerminators);
2217  }
2218 
2219  p.printOptionalAttrDict((*this)->getAttrs());
2220 }
2221 
2222 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2224  // The `then` and the `else` region branch back to the parent operation.
2225  if (!point.isParent()) {
2226  regions.push_back(RegionSuccessor(getResults()));
2227  return;
2228  }
2229 
2230  regions.push_back(RegionSuccessor(&getThenRegion()));
2231 
2232  // Don't consider the else region if it is empty.
2233  Region *elseRegion = &this->getElseRegion();
2234  if (elseRegion->empty())
2235  regions.push_back(RegionSuccessor());
2236  else
2237  regions.push_back(RegionSuccessor(elseRegion));
2238 }
2239 
2240 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2242  FoldAdaptor adaptor(operands, *this);
2243  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2244  if (!boolAttr || boolAttr.getValue())
2245  regions.emplace_back(&getThenRegion());
2246 
2247  // If the else region is empty, execution continues after the parent op.
2248  if (!boolAttr || !boolAttr.getValue()) {
2249  if (!getElseRegion().empty())
2250  regions.emplace_back(&getElseRegion());
2251  else
2252  regions.emplace_back(getResults());
2253  }
2254 }
2255 
2256 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2257  SmallVectorImpl<OpFoldResult> &results) {
2258  // if (!c) then A() else B() -> if c then B() else A()
2259  if (getElseRegion().empty())
2260  return failure();
2261 
2262  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2263  if (!xorStmt)
2264  return failure();
2265 
2266  if (!matchPattern(xorStmt.getRhs(), m_One()))
2267  return failure();
2268 
2269  getConditionMutable().assign(xorStmt.getLhs());
2270  Block *thenBlock = &getThenRegion().front();
2271  // It would be nicer to use iplist::swap, but that has no implemented
2272  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2273  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2274  getElseRegion().getBlocks());
2275  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2276  getThenRegion().getBlocks(), thenBlock);
2277  return success();
2278 }
2279 
2280 void IfOp::getRegionInvocationBounds(
2281  ArrayRef<Attribute> operands,
2282  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2283  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2284  // If the condition is known, then one region is known to be executed once
2285  // and the other zero times.
2286  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2287  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2288  } else {
2289  // Non-constant condition. Each region may be executed 0 or 1 times.
2290  invocationBounds.assign(2, {0, 1});
2291  }
2292 }
2293 
2294 namespace {
2295 // Pattern to remove unused IfOp results.
2296 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2298 
2299  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2300  PatternRewriter &rewriter) const {
2301  // Move all operations to the destination block.
2302  rewriter.mergeBlocks(source, dest);
2303  // Replace the yield op by one that returns only the used values.
2304  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2305  SmallVector<Value, 4> usedOperands;
2306  llvm::transform(usedResults, std::back_inserter(usedOperands),
2307  [&](OpResult result) {
2308  return yieldOp.getOperand(result.getResultNumber());
2309  });
2310  rewriter.modifyOpInPlace(yieldOp,
2311  [&]() { yieldOp->setOperands(usedOperands); });
2312  }
2313 
2314  LogicalResult matchAndRewrite(IfOp op,
2315  PatternRewriter &rewriter) const override {
2316  // Compute the list of used results.
2317  SmallVector<OpResult, 4> usedResults;
2318  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2319  [](OpResult result) { return !result.use_empty(); });
2320 
2321  // Replace the operation if only a subset of its results have uses.
2322  if (usedResults.size() == op.getNumResults())
2323  return failure();
2324 
2325  // Compute the result types of the replacement operation.
2326  SmallVector<Type, 4> newTypes;
2327  llvm::transform(usedResults, std::back_inserter(newTypes),
2328  [](OpResult result) { return result.getType(); });
2329 
2330  // Create a replacement operation with empty then and else regions.
2331  auto newOp =
2332  IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2333  rewriter.createBlock(&newOp.getThenRegion());
2334  rewriter.createBlock(&newOp.getElseRegion());
2335 
2336  // Move the bodies and replace the terminators (note there is a then and
2337  // an else region since the operation returns results).
2338  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2339  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2340 
2341  // Replace the operation by the new one.
2342  SmallVector<Value, 4> repResults(op.getNumResults());
2343  for (const auto &en : llvm::enumerate(usedResults))
2344  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2345  rewriter.replaceOp(op, repResults);
2346  return success();
2347  }
2348 };
2349 
2350 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2352 
2353  LogicalResult matchAndRewrite(IfOp op,
2354  PatternRewriter &rewriter) const override {
2355  BoolAttr condition;
2356  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2357  return failure();
2358 
2359  if (condition.getValue())
2360  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2361  else if (!op.getElseRegion().empty())
2362  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2363  else
2364  rewriter.eraseOp(op);
2365 
2366  return success();
2367  }
2368 };
2369 
2370 /// Hoist any yielded results whose operands are defined outside
2371 /// the if, to a select instruction.
2372 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2374 
2375  LogicalResult matchAndRewrite(IfOp op,
2376  PatternRewriter &rewriter) const override {
2377  if (op->getNumResults() == 0)
2378  return failure();
2379 
2380  auto cond = op.getCondition();
2381  auto thenYieldArgs = op.thenYield().getOperands();
2382  auto elseYieldArgs = op.elseYield().getOperands();
2383 
2384  SmallVector<Type> nonHoistable;
2385  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2386  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2387  &op.getElseRegion() == falseVal.getParentRegion())
2388  nonHoistable.push_back(trueVal.getType());
2389  }
2390  // Early exit if there aren't any yielded values we can
2391  // hoist outside the if.
2392  if (nonHoistable.size() == op->getNumResults())
2393  return failure();
2394 
2395  IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2396  /*withElseRegion=*/false);
2397  if (replacement.thenBlock())
2398  rewriter.eraseBlock(replacement.thenBlock());
2399  replacement.getThenRegion().takeBody(op.getThenRegion());
2400  replacement.getElseRegion().takeBody(op.getElseRegion());
2401 
2402  SmallVector<Value> results(op->getNumResults());
2403  assert(thenYieldArgs.size() == results.size());
2404  assert(elseYieldArgs.size() == results.size());
2405 
2406  SmallVector<Value> trueYields;
2407  SmallVector<Value> falseYields;
2408  rewriter.setInsertionPoint(replacement);
2409  for (const auto &it :
2410  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2411  Value trueVal = std::get<0>(it.value());
2412  Value falseVal = std::get<1>(it.value());
2413  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2414  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2415  results[it.index()] = replacement.getResult(trueYields.size());
2416  trueYields.push_back(trueVal);
2417  falseYields.push_back(falseVal);
2418  } else if (trueVal == falseVal)
2419  results[it.index()] = trueVal;
2420  else
2421  results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2422  cond, trueVal, falseVal);
2423  }
2424 
2425  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2426  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2427 
2428  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2429  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2430 
2431  rewriter.replaceOp(op, results);
2432  return success();
2433  }
2434 };
2435 
2436 /// Remove any statements from an if that are equivalent to the condition
2437 /// or its negation. For example:
2438 ///
2439 /// %res:2 = scf.if %cmp {
2440 /// yield something(), true
2441 /// } else {
2442 /// yield something2(), false
2443 /// }
2444 /// print(%res#1)
2445 ///
2446 /// becomes
2447 /// %res = scf.if %cmp {
2448 /// yield something()
2449 /// } else {
2450 /// yield something2()
2451 /// }
2452 /// print(%cmp)
2453 ///
2454 /// Additionally if both branches yield the same value, replace all uses
2455 /// of the result with the yielded value.
2456 ///
2457 /// %res:2 = scf.if %cmp {
2458 /// yield something(), %arg1
2459 /// } else {
2460 /// yield something2(), %arg1
2461 /// }
2462 /// print(%res#1)
2463 ///
2464 /// becomes
2465 /// %res = scf.if %cmp {
2466 /// yield something()
2467 /// } else {
2468 /// yield something2()
2469 /// }
2470 /// print(%arg1)
2471 ///
2472 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2474 
2475  LogicalResult matchAndRewrite(IfOp op,
2476  PatternRewriter &rewriter) const override {
2477  // Early exit if there are no results that could be replaced.
2478  if (op.getNumResults() == 0)
2479  return failure();
2480 
2481  auto trueYield =
2482  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2483  auto falseYield =
2484  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2485 
2486  rewriter.setInsertionPoint(op->getBlock(),
2487  op.getOperation()->getIterator());
2488  bool changed = false;
2489  Type i1Ty = rewriter.getI1Type();
2490  for (auto [trueResult, falseResult, opResult] :
2491  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2492  op.getResults())) {
2493  if (trueResult == falseResult) {
2494  if (!opResult.use_empty()) {
2495  opResult.replaceAllUsesWith(trueResult);
2496  changed = true;
2497  }
2498  continue;
2499  }
2500 
2501  BoolAttr trueYield, falseYield;
2502  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2503  !matchPattern(falseResult, m_Constant(&falseYield)))
2504  continue;
2505 
2506  bool trueVal = trueYield.getValue();
2507  bool falseVal = falseYield.getValue();
2508  if (!trueVal && falseVal) {
2509  if (!opResult.use_empty()) {
2510  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2511  Value notCond = arith::XOrIOp::create(
2512  rewriter, op.getLoc(), op.getCondition(),
2513  constDialect
2514  ->materializeConstant(rewriter,
2515  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2516  op.getLoc())
2517  ->getResult(0));
2518  opResult.replaceAllUsesWith(notCond);
2519  changed = true;
2520  }
2521  }
2522  if (trueVal && !falseVal) {
2523  if (!opResult.use_empty()) {
2524  opResult.replaceAllUsesWith(op.getCondition());
2525  changed = true;
2526  }
2527  }
2528  }
2529  return success(changed);
2530  }
2531 };
2532 
2533 /// Merge any consecutive scf.if's with the same condition.
2534 ///
2535 /// scf.if %cond {
2536 /// firstCodeTrue();...
2537 /// } else {
2538 /// firstCodeFalse();...
2539 /// }
2540 /// %res = scf.if %cond {
2541 /// secondCodeTrue();...
2542 /// } else {
2543 /// secondCodeFalse();...
2544 /// }
2545 ///
2546 /// becomes
2547 /// %res = scf.if %cmp {
2548 /// firstCodeTrue();...
2549 /// secondCodeTrue();...
2550 /// } else {
2551 /// firstCodeFalse();...
2552 /// secondCodeFalse();...
2553 /// }
2554 struct CombineIfs : public OpRewritePattern<IfOp> {
2556 
2557  LogicalResult matchAndRewrite(IfOp nextIf,
2558  PatternRewriter &rewriter) const override {
2559  Block *parent = nextIf->getBlock();
2560  if (nextIf == &parent->front())
2561  return failure();
2562 
2563  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2564  if (!prevIf)
2565  return failure();
2566 
2567  // Determine the logical then/else blocks when prevIf's
2568  // condition is used. Null means the block does not exist
2569  // in that case (e.g. empty else). If neither of these
2570  // are set, the two conditions cannot be compared.
2571  Block *nextThen = nullptr;
2572  Block *nextElse = nullptr;
2573  if (nextIf.getCondition() == prevIf.getCondition()) {
2574  nextThen = nextIf.thenBlock();
2575  if (!nextIf.getElseRegion().empty())
2576  nextElse = nextIf.elseBlock();
2577  }
2578  if (arith::XOrIOp notv =
2579  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2580  if (notv.getLhs() == prevIf.getCondition() &&
2581  matchPattern(notv.getRhs(), m_One())) {
2582  nextElse = nextIf.thenBlock();
2583  if (!nextIf.getElseRegion().empty())
2584  nextThen = nextIf.elseBlock();
2585  }
2586  }
2587  if (arith::XOrIOp notv =
2588  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2589  if (notv.getLhs() == nextIf.getCondition() &&
2590  matchPattern(notv.getRhs(), m_One())) {
2591  nextElse = nextIf.thenBlock();
2592  if (!nextIf.getElseRegion().empty())
2593  nextThen = nextIf.elseBlock();
2594  }
2595  }
2596 
2597  if (!nextThen && !nextElse)
2598  return failure();
2599 
2600  SmallVector<Value> prevElseYielded;
2601  if (!prevIf.getElseRegion().empty())
2602  prevElseYielded = prevIf.elseYield().getOperands();
2603  // Replace all uses of return values of op within nextIf with the
2604  // corresponding yields
2605  for (auto it : llvm::zip(prevIf.getResults(),
2606  prevIf.thenYield().getOperands(), prevElseYielded))
2607  for (OpOperand &use :
2608  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2609  if (nextThen && nextThen->getParent()->isAncestor(
2610  use.getOwner()->getParentRegion())) {
2611  rewriter.startOpModification(use.getOwner());
2612  use.set(std::get<1>(it));
2613  rewriter.finalizeOpModification(use.getOwner());
2614  } else if (nextElse && nextElse->getParent()->isAncestor(
2615  use.getOwner()->getParentRegion())) {
2616  rewriter.startOpModification(use.getOwner());
2617  use.set(std::get<2>(it));
2618  rewriter.finalizeOpModification(use.getOwner());
2619  }
2620  }
2621 
2622  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2623  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2624 
2625  IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2626  prevIf.getCondition(), /*hasElse=*/false);
2627  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2628 
2629  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2630  combinedIf.getThenRegion(),
2631  combinedIf.getThenRegion().begin());
2632 
2633  if (nextThen) {
2634  YieldOp thenYield = combinedIf.thenYield();
2635  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2636  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2637  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2638 
2639  SmallVector<Value> mergedYields(thenYield.getOperands());
2640  llvm::append_range(mergedYields, thenYield2.getOperands());
2641  YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2642  rewriter.eraseOp(thenYield);
2643  rewriter.eraseOp(thenYield2);
2644  }
2645 
2646  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2647  combinedIf.getElseRegion(),
2648  combinedIf.getElseRegion().begin());
2649 
2650  if (nextElse) {
2651  if (combinedIf.getElseRegion().empty()) {
2652  rewriter.inlineRegionBefore(*nextElse->getParent(),
2653  combinedIf.getElseRegion(),
2654  combinedIf.getElseRegion().begin());
2655  } else {
2656  YieldOp elseYield = combinedIf.elseYield();
2657  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2658  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2659 
2660  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2661 
2662  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2663  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2664 
2665  YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2666  rewriter.eraseOp(elseYield);
2667  rewriter.eraseOp(elseYield2);
2668  }
2669  }
2670 
2671  SmallVector<Value> prevValues;
2672  SmallVector<Value> nextValues;
2673  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2674  if (pair.index() < prevIf.getNumResults())
2675  prevValues.push_back(pair.value());
2676  else
2677  nextValues.push_back(pair.value());
2678  }
2679  rewriter.replaceOp(prevIf, prevValues);
2680  rewriter.replaceOp(nextIf, nextValues);
2681  return success();
2682  }
2683 };
2684 
2685 /// Pattern to remove an empty else branch.
2686 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2688 
2689  LogicalResult matchAndRewrite(IfOp ifOp,
2690  PatternRewriter &rewriter) const override {
2691  // Cannot remove else region when there are operation results.
2692  if (ifOp.getNumResults())
2693  return failure();
2694  Block *elseBlock = ifOp.elseBlock();
2695  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2696  return failure();
2697  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2698  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2699  newIfOp.getThenRegion().begin());
2700  rewriter.eraseOp(ifOp);
2701  return success();
2702  }
2703 };
2704 
2705 /// Convert nested `if`s into `arith.andi` + single `if`.
2706 ///
2707 /// scf.if %arg0 {
2708 /// scf.if %arg1 {
2709 /// ...
2710 /// scf.yield
2711 /// }
2712 /// scf.yield
2713 /// }
2714 /// becomes
2715 ///
2716 /// %0 = arith.andi %arg0, %arg1
2717 /// scf.if %0 {
2718 /// ...
2719 /// scf.yield
2720 /// }
2721 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2723 
2724  LogicalResult matchAndRewrite(IfOp op,
2725  PatternRewriter &rewriter) const override {
2726  auto nestedOps = op.thenBlock()->without_terminator();
2727  // Nested `if` must be the only op in block.
2728  if (!llvm::hasSingleElement(nestedOps))
2729  return failure();
2730 
2731  // If there is an else block, it can only yield
2732  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2733  return failure();
2734 
2735  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2736  if (!nestedIf)
2737  return failure();
2738 
2739  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2740  return failure();
2741 
2742  SmallVector<Value> thenYield(op.thenYield().getOperands());
2743  SmallVector<Value> elseYield;
2744  if (op.elseBlock())
2745  llvm::append_range(elseYield, op.elseYield().getOperands());
2746 
2747  // A list of indices for which we should upgrade the value yielded
2748  // in the else to a select.
2749  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2750 
2751  // If the outer scf.if yields a value produced by the inner scf.if,
2752  // only permit combining if the value yielded when the condition
2753  // is false in the outer scf.if is the same value yielded when the
2754  // inner scf.if condition is false.
2755  // Note that the array access to elseYield will not go out of bounds
2756  // since it must have the same length as thenYield, since they both
2757  // come from the same scf.if.
2758  for (const auto &tup : llvm::enumerate(thenYield)) {
2759  if (tup.value().getDefiningOp() == nestedIf) {
2760  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2761  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2762  elseYield[tup.index()]) {
2763  return failure();
2764  }
2765  // If the correctness test passes, we will yield
2766  // corresponding value from the inner scf.if
2767  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2768  continue;
2769  }
2770 
2771  // Otherwise, we need to ensure the else block of the combined
2772  // condition still returns the same value when the outer condition is
2773  // true and the inner condition is false. This can be accomplished if
2774  // the then value is defined outside the outer scf.if and we replace the
2775  // value with a select that considers just the outer condition. Since
2776  // the else region contains just the yield, its yielded value is
2777  // defined outside the scf.if, by definition.
2778 
2779  // If the then value is defined within the scf.if, bail.
2780  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2781  return failure();
2782  }
2783  elseYieldsToUpgradeToSelect.push_back(tup.index());
2784  }
2785 
2786  Location loc = op.getLoc();
2787  Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2788  nestedIf.getCondition());
2789  auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2790  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2791 
2792  SmallVector<Value> results;
2793  llvm::append_range(results, newIf.getResults());
2794  rewriter.setInsertionPoint(newIf);
2795 
2796  for (auto idx : elseYieldsToUpgradeToSelect)
2797  results[idx] =
2798  arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2799  thenYield[idx], elseYield[idx]);
2800 
2801  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2802  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2803  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2804  if (!elseYield.empty()) {
2805  rewriter.createBlock(&newIf.getElseRegion());
2806  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2807  YieldOp::create(rewriter, loc, elseYield);
2808  }
2809  rewriter.replaceOp(op, results);
2810  return success();
2811  }
2812 };
2813 
2814 } // namespace
2815 
2816 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2817  MLIRContext *context) {
2818  results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
2819  RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
2820  ReplaceIfYieldWithConditionOrValue>(context);
2821 }
2822 
2823 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2824 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2825 Block *IfOp::elseBlock() {
2826  Region &r = getElseRegion();
2827  if (r.empty())
2828  return nullptr;
2829  return &r.back();
2830 }
2831 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2832 
2833 //===----------------------------------------------------------------------===//
2834 // ParallelOp
2835 //===----------------------------------------------------------------------===//
2836 
2837 void ParallelOp::build(
2838  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2839  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2841  bodyBuilderFn) {
2842  result.addOperands(lowerBounds);
2843  result.addOperands(upperBounds);
2844  result.addOperands(steps);
2845  result.addOperands(initVals);
2846  result.addAttribute(
2847  ParallelOp::getOperandSegmentSizeAttr(),
2848  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2849  static_cast<int32_t>(upperBounds.size()),
2850  static_cast<int32_t>(steps.size()),
2851  static_cast<int32_t>(initVals.size())}));
2852  result.addTypes(initVals.getTypes());
2853 
2854  OpBuilder::InsertionGuard guard(builder);
2855  unsigned numIVs = steps.size();
2856  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2857  SmallVector<Location, 8> argLocs(numIVs, result.location);
2858  Region *bodyRegion = result.addRegion();
2859  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2860 
2861  if (bodyBuilderFn) {
2862  builder.setInsertionPointToStart(bodyBlock);
2863  bodyBuilderFn(builder, result.location,
2864  bodyBlock->getArguments().take_front(numIVs),
2865  bodyBlock->getArguments().drop_front(numIVs));
2866  }
2867  // Add terminator only if there are no reductions.
2868  if (initVals.empty())
2869  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2870 }
2871 
2872 void ParallelOp::build(
2873  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2874  ValueRange upperBounds, ValueRange steps,
2875  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2876  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2877  // we don't capture a reference to a temporary by constructing the lambda at
2878  // function level.
2879  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2880  Location nestedLoc, ValueRange ivs,
2881  ValueRange) {
2882  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2883  };
2884  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2885  if (bodyBuilderFn)
2886  wrapper = wrappedBuilderFn;
2887 
2888  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2889  wrapper);
2890 }
2891 
2892 LogicalResult ParallelOp::verify() {
2893  // Check that there is at least one value in lowerBound, upperBound and step.
2894  // It is sufficient to test only step, because it is ensured already that the
2895  // number of elements in lowerBound, upperBound and step are the same.
2896  Operation::operand_range stepValues = getStep();
2897  if (stepValues.empty())
2898  return emitOpError(
2899  "needs at least one tuple element for lowerBound, upperBound and step");
2900 
2901  // Check whether all constant step values are positive.
2902  for (Value stepValue : stepValues)
2903  if (auto cst = getConstantIntValue(stepValue))
2904  if (*cst <= 0)
2905  return emitOpError("constant step operand must be positive");
2906 
2907  // Check that the body defines the same number of block arguments as the
2908  // number of tuple elements in step.
2909  Block *body = getBody();
2910  if (body->getNumArguments() != stepValues.size())
2911  return emitOpError() << "expects the same number of induction variables: "
2912  << body->getNumArguments()
2913  << " as bound and step values: " << stepValues.size();
2914  for (auto arg : body->getArguments())
2915  if (!arg.getType().isIndex())
2916  return emitOpError(
2917  "expects arguments for the induction variable to be of index type");
2918 
2919  // Check that the terminator is an scf.reduce op.
2920  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2921  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2922  if (!reduceOp)
2923  return failure();
2924 
2925  // Check that the number of results is the same as the number of reductions.
2926  auto resultsSize = getResults().size();
2927  auto reductionsSize = reduceOp.getReductions().size();
2928  auto initValsSize = getInitVals().size();
2929  if (resultsSize != reductionsSize)
2930  return emitOpError() << "expects number of results: " << resultsSize
2931  << " to be the same as number of reductions: "
2932  << reductionsSize;
2933  if (resultsSize != initValsSize)
2934  return emitOpError() << "expects number of results: " << resultsSize
2935  << " to be the same as number of initial values: "
2936  << initValsSize;
2937 
2938  // Check that the types of the results and reductions are the same.
2939  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2940  auto resultType = getOperation()->getResult(i).getType();
2941  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2942  if (resultType != reductionOperandType)
2943  return reduceOp.emitOpError()
2944  << "expects type of " << i
2945  << "-th reduction operand: " << reductionOperandType
2946  << " to be the same as the " << i
2947  << "-th result type: " << resultType;
2948  }
2949  return success();
2950 }
2951 
2952 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2953  auto &builder = parser.getBuilder();
2954  // Parse an opening `(` followed by induction variables followed by `)`
2956  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2957  return failure();
2958 
2959  // Parse loop bounds.
2961  if (parser.parseEqual() ||
2962  parser.parseOperandList(lower, ivs.size(),
2963  OpAsmParser::Delimiter::Paren) ||
2964  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2965  return failure();
2966 
2968  if (parser.parseKeyword("to") ||
2969  parser.parseOperandList(upper, ivs.size(),
2970  OpAsmParser::Delimiter::Paren) ||
2971  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2972  return failure();
2973 
2974  // Parse step values.
2976  if (parser.parseKeyword("step") ||
2977  parser.parseOperandList(steps, ivs.size(),
2978  OpAsmParser::Delimiter::Paren) ||
2979  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2980  return failure();
2981 
2982  // Parse init values.
2984  if (succeeded(parser.parseOptionalKeyword("init"))) {
2985  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2986  return failure();
2987  }
2988 
2989  // Parse optional results in case there is a reduce.
2990  if (parser.parseOptionalArrowTypeList(result.types))
2991  return failure();
2992 
2993  // Now parse the body.
2994  Region *body = result.addRegion();
2995  for (auto &iv : ivs)
2996  iv.type = builder.getIndexType();
2997  if (parser.parseRegion(*body, ivs))
2998  return failure();
2999 
3000  // Set `operandSegmentSizes` attribute.
3001  result.addAttribute(
3002  ParallelOp::getOperandSegmentSizeAttr(),
3003  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3004  static_cast<int32_t>(upper.size()),
3005  static_cast<int32_t>(steps.size()),
3006  static_cast<int32_t>(initVals.size())}));
3007 
3008  // Parse attributes.
3009  if (parser.parseOptionalAttrDict(result.attributes) ||
3010  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3011  result.operands))
3012  return failure();
3013 
3014  // Add a terminator if none was parsed.
3015  ParallelOp::ensureTerminator(*body, builder, result.location);
3016  return success();
3017 }
3018 
3020  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3021  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3022  if (!getInitVals().empty())
3023  p << " init (" << getInitVals() << ")";
3024  p.printOptionalArrowTypeList(getResultTypes());
3025  p << ' ';
3026  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3028  (*this)->getAttrs(),
3029  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3030 }
3031 
3032 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3033 
3034 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3035  return SmallVector<Value>{getBody()->getArguments()};
3036 }
3037 
3038 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3039  return getLowerBound();
3040 }
3041 
3042 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3043  return getUpperBound();
3044 }
3045 
3046 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3047  return getStep();
3048 }
3049 
3051  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3052  if (!ivArg)
3053  return ParallelOp();
3054  assert(ivArg.getOwner() && "unlinked block argument");
3055  auto *containingOp = ivArg.getOwner()->getParentOp();
3056  return dyn_cast<ParallelOp>(containingOp);
3057 }
3058 
3059 namespace {
3060 // Collapse loop dimensions that perform a single iteration.
3061 struct ParallelOpSingleOrZeroIterationDimsFolder
3062  : public OpRewritePattern<ParallelOp> {
3064 
3065  LogicalResult matchAndRewrite(ParallelOp op,
3066  PatternRewriter &rewriter) const override {
3067  Location loc = op.getLoc();
3068 
3069  // Compute new loop bounds that omit all single-iteration loop dimensions.
3070  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3071  IRMapping mapping;
3072  for (auto [lb, ub, step, iv] :
3073  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3074  op.getInductionVars())) {
3075  auto numIterations = constantTripCount(lb, ub, step);
3076  if (numIterations.has_value()) {
3077  // Remove the loop if it performs zero iterations.
3078  if (*numIterations == 0) {
3079  rewriter.replaceOp(op, op.getInitVals());
3080  return success();
3081  }
3082  // Replace the loop induction variable by the lower bound if the loop
3083  // performs a single iteration. Otherwise, copy the loop bounds.
3084  if (*numIterations == 1) {
3085  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3086  continue;
3087  }
3088  }
3089  newLowerBounds.push_back(lb);
3090  newUpperBounds.push_back(ub);
3091  newSteps.push_back(step);
3092  }
3093  // Exit if none of the loop dimensions perform a single iteration.
3094  if (newLowerBounds.size() == op.getLowerBound().size())
3095  return failure();
3096 
3097  if (newLowerBounds.empty()) {
3098  // All of the loop dimensions perform a single iteration. Inline
3099  // loop body and nested ReduceOp's
3100  SmallVector<Value> results;
3101  results.reserve(op.getInitVals().size());
3102  for (auto &bodyOp : op.getBody()->without_terminator())
3103  rewriter.clone(bodyOp, mapping);
3104  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3105  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3106  Block &reduceBlock = reduceOp.getReductions()[i].front();
3107  auto initValIndex = results.size();
3108  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3109  mapping.map(reduceBlock.getArgument(1),
3110  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3111  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3112  rewriter.clone(reduceBodyOp, mapping);
3113 
3114  auto result = mapping.lookupOrDefault(
3115  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3116  results.push_back(result);
3117  }
3118 
3119  rewriter.replaceOp(op, results);
3120  return success();
3121  }
3122  // Replace the parallel loop by lower-dimensional parallel loop.
3123  auto newOp =
3124  ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3125  newUpperBounds, newSteps, op.getInitVals(), nullptr);
3126  // Erase the empty block that was inserted by the builder.
3127  rewriter.eraseBlock(newOp.getBody());
3128  // Clone the loop body and remap the block arguments of the collapsed loops
3129  // (inlining does not support a cancellable block argument mapping).
3130  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3131  newOp.getRegion().begin(), mapping);
3132  rewriter.replaceOp(op, newOp.getResults());
3133  return success();
3134  }
3135 };
3136 
3137 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3139 
3140  LogicalResult matchAndRewrite(ParallelOp op,
3141  PatternRewriter &rewriter) const override {
3142  Block &outerBody = *op.getBody();
3143  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3144  return failure();
3145 
3146  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3147  if (!innerOp)
3148  return failure();
3149 
3150  for (auto val : outerBody.getArguments())
3151  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3152  llvm::is_contained(innerOp.getUpperBound(), val) ||
3153  llvm::is_contained(innerOp.getStep(), val))
3154  return failure();
3155 
3156  // Reductions are not supported yet.
3157  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3158  return failure();
3159 
3160  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3161  ValueRange iterVals, ValueRange) {
3162  Block &innerBody = *innerOp.getBody();
3163  assert(iterVals.size() ==
3164  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3165  IRMapping mapping;
3166  mapping.map(outerBody.getArguments(),
3167  iterVals.take_front(outerBody.getNumArguments()));
3168  mapping.map(innerBody.getArguments(),
3169  iterVals.take_back(innerBody.getNumArguments()));
3170  for (Operation &op : innerBody.without_terminator())
3171  builder.clone(op, mapping);
3172  };
3173 
3174  auto concatValues = [](const auto &first, const auto &second) {
3175  SmallVector<Value> ret;
3176  ret.reserve(first.size() + second.size());
3177  ret.assign(first.begin(), first.end());
3178  ret.append(second.begin(), second.end());
3179  return ret;
3180  };
3181 
3182  auto newLowerBounds =
3183  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3184  auto newUpperBounds =
3185  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3186  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3187 
3188  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3189  newSteps, ValueRange(),
3190  bodyBuilder);
3191  return success();
3192  }
3193 };
3194 
3195 } // namespace
3196 
3197 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3198  MLIRContext *context) {
3199  results
3200  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3201  context);
3202 }
3203 
3204 /// Given the region at `index`, or the parent operation if `index` is None,
3205 /// return the successor regions. These are the regions that may be selected
3206 /// during the flow of control. `operands` is a set of optional attributes that
3207 /// correspond to a constant value for each operand, or null if that operand is
3208 /// not a constant.
3209 void ParallelOp::getSuccessorRegions(
3211  // Both the operation itself and the region may be branching into the body or
3212  // back into the operation itself. It is possible for loop not to enter the
3213  // body.
3214  regions.push_back(RegionSuccessor(&getRegion()));
3215  regions.push_back(RegionSuccessor());
3216 }
3217 
3218 //===----------------------------------------------------------------------===//
3219 // ReduceOp
3220 //===----------------------------------------------------------------------===//
3221 
3222 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3223 
3224 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3225  ValueRange operands) {
3226  result.addOperands(operands);
3227  for (Value v : operands) {
3228  OpBuilder::InsertionGuard guard(builder);
3229  Region *bodyRegion = result.addRegion();
3230  builder.createBlock(bodyRegion, {},
3231  ArrayRef<Type>{v.getType(), v.getType()},
3232  {result.location, result.location});
3233  }
3234 }
3235 
3236 LogicalResult ReduceOp::verifyRegions() {
3237  // The region of a ReduceOp has two arguments of the same type as its
3238  // corresponding operand.
3239  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3240  auto type = getOperands()[i].getType();
3241  Block &block = getReductions()[i].front();
3242  if (block.empty())
3243  return emitOpError() << i << "-th reduction has an empty body";
3244  if (block.getNumArguments() != 2 ||
3245  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3246  return arg.getType() != type;
3247  }))
3248  return emitOpError() << "expected two block arguments with type " << type
3249  << " in the " << i << "-th reduction region";
3250 
3251  // Check that the block is terminated by a ReduceReturnOp.
3252  if (!isa<ReduceReturnOp>(block.getTerminator()))
3253  return emitOpError("reduction bodies must be terminated with an "
3254  "'scf.reduce.return' op");
3255  }
3256 
3257  return success();
3258 }
3259 
3262  // No operands are forwarded to the next iteration.
3263  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3264 }
3265 
3266 //===----------------------------------------------------------------------===//
3267 // ReduceReturnOp
3268 //===----------------------------------------------------------------------===//
3269 
3270 LogicalResult ReduceReturnOp::verify() {
3271  // The type of the return value should be the same type as the types of the
3272  // block arguments of the reduction body.
3273  Block *reductionBody = getOperation()->getBlock();
3274  // Should already be verified by an op trait.
3275  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3276  Type expectedResultType = reductionBody->getArgument(0).getType();
3277  if (expectedResultType != getResult().getType())
3278  return emitOpError() << "must have type " << expectedResultType
3279  << " (the type of the reduction inputs)";
3280  return success();
3281 }
3282 
3283 //===----------------------------------------------------------------------===//
3284 // WhileOp
3285 //===----------------------------------------------------------------------===//
3286 
3287 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3288  ::mlir::OperationState &odsState, TypeRange resultTypes,
3289  ValueRange inits, BodyBuilderFn beforeBuilder,
3290  BodyBuilderFn afterBuilder) {
3291  odsState.addOperands(inits);
3292  odsState.addTypes(resultTypes);
3293 
3294  OpBuilder::InsertionGuard guard(odsBuilder);
3295 
3296  // Build before region.
3297  SmallVector<Location, 4> beforeArgLocs;
3298  beforeArgLocs.reserve(inits.size());
3299  for (Value operand : inits) {
3300  beforeArgLocs.push_back(operand.getLoc());
3301  }
3302 
3303  Region *beforeRegion = odsState.addRegion();
3304  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3305  inits.getTypes(), beforeArgLocs);
3306  if (beforeBuilder)
3307  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3308 
3309  // Build after region.
3310  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3311 
3312  Region *afterRegion = odsState.addRegion();
3313  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3314  resultTypes, afterArgLocs);
3315 
3316  if (afterBuilder)
3317  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3318 }
3319 
3320 ConditionOp WhileOp::getConditionOp() {
3321  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3322 }
3323 
3324 YieldOp WhileOp::getYieldOp() {
3325  return cast<YieldOp>(getAfterBody()->getTerminator());
3326 }
3327 
3328 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3329  return getYieldOp().getResultsMutable();
3330 }
3331 
3332 Block::BlockArgListType WhileOp::getBeforeArguments() {
3333  return getBeforeBody()->getArguments();
3334 }
3335 
3336 Block::BlockArgListType WhileOp::getAfterArguments() {
3337  return getAfterBody()->getArguments();
3338 }
3339 
3340 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3341  return getBeforeArguments();
3342 }
3343 
3344 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3345  assert(point == getBefore() &&
3346  "WhileOp is expected to branch only to the first region");
3347  return getInits();
3348 }
3349 
3350 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3352  // The parent op always branches to the condition region.
3353  if (point.isParent()) {
3354  regions.emplace_back(&getBefore(), getBefore().getArguments());
3355  return;
3356  }
3357 
3358  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3359  "there are only two regions in a WhileOp");
3360  // The body region always branches back to the condition region.
3361  if (point == getAfter()) {
3362  regions.emplace_back(&getBefore(), getBefore().getArguments());
3363  return;
3364  }
3365 
3366  regions.emplace_back(getResults());
3367  regions.emplace_back(&getAfter(), getAfter().getArguments());
3368 }
3369 
3370 SmallVector<Region *> WhileOp::getLoopRegions() {
3371  return {&getBefore(), &getAfter()};
3372 }
3373 
3374 /// Parses a `while` op.
3375 ///
3376 /// op ::= `scf.while` assignments `:` function-type region `do` region
3377 /// `attributes` attribute-dict
3378 /// initializer ::= /* empty */ | `(` assignment-list `)`
3379 /// assignment-list ::= assignment | assignment `,` assignment-list
3380 /// assignment ::= ssa-value `=` ssa-value
3381 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3384  Region *before = result.addRegion();
3385  Region *after = result.addRegion();
3386 
3387  OptionalParseResult listResult =
3388  parser.parseOptionalAssignmentList(regionArgs, operands);
3389  if (listResult.has_value() && failed(listResult.value()))
3390  return failure();
3391 
3392  FunctionType functionType;
3393  SMLoc typeLoc = parser.getCurrentLocation();
3394  if (failed(parser.parseColonType(functionType)))
3395  return failure();
3396 
3397  result.addTypes(functionType.getResults());
3398 
3399  if (functionType.getNumInputs() != operands.size()) {
3400  return parser.emitError(typeLoc)
3401  << "expected as many input types as operands " << "(expected "
3402  << operands.size() << " got " << functionType.getNumInputs() << ")";
3403  }
3404 
3405  // Resolve input operands.
3406  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3407  parser.getCurrentLocation(),
3408  result.operands)))
3409  return failure();
3410 
3411  // Propagate the types into the region arguments.
3412  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3413  regionArgs[i].type = functionType.getInput(i);
3414 
3415  return failure(parser.parseRegion(*before, regionArgs) ||
3416  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3418 }
3419 
3420 /// Prints a `while` op.
3422  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3423  p << " : ";
3424  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3425  p << ' ';
3426  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3427  p << " do ";
3428  p.printRegion(getAfter());
3429  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3430 }
3431 
3432 /// Verifies that two ranges of types match, i.e. have the same number of
3433 /// entries and that types are pairwise equals. Reports errors on the given
3434 /// operation in case of mismatch.
3435 template <typename OpTy>
3436 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3437  TypeRange right, StringRef message) {
3438  if (left.size() != right.size())
3439  return op.emitOpError("expects the same number of ") << message;
3440 
3441  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3442  if (left[i] != right[i]) {
3443  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3444  << message;
3445  diag.attachNote() << "for argument " << i << ", found " << left[i]
3446  << " and " << right[i];
3447  return diag;
3448  }
3449  }
3450 
3451  return success();
3452 }
3453 
3454 LogicalResult scf::WhileOp::verify() {
3455  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3456  *this, getBefore(),
3457  "expects the 'before' region to terminate with 'scf.condition'");
3458  if (!beforeTerminator)
3459  return failure();
3460 
3461  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3462  *this, getAfter(),
3463  "expects the 'after' region to terminate with 'scf.yield'");
3464  return success(afterTerminator != nullptr);
3465 }
3466 
3467 namespace {
3468 /// Replace uses of the condition within the do block with true, since otherwise
3469 /// the block would not be evaluated.
3470 ///
3471 /// scf.while (..) : (i1, ...) -> ... {
3472 /// %condition = call @evaluate_condition() : () -> i1
3473 /// scf.condition(%condition) %condition : i1, ...
3474 /// } do {
3475 /// ^bb0(%arg0: i1, ...):
3476 /// use(%arg0)
3477 /// ...
3478 ///
3479 /// becomes
3480 /// scf.while (..) : (i1, ...) -> ... {
3481 /// %condition = call @evaluate_condition() : () -> i1
3482 /// scf.condition(%condition) %condition : i1, ...
3483 /// } do {
3484 /// ^bb0(%arg0: i1, ...):
3485 /// use(%true)
3486 /// ...
3487 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3489 
3490  LogicalResult matchAndRewrite(WhileOp op,
3491  PatternRewriter &rewriter) const override {
3492  auto term = op.getConditionOp();
3493 
3494  // These variables serve to prevent creating duplicate constants
3495  // and hold constant true or false values.
3496  Value constantTrue = nullptr;
3497 
3498  bool replaced = false;
3499  for (auto yieldedAndBlockArgs :
3500  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3501  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3502  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3503  if (!constantTrue)
3504  constantTrue = arith::ConstantOp::create(
3505  rewriter, op.getLoc(), term.getCondition().getType(),
3506  rewriter.getBoolAttr(true));
3507 
3508  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3509  constantTrue);
3510  replaced = true;
3511  }
3512  }
3513  }
3514  return success(replaced);
3515  }
3516 };
3517 
3518 /// Remove loop invariant arguments from `before` block of scf.while.
3519 /// A before block argument is considered loop invariant if :-
3520 /// 1. i-th yield operand is equal to the i-th while operand.
3521 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3522 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3523 /// iter argument/while operand.
3524 /// For the arguments which are removed, their uses inside scf.while
3525 /// are replaced with their corresponding initial value.
3526 ///
3527 /// Eg:
3528 /// INPUT :-
3529 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3530 /// ..., %argN_before = %N)
3531 /// {
3532 /// ...
3533 /// scf.condition(%cond) %arg1_before, %arg0_before,
3534 /// %arg2_before, %arg0_before, ...
3535 /// } do {
3536 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3537 /// ..., %argK_after):
3538 /// ...
3539 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3540 /// }
3541 ///
3542 /// OUTPUT :-
3543 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3544 /// %N)
3545 /// {
3546 /// ...
3547 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3548 /// } do {
3549 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3550 /// ..., %argK_after):
3551 /// ...
3552 /// scf.yield %arg1_after, ..., %argN
3553 /// }
3554 ///
3555 /// EXPLANATION:
3556 /// We iterate over each yield operand.
3557 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3558 /// %arg0_before, which in turn is the 0-th iter argument. So we
3559 /// remove 0-th before block argument and yield operand, and replace
3560 /// all uses of the 0-th before block argument with its initial value
3561 /// %a.
3562 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3563 /// value. So we remove this operand and the corresponding before
3564 /// block argument and replace all uses of 1-th before block argument
3565 /// with %b.
3566 struct RemoveLoopInvariantArgsFromBeforeBlock
3567  : public OpRewritePattern<WhileOp> {
3569 
3570  LogicalResult matchAndRewrite(WhileOp op,
3571  PatternRewriter &rewriter) const override {
3572  Block &afterBlock = *op.getAfterBody();
3573  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3574  ConditionOp condOp = op.getConditionOp();
3575  OperandRange condOpArgs = condOp.getArgs();
3576  Operation *yieldOp = afterBlock.getTerminator();
3577  ValueRange yieldOpArgs = yieldOp->getOperands();
3578 
3579  bool canSimplify = false;
3580  for (const auto &it :
3581  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3582  auto index = static_cast<unsigned>(it.index());
3583  auto [initVal, yieldOpArg] = it.value();
3584  // If i-th yield operand is equal to the i-th operand of the scf.while,
3585  // the i-th before block argument is a loop invariant.
3586  if (yieldOpArg == initVal) {
3587  canSimplify = true;
3588  break;
3589  }
3590  // If the i-th yield operand is k-th after block argument, then we check
3591  // if the (k+1)-th condition op operand is equal to either the i-th before
3592  // block argument or the initial value of i-th before block argument. If
3593  // the comparison results `true`, i-th before block argument is a loop
3594  // invariant.
3595  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3596  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3597  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3598  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3599  canSimplify = true;
3600  break;
3601  }
3602  }
3603  }
3604 
3605  if (!canSimplify)
3606  return failure();
3607 
3608  SmallVector<Value> newInitArgs, newYieldOpArgs;
3609  DenseMap<unsigned, Value> beforeBlockInitValMap;
3610  SmallVector<Location> newBeforeBlockArgLocs;
3611  for (const auto &it :
3612  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3613  auto index = static_cast<unsigned>(it.index());
3614  auto [initVal, yieldOpArg] = it.value();
3615 
3616  // If i-th yield operand is equal to the i-th operand of the scf.while,
3617  // the i-th before block argument is a loop invariant.
3618  if (yieldOpArg == initVal) {
3619  beforeBlockInitValMap.insert({index, initVal});
3620  continue;
3621  } else {
3622  // If the i-th yield operand is k-th after block argument, then we check
3623  // if the (k+1)-th condition op operand is equal to either the i-th
3624  // before block argument or the initial value of i-th before block
3625  // argument. If the comparison results `true`, i-th before block
3626  // argument is a loop invariant.
3627  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3628  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3629  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3630  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3631  beforeBlockInitValMap.insert({index, initVal});
3632  continue;
3633  }
3634  }
3635  }
3636  newInitArgs.emplace_back(initVal);
3637  newYieldOpArgs.emplace_back(yieldOpArg);
3638  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3639  }
3640 
3641  {
3642  OpBuilder::InsertionGuard g(rewriter);
3643  rewriter.setInsertionPoint(yieldOp);
3644  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3645  }
3646 
3647  auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3648  newInitArgs);
3649 
3650  Block &newBeforeBlock = *rewriter.createBlock(
3651  &newWhile.getBefore(), /*insertPt*/ {},
3652  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3653 
3654  Block &beforeBlock = *op.getBeforeBody();
3655  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3656  // For each i-th before block argument we find it's replacement value as :-
3657  // 1. If i-th before block argument is a loop invariant, we fetch it's
3658  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3659  // 2. Else we fetch j-th new before block argument as the replacement
3660  // value of i-th before block argument.
3661  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3662  // If the index 'i' argument was a loop invariant we fetch it's initial
3663  // value from `beforeBlockInitValMap`.
3664  if (beforeBlockInitValMap.count(i) != 0)
3665  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3666  else
3667  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3668  }
3669 
3670  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3671  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3672  newWhile.getAfter().begin());
3673 
3674  rewriter.replaceOp(op, newWhile.getResults());
3675  return success();
3676  }
3677 };
3678 
3679 /// Remove loop invariant value from result (condition op) of scf.while.
3680 /// A value is considered loop invariant if the final value yielded by
3681 /// scf.condition is defined outside of the `before` block. We remove the
3682 /// corresponding argument in `after` block and replace the use with the value.
3683 /// We also replace the use of the corresponding result of scf.while with the
3684 /// value.
3685 ///
3686 /// Eg:
3687 /// INPUT :-
3688 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3689 /// %argN_before = %N) {
3690 /// ...
3691 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3692 /// } do {
3693 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3694 /// ...
3695 /// some_func(%arg1_after)
3696 /// ...
3697 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3698 /// }
3699 ///
3700 /// OUTPUT :-
3701 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3702 /// ...
3703 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3704 /// } do {
3705 /// ^bb0(%arg0, %arg3, ..., %argM):
3706 /// ...
3707 /// some_func(%a)
3708 /// ...
3709 /// scf.yield %arg0, %b, ..., %argN
3710 /// }
3711 ///
3712 /// EXPLANATION:
3713 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3714 /// before block of scf.while, so they get removed.
3715 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3716 /// replaced by %b.
3717 /// 3. The corresponding after block argument %arg1_after's uses are
3718 /// replaced by %a and %arg2_after's uses are replaced by %b.
3719 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3721 
3722  LogicalResult matchAndRewrite(WhileOp op,
3723  PatternRewriter &rewriter) const override {
3724  Block &beforeBlock = *op.getBeforeBody();
3725  ConditionOp condOp = op.getConditionOp();
3726  OperandRange condOpArgs = condOp.getArgs();
3727 
3728  bool canSimplify = false;
3729  for (Value condOpArg : condOpArgs) {
3730  // Those values not defined within `before` block will be considered as
3731  // loop invariant values. We map the corresponding `index` with their
3732  // value.
3733  if (condOpArg.getParentBlock() != &beforeBlock) {
3734  canSimplify = true;
3735  break;
3736  }
3737  }
3738 
3739  if (!canSimplify)
3740  return failure();
3741 
3742  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3743 
3744  SmallVector<Value> newCondOpArgs;
3745  SmallVector<Type> newAfterBlockType;
3746  DenseMap<unsigned, Value> condOpInitValMap;
3747  SmallVector<Location> newAfterBlockArgLocs;
3748  for (const auto &it : llvm::enumerate(condOpArgs)) {
3749  auto index = static_cast<unsigned>(it.index());
3750  Value condOpArg = it.value();
3751  // Those values not defined within `before` block will be considered as
3752  // loop invariant values. We map the corresponding `index` with their
3753  // value.
3754  if (condOpArg.getParentBlock() != &beforeBlock) {
3755  condOpInitValMap.insert({index, condOpArg});
3756  } else {
3757  newCondOpArgs.emplace_back(condOpArg);
3758  newAfterBlockType.emplace_back(condOpArg.getType());
3759  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3760  }
3761  }
3762 
3763  {
3764  OpBuilder::InsertionGuard g(rewriter);
3765  rewriter.setInsertionPoint(condOp);
3766  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3767  newCondOpArgs);
3768  }
3769 
3770  auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3771  op.getOperands());
3772 
3773  Block &newAfterBlock =
3774  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3775  newAfterBlockType, newAfterBlockArgLocs);
3776 
3777  Block &afterBlock = *op.getAfterBody();
3778  // Since a new scf.condition op was created, we need to fetch the new
3779  // `after` block arguments which will be used while replacing operations of
3780  // previous scf.while's `after` blocks. We'd also be fetching new result
3781  // values too.
3782  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3783  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3784  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3785  Value afterBlockArg, result;
3786  // If index 'i' argument was loop invariant we fetch it's value from the
3787  // `condOpInitMap` map.
3788  if (condOpInitValMap.count(i) != 0) {
3789  afterBlockArg = condOpInitValMap[i];
3790  result = afterBlockArg;
3791  } else {
3792  afterBlockArg = newAfterBlock.getArgument(j);
3793  result = newWhile.getResult(j);
3794  j++;
3795  }
3796  newAfterBlockArgs[i] = afterBlockArg;
3797  newWhileResults[i] = result;
3798  }
3799 
3800  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3801  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3802  newWhile.getBefore().begin());
3803 
3804  rewriter.replaceOp(op, newWhileResults);
3805  return success();
3806  }
3807 };
3808 
3809 /// Remove WhileOp results that are also unused in 'after' block.
3810 ///
3811 /// %0:2 = scf.while () : () -> (i32, i64) {
3812 /// %condition = "test.condition"() : () -> i1
3813 /// %v1 = "test.get_some_value"() : () -> i32
3814 /// %v2 = "test.get_some_value"() : () -> i64
3815 /// scf.condition(%condition) %v1, %v2 : i32, i64
3816 /// } do {
3817 /// ^bb0(%arg0: i32, %arg1: i64):
3818 /// "test.use"(%arg0) : (i32) -> ()
3819 /// scf.yield
3820 /// }
3821 /// return %0#0 : i32
3822 ///
3823 /// becomes
3824 /// %0 = scf.while () : () -> (i32) {
3825 /// %condition = "test.condition"() : () -> i1
3826 /// %v1 = "test.get_some_value"() : () -> i32
3827 /// %v2 = "test.get_some_value"() : () -> i64
3828 /// scf.condition(%condition) %v1 : i32
3829 /// } do {
3830 /// ^bb0(%arg0: i32):
3831 /// "test.use"(%arg0) : (i32) -> ()
3832 /// scf.yield
3833 /// }
3834 /// return %0 : i32
3835 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3837 
3838  LogicalResult matchAndRewrite(WhileOp op,
3839  PatternRewriter &rewriter) const override {
3840  auto term = op.getConditionOp();
3841  auto afterArgs = op.getAfterArguments();
3842  auto termArgs = term.getArgs();
3843 
3844  // Collect results mapping, new terminator args and new result types.
3845  SmallVector<unsigned> newResultsIndices;
3846  SmallVector<Type> newResultTypes;
3847  SmallVector<Value> newTermArgs;
3848  SmallVector<Location> newArgLocs;
3849  bool needUpdate = false;
3850  for (const auto &it :
3851  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3852  auto i = static_cast<unsigned>(it.index());
3853  Value result = std::get<0>(it.value());
3854  Value afterArg = std::get<1>(it.value());
3855  Value termArg = std::get<2>(it.value());
3856  if (result.use_empty() && afterArg.use_empty()) {
3857  needUpdate = true;
3858  } else {
3859  newResultsIndices.emplace_back(i);
3860  newTermArgs.emplace_back(termArg);
3861  newResultTypes.emplace_back(result.getType());
3862  newArgLocs.emplace_back(result.getLoc());
3863  }
3864  }
3865 
3866  if (!needUpdate)
3867  return failure();
3868 
3869  {
3870  OpBuilder::InsertionGuard g(rewriter);
3871  rewriter.setInsertionPoint(term);
3872  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3873  newTermArgs);
3874  }
3875 
3876  auto newWhile =
3877  WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
3878 
3879  Block &newAfterBlock = *rewriter.createBlock(
3880  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3881 
3882  // Build new results list and new after block args (unused entries will be
3883  // null).
3884  SmallVector<Value> newResults(op.getNumResults());
3885  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3886  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3887  newResults[it.value()] = newWhile.getResult(it.index());
3888  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3889  }
3890 
3891  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3892  newWhile.getBefore().begin());
3893 
3894  Block &afterBlock = *op.getAfterBody();
3895  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3896 
3897  rewriter.replaceOp(op, newResults);
3898  return success();
3899  }
3900 };
3901 
3902 /// Replace operations equivalent to the condition in the do block with true,
3903 /// since otherwise the block would not be evaluated.
3904 ///
3905 /// scf.while (..) : (i32, ...) -> ... {
3906 /// %z = ... : i32
3907 /// %condition = cmpi pred %z, %a
3908 /// scf.condition(%condition) %z : i32, ...
3909 /// } do {
3910 /// ^bb0(%arg0: i32, ...):
3911 /// %condition2 = cmpi pred %arg0, %a
3912 /// use(%condition2)
3913 /// ...
3914 ///
3915 /// becomes
3916 /// scf.while (..) : (i32, ...) -> ... {
3917 /// %z = ... : i32
3918 /// %condition = cmpi pred %z, %a
3919 /// scf.condition(%condition) %z : i32, ...
3920 /// } do {
3921 /// ^bb0(%arg0: i32, ...):
3922 /// use(%true)
3923 /// ...
3924 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3926 
3927  LogicalResult matchAndRewrite(scf::WhileOp op,
3928  PatternRewriter &rewriter) const override {
3929  using namespace scf;
3930  auto cond = op.getConditionOp();
3931  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3932  if (!cmp)
3933  return failure();
3934  bool changed = false;
3935  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3936  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3937  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3938  continue;
3939  for (OpOperand &u :
3940  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3941  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3942  if (!cmp2)
3943  continue;
3944  // For a binary operator 1-opIdx gets the other side.
3945  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3946  continue;
3947  bool samePredicate;
3948  if (cmp2.getPredicate() == cmp.getPredicate())
3949  samePredicate = true;
3950  else if (cmp2.getPredicate() ==
3951  arith::invertPredicate(cmp.getPredicate()))
3952  samePredicate = false;
3953  else
3954  continue;
3955 
3956  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3957  1);
3958  changed = true;
3959  }
3960  }
3961  }
3962  return success(changed);
3963  }
3964 };
3965 
3966 /// Remove unused init/yield args.
3967 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3969 
3970  LogicalResult matchAndRewrite(WhileOp op,
3971  PatternRewriter &rewriter) const override {
3972 
3973  if (!llvm::any_of(op.getBeforeArguments(),
3974  [](Value arg) { return arg.use_empty(); }))
3975  return rewriter.notifyMatchFailure(op, "No args to remove");
3976 
3977  YieldOp yield = op.getYieldOp();
3978 
3979  // Collect results mapping, new terminator args and new result types.
3980  SmallVector<Value> newYields;
3981  SmallVector<Value> newInits;
3982  llvm::BitVector argsToErase;
3983 
3984  size_t argsCount = op.getBeforeArguments().size();
3985  newYields.reserve(argsCount);
3986  newInits.reserve(argsCount);
3987  argsToErase.reserve(argsCount);
3988  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3989  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3990  if (beforeArg.use_empty()) {
3991  argsToErase.push_back(true);
3992  } else {
3993  argsToErase.push_back(false);
3994  newYields.emplace_back(yieldValue);
3995  newInits.emplace_back(initValue);
3996  }
3997  }
3998 
3999  Block &beforeBlock = *op.getBeforeBody();
4000  Block &afterBlock = *op.getAfterBody();
4001 
4002  beforeBlock.eraseArguments(argsToErase);
4003 
4004  Location loc = op.getLoc();
4005  auto newWhileOp =
4006  WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4007  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4008  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4009  Block &newAfterBlock = *newWhileOp.getAfterBody();
4010 
4011  OpBuilder::InsertionGuard g(rewriter);
4012  rewriter.setInsertionPoint(yield);
4013  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4014 
4015  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4016  newBeforeBlock.getArguments());
4017  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4018  newAfterBlock.getArguments());
4019 
4020  rewriter.replaceOp(op, newWhileOp.getResults());
4021  return success();
4022  }
4023 };
4024 
4025 /// Remove duplicated ConditionOp args.
4026 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4027  using OpRewritePattern::OpRewritePattern;
4028 
4029  LogicalResult matchAndRewrite(WhileOp op,
4030  PatternRewriter &rewriter) const override {
4031  ConditionOp condOp = op.getConditionOp();
4032  ValueRange condOpArgs = condOp.getArgs();
4033 
4034  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4035 
4036  if (argsSet.size() == condOpArgs.size())
4037  return rewriter.notifyMatchFailure(op, "No results to remove");
4038 
4039  llvm::SmallDenseMap<Value, unsigned> argsMap;
4040  SmallVector<Value> newArgs;
4041  argsMap.reserve(condOpArgs.size());
4042  newArgs.reserve(condOpArgs.size());
4043  for (Value arg : condOpArgs) {
4044  if (!argsMap.count(arg)) {
4045  auto pos = static_cast<unsigned>(argsMap.size());
4046  argsMap.insert({arg, pos});
4047  newArgs.emplace_back(arg);
4048  }
4049  }
4050 
4051  ValueRange argsRange(newArgs);
4052 
4053  Location loc = op.getLoc();
4054  auto newWhileOp =
4055  scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4056  /*beforeBody*/ nullptr,
4057  /*afterBody*/ nullptr);
4058  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4059  Block &newAfterBlock = *newWhileOp.getAfterBody();
4060 
4061  SmallVector<Value> afterArgsMapping;
4062  SmallVector<Value> resultsMapping;
4063  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4064  auto it = argsMap.find(arg);
4065  assert(it != argsMap.end());
4066  auto pos = it->second;
4067  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4068  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4069  }
4070 
4071  OpBuilder::InsertionGuard g(rewriter);
4072  rewriter.setInsertionPoint(condOp);
4073  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4074  argsRange);
4075 
4076  Block &beforeBlock = *op.getBeforeBody();
4077  Block &afterBlock = *op.getAfterBody();
4078 
4079  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4080  newBeforeBlock.getArguments());
4081  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4082  rewriter.replaceOp(op, resultsMapping);
4083  return success();
4084  }
4085 };
4086 
4087 /// If both ranges contain same values return mappping indices from args2 to
4088 /// args1. Otherwise return std::nullopt.
4089 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4090  ValueRange args2) {
4091  if (args1.size() != args2.size())
4092  return std::nullopt;
4093 
4094  SmallVector<unsigned> ret(args1.size());
4095  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4096  auto it = llvm::find(args2, arg1);
4097  if (it == args2.end())
4098  return std::nullopt;
4099 
4100  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4101  }
4102 
4103  return ret;
4104 }
4105 
4106 static bool hasDuplicates(ValueRange args) {
4107  llvm::SmallDenseSet<Value> set;
4108  for (Value arg : args) {
4109  if (!set.insert(arg).second)
4110  return true;
4111  }
4112  return false;
4113 }
4114 
4115 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4116 /// `scf.condition` args into same order as block args. Update `after` block
4117 /// args and op result values accordingly.
4118 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4119 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4120  using OpRewritePattern::OpRewritePattern;
4121 
4122  LogicalResult matchAndRewrite(WhileOp loop,
4123  PatternRewriter &rewriter) const override {
4124  auto oldBefore = loop.getBeforeBody();
4125  ConditionOp oldTerm = loop.getConditionOp();
4126  ValueRange beforeArgs = oldBefore->getArguments();
4127  ValueRange termArgs = oldTerm.getArgs();
4128  if (beforeArgs == termArgs)
4129  return failure();
4130 
4131  if (hasDuplicates(termArgs))
4132  return failure();
4133 
4134  auto mapping = getArgsMapping(beforeArgs, termArgs);
4135  if (!mapping)
4136  return failure();
4137 
4138  {
4139  OpBuilder::InsertionGuard g(rewriter);
4140  rewriter.setInsertionPoint(oldTerm);
4141  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4142  beforeArgs);
4143  }
4144 
4145  auto oldAfter = loop.getAfterBody();
4146 
4147  SmallVector<Type> newResultTypes(beforeArgs.size());
4148  for (auto &&[i, j] : llvm::enumerate(*mapping))
4149  newResultTypes[j] = loop.getResult(i).getType();
4150 
4151  auto newLoop = WhileOp::create(
4152  rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4153  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4154  auto newBefore = newLoop.getBeforeBody();
4155  auto newAfter = newLoop.getAfterBody();
4156 
4157  SmallVector<Value> newResults(beforeArgs.size());
4158  SmallVector<Value> newAfterArgs(beforeArgs.size());
4159  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4160  newResults[i] = newLoop.getResult(j);
4161  newAfterArgs[i] = newAfter->getArgument(j);
4162  }
4163 
4164  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4165  newBefore->getArguments());
4166  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4167  newAfterArgs);
4168 
4169  rewriter.replaceOp(loop, newResults);
4170  return success();
4171  }
4172 };
4173 } // namespace
4174 
4175 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4176  MLIRContext *context) {
4177  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4178  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4179  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4180  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4181 }
4182 
4183 //===----------------------------------------------------------------------===//
4184 // IndexSwitchOp
4185 //===----------------------------------------------------------------------===//
4186 
4187 /// Parse the case regions and values.
4188 static ParseResult
4190  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4191  SmallVector<int64_t> caseValues;
4192  while (succeeded(p.parseOptionalKeyword("case"))) {
4193  int64_t value;
4194  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4195  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4196  return failure();
4197  caseValues.push_back(value);
4198  }
4199  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4200  return success();
4201 }
4202 
4203 /// Print the case regions and values.
4205  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4206  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4207  p.printNewline();
4208  p << "case " << value << ' ';
4209  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4210  }
4211 }
4212 
4213 LogicalResult scf::IndexSwitchOp::verify() {
4214  if (getCases().size() != getCaseRegions().size()) {
4215  return emitOpError("has ")
4216  << getCaseRegions().size() << " case regions but "
4217  << getCases().size() << " case values";
4218  }
4219 
4220  DenseSet<int64_t> valueSet;
4221  for (int64_t value : getCases())
4222  if (!valueSet.insert(value).second)
4223  return emitOpError("has duplicate case value: ") << value;
4224  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4225  auto yield = dyn_cast<YieldOp>(region.front().back());
4226  if (!yield)
4227  return emitOpError("expected region to end with scf.yield, but got ")
4228  << region.front().back().getName();
4229 
4230  if (yield.getNumOperands() != getNumResults()) {
4231  return (emitOpError("expected each region to return ")
4232  << getNumResults() << " values, but " << name << " returns "
4233  << yield.getNumOperands())
4234  .attachNote(yield.getLoc())
4235  << "see yield operation here";
4236  }
4237  for (auto [idx, result, operand] :
4238  llvm::enumerate(getResultTypes(), yield.getOperands())) {
4239  if (!operand)
4240  return yield.emitOpError() << "operand " << idx << " is null\n";
4241  if (result == operand.getType())
4242  continue;
4243  return (emitOpError("expected result #")
4244  << idx << " of each region to be " << result)
4245  .attachNote(yield.getLoc())
4246  << name << " returns " << operand.getType() << " here";
4247  }
4248  return success();
4249  };
4250 
4251  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4252  return failure();
4253  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4254  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4255  return failure();
4256 
4257  return success();
4258 }
4259 
4260 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4261 
4262 Block &scf::IndexSwitchOp::getDefaultBlock() {
4263  return getDefaultRegion().front();
4264 }
4265 
4266 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4267  assert(idx < getNumCases() && "case index out-of-bounds");
4268  return getCaseRegions()[idx].front();
4269 }
4270 
4271 void IndexSwitchOp::getSuccessorRegions(
4273  // All regions branch back to the parent op.
4274  if (!point.isParent()) {
4275  successors.emplace_back(getResults());
4276  return;
4277  }
4278 
4279  llvm::append_range(successors, getRegions());
4280 }
4281 
4282 void IndexSwitchOp::getEntrySuccessorRegions(
4283  ArrayRef<Attribute> operands,
4284  SmallVectorImpl<RegionSuccessor> &successors) {
4285  FoldAdaptor adaptor(operands, *this);
4286 
4287  // If a constant was not provided, all regions are possible successors.
4288  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4289  if (!arg) {
4290  llvm::append_range(successors, getRegions());
4291  return;
4292  }
4293 
4294  // Otherwise, try to find a case with a matching value. If not, the
4295  // default region is the only successor.
4296  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4297  if (caseValue == arg.getInt()) {
4298  successors.emplace_back(&caseRegion);
4299  return;
4300  }
4301  }
4302  successors.emplace_back(&getDefaultRegion());
4303 }
4304 
4305 void IndexSwitchOp::getRegionInvocationBounds(
4307  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4308  if (!operandValue) {
4309  // All regions are invoked at most once.
4310  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4311  return;
4312  }
4313 
4314  unsigned liveIndex = getNumRegions() - 1;
4315  const auto *it = llvm::find(getCases(), operandValue.getInt());
4316  if (it != getCases().end())
4317  liveIndex = std::distance(getCases().begin(), it);
4318  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4319  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4320 }
4321 
4322 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4324 
4325  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4326  PatternRewriter &rewriter) const override {
4327  // If `op.getArg()` is a constant, select the region that matches with
4328  // the constant value. Use the default region if no matche is found.
4329  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4330  if (!maybeCst.has_value())
4331  return failure();
4332  int64_t cst = *maybeCst;
4333  int64_t caseIdx, e = op.getNumCases();
4334  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4335  if (cst == op.getCases()[caseIdx])
4336  break;
4337  }
4338 
4339  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4340  : op.getDefaultRegion();
4341  Block &source = r.front();
4342  Operation *terminator = source.getTerminator();
4343  SmallVector<Value> results = terminator->getOperands();
4344 
4345  rewriter.inlineBlockBefore(&source, op);
4346  rewriter.eraseOp(terminator);
4347  // Replace the operation with a potentially empty list of results.
4348  // Fold mechanism doesn't support the case where the result list is empty.
4349  rewriter.replaceOp(op, results);
4350 
4351  return success();
4352  }
4353 };
4354 
4355 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4356  MLIRContext *context) {
4357  results.add<FoldConstantCase>(context);
4358 }
4359 
4360 //===----------------------------------------------------------------------===//
4361 // TableGen'd op method definitions
4362 //===----------------------------------------------------------------------===//
4363 
4364 #define GET_OP_CLASSES
4365 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:757
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:749
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:131
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1337
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4189
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3436
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:438
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:93
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4204
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
UnitAttr getUnitAttr()
Definition: Builders.cpp:93
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:575
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:585
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:119
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3050
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:86
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:707
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:2037
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:664
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:617
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1508
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:796
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4325
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:237
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:188
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.