MLIR  22.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/PatternMatch.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/SmallPtrSet.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/DebugLog.h"
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::scf;
37 
38 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // SCFDialect Dialect Interfaces
42 //===----------------------------------------------------------------------===//
43 
44 namespace {
45 struct SCFInlinerInterface : public DialectInlinerInterface {
47  // We don't have any special restrictions on what can be inlined into
48  // destination regions (e.g. while/conditional bodies). Always allow it.
49  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
50  IRMapping &valueMapping) const final {
51  return true;
52  }
53  // Operations in scf dialect are always legal to inline since they are
54  // pure.
55  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
56  return true;
57  }
58  // Handle the given inlined terminator by replacing it with a new operation
59  // as necessary. Required when the region has only one block.
60  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
61  auto retValOp = dyn_cast<scf::YieldOp>(op);
62  if (!retValOp)
63  return;
64 
65  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
66  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
67  }
68  }
69 };
70 } // namespace
71 
72 //===----------------------------------------------------------------------===//
73 // SCFDialect
74 //===----------------------------------------------------------------------===//
75 
76 void SCFDialect::initialize() {
77  addOperations<
78 #define GET_OP_LIST
79 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
80  >();
81  addInterfaces<SCFInlinerInterface>();
82  declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
83  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
84  InParallelOp, ReduceReturnOp>();
85  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
86  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
87  ForallOp, InParallelOp, WhileOp, YieldOp>();
88  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
89 }
90 
91 /// Default callback for IfOp builders. Inserts a yield without arguments.
93  scf::YieldOp::create(builder, loc);
94 }
95 
96 /// Verifies that the first block of the given `region` is terminated by a
97 /// TerminatorTy. Reports errors on the given operation if it is not the case.
98 template <typename TerminatorTy>
99 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
100  StringRef errorMessage) {
101  Operation *terminatorOperation = nullptr;
102  if (!region.empty() && !region.front().empty()) {
103  terminatorOperation = &region.front().back();
104  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
105  return yield;
106  }
107  auto diag = op->emitOpError(errorMessage);
108  if (terminatorOperation)
109  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
110  return nullptr;
111 }
112 
113 /// Helper function to compute the difference between two values. This is used
114 /// by the loop implementations to compute the trip count.
115 static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
116  bool isSigned) {
117  llvm::APSInt diff;
118  auto addOp = ub.getDefiningOp<arith::AddIOp>();
119  if (!addOp)
120  return std::nullopt;
121  if ((isSigned && !addOp.hasNoSignedWrap()) ||
122  (!isSigned && !addOp.hasNoUnsignedWrap()))
123  return std::nullopt;
124 
125  if (addOp.getLhs() != lb ||
126  !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
127  return std::nullopt;
128  return diff;
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // ExecuteRegionOp
133 //===----------------------------------------------------------------------===//
134 
135 /// Replaces the given op with the contents of the given single-block region,
136 /// using the operands of the block terminator to replace operation results.
137 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
138  Region &region, ValueRange blockArgs = {}) {
139  assert(region.hasOneBlock() && "expected single-block region");
140  Block *block = &region.front();
141  Operation *terminator = block->getTerminator();
142  ValueRange results = terminator->getOperands();
143  rewriter.inlineBlockBefore(block, op, blockArgs);
144  rewriter.replaceOp(op, results);
145  rewriter.eraseOp(terminator);
146 }
147 
148 ///
149 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
150 /// block+
151 /// `}`
152 ///
153 /// Example:
154 /// scf.execute_region -> i32 {
155 /// %idx = load %rI[%i] : memref<128xi32>
156 /// return %idx : i32
157 /// }
158 ///
159 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
160  OperationState &result) {
161  if (parser.parseOptionalArrowTypeList(result.types))
162  return failure();
163 
164  if (succeeded(parser.parseOptionalKeyword("no_inline")))
165  result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
166 
167  // Introduce the body region and parse it.
168  Region *body = result.addRegion();
169  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
170  parser.parseOptionalAttrDict(result.attributes))
171  return failure();
172 
173  return success();
174 }
175 
177  p.printOptionalArrowTypeList(getResultTypes());
178  p << ' ';
179  if (getNoInline())
180  p << "no_inline ";
181  p.printRegion(getRegion(),
182  /*printEntryBlockArgs=*/false,
183  /*printBlockTerminators=*/true);
184  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"});
185 }
186 
187 LogicalResult ExecuteRegionOp::verify() {
188  if (getRegion().empty())
189  return emitOpError("region needs to have at least one block");
190  if (getRegion().front().getNumArguments() > 0)
191  return emitOpError("region cannot have any arguments");
192  return success();
193 }
194 
195 // Inline an ExecuteRegionOp if it only contains one block.
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %x = "test.val"() : () -> i64
199 // scf.yield %x : i64
200 // }
201 // "test.bar"(%v) : (i64) -> ()
202 //
203 // becomes
204 //
205 // "test.foo"() : () -> ()
206 // %x = "test.val"() : () -> i64
207 // "test.bar"(%x) : (i64) -> ()
208 //
209 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
211 
212  LogicalResult matchAndRewrite(ExecuteRegionOp op,
213  PatternRewriter &rewriter) const override {
214  if (!op.getRegion().hasOneBlock() || op.getNoInline())
215  return failure();
216  replaceOpWithRegion(rewriter, op, op.getRegion());
217  return success();
218  }
219 };
220 
221 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
222 // TODO generalize the conditions for operations which can be inlined into.
223 // func @func_execute_region_elim() {
224 // "test.foo"() : () -> ()
225 // %v = scf.execute_region -> i64 {
226 // %c = "test.cmp"() : () -> i1
227 // cf.cond_br %c, ^bb2, ^bb3
228 // ^bb2:
229 // %x = "test.val1"() : () -> i64
230 // cf.br ^bb4(%x : i64)
231 // ^bb3:
232 // %y = "test.val2"() : () -> i64
233 // cf.br ^bb4(%y : i64)
234 // ^bb4(%z : i64):
235 // scf.yield %z : i64
236 // }
237 // "test.bar"(%v) : (i64) -> ()
238 // return
239 // }
240 //
241 // becomes
242 //
243 // func @func_execute_region_elim() {
244 // "test.foo"() : () -> ()
245 // %c = "test.cmp"() : () -> i1
246 // cf.cond_br %c, ^bb1, ^bb2
247 // ^bb1: // pred: ^bb0
248 // %x = "test.val1"() : () -> i64
249 // cf.br ^bb3(%x : i64)
250 // ^bb2: // pred: ^bb0
251 // %y = "test.val2"() : () -> i64
252 // cf.br ^bb3(%y : i64)
253 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
254 // "test.bar"(%z) : (i64) -> ()
255 // return
256 // }
257 //
258 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
260 
261  LogicalResult matchAndRewrite(ExecuteRegionOp op,
262  PatternRewriter &rewriter) const override {
263  if (op.getNoInline())
264  return failure();
265  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
266  return failure();
267 
268  Block *prevBlock = op->getBlock();
269  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
270  rewriter.setInsertionPointToEnd(prevBlock);
271 
272  cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
273 
274  for (Block &blk : op.getRegion()) {
275  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
276  rewriter.setInsertionPoint(yieldOp);
277  cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
278  yieldOp.getResults());
279  rewriter.eraseOp(yieldOp);
280  }
281  }
282 
283  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
284  SmallVector<Value> blockArgs;
285 
286  for (auto res : op.getResults())
287  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
288 
289  rewriter.replaceOp(op, blockArgs);
290  return success();
291  }
292 };
293 
294 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
295  MLIRContext *context) {
297 }
298 
299 void ExecuteRegionOp::getSuccessorRegions(
301  // If the predecessor is the ExecuteRegionOp, branch into the body.
302  if (point.isParent()) {
303  regions.push_back(RegionSuccessor(&getRegion()));
304  return;
305  }
306 
307  // Otherwise, the region branches back to the parent operation.
308  regions.push_back(RegionSuccessor(getResults()));
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // ConditionOp
313 //===----------------------------------------------------------------------===//
314 
317  assert((point.isParent() || point == getParentOp().getAfter()) &&
318  "condition op can only exit the loop or branch to the after"
319  "region");
320  // Pass all operands except the condition to the successor region.
321  return getArgsMutable();
322 }
323 
324 void ConditionOp::getSuccessorRegions(
326  FoldAdaptor adaptor(operands, *this);
327 
328  WhileOp whileOp = getParentOp();
329 
330  // Condition can either lead to the after region or back to the parent op
331  // depending on whether the condition is true or not.
332  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
333  if (!boolAttr || boolAttr.getValue())
334  regions.emplace_back(&whileOp.getAfter(),
335  whileOp.getAfter().getArguments());
336  if (!boolAttr || !boolAttr.getValue())
337  regions.emplace_back(whileOp.getResults());
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // ForOp
342 //===----------------------------------------------------------------------===//
343 
344 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
345  Value ub, Value step, ValueRange initArgs,
346  BodyBuilderFn bodyBuilder, bool unsignedCmp) {
347  OpBuilder::InsertionGuard guard(builder);
348 
349  if (unsignedCmp)
350  result.addAttribute(getUnsignedCmpAttrName(result.name),
351  builder.getUnitAttr());
352  result.addOperands({lb, ub, step});
353  result.addOperands(initArgs);
354  for (Value v : initArgs)
355  result.addTypes(v.getType());
356  Type t = lb.getType();
357  Region *bodyRegion = result.addRegion();
358  Block *bodyBlock = builder.createBlock(bodyRegion);
359  bodyBlock->addArgument(t, result.location);
360  for (Value v : initArgs)
361  bodyBlock->addArgument(v.getType(), v.getLoc());
362 
363  // Create the default terminator if the builder is not provided and if the
364  // iteration arguments are not provided. Otherwise, leave this to the caller
365  // because we don't know which values to return from the loop.
366  if (initArgs.empty() && !bodyBuilder) {
367  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
368  } else if (bodyBuilder) {
369  OpBuilder::InsertionGuard guard(builder);
370  builder.setInsertionPointToStart(bodyBlock);
371  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
372  bodyBlock->getArguments().drop_front());
373  }
374 }
375 
376 LogicalResult ForOp::verify() {
377  // Check that the number of init args and op results is the same.
378  if (getInitArgs().size() != getNumResults())
379  return emitOpError(
380  "mismatch in number of loop-carried values and defined values");
381 
382  return success();
383 }
384 
385 LogicalResult ForOp::verifyRegions() {
386  // Check that the body defines as single block argument for the induction
387  // variable.
388  if (getInductionVar().getType() != getLowerBound().getType())
389  return emitOpError(
390  "expected induction variable to be same type as bounds and step");
391 
392  if (getNumRegionIterArgs() != getNumResults())
393  return emitOpError(
394  "mismatch in number of basic block args and defined values");
395 
396  auto initArgs = getInitArgs();
397  auto iterArgs = getRegionIterArgs();
398  auto opResults = getResults();
399  unsigned i = 0;
400  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
401  if (std::get<0>(e).getType() != std::get<2>(e).getType())
402  return emitOpError() << "types mismatch between " << i
403  << "th iter operand and defined value";
404  if (std::get<1>(e).getType() != std::get<2>(e).getType())
405  return emitOpError() << "types mismatch between " << i
406  << "th iter region arg and defined value";
407 
408  ++i;
409  }
410  return success();
411 }
412 
413 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
414  return SmallVector<Value>{getInductionVar()};
415 }
416 
417 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
419 }
420 
421 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
422  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
423 }
424 
425 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
427 }
428 
429 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
430 
431 /// Promotes the loop body of a forOp to its containing block if the forOp
432 /// it can be determined that the loop has a single iteration.
433 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
434  std::optional<APInt> tripCount = getStaticTripCount();
435  LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
436  << " for loop "
437  << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
438  if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
439  return failure();
440 
441  if (*tripCount == 0) {
442  rewriter.replaceAllUsesWith(getResults(), getInitArgs());
443  rewriter.eraseOp(*this);
444  return success();
445  }
446 
447  // Replace all results with the yielded values.
448  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
449  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
450 
451  // Replace block arguments with lower bound (replacement for IV) and
452  // iter_args.
453  SmallVector<Value> bbArgReplacements;
454  bbArgReplacements.push_back(getLowerBound());
455  llvm::append_range(bbArgReplacements, getInitArgs());
456 
457  // Move the loop body operations to the loop's containing block.
458  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
459  getOperation()->getIterator(), bbArgReplacements);
460 
461  // Erase the old terminator and the loop.
462  rewriter.eraseOp(yieldOp);
463  rewriter.eraseOp(*this);
464 
465  return success();
466 }
467 
468 /// Prints the initialization list in the form of
469 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
470 /// where 'inner' values are assumed to be region arguments and 'outer' values
471 /// are regular SSA values.
473  Block::BlockArgListType blocksArgs,
474  ValueRange initializers,
475  StringRef prefix = "") {
476  assert(blocksArgs.size() == initializers.size() &&
477  "expected same length of arguments and initializers");
478  if (initializers.empty())
479  return;
480 
481  p << prefix << '(';
482  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
483  p << std::get<0>(it) << " = " << std::get<1>(it);
484  });
485  p << ")";
486 }
487 
488 void ForOp::print(OpAsmPrinter &p) {
489  if (getUnsignedCmp())
490  p << " unsigned";
491 
492  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
493  << getUpperBound() << " step " << getStep();
494 
495  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
496  if (!getInitArgs().empty())
497  p << " -> (" << getInitArgs().getTypes() << ')';
498  p << ' ';
499  if (Type t = getInductionVar().getType(); !t.isIndex())
500  p << " : " << t << ' ';
501  p.printRegion(getRegion(),
502  /*printEntryBlockArgs=*/false,
503  /*printBlockTerminators=*/!getInitArgs().empty());
504  p.printOptionalAttrDict((*this)->getAttrs(),
505  /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
506 }
507 
508 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
509  auto &builder = parser.getBuilder();
510  Type type;
511 
512  OpAsmParser::Argument inductionVariable;
513  OpAsmParser::UnresolvedOperand lb, ub, step;
514 
515  if (succeeded(parser.parseOptionalKeyword("unsigned")))
516  result.addAttribute(getUnsignedCmpAttrName(result.name),
517  builder.getUnitAttr());
518 
519  // Parse the induction variable followed by '='.
520  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
521  // Parse loop bounds.
522  parser.parseOperand(lb) || parser.parseKeyword("to") ||
523  parser.parseOperand(ub) || parser.parseKeyword("step") ||
524  parser.parseOperand(step))
525  return failure();
526 
527  // Parse the optional initial iteration arguments.
530  regionArgs.push_back(inductionVariable);
531 
532  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
533  if (hasIterArgs) {
534  // Parse assignment list and results type list.
535  if (parser.parseAssignmentList(regionArgs, operands) ||
536  parser.parseArrowTypeList(result.types))
537  return failure();
538  }
539 
540  if (regionArgs.size() != result.types.size() + 1)
541  return parser.emitError(
542  parser.getNameLoc(),
543  "mismatch in number of loop-carried values and defined values");
544 
545  // Parse optional type, else assume Index.
546  if (parser.parseOptionalColon())
547  type = builder.getIndexType();
548  else if (parser.parseType(type))
549  return failure();
550 
551  // Set block argument types, so that they are known when parsing the region.
552  regionArgs.front().type = type;
553  for (auto [iterArg, type] :
554  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
555  iterArg.type = type;
556 
557  // Parse the body region.
558  Region *body = result.addRegion();
559  if (parser.parseRegion(*body, regionArgs))
560  return failure();
561  ForOp::ensureTerminator(*body, builder, result.location);
562 
563  // Resolve input operands. This should be done after parsing the region to
564  // catch invalid IR where operands were defined inside of the region.
565  if (parser.resolveOperand(lb, type, result.operands) ||
566  parser.resolveOperand(ub, type, result.operands) ||
567  parser.resolveOperand(step, type, result.operands))
568  return failure();
569  if (hasIterArgs) {
570  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
571  operands, result.types)) {
572  Type type = std::get<2>(argOperandType);
573  std::get<0>(argOperandType).type = type;
574  if (parser.resolveOperand(std::get<1>(argOperandType), type,
575  result.operands))
576  return failure();
577  }
578  }
579 
580  // Parse the optional attribute list.
581  if (parser.parseOptionalAttrDict(result.attributes))
582  return failure();
583 
584  return success();
585 }
586 
587 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
588 
589 Block::BlockArgListType ForOp::getRegionIterArgs() {
590  return getBody()->getArguments().drop_front(getNumInductionVars());
591 }
592 
593 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
594  return getInitArgsMutable();
595 }
596 
597 FailureOr<LoopLikeOpInterface>
598 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
599  ValueRange newInitOperands,
600  bool replaceInitOperandUsesInLoop,
601  const NewYieldValuesFn &newYieldValuesFn) {
602  // Create a new loop before the existing one, with the extra operands.
603  OpBuilder::InsertionGuard g(rewriter);
604  rewriter.setInsertionPoint(getOperation());
605  auto inits = llvm::to_vector(getInitArgs());
606  inits.append(newInitOperands.begin(), newInitOperands.end());
607  scf::ForOp newLoop = scf::ForOp::create(
608  rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
609  [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
610  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
611 
612  // Generate the new yield values and append them to the scf.yield operation.
613  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
614  ArrayRef<BlockArgument> newIterArgs =
615  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
616  {
617  OpBuilder::InsertionGuard g(rewriter);
618  rewriter.setInsertionPoint(yieldOp);
619  SmallVector<Value> newYieldedValues =
620  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
621  assert(newInitOperands.size() == newYieldedValues.size() &&
622  "expected as many new yield values as new iter operands");
623  rewriter.modifyOpInPlace(yieldOp, [&]() {
624  yieldOp.getResultsMutable().append(newYieldedValues);
625  });
626  }
627 
628  // Move the loop body to the new op.
629  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
630  newLoop.getBody()->getArguments().take_front(
631  getBody()->getNumArguments()));
632 
633  if (replaceInitOperandUsesInLoop) {
634  // Replace all uses of `newInitOperands` with the corresponding basic block
635  // arguments.
636  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
637  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
638  [&](OpOperand &use) {
639  Operation *user = use.getOwner();
640  return newLoop->isProperAncestor(user);
641  });
642  }
643  }
644 
645  // Replace the old loop.
646  rewriter.replaceOp(getOperation(),
647  newLoop->getResults().take_front(getNumResults()));
648  return cast<LoopLikeOpInterface>(newLoop.getOperation());
649 }
650 
652  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
653  if (!ivArg)
654  return ForOp();
655  assert(ivArg.getOwner() && "unlinked block argument");
656  auto *containingOp = ivArg.getOwner()->getParentOp();
657  return dyn_cast_or_null<ForOp>(containingOp);
658 }
659 
660 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
661  return getInitArgs();
662 }
663 
664 void ForOp::getSuccessorRegions(RegionBranchPoint point,
666  // Both the operation itself and the region may be branching into the body or
667  // back into the operation itself. It is possible for loop not to enter the
668  // body.
669  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
670  regions.push_back(RegionSuccessor(getResults()));
671 }
672 
673 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
674 
675 /// Promotes the loop body of a forallOp to its containing block if it can be
676 /// determined that the loop has a single iteration.
677 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
678  for (auto [lb, ub, step] :
679  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
680  auto tripCount =
681  constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
682  if (!tripCount.has_value() || *tripCount != 1)
683  return failure();
684  }
685 
686  promote(rewriter, *this);
687  return success();
688 }
689 
690 Block::BlockArgListType ForallOp::getRegionIterArgs() {
691  return getBody()->getArguments().drop_front(getRank());
692 }
693 
694 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
695  return getOutputsMutable();
696 }
697 
698 /// Promotes the loop body of a scf::ForallOp to its containing block.
699 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
700  OpBuilder::InsertionGuard g(rewriter);
701  scf::InParallelOp terminator = forallOp.getTerminator();
702 
703  // Replace block arguments with lower bounds (replacements for IVs) and
704  // outputs.
705  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
706  bbArgReplacements.append(forallOp.getOutputs().begin(),
707  forallOp.getOutputs().end());
708 
709  // Move the loop body operations to the loop's containing block.
710  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
711  forallOp->getIterator(), bbArgReplacements);
712 
713  // Replace the terminator with tensor.insert_slice ops.
714  rewriter.setInsertionPointAfter(forallOp);
715  SmallVector<Value> results;
716  results.reserve(forallOp.getResults().size());
717  for (auto &yieldingOp : terminator.getYieldingOps()) {
718  auto parallelInsertSliceOp =
719  dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
720  if (!parallelInsertSliceOp)
721  continue;
722 
723  Value dst = parallelInsertSliceOp.getDest();
724  Value src = parallelInsertSliceOp.getSource();
725  if (llvm::isa<TensorType>(src.getType())) {
726  results.push_back(tensor::InsertSliceOp::create(
727  rewriter, forallOp.getLoc(), dst.getType(), src, dst,
728  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
729  parallelInsertSliceOp.getStrides(),
730  parallelInsertSliceOp.getStaticOffsets(),
731  parallelInsertSliceOp.getStaticSizes(),
732  parallelInsertSliceOp.getStaticStrides()));
733  } else {
734  llvm_unreachable("unsupported terminator");
735  }
736  }
737  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
738 
739  // Erase the old terminator and the loop.
740  rewriter.eraseOp(terminator);
741  rewriter.eraseOp(forallOp);
742 }
743 
745  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
746  ValueRange steps, ValueRange iterArgs,
748  bodyBuilder) {
749  assert(lbs.size() == ubs.size() &&
750  "expected the same number of lower and upper bounds");
751  assert(lbs.size() == steps.size() &&
752  "expected the same number of lower bounds and steps");
753 
754  // If there are no bounds, call the body-building function and return early.
755  if (lbs.empty()) {
756  ValueVector results =
757  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
758  : ValueVector();
759  assert(results.size() == iterArgs.size() &&
760  "loop nest body must return as many values as loop has iteration "
761  "arguments");
762  return LoopNest{{}, std::move(results)};
763  }
764 
765  // First, create the loop structure iteratively using the body-builder
766  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
767  OpBuilder::InsertionGuard guard(builder);
770  loops.reserve(lbs.size());
771  ivs.reserve(lbs.size());
772  ValueRange currentIterArgs = iterArgs;
773  Location currentLoc = loc;
774  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
775  auto loop = scf::ForOp::create(
776  builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
777  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
778  ValueRange args) {
779  ivs.push_back(iv);
780  // It is safe to store ValueRange args because it points to block
781  // arguments of a loop operation that we also own.
782  currentIterArgs = args;
783  currentLoc = nestedLoc;
784  });
785  // Set the builder to point to the body of the newly created loop. We don't
786  // do this in the callback because the builder is reset when the callback
787  // returns.
788  builder.setInsertionPointToStart(loop.getBody());
789  loops.push_back(loop);
790  }
791 
792  // For all loops but the innermost, yield the results of the nested loop.
793  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
794  builder.setInsertionPointToEnd(loops[i].getBody());
795  scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
796  }
797 
798  // In the body of the innermost loop, call the body building function if any
799  // and yield its results.
800  builder.setInsertionPointToStart(loops.back().getBody());
801  ValueVector results = bodyBuilder
802  ? bodyBuilder(builder, currentLoc, ivs,
803  loops.back().getRegionIterArgs())
804  : ValueVector();
805  assert(results.size() == iterArgs.size() &&
806  "loop nest body must return as many values as loop has iteration "
807  "arguments");
808  builder.setInsertionPointToEnd(loops.back().getBody());
809  scf::YieldOp::create(builder, loc, results);
810 
811  // Return the loops.
812  ValueVector nestResults;
813  llvm::append_range(nestResults, loops.front().getResults());
814  return LoopNest{std::move(loops), std::move(nestResults)};
815 }
816 
818  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
819  ValueRange steps,
820  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
821  // Delegate to the main function by wrapping the body builder.
822  return buildLoopNest(builder, loc, lbs, ubs, steps, {},
823  [&bodyBuilder](OpBuilder &nestedBuilder,
824  Location nestedLoc, ValueRange ivs,
825  ValueRange) -> ValueVector {
826  if (bodyBuilder)
827  bodyBuilder(nestedBuilder, nestedLoc, ivs);
828  return {};
829  });
830 }
831 
834  OpOperand &operand, Value replacement,
835  const ValueTypeCastFnTy &castFn) {
836  assert(operand.getOwner() == forOp);
837  Type oldType = operand.get().getType(), newType = replacement.getType();
838 
839  // 1. Create new iter operands, exactly 1 is replaced.
840  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
841  "expected an iter OpOperand");
842  assert(operand.get().getType() != replacement.getType() &&
843  "Expected a different type");
844  SmallVector<Value> newIterOperands;
845  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
846  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
847  newIterOperands.push_back(replacement);
848  continue;
849  }
850  newIterOperands.push_back(opOperand.get());
851  }
852 
853  // 2. Create the new forOp shell.
854  scf::ForOp newForOp = scf::ForOp::create(
855  rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
856  forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
857  forOp.getUnsignedCmp());
858  newForOp->setAttrs(forOp->getAttrs());
859  Block &newBlock = newForOp.getRegion().front();
860  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
861  newBlock.getArguments().end());
862 
863  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
864  // corresponding to the `replacement` value.
865  OpBuilder::InsertionGuard g(rewriter);
866  rewriter.setInsertionPointToStart(&newBlock);
867  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
868  &newForOp->getOpOperand(operand.getOperandNumber()));
869  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
870  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
871 
872  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
873  Block &oldBlock = forOp.getRegion().front();
874  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
875 
876  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
877  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
878  rewriter.setInsertionPoint(clonedYieldOp);
879  unsigned yieldIdx =
880  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
881  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
882  clonedYieldOp.getOperand(yieldIdx));
883  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
884  newYieldOperands[yieldIdx] = castOut;
885  scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
886  rewriter.eraseOp(clonedYieldOp);
887 
888  // 6. Inject an outgoing cast op after the forOp.
889  rewriter.setInsertionPointAfter(newForOp);
890  SmallVector<Value> newResults = newForOp.getResults();
891  newResults[yieldIdx] =
892  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
893 
894  return newResults;
895 }
896 
897 namespace {
898 // Fold away ForOp iter arguments when:
899 // 1) The op yields the iter arguments.
900 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
901 // 3) The iter arguments have no use and the corresponding (operation) results
902 // have no use.
903 //
904 // These arguments must be defined outside of the ForOp region and can just be
905 // forwarded after simplifying the op inits, yields and returns.
906 //
907 // The implementation uses `inlineBlockBefore` to steal the content of the
908 // original ForOp and avoid cloning.
909 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
911 
912  LogicalResult matchAndRewrite(scf::ForOp forOp,
913  PatternRewriter &rewriter) const final {
914  bool canonicalize = false;
915 
916  // An internal flat vector of block transfer
917  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
918  // transformed block argument mappings. This plays the role of a
919  // IRMapping for the particular use case of calling into
920  // `inlineBlockBefore`.
921  int64_t numResults = forOp.getNumResults();
922  SmallVector<bool, 4> keepMask;
923  keepMask.reserve(numResults);
924  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
925  newResultValues;
926  newBlockTransferArgs.reserve(1 + numResults);
927  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
928  newIterArgs.reserve(forOp.getInitArgs().size());
929  newYieldValues.reserve(numResults);
930  newResultValues.reserve(numResults);
931  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
932  for (auto [init, arg, result, yielded] :
933  llvm::zip(forOp.getInitArgs(), // iter from outside
934  forOp.getRegionIterArgs(), // iter inside region
935  forOp.getResults(), // op results
936  forOp.getYieldedValues() // iter yield
937  )) {
938  // Forwarded is `true` when:
939  // 1) The region `iter` argument is yielded.
940  // 2) The region `iter` argument the corresponding input is yielded.
941  // 3) The region `iter` argument has no use, and the corresponding op
942  // result has no use.
943  bool forwarded = (arg == yielded) || (init == yielded) ||
944  (arg.use_empty() && result.use_empty());
945  if (forwarded) {
946  canonicalize = true;
947  keepMask.push_back(false);
948  newBlockTransferArgs.push_back(init);
949  newResultValues.push_back(init);
950  continue;
951  }
952 
953  // Check if a previous kept argument always has the same values for init
954  // and yielded values.
955  if (auto it = initYieldToArg.find({init, yielded});
956  it != initYieldToArg.end()) {
957  canonicalize = true;
958  keepMask.push_back(false);
959  auto [sameArg, sameResult] = it->second;
960  rewriter.replaceAllUsesWith(arg, sameArg);
961  rewriter.replaceAllUsesWith(result, sameResult);
962  // The replacement value doesn't matter because there are no uses.
963  newBlockTransferArgs.push_back(init);
964  newResultValues.push_back(init);
965  continue;
966  }
967 
968  // This value is kept.
969  initYieldToArg.insert({{init, yielded}, {arg, result}});
970  keepMask.push_back(true);
971  newIterArgs.push_back(init);
972  newYieldValues.push_back(yielded);
973  newBlockTransferArgs.push_back(Value()); // placeholder with null value
974  newResultValues.push_back(Value()); // placeholder with null value
975  }
976 
977  if (!canonicalize)
978  return failure();
979 
980  scf::ForOp newForOp =
981  scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
982  forOp.getUpperBound(), forOp.getStep(), newIterArgs,
983  /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
984  newForOp->setAttrs(forOp->getAttrs());
985  Block &newBlock = newForOp.getRegion().front();
986 
987  // Replace the null placeholders with newly constructed values.
988  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
989  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
990  idx != e; ++idx) {
991  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
992  Value &newResultVal = newResultValues[idx];
993  assert((blockTransferArg && newResultVal) ||
994  (!blockTransferArg && !newResultVal));
995  if (!blockTransferArg) {
996  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
997  newResultVal = newForOp.getResult(collapsedIdx++);
998  }
999  }
1000 
1001  Block &oldBlock = forOp.getRegion().front();
1002  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
1003  "unexpected argument size mismatch");
1004 
1005  // No results case: the scf::ForOp builder already created a zero
1006  // result terminator. Merge before this terminator and just get rid of the
1007  // original terminator that has been merged in.
1008  if (newIterArgs.empty()) {
1009  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1010  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
1011  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
1012  rewriter.replaceOp(forOp, newResultValues);
1013  return success();
1014  }
1015 
1016  // No terminator case: merge and rewrite the merged terminator.
1017  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
1018  OpBuilder::InsertionGuard g(rewriter);
1019  rewriter.setInsertionPoint(mergedTerminator);
1020  SmallVector<Value, 4> filteredOperands;
1021  filteredOperands.reserve(newResultValues.size());
1022  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
1023  if (keepMask[idx])
1024  filteredOperands.push_back(mergedTerminator.getOperand(idx));
1025  scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
1026  filteredOperands);
1027  };
1028 
1029  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1030  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1031  cloneFilteredTerminator(mergedYieldOp);
1032  rewriter.eraseOp(mergedYieldOp);
1033  rewriter.replaceOp(forOp, newResultValues);
1034  return success();
1035  }
1036 };
1037 
1038 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1039 /// single-iteration loops with their bodies, and removes empty loops that
1040 /// iterate at least once and only return values defined outside of the loop.
1041 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1043 
1044  LogicalResult matchAndRewrite(ForOp op,
1045  PatternRewriter &rewriter) const override {
1046  std::optional<APInt> tripCount = op.getStaticTripCount();
1047  if (!tripCount.has_value())
1048  return rewriter.notifyMatchFailure(op,
1049  "can't compute constant trip count");
1050 
1051  if (tripCount->isZero()) {
1052  LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
1053  << OpWithFlags(op, OpPrintingFlags().skipRegions());
1054  rewriter.replaceOp(op, op.getInitArgs());
1055  return success();
1056  }
1057 
1058  if (tripCount->getSExtValue() == 1) {
1059  LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
1060  << OpWithFlags(op, OpPrintingFlags().skipRegions());
1061  SmallVector<Value, 4> blockArgs;
1062  blockArgs.reserve(op.getInitArgs().size() + 1);
1063  blockArgs.push_back(op.getLowerBound());
1064  llvm::append_range(blockArgs, op.getInitArgs());
1065  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1066  return success();
1067  }
1068 
1069  // Now we are left with loops that have more than 1 iterations.
1070  Block &block = op.getRegion().front();
1071  if (!llvm::hasSingleElement(block))
1072  return failure();
1073  // The loop is empty and iterates at least once, if it only returns values
1074  // defined outside of the loop, remove it and replace it with yield values.
1075  if (llvm::any_of(op.getYieldedValues(),
1076  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1077  return failure();
1078  LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
1079  "yield operands for loop "
1080  << OpWithFlags(op, OpPrintingFlags().skipRegions());
1081  rewriter.replaceOp(op, op.getYieldedValues());
1082  return success();
1083  }
1084 };
1085 
1086 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1087 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1088 ///
1089 /// ```
1090 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1091 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1092 /// -> (tensor<?x?xf32>) {
1093 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1094 /// scf.yield %2 : tensor<?x?xf32>
1095 /// }
1096 /// use_of(%1)
1097 /// ```
1098 ///
1099 /// folds into:
1100 ///
1101 /// ```
1102 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1103 /// -> (tensor<32x1024xf32>) {
1104 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1105 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1106 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1107 /// scf.yield %4 : tensor<32x1024xf32>
1108 /// }
1109 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1110 /// use_of(%1)
1111 /// ```
1112 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1114 
1115  LogicalResult matchAndRewrite(ForOp op,
1116  PatternRewriter &rewriter) const override {
1117  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1118  OpOperand &iterOpOperand = std::get<0>(it);
1119  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1120  if (!incomingCast ||
1121  incomingCast.getSource().getType() == incomingCast.getType())
1122  continue;
1123  // If the dest type of the cast does not preserve static information in
1124  // the source type.
1126  incomingCast.getDest().getType(),
1127  incomingCast.getSource().getType()))
1128  continue;
1129  if (!std::get<1>(it).hasOneUse())
1130  continue;
1131 
1132  // Create a new ForOp with that iter operand replaced.
1133  rewriter.replaceOp(
1135  rewriter, op, iterOpOperand, incomingCast.getSource(),
1136  [](OpBuilder &b, Location loc, Type type, Value source) {
1137  return tensor::CastOp::create(b, loc, type, source);
1138  }));
1139  return success();
1140  }
1141  return failure();
1142  }
1143 };
1144 
1145 } // namespace
1146 
1147 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1148  MLIRContext *context) {
1149  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1150  context);
1151 }
1152 
1153 std::optional<APInt> ForOp::getConstantStep() {
1154  IntegerAttr step;
1155  if (matchPattern(getStep(), m_Constant(&step)))
1156  return step.getValue();
1157  return {};
1158 }
1159 
1160 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1161  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1162 }
1163 
1164 Speculation::Speculatability ForOp::getSpeculatability() {
1165  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1166  // and End.
1167  if (auto constantStep = getConstantStep())
1168  if (*constantStep == 1)
1170 
1171  // For Step != 1, the loop may not terminate. We can add more smarts here if
1172  // needed.
1174 }
1175 
1176 std::optional<APInt> ForOp::getStaticTripCount() {
1177  return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
1178  /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // ForallOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 LogicalResult ForallOp::verify() {
1186  unsigned numLoops = getRank();
1187  // Check number of outputs.
1188  if (getNumResults() != getOutputs().size())
1189  return emitOpError("produces ")
1190  << getNumResults() << " results, but has only "
1191  << getOutputs().size() << " outputs";
1192 
1193  // Check that the body defines block arguments for thread indices and outputs.
1194  auto *body = getBody();
1195  if (body->getNumArguments() != numLoops + getOutputs().size())
1196  return emitOpError("region expects ") << numLoops << " arguments";
1197  for (int64_t i = 0; i < numLoops; ++i)
1198  if (!body->getArgument(i).getType().isIndex())
1199  return emitOpError("expects ")
1200  << i << "-th block argument to be an index";
1201  for (unsigned i = 0; i < getOutputs().size(); ++i)
1202  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1203  return emitOpError("type mismatch between ")
1204  << i << "-th output and corresponding block argument";
1205  if (getMapping().has_value() && !getMapping()->empty()) {
1206  if (getDeviceMappingAttrs().size() != numLoops)
1207  return emitOpError() << "mapping attribute size must match op rank";
1208  if (failed(getDeviceMaskingAttr()))
1209  return emitOpError() << getMappingAttrName()
1210  << " supports at most one device masking attribute";
1211  }
1212 
1213  // Verify mixed static/dynamic control variables.
1214  Operation *op = getOperation();
1215  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1216  getStaticLowerBound(),
1217  getDynamicLowerBound())))
1218  return failure();
1219  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1220  getStaticUpperBound(),
1221  getDynamicUpperBound())))
1222  return failure();
1223  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1224  getStaticStep(), getDynamicStep())))
1225  return failure();
1226 
1227  return success();
1228 }
1229 
1230 void ForallOp::print(OpAsmPrinter &p) {
1231  Operation *op = getOperation();
1232  p << " (" << getInductionVars();
1233  if (isNormalized()) {
1234  p << ") in ";
1235  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1236  /*valueTypes=*/{}, /*scalables=*/{},
1238  } else {
1239  p << ") = ";
1240  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1241  /*valueTypes=*/{}, /*scalables=*/{},
1243  p << " to ";
1244  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1245  /*valueTypes=*/{}, /*scalables=*/{},
1247  p << " step ";
1248  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1249  /*valueTypes=*/{}, /*scalables=*/{},
1251  }
1252  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1253  p << " ";
1254  if (!getRegionOutArgs().empty())
1255  p << "-> (" << getResultTypes() << ") ";
1256  p.printRegion(getRegion(),
1257  /*printEntryBlockArgs=*/false,
1258  /*printBlockTerminators=*/getNumResults() > 0);
1259  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1260  getStaticLowerBoundAttrName(),
1261  getStaticUpperBoundAttrName(),
1262  getStaticStepAttrName()});
1263 }
1264 
1265 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1266  OpBuilder b(parser.getContext());
1267  auto indexType = b.getIndexType();
1268 
1269  // Parse an opening `(` followed by thread index variables followed by `)`
1270  // TODO: when we can refer to such "induction variable"-like handles from the
1271  // declarative assembly format, we can implement the parser as a custom hook.
1274  return failure();
1275 
1276  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1277  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1278  dynamicSteps;
1279  if (succeeded(parser.parseOptionalKeyword("in"))) {
1280  // Parse upper bounds.
1281  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1282  /*valueTypes=*/nullptr,
1284  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1285  return failure();
1286 
1287  unsigned numLoops = ivs.size();
1288  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1289  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1290  } else {
1291  // Parse lower bounds.
1292  if (parser.parseEqual() ||
1293  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1294  /*valueTypes=*/nullptr,
1296 
1297  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1298  return failure();
1299 
1300  // Parse upper bounds.
1301  if (parser.parseKeyword("to") ||
1302  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1303  /*valueTypes=*/nullptr,
1305  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1306  return failure();
1307 
1308  // Parse step values.
1309  if (parser.parseKeyword("step") ||
1310  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1311  /*valueTypes=*/nullptr,
1313  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1314  return failure();
1315  }
1316 
1317  // Parse out operands and results.
1320  SMLoc outOperandsLoc = parser.getCurrentLocation();
1321  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1322  if (outOperands.size() != result.types.size())
1323  return parser.emitError(outOperandsLoc,
1324  "mismatch between out operands and types");
1325  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1326  parser.parseOptionalArrowTypeList(result.types) ||
1327  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1328  result.operands))
1329  return failure();
1330  }
1331 
1332  // Parse region.
1334  std::unique_ptr<Region> region = std::make_unique<Region>();
1335  for (auto &iv : ivs) {
1336  iv.type = b.getIndexType();
1337  regionArgs.push_back(iv);
1338  }
1339  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1340  auto &out = it.value();
1341  out.type = result.types[it.index()];
1342  regionArgs.push_back(out);
1343  }
1344  if (parser.parseRegion(*region, regionArgs))
1345  return failure();
1346 
1347  // Ensure terminator and move region.
1348  ForallOp::ensureTerminator(*region, b, result.location);
1349  result.addRegion(std::move(region));
1350 
1351  // Parse the optional attribute list.
1352  if (parser.parseOptionalAttrDict(result.attributes))
1353  return failure();
1354 
1355  result.addAttribute("staticLowerBound", staticLbs);
1356  result.addAttribute("staticUpperBound", staticUbs);
1357  result.addAttribute("staticStep", staticSteps);
1358  result.addAttribute("operandSegmentSizes",
1360  {static_cast<int32_t>(dynamicLbs.size()),
1361  static_cast<int32_t>(dynamicUbs.size()),
1362  static_cast<int32_t>(dynamicSteps.size()),
1363  static_cast<int32_t>(outOperands.size())}));
1364  return success();
1365 }
1366 
1367 // Builder that takes loop bounds.
1368 void ForallOp::build(
1371  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1372  std::optional<ArrayAttr> mapping,
1373  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1374  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1375  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1376  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1377  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1378  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1379 
1380  result.addOperands(dynamicLbs);
1381  result.addOperands(dynamicUbs);
1382  result.addOperands(dynamicSteps);
1383  result.addOperands(outputs);
1384  result.addTypes(TypeRange(outputs));
1385 
1386  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1387  b.getDenseI64ArrayAttr(staticLbs));
1388  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1389  b.getDenseI64ArrayAttr(staticUbs));
1390  result.addAttribute(getStaticStepAttrName(result.name),
1391  b.getDenseI64ArrayAttr(staticSteps));
1392  result.addAttribute(
1393  "operandSegmentSizes",
1394  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1395  static_cast<int32_t>(dynamicUbs.size()),
1396  static_cast<int32_t>(dynamicSteps.size()),
1397  static_cast<int32_t>(outputs.size())}));
1398  if (mapping.has_value()) {
1400  mapping.value());
1401  }
1402 
1403  Region *bodyRegion = result.addRegion();
1405  b.createBlock(bodyRegion);
1406  Block &bodyBlock = bodyRegion->front();
1407 
1408  // Add block arguments for indices and outputs.
1409  bodyBlock.addArguments(
1410  SmallVector<Type>(lbs.size(), b.getIndexType()),
1411  SmallVector<Location>(staticLbs.size(), result.location));
1412  bodyBlock.addArguments(
1413  TypeRange(outputs),
1414  SmallVector<Location>(outputs.size(), result.location));
1415 
1416  b.setInsertionPointToStart(&bodyBlock);
1417  if (!bodyBuilderFn) {
1418  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1419  return;
1420  }
1421  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1422 }
1423 
1424 // Builder that takes loop bounds.
1425 void ForallOp::build(
1427  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1428  std::optional<ArrayAttr> mapping,
1429  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1430  unsigned numLoops = ubs.size();
1431  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1432  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1433  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1434 }
1435 
1436 // Checks if the lbs are zeros and steps are ones.
1437 bool ForallOp::isNormalized() {
1438  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1439  return llvm::all_of(results, [&](OpFoldResult ofr) {
1440  auto intValue = getConstantIntValue(ofr);
1441  return intValue.has_value() && intValue == val;
1442  });
1443  };
1444  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1445 }
1446 
1447 InParallelOp ForallOp::getTerminator() {
1448  return cast<InParallelOp>(getBody()->getTerminator());
1449 }
1450 
1451 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1452  SmallVector<Operation *> storeOps;
1453  for (Operation *user : bbArg.getUsers()) {
1454  if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
1455  storeOps.push_back(parallelOp);
1456  }
1457  }
1458  return storeOps;
1459 }
1460 
1461 SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1463  if (!getMapping())
1464  return res;
1465  for (auto attr : getMapping()->getValue()) {
1466  auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1467  if (m)
1468  res.push_back(m);
1469  }
1470  return res;
1471 }
1472 
1473 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1474  DeviceMaskingAttrInterface res;
1475  if (!getMapping())
1476  return res;
1477  for (auto attr : getMapping()->getValue()) {
1478  auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1479  if (m && res)
1480  return failure();
1481  if (m)
1482  res = m;
1483  }
1484  return res;
1485 }
1486 
1487 bool ForallOp::usesLinearMapping() {
1488  SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1489  if (ifaces.empty())
1490  return false;
1491  return ifaces.front().isLinearMapping();
1492 }
1493 
1494 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1495  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1496 }
1497 
1498 // Get lower bounds as OpFoldResult.
1499 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1500  Builder b(getOperation()->getContext());
1501  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1502 }
1503 
1504 // Get upper bounds as OpFoldResult.
1505 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1506  Builder b(getOperation()->getContext());
1507  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1508 }
1509 
1510 // Get steps as OpFoldResult.
1511 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1512  Builder b(getOperation()->getContext());
1513  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1514 }
1515 
1517  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1518  if (!tidxArg)
1519  return ForallOp();
1520  assert(tidxArg.getOwner() && "unlinked block argument");
1521  auto *containingOp = tidxArg.getOwner()->getParentOp();
1522  return dyn_cast<ForallOp>(containingOp);
1523 }
1524 
1525 namespace {
1526 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1527 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1529 
1530  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1531  PatternRewriter &rewriter) const final {
1532  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1533  if (!forallOp)
1534  return failure();
1535  Value sharedOut =
1536  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1537  ->get();
1538  rewriter.modifyOpInPlace(
1539  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1540  return success();
1541  }
1542 };
1543 
1544 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1545 public:
1547 
1548  LogicalResult matchAndRewrite(ForallOp op,
1549  PatternRewriter &rewriter) const override {
1550  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1551  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1552  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1553  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1554  failed(foldDynamicIndexList(mixedUpperBound)) &&
1555  failed(foldDynamicIndexList(mixedStep)))
1556  return failure();
1557 
1558  rewriter.modifyOpInPlace(op, [&]() {
1559  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1560  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1561  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1562  staticLowerBound);
1563  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1564  op.setStaticLowerBound(staticLowerBound);
1565 
1566  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1567  staticUpperBound);
1568  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1569  op.setStaticUpperBound(staticUpperBound);
1570 
1571  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1572  op.getDynamicStepMutable().assign(dynamicStep);
1573  op.setStaticStep(staticStep);
1574 
1575  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1576  rewriter.getDenseI32ArrayAttr(
1577  {static_cast<int32_t>(dynamicLowerBound.size()),
1578  static_cast<int32_t>(dynamicUpperBound.size()),
1579  static_cast<int32_t>(dynamicStep.size()),
1580  static_cast<int32_t>(op.getNumResults())}));
1581  });
1582  return success();
1583  }
1584 };
1585 
1586 /// The following canonicalization pattern folds the iter arguments of
1587 /// scf.forall op if :-
1588 /// 1. The corresponding result has zero uses.
1589 /// 2. The iter argument is NOT being modified within the loop body.
1590 /// uses.
1591 ///
1592 /// Example of first case :-
1593 /// INPUT:
1594 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1595 /// {
1596 /// ...
1597 /// <SOME USE OF %arg0>
1598 /// <SOME USE OF %arg1>
1599 /// <SOME USE OF %arg2>
1600 /// ...
1601 /// scf.forall.in_parallel {
1602 /// <STORE OP WITH DESTINATION %arg1>
1603 /// <STORE OP WITH DESTINATION %arg0>
1604 /// <STORE OP WITH DESTINATION %arg2>
1605 /// }
1606 /// }
1607 /// return %res#1
1608 ///
1609 /// OUTPUT:
1610 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1611 /// {
1612 /// ...
1613 /// <SOME USE OF %a>
1614 /// <SOME USE OF %new_arg0>
1615 /// <SOME USE OF %c>
1616 /// ...
1617 /// scf.forall.in_parallel {
1618 /// <STORE OP WITH DESTINATION %new_arg0>
1619 /// }
1620 /// }
1621 /// return %res
1622 ///
1623 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1624 /// scf.forall is replaced by their corresponding operands.
1625 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1626 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1627 /// this canonicalization remains valid. For more details, please refer
1628 /// to :
1629 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1630 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1631 /// handles tensor.parallel_insert_slice ops only.
1632 ///
1633 /// Example of second case :-
1634 /// INPUT:
1635 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1636 /// {
1637 /// ...
1638 /// <SOME USE OF %arg0>
1639 /// <SOME USE OF %arg1>
1640 /// ...
1641 /// scf.forall.in_parallel {
1642 /// <STORE OP WITH DESTINATION %arg1>
1643 /// }
1644 /// }
1645 /// return %res#0, %res#1
1646 ///
1647 /// OUTPUT:
1648 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1649 /// {
1650 /// ...
1651 /// <SOME USE OF %a>
1652 /// <SOME USE OF %new_arg0>
1653 /// ...
1654 /// scf.forall.in_parallel {
1655 /// <STORE OP WITH DESTINATION %new_arg0>
1656 /// }
1657 /// }
1658 /// return %a, %res
1659 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1661 
1662  LogicalResult matchAndRewrite(ForallOp forallOp,
1663  PatternRewriter &rewriter) const final {
1664  // Step 1: For a given i-th result of scf.forall, check the following :-
1665  // a. If it has any use.
1666  // b. If the corresponding iter argument is being modified within
1667  // the loop, i.e. has at least one store op with the iter arg as
1668  // its destination operand. For this we use
1669  // ForallOp::getCombiningOps(iter_arg).
1670  //
1671  // Based on the check we maintain the following :-
1672  // a. `resultToDelete` - i-th result of scf.forall that'll be
1673  // deleted.
1674  // b. `resultToReplace` - i-th result of the old scf.forall
1675  // whose uses will be replaced by the new scf.forall.
1676  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1677  // corresponding to the i-th result with at least one use.
1678  SetVector<OpResult> resultToDelete;
1679  SmallVector<Value> resultToReplace;
1680  SmallVector<Value> newOuts;
1681  for (OpResult result : forallOp.getResults()) {
1682  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1683  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1684  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1685  resultToDelete.insert(result);
1686  } else {
1687  resultToReplace.push_back(result);
1688  newOuts.push_back(opOperand->get());
1689  }
1690  }
1691 
1692  // Return early if all results of scf.forall have at least one use and being
1693  // modified within the loop.
1694  if (resultToDelete.empty())
1695  return failure();
1696 
1697  // Step 2: For the the i-th result, do the following :-
1698  // a. Fetch the corresponding BlockArgument.
1699  // b. Look for store ops (currently tensor.parallel_insert_slice)
1700  // with the BlockArgument as its destination operand.
1701  // c. Remove the operations fetched in b.
1702  for (OpResult result : resultToDelete) {
1703  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1704  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1705  SmallVector<Operation *> combiningOps =
1706  forallOp.getCombiningOps(blockArg);
1707  for (Operation *combiningOp : combiningOps)
1708  rewriter.eraseOp(combiningOp);
1709  }
1710 
1711  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1712  // fetched earlier
1713  auto newForallOp = scf::ForallOp::create(
1714  rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1715  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1716  forallOp.getMapping(),
1717  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1718 
1719  // Step 4. Merge the block of the old scf.forall into the newly created
1720  // scf.forall using the new set of arguments.
1721  Block *loopBody = forallOp.getBody();
1722  Block *newLoopBody = newForallOp.getBody();
1723  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1724  // Form initial new bbArg list with just the control operands of the new
1725  // scf.forall op.
1726  SmallVector<Value> newBlockArgs =
1727  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1728  [](BlockArgument b) -> Value { return b; });
1729  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1730  unsigned index = 0;
1731  // Take the new corresponding bbArg if the old bbArg was used as a
1732  // destination in the in_parallel op. For all other bbArgs, use the
1733  // corresponding init_arg from the old scf.forall op.
1734  for (OpResult result : forallOp.getResults()) {
1735  if (resultToDelete.count(result)) {
1736  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1737  } else {
1738  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1739  }
1740  }
1741  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1742 
1743  // Step 5. Replace the uses of result of old scf.forall with that of the new
1744  // scf.forall.
1745  for (auto &&[oldResult, newResult] :
1746  llvm::zip(resultToReplace, newForallOp->getResults()))
1747  rewriter.replaceAllUsesWith(oldResult, newResult);
1748 
1749  // Step 6. Replace the uses of those values that either has no use or are
1750  // not being modified within the loop with the corresponding
1751  // OpOperand.
1752  for (OpResult oldResult : resultToDelete)
1753  rewriter.replaceAllUsesWith(oldResult,
1754  forallOp.getTiedOpOperand(oldResult)->get());
1755  return success();
1756  }
1757 };
1758 
1759 struct ForallOpSingleOrZeroIterationDimsFolder
1760  : public OpRewritePattern<ForallOp> {
1762 
1763  LogicalResult matchAndRewrite(ForallOp op,
1764  PatternRewriter &rewriter) const override {
1765  // Do not fold dimensions if they are mapped to processing units.
1766  if (op.getMapping().has_value() && !op.getMapping()->empty())
1767  return failure();
1768  Location loc = op.getLoc();
1769 
1770  // Compute new loop bounds that omit all single-iteration loop dimensions.
1771  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1772  newMixedSteps;
1773  IRMapping mapping;
1774  for (auto [lb, ub, step, iv] :
1775  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1776  op.getMixedStep(), op.getInductionVars())) {
1777  auto numIterations =
1778  constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1779  if (numIterations.has_value()) {
1780  // Remove the loop if it performs zero iterations.
1781  if (*numIterations == 0) {
1782  rewriter.replaceOp(op, op.getOutputs());
1783  return success();
1784  }
1785  // Replace the loop induction variable by the lower bound if the loop
1786  // performs a single iteration. Otherwise, copy the loop bounds.
1787  if (*numIterations == 1) {
1788  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1789  continue;
1790  }
1791  }
1792  newMixedLowerBounds.push_back(lb);
1793  newMixedUpperBounds.push_back(ub);
1794  newMixedSteps.push_back(step);
1795  }
1796 
1797  // All of the loop dimensions perform a single iteration. Inline loop body.
1798  if (newMixedLowerBounds.empty()) {
1799  promote(rewriter, op);
1800  return success();
1801  }
1802 
1803  // Exit if none of the loop dimensions perform a single iteration.
1804  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1805  return rewriter.notifyMatchFailure(
1806  op, "no dimensions have 0 or 1 iterations");
1807  }
1808 
1809  // Replace the loop by a lower-dimensional loop.
1810  ForallOp newOp;
1811  newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1812  newMixedUpperBounds, newMixedSteps,
1813  op.getOutputs(), std::nullopt, nullptr);
1814  newOp.getBodyRegion().getBlocks().clear();
1815  // The new loop needs to keep all attributes from the old one, except for
1816  // "operandSegmentSizes" and static loop bound attributes which capture
1817  // the outdated information of the old iteration domain.
1818  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1819  newOp.getStaticLowerBoundAttrName(),
1820  newOp.getStaticUpperBoundAttrName(),
1821  newOp.getStaticStepAttrName()};
1822  for (const auto &namedAttr : op->getAttrs()) {
1823  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1824  continue;
1825  rewriter.modifyOpInPlace(newOp, [&]() {
1826  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1827  });
1828  }
1829  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1830  newOp.getRegion().begin(), mapping);
1831  rewriter.replaceOp(op, newOp.getResults());
1832  return success();
1833  }
1834 };
1835 
1836 /// Replace all induction vars with a single trip count with their lower bound.
1837 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1839 
1840  LogicalResult matchAndRewrite(ForallOp op,
1841  PatternRewriter &rewriter) const override {
1842  Location loc = op.getLoc();
1843  bool changed = false;
1844  for (auto [lb, ub, step, iv] :
1845  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1846  op.getMixedStep(), op.getInductionVars())) {
1847  if (iv.hasNUses(0))
1848  continue;
1849  auto numIterations =
1850  constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
1851  if (!numIterations.has_value() || numIterations.value() != 1) {
1852  continue;
1853  }
1854  rewriter.replaceAllUsesWith(
1855  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1856  changed = true;
1857  }
1858  return success(changed);
1859  }
1860 };
1861 
1862 struct FoldTensorCastOfOutputIntoForallOp
1863  : public OpRewritePattern<scf::ForallOp> {
1865 
1866  struct TypeCast {
1867  Type srcType;
1868  Type dstType;
1869  };
1870 
1871  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1872  PatternRewriter &rewriter) const final {
1873  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1874  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1875  for (auto en : llvm::enumerate(newOutputTensors)) {
1876  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1877  if (!castOp)
1878  continue;
1879 
1880  // Only casts that that preserve static information, i.e. will make the
1881  // loop result type "more" static than before, will be folded.
1882  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1883  castOp.getSource().getType())) {
1884  continue;
1885  }
1886 
1887  tensorCastProducers[en.index()] =
1888  TypeCast{castOp.getSource().getType(), castOp.getType()};
1889  newOutputTensors[en.index()] = castOp.getSource();
1890  }
1891 
1892  if (tensorCastProducers.empty())
1893  return failure();
1894 
1895  // Create new loop.
1896  Location loc = forallOp.getLoc();
1897  auto newForallOp = ForallOp::create(
1898  rewriter, loc, forallOp.getMixedLowerBound(),
1899  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1900  newOutputTensors, forallOp.getMapping(),
1901  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1902  auto castBlockArgs =
1903  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1904  for (auto [index, cast] : tensorCastProducers) {
1905  Value &oldTypeBBArg = castBlockArgs[index];
1906  oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1907  cast.dstType, oldTypeBBArg);
1908  }
1909 
1910  // Move old body into new parallel loop.
1911  SmallVector<Value> ivsBlockArgs =
1912  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1913  ivsBlockArgs.append(castBlockArgs);
1914  rewriter.mergeBlocks(forallOp.getBody(),
1915  bbArgs.front().getParentBlock(), ivsBlockArgs);
1916  });
1917 
1918  // After `mergeBlocks` happened, the destinations in the terminator were
1919  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1920  // destination have to be updated to point to the output bbArgs directly.
1921  auto terminator = newForallOp.getTerminator();
1922  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1923  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1924  if (auto parallelCombingingOp =
1925  dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
1926  parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
1927  }
1928  }
1929 
1930  // Cast results back to the original types.
1931  rewriter.setInsertionPointAfter(newForallOp);
1932  SmallVector<Value> castResults = newForallOp.getResults();
1933  for (auto &item : tensorCastProducers) {
1934  Value &oldTypeResult = castResults[item.first];
1935  oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1936  oldTypeResult);
1937  }
1938  rewriter.replaceOp(forallOp, castResults);
1939  return success();
1940  }
1941 };
1942 
1943 } // namespace
1944 
1945 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1946  MLIRContext *context) {
1947  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1948  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1949  ForallOpSingleOrZeroIterationDimsFolder,
1950  ForallOpReplaceConstantInductionVar>(context);
1951 }
1952 
1953 /// Given the region at `index`, or the parent operation if `index` is None,
1954 /// return the successor regions. These are the regions that may be selected
1955 /// during the flow of control. `operands` is a set of optional attributes that
1956 /// correspond to a constant value for each operand, or null if that operand is
1957 /// not a constant.
1958 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1960  // In accordance with the semantics of forall, its body is executed in
1961  // parallel by multiple threads. We should not expect to branch back into
1962  // the forall body after the region's execution is complete.
1963  if (point.isParent())
1964  regions.push_back(RegionSuccessor(&getRegion()));
1965  else
1966  regions.push_back(RegionSuccessor());
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // InParallelOp
1971 //===----------------------------------------------------------------------===//
1972 
1973 // Build a InParallelOp with mixed static and dynamic entries.
1974 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1976  Region *bodyRegion = result.addRegion();
1977  b.createBlock(bodyRegion);
1978 }
1979 
1980 LogicalResult InParallelOp::verify() {
1981  scf::ForallOp forallOp =
1982  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1983  if (!forallOp)
1984  return this->emitOpError("expected forall op parent");
1985 
1986  for (Operation &op : getRegion().front().getOperations()) {
1987  auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
1988  if (!parallelCombiningOp) {
1989  return this->emitOpError("expected only ParallelCombiningOpInterface")
1990  << " ops";
1991  }
1992 
1993  // Verify that inserts are into out block arguments.
1994  MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
1995  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1996  for (OpOperand &dest : dests) {
1997  if (!llvm::is_contained(regionOutArgs, dest.get()))
1998  return op.emitOpError("may only insert into an output block argument");
1999  }
2000  }
2001 
2002  return success();
2003 }
2004 
2006  p << " ";
2007  p.printRegion(getRegion(),
2008  /*printEntryBlockArgs=*/false,
2009  /*printBlockTerminators=*/false);
2010  p.printOptionalAttrDict(getOperation()->getAttrs());
2011 }
2012 
2013 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2014  auto &builder = parser.getBuilder();
2015 
2017  std::unique_ptr<Region> region = std::make_unique<Region>();
2018  if (parser.parseRegion(*region, regionOperands))
2019  return failure();
2020 
2021  if (region->empty())
2022  OpBuilder(builder.getContext()).createBlock(region.get());
2023  result.addRegion(std::move(region));
2024 
2025  // Parse the optional attribute list.
2026  if (parser.parseOptionalAttrDict(result.attributes))
2027  return failure();
2028  return success();
2029 }
2030 
2031 OpResult InParallelOp::getParentResult(int64_t idx) {
2032  return getOperation()->getParentOp()->getResult(idx);
2033 }
2034 
2035 SmallVector<BlockArgument> InParallelOp::getDests() {
2036  SmallVector<BlockArgument> updatedDests;
2037  for (Operation &yieldingOp : getYieldingOps()) {
2038  auto parallelCombiningOp =
2039  dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
2040  if (!parallelCombiningOp)
2041  continue;
2042  for (OpOperand &updatedOperand :
2043  parallelCombiningOp.getUpdatedDestinations())
2044  updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
2045  }
2046  return updatedDests;
2047 }
2048 
2049 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2050  return getRegion().front().getOperations();
2051 }
2052 
2053 //===----------------------------------------------------------------------===//
2054 // IfOp
2055 //===----------------------------------------------------------------------===//
2056 
2058  assert(a && "expected non-empty operation");
2059  assert(b && "expected non-empty operation");
2060 
2061  IfOp ifOp = a->getParentOfType<IfOp>();
2062  while (ifOp) {
2063  // Check if b is inside ifOp. (We already know that a is.)
2064  if (ifOp->isProperAncestor(b))
2065  // b is contained in ifOp. a and b are in mutually exclusive branches if
2066  // they are in different blocks of ifOp.
2067  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2068  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2069  // Check next enclosing IfOp.
2070  ifOp = ifOp->getParentOfType<IfOp>();
2071  }
2072 
2073  // Could not find a common IfOp among a's and b's ancestors.
2074  return false;
2075 }
2076 
2077 LogicalResult
2078 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2079  IfOp::Adaptor adaptor,
2080  SmallVectorImpl<Type> &inferredReturnTypes) {
2081  if (adaptor.getRegions().empty())
2082  return failure();
2083  Region *r = &adaptor.getThenRegion();
2084  if (r->empty())
2085  return failure();
2086  Block &b = r->front();
2087  if (b.empty())
2088  return failure();
2089  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2090  if (!yieldOp)
2091  return failure();
2092  TypeRange types = yieldOp.getOperandTypes();
2093  llvm::append_range(inferredReturnTypes, types);
2094  return success();
2095 }
2096 
2097 void IfOp::build(OpBuilder &builder, OperationState &result,
2098  TypeRange resultTypes, Value cond) {
2099  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2100  /*addElseBlock=*/false);
2101 }
2102 
2103 void IfOp::build(OpBuilder &builder, OperationState &result,
2104  TypeRange resultTypes, Value cond, bool addThenBlock,
2105  bool addElseBlock) {
2106  assert((!addElseBlock || addThenBlock) &&
2107  "must not create else block w/o then block");
2108  result.addTypes(resultTypes);
2109  result.addOperands(cond);
2110 
2111  // Add regions and blocks.
2112  OpBuilder::InsertionGuard guard(builder);
2113  Region *thenRegion = result.addRegion();
2114  if (addThenBlock)
2115  builder.createBlock(thenRegion);
2116  Region *elseRegion = result.addRegion();
2117  if (addElseBlock)
2118  builder.createBlock(elseRegion);
2119 }
2120 
2121 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2122  bool withElseRegion) {
2123  build(builder, result, TypeRange{}, cond, withElseRegion);
2124 }
2125 
2126 void IfOp::build(OpBuilder &builder, OperationState &result,
2127  TypeRange resultTypes, Value cond, bool withElseRegion) {
2128  result.addTypes(resultTypes);
2129  result.addOperands(cond);
2130 
2131  // Build then region.
2132  OpBuilder::InsertionGuard guard(builder);
2133  Region *thenRegion = result.addRegion();
2134  builder.createBlock(thenRegion);
2135  if (resultTypes.empty())
2136  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2137 
2138  // Build else region.
2139  Region *elseRegion = result.addRegion();
2140  if (withElseRegion) {
2141  builder.createBlock(elseRegion);
2142  if (resultTypes.empty())
2143  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2144  }
2145 }
2146 
2147 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2148  function_ref<void(OpBuilder &, Location)> thenBuilder,
2149  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2150  assert(thenBuilder && "the builder callback for 'then' must be present");
2151  result.addOperands(cond);
2152 
2153  // Build then region.
2154  OpBuilder::InsertionGuard guard(builder);
2155  Region *thenRegion = result.addRegion();
2156  builder.createBlock(thenRegion);
2157  thenBuilder(builder, result.location);
2158 
2159  // Build else region.
2160  Region *elseRegion = result.addRegion();
2161  if (elseBuilder) {
2162  builder.createBlock(elseRegion);
2163  elseBuilder(builder, result.location);
2164  }
2165 
2166  // Infer result types.
2167  SmallVector<Type> inferredReturnTypes;
2168  MLIRContext *ctx = builder.getContext();
2169  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2170  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2171  /*properties=*/nullptr, result.regions,
2172  inferredReturnTypes))) {
2173  result.addTypes(inferredReturnTypes);
2174  }
2175 }
2176 
2177 LogicalResult IfOp::verify() {
2178  if (getNumResults() != 0 && getElseRegion().empty())
2179  return emitOpError("must have an else block if defining values");
2180  return success();
2181 }
2182 
2183 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2184  // Create the regions for 'then'.
2185  result.regions.reserve(2);
2186  Region *thenRegion = result.addRegion();
2187  Region *elseRegion = result.addRegion();
2188 
2189  auto &builder = parser.getBuilder();
2191  Type i1Type = builder.getIntegerType(1);
2192  if (parser.parseOperand(cond) ||
2193  parser.resolveOperand(cond, i1Type, result.operands))
2194  return failure();
2195  // Parse optional results type list.
2196  if (parser.parseOptionalArrowTypeList(result.types))
2197  return failure();
2198  // Parse the 'then' region.
2199  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2200  return failure();
2201  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2202 
2203  // If we find an 'else' keyword then parse the 'else' region.
2204  if (!parser.parseOptionalKeyword("else")) {
2205  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2206  return failure();
2207  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2208  }
2209 
2210  // Parse the optional attribute list.
2211  if (parser.parseOptionalAttrDict(result.attributes))
2212  return failure();
2213  return success();
2214 }
2215 
2216 void IfOp::print(OpAsmPrinter &p) {
2217  bool printBlockTerminators = false;
2218 
2219  p << " " << getCondition();
2220  if (!getResults().empty()) {
2221  p << " -> (" << getResultTypes() << ")";
2222  // Print yield explicitly if the op defines values.
2223  printBlockTerminators = true;
2224  }
2225  p << ' ';
2226  p.printRegion(getThenRegion(),
2227  /*printEntryBlockArgs=*/false,
2228  /*printBlockTerminators=*/printBlockTerminators);
2229 
2230  // Print the 'else' regions if it exists and has a block.
2231  auto &elseRegion = getElseRegion();
2232  if (!elseRegion.empty()) {
2233  p << " else ";
2234  p.printRegion(elseRegion,
2235  /*printEntryBlockArgs=*/false,
2236  /*printBlockTerminators=*/printBlockTerminators);
2237  }
2238 
2239  p.printOptionalAttrDict((*this)->getAttrs());
2240 }
2241 
2242 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2244  // The `then` and the `else` region branch back to the parent operation.
2245  if (!point.isParent()) {
2246  regions.push_back(RegionSuccessor(getResults()));
2247  return;
2248  }
2249 
2250  regions.push_back(RegionSuccessor(&getThenRegion()));
2251 
2252  // Don't consider the else region if it is empty.
2253  Region *elseRegion = &this->getElseRegion();
2254  if (elseRegion->empty())
2255  regions.push_back(RegionSuccessor());
2256  else
2257  regions.push_back(RegionSuccessor(elseRegion));
2258 }
2259 
2260 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2262  FoldAdaptor adaptor(operands, *this);
2263  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2264  if (!boolAttr || boolAttr.getValue())
2265  regions.emplace_back(&getThenRegion());
2266 
2267  // If the else region is empty, execution continues after the parent op.
2268  if (!boolAttr || !boolAttr.getValue()) {
2269  if (!getElseRegion().empty())
2270  regions.emplace_back(&getElseRegion());
2271  else
2272  regions.emplace_back(getResults());
2273  }
2274 }
2275 
2276 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2277  SmallVectorImpl<OpFoldResult> &results) {
2278  // if (!c) then A() else B() -> if c then B() else A()
2279  if (getElseRegion().empty())
2280  return failure();
2281 
2282  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2283  if (!xorStmt)
2284  return failure();
2285 
2286  if (!matchPattern(xorStmt.getRhs(), m_One()))
2287  return failure();
2288 
2289  getConditionMutable().assign(xorStmt.getLhs());
2290  Block *thenBlock = &getThenRegion().front();
2291  // It would be nicer to use iplist::swap, but that has no implemented
2292  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2293  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2294  getElseRegion().getBlocks());
2295  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2296  getThenRegion().getBlocks(), thenBlock);
2297  return success();
2298 }
2299 
2300 void IfOp::getRegionInvocationBounds(
2301  ArrayRef<Attribute> operands,
2302  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2303  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2304  // If the condition is known, then one region is known to be executed once
2305  // and the other zero times.
2306  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2307  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2308  } else {
2309  // Non-constant condition. Each region may be executed 0 or 1 times.
2310  invocationBounds.assign(2, {0, 1});
2311  }
2312 }
2313 
2314 namespace {
2315 // Pattern to remove unused IfOp results.
2316 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2318 
2319  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2320  PatternRewriter &rewriter) const {
2321  // Move all operations to the destination block.
2322  rewriter.mergeBlocks(source, dest);
2323  // Replace the yield op by one that returns only the used values.
2324  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2325  SmallVector<Value, 4> usedOperands;
2326  llvm::transform(usedResults, std::back_inserter(usedOperands),
2327  [&](OpResult result) {
2328  return yieldOp.getOperand(result.getResultNumber());
2329  });
2330  rewriter.modifyOpInPlace(yieldOp,
2331  [&]() { yieldOp->setOperands(usedOperands); });
2332  }
2333 
2334  LogicalResult matchAndRewrite(IfOp op,
2335  PatternRewriter &rewriter) const override {
2336  // Compute the list of used results.
2337  SmallVector<OpResult, 4> usedResults;
2338  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2339  [](OpResult result) { return !result.use_empty(); });
2340 
2341  // Replace the operation if only a subset of its results have uses.
2342  if (usedResults.size() == op.getNumResults())
2343  return failure();
2344 
2345  // Compute the result types of the replacement operation.
2346  SmallVector<Type, 4> newTypes;
2347  llvm::transform(usedResults, std::back_inserter(newTypes),
2348  [](OpResult result) { return result.getType(); });
2349 
2350  // Create a replacement operation with empty then and else regions.
2351  auto newOp =
2352  IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2353  rewriter.createBlock(&newOp.getThenRegion());
2354  rewriter.createBlock(&newOp.getElseRegion());
2355 
2356  // Move the bodies and replace the terminators (note there is a then and
2357  // an else region since the operation returns results).
2358  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2359  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2360 
2361  // Replace the operation by the new one.
2362  SmallVector<Value, 4> repResults(op.getNumResults());
2363  for (const auto &en : llvm::enumerate(usedResults))
2364  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2365  rewriter.replaceOp(op, repResults);
2366  return success();
2367  }
2368 };
2369 
2370 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2372 
2373  LogicalResult matchAndRewrite(IfOp op,
2374  PatternRewriter &rewriter) const override {
2375  BoolAttr condition;
2376  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2377  return failure();
2378 
2379  if (condition.getValue())
2380  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2381  else if (!op.getElseRegion().empty())
2382  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2383  else
2384  rewriter.eraseOp(op);
2385 
2386  return success();
2387  }
2388 };
2389 
2390 /// Hoist any yielded results whose operands are defined outside
2391 /// the if, to a select instruction.
2392 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2394 
2395  LogicalResult matchAndRewrite(IfOp op,
2396  PatternRewriter &rewriter) const override {
2397  if (op->getNumResults() == 0)
2398  return failure();
2399 
2400  auto cond = op.getCondition();
2401  auto thenYieldArgs = op.thenYield().getOperands();
2402  auto elseYieldArgs = op.elseYield().getOperands();
2403 
2404  SmallVector<Type> nonHoistable;
2405  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2406  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2407  &op.getElseRegion() == falseVal.getParentRegion())
2408  nonHoistable.push_back(trueVal.getType());
2409  }
2410  // Early exit if there aren't any yielded values we can
2411  // hoist outside the if.
2412  if (nonHoistable.size() == op->getNumResults())
2413  return failure();
2414 
2415  IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2416  /*withElseRegion=*/false);
2417  if (replacement.thenBlock())
2418  rewriter.eraseBlock(replacement.thenBlock());
2419  replacement.getThenRegion().takeBody(op.getThenRegion());
2420  replacement.getElseRegion().takeBody(op.getElseRegion());
2421 
2422  SmallVector<Value> results(op->getNumResults());
2423  assert(thenYieldArgs.size() == results.size());
2424  assert(elseYieldArgs.size() == results.size());
2425 
2426  SmallVector<Value> trueYields;
2427  SmallVector<Value> falseYields;
2428  rewriter.setInsertionPoint(replacement);
2429  for (const auto &it :
2430  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2431  Value trueVal = std::get<0>(it.value());
2432  Value falseVal = std::get<1>(it.value());
2433  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2434  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2435  results[it.index()] = replacement.getResult(trueYields.size());
2436  trueYields.push_back(trueVal);
2437  falseYields.push_back(falseVal);
2438  } else if (trueVal == falseVal)
2439  results[it.index()] = trueVal;
2440  else
2441  results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2442  cond, trueVal, falseVal);
2443  }
2444 
2445  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2446  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2447 
2448  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2449  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2450 
2451  rewriter.replaceOp(op, results);
2452  return success();
2453  }
2454 };
2455 
2456 /// Allow the true region of an if to assume the condition is true
2457 /// and vice versa. For example:
2458 ///
2459 /// scf.if %cmp {
2460 /// print(%cmp)
2461 /// }
2462 ///
2463 /// becomes
2464 ///
2465 /// scf.if %cmp {
2466 /// print(true)
2467 /// }
2468 ///
2469 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2471 
2472  LogicalResult matchAndRewrite(IfOp op,
2473  PatternRewriter &rewriter) const override {
2474  // Early exit if the condition is constant since replacing a constant
2475  // in the body with another constant isn't a simplification.
2476  if (matchPattern(op.getCondition(), m_Constant()))
2477  return failure();
2478 
2479  bool changed = false;
2480  mlir::Type i1Ty = rewriter.getI1Type();
2481 
2482  // These variables serve to prevent creating duplicate constants
2483  // and hold constant true or false values.
2484  Value constantTrue = nullptr;
2485  Value constantFalse = nullptr;
2486 
2487  for (OpOperand &use :
2488  llvm::make_early_inc_range(op.getCondition().getUses())) {
2489  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2490  changed = true;
2491 
2492  if (!constantTrue)
2493  constantTrue = rewriter.create<arith::ConstantOp>(
2494  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2495 
2496  rewriter.modifyOpInPlace(use.getOwner(),
2497  [&]() { use.set(constantTrue); });
2498  } else if (op.getElseRegion().isAncestor(
2499  use.getOwner()->getParentRegion())) {
2500  changed = true;
2501 
2502  if (!constantFalse)
2503  constantFalse = rewriter.create<arith::ConstantOp>(
2504  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2505 
2506  rewriter.modifyOpInPlace(use.getOwner(),
2507  [&]() { use.set(constantFalse); });
2508  }
2509  }
2510 
2511  return success(changed);
2512  }
2513 };
2514 
2515 /// Remove any statements from an if that are equivalent to the condition
2516 /// or its negation. For example:
2517 ///
2518 /// %res:2 = scf.if %cmp {
2519 /// yield something(), true
2520 /// } else {
2521 /// yield something2(), false
2522 /// }
2523 /// print(%res#1)
2524 ///
2525 /// becomes
2526 /// %res = scf.if %cmp {
2527 /// yield something()
2528 /// } else {
2529 /// yield something2()
2530 /// }
2531 /// print(%cmp)
2532 ///
2533 /// Additionally if both branches yield the same value, replace all uses
2534 /// of the result with the yielded value.
2535 ///
2536 /// %res:2 = scf.if %cmp {
2537 /// yield something(), %arg1
2538 /// } else {
2539 /// yield something2(), %arg1
2540 /// }
2541 /// print(%res#1)
2542 ///
2543 /// becomes
2544 /// %res = scf.if %cmp {
2545 /// yield something()
2546 /// } else {
2547 /// yield something2()
2548 /// }
2549 /// print(%arg1)
2550 ///
2551 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2553 
2554  LogicalResult matchAndRewrite(IfOp op,
2555  PatternRewriter &rewriter) const override {
2556  // Early exit if there are no results that could be replaced.
2557  if (op.getNumResults() == 0)
2558  return failure();
2559 
2560  auto trueYield =
2561  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2562  auto falseYield =
2563  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2564 
2565  rewriter.setInsertionPoint(op->getBlock(),
2566  op.getOperation()->getIterator());
2567  bool changed = false;
2568  Type i1Ty = rewriter.getI1Type();
2569  for (auto [trueResult, falseResult, opResult] :
2570  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2571  op.getResults())) {
2572  if (trueResult == falseResult) {
2573  if (!opResult.use_empty()) {
2574  opResult.replaceAllUsesWith(trueResult);
2575  changed = true;
2576  }
2577  continue;
2578  }
2579 
2580  BoolAttr trueYield, falseYield;
2581  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2582  !matchPattern(falseResult, m_Constant(&falseYield)))
2583  continue;
2584 
2585  bool trueVal = trueYield.getValue();
2586  bool falseVal = falseYield.getValue();
2587  if (!trueVal && falseVal) {
2588  if (!opResult.use_empty()) {
2589  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2590  Value notCond = arith::XOrIOp::create(
2591  rewriter, op.getLoc(), op.getCondition(),
2592  constDialect
2593  ->materializeConstant(rewriter,
2594  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2595  op.getLoc())
2596  ->getResult(0));
2597  opResult.replaceAllUsesWith(notCond);
2598  changed = true;
2599  }
2600  }
2601  if (trueVal && !falseVal) {
2602  if (!opResult.use_empty()) {
2603  opResult.replaceAllUsesWith(op.getCondition());
2604  changed = true;
2605  }
2606  }
2607  }
2608  return success(changed);
2609  }
2610 };
2611 
2612 /// Merge any consecutive scf.if's with the same condition.
2613 ///
2614 /// scf.if %cond {
2615 /// firstCodeTrue();...
2616 /// } else {
2617 /// firstCodeFalse();...
2618 /// }
2619 /// %res = scf.if %cond {
2620 /// secondCodeTrue();...
2621 /// } else {
2622 /// secondCodeFalse();...
2623 /// }
2624 ///
2625 /// becomes
2626 /// %res = scf.if %cmp {
2627 /// firstCodeTrue();...
2628 /// secondCodeTrue();...
2629 /// } else {
2630 /// firstCodeFalse();...
2631 /// secondCodeFalse();...
2632 /// }
2633 struct CombineIfs : public OpRewritePattern<IfOp> {
2635 
2636  LogicalResult matchAndRewrite(IfOp nextIf,
2637  PatternRewriter &rewriter) const override {
2638  Block *parent = nextIf->getBlock();
2639  if (nextIf == &parent->front())
2640  return failure();
2641 
2642  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2643  if (!prevIf)
2644  return failure();
2645 
2646  // Determine the logical then/else blocks when prevIf's
2647  // condition is used. Null means the block does not exist
2648  // in that case (e.g. empty else). If neither of these
2649  // are set, the two conditions cannot be compared.
2650  Block *nextThen = nullptr;
2651  Block *nextElse = nullptr;
2652  if (nextIf.getCondition() == prevIf.getCondition()) {
2653  nextThen = nextIf.thenBlock();
2654  if (!nextIf.getElseRegion().empty())
2655  nextElse = nextIf.elseBlock();
2656  }
2657  if (arith::XOrIOp notv =
2658  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2659  if (notv.getLhs() == prevIf.getCondition() &&
2660  matchPattern(notv.getRhs(), m_One())) {
2661  nextElse = nextIf.thenBlock();
2662  if (!nextIf.getElseRegion().empty())
2663  nextThen = nextIf.elseBlock();
2664  }
2665  }
2666  if (arith::XOrIOp notv =
2667  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2668  if (notv.getLhs() == nextIf.getCondition() &&
2669  matchPattern(notv.getRhs(), m_One())) {
2670  nextElse = nextIf.thenBlock();
2671  if (!nextIf.getElseRegion().empty())
2672  nextThen = nextIf.elseBlock();
2673  }
2674  }
2675 
2676  if (!nextThen && !nextElse)
2677  return failure();
2678 
2679  SmallVector<Value> prevElseYielded;
2680  if (!prevIf.getElseRegion().empty())
2681  prevElseYielded = prevIf.elseYield().getOperands();
2682  // Replace all uses of return values of op within nextIf with the
2683  // corresponding yields
2684  for (auto it : llvm::zip(prevIf.getResults(),
2685  prevIf.thenYield().getOperands(), prevElseYielded))
2686  for (OpOperand &use :
2687  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2688  if (nextThen && nextThen->getParent()->isAncestor(
2689  use.getOwner()->getParentRegion())) {
2690  rewriter.startOpModification(use.getOwner());
2691  use.set(std::get<1>(it));
2692  rewriter.finalizeOpModification(use.getOwner());
2693  } else if (nextElse && nextElse->getParent()->isAncestor(
2694  use.getOwner()->getParentRegion())) {
2695  rewriter.startOpModification(use.getOwner());
2696  use.set(std::get<2>(it));
2697  rewriter.finalizeOpModification(use.getOwner());
2698  }
2699  }
2700 
2701  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2702  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2703 
2704  IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2705  prevIf.getCondition(), /*hasElse=*/false);
2706  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2707 
2708  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2709  combinedIf.getThenRegion(),
2710  combinedIf.getThenRegion().begin());
2711 
2712  if (nextThen) {
2713  YieldOp thenYield = combinedIf.thenYield();
2714  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2715  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2716  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2717 
2718  SmallVector<Value> mergedYields(thenYield.getOperands());
2719  llvm::append_range(mergedYields, thenYield2.getOperands());
2720  YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2721  rewriter.eraseOp(thenYield);
2722  rewriter.eraseOp(thenYield2);
2723  }
2724 
2725  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2726  combinedIf.getElseRegion(),
2727  combinedIf.getElseRegion().begin());
2728 
2729  if (nextElse) {
2730  if (combinedIf.getElseRegion().empty()) {
2731  rewriter.inlineRegionBefore(*nextElse->getParent(),
2732  combinedIf.getElseRegion(),
2733  combinedIf.getElseRegion().begin());
2734  } else {
2735  YieldOp elseYield = combinedIf.elseYield();
2736  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2737  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2738 
2739  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2740 
2741  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2742  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2743 
2744  YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2745  rewriter.eraseOp(elseYield);
2746  rewriter.eraseOp(elseYield2);
2747  }
2748  }
2749 
2750  SmallVector<Value> prevValues;
2751  SmallVector<Value> nextValues;
2752  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2753  if (pair.index() < prevIf.getNumResults())
2754  prevValues.push_back(pair.value());
2755  else
2756  nextValues.push_back(pair.value());
2757  }
2758  rewriter.replaceOp(prevIf, prevValues);
2759  rewriter.replaceOp(nextIf, nextValues);
2760  return success();
2761  }
2762 };
2763 
2764 /// Pattern to remove an empty else branch.
2765 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2767 
2768  LogicalResult matchAndRewrite(IfOp ifOp,
2769  PatternRewriter &rewriter) const override {
2770  // Cannot remove else region when there are operation results.
2771  if (ifOp.getNumResults())
2772  return failure();
2773  Block *elseBlock = ifOp.elseBlock();
2774  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2775  return failure();
2776  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2777  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2778  newIfOp.getThenRegion().begin());
2779  rewriter.eraseOp(ifOp);
2780  return success();
2781  }
2782 };
2783 
2784 /// Convert nested `if`s into `arith.andi` + single `if`.
2785 ///
2786 /// scf.if %arg0 {
2787 /// scf.if %arg1 {
2788 /// ...
2789 /// scf.yield
2790 /// }
2791 /// scf.yield
2792 /// }
2793 /// becomes
2794 ///
2795 /// %0 = arith.andi %arg0, %arg1
2796 /// scf.if %0 {
2797 /// ...
2798 /// scf.yield
2799 /// }
2800 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2802 
2803  LogicalResult matchAndRewrite(IfOp op,
2804  PatternRewriter &rewriter) const override {
2805  auto nestedOps = op.thenBlock()->without_terminator();
2806  // Nested `if` must be the only op in block.
2807  if (!llvm::hasSingleElement(nestedOps))
2808  return failure();
2809 
2810  // If there is an else block, it can only yield
2811  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2812  return failure();
2813 
2814  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2815  if (!nestedIf)
2816  return failure();
2817 
2818  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2819  return failure();
2820 
2821  SmallVector<Value> thenYield(op.thenYield().getOperands());
2822  SmallVector<Value> elseYield;
2823  if (op.elseBlock())
2824  llvm::append_range(elseYield, op.elseYield().getOperands());
2825 
2826  // A list of indices for which we should upgrade the value yielded
2827  // in the else to a select.
2828  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2829 
2830  // If the outer scf.if yields a value produced by the inner scf.if,
2831  // only permit combining if the value yielded when the condition
2832  // is false in the outer scf.if is the same value yielded when the
2833  // inner scf.if condition is false.
2834  // Note that the array access to elseYield will not go out of bounds
2835  // since it must have the same length as thenYield, since they both
2836  // come from the same scf.if.
2837  for (const auto &tup : llvm::enumerate(thenYield)) {
2838  if (tup.value().getDefiningOp() == nestedIf) {
2839  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2840  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2841  elseYield[tup.index()]) {
2842  return failure();
2843  }
2844  // If the correctness test passes, we will yield
2845  // corresponding value from the inner scf.if
2846  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2847  continue;
2848  }
2849 
2850  // Otherwise, we need to ensure the else block of the combined
2851  // condition still returns the same value when the outer condition is
2852  // true and the inner condition is false. This can be accomplished if
2853  // the then value is defined outside the outer scf.if and we replace the
2854  // value with a select that considers just the outer condition. Since
2855  // the else region contains just the yield, its yielded value is
2856  // defined outside the scf.if, by definition.
2857 
2858  // If the then value is defined within the scf.if, bail.
2859  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2860  return failure();
2861  }
2862  elseYieldsToUpgradeToSelect.push_back(tup.index());
2863  }
2864 
2865  Location loc = op.getLoc();
2866  Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2867  nestedIf.getCondition());
2868  auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2869  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2870 
2871  SmallVector<Value> results;
2872  llvm::append_range(results, newIf.getResults());
2873  rewriter.setInsertionPoint(newIf);
2874 
2875  for (auto idx : elseYieldsToUpgradeToSelect)
2876  results[idx] =
2877  arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2878  thenYield[idx], elseYield[idx]);
2879 
2880  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2881  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2882  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2883  if (!elseYield.empty()) {
2884  rewriter.createBlock(&newIf.getElseRegion());
2885  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2886  YieldOp::create(rewriter, loc, elseYield);
2887  }
2888  rewriter.replaceOp(op, results);
2889  return success();
2890  }
2891 };
2892 
2893 } // namespace
2894 
2895 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2896  MLIRContext *context) {
2897  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2898  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2899  RemoveStaticCondition, RemoveUnusedResults,
2900  ReplaceIfYieldWithConditionOrValue>(context);
2901 }
2902 
2903 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2904 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2905 Block *IfOp::elseBlock() {
2906  Region &r = getElseRegion();
2907  if (r.empty())
2908  return nullptr;
2909  return &r.back();
2910 }
2911 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2912 
2913 //===----------------------------------------------------------------------===//
2914 // ParallelOp
2915 //===----------------------------------------------------------------------===//
2916 
2917 void ParallelOp::build(
2918  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2919  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2921  bodyBuilderFn) {
2922  result.addOperands(lowerBounds);
2923  result.addOperands(upperBounds);
2924  result.addOperands(steps);
2925  result.addOperands(initVals);
2926  result.addAttribute(
2927  ParallelOp::getOperandSegmentSizeAttr(),
2928  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2929  static_cast<int32_t>(upperBounds.size()),
2930  static_cast<int32_t>(steps.size()),
2931  static_cast<int32_t>(initVals.size())}));
2932  result.addTypes(initVals.getTypes());
2933 
2934  OpBuilder::InsertionGuard guard(builder);
2935  unsigned numIVs = steps.size();
2936  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2937  SmallVector<Location, 8> argLocs(numIVs, result.location);
2938  Region *bodyRegion = result.addRegion();
2939  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2940 
2941  if (bodyBuilderFn) {
2942  builder.setInsertionPointToStart(bodyBlock);
2943  bodyBuilderFn(builder, result.location,
2944  bodyBlock->getArguments().take_front(numIVs),
2945  bodyBlock->getArguments().drop_front(numIVs));
2946  }
2947  // Add terminator only if there are no reductions.
2948  if (initVals.empty())
2949  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2950 }
2951 
2952 void ParallelOp::build(
2953  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2954  ValueRange upperBounds, ValueRange steps,
2955  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2956  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2957  // we don't capture a reference to a temporary by constructing the lambda at
2958  // function level.
2959  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2960  Location nestedLoc, ValueRange ivs,
2961  ValueRange) {
2962  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2963  };
2964  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2965  if (bodyBuilderFn)
2966  wrapper = wrappedBuilderFn;
2967 
2968  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2969  wrapper);
2970 }
2971 
2972 LogicalResult ParallelOp::verify() {
2973  // Check that there is at least one value in lowerBound, upperBound and step.
2974  // It is sufficient to test only step, because it is ensured already that the
2975  // number of elements in lowerBound, upperBound and step are the same.
2976  Operation::operand_range stepValues = getStep();
2977  if (stepValues.empty())
2978  return emitOpError(
2979  "needs at least one tuple element for lowerBound, upperBound and step");
2980 
2981  // Check whether all constant step values are positive.
2982  for (Value stepValue : stepValues)
2983  if (auto cst = getConstantIntValue(stepValue))
2984  if (*cst <= 0)
2985  return emitOpError("constant step operand must be positive");
2986 
2987  // Check that the body defines the same number of block arguments as the
2988  // number of tuple elements in step.
2989  Block *body = getBody();
2990  if (body->getNumArguments() != stepValues.size())
2991  return emitOpError() << "expects the same number of induction variables: "
2992  << body->getNumArguments()
2993  << " as bound and step values: " << stepValues.size();
2994  for (auto arg : body->getArguments())
2995  if (!arg.getType().isIndex())
2996  return emitOpError(
2997  "expects arguments for the induction variable to be of index type");
2998 
2999  // Check that the terminator is an scf.reduce op.
3000  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
3001  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
3002  if (!reduceOp)
3003  return failure();
3004 
3005  // Check that the number of results is the same as the number of reductions.
3006  auto resultsSize = getResults().size();
3007  auto reductionsSize = reduceOp.getReductions().size();
3008  auto initValsSize = getInitVals().size();
3009  if (resultsSize != reductionsSize)
3010  return emitOpError() << "expects number of results: " << resultsSize
3011  << " to be the same as number of reductions: "
3012  << reductionsSize;
3013  if (resultsSize != initValsSize)
3014  return emitOpError() << "expects number of results: " << resultsSize
3015  << " to be the same as number of initial values: "
3016  << initValsSize;
3017 
3018  // Check that the types of the results and reductions are the same.
3019  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
3020  auto resultType = getOperation()->getResult(i).getType();
3021  auto reductionOperandType = reduceOp.getOperands()[i].getType();
3022  if (resultType != reductionOperandType)
3023  return reduceOp.emitOpError()
3024  << "expects type of " << i
3025  << "-th reduction operand: " << reductionOperandType
3026  << " to be the same as the " << i
3027  << "-th result type: " << resultType;
3028  }
3029  return success();
3030 }
3031 
3032 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
3033  auto &builder = parser.getBuilder();
3034  // Parse an opening `(` followed by induction variables followed by `)`
3036  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
3037  return failure();
3038 
3039  // Parse loop bounds.
3041  if (parser.parseEqual() ||
3042  parser.parseOperandList(lower, ivs.size(),
3043  OpAsmParser::Delimiter::Paren) ||
3044  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
3045  return failure();
3046 
3048  if (parser.parseKeyword("to") ||
3049  parser.parseOperandList(upper, ivs.size(),
3050  OpAsmParser::Delimiter::Paren) ||
3051  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
3052  return failure();
3053 
3054  // Parse step values.
3056  if (parser.parseKeyword("step") ||
3057  parser.parseOperandList(steps, ivs.size(),
3058  OpAsmParser::Delimiter::Paren) ||
3059  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
3060  return failure();
3061 
3062  // Parse init values.
3064  if (succeeded(parser.parseOptionalKeyword("init"))) {
3065  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
3066  return failure();
3067  }
3068 
3069  // Parse optional results in case there is a reduce.
3070  if (parser.parseOptionalArrowTypeList(result.types))
3071  return failure();
3072 
3073  // Now parse the body.
3074  Region *body = result.addRegion();
3075  for (auto &iv : ivs)
3076  iv.type = builder.getIndexType();
3077  if (parser.parseRegion(*body, ivs))
3078  return failure();
3079 
3080  // Set `operandSegmentSizes` attribute.
3081  result.addAttribute(
3082  ParallelOp::getOperandSegmentSizeAttr(),
3083  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3084  static_cast<int32_t>(upper.size()),
3085  static_cast<int32_t>(steps.size()),
3086  static_cast<int32_t>(initVals.size())}));
3087 
3088  // Parse attributes.
3089  if (parser.parseOptionalAttrDict(result.attributes) ||
3090  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3091  result.operands))
3092  return failure();
3093 
3094  // Add a terminator if none was parsed.
3095  ParallelOp::ensureTerminator(*body, builder, result.location);
3096  return success();
3097 }
3098 
3100  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3101  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3102  if (!getInitVals().empty())
3103  p << " init (" << getInitVals() << ")";
3104  p.printOptionalArrowTypeList(getResultTypes());
3105  p << ' ';
3106  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3108  (*this)->getAttrs(),
3109  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3110 }
3111 
3112 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3113 
3114 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3115  return SmallVector<Value>{getBody()->getArguments()};
3116 }
3117 
3118 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3119  return getLowerBound();
3120 }
3121 
3122 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3123  return getUpperBound();
3124 }
3125 
3126 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3127  return getStep();
3128 }
3129 
3131  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3132  if (!ivArg)
3133  return ParallelOp();
3134  assert(ivArg.getOwner() && "unlinked block argument");
3135  auto *containingOp = ivArg.getOwner()->getParentOp();
3136  return dyn_cast<ParallelOp>(containingOp);
3137 }
3138 
3139 namespace {
3140 // Collapse loop dimensions that perform a single iteration.
3141 struct ParallelOpSingleOrZeroIterationDimsFolder
3142  : public OpRewritePattern<ParallelOp> {
3144 
3145  LogicalResult matchAndRewrite(ParallelOp op,
3146  PatternRewriter &rewriter) const override {
3147  Location loc = op.getLoc();
3148 
3149  // Compute new loop bounds that omit all single-iteration loop dimensions.
3150  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3151  IRMapping mapping;
3152  for (auto [lb, ub, step, iv] :
3153  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3154  op.getInductionVars())) {
3155  auto numIterations =
3156  constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
3157  if (numIterations.has_value()) {
3158  // Remove the loop if it performs zero iterations.
3159  if (*numIterations == 0) {
3160  rewriter.replaceOp(op, op.getInitVals());
3161  return success();
3162  }
3163  // Replace the loop induction variable by the lower bound if the loop
3164  // performs a single iteration. Otherwise, copy the loop bounds.
3165  if (*numIterations == 1) {
3166  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3167  continue;
3168  }
3169  }
3170  newLowerBounds.push_back(lb);
3171  newUpperBounds.push_back(ub);
3172  newSteps.push_back(step);
3173  }
3174  // Exit if none of the loop dimensions perform a single iteration.
3175  if (newLowerBounds.size() == op.getLowerBound().size())
3176  return failure();
3177 
3178  if (newLowerBounds.empty()) {
3179  // All of the loop dimensions perform a single iteration. Inline
3180  // loop body and nested ReduceOp's
3181  SmallVector<Value> results;
3182  results.reserve(op.getInitVals().size());
3183  for (auto &bodyOp : op.getBody()->without_terminator())
3184  rewriter.clone(bodyOp, mapping);
3185  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3186  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3187  Block &reduceBlock = reduceOp.getReductions()[i].front();
3188  auto initValIndex = results.size();
3189  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3190  mapping.map(reduceBlock.getArgument(1),
3191  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3192  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3193  rewriter.clone(reduceBodyOp, mapping);
3194 
3195  auto result = mapping.lookupOrDefault(
3196  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3197  results.push_back(result);
3198  }
3199 
3200  rewriter.replaceOp(op, results);
3201  return success();
3202  }
3203  // Replace the parallel loop by lower-dimensional parallel loop.
3204  auto newOp =
3205  ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3206  newUpperBounds, newSteps, op.getInitVals(), nullptr);
3207  // Erase the empty block that was inserted by the builder.
3208  rewriter.eraseBlock(newOp.getBody());
3209  // Clone the loop body and remap the block arguments of the collapsed loops
3210  // (inlining does not support a cancellable block argument mapping).
3211  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3212  newOp.getRegion().begin(), mapping);
3213  rewriter.replaceOp(op, newOp.getResults());
3214  return success();
3215  }
3216 };
3217 
3218 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3220 
3221  LogicalResult matchAndRewrite(ParallelOp op,
3222  PatternRewriter &rewriter) const override {
3223  Block &outerBody = *op.getBody();
3224  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3225  return failure();
3226 
3227  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3228  if (!innerOp)
3229  return failure();
3230 
3231  for (auto val : outerBody.getArguments())
3232  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3233  llvm::is_contained(innerOp.getUpperBound(), val) ||
3234  llvm::is_contained(innerOp.getStep(), val))
3235  return failure();
3236 
3237  // Reductions are not supported yet.
3238  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3239  return failure();
3240 
3241  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3242  ValueRange iterVals, ValueRange) {
3243  Block &innerBody = *innerOp.getBody();
3244  assert(iterVals.size() ==
3245  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3246  IRMapping mapping;
3247  mapping.map(outerBody.getArguments(),
3248  iterVals.take_front(outerBody.getNumArguments()));
3249  mapping.map(innerBody.getArguments(),
3250  iterVals.take_back(innerBody.getNumArguments()));
3251  for (Operation &op : innerBody.without_terminator())
3252  builder.clone(op, mapping);
3253  };
3254 
3255  auto concatValues = [](const auto &first, const auto &second) {
3256  SmallVector<Value> ret;
3257  ret.reserve(first.size() + second.size());
3258  ret.assign(first.begin(), first.end());
3259  ret.append(second.begin(), second.end());
3260  return ret;
3261  };
3262 
3263  auto newLowerBounds =
3264  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3265  auto newUpperBounds =
3266  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3267  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3268 
3269  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3270  newSteps, ValueRange(),
3271  bodyBuilder);
3272  return success();
3273  }
3274 };
3275 
3276 } // namespace
3277 
3278 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3279  MLIRContext *context) {
3280  results
3281  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3282  context);
3283 }
3284 
3285 /// Given the region at `index`, or the parent operation if `index` is None,
3286 /// return the successor regions. These are the regions that may be selected
3287 /// during the flow of control. `operands` is a set of optional attributes that
3288 /// correspond to a constant value for each operand, or null if that operand is
3289 /// not a constant.
3290 void ParallelOp::getSuccessorRegions(
3292  // Both the operation itself and the region may be branching into the body or
3293  // back into the operation itself. It is possible for loop not to enter the
3294  // body.
3295  regions.push_back(RegionSuccessor(&getRegion()));
3296  regions.push_back(RegionSuccessor());
3297 }
3298 
3299 //===----------------------------------------------------------------------===//
3300 // ReduceOp
3301 //===----------------------------------------------------------------------===//
3302 
3303 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3304 
3305 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3306  ValueRange operands) {
3307  result.addOperands(operands);
3308  for (Value v : operands) {
3309  OpBuilder::InsertionGuard guard(builder);
3310  Region *bodyRegion = result.addRegion();
3311  builder.createBlock(bodyRegion, {},
3312  ArrayRef<Type>{v.getType(), v.getType()},
3313  {result.location, result.location});
3314  }
3315 }
3316 
3317 LogicalResult ReduceOp::verifyRegions() {
3318  // The region of a ReduceOp has two arguments of the same type as its
3319  // corresponding operand.
3320  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3321  auto type = getOperands()[i].getType();
3322  Block &block = getReductions()[i].front();
3323  if (block.empty())
3324  return emitOpError() << i << "-th reduction has an empty body";
3325  if (block.getNumArguments() != 2 ||
3326  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3327  return arg.getType() != type;
3328  }))
3329  return emitOpError() << "expected two block arguments with type " << type
3330  << " in the " << i << "-th reduction region";
3331 
3332  // Check that the block is terminated by a ReduceReturnOp.
3333  if (!isa<ReduceReturnOp>(block.getTerminator()))
3334  return emitOpError("reduction bodies must be terminated with an "
3335  "'scf.reduce.return' op");
3336  }
3337 
3338  return success();
3339 }
3340 
3343  // No operands are forwarded to the next iteration.
3344  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3345 }
3346 
3347 //===----------------------------------------------------------------------===//
3348 // ReduceReturnOp
3349 //===----------------------------------------------------------------------===//
3350 
3351 LogicalResult ReduceReturnOp::verify() {
3352  // The type of the return value should be the same type as the types of the
3353  // block arguments of the reduction body.
3354  Block *reductionBody = getOperation()->getBlock();
3355  // Should already be verified by an op trait.
3356  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3357  Type expectedResultType = reductionBody->getArgument(0).getType();
3358  if (expectedResultType != getResult().getType())
3359  return emitOpError() << "must have type " << expectedResultType
3360  << " (the type of the reduction inputs)";
3361  return success();
3362 }
3363 
3364 //===----------------------------------------------------------------------===//
3365 // WhileOp
3366 //===----------------------------------------------------------------------===//
3367 
3368 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3369  ::mlir::OperationState &odsState, TypeRange resultTypes,
3370  ValueRange inits, BodyBuilderFn beforeBuilder,
3371  BodyBuilderFn afterBuilder) {
3372  odsState.addOperands(inits);
3373  odsState.addTypes(resultTypes);
3374 
3375  OpBuilder::InsertionGuard guard(odsBuilder);
3376 
3377  // Build before region.
3378  SmallVector<Location, 4> beforeArgLocs;
3379  beforeArgLocs.reserve(inits.size());
3380  for (Value operand : inits) {
3381  beforeArgLocs.push_back(operand.getLoc());
3382  }
3383 
3384  Region *beforeRegion = odsState.addRegion();
3385  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3386  inits.getTypes(), beforeArgLocs);
3387  if (beforeBuilder)
3388  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3389 
3390  // Build after region.
3391  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3392 
3393  Region *afterRegion = odsState.addRegion();
3394  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3395  resultTypes, afterArgLocs);
3396 
3397  if (afterBuilder)
3398  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3399 }
3400 
3401 ConditionOp WhileOp::getConditionOp() {
3402  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3403 }
3404 
3405 YieldOp WhileOp::getYieldOp() {
3406  return cast<YieldOp>(getAfterBody()->getTerminator());
3407 }
3408 
3409 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3410  return getYieldOp().getResultsMutable();
3411 }
3412 
3413 Block::BlockArgListType WhileOp::getBeforeArguments() {
3414  return getBeforeBody()->getArguments();
3415 }
3416 
3417 Block::BlockArgListType WhileOp::getAfterArguments() {
3418  return getAfterBody()->getArguments();
3419 }
3420 
3421 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3422  return getBeforeArguments();
3423 }
3424 
3425 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3426  assert(point == getBefore() &&
3427  "WhileOp is expected to branch only to the first region");
3428  return getInits();
3429 }
3430 
3431 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3433  // The parent op always branches to the condition region.
3434  if (point.isParent()) {
3435  regions.emplace_back(&getBefore(), getBefore().getArguments());
3436  return;
3437  }
3438 
3439  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3440  "there are only two regions in a WhileOp");
3441  // The body region always branches back to the condition region.
3442  if (point == getAfter()) {
3443  regions.emplace_back(&getBefore(), getBefore().getArguments());
3444  return;
3445  }
3446 
3447  regions.emplace_back(getResults());
3448  regions.emplace_back(&getAfter(), getAfter().getArguments());
3449 }
3450 
3451 SmallVector<Region *> WhileOp::getLoopRegions() {
3452  return {&getBefore(), &getAfter()};
3453 }
3454 
3455 /// Parses a `while` op.
3456 ///
3457 /// op ::= `scf.while` assignments `:` function-type region `do` region
3458 /// `attributes` attribute-dict
3459 /// initializer ::= /* empty */ | `(` assignment-list `)`
3460 /// assignment-list ::= assignment | assignment `,` assignment-list
3461 /// assignment ::= ssa-value `=` ssa-value
3462 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3465  Region *before = result.addRegion();
3466  Region *after = result.addRegion();
3467 
3468  OptionalParseResult listResult =
3469  parser.parseOptionalAssignmentList(regionArgs, operands);
3470  if (listResult.has_value() && failed(listResult.value()))
3471  return failure();
3472 
3473  FunctionType functionType;
3474  SMLoc typeLoc = parser.getCurrentLocation();
3475  if (failed(parser.parseColonType(functionType)))
3476  return failure();
3477 
3478  result.addTypes(functionType.getResults());
3479 
3480  if (functionType.getNumInputs() != operands.size()) {
3481  return parser.emitError(typeLoc)
3482  << "expected as many input types as operands " << "(expected "
3483  << operands.size() << " got " << functionType.getNumInputs() << ")";
3484  }
3485 
3486  // Resolve input operands.
3487  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3488  parser.getCurrentLocation(),
3489  result.operands)))
3490  return failure();
3491 
3492  // Propagate the types into the region arguments.
3493  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3494  regionArgs[i].type = functionType.getInput(i);
3495 
3496  return failure(parser.parseRegion(*before, regionArgs) ||
3497  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3499 }
3500 
3501 /// Prints a `while` op.
3503  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3504  p << " : ";
3505  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3506  p << ' ';
3507  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3508  p << " do ";
3509  p.printRegion(getAfter());
3510  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3511 }
3512 
3513 /// Verifies that two ranges of types match, i.e. have the same number of
3514 /// entries and that types are pairwise equals. Reports errors on the given
3515 /// operation in case of mismatch.
3516 template <typename OpTy>
3517 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3518  TypeRange right, StringRef message) {
3519  if (left.size() != right.size())
3520  return op.emitOpError("expects the same number of ") << message;
3521 
3522  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3523  if (left[i] != right[i]) {
3524  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3525  << message;
3526  diag.attachNote() << "for argument " << i << ", found " << left[i]
3527  << " and " << right[i];
3528  return diag;
3529  }
3530  }
3531 
3532  return success();
3533 }
3534 
3535 LogicalResult scf::WhileOp::verify() {
3536  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3537  *this, getBefore(),
3538  "expects the 'before' region to terminate with 'scf.condition'");
3539  if (!beforeTerminator)
3540  return failure();
3541 
3542  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3543  *this, getAfter(),
3544  "expects the 'after' region to terminate with 'scf.yield'");
3545  return success(afterTerminator != nullptr);
3546 }
3547 
3548 namespace {
3549 /// Replace uses of the condition within the do block with true, since otherwise
3550 /// the block would not be evaluated.
3551 ///
3552 /// scf.while (..) : (i1, ...) -> ... {
3553 /// %condition = call @evaluate_condition() : () -> i1
3554 /// scf.condition(%condition) %condition : i1, ...
3555 /// } do {
3556 /// ^bb0(%arg0: i1, ...):
3557 /// use(%arg0)
3558 /// ...
3559 ///
3560 /// becomes
3561 /// scf.while (..) : (i1, ...) -> ... {
3562 /// %condition = call @evaluate_condition() : () -> i1
3563 /// scf.condition(%condition) %condition : i1, ...
3564 /// } do {
3565 /// ^bb0(%arg0: i1, ...):
3566 /// use(%true)
3567 /// ...
3568 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3570 
3571  LogicalResult matchAndRewrite(WhileOp op,
3572  PatternRewriter &rewriter) const override {
3573  auto term = op.getConditionOp();
3574 
3575  // These variables serve to prevent creating duplicate constants
3576  // and hold constant true or false values.
3577  Value constantTrue = nullptr;
3578 
3579  bool replaced = false;
3580  for (auto yieldedAndBlockArgs :
3581  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3582  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3583  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3584  if (!constantTrue)
3585  constantTrue = arith::ConstantOp::create(
3586  rewriter, op.getLoc(), term.getCondition().getType(),
3587  rewriter.getBoolAttr(true));
3588 
3589  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3590  constantTrue);
3591  replaced = true;
3592  }
3593  }
3594  }
3595  return success(replaced);
3596  }
3597 };
3598 
3599 /// Remove loop invariant arguments from `before` block of scf.while.
3600 /// A before block argument is considered loop invariant if :-
3601 /// 1. i-th yield operand is equal to the i-th while operand.
3602 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3603 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3604 /// iter argument/while operand.
3605 /// For the arguments which are removed, their uses inside scf.while
3606 /// are replaced with their corresponding initial value.
3607 ///
3608 /// Eg:
3609 /// INPUT :-
3610 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3611 /// ..., %argN_before = %N)
3612 /// {
3613 /// ...
3614 /// scf.condition(%cond) %arg1_before, %arg0_before,
3615 /// %arg2_before, %arg0_before, ...
3616 /// } do {
3617 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3618 /// ..., %argK_after):
3619 /// ...
3620 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3621 /// }
3622 ///
3623 /// OUTPUT :-
3624 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3625 /// %N)
3626 /// {
3627 /// ...
3628 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3629 /// } do {
3630 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3631 /// ..., %argK_after):
3632 /// ...
3633 /// scf.yield %arg1_after, ..., %argN
3634 /// }
3635 ///
3636 /// EXPLANATION:
3637 /// We iterate over each yield operand.
3638 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3639 /// %arg0_before, which in turn is the 0-th iter argument. So we
3640 /// remove 0-th before block argument and yield operand, and replace
3641 /// all uses of the 0-th before block argument with its initial value
3642 /// %a.
3643 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3644 /// value. So we remove this operand and the corresponding before
3645 /// block argument and replace all uses of 1-th before block argument
3646 /// with %b.
3647 struct RemoveLoopInvariantArgsFromBeforeBlock
3648  : public OpRewritePattern<WhileOp> {
3650 
3651  LogicalResult matchAndRewrite(WhileOp op,
3652  PatternRewriter &rewriter) const override {
3653  Block &afterBlock = *op.getAfterBody();
3654  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3655  ConditionOp condOp = op.getConditionOp();
3656  OperandRange condOpArgs = condOp.getArgs();
3657  Operation *yieldOp = afterBlock.getTerminator();
3658  ValueRange yieldOpArgs = yieldOp->getOperands();
3659 
3660  bool canSimplify = false;
3661  for (const auto &it :
3662  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3663  auto index = static_cast<unsigned>(it.index());
3664  auto [initVal, yieldOpArg] = it.value();
3665  // If i-th yield operand is equal to the i-th operand of the scf.while,
3666  // the i-th before block argument is a loop invariant.
3667  if (yieldOpArg == initVal) {
3668  canSimplify = true;
3669  break;
3670  }
3671  // If the i-th yield operand is k-th after block argument, then we check
3672  // if the (k+1)-th condition op operand is equal to either the i-th before
3673  // block argument or the initial value of i-th before block argument. If
3674  // the comparison results `true`, i-th before block argument is a loop
3675  // invariant.
3676  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3677  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3678  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3679  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3680  canSimplify = true;
3681  break;
3682  }
3683  }
3684  }
3685 
3686  if (!canSimplify)
3687  return failure();
3688 
3689  SmallVector<Value> newInitArgs, newYieldOpArgs;
3690  DenseMap<unsigned, Value> beforeBlockInitValMap;
3691  SmallVector<Location> newBeforeBlockArgLocs;
3692  for (const auto &it :
3693  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3694  auto index = static_cast<unsigned>(it.index());
3695  auto [initVal, yieldOpArg] = it.value();
3696 
3697  // If i-th yield operand is equal to the i-th operand of the scf.while,
3698  // the i-th before block argument is a loop invariant.
3699  if (yieldOpArg == initVal) {
3700  beforeBlockInitValMap.insert({index, initVal});
3701  continue;
3702  } else {
3703  // If the i-th yield operand is k-th after block argument, then we check
3704  // if the (k+1)-th condition op operand is equal to either the i-th
3705  // before block argument or the initial value of i-th before block
3706  // argument. If the comparison results `true`, i-th before block
3707  // argument is a loop invariant.
3708  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3709  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3710  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3711  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3712  beforeBlockInitValMap.insert({index, initVal});
3713  continue;
3714  }
3715  }
3716  }
3717  newInitArgs.emplace_back(initVal);
3718  newYieldOpArgs.emplace_back(yieldOpArg);
3719  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3720  }
3721 
3722  {
3723  OpBuilder::InsertionGuard g(rewriter);
3724  rewriter.setInsertionPoint(yieldOp);
3725  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3726  }
3727 
3728  auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3729  newInitArgs);
3730 
3731  Block &newBeforeBlock = *rewriter.createBlock(
3732  &newWhile.getBefore(), /*insertPt*/ {},
3733  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3734 
3735  Block &beforeBlock = *op.getBeforeBody();
3736  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3737  // For each i-th before block argument we find it's replacement value as :-
3738  // 1. If i-th before block argument is a loop invariant, we fetch it's
3739  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3740  // 2. Else we fetch j-th new before block argument as the replacement
3741  // value of i-th before block argument.
3742  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3743  // If the index 'i' argument was a loop invariant we fetch it's initial
3744  // value from `beforeBlockInitValMap`.
3745  if (beforeBlockInitValMap.count(i) != 0)
3746  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3747  else
3748  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3749  }
3750 
3751  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3752  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3753  newWhile.getAfter().begin());
3754 
3755  rewriter.replaceOp(op, newWhile.getResults());
3756  return success();
3757  }
3758 };
3759 
3760 /// Remove loop invariant value from result (condition op) of scf.while.
3761 /// A value is considered loop invariant if the final value yielded by
3762 /// scf.condition is defined outside of the `before` block. We remove the
3763 /// corresponding argument in `after` block and replace the use with the value.
3764 /// We also replace the use of the corresponding result of scf.while with the
3765 /// value.
3766 ///
3767 /// Eg:
3768 /// INPUT :-
3769 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3770 /// %argN_before = %N) {
3771 /// ...
3772 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3773 /// } do {
3774 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3775 /// ...
3776 /// some_func(%arg1_after)
3777 /// ...
3778 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3779 /// }
3780 ///
3781 /// OUTPUT :-
3782 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3783 /// ...
3784 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3785 /// } do {
3786 /// ^bb0(%arg0, %arg3, ..., %argM):
3787 /// ...
3788 /// some_func(%a)
3789 /// ...
3790 /// scf.yield %arg0, %b, ..., %argN
3791 /// }
3792 ///
3793 /// EXPLANATION:
3794 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3795 /// before block of scf.while, so they get removed.
3796 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3797 /// replaced by %b.
3798 /// 3. The corresponding after block argument %arg1_after's uses are
3799 /// replaced by %a and %arg2_after's uses are replaced by %b.
3800 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3802 
3803  LogicalResult matchAndRewrite(WhileOp op,
3804  PatternRewriter &rewriter) const override {
3805  Block &beforeBlock = *op.getBeforeBody();
3806  ConditionOp condOp = op.getConditionOp();
3807  OperandRange condOpArgs = condOp.getArgs();
3808 
3809  bool canSimplify = false;
3810  for (Value condOpArg : condOpArgs) {
3811  // Those values not defined within `before` block will be considered as
3812  // loop invariant values. We map the corresponding `index` with their
3813  // value.
3814  if (condOpArg.getParentBlock() != &beforeBlock) {
3815  canSimplify = true;
3816  break;
3817  }
3818  }
3819 
3820  if (!canSimplify)
3821  return failure();
3822 
3823  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3824 
3825  SmallVector<Value> newCondOpArgs;
3826  SmallVector<Type> newAfterBlockType;
3827  DenseMap<unsigned, Value> condOpInitValMap;
3828  SmallVector<Location> newAfterBlockArgLocs;
3829  for (const auto &it : llvm::enumerate(condOpArgs)) {
3830  auto index = static_cast<unsigned>(it.index());
3831  Value condOpArg = it.value();
3832  // Those values not defined within `before` block will be considered as
3833  // loop invariant values. We map the corresponding `index` with their
3834  // value.
3835  if (condOpArg.getParentBlock() != &beforeBlock) {
3836  condOpInitValMap.insert({index, condOpArg});
3837  } else {
3838  newCondOpArgs.emplace_back(condOpArg);
3839  newAfterBlockType.emplace_back(condOpArg.getType());
3840  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3841  }
3842  }
3843 
3844  {
3845  OpBuilder::InsertionGuard g(rewriter);
3846  rewriter.setInsertionPoint(condOp);
3847  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3848  newCondOpArgs);
3849  }
3850 
3851  auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3852  op.getOperands());
3853 
3854  Block &newAfterBlock =
3855  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3856  newAfterBlockType, newAfterBlockArgLocs);
3857 
3858  Block &afterBlock = *op.getAfterBody();
3859  // Since a new scf.condition op was created, we need to fetch the new
3860  // `after` block arguments which will be used while replacing operations of
3861  // previous scf.while's `after` blocks. We'd also be fetching new result
3862  // values too.
3863  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3864  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3865  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3866  Value afterBlockArg, result;
3867  // If index 'i' argument was loop invariant we fetch it's value from the
3868  // `condOpInitMap` map.
3869  if (condOpInitValMap.count(i) != 0) {
3870  afterBlockArg = condOpInitValMap[i];
3871  result = afterBlockArg;
3872  } else {
3873  afterBlockArg = newAfterBlock.getArgument(j);
3874  result = newWhile.getResult(j);
3875  j++;
3876  }
3877  newAfterBlockArgs[i] = afterBlockArg;
3878  newWhileResults[i] = result;
3879  }
3880 
3881  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3882  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3883  newWhile.getBefore().begin());
3884 
3885  rewriter.replaceOp(op, newWhileResults);
3886  return success();
3887  }
3888 };
3889 
3890 /// Remove WhileOp results that are also unused in 'after' block.
3891 ///
3892 /// %0:2 = scf.while () : () -> (i32, i64) {
3893 /// %condition = "test.condition"() : () -> i1
3894 /// %v1 = "test.get_some_value"() : () -> i32
3895 /// %v2 = "test.get_some_value"() : () -> i64
3896 /// scf.condition(%condition) %v1, %v2 : i32, i64
3897 /// } do {
3898 /// ^bb0(%arg0: i32, %arg1: i64):
3899 /// "test.use"(%arg0) : (i32) -> ()
3900 /// scf.yield
3901 /// }
3902 /// return %0#0 : i32
3903 ///
3904 /// becomes
3905 /// %0 = scf.while () : () -> (i32) {
3906 /// %condition = "test.condition"() : () -> i1
3907 /// %v1 = "test.get_some_value"() : () -> i32
3908 /// %v2 = "test.get_some_value"() : () -> i64
3909 /// scf.condition(%condition) %v1 : i32
3910 /// } do {
3911 /// ^bb0(%arg0: i32):
3912 /// "test.use"(%arg0) : (i32) -> ()
3913 /// scf.yield
3914 /// }
3915 /// return %0 : i32
3916 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3918 
3919  LogicalResult matchAndRewrite(WhileOp op,
3920  PatternRewriter &rewriter) const override {
3921  auto term = op.getConditionOp();
3922  auto afterArgs = op.getAfterArguments();
3923  auto termArgs = term.getArgs();
3924 
3925  // Collect results mapping, new terminator args and new result types.
3926  SmallVector<unsigned> newResultsIndices;
3927  SmallVector<Type> newResultTypes;
3928  SmallVector<Value> newTermArgs;
3929  SmallVector<Location> newArgLocs;
3930  bool needUpdate = false;
3931  for (const auto &it :
3932  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3933  auto i = static_cast<unsigned>(it.index());
3934  Value result = std::get<0>(it.value());
3935  Value afterArg = std::get<1>(it.value());
3936  Value termArg = std::get<2>(it.value());
3937  if (result.use_empty() && afterArg.use_empty()) {
3938  needUpdate = true;
3939  } else {
3940  newResultsIndices.emplace_back(i);
3941  newTermArgs.emplace_back(termArg);
3942  newResultTypes.emplace_back(result.getType());
3943  newArgLocs.emplace_back(result.getLoc());
3944  }
3945  }
3946 
3947  if (!needUpdate)
3948  return failure();
3949 
3950  {
3951  OpBuilder::InsertionGuard g(rewriter);
3952  rewriter.setInsertionPoint(term);
3953  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3954  newTermArgs);
3955  }
3956 
3957  auto newWhile =
3958  WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
3959 
3960  Block &newAfterBlock = *rewriter.createBlock(
3961  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3962 
3963  // Build new results list and new after block args (unused entries will be
3964  // null).
3965  SmallVector<Value> newResults(op.getNumResults());
3966  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3967  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3968  newResults[it.value()] = newWhile.getResult(it.index());
3969  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3970  }
3971 
3972  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3973  newWhile.getBefore().begin());
3974 
3975  Block &afterBlock = *op.getAfterBody();
3976  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3977 
3978  rewriter.replaceOp(op, newResults);
3979  return success();
3980  }
3981 };
3982 
3983 /// Replace operations equivalent to the condition in the do block with true,
3984 /// since otherwise the block would not be evaluated.
3985 ///
3986 /// scf.while (..) : (i32, ...) -> ... {
3987 /// %z = ... : i32
3988 /// %condition = cmpi pred %z, %a
3989 /// scf.condition(%condition) %z : i32, ...
3990 /// } do {
3991 /// ^bb0(%arg0: i32, ...):
3992 /// %condition2 = cmpi pred %arg0, %a
3993 /// use(%condition2)
3994 /// ...
3995 ///
3996 /// becomes
3997 /// scf.while (..) : (i32, ...) -> ... {
3998 /// %z = ... : i32
3999 /// %condition = cmpi pred %z, %a
4000 /// scf.condition(%condition) %z : i32, ...
4001 /// } do {
4002 /// ^bb0(%arg0: i32, ...):
4003 /// use(%true)
4004 /// ...
4005 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
4007 
4008  LogicalResult matchAndRewrite(scf::WhileOp op,
4009  PatternRewriter &rewriter) const override {
4010  using namespace scf;
4011  auto cond = op.getConditionOp();
4012  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
4013  if (!cmp)
4014  return failure();
4015  bool changed = false;
4016  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
4017  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
4018  if (std::get<0>(tup) != cmp.getOperand(opIdx))
4019  continue;
4020  for (OpOperand &u :
4021  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
4022  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
4023  if (!cmp2)
4024  continue;
4025  // For a binary operator 1-opIdx gets the other side.
4026  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
4027  continue;
4028  bool samePredicate;
4029  if (cmp2.getPredicate() == cmp.getPredicate())
4030  samePredicate = true;
4031  else if (cmp2.getPredicate() ==
4032  arith::invertPredicate(cmp.getPredicate()))
4033  samePredicate = false;
4034  else
4035  continue;
4036 
4037  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
4038  1);
4039  changed = true;
4040  }
4041  }
4042  }
4043  return success(changed);
4044  }
4045 };
4046 
4047 /// Remove unused init/yield args.
4048 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
4050 
4051  LogicalResult matchAndRewrite(WhileOp op,
4052  PatternRewriter &rewriter) const override {
4053 
4054  if (!llvm::any_of(op.getBeforeArguments(),
4055  [](Value arg) { return arg.use_empty(); }))
4056  return rewriter.notifyMatchFailure(op, "No args to remove");
4057 
4058  YieldOp yield = op.getYieldOp();
4059 
4060  // Collect results mapping, new terminator args and new result types.
4061  SmallVector<Value> newYields;
4062  SmallVector<Value> newInits;
4063  llvm::BitVector argsToErase;
4064 
4065  size_t argsCount = op.getBeforeArguments().size();
4066  newYields.reserve(argsCount);
4067  newInits.reserve(argsCount);
4068  argsToErase.reserve(argsCount);
4069  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4070  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4071  if (beforeArg.use_empty()) {
4072  argsToErase.push_back(true);
4073  } else {
4074  argsToErase.push_back(false);
4075  newYields.emplace_back(yieldValue);
4076  newInits.emplace_back(initValue);
4077  }
4078  }
4079 
4080  Block &beforeBlock = *op.getBeforeBody();
4081  Block &afterBlock = *op.getAfterBody();
4082 
4083  beforeBlock.eraseArguments(argsToErase);
4084 
4085  Location loc = op.getLoc();
4086  auto newWhileOp =
4087  WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
4088  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4089  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4090  Block &newAfterBlock = *newWhileOp.getAfterBody();
4091 
4092  OpBuilder::InsertionGuard g(rewriter);
4093  rewriter.setInsertionPoint(yield);
4094  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4095 
4096  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4097  newBeforeBlock.getArguments());
4098  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4099  newAfterBlock.getArguments());
4100 
4101  rewriter.replaceOp(op, newWhileOp.getResults());
4102  return success();
4103  }
4104 };
4105 
4106 /// Remove duplicated ConditionOp args.
4107 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4108  using OpRewritePattern::OpRewritePattern;
4109 
4110  LogicalResult matchAndRewrite(WhileOp op,
4111  PatternRewriter &rewriter) const override {
4112  ConditionOp condOp = op.getConditionOp();
4113  ValueRange condOpArgs = condOp.getArgs();
4114 
4115  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4116 
4117  if (argsSet.size() == condOpArgs.size())
4118  return rewriter.notifyMatchFailure(op, "No results to remove");
4119 
4120  llvm::SmallDenseMap<Value, unsigned> argsMap;
4121  SmallVector<Value> newArgs;
4122  argsMap.reserve(condOpArgs.size());
4123  newArgs.reserve(condOpArgs.size());
4124  for (Value arg : condOpArgs) {
4125  if (!argsMap.count(arg)) {
4126  auto pos = static_cast<unsigned>(argsMap.size());
4127  argsMap.insert({arg, pos});
4128  newArgs.emplace_back(arg);
4129  }
4130  }
4131 
4132  ValueRange argsRange(newArgs);
4133 
4134  Location loc = op.getLoc();
4135  auto newWhileOp =
4136  scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4137  /*beforeBody*/ nullptr,
4138  /*afterBody*/ nullptr);
4139  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4140  Block &newAfterBlock = *newWhileOp.getAfterBody();
4141 
4142  SmallVector<Value> afterArgsMapping;
4143  SmallVector<Value> resultsMapping;
4144  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4145  auto it = argsMap.find(arg);
4146  assert(it != argsMap.end());
4147  auto pos = it->second;
4148  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4149  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4150  }
4151 
4152  OpBuilder::InsertionGuard g(rewriter);
4153  rewriter.setInsertionPoint(condOp);
4154  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4155  argsRange);
4156 
4157  Block &beforeBlock = *op.getBeforeBody();
4158  Block &afterBlock = *op.getAfterBody();
4159 
4160  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4161  newBeforeBlock.getArguments());
4162  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4163  rewriter.replaceOp(op, resultsMapping);
4164  return success();
4165  }
4166 };
4167 
4168 /// If both ranges contain same values return mappping indices from args2 to
4169 /// args1. Otherwise return std::nullopt.
4170 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4171  ValueRange args2) {
4172  if (args1.size() != args2.size())
4173  return std::nullopt;
4174 
4175  SmallVector<unsigned> ret(args1.size());
4176  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4177  auto it = llvm::find(args2, arg1);
4178  if (it == args2.end())
4179  return std::nullopt;
4180 
4181  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4182  }
4183 
4184  return ret;
4185 }
4186 
4187 static bool hasDuplicates(ValueRange args) {
4188  llvm::SmallDenseSet<Value> set;
4189  for (Value arg : args) {
4190  if (!set.insert(arg).second)
4191  return true;
4192  }
4193  return false;
4194 }
4195 
4196 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4197 /// `scf.condition` args into same order as block args. Update `after` block
4198 /// args and op result values accordingly.
4199 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4200 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4201  using OpRewritePattern::OpRewritePattern;
4202 
4203  LogicalResult matchAndRewrite(WhileOp loop,
4204  PatternRewriter &rewriter) const override {
4205  auto oldBefore = loop.getBeforeBody();
4206  ConditionOp oldTerm = loop.getConditionOp();
4207  ValueRange beforeArgs = oldBefore->getArguments();
4208  ValueRange termArgs = oldTerm.getArgs();
4209  if (beforeArgs == termArgs)
4210  return failure();
4211 
4212  if (hasDuplicates(termArgs))
4213  return failure();
4214 
4215  auto mapping = getArgsMapping(beforeArgs, termArgs);
4216  if (!mapping)
4217  return failure();
4218 
4219  {
4220  OpBuilder::InsertionGuard g(rewriter);
4221  rewriter.setInsertionPoint(oldTerm);
4222  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4223  beforeArgs);
4224  }
4225 
4226  auto oldAfter = loop.getAfterBody();
4227 
4228  SmallVector<Type> newResultTypes(beforeArgs.size());
4229  for (auto &&[i, j] : llvm::enumerate(*mapping))
4230  newResultTypes[j] = loop.getResult(i).getType();
4231 
4232  auto newLoop = WhileOp::create(
4233  rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4234  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4235  auto newBefore = newLoop.getBeforeBody();
4236  auto newAfter = newLoop.getAfterBody();
4237 
4238  SmallVector<Value> newResults(beforeArgs.size());
4239  SmallVector<Value> newAfterArgs(beforeArgs.size());
4240  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4241  newResults[i] = newLoop.getResult(j);
4242  newAfterArgs[i] = newAfter->getArgument(j);
4243  }
4244 
4245  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4246  newBefore->getArguments());
4247  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4248  newAfterArgs);
4249 
4250  rewriter.replaceOp(loop, newResults);
4251  return success();
4252  }
4253 };
4254 } // namespace
4255 
4256 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4257  MLIRContext *context) {
4258  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4259  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4260  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4261  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4262 }
4263 
4264 //===----------------------------------------------------------------------===//
4265 // IndexSwitchOp
4266 //===----------------------------------------------------------------------===//
4267 
4268 /// Parse the case regions and values.
4269 static ParseResult
4271  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4272  SmallVector<int64_t> caseValues;
4273  while (succeeded(p.parseOptionalKeyword("case"))) {
4274  int64_t value;
4275  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4276  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4277  return failure();
4278  caseValues.push_back(value);
4279  }
4280  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4281  return success();
4282 }
4283 
4284 /// Print the case regions and values.
4286  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4287  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4288  p.printNewline();
4289  p << "case " << value << ' ';
4290  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4291  }
4292 }
4293 
4294 LogicalResult scf::IndexSwitchOp::verify() {
4295  if (getCases().size() != getCaseRegions().size()) {
4296  return emitOpError("has ")
4297  << getCaseRegions().size() << " case regions but "
4298  << getCases().size() << " case values";
4299  }
4300 
4301  DenseSet<int64_t> valueSet;
4302  for (int64_t value : getCases())
4303  if (!valueSet.insert(value).second)
4304  return emitOpError("has duplicate case value: ") << value;
4305  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4306  auto yield = dyn_cast<YieldOp>(region.front().back());
4307  if (!yield)
4308  return emitOpError("expected region to end with scf.yield, but got ")
4309  << region.front().back().getName();
4310 
4311  if (yield.getNumOperands() != getNumResults()) {
4312  return (emitOpError("expected each region to return ")
4313  << getNumResults() << " values, but " << name << " returns "
4314  << yield.getNumOperands())
4315  .attachNote(yield.getLoc())
4316  << "see yield operation here";
4317  }
4318  for (auto [idx, result, operand] :
4319  llvm::enumerate(getResultTypes(), yield.getOperands())) {
4320  if (!operand)
4321  return yield.emitOpError() << "operand " << idx << " is null\n";
4322  if (result == operand.getType())
4323  continue;
4324  return (emitOpError("expected result #")
4325  << idx << " of each region to be " << result)
4326  .attachNote(yield.getLoc())
4327  << name << " returns " << operand.getType() << " here";
4328  }
4329  return success();
4330  };
4331 
4332  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4333  return failure();
4334  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4335  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4336  return failure();
4337 
4338  return success();
4339 }
4340 
4341 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4342 
4343 Block &scf::IndexSwitchOp::getDefaultBlock() {
4344  return getDefaultRegion().front();
4345 }
4346 
4347 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4348  assert(idx < getNumCases() && "case index out-of-bounds");
4349  return getCaseRegions()[idx].front();
4350 }
4351 
4352 void IndexSwitchOp::getSuccessorRegions(
4354  // All regions branch back to the parent op.
4355  if (!point.isParent()) {
4356  successors.emplace_back(getResults());
4357  return;
4358  }
4359 
4360  llvm::append_range(successors, getRegions());
4361 }
4362 
4363 void IndexSwitchOp::getEntrySuccessorRegions(
4364  ArrayRef<Attribute> operands,
4365  SmallVectorImpl<RegionSuccessor> &successors) {
4366  FoldAdaptor adaptor(operands, *this);
4367 
4368  // If a constant was not provided, all regions are possible successors.
4369  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4370  if (!arg) {
4371  llvm::append_range(successors, getRegions());
4372  return;
4373  }
4374 
4375  // Otherwise, try to find a case with a matching value. If not, the
4376  // default region is the only successor.
4377  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4378  if (caseValue == arg.getInt()) {
4379  successors.emplace_back(&caseRegion);
4380  return;
4381  }
4382  }
4383  successors.emplace_back(&getDefaultRegion());
4384 }
4385 
4386 void IndexSwitchOp::getRegionInvocationBounds(
4388  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4389  if (!operandValue) {
4390  // All regions are invoked at most once.
4391  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4392  return;
4393  }
4394 
4395  unsigned liveIndex = getNumRegions() - 1;
4396  const auto *it = llvm::find(getCases(), operandValue.getInt());
4397  if (it != getCases().end())
4398  liveIndex = std::distance(getCases().begin(), it);
4399  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4400  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4401 }
4402 
4403 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4405 
4406  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4407  PatternRewriter &rewriter) const override {
4408  // If `op.getArg()` is a constant, select the region that matches with
4409  // the constant value. Use the default region if no matche is found.
4410  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4411  if (!maybeCst.has_value())
4412  return failure();
4413  int64_t cst = *maybeCst;
4414  int64_t caseIdx, e = op.getNumCases();
4415  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4416  if (cst == op.getCases()[caseIdx])
4417  break;
4418  }
4419 
4420  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4421  : op.getDefaultRegion();
4422  Block &source = r.front();
4423  Operation *terminator = source.getTerminator();
4424  SmallVector<Value> results = terminator->getOperands();
4425 
4426  rewriter.inlineBlockBefore(&source, op);
4427  rewriter.eraseOp(terminator);
4428  // Replace the operation with a potentially empty list of results.
4429  // Fold mechanism doesn't support the case where the result list is empty.
4430  rewriter.replaceOp(op, results);
4431 
4432  return success();
4433  }
4434 };
4435 
4436 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4437  MLIRContext *context) {
4438  results.add<FoldConstantCase>(context);
4439 }
4440 
4441 //===----------------------------------------------------------------------===//
4442 // TableGen'd op method definitions
4443 //===----------------------------------------------------------------------===//
4444 
4445 #define GET_OP_CLASSES
4446 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:757
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:749
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:131
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1365
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:137
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4270
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3517
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:472
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:99
static std::optional< llvm::APSInt > computeUbMinusLb(Value lb, Value ub, bool isSigned)
Helper function to compute the difference between two values.
Definition: SCF.cpp:115
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4285
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:569
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:596
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:593
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:119
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3130
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:92
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:744
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:2057
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:699
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:651
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1516
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:833
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
std::optional< APInt > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned, llvm::function_ref< std::optional< llvm::APSInt >(Value, Value, bool)> computeUbMinusLb)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4406
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:261
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:212
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.