MLIR  21.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange initArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(initArgs);
321  for (Value v : initArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : initArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (initArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Set block argument types, so that they are known when parsing the region.
503  regionArgs.front().type = type;
504  for (auto [iterArg, type] :
505  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
506  iterArg.type = type;
507 
508  // Parse the body region.
509  Region *body = result.addRegion();
510  if (parser.parseRegion(*body, regionArgs))
511  return failure();
512  ForOp::ensureTerminator(*body, builder, result.location);
513 
514  // Resolve input operands. This should be done after parsing the region to
515  // catch invalid IR where operands were defined inside of the region.
516  if (parser.resolveOperand(lb, type, result.operands) ||
517  parser.resolveOperand(ub, type, result.operands) ||
518  parser.resolveOperand(step, type, result.operands))
519  return failure();
520  if (hasIterArgs) {
521  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
522  operands, result.types)) {
523  Type type = std::get<2>(argOperandType);
524  std::get<0>(argOperandType).type = type;
525  if (parser.resolveOperand(std::get<1>(argOperandType), type,
526  result.operands))
527  return failure();
528  }
529  }
530 
531  // Parse the optional attribute list.
532  if (parser.parseOptionalAttrDict(result.attributes))
533  return failure();
534 
535  return success();
536 }
537 
538 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
539 
540 Block::BlockArgListType ForOp::getRegionIterArgs() {
541  return getBody()->getArguments().drop_front(getNumInductionVars());
542 }
543 
544 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
545  return getInitArgsMutable();
546 }
547 
548 FailureOr<LoopLikeOpInterface>
549 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
550  ValueRange newInitOperands,
551  bool replaceInitOperandUsesInLoop,
552  const NewYieldValuesFn &newYieldValuesFn) {
553  // Create a new loop before the existing one, with the extra operands.
554  OpBuilder::InsertionGuard g(rewriter);
555  rewriter.setInsertionPoint(getOperation());
556  auto inits = llvm::to_vector(getInitArgs());
557  inits.append(newInitOperands.begin(), newInitOperands.end());
558  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
559  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
560  [](OpBuilder &, Location, Value, ValueRange) {});
561  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
562 
563  // Generate the new yield values and append them to the scf.yield operation.
564  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
565  ArrayRef<BlockArgument> newIterArgs =
566  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
567  {
568  OpBuilder::InsertionGuard g(rewriter);
569  rewriter.setInsertionPoint(yieldOp);
570  SmallVector<Value> newYieldedValues =
571  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
572  assert(newInitOperands.size() == newYieldedValues.size() &&
573  "expected as many new yield values as new iter operands");
574  rewriter.modifyOpInPlace(yieldOp, [&]() {
575  yieldOp.getResultsMutable().append(newYieldedValues);
576  });
577  }
578 
579  // Move the loop body to the new op.
580  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
581  newLoop.getBody()->getArguments().take_front(
582  getBody()->getNumArguments()));
583 
584  if (replaceInitOperandUsesInLoop) {
585  // Replace all uses of `newInitOperands` with the corresponding basic block
586  // arguments.
587  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
588  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
589  [&](OpOperand &use) {
590  Operation *user = use.getOwner();
591  return newLoop->isProperAncestor(user);
592  });
593  }
594  }
595 
596  // Replace the old loop.
597  rewriter.replaceOp(getOperation(),
598  newLoop->getResults().take_front(getNumResults()));
599  return cast<LoopLikeOpInterface>(newLoop.getOperation());
600 }
601 
603  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
604  if (!ivArg)
605  return ForOp();
606  assert(ivArg.getOwner() && "unlinked block argument");
607  auto *containingOp = ivArg.getOwner()->getParentOp();
608  return dyn_cast_or_null<ForOp>(containingOp);
609 }
610 
611 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
612  return getInitArgs();
613 }
614 
615 void ForOp::getSuccessorRegions(RegionBranchPoint point,
617  // Both the operation itself and the region may be branching into the body or
618  // back into the operation itself. It is possible for loop not to enter the
619  // body.
620  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
621  regions.push_back(RegionSuccessor(getResults()));
622 }
623 
624 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
625 
626 /// Promotes the loop body of a forallOp to its containing block if it can be
627 /// determined that the loop has a single iteration.
628 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
629  for (auto [lb, ub, step] :
630  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
631  auto tripCount = constantTripCount(lb, ub, step);
632  if (!tripCount.has_value() || *tripCount != 1)
633  return failure();
634  }
635 
636  promote(rewriter, *this);
637  return success();
638 }
639 
640 Block::BlockArgListType ForallOp::getRegionIterArgs() {
641  return getBody()->getArguments().drop_front(getRank());
642 }
643 
644 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
645  return getOutputsMutable();
646 }
647 
648 /// Promotes the loop body of a scf::ForallOp to its containing block.
649 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
650  OpBuilder::InsertionGuard g(rewriter);
651  scf::InParallelOp terminator = forallOp.getTerminator();
652 
653  // Replace block arguments with lower bounds (replacements for IVs) and
654  // outputs.
655  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
656  bbArgReplacements.append(forallOp.getOutputs().begin(),
657  forallOp.getOutputs().end());
658 
659  // Move the loop body operations to the loop's containing block.
660  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
661  forallOp->getIterator(), bbArgReplacements);
662 
663  // Replace the terminator with tensor.insert_slice ops.
664  rewriter.setInsertionPointAfter(forallOp);
665  SmallVector<Value> results;
666  results.reserve(forallOp.getResults().size());
667  for (auto &yieldingOp : terminator.getYieldingOps()) {
668  auto parallelInsertSliceOp =
669  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
670 
671  Value dst = parallelInsertSliceOp.getDest();
672  Value src = parallelInsertSliceOp.getSource();
673  if (llvm::isa<TensorType>(src.getType())) {
674  results.push_back(rewriter.create<tensor::InsertSliceOp>(
675  forallOp.getLoc(), dst.getType(), src, dst,
676  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
677  parallelInsertSliceOp.getStrides(),
678  parallelInsertSliceOp.getStaticOffsets(),
679  parallelInsertSliceOp.getStaticSizes(),
680  parallelInsertSliceOp.getStaticStrides()));
681  } else {
682  llvm_unreachable("unsupported terminator");
683  }
684  }
685  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
686 
687  // Erase the old terminator and the loop.
688  rewriter.eraseOp(terminator);
689  rewriter.eraseOp(forallOp);
690 }
691 
693  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
694  ValueRange steps, ValueRange iterArgs,
696  bodyBuilder) {
697  assert(lbs.size() == ubs.size() &&
698  "expected the same number of lower and upper bounds");
699  assert(lbs.size() == steps.size() &&
700  "expected the same number of lower bounds and steps");
701 
702  // If there are no bounds, call the body-building function and return early.
703  if (lbs.empty()) {
704  ValueVector results =
705  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
706  : ValueVector();
707  assert(results.size() == iterArgs.size() &&
708  "loop nest body must return as many values as loop has iteration "
709  "arguments");
710  return LoopNest{{}, std::move(results)};
711  }
712 
713  // First, create the loop structure iteratively using the body-builder
714  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
715  OpBuilder::InsertionGuard guard(builder);
718  loops.reserve(lbs.size());
719  ivs.reserve(lbs.size());
720  ValueRange currentIterArgs = iterArgs;
721  Location currentLoc = loc;
722  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
723  auto loop = builder.create<scf::ForOp>(
724  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
726  ValueRange args) {
727  ivs.push_back(iv);
728  // It is safe to store ValueRange args because it points to block
729  // arguments of a loop operation that we also own.
730  currentIterArgs = args;
731  currentLoc = nestedLoc;
732  });
733  // Set the builder to point to the body of the newly created loop. We don't
734  // do this in the callback because the builder is reset when the callback
735  // returns.
736  builder.setInsertionPointToStart(loop.getBody());
737  loops.push_back(loop);
738  }
739 
740  // For all loops but the innermost, yield the results of the nested loop.
741  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
742  builder.setInsertionPointToEnd(loops[i].getBody());
743  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
744  }
745 
746  // In the body of the innermost loop, call the body building function if any
747  // and yield its results.
748  builder.setInsertionPointToStart(loops.back().getBody());
749  ValueVector results = bodyBuilder
750  ? bodyBuilder(builder, currentLoc, ivs,
751  loops.back().getRegionIterArgs())
752  : ValueVector();
753  assert(results.size() == iterArgs.size() &&
754  "loop nest body must return as many values as loop has iteration "
755  "arguments");
756  builder.setInsertionPointToEnd(loops.back().getBody());
757  builder.create<scf::YieldOp>(loc, results);
758 
759  // Return the loops.
760  ValueVector nestResults;
761  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
762  return LoopNest{std::move(loops), std::move(nestResults)};
763 }
764 
766  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
767  ValueRange steps,
768  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
769  // Delegate to the main function by wrapping the body builder.
770  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
771  [&bodyBuilder](OpBuilder &nestedBuilder,
772  Location nestedLoc, ValueRange ivs,
773  ValueRange) -> ValueVector {
774  if (bodyBuilder)
775  bodyBuilder(nestedBuilder, nestedLoc, ivs);
776  return {};
777  });
778 }
779 
782  OpOperand &operand, Value replacement,
783  const ValueTypeCastFnTy &castFn) {
784  assert(operand.getOwner() == forOp);
785  Type oldType = operand.get().getType(), newType = replacement.getType();
786 
787  // 1. Create new iter operands, exactly 1 is replaced.
788  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
789  "expected an iter OpOperand");
790  assert(operand.get().getType() != replacement.getType() &&
791  "Expected a different type");
792  SmallVector<Value> newIterOperands;
793  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
794  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
795  newIterOperands.push_back(replacement);
796  continue;
797  }
798  newIterOperands.push_back(opOperand.get());
799  }
800 
801  // 2. Create the new forOp shell.
802  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
803  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
804  forOp.getStep(), newIterOperands);
805  newForOp->setAttrs(forOp->getAttrs());
806  Block &newBlock = newForOp.getRegion().front();
807  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
808  newBlock.getArguments().end());
809 
810  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
811  // corresponding to the `replacement` value.
812  OpBuilder::InsertionGuard g(rewriter);
813  rewriter.setInsertionPointToStart(&newBlock);
814  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
815  &newForOp->getOpOperand(operand.getOperandNumber()));
816  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
817  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
818 
819  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
820  Block &oldBlock = forOp.getRegion().front();
821  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
822 
823  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
824  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
825  rewriter.setInsertionPoint(clonedYieldOp);
826  unsigned yieldIdx =
827  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
828  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
829  clonedYieldOp.getOperand(yieldIdx));
830  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
831  newYieldOperands[yieldIdx] = castOut;
832  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
833  rewriter.eraseOp(clonedYieldOp);
834 
835  // 6. Inject an outgoing cast op after the forOp.
836  rewriter.setInsertionPointAfter(newForOp);
837  SmallVector<Value> newResults = newForOp.getResults();
838  newResults[yieldIdx] =
839  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
840 
841  return newResults;
842 }
843 
844 namespace {
845 // Fold away ForOp iter arguments when:
846 // 1) The op yields the iter arguments.
847 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
848 // 3) The iter arguments have no use and the corresponding (operation) results
849 // have no use.
850 //
851 // These arguments must be defined outside of the ForOp region and can just be
852 // forwarded after simplifying the op inits, yields and returns.
853 //
854 // The implementation uses `inlineBlockBefore` to steal the content of the
855 // original ForOp and avoid cloning.
856 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
858 
859  LogicalResult matchAndRewrite(scf::ForOp forOp,
860  PatternRewriter &rewriter) const final {
861  bool canonicalize = false;
862 
863  // An internal flat vector of block transfer
864  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
865  // transformed block argument mappings. This plays the role of a
866  // IRMapping for the particular use case of calling into
867  // `inlineBlockBefore`.
868  int64_t numResults = forOp.getNumResults();
869  SmallVector<bool, 4> keepMask;
870  keepMask.reserve(numResults);
871  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
872  newResultValues;
873  newBlockTransferArgs.reserve(1 + numResults);
874  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
875  newIterArgs.reserve(forOp.getInitArgs().size());
876  newYieldValues.reserve(numResults);
877  newResultValues.reserve(numResults);
878  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
879  for (auto [init, arg, result, yielded] :
880  llvm::zip(forOp.getInitArgs(), // iter from outside
881  forOp.getRegionIterArgs(), // iter inside region
882  forOp.getResults(), // op results
883  forOp.getYieldedValues() // iter yield
884  )) {
885  // Forwarded is `true` when:
886  // 1) The region `iter` argument is yielded.
887  // 2) The region `iter` argument the corresponding input is yielded.
888  // 3) The region `iter` argument has no use, and the corresponding op
889  // result has no use.
890  bool forwarded = (arg == yielded) || (init == yielded) ||
891  (arg.use_empty() && result.use_empty());
892  if (forwarded) {
893  canonicalize = true;
894  keepMask.push_back(false);
895  newBlockTransferArgs.push_back(init);
896  newResultValues.push_back(init);
897  continue;
898  }
899 
900  // Check if a previous kept argument always has the same values for init
901  // and yielded values.
902  if (auto it = initYieldToArg.find({init, yielded});
903  it != initYieldToArg.end()) {
904  canonicalize = true;
905  keepMask.push_back(false);
906  auto [sameArg, sameResult] = it->second;
907  rewriter.replaceAllUsesWith(arg, sameArg);
908  rewriter.replaceAllUsesWith(result, sameResult);
909  // The replacement value doesn't matter because there are no uses.
910  newBlockTransferArgs.push_back(init);
911  newResultValues.push_back(init);
912  continue;
913  }
914 
915  // This value is kept.
916  initYieldToArg.insert({{init, yielded}, {arg, result}});
917  keepMask.push_back(true);
918  newIterArgs.push_back(init);
919  newYieldValues.push_back(yielded);
920  newBlockTransferArgs.push_back(Value()); // placeholder with null value
921  newResultValues.push_back(Value()); // placeholder with null value
922  }
923 
924  if (!canonicalize)
925  return failure();
926 
927  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
928  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
929  forOp.getStep(), newIterArgs);
930  newForOp->setAttrs(forOp->getAttrs());
931  Block &newBlock = newForOp.getRegion().front();
932 
933  // Replace the null placeholders with newly constructed values.
934  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
935  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
936  idx != e; ++idx) {
937  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
938  Value &newResultVal = newResultValues[idx];
939  assert((blockTransferArg && newResultVal) ||
940  (!blockTransferArg && !newResultVal));
941  if (!blockTransferArg) {
942  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
943  newResultVal = newForOp.getResult(collapsedIdx++);
944  }
945  }
946 
947  Block &oldBlock = forOp.getRegion().front();
948  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
949  "unexpected argument size mismatch");
950 
951  // No results case: the scf::ForOp builder already created a zero
952  // result terminator. Merge before this terminator and just get rid of the
953  // original terminator that has been merged in.
954  if (newIterArgs.empty()) {
955  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
956  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
957  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
958  rewriter.replaceOp(forOp, newResultValues);
959  return success();
960  }
961 
962  // No terminator case: merge and rewrite the merged terminator.
963  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
964  OpBuilder::InsertionGuard g(rewriter);
965  rewriter.setInsertionPoint(mergedTerminator);
966  SmallVector<Value, 4> filteredOperands;
967  filteredOperands.reserve(newResultValues.size());
968  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
969  if (keepMask[idx])
970  filteredOperands.push_back(mergedTerminator.getOperand(idx));
971  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
972  filteredOperands);
973  };
974 
975  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
976  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
977  cloneFilteredTerminator(mergedYieldOp);
978  rewriter.eraseOp(mergedYieldOp);
979  rewriter.replaceOp(forOp, newResultValues);
980  return success();
981  }
982 };
983 
984 /// Util function that tries to compute a constant diff between u and l.
985 /// Returns std::nullopt when the difference between two AffineValueMap is
986 /// dynamic.
987 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
988  IntegerAttr clb, cub;
989  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
990  llvm::APInt lbValue = clb.getValue();
991  llvm::APInt ubValue = cub.getValue();
992  return (ubValue - lbValue).getSExtValue();
993  }
994 
995  // Else a simple pattern match for x + c or c + x
996  llvm::APInt diff;
997  if (matchPattern(
998  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
999  matchPattern(
1000  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1001  return diff.getSExtValue();
1002  return std::nullopt;
1003 }
1004 
1005 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1006 /// single-iteration loops with their bodies, and removes empty loops that
1007 /// iterate at least once and only return values defined outside of the loop.
1008 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1010 
1011  LogicalResult matchAndRewrite(ForOp op,
1012  PatternRewriter &rewriter) const override {
1013  // If the upper bound is the same as the lower bound, the loop does not
1014  // iterate, just remove it.
1015  if (op.getLowerBound() == op.getUpperBound()) {
1016  rewriter.replaceOp(op, op.getInitArgs());
1017  return success();
1018  }
1019 
1020  std::optional<int64_t> diff =
1021  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1022  if (!diff)
1023  return failure();
1024 
1025  // If the loop is known to have 0 iterations, remove it.
1026  if (*diff <= 0) {
1027  rewriter.replaceOp(op, op.getInitArgs());
1028  return success();
1029  }
1030 
1031  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1032  if (!maybeStepValue)
1033  return failure();
1034 
1035  // If the loop is known to have 1 iteration, inline its body and remove the
1036  // loop.
1037  llvm::APInt stepValue = *maybeStepValue;
1038  if (stepValue.sge(*diff)) {
1039  SmallVector<Value, 4> blockArgs;
1040  blockArgs.reserve(op.getInitArgs().size() + 1);
1041  blockArgs.push_back(op.getLowerBound());
1042  llvm::append_range(blockArgs, op.getInitArgs());
1043  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1044  return success();
1045  }
1046 
1047  // Now we are left with loops that have more than 1 iterations.
1048  Block &block = op.getRegion().front();
1049  if (!llvm::hasSingleElement(block))
1050  return failure();
1051  // If the loop is empty, iterates at least once, and only returns values
1052  // defined outside of the loop, remove it and replace it with yield values.
1053  if (llvm::any_of(op.getYieldedValues(),
1054  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1055  return failure();
1056  rewriter.replaceOp(op, op.getYieldedValues());
1057  return success();
1058  }
1059 };
1060 
1061 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1062 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1063 ///
1064 /// ```
1065 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1066 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1067 /// -> (tensor<?x?xf32>) {
1068 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1069 /// scf.yield %2 : tensor<?x?xf32>
1070 /// }
1071 /// use_of(%1)
1072 /// ```
1073 ///
1074 /// folds into:
1075 ///
1076 /// ```
1077 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1078 /// -> (tensor<32x1024xf32>) {
1079 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1080 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1081 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1082 /// scf.yield %4 : tensor<32x1024xf32>
1083 /// }
1084 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1085 /// use_of(%1)
1086 /// ```
1087 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1089 
1090  LogicalResult matchAndRewrite(ForOp op,
1091  PatternRewriter &rewriter) const override {
1092  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1093  OpOperand &iterOpOperand = std::get<0>(it);
1094  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1095  if (!incomingCast ||
1096  incomingCast.getSource().getType() == incomingCast.getType())
1097  continue;
1098  // If the dest type of the cast does not preserve static information in
1099  // the source type.
1101  incomingCast.getDest().getType(),
1102  incomingCast.getSource().getType()))
1103  continue;
1104  if (!std::get<1>(it).hasOneUse())
1105  continue;
1106 
1107  // Create a new ForOp with that iter operand replaced.
1108  rewriter.replaceOp(
1110  rewriter, op, iterOpOperand, incomingCast.getSource(),
1111  [](OpBuilder &b, Location loc, Type type, Value source) {
1112  return b.create<tensor::CastOp>(loc, type, source);
1113  }));
1114  return success();
1115  }
1116  return failure();
1117  }
1118 };
1119 
1120 } // namespace
1121 
1122 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123  MLIRContext *context) {
1124  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1125  context);
1126 }
1127 
1128 std::optional<APInt> ForOp::getConstantStep() {
1129  IntegerAttr step;
1130  if (matchPattern(getStep(), m_Constant(&step)))
1131  return step.getValue();
1132  return {};
1133 }
1134 
1135 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1136  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1137 }
1138 
1139 Speculation::Speculatability ForOp::getSpeculatability() {
1140  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1141  // and End.
1142  if (auto constantStep = getConstantStep())
1143  if (*constantStep == 1)
1145 
1146  // For Step != 1, the loop may not terminate. We can add more smarts here if
1147  // needed.
1149 }
1150 
1151 //===----------------------------------------------------------------------===//
1152 // ForallOp
1153 //===----------------------------------------------------------------------===//
1154 
1155 LogicalResult ForallOp::verify() {
1156  unsigned numLoops = getRank();
1157  // Check number of outputs.
1158  if (getNumResults() != getOutputs().size())
1159  return emitOpError("produces ")
1160  << getNumResults() << " results, but has only "
1161  << getOutputs().size() << " outputs";
1162 
1163  // Check that the body defines block arguments for thread indices and outputs.
1164  auto *body = getBody();
1165  if (body->getNumArguments() != numLoops + getOutputs().size())
1166  return emitOpError("region expects ") << numLoops << " arguments";
1167  for (int64_t i = 0; i < numLoops; ++i)
1168  if (!body->getArgument(i).getType().isIndex())
1169  return emitOpError("expects ")
1170  << i << "-th block argument to be an index";
1171  for (unsigned i = 0; i < getOutputs().size(); ++i)
1172  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1173  return emitOpError("type mismatch between ")
1174  << i << "-th output and corresponding block argument";
1175  if (getMapping().has_value() && !getMapping()->empty()) {
1176  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1177  return emitOpError() << "mapping attribute size must match op rank";
1178  for (auto map : getMapping()->getValue()) {
1179  if (!isa<DeviceMappingAttrInterface>(map))
1180  return emitOpError()
1181  << getMappingAttrName() << " is not device mapping attribute";
1182  }
1183  }
1184 
1185  // Verify mixed static/dynamic control variables.
1186  Operation *op = getOperation();
1187  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1188  getStaticLowerBound(),
1189  getDynamicLowerBound())))
1190  return failure();
1191  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1192  getStaticUpperBound(),
1193  getDynamicUpperBound())))
1194  return failure();
1195  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1196  getStaticStep(), getDynamicStep())))
1197  return failure();
1198 
1199  return success();
1200 }
1201 
1202 void ForallOp::print(OpAsmPrinter &p) {
1203  Operation *op = getOperation();
1204  p << " (" << getInductionVars();
1205  if (isNormalized()) {
1206  p << ") in ";
1207  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1208  /*valueTypes=*/{}, /*scalables=*/{},
1210  } else {
1211  p << ") = ";
1212  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1213  /*valueTypes=*/{}, /*scalables=*/{},
1215  p << " to ";
1216  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1217  /*valueTypes=*/{}, /*scalables=*/{},
1219  p << " step ";
1220  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1221  /*valueTypes=*/{}, /*scalables=*/{},
1223  }
1224  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1225  p << " ";
1226  if (!getRegionOutArgs().empty())
1227  p << "-> (" << getResultTypes() << ") ";
1228  p.printRegion(getRegion(),
1229  /*printEntryBlockArgs=*/false,
1230  /*printBlockTerminators=*/getNumResults() > 0);
1231  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1232  getStaticLowerBoundAttrName(),
1233  getStaticUpperBoundAttrName(),
1234  getStaticStepAttrName()});
1235 }
1236 
1237 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1238  OpBuilder b(parser.getContext());
1239  auto indexType = b.getIndexType();
1240 
1241  // Parse an opening `(` followed by thread index variables followed by `)`
1242  // TODO: when we can refer to such "induction variable"-like handles from the
1243  // declarative assembly format, we can implement the parser as a custom hook.
1246  return failure();
1247 
1248  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1249  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1250  dynamicSteps;
1251  if (succeeded(parser.parseOptionalKeyword("in"))) {
1252  // Parse upper bounds.
1253  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1254  /*valueTypes=*/nullptr,
1256  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1257  return failure();
1258 
1259  unsigned numLoops = ivs.size();
1260  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1261  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1262  } else {
1263  // Parse lower bounds.
1264  if (parser.parseEqual() ||
1265  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1266  /*valueTypes=*/nullptr,
1268 
1269  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1270  return failure();
1271 
1272  // Parse upper bounds.
1273  if (parser.parseKeyword("to") ||
1274  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1275  /*valueTypes=*/nullptr,
1277  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1278  return failure();
1279 
1280  // Parse step values.
1281  if (parser.parseKeyword("step") ||
1282  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1283  /*valueTypes=*/nullptr,
1285  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1286  return failure();
1287  }
1288 
1289  // Parse out operands and results.
1292  SMLoc outOperandsLoc = parser.getCurrentLocation();
1293  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1294  if (outOperands.size() != result.types.size())
1295  return parser.emitError(outOperandsLoc,
1296  "mismatch between out operands and types");
1297  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1298  parser.parseOptionalArrowTypeList(result.types) ||
1299  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1300  result.operands))
1301  return failure();
1302  }
1303 
1304  // Parse region.
1306  std::unique_ptr<Region> region = std::make_unique<Region>();
1307  for (auto &iv : ivs) {
1308  iv.type = b.getIndexType();
1309  regionArgs.push_back(iv);
1310  }
1311  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1312  auto &out = it.value();
1313  out.type = result.types[it.index()];
1314  regionArgs.push_back(out);
1315  }
1316  if (parser.parseRegion(*region, regionArgs))
1317  return failure();
1318 
1319  // Ensure terminator and move region.
1320  ForallOp::ensureTerminator(*region, b, result.location);
1321  result.addRegion(std::move(region));
1322 
1323  // Parse the optional attribute list.
1324  if (parser.parseOptionalAttrDict(result.attributes))
1325  return failure();
1326 
1327  result.addAttribute("staticLowerBound", staticLbs);
1328  result.addAttribute("staticUpperBound", staticUbs);
1329  result.addAttribute("staticStep", staticSteps);
1330  result.addAttribute("operandSegmentSizes",
1332  {static_cast<int32_t>(dynamicLbs.size()),
1333  static_cast<int32_t>(dynamicUbs.size()),
1334  static_cast<int32_t>(dynamicSteps.size()),
1335  static_cast<int32_t>(outOperands.size())}));
1336  return success();
1337 }
1338 
1339 // Builder that takes loop bounds.
1340 void ForallOp::build(
1343  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1344  std::optional<ArrayAttr> mapping,
1345  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1346  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1347  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1348  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1349  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1350  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1351 
1352  result.addOperands(dynamicLbs);
1353  result.addOperands(dynamicUbs);
1354  result.addOperands(dynamicSteps);
1355  result.addOperands(outputs);
1356  result.addTypes(TypeRange(outputs));
1357 
1358  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1359  b.getDenseI64ArrayAttr(staticLbs));
1360  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1361  b.getDenseI64ArrayAttr(staticUbs));
1362  result.addAttribute(getStaticStepAttrName(result.name),
1363  b.getDenseI64ArrayAttr(staticSteps));
1364  result.addAttribute(
1365  "operandSegmentSizes",
1366  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1367  static_cast<int32_t>(dynamicUbs.size()),
1368  static_cast<int32_t>(dynamicSteps.size()),
1369  static_cast<int32_t>(outputs.size())}));
1370  if (mapping.has_value()) {
1372  mapping.value());
1373  }
1374 
1375  Region *bodyRegion = result.addRegion();
1377  b.createBlock(bodyRegion);
1378  Block &bodyBlock = bodyRegion->front();
1379 
1380  // Add block arguments for indices and outputs.
1381  bodyBlock.addArguments(
1382  SmallVector<Type>(lbs.size(), b.getIndexType()),
1383  SmallVector<Location>(staticLbs.size(), result.location));
1384  bodyBlock.addArguments(
1385  TypeRange(outputs),
1386  SmallVector<Location>(outputs.size(), result.location));
1387 
1388  b.setInsertionPointToStart(&bodyBlock);
1389  if (!bodyBuilderFn) {
1390  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1391  return;
1392  }
1393  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1394 }
1395 
1396 // Builder that takes loop bounds.
1397 void ForallOp::build(
1399  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1400  std::optional<ArrayAttr> mapping,
1401  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1402  unsigned numLoops = ubs.size();
1403  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1404  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1405  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1406 }
1407 
1408 // Checks if the lbs are zeros and steps are ones.
1409 bool ForallOp::isNormalized() {
1410  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1411  return llvm::all_of(results, [&](OpFoldResult ofr) {
1412  auto intValue = getConstantIntValue(ofr);
1413  return intValue.has_value() && intValue == val;
1414  });
1415  };
1416  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1417 }
1418 
1419 InParallelOp ForallOp::getTerminator() {
1420  return cast<InParallelOp>(getBody()->getTerminator());
1421 }
1422 
1423 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1424  SmallVector<Operation *> storeOps;
1425  InParallelOp inParallelOp = getTerminator();
1426  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1427  if (auto parallelInsertSliceOp =
1428  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1429  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1430  storeOps.push_back(parallelInsertSliceOp);
1431  }
1432  }
1433  return storeOps;
1434 }
1435 
1436 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1437  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1438 }
1439 
1440 // Get lower bounds as OpFoldResult.
1441 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1442  Builder b(getOperation()->getContext());
1443  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1444 }
1445 
1446 // Get upper bounds as OpFoldResult.
1447 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1448  Builder b(getOperation()->getContext());
1449  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1450 }
1451 
1452 // Get steps as OpFoldResult.
1453 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1454  Builder b(getOperation()->getContext());
1455  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1456 }
1457 
1459  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1460  if (!tidxArg)
1461  return ForallOp();
1462  assert(tidxArg.getOwner() && "unlinked block argument");
1463  auto *containingOp = tidxArg.getOwner()->getParentOp();
1464  return dyn_cast<ForallOp>(containingOp);
1465 }
1466 
1467 namespace {
1468 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1469 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1471 
1472  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1473  PatternRewriter &rewriter) const final {
1474  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1475  if (!forallOp)
1476  return failure();
1477  Value sharedOut =
1478  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1479  ->get();
1480  rewriter.modifyOpInPlace(
1481  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1482  return success();
1483  }
1484 };
1485 
1486 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1487 public:
1489 
1490  LogicalResult matchAndRewrite(ForallOp op,
1491  PatternRewriter &rewriter) const override {
1492  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1493  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1494  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1495  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1496  failed(foldDynamicIndexList(mixedUpperBound)) &&
1497  failed(foldDynamicIndexList(mixedStep)))
1498  return failure();
1499 
1500  rewriter.modifyOpInPlace(op, [&]() {
1501  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1502  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1503  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1504  staticLowerBound);
1505  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1506  op.setStaticLowerBound(staticLowerBound);
1507 
1508  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1509  staticUpperBound);
1510  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1511  op.setStaticUpperBound(staticUpperBound);
1512 
1513  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1514  op.getDynamicStepMutable().assign(dynamicStep);
1515  op.setStaticStep(staticStep);
1516 
1517  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1518  rewriter.getDenseI32ArrayAttr(
1519  {static_cast<int32_t>(dynamicLowerBound.size()),
1520  static_cast<int32_t>(dynamicUpperBound.size()),
1521  static_cast<int32_t>(dynamicStep.size()),
1522  static_cast<int32_t>(op.getNumResults())}));
1523  });
1524  return success();
1525  }
1526 };
1527 
1528 /// The following canonicalization pattern folds the iter arguments of
1529 /// scf.forall op if :-
1530 /// 1. The corresponding result has zero uses.
1531 /// 2. The iter argument is NOT being modified within the loop body.
1532 /// uses.
1533 ///
1534 /// Example of first case :-
1535 /// INPUT:
1536 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1537 /// {
1538 /// ...
1539 /// <SOME USE OF %arg0>
1540 /// <SOME USE OF %arg1>
1541 /// <SOME USE OF %arg2>
1542 /// ...
1543 /// scf.forall.in_parallel {
1544 /// <STORE OP WITH DESTINATION %arg1>
1545 /// <STORE OP WITH DESTINATION %arg0>
1546 /// <STORE OP WITH DESTINATION %arg2>
1547 /// }
1548 /// }
1549 /// return %res#1
1550 ///
1551 /// OUTPUT:
1552 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1553 /// {
1554 /// ...
1555 /// <SOME USE OF %a>
1556 /// <SOME USE OF %new_arg0>
1557 /// <SOME USE OF %c>
1558 /// ...
1559 /// scf.forall.in_parallel {
1560 /// <STORE OP WITH DESTINATION %new_arg0>
1561 /// }
1562 /// }
1563 /// return %res
1564 ///
1565 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1566 /// scf.forall is replaced by their corresponding operands.
1567 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1568 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1569 /// this canonicalization remains valid. For more details, please refer
1570 /// to :
1571 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1572 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1573 /// handles tensor.parallel_insert_slice ops only.
1574 ///
1575 /// Example of second case :-
1576 /// INPUT:
1577 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1578 /// {
1579 /// ...
1580 /// <SOME USE OF %arg0>
1581 /// <SOME USE OF %arg1>
1582 /// ...
1583 /// scf.forall.in_parallel {
1584 /// <STORE OP WITH DESTINATION %arg1>
1585 /// }
1586 /// }
1587 /// return %res#0, %res#1
1588 ///
1589 /// OUTPUT:
1590 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1591 /// {
1592 /// ...
1593 /// <SOME USE OF %a>
1594 /// <SOME USE OF %new_arg0>
1595 /// ...
1596 /// scf.forall.in_parallel {
1597 /// <STORE OP WITH DESTINATION %new_arg0>
1598 /// }
1599 /// }
1600 /// return %a, %res
1601 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1603 
1604  LogicalResult matchAndRewrite(ForallOp forallOp,
1605  PatternRewriter &rewriter) const final {
1606  // Step 1: For a given i-th result of scf.forall, check the following :-
1607  // a. If it has any use.
1608  // b. If the corresponding iter argument is being modified within
1609  // the loop, i.e. has at least one store op with the iter arg as
1610  // its destination operand. For this we use
1611  // ForallOp::getCombiningOps(iter_arg).
1612  //
1613  // Based on the check we maintain the following :-
1614  // a. `resultToDelete` - i-th result of scf.forall that'll be
1615  // deleted.
1616  // b. `resultToReplace` - i-th result of the old scf.forall
1617  // whose uses will be replaced by the new scf.forall.
1618  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1619  // corresponding to the i-th result with at least one use.
1620  SetVector<OpResult> resultToDelete;
1621  SmallVector<Value> resultToReplace;
1622  SmallVector<Value> newOuts;
1623  for (OpResult result : forallOp.getResults()) {
1624  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1625  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1626  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1627  resultToDelete.insert(result);
1628  } else {
1629  resultToReplace.push_back(result);
1630  newOuts.push_back(opOperand->get());
1631  }
1632  }
1633 
1634  // Return early if all results of scf.forall have at least one use and being
1635  // modified within the loop.
1636  if (resultToDelete.empty())
1637  return failure();
1638 
1639  // Step 2: For the the i-th result, do the following :-
1640  // a. Fetch the corresponding BlockArgument.
1641  // b. Look for store ops (currently tensor.parallel_insert_slice)
1642  // with the BlockArgument as its destination operand.
1643  // c. Remove the operations fetched in b.
1644  for (OpResult result : resultToDelete) {
1645  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1646  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1647  SmallVector<Operation *> combiningOps =
1648  forallOp.getCombiningOps(blockArg);
1649  for (Operation *combiningOp : combiningOps)
1650  rewriter.eraseOp(combiningOp);
1651  }
1652 
1653  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1654  // fetched earlier
1655  auto newForallOp = rewriter.create<scf::ForallOp>(
1656  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1657  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1658  forallOp.getMapping(),
1659  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1660 
1661  // Step 4. Merge the block of the old scf.forall into the newly created
1662  // scf.forall using the new set of arguments.
1663  Block *loopBody = forallOp.getBody();
1664  Block *newLoopBody = newForallOp.getBody();
1665  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1666  // Form initial new bbArg list with just the control operands of the new
1667  // scf.forall op.
1668  SmallVector<Value> newBlockArgs =
1669  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1670  [](BlockArgument b) -> Value { return b; });
1671  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1672  unsigned index = 0;
1673  // Take the new corresponding bbArg if the old bbArg was used as a
1674  // destination in the in_parallel op. For all other bbArgs, use the
1675  // corresponding init_arg from the old scf.forall op.
1676  for (OpResult result : forallOp.getResults()) {
1677  if (resultToDelete.count(result)) {
1678  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1679  } else {
1680  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1681  }
1682  }
1683  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1684 
1685  // Step 5. Replace the uses of result of old scf.forall with that of the new
1686  // scf.forall.
1687  for (auto &&[oldResult, newResult] :
1688  llvm::zip(resultToReplace, newForallOp->getResults()))
1689  rewriter.replaceAllUsesWith(oldResult, newResult);
1690 
1691  // Step 6. Replace the uses of those values that either has no use or are
1692  // not being modified within the loop with the corresponding
1693  // OpOperand.
1694  for (OpResult oldResult : resultToDelete)
1695  rewriter.replaceAllUsesWith(oldResult,
1696  forallOp.getTiedOpOperand(oldResult)->get());
1697  return success();
1698  }
1699 };
1700 
1701 struct ForallOpSingleOrZeroIterationDimsFolder
1702  : public OpRewritePattern<ForallOp> {
1704 
1705  LogicalResult matchAndRewrite(ForallOp op,
1706  PatternRewriter &rewriter) const override {
1707  // Do not fold dimensions if they are mapped to processing units.
1708  if (op.getMapping().has_value() && !op.getMapping()->empty())
1709  return failure();
1710  Location loc = op.getLoc();
1711 
1712  // Compute new loop bounds that omit all single-iteration loop dimensions.
1713  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1714  newMixedSteps;
1715  IRMapping mapping;
1716  for (auto [lb, ub, step, iv] :
1717  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1718  op.getMixedStep(), op.getInductionVars())) {
1719  auto numIterations = constantTripCount(lb, ub, step);
1720  if (numIterations.has_value()) {
1721  // Remove the loop if it performs zero iterations.
1722  if (*numIterations == 0) {
1723  rewriter.replaceOp(op, op.getOutputs());
1724  return success();
1725  }
1726  // Replace the loop induction variable by the lower bound if the loop
1727  // performs a single iteration. Otherwise, copy the loop bounds.
1728  if (*numIterations == 1) {
1729  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1730  continue;
1731  }
1732  }
1733  newMixedLowerBounds.push_back(lb);
1734  newMixedUpperBounds.push_back(ub);
1735  newMixedSteps.push_back(step);
1736  }
1737 
1738  // All of the loop dimensions perform a single iteration. Inline loop body.
1739  if (newMixedLowerBounds.empty()) {
1740  promote(rewriter, op);
1741  return success();
1742  }
1743 
1744  // Exit if none of the loop dimensions perform a single iteration.
1745  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1746  return rewriter.notifyMatchFailure(
1747  op, "no dimensions have 0 or 1 iterations");
1748  }
1749 
1750  // Replace the loop by a lower-dimensional loop.
1751  ForallOp newOp;
1752  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1753  newMixedUpperBounds, newMixedSteps,
1754  op.getOutputs(), std::nullopt, nullptr);
1755  newOp.getBodyRegion().getBlocks().clear();
1756  // The new loop needs to keep all attributes from the old one, except for
1757  // "operandSegmentSizes" and static loop bound attributes which capture
1758  // the outdated information of the old iteration domain.
1759  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1760  newOp.getStaticLowerBoundAttrName(),
1761  newOp.getStaticUpperBoundAttrName(),
1762  newOp.getStaticStepAttrName()};
1763  for (const auto &namedAttr : op->getAttrs()) {
1764  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1765  continue;
1766  rewriter.modifyOpInPlace(newOp, [&]() {
1767  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1768  });
1769  }
1770  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1771  newOp.getRegion().begin(), mapping);
1772  rewriter.replaceOp(op, newOp.getResults());
1773  return success();
1774  }
1775 };
1776 
1777 /// Replace all induction vars with a single trip count with their lower bound.
1778 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1780 
1781  LogicalResult matchAndRewrite(ForallOp op,
1782  PatternRewriter &rewriter) const override {
1783  Location loc = op.getLoc();
1784  bool changed = false;
1785  for (auto [lb, ub, step, iv] :
1786  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1787  op.getMixedStep(), op.getInductionVars())) {
1788  if (iv.getUses().begin() == iv.getUses().end())
1789  continue;
1790  auto numIterations = constantTripCount(lb, ub, step);
1791  if (!numIterations.has_value() || numIterations.value() != 1) {
1792  continue;
1793  }
1794  rewriter.replaceAllUsesWith(
1795  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1796  changed = true;
1797  }
1798  return success(changed);
1799  }
1800 };
1801 
1802 struct FoldTensorCastOfOutputIntoForallOp
1803  : public OpRewritePattern<scf::ForallOp> {
1805 
1806  struct TypeCast {
1807  Type srcType;
1808  Type dstType;
1809  };
1810 
1811  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1812  PatternRewriter &rewriter) const final {
1813  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1814  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1815  for (auto en : llvm::enumerate(newOutputTensors)) {
1816  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1817  if (!castOp)
1818  continue;
1819 
1820  // Only casts that that preserve static information, i.e. will make the
1821  // loop result type "more" static than before, will be folded.
1822  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1823  castOp.getSource().getType())) {
1824  continue;
1825  }
1826 
1827  tensorCastProducers[en.index()] =
1828  TypeCast{castOp.getSource().getType(), castOp.getType()};
1829  newOutputTensors[en.index()] = castOp.getSource();
1830  }
1831 
1832  if (tensorCastProducers.empty())
1833  return failure();
1834 
1835  // Create new loop.
1836  Location loc = forallOp.getLoc();
1837  auto newForallOp = rewriter.create<ForallOp>(
1838  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1839  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1840  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1841  auto castBlockArgs =
1842  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1843  for (auto [index, cast] : tensorCastProducers) {
1844  Value &oldTypeBBArg = castBlockArgs[index];
1845  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1846  nestedLoc, cast.dstType, oldTypeBBArg);
1847  }
1848 
1849  // Move old body into new parallel loop.
1850  SmallVector<Value> ivsBlockArgs =
1851  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1852  ivsBlockArgs.append(castBlockArgs);
1853  rewriter.mergeBlocks(forallOp.getBody(),
1854  bbArgs.front().getParentBlock(), ivsBlockArgs);
1855  });
1856 
1857  // After `mergeBlocks` happened, the destinations in the terminator were
1858  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1859  // destination have to be updated to point to the output bbArgs directly.
1860  auto terminator = newForallOp.getTerminator();
1861  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1862  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1863  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1864  insertSliceOp.getDestMutable().assign(outputBlockArg);
1865  }
1866 
1867  // Cast results back to the original types.
1868  rewriter.setInsertionPointAfter(newForallOp);
1869  SmallVector<Value> castResults = newForallOp.getResults();
1870  for (auto &item : tensorCastProducers) {
1871  Value &oldTypeResult = castResults[item.first];
1872  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1873  oldTypeResult);
1874  }
1875  rewriter.replaceOp(forallOp, castResults);
1876  return success();
1877  }
1878 };
1879 
1880 } // namespace
1881 
1882 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1883  MLIRContext *context) {
1884  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1885  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1886  ForallOpSingleOrZeroIterationDimsFolder,
1887  ForallOpReplaceConstantInductionVar>(context);
1888 }
1889 
1890 /// Given the region at `index`, or the parent operation if `index` is None,
1891 /// return the successor regions. These are the regions that may be selected
1892 /// during the flow of control. `operands` is a set of optional attributes that
1893 /// correspond to a constant value for each operand, or null if that operand is
1894 /// not a constant.
1895 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1897  // Both the operation itself and the region may be branching into the body or
1898  // back into the operation itself. It is possible for loop not to enter the
1899  // body.
1900  regions.push_back(RegionSuccessor(&getRegion()));
1901  regions.push_back(RegionSuccessor());
1902 }
1903 
1904 //===----------------------------------------------------------------------===//
1905 // InParallelOp
1906 //===----------------------------------------------------------------------===//
1907 
1908 // Build a InParallelOp with mixed static and dynamic entries.
1909 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1911  Region *bodyRegion = result.addRegion();
1912  b.createBlock(bodyRegion);
1913 }
1914 
1915 LogicalResult InParallelOp::verify() {
1916  scf::ForallOp forallOp =
1917  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1918  if (!forallOp)
1919  return this->emitOpError("expected forall op parent");
1920 
1921  // TODO: InParallelOpInterface.
1922  for (Operation &op : getRegion().front().getOperations()) {
1923  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1924  return this->emitOpError("expected only ")
1925  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1926  }
1927 
1928  // Verify that inserts are into out block arguments.
1929  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1930  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1931  if (!llvm::is_contained(regionOutArgs, dest))
1932  return op.emitOpError("may only insert into an output block argument");
1933  }
1934  return success();
1935 }
1936 
1938  p << " ";
1939  p.printRegion(getRegion(),
1940  /*printEntryBlockArgs=*/false,
1941  /*printBlockTerminators=*/false);
1942  p.printOptionalAttrDict(getOperation()->getAttrs());
1943 }
1944 
1945 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1946  auto &builder = parser.getBuilder();
1947 
1949  std::unique_ptr<Region> region = std::make_unique<Region>();
1950  if (parser.parseRegion(*region, regionOperands))
1951  return failure();
1952 
1953  if (region->empty())
1954  OpBuilder(builder.getContext()).createBlock(region.get());
1955  result.addRegion(std::move(region));
1956 
1957  // Parse the optional attribute list.
1958  if (parser.parseOptionalAttrDict(result.attributes))
1959  return failure();
1960  return success();
1961 }
1962 
1963 OpResult InParallelOp::getParentResult(int64_t idx) {
1964  return getOperation()->getParentOp()->getResult(idx);
1965 }
1966 
1967 SmallVector<BlockArgument> InParallelOp::getDests() {
1968  return llvm::to_vector<4>(
1969  llvm::map_range(getYieldingOps(), [](Operation &op) {
1970  // Add new ops here as needed.
1971  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1972  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1973  }));
1974 }
1975 
1976 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1977  return getRegion().front().getOperations();
1978 }
1979 
1980 //===----------------------------------------------------------------------===//
1981 // IfOp
1982 //===----------------------------------------------------------------------===//
1983 
1985  assert(a && "expected non-empty operation");
1986  assert(b && "expected non-empty operation");
1987 
1988  IfOp ifOp = a->getParentOfType<IfOp>();
1989  while (ifOp) {
1990  // Check if b is inside ifOp. (We already know that a is.)
1991  if (ifOp->isProperAncestor(b))
1992  // b is contained in ifOp. a and b are in mutually exclusive branches if
1993  // they are in different blocks of ifOp.
1994  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1995  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1996  // Check next enclosing IfOp.
1997  ifOp = ifOp->getParentOfType<IfOp>();
1998  }
1999 
2000  // Could not find a common IfOp among a's and b's ancestors.
2001  return false;
2002 }
2003 
2004 LogicalResult
2005 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2006  IfOp::Adaptor adaptor,
2007  SmallVectorImpl<Type> &inferredReturnTypes) {
2008  if (adaptor.getRegions().empty())
2009  return failure();
2010  Region *r = &adaptor.getThenRegion();
2011  if (r->empty())
2012  return failure();
2013  Block &b = r->front();
2014  if (b.empty())
2015  return failure();
2016  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2017  if (!yieldOp)
2018  return failure();
2019  TypeRange types = yieldOp.getOperandTypes();
2020  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2021  types.end());
2022  return success();
2023 }
2024 
2025 void IfOp::build(OpBuilder &builder, OperationState &result,
2026  TypeRange resultTypes, Value cond) {
2027  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2028  /*addElseBlock=*/false);
2029 }
2030 
2031 void IfOp::build(OpBuilder &builder, OperationState &result,
2032  TypeRange resultTypes, Value cond, bool addThenBlock,
2033  bool addElseBlock) {
2034  assert((!addElseBlock || addThenBlock) &&
2035  "must not create else block w/o then block");
2036  result.addTypes(resultTypes);
2037  result.addOperands(cond);
2038 
2039  // Add regions and blocks.
2040  OpBuilder::InsertionGuard guard(builder);
2041  Region *thenRegion = result.addRegion();
2042  if (addThenBlock)
2043  builder.createBlock(thenRegion);
2044  Region *elseRegion = result.addRegion();
2045  if (addElseBlock)
2046  builder.createBlock(elseRegion);
2047 }
2048 
2049 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2050  bool withElseRegion) {
2051  build(builder, result, TypeRange{}, cond, withElseRegion);
2052 }
2053 
2054 void IfOp::build(OpBuilder &builder, OperationState &result,
2055  TypeRange resultTypes, Value cond, bool withElseRegion) {
2056  result.addTypes(resultTypes);
2057  result.addOperands(cond);
2058 
2059  // Build then region.
2060  OpBuilder::InsertionGuard guard(builder);
2061  Region *thenRegion = result.addRegion();
2062  builder.createBlock(thenRegion);
2063  if (resultTypes.empty())
2064  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2065 
2066  // Build else region.
2067  Region *elseRegion = result.addRegion();
2068  if (withElseRegion) {
2069  builder.createBlock(elseRegion);
2070  if (resultTypes.empty())
2071  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2072  }
2073 }
2074 
2075 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2076  function_ref<void(OpBuilder &, Location)> thenBuilder,
2077  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2078  assert(thenBuilder && "the builder callback for 'then' must be present");
2079  result.addOperands(cond);
2080 
2081  // Build then region.
2082  OpBuilder::InsertionGuard guard(builder);
2083  Region *thenRegion = result.addRegion();
2084  builder.createBlock(thenRegion);
2085  thenBuilder(builder, result.location);
2086 
2087  // Build else region.
2088  Region *elseRegion = result.addRegion();
2089  if (elseBuilder) {
2090  builder.createBlock(elseRegion);
2091  elseBuilder(builder, result.location);
2092  }
2093 
2094  // Infer result types.
2095  SmallVector<Type> inferredReturnTypes;
2096  MLIRContext *ctx = builder.getContext();
2097  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2098  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2099  /*properties=*/nullptr, result.regions,
2100  inferredReturnTypes))) {
2101  result.addTypes(inferredReturnTypes);
2102  }
2103 }
2104 
2105 LogicalResult IfOp::verify() {
2106  if (getNumResults() != 0 && getElseRegion().empty())
2107  return emitOpError("must have an else block if defining values");
2108  return success();
2109 }
2110 
2111 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2112  // Create the regions for 'then'.
2113  result.regions.reserve(2);
2114  Region *thenRegion = result.addRegion();
2115  Region *elseRegion = result.addRegion();
2116 
2117  auto &builder = parser.getBuilder();
2119  Type i1Type = builder.getIntegerType(1);
2120  if (parser.parseOperand(cond) ||
2121  parser.resolveOperand(cond, i1Type, result.operands))
2122  return failure();
2123  // Parse optional results type list.
2124  if (parser.parseOptionalArrowTypeList(result.types))
2125  return failure();
2126  // Parse the 'then' region.
2127  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2128  return failure();
2129  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2130 
2131  // If we find an 'else' keyword then parse the 'else' region.
2132  if (!parser.parseOptionalKeyword("else")) {
2133  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2134  return failure();
2135  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2136  }
2137 
2138  // Parse the optional attribute list.
2139  if (parser.parseOptionalAttrDict(result.attributes))
2140  return failure();
2141  return success();
2142 }
2143 
2144 void IfOp::print(OpAsmPrinter &p) {
2145  bool printBlockTerminators = false;
2146 
2147  p << " " << getCondition();
2148  if (!getResults().empty()) {
2149  p << " -> (" << getResultTypes() << ")";
2150  // Print yield explicitly if the op defines values.
2151  printBlockTerminators = true;
2152  }
2153  p << ' ';
2154  p.printRegion(getThenRegion(),
2155  /*printEntryBlockArgs=*/false,
2156  /*printBlockTerminators=*/printBlockTerminators);
2157 
2158  // Print the 'else' regions if it exists and has a block.
2159  auto &elseRegion = getElseRegion();
2160  if (!elseRegion.empty()) {
2161  p << " else ";
2162  p.printRegion(elseRegion,
2163  /*printEntryBlockArgs=*/false,
2164  /*printBlockTerminators=*/printBlockTerminators);
2165  }
2166 
2167  p.printOptionalAttrDict((*this)->getAttrs());
2168 }
2169 
2170 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2172  // The `then` and the `else` region branch back to the parent operation.
2173  if (!point.isParent()) {
2174  regions.push_back(RegionSuccessor(getResults()));
2175  return;
2176  }
2177 
2178  regions.push_back(RegionSuccessor(&getThenRegion()));
2179 
2180  // Don't consider the else region if it is empty.
2181  Region *elseRegion = &this->getElseRegion();
2182  if (elseRegion->empty())
2183  regions.push_back(RegionSuccessor());
2184  else
2185  regions.push_back(RegionSuccessor(elseRegion));
2186 }
2187 
2188 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2190  FoldAdaptor adaptor(operands, *this);
2191  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2192  if (!boolAttr || boolAttr.getValue())
2193  regions.emplace_back(&getThenRegion());
2194 
2195  // If the else region is empty, execution continues after the parent op.
2196  if (!boolAttr || !boolAttr.getValue()) {
2197  if (!getElseRegion().empty())
2198  regions.emplace_back(&getElseRegion());
2199  else
2200  regions.emplace_back(getResults());
2201  }
2202 }
2203 
2204 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2205  SmallVectorImpl<OpFoldResult> &results) {
2206  // if (!c) then A() else B() -> if c then B() else A()
2207  if (getElseRegion().empty())
2208  return failure();
2209 
2210  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2211  if (!xorStmt)
2212  return failure();
2213 
2214  if (!matchPattern(xorStmt.getRhs(), m_One()))
2215  return failure();
2216 
2217  getConditionMutable().assign(xorStmt.getLhs());
2218  Block *thenBlock = &getThenRegion().front();
2219  // It would be nicer to use iplist::swap, but that has no implemented
2220  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2221  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2222  getElseRegion().getBlocks());
2223  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2224  getThenRegion().getBlocks(), thenBlock);
2225  return success();
2226 }
2227 
2228 void IfOp::getRegionInvocationBounds(
2229  ArrayRef<Attribute> operands,
2230  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2231  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2232  // If the condition is known, then one region is known to be executed once
2233  // and the other zero times.
2234  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2235  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2236  } else {
2237  // Non-constant condition. Each region may be executed 0 or 1 times.
2238  invocationBounds.assign(2, {0, 1});
2239  }
2240 }
2241 
2242 namespace {
2243 // Pattern to remove unused IfOp results.
2244 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2246 
2247  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2248  PatternRewriter &rewriter) const {
2249  // Move all operations to the destination block.
2250  rewriter.mergeBlocks(source, dest);
2251  // Replace the yield op by one that returns only the used values.
2252  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2253  SmallVector<Value, 4> usedOperands;
2254  llvm::transform(usedResults, std::back_inserter(usedOperands),
2255  [&](OpResult result) {
2256  return yieldOp.getOperand(result.getResultNumber());
2257  });
2258  rewriter.modifyOpInPlace(yieldOp,
2259  [&]() { yieldOp->setOperands(usedOperands); });
2260  }
2261 
2262  LogicalResult matchAndRewrite(IfOp op,
2263  PatternRewriter &rewriter) const override {
2264  // Compute the list of used results.
2265  SmallVector<OpResult, 4> usedResults;
2266  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2267  [](OpResult result) { return !result.use_empty(); });
2268 
2269  // Replace the operation if only a subset of its results have uses.
2270  if (usedResults.size() == op.getNumResults())
2271  return failure();
2272 
2273  // Compute the result types of the replacement operation.
2274  SmallVector<Type, 4> newTypes;
2275  llvm::transform(usedResults, std::back_inserter(newTypes),
2276  [](OpResult result) { return result.getType(); });
2277 
2278  // Create a replacement operation with empty then and else regions.
2279  auto newOp =
2280  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2281  rewriter.createBlock(&newOp.getThenRegion());
2282  rewriter.createBlock(&newOp.getElseRegion());
2283 
2284  // Move the bodies and replace the terminators (note there is a then and
2285  // an else region since the operation returns results).
2286  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2287  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2288 
2289  // Replace the operation by the new one.
2290  SmallVector<Value, 4> repResults(op.getNumResults());
2291  for (const auto &en : llvm::enumerate(usedResults))
2292  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2293  rewriter.replaceOp(op, repResults);
2294  return success();
2295  }
2296 };
2297 
2298 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2300 
2301  LogicalResult matchAndRewrite(IfOp op,
2302  PatternRewriter &rewriter) const override {
2303  BoolAttr condition;
2304  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2305  return failure();
2306 
2307  if (condition.getValue())
2308  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2309  else if (!op.getElseRegion().empty())
2310  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2311  else
2312  rewriter.eraseOp(op);
2313 
2314  return success();
2315  }
2316 };
2317 
2318 /// Hoist any yielded results whose operands are defined outside
2319 /// the if, to a select instruction.
2320 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2322 
2323  LogicalResult matchAndRewrite(IfOp op,
2324  PatternRewriter &rewriter) const override {
2325  if (op->getNumResults() == 0)
2326  return failure();
2327 
2328  auto cond = op.getCondition();
2329  auto thenYieldArgs = op.thenYield().getOperands();
2330  auto elseYieldArgs = op.elseYield().getOperands();
2331 
2332  SmallVector<Type> nonHoistable;
2333  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2334  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2335  &op.getElseRegion() == falseVal.getParentRegion())
2336  nonHoistable.push_back(trueVal.getType());
2337  }
2338  // Early exit if there aren't any yielded values we can
2339  // hoist outside the if.
2340  if (nonHoistable.size() == op->getNumResults())
2341  return failure();
2342 
2343  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2344  /*withElseRegion=*/false);
2345  if (replacement.thenBlock())
2346  rewriter.eraseBlock(replacement.thenBlock());
2347  replacement.getThenRegion().takeBody(op.getThenRegion());
2348  replacement.getElseRegion().takeBody(op.getElseRegion());
2349 
2350  SmallVector<Value> results(op->getNumResults());
2351  assert(thenYieldArgs.size() == results.size());
2352  assert(elseYieldArgs.size() == results.size());
2353 
2354  SmallVector<Value> trueYields;
2355  SmallVector<Value> falseYields;
2356  rewriter.setInsertionPoint(replacement);
2357  for (const auto &it :
2358  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2359  Value trueVal = std::get<0>(it.value());
2360  Value falseVal = std::get<1>(it.value());
2361  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2362  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2363  results[it.index()] = replacement.getResult(trueYields.size());
2364  trueYields.push_back(trueVal);
2365  falseYields.push_back(falseVal);
2366  } else if (trueVal == falseVal)
2367  results[it.index()] = trueVal;
2368  else
2369  results[it.index()] = rewriter.create<arith::SelectOp>(
2370  op.getLoc(), cond, trueVal, falseVal);
2371  }
2372 
2373  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2374  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2375 
2376  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2377  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2378 
2379  rewriter.replaceOp(op, results);
2380  return success();
2381  }
2382 };
2383 
2384 /// Allow the true region of an if to assume the condition is true
2385 /// and vice versa. For example:
2386 ///
2387 /// scf.if %cmp {
2388 /// print(%cmp)
2389 /// }
2390 ///
2391 /// becomes
2392 ///
2393 /// scf.if %cmp {
2394 /// print(true)
2395 /// }
2396 ///
2397 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2399 
2400  LogicalResult matchAndRewrite(IfOp op,
2401  PatternRewriter &rewriter) const override {
2402  // Early exit if the condition is constant since replacing a constant
2403  // in the body with another constant isn't a simplification.
2404  if (matchPattern(op.getCondition(), m_Constant()))
2405  return failure();
2406 
2407  bool changed = false;
2408  mlir::Type i1Ty = rewriter.getI1Type();
2409 
2410  // These variables serve to prevent creating duplicate constants
2411  // and hold constant true or false values.
2412  Value constantTrue = nullptr;
2413  Value constantFalse = nullptr;
2414 
2415  for (OpOperand &use :
2416  llvm::make_early_inc_range(op.getCondition().getUses())) {
2417  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2418  changed = true;
2419 
2420  if (!constantTrue)
2421  constantTrue = rewriter.create<arith::ConstantOp>(
2422  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2423 
2424  rewriter.modifyOpInPlace(use.getOwner(),
2425  [&]() { use.set(constantTrue); });
2426  } else if (op.getElseRegion().isAncestor(
2427  use.getOwner()->getParentRegion())) {
2428  changed = true;
2429 
2430  if (!constantFalse)
2431  constantFalse = rewriter.create<arith::ConstantOp>(
2432  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2433 
2434  rewriter.modifyOpInPlace(use.getOwner(),
2435  [&]() { use.set(constantFalse); });
2436  }
2437  }
2438 
2439  return success(changed);
2440  }
2441 };
2442 
2443 /// Remove any statements from an if that are equivalent to the condition
2444 /// or its negation. For example:
2445 ///
2446 /// %res:2 = scf.if %cmp {
2447 /// yield something(), true
2448 /// } else {
2449 /// yield something2(), false
2450 /// }
2451 /// print(%res#1)
2452 ///
2453 /// becomes
2454 /// %res = scf.if %cmp {
2455 /// yield something()
2456 /// } else {
2457 /// yield something2()
2458 /// }
2459 /// print(%cmp)
2460 ///
2461 /// Additionally if both branches yield the same value, replace all uses
2462 /// of the result with the yielded value.
2463 ///
2464 /// %res:2 = scf.if %cmp {
2465 /// yield something(), %arg1
2466 /// } else {
2467 /// yield something2(), %arg1
2468 /// }
2469 /// print(%res#1)
2470 ///
2471 /// becomes
2472 /// %res = scf.if %cmp {
2473 /// yield something()
2474 /// } else {
2475 /// yield something2()
2476 /// }
2477 /// print(%arg1)
2478 ///
2479 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2481 
2482  LogicalResult matchAndRewrite(IfOp op,
2483  PatternRewriter &rewriter) const override {
2484  // Early exit if there are no results that could be replaced.
2485  if (op.getNumResults() == 0)
2486  return failure();
2487 
2488  auto trueYield =
2489  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2490  auto falseYield =
2491  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2492 
2493  rewriter.setInsertionPoint(op->getBlock(),
2494  op.getOperation()->getIterator());
2495  bool changed = false;
2496  Type i1Ty = rewriter.getI1Type();
2497  for (auto [trueResult, falseResult, opResult] :
2498  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2499  op.getResults())) {
2500  if (trueResult == falseResult) {
2501  if (!opResult.use_empty()) {
2502  opResult.replaceAllUsesWith(trueResult);
2503  changed = true;
2504  }
2505  continue;
2506  }
2507 
2508  BoolAttr trueYield, falseYield;
2509  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2510  !matchPattern(falseResult, m_Constant(&falseYield)))
2511  continue;
2512 
2513  bool trueVal = trueYield.getValue();
2514  bool falseVal = falseYield.getValue();
2515  if (!trueVal && falseVal) {
2516  if (!opResult.use_empty()) {
2517  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2518  Value notCond = rewriter.create<arith::XOrIOp>(
2519  op.getLoc(), op.getCondition(),
2520  constDialect
2521  ->materializeConstant(rewriter,
2522  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2523  op.getLoc())
2524  ->getResult(0));
2525  opResult.replaceAllUsesWith(notCond);
2526  changed = true;
2527  }
2528  }
2529  if (trueVal && !falseVal) {
2530  if (!opResult.use_empty()) {
2531  opResult.replaceAllUsesWith(op.getCondition());
2532  changed = true;
2533  }
2534  }
2535  }
2536  return success(changed);
2537  }
2538 };
2539 
2540 /// Merge any consecutive scf.if's with the same condition.
2541 ///
2542 /// scf.if %cond {
2543 /// firstCodeTrue();...
2544 /// } else {
2545 /// firstCodeFalse();...
2546 /// }
2547 /// %res = scf.if %cond {
2548 /// secondCodeTrue();...
2549 /// } else {
2550 /// secondCodeFalse();...
2551 /// }
2552 ///
2553 /// becomes
2554 /// %res = scf.if %cmp {
2555 /// firstCodeTrue();...
2556 /// secondCodeTrue();...
2557 /// } else {
2558 /// firstCodeFalse();...
2559 /// secondCodeFalse();...
2560 /// }
2561 struct CombineIfs : public OpRewritePattern<IfOp> {
2563 
2564  LogicalResult matchAndRewrite(IfOp nextIf,
2565  PatternRewriter &rewriter) const override {
2566  Block *parent = nextIf->getBlock();
2567  if (nextIf == &parent->front())
2568  return failure();
2569 
2570  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2571  if (!prevIf)
2572  return failure();
2573 
2574  // Determine the logical then/else blocks when prevIf's
2575  // condition is used. Null means the block does not exist
2576  // in that case (e.g. empty else). If neither of these
2577  // are set, the two conditions cannot be compared.
2578  Block *nextThen = nullptr;
2579  Block *nextElse = nullptr;
2580  if (nextIf.getCondition() == prevIf.getCondition()) {
2581  nextThen = nextIf.thenBlock();
2582  if (!nextIf.getElseRegion().empty())
2583  nextElse = nextIf.elseBlock();
2584  }
2585  if (arith::XOrIOp notv =
2586  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2587  if (notv.getLhs() == prevIf.getCondition() &&
2588  matchPattern(notv.getRhs(), m_One())) {
2589  nextElse = nextIf.thenBlock();
2590  if (!nextIf.getElseRegion().empty())
2591  nextThen = nextIf.elseBlock();
2592  }
2593  }
2594  if (arith::XOrIOp notv =
2595  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2596  if (notv.getLhs() == nextIf.getCondition() &&
2597  matchPattern(notv.getRhs(), m_One())) {
2598  nextElse = nextIf.thenBlock();
2599  if (!nextIf.getElseRegion().empty())
2600  nextThen = nextIf.elseBlock();
2601  }
2602  }
2603 
2604  if (!nextThen && !nextElse)
2605  return failure();
2606 
2607  SmallVector<Value> prevElseYielded;
2608  if (!prevIf.getElseRegion().empty())
2609  prevElseYielded = prevIf.elseYield().getOperands();
2610  // Replace all uses of return values of op within nextIf with the
2611  // corresponding yields
2612  for (auto it : llvm::zip(prevIf.getResults(),
2613  prevIf.thenYield().getOperands(), prevElseYielded))
2614  for (OpOperand &use :
2615  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2616  if (nextThen && nextThen->getParent()->isAncestor(
2617  use.getOwner()->getParentRegion())) {
2618  rewriter.startOpModification(use.getOwner());
2619  use.set(std::get<1>(it));
2620  rewriter.finalizeOpModification(use.getOwner());
2621  } else if (nextElse && nextElse->getParent()->isAncestor(
2622  use.getOwner()->getParentRegion())) {
2623  rewriter.startOpModification(use.getOwner());
2624  use.set(std::get<2>(it));
2625  rewriter.finalizeOpModification(use.getOwner());
2626  }
2627  }
2628 
2629  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2630  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2631 
2632  IfOp combinedIf = rewriter.create<IfOp>(
2633  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2634  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2635 
2636  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2637  combinedIf.getThenRegion(),
2638  combinedIf.getThenRegion().begin());
2639 
2640  if (nextThen) {
2641  YieldOp thenYield = combinedIf.thenYield();
2642  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2643  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2644  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2645 
2646  SmallVector<Value> mergedYields(thenYield.getOperands());
2647  llvm::append_range(mergedYields, thenYield2.getOperands());
2648  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2649  rewriter.eraseOp(thenYield);
2650  rewriter.eraseOp(thenYield2);
2651  }
2652 
2653  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2654  combinedIf.getElseRegion(),
2655  combinedIf.getElseRegion().begin());
2656 
2657  if (nextElse) {
2658  if (combinedIf.getElseRegion().empty()) {
2659  rewriter.inlineRegionBefore(*nextElse->getParent(),
2660  combinedIf.getElseRegion(),
2661  combinedIf.getElseRegion().begin());
2662  } else {
2663  YieldOp elseYield = combinedIf.elseYield();
2664  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2665  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2666 
2667  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2668 
2669  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2670  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2671 
2672  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2673  rewriter.eraseOp(elseYield);
2674  rewriter.eraseOp(elseYield2);
2675  }
2676  }
2677 
2678  SmallVector<Value> prevValues;
2679  SmallVector<Value> nextValues;
2680  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2681  if (pair.index() < prevIf.getNumResults())
2682  prevValues.push_back(pair.value());
2683  else
2684  nextValues.push_back(pair.value());
2685  }
2686  rewriter.replaceOp(prevIf, prevValues);
2687  rewriter.replaceOp(nextIf, nextValues);
2688  return success();
2689  }
2690 };
2691 
2692 /// Pattern to remove an empty else branch.
2693 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2695 
2696  LogicalResult matchAndRewrite(IfOp ifOp,
2697  PatternRewriter &rewriter) const override {
2698  // Cannot remove else region when there are operation results.
2699  if (ifOp.getNumResults())
2700  return failure();
2701  Block *elseBlock = ifOp.elseBlock();
2702  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2703  return failure();
2704  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2705  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2706  newIfOp.getThenRegion().begin());
2707  rewriter.eraseOp(ifOp);
2708  return success();
2709  }
2710 };
2711 
2712 /// Convert nested `if`s into `arith.andi` + single `if`.
2713 ///
2714 /// scf.if %arg0 {
2715 /// scf.if %arg1 {
2716 /// ...
2717 /// scf.yield
2718 /// }
2719 /// scf.yield
2720 /// }
2721 /// becomes
2722 ///
2723 /// %0 = arith.andi %arg0, %arg1
2724 /// scf.if %0 {
2725 /// ...
2726 /// scf.yield
2727 /// }
2728 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2730 
2731  LogicalResult matchAndRewrite(IfOp op,
2732  PatternRewriter &rewriter) const override {
2733  auto nestedOps = op.thenBlock()->without_terminator();
2734  // Nested `if` must be the only op in block.
2735  if (!llvm::hasSingleElement(nestedOps))
2736  return failure();
2737 
2738  // If there is an else block, it can only yield
2739  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2740  return failure();
2741 
2742  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2743  if (!nestedIf)
2744  return failure();
2745 
2746  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2747  return failure();
2748 
2749  SmallVector<Value> thenYield(op.thenYield().getOperands());
2750  SmallVector<Value> elseYield;
2751  if (op.elseBlock())
2752  llvm::append_range(elseYield, op.elseYield().getOperands());
2753 
2754  // A list of indices for which we should upgrade the value yielded
2755  // in the else to a select.
2756  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2757 
2758  // If the outer scf.if yields a value produced by the inner scf.if,
2759  // only permit combining if the value yielded when the condition
2760  // is false in the outer scf.if is the same value yielded when the
2761  // inner scf.if condition is false.
2762  // Note that the array access to elseYield will not go out of bounds
2763  // since it must have the same length as thenYield, since they both
2764  // come from the same scf.if.
2765  for (const auto &tup : llvm::enumerate(thenYield)) {
2766  if (tup.value().getDefiningOp() == nestedIf) {
2767  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2768  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2769  elseYield[tup.index()]) {
2770  return failure();
2771  }
2772  // If the correctness test passes, we will yield
2773  // corresponding value from the inner scf.if
2774  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2775  continue;
2776  }
2777 
2778  // Otherwise, we need to ensure the else block of the combined
2779  // condition still returns the same value when the outer condition is
2780  // true and the inner condition is false. This can be accomplished if
2781  // the then value is defined outside the outer scf.if and we replace the
2782  // value with a select that considers just the outer condition. Since
2783  // the else region contains just the yield, its yielded value is
2784  // defined outside the scf.if, by definition.
2785 
2786  // If the then value is defined within the scf.if, bail.
2787  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2788  return failure();
2789  }
2790  elseYieldsToUpgradeToSelect.push_back(tup.index());
2791  }
2792 
2793  Location loc = op.getLoc();
2794  Value newCondition = rewriter.create<arith::AndIOp>(
2795  loc, op.getCondition(), nestedIf.getCondition());
2796  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2797  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2798 
2799  SmallVector<Value> results;
2800  llvm::append_range(results, newIf.getResults());
2801  rewriter.setInsertionPoint(newIf);
2802 
2803  for (auto idx : elseYieldsToUpgradeToSelect)
2804  results[idx] = rewriter.create<arith::SelectOp>(
2805  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2806 
2807  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2808  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2809  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2810  if (!elseYield.empty()) {
2811  rewriter.createBlock(&newIf.getElseRegion());
2812  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2813  rewriter.create<YieldOp>(loc, elseYield);
2814  }
2815  rewriter.replaceOp(op, results);
2816  return success();
2817  }
2818 };
2819 
2820 } // namespace
2821 
2822 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2823  MLIRContext *context) {
2824  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2825  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2826  RemoveStaticCondition, RemoveUnusedResults,
2827  ReplaceIfYieldWithConditionOrValue>(context);
2828 }
2829 
2830 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2831 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2832 Block *IfOp::elseBlock() {
2833  Region &r = getElseRegion();
2834  if (r.empty())
2835  return nullptr;
2836  return &r.back();
2837 }
2838 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2839 
2840 //===----------------------------------------------------------------------===//
2841 // ParallelOp
2842 //===----------------------------------------------------------------------===//
2843 
2844 void ParallelOp::build(
2845  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2846  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2848  bodyBuilderFn) {
2849  result.addOperands(lowerBounds);
2850  result.addOperands(upperBounds);
2851  result.addOperands(steps);
2852  result.addOperands(initVals);
2853  result.addAttribute(
2854  ParallelOp::getOperandSegmentSizeAttr(),
2855  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2856  static_cast<int32_t>(upperBounds.size()),
2857  static_cast<int32_t>(steps.size()),
2858  static_cast<int32_t>(initVals.size())}));
2859  result.addTypes(initVals.getTypes());
2860 
2861  OpBuilder::InsertionGuard guard(builder);
2862  unsigned numIVs = steps.size();
2863  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2864  SmallVector<Location, 8> argLocs(numIVs, result.location);
2865  Region *bodyRegion = result.addRegion();
2866  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2867 
2868  if (bodyBuilderFn) {
2869  builder.setInsertionPointToStart(bodyBlock);
2870  bodyBuilderFn(builder, result.location,
2871  bodyBlock->getArguments().take_front(numIVs),
2872  bodyBlock->getArguments().drop_front(numIVs));
2873  }
2874  // Add terminator only if there are no reductions.
2875  if (initVals.empty())
2876  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2877 }
2878 
2879 void ParallelOp::build(
2880  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2881  ValueRange upperBounds, ValueRange steps,
2882  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2883  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2884  // we don't capture a reference to a temporary by constructing the lambda at
2885  // function level.
2886  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2887  Location nestedLoc, ValueRange ivs,
2888  ValueRange) {
2889  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2890  };
2891  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2892  if (bodyBuilderFn)
2893  wrapper = wrappedBuilderFn;
2894 
2895  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2896  wrapper);
2897 }
2898 
2899 LogicalResult ParallelOp::verify() {
2900  // Check that there is at least one value in lowerBound, upperBound and step.
2901  // It is sufficient to test only step, because it is ensured already that the
2902  // number of elements in lowerBound, upperBound and step are the same.
2903  Operation::operand_range stepValues = getStep();
2904  if (stepValues.empty())
2905  return emitOpError(
2906  "needs at least one tuple element for lowerBound, upperBound and step");
2907 
2908  // Check whether all constant step values are positive.
2909  for (Value stepValue : stepValues)
2910  if (auto cst = getConstantIntValue(stepValue))
2911  if (*cst <= 0)
2912  return emitOpError("constant step operand must be positive");
2913 
2914  // Check that the body defines the same number of block arguments as the
2915  // number of tuple elements in step.
2916  Block *body = getBody();
2917  if (body->getNumArguments() != stepValues.size())
2918  return emitOpError() << "expects the same number of induction variables: "
2919  << body->getNumArguments()
2920  << " as bound and step values: " << stepValues.size();
2921  for (auto arg : body->getArguments())
2922  if (!arg.getType().isIndex())
2923  return emitOpError(
2924  "expects arguments for the induction variable to be of index type");
2925 
2926  // Check that the terminator is an scf.reduce op.
2927  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2928  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2929  if (!reduceOp)
2930  return failure();
2931 
2932  // Check that the number of results is the same as the number of reductions.
2933  auto resultsSize = getResults().size();
2934  auto reductionsSize = reduceOp.getReductions().size();
2935  auto initValsSize = getInitVals().size();
2936  if (resultsSize != reductionsSize)
2937  return emitOpError() << "expects number of results: " << resultsSize
2938  << " to be the same as number of reductions: "
2939  << reductionsSize;
2940  if (resultsSize != initValsSize)
2941  return emitOpError() << "expects number of results: " << resultsSize
2942  << " to be the same as number of initial values: "
2943  << initValsSize;
2944 
2945  // Check that the types of the results and reductions are the same.
2946  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2947  auto resultType = getOperation()->getResult(i).getType();
2948  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2949  if (resultType != reductionOperandType)
2950  return reduceOp.emitOpError()
2951  << "expects type of " << i
2952  << "-th reduction operand: " << reductionOperandType
2953  << " to be the same as the " << i
2954  << "-th result type: " << resultType;
2955  }
2956  return success();
2957 }
2958 
2959 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2960  auto &builder = parser.getBuilder();
2961  // Parse an opening `(` followed by induction variables followed by `)`
2964  return failure();
2965 
2966  // Parse loop bounds.
2968  if (parser.parseEqual() ||
2969  parser.parseOperandList(lower, ivs.size(),
2971  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2972  return failure();
2973 
2975  if (parser.parseKeyword("to") ||
2976  parser.parseOperandList(upper, ivs.size(),
2978  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2979  return failure();
2980 
2981  // Parse step values.
2983  if (parser.parseKeyword("step") ||
2984  parser.parseOperandList(steps, ivs.size(),
2986  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2987  return failure();
2988 
2989  // Parse init values.
2991  if (succeeded(parser.parseOptionalKeyword("init"))) {
2992  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2993  return failure();
2994  }
2995 
2996  // Parse optional results in case there is a reduce.
2997  if (parser.parseOptionalArrowTypeList(result.types))
2998  return failure();
2999 
3000  // Now parse the body.
3001  Region *body = result.addRegion();
3002  for (auto &iv : ivs)
3003  iv.type = builder.getIndexType();
3004  if (parser.parseRegion(*body, ivs))
3005  return failure();
3006 
3007  // Set `operandSegmentSizes` attribute.
3008  result.addAttribute(
3009  ParallelOp::getOperandSegmentSizeAttr(),
3010  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3011  static_cast<int32_t>(upper.size()),
3012  static_cast<int32_t>(steps.size()),
3013  static_cast<int32_t>(initVals.size())}));
3014 
3015  // Parse attributes.
3016  if (parser.parseOptionalAttrDict(result.attributes) ||
3017  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3018  result.operands))
3019  return failure();
3020 
3021  // Add a terminator if none was parsed.
3022  ParallelOp::ensureTerminator(*body, builder, result.location);
3023  return success();
3024 }
3025 
3027  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3028  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3029  if (!getInitVals().empty())
3030  p << " init (" << getInitVals() << ")";
3031  p.printOptionalArrowTypeList(getResultTypes());
3032  p << ' ';
3033  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3035  (*this)->getAttrs(),
3036  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3037 }
3038 
3039 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3040 
3041 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3042  return SmallVector<Value>{getBody()->getArguments()};
3043 }
3044 
3045 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3046  return getLowerBound();
3047 }
3048 
3049 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3050  return getUpperBound();
3051 }
3052 
3053 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3054  return getStep();
3055 }
3056 
3058  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3059  if (!ivArg)
3060  return ParallelOp();
3061  assert(ivArg.getOwner() && "unlinked block argument");
3062  auto *containingOp = ivArg.getOwner()->getParentOp();
3063  return dyn_cast<ParallelOp>(containingOp);
3064 }
3065 
3066 namespace {
3067 // Collapse loop dimensions that perform a single iteration.
3068 struct ParallelOpSingleOrZeroIterationDimsFolder
3069  : public OpRewritePattern<ParallelOp> {
3071 
3072  LogicalResult matchAndRewrite(ParallelOp op,
3073  PatternRewriter &rewriter) const override {
3074  Location loc = op.getLoc();
3075 
3076  // Compute new loop bounds that omit all single-iteration loop dimensions.
3077  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3078  IRMapping mapping;
3079  for (auto [lb, ub, step, iv] :
3080  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3081  op.getInductionVars())) {
3082  auto numIterations = constantTripCount(lb, ub, step);
3083  if (numIterations.has_value()) {
3084  // Remove the loop if it performs zero iterations.
3085  if (*numIterations == 0) {
3086  rewriter.replaceOp(op, op.getInitVals());
3087  return success();
3088  }
3089  // Replace the loop induction variable by the lower bound if the loop
3090  // performs a single iteration. Otherwise, copy the loop bounds.
3091  if (*numIterations == 1) {
3092  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3093  continue;
3094  }
3095  }
3096  newLowerBounds.push_back(lb);
3097  newUpperBounds.push_back(ub);
3098  newSteps.push_back(step);
3099  }
3100  // Exit if none of the loop dimensions perform a single iteration.
3101  if (newLowerBounds.size() == op.getLowerBound().size())
3102  return failure();
3103 
3104  if (newLowerBounds.empty()) {
3105  // All of the loop dimensions perform a single iteration. Inline
3106  // loop body and nested ReduceOp's
3107  SmallVector<Value> results;
3108  results.reserve(op.getInitVals().size());
3109  for (auto &bodyOp : op.getBody()->without_terminator())
3110  rewriter.clone(bodyOp, mapping);
3111  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3112  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3113  Block &reduceBlock = reduceOp.getReductions()[i].front();
3114  auto initValIndex = results.size();
3115  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3116  mapping.map(reduceBlock.getArgument(1),
3117  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3118  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3119  rewriter.clone(reduceBodyOp, mapping);
3120 
3121  auto result = mapping.lookupOrDefault(
3122  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3123  results.push_back(result);
3124  }
3125 
3126  rewriter.replaceOp(op, results);
3127  return success();
3128  }
3129  // Replace the parallel loop by lower-dimensional parallel loop.
3130  auto newOp =
3131  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3132  newSteps, op.getInitVals(), nullptr);
3133  // Erase the empty block that was inserted by the builder.
3134  rewriter.eraseBlock(newOp.getBody());
3135  // Clone the loop body and remap the block arguments of the collapsed loops
3136  // (inlining does not support a cancellable block argument mapping).
3137  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3138  newOp.getRegion().begin(), mapping);
3139  rewriter.replaceOp(op, newOp.getResults());
3140  return success();
3141  }
3142 };
3143 
3144 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3146 
3147  LogicalResult matchAndRewrite(ParallelOp op,
3148  PatternRewriter &rewriter) const override {
3149  Block &outerBody = *op.getBody();
3150  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3151  return failure();
3152 
3153  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3154  if (!innerOp)
3155  return failure();
3156 
3157  for (auto val : outerBody.getArguments())
3158  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3159  llvm::is_contained(innerOp.getUpperBound(), val) ||
3160  llvm::is_contained(innerOp.getStep(), val))
3161  return failure();
3162 
3163  // Reductions are not supported yet.
3164  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3165  return failure();
3166 
3167  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3168  ValueRange iterVals, ValueRange) {
3169  Block &innerBody = *innerOp.getBody();
3170  assert(iterVals.size() ==
3171  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3172  IRMapping mapping;
3173  mapping.map(outerBody.getArguments(),
3174  iterVals.take_front(outerBody.getNumArguments()));
3175  mapping.map(innerBody.getArguments(),
3176  iterVals.take_back(innerBody.getNumArguments()));
3177  for (Operation &op : innerBody.without_terminator())
3178  builder.clone(op, mapping);
3179  };
3180 
3181  auto concatValues = [](const auto &first, const auto &second) {
3182  SmallVector<Value> ret;
3183  ret.reserve(first.size() + second.size());
3184  ret.assign(first.begin(), first.end());
3185  ret.append(second.begin(), second.end());
3186  return ret;
3187  };
3188 
3189  auto newLowerBounds =
3190  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3191  auto newUpperBounds =
3192  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3193  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3194 
3195  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3196  newSteps, std::nullopt,
3197  bodyBuilder);
3198  return success();
3199  }
3200 };
3201 
3202 } // namespace
3203 
3204 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3205  MLIRContext *context) {
3206  results
3207  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3208  context);
3209 }
3210 
3211 /// Given the region at `index`, or the parent operation if `index` is None,
3212 /// return the successor regions. These are the regions that may be selected
3213 /// during the flow of control. `operands` is a set of optional attributes that
3214 /// correspond to a constant value for each operand, or null if that operand is
3215 /// not a constant.
3216 void ParallelOp::getSuccessorRegions(
3218  // Both the operation itself and the region may be branching into the body or
3219  // back into the operation itself. It is possible for loop not to enter the
3220  // body.
3221  regions.push_back(RegionSuccessor(&getRegion()));
3222  regions.push_back(RegionSuccessor());
3223 }
3224 
3225 //===----------------------------------------------------------------------===//
3226 // ReduceOp
3227 //===----------------------------------------------------------------------===//
3228 
3229 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3230 
3231 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3232  ValueRange operands) {
3233  result.addOperands(operands);
3234  for (Value v : operands) {
3235  OpBuilder::InsertionGuard guard(builder);
3236  Region *bodyRegion = result.addRegion();
3237  builder.createBlock(bodyRegion, {},
3238  ArrayRef<Type>{v.getType(), v.getType()},
3239  {result.location, result.location});
3240  }
3241 }
3242 
3243 LogicalResult ReduceOp::verifyRegions() {
3244  // The region of a ReduceOp has two arguments of the same type as its
3245  // corresponding operand.
3246  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3247  auto type = getOperands()[i].getType();
3248  Block &block = getReductions()[i].front();
3249  if (block.empty())
3250  return emitOpError() << i << "-th reduction has an empty body";
3251  if (block.getNumArguments() != 2 ||
3252  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3253  return arg.getType() != type;
3254  }))
3255  return emitOpError() << "expected two block arguments with type " << type
3256  << " in the " << i << "-th reduction region";
3257 
3258  // Check that the block is terminated by a ReduceReturnOp.
3259  if (!isa<ReduceReturnOp>(block.getTerminator()))
3260  return emitOpError("reduction bodies must be terminated with an "
3261  "'scf.reduce.return' op");
3262  }
3263 
3264  return success();
3265 }
3266 
3269  // No operands are forwarded to the next iteration.
3270  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3271 }
3272 
3273 //===----------------------------------------------------------------------===//
3274 // ReduceReturnOp
3275 //===----------------------------------------------------------------------===//
3276 
3277 LogicalResult ReduceReturnOp::verify() {
3278  // The type of the return value should be the same type as the types of the
3279  // block arguments of the reduction body.
3280  Block *reductionBody = getOperation()->getBlock();
3281  // Should already be verified by an op trait.
3282  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3283  Type expectedResultType = reductionBody->getArgument(0).getType();
3284  if (expectedResultType != getResult().getType())
3285  return emitOpError() << "must have type " << expectedResultType
3286  << " (the type of the reduction inputs)";
3287  return success();
3288 }
3289 
3290 //===----------------------------------------------------------------------===//
3291 // WhileOp
3292 //===----------------------------------------------------------------------===//
3293 
3294 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3295  ::mlir::OperationState &odsState, TypeRange resultTypes,
3296  ValueRange inits, BodyBuilderFn beforeBuilder,
3297  BodyBuilderFn afterBuilder) {
3298  odsState.addOperands(inits);
3299  odsState.addTypes(resultTypes);
3300 
3301  OpBuilder::InsertionGuard guard(odsBuilder);
3302 
3303  // Build before region.
3304  SmallVector<Location, 4> beforeArgLocs;
3305  beforeArgLocs.reserve(inits.size());
3306  for (Value operand : inits) {
3307  beforeArgLocs.push_back(operand.getLoc());
3308  }
3309 
3310  Region *beforeRegion = odsState.addRegion();
3311  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3312  inits.getTypes(), beforeArgLocs);
3313  if (beforeBuilder)
3314  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3315 
3316  // Build after region.
3317  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3318 
3319  Region *afterRegion = odsState.addRegion();
3320  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3321  resultTypes, afterArgLocs);
3322 
3323  if (afterBuilder)
3324  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3325 }
3326 
3327 ConditionOp WhileOp::getConditionOp() {
3328  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3329 }
3330 
3331 YieldOp WhileOp::getYieldOp() {
3332  return cast<YieldOp>(getAfterBody()->getTerminator());
3333 }
3334 
3335 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3336  return getYieldOp().getResultsMutable();
3337 }
3338 
3339 Block::BlockArgListType WhileOp::getBeforeArguments() {
3340  return getBeforeBody()->getArguments();
3341 }
3342 
3343 Block::BlockArgListType WhileOp::getAfterArguments() {
3344  return getAfterBody()->getArguments();
3345 }
3346 
3347 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3348  return getBeforeArguments();
3349 }
3350 
3351 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3352  assert(point == getBefore() &&
3353  "WhileOp is expected to branch only to the first region");
3354  return getInits();
3355 }
3356 
3357 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3359  // The parent op always branches to the condition region.
3360  if (point.isParent()) {
3361  regions.emplace_back(&getBefore(), getBefore().getArguments());
3362  return;
3363  }
3364 
3365  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3366  "there are only two regions in a WhileOp");
3367  // The body region always branches back to the condition region.
3368  if (point == getAfter()) {
3369  regions.emplace_back(&getBefore(), getBefore().getArguments());
3370  return;
3371  }
3372 
3373  regions.emplace_back(getResults());
3374  regions.emplace_back(&getAfter(), getAfter().getArguments());
3375 }
3376 
3377 SmallVector<Region *> WhileOp::getLoopRegions() {
3378  return {&getBefore(), &getAfter()};
3379 }
3380 
3381 /// Parses a `while` op.
3382 ///
3383 /// op ::= `scf.while` assignments `:` function-type region `do` region
3384 /// `attributes` attribute-dict
3385 /// initializer ::= /* empty */ | `(` assignment-list `)`
3386 /// assignment-list ::= assignment | assignment `,` assignment-list
3387 /// assignment ::= ssa-value `=` ssa-value
3388 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3391  Region *before = result.addRegion();
3392  Region *after = result.addRegion();
3393 
3394  OptionalParseResult listResult =
3395  parser.parseOptionalAssignmentList(regionArgs, operands);
3396  if (listResult.has_value() && failed(listResult.value()))
3397  return failure();
3398 
3399  FunctionType functionType;
3400  SMLoc typeLoc = parser.getCurrentLocation();
3401  if (failed(parser.parseColonType(functionType)))
3402  return failure();
3403 
3404  result.addTypes(functionType.getResults());
3405 
3406  if (functionType.getNumInputs() != operands.size()) {
3407  return parser.emitError(typeLoc)
3408  << "expected as many input types as operands "
3409  << "(expected " << operands.size() << " got "
3410  << functionType.getNumInputs() << ")";
3411  }
3412 
3413  // Resolve input operands.
3414  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3415  parser.getCurrentLocation(),
3416  result.operands)))
3417  return failure();
3418 
3419  // Propagate the types into the region arguments.
3420  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3421  regionArgs[i].type = functionType.getInput(i);
3422 
3423  return failure(parser.parseRegion(*before, regionArgs) ||
3424  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3426 }
3427 
3428 /// Prints a `while` op.
3430  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3431  p << " : ";
3432  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3433  p << ' ';
3434  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3435  p << " do ";
3436  p.printRegion(getAfter());
3437  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3438 }
3439 
3440 /// Verifies that two ranges of types match, i.e. have the same number of
3441 /// entries and that types are pairwise equals. Reports errors on the given
3442 /// operation in case of mismatch.
3443 template <typename OpTy>
3444 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3445  TypeRange right, StringRef message) {
3446  if (left.size() != right.size())
3447  return op.emitOpError("expects the same number of ") << message;
3448 
3449  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3450  if (left[i] != right[i]) {
3451  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3452  << message;
3453  diag.attachNote() << "for argument " << i << ", found " << left[i]
3454  << " and " << right[i];
3455  return diag;
3456  }
3457  }
3458 
3459  return success();
3460 }
3461 
3462 LogicalResult scf::WhileOp::verify() {
3463  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3464  *this, getBefore(),
3465  "expects the 'before' region to terminate with 'scf.condition'");
3466  if (!beforeTerminator)
3467  return failure();
3468 
3469  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3470  *this, getAfter(),
3471  "expects the 'after' region to terminate with 'scf.yield'");
3472  return success(afterTerminator != nullptr);
3473 }
3474 
3475 namespace {
3476 /// Replace uses of the condition within the do block with true, since otherwise
3477 /// the block would not be evaluated.
3478 ///
3479 /// scf.while (..) : (i1, ...) -> ... {
3480 /// %condition = call @evaluate_condition() : () -> i1
3481 /// scf.condition(%condition) %condition : i1, ...
3482 /// } do {
3483 /// ^bb0(%arg0: i1, ...):
3484 /// use(%arg0)
3485 /// ...
3486 ///
3487 /// becomes
3488 /// scf.while (..) : (i1, ...) -> ... {
3489 /// %condition = call @evaluate_condition() : () -> i1
3490 /// scf.condition(%condition) %condition : i1, ...
3491 /// } do {
3492 /// ^bb0(%arg0: i1, ...):
3493 /// use(%true)
3494 /// ...
3495 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3497 
3498  LogicalResult matchAndRewrite(WhileOp op,
3499  PatternRewriter &rewriter) const override {
3500  auto term = op.getConditionOp();
3501 
3502  // These variables serve to prevent creating duplicate constants
3503  // and hold constant true or false values.
3504  Value constantTrue = nullptr;
3505 
3506  bool replaced = false;
3507  for (auto yieldedAndBlockArgs :
3508  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3509  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3510  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3511  if (!constantTrue)
3512  constantTrue = rewriter.create<arith::ConstantOp>(
3513  op.getLoc(), term.getCondition().getType(),
3514  rewriter.getBoolAttr(true));
3515 
3516  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3517  constantTrue);
3518  replaced = true;
3519  }
3520  }
3521  }
3522  return success(replaced);
3523  }
3524 };
3525 
3526 /// Remove loop invariant arguments from `before` block of scf.while.
3527 /// A before block argument is considered loop invariant if :-
3528 /// 1. i-th yield operand is equal to the i-th while operand.
3529 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3530 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3531 /// iter argument/while operand.
3532 /// For the arguments which are removed, their uses inside scf.while
3533 /// are replaced with their corresponding initial value.
3534 ///
3535 /// Eg:
3536 /// INPUT :-
3537 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3538 /// ..., %argN_before = %N)
3539 /// {
3540 /// ...
3541 /// scf.condition(%cond) %arg1_before, %arg0_before,
3542 /// %arg2_before, %arg0_before, ...
3543 /// } do {
3544 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3545 /// ..., %argK_after):
3546 /// ...
3547 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3548 /// }
3549 ///
3550 /// OUTPUT :-
3551 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3552 /// %N)
3553 /// {
3554 /// ...
3555 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3556 /// } do {
3557 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3558 /// ..., %argK_after):
3559 /// ...
3560 /// scf.yield %arg1_after, ..., %argN
3561 /// }
3562 ///
3563 /// EXPLANATION:
3564 /// We iterate over each yield operand.
3565 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3566 /// %arg0_before, which in turn is the 0-th iter argument. So we
3567 /// remove 0-th before block argument and yield operand, and replace
3568 /// all uses of the 0-th before block argument with its initial value
3569 /// %a.
3570 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3571 /// value. So we remove this operand and the corresponding before
3572 /// block argument and replace all uses of 1-th before block argument
3573 /// with %b.
3574 struct RemoveLoopInvariantArgsFromBeforeBlock
3575  : public OpRewritePattern<WhileOp> {
3577 
3578  LogicalResult matchAndRewrite(WhileOp op,
3579  PatternRewriter &rewriter) const override {
3580  Block &afterBlock = *op.getAfterBody();
3581  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3582  ConditionOp condOp = op.getConditionOp();
3583  OperandRange condOpArgs = condOp.getArgs();
3584  Operation *yieldOp = afterBlock.getTerminator();
3585  ValueRange yieldOpArgs = yieldOp->getOperands();
3586 
3587  bool canSimplify = false;
3588  for (const auto &it :
3589  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3590  auto index = static_cast<unsigned>(it.index());
3591  auto [initVal, yieldOpArg] = it.value();
3592  // If i-th yield operand is equal to the i-th operand of the scf.while,
3593  // the i-th before block argument is a loop invariant.
3594  if (yieldOpArg == initVal) {
3595  canSimplify = true;
3596  break;
3597  }
3598  // If the i-th yield operand is k-th after block argument, then we check
3599  // if the (k+1)-th condition op operand is equal to either the i-th before
3600  // block argument or the initial value of i-th before block argument. If
3601  // the comparison results `true`, i-th before block argument is a loop
3602  // invariant.
3603  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3604  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3605  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3606  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3607  canSimplify = true;
3608  break;
3609  }
3610  }
3611  }
3612 
3613  if (!canSimplify)
3614  return failure();
3615 
3616  SmallVector<Value> newInitArgs, newYieldOpArgs;
3617  DenseMap<unsigned, Value> beforeBlockInitValMap;
3618  SmallVector<Location> newBeforeBlockArgLocs;
3619  for (const auto &it :
3620  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3621  auto index = static_cast<unsigned>(it.index());
3622  auto [initVal, yieldOpArg] = it.value();
3623 
3624  // If i-th yield operand is equal to the i-th operand of the scf.while,
3625  // the i-th before block argument is a loop invariant.
3626  if (yieldOpArg == initVal) {
3627  beforeBlockInitValMap.insert({index, initVal});
3628  continue;
3629  } else {
3630  // If the i-th yield operand is k-th after block argument, then we check
3631  // if the (k+1)-th condition op operand is equal to either the i-th
3632  // before block argument or the initial value of i-th before block
3633  // argument. If the comparison results `true`, i-th before block
3634  // argument is a loop invariant.
3635  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3636  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3637  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3638  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3639  beforeBlockInitValMap.insert({index, initVal});
3640  continue;
3641  }
3642  }
3643  }
3644  newInitArgs.emplace_back(initVal);
3645  newYieldOpArgs.emplace_back(yieldOpArg);
3646  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3647  }
3648 
3649  {
3650  OpBuilder::InsertionGuard g(rewriter);
3651  rewriter.setInsertionPoint(yieldOp);
3652  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3653  }
3654 
3655  auto newWhile =
3656  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3657 
3658  Block &newBeforeBlock = *rewriter.createBlock(
3659  &newWhile.getBefore(), /*insertPt*/ {},
3660  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3661 
3662  Block &beforeBlock = *op.getBeforeBody();
3663  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3664  // For each i-th before block argument we find it's replacement value as :-
3665  // 1. If i-th before block argument is a loop invariant, we fetch it's
3666  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3667  // 2. Else we fetch j-th new before block argument as the replacement
3668  // value of i-th before block argument.
3669  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3670  // If the index 'i' argument was a loop invariant we fetch it's initial
3671  // value from `beforeBlockInitValMap`.
3672  if (beforeBlockInitValMap.count(i) != 0)
3673  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3674  else
3675  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3676  }
3677 
3678  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3679  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3680  newWhile.getAfter().begin());
3681 
3682  rewriter.replaceOp(op, newWhile.getResults());
3683  return success();
3684  }
3685 };
3686 
3687 /// Remove loop invariant value from result (condition op) of scf.while.
3688 /// A value is considered loop invariant if the final value yielded by
3689 /// scf.condition is defined outside of the `before` block. We remove the
3690 /// corresponding argument in `after` block and replace the use with the value.
3691 /// We also replace the use of the corresponding result of scf.while with the
3692 /// value.
3693 ///
3694 /// Eg:
3695 /// INPUT :-
3696 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3697 /// %argN_before = %N) {
3698 /// ...
3699 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3700 /// } do {
3701 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3702 /// ...
3703 /// some_func(%arg1_after)
3704 /// ...
3705 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3706 /// }
3707 ///
3708 /// OUTPUT :-
3709 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3710 /// ...
3711 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3712 /// } do {
3713 /// ^bb0(%arg0, %arg3, ..., %argM):
3714 /// ...
3715 /// some_func(%a)
3716 /// ...
3717 /// scf.yield %arg0, %b, ..., %argN
3718 /// }
3719 ///
3720 /// EXPLANATION:
3721 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3722 /// before block of scf.while, so they get removed.
3723 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3724 /// replaced by %b.
3725 /// 3. The corresponding after block argument %arg1_after's uses are
3726 /// replaced by %a and %arg2_after's uses are replaced by %b.
3727 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3729 
3730  LogicalResult matchAndRewrite(WhileOp op,
3731  PatternRewriter &rewriter) const override {
3732  Block &beforeBlock = *op.getBeforeBody();
3733  ConditionOp condOp = op.getConditionOp();
3734  OperandRange condOpArgs = condOp.getArgs();
3735 
3736  bool canSimplify = false;
3737  for (Value condOpArg : condOpArgs) {
3738  // Those values not defined within `before` block will be considered as
3739  // loop invariant values. We map the corresponding `index` with their
3740  // value.
3741  if (condOpArg.getParentBlock() != &beforeBlock) {
3742  canSimplify = true;
3743  break;
3744  }
3745  }
3746 
3747  if (!canSimplify)
3748  return failure();
3749 
3750  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3751 
3752  SmallVector<Value> newCondOpArgs;
3753  SmallVector<Type> newAfterBlockType;
3754  DenseMap<unsigned, Value> condOpInitValMap;
3755  SmallVector<Location> newAfterBlockArgLocs;
3756  for (const auto &it : llvm::enumerate(condOpArgs)) {
3757  auto index = static_cast<unsigned>(it.index());
3758  Value condOpArg = it.value();
3759  // Those values not defined within `before` block will be considered as
3760  // loop invariant values. We map the corresponding `index` with their
3761  // value.
3762  if (condOpArg.getParentBlock() != &beforeBlock) {
3763  condOpInitValMap.insert({index, condOpArg});
3764  } else {
3765  newCondOpArgs.emplace_back(condOpArg);
3766  newAfterBlockType.emplace_back(condOpArg.getType());
3767  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3768  }
3769  }
3770 
3771  {
3772  OpBuilder::InsertionGuard g(rewriter);
3773  rewriter.setInsertionPoint(condOp);
3774  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3775  newCondOpArgs);
3776  }
3777 
3778  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3779  op.getOperands());
3780 
3781  Block &newAfterBlock =
3782  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3783  newAfterBlockType, newAfterBlockArgLocs);
3784 
3785  Block &afterBlock = *op.getAfterBody();
3786  // Since a new scf.condition op was created, we need to fetch the new
3787  // `after` block arguments which will be used while replacing operations of
3788  // previous scf.while's `after` blocks. We'd also be fetching new result
3789  // values too.
3790  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3791  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3792  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3793  Value afterBlockArg, result;
3794  // If index 'i' argument was loop invariant we fetch it's value from the
3795  // `condOpInitMap` map.
3796  if (condOpInitValMap.count(i) != 0) {
3797  afterBlockArg = condOpInitValMap[i];
3798  result = afterBlockArg;
3799  } else {
3800  afterBlockArg = newAfterBlock.getArgument(j);
3801  result = newWhile.getResult(j);
3802  j++;
3803  }
3804  newAfterBlockArgs[i] = afterBlockArg;
3805  newWhileResults[i] = result;
3806  }
3807 
3808  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3809  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3810  newWhile.getBefore().begin());
3811 
3812  rewriter.replaceOp(op, newWhileResults);
3813  return success();
3814  }
3815 };
3816 
3817 /// Remove WhileOp results that are also unused in 'after' block.
3818 ///
3819 /// %0:2 = scf.while () : () -> (i32, i64) {
3820 /// %condition = "test.condition"() : () -> i1
3821 /// %v1 = "test.get_some_value"() : () -> i32
3822 /// %v2 = "test.get_some_value"() : () -> i64
3823 /// scf.condition(%condition) %v1, %v2 : i32, i64
3824 /// } do {
3825 /// ^bb0(%arg0: i32, %arg1: i64):
3826 /// "test.use"(%arg0) : (i32) -> ()
3827 /// scf.yield
3828 /// }
3829 /// return %0#0 : i32
3830 ///
3831 /// becomes
3832 /// %0 = scf.while () : () -> (i32) {
3833 /// %condition = "test.condition"() : () -> i1
3834 /// %v1 = "test.get_some_value"() : () -> i32
3835 /// %v2 = "test.get_some_value"() : () -> i64
3836 /// scf.condition(%condition) %v1 : i32
3837 /// } do {
3838 /// ^bb0(%arg0: i32):
3839 /// "test.use"(%arg0) : (i32) -> ()
3840 /// scf.yield
3841 /// }
3842 /// return %0 : i32
3843 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3845 
3846  LogicalResult matchAndRewrite(WhileOp op,
3847  PatternRewriter &rewriter) const override {
3848  auto term = op.getConditionOp();
3849  auto afterArgs = op.getAfterArguments();
3850  auto termArgs = term.getArgs();
3851 
3852  // Collect results mapping, new terminator args and new result types.
3853  SmallVector<unsigned> newResultsIndices;
3854  SmallVector<Type> newResultTypes;
3855  SmallVector<Value> newTermArgs;
3856  SmallVector<Location> newArgLocs;
3857  bool needUpdate = false;
3858  for (const auto &it :
3859  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3860  auto i = static_cast<unsigned>(it.index());
3861  Value result = std::get<0>(it.value());
3862  Value afterArg = std::get<1>(it.value());
3863  Value termArg = std::get<2>(it.value());
3864  if (result.use_empty() && afterArg.use_empty()) {
3865  needUpdate = true;
3866  } else {
3867  newResultsIndices.emplace_back(i);
3868  newTermArgs.emplace_back(termArg);
3869  newResultTypes.emplace_back(result.getType());
3870  newArgLocs.emplace_back(result.getLoc());
3871  }
3872  }
3873 
3874  if (!needUpdate)
3875  return failure();
3876 
3877  {
3878  OpBuilder::InsertionGuard g(rewriter);
3879  rewriter.setInsertionPoint(term);
3880  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3881  newTermArgs);
3882  }
3883 
3884  auto newWhile =
3885  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3886 
3887  Block &newAfterBlock = *rewriter.createBlock(
3888  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3889 
3890  // Build new results list and new after block args (unused entries will be
3891  // null).
3892  SmallVector<Value> newResults(op.getNumResults());
3893  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3894  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3895  newResults[it.value()] = newWhile.getResult(it.index());
3896  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3897  }
3898 
3899  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3900  newWhile.getBefore().begin());
3901 
3902  Block &afterBlock = *op.getAfterBody();
3903  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3904 
3905  rewriter.replaceOp(op, newResults);
3906  return success();
3907  }
3908 };
3909 
3910 /// Replace operations equivalent to the condition in the do block with true,
3911 /// since otherwise the block would not be evaluated.
3912 ///
3913 /// scf.while (..) : (i32, ...) -> ... {
3914 /// %z = ... : i32
3915 /// %condition = cmpi pred %z, %a
3916 /// scf.condition(%condition) %z : i32, ...
3917 /// } do {
3918 /// ^bb0(%arg0: i32, ...):
3919 /// %condition2 = cmpi pred %arg0, %a
3920 /// use(%condition2)
3921 /// ...
3922 ///
3923 /// becomes
3924 /// scf.while (..) : (i32, ...) -> ... {
3925 /// %z = ... : i32
3926 /// %condition = cmpi pred %z, %a
3927 /// scf.condition(%condition) %z : i32, ...
3928 /// } do {
3929 /// ^bb0(%arg0: i32, ...):
3930 /// use(%true)
3931 /// ...
3932 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3934 
3935  LogicalResult matchAndRewrite(scf::WhileOp op,
3936  PatternRewriter &rewriter) const override {
3937  using namespace scf;
3938  auto cond = op.getConditionOp();
3939  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3940  if (!cmp)
3941  return failure();
3942  bool changed = false;
3943  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3944  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3945  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3946  continue;
3947  for (OpOperand &u :
3948  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3949  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3950  if (!cmp2)
3951  continue;
3952  // For a binary operator 1-opIdx gets the other side.
3953  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3954  continue;
3955  bool samePredicate;
3956  if (cmp2.getPredicate() == cmp.getPredicate())
3957  samePredicate = true;
3958  else if (cmp2.getPredicate() ==
3959  arith::invertPredicate(cmp.getPredicate()))
3960  samePredicate = false;
3961  else
3962  continue;
3963 
3964  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3965  1);
3966  changed = true;
3967  }
3968  }
3969  }
3970  return success(changed);
3971  }
3972 };
3973 
3974 /// Remove unused init/yield args.
3975 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3977 
3978  LogicalResult matchAndRewrite(WhileOp op,
3979  PatternRewriter &rewriter) const override {
3980 
3981  if (!llvm::any_of(op.getBeforeArguments(),
3982  [](Value arg) { return arg.use_empty(); }))
3983  return rewriter.notifyMatchFailure(op, "No args to remove");
3984 
3985  YieldOp yield = op.getYieldOp();
3986 
3987  // Collect results mapping, new terminator args and new result types.
3988  SmallVector<Value> newYields;
3989  SmallVector<Value> newInits;
3990  llvm::BitVector argsToErase;
3991 
3992  size_t argsCount = op.getBeforeArguments().size();
3993  newYields.reserve(argsCount);
3994  newInits.reserve(argsCount);
3995  argsToErase.reserve(argsCount);
3996  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3997  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3998  if (beforeArg.use_empty()) {
3999  argsToErase.push_back(true);
4000  } else {
4001  argsToErase.push_back(false);
4002  newYields.emplace_back(yieldValue);
4003  newInits.emplace_back(initValue);
4004  }
4005  }
4006 
4007  Block &beforeBlock = *op.getBeforeBody();
4008  Block &afterBlock = *op.getAfterBody();
4009 
4010  beforeBlock.eraseArguments(argsToErase);
4011 
4012  Location loc = op.getLoc();
4013  auto newWhileOp =
4014  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4015  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4016  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4017  Block &newAfterBlock = *newWhileOp.getAfterBody();
4018 
4019  OpBuilder::InsertionGuard g(rewriter);
4020  rewriter.setInsertionPoint(yield);
4021  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4022 
4023  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4024  newBeforeBlock.getArguments());
4025  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4026  newAfterBlock.getArguments());
4027 
4028  rewriter.replaceOp(op, newWhileOp.getResults());
4029  return success();
4030  }
4031 };
4032 
4033 /// Remove duplicated ConditionOp args.
4034 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4036 
4037  LogicalResult matchAndRewrite(WhileOp op,
4038  PatternRewriter &rewriter) const override {
4039  ConditionOp condOp = op.getConditionOp();
4040  ValueRange condOpArgs = condOp.getArgs();
4041 
4042  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4043 
4044  if (argsSet.size() == condOpArgs.size())
4045  return rewriter.notifyMatchFailure(op, "No results to remove");
4046 
4047  llvm::SmallDenseMap<Value, unsigned> argsMap;
4048  SmallVector<Value> newArgs;
4049  argsMap.reserve(condOpArgs.size());
4050  newArgs.reserve(condOpArgs.size());
4051  for (Value arg : condOpArgs) {
4052  if (!argsMap.count(arg)) {
4053  auto pos = static_cast<unsigned>(argsMap.size());
4054  argsMap.insert({arg, pos});
4055  newArgs.emplace_back(arg);
4056  }
4057  }
4058 
4059  ValueRange argsRange(newArgs);
4060 
4061  Location loc = op.getLoc();
4062  auto newWhileOp = rewriter.create<scf::WhileOp>(
4063  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4064  /*afterBody*/ nullptr);
4065  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4066  Block &newAfterBlock = *newWhileOp.getAfterBody();
4067 
4068  SmallVector<Value> afterArgsMapping;
4069  SmallVector<Value> resultsMapping;
4070  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4071  auto it = argsMap.find(arg);
4072  assert(it != argsMap.end());
4073  auto pos = it->second;
4074  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4075  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4076  }
4077 
4078  OpBuilder::InsertionGuard g(rewriter);
4079  rewriter.setInsertionPoint(condOp);
4080  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4081  argsRange);
4082 
4083  Block &beforeBlock = *op.getBeforeBody();
4084  Block &afterBlock = *op.getAfterBody();
4085 
4086  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4087  newBeforeBlock.getArguments());
4088  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4089  rewriter.replaceOp(op, resultsMapping);
4090  return success();
4091  }
4092 };
4093 
4094 /// If both ranges contain same values return mappping indices from args2 to
4095 /// args1. Otherwise return std::nullopt.
4096 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4097  ValueRange args2) {
4098  if (args1.size() != args2.size())
4099  return std::nullopt;
4100 
4101  SmallVector<unsigned> ret(args1.size());
4102  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4103  auto it = llvm::find(args2, arg1);
4104  if (it == args2.end())
4105  return std::nullopt;
4106 
4107  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4108  }
4109 
4110  return ret;
4111 }
4112 
4113 static bool hasDuplicates(ValueRange args) {
4114  llvm::SmallDenseSet<Value> set;
4115  for (Value arg : args) {
4116  if (!set.insert(arg).second)
4117  return true;
4118  }
4119  return false;
4120 }
4121 
4122 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4123 /// `scf.condition` args into same order as block args. Update `after` block
4124 /// args and op result values accordingly.
4125 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4126 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4128 
4129  LogicalResult matchAndRewrite(WhileOp loop,
4130  PatternRewriter &rewriter) const override {
4131  auto oldBefore = loop.getBeforeBody();
4132  ConditionOp oldTerm = loop.getConditionOp();
4133  ValueRange beforeArgs = oldBefore->getArguments();
4134  ValueRange termArgs = oldTerm.getArgs();
4135  if (beforeArgs == termArgs)
4136  return failure();
4137 
4138  if (hasDuplicates(termArgs))
4139  return failure();
4140 
4141  auto mapping = getArgsMapping(beforeArgs, termArgs);
4142  if (!mapping)
4143  return failure();
4144 
4145  {
4146  OpBuilder::InsertionGuard g(rewriter);
4147  rewriter.setInsertionPoint(oldTerm);
4148  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4149  beforeArgs);
4150  }
4151 
4152  auto oldAfter = loop.getAfterBody();
4153 
4154  SmallVector<Type> newResultTypes(beforeArgs.size());
4155  for (auto &&[i, j] : llvm::enumerate(*mapping))
4156  newResultTypes[j] = loop.getResult(i).getType();
4157 
4158  auto newLoop = rewriter.create<WhileOp>(
4159  loop.getLoc(), newResultTypes, loop.getInits(),
4160  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4161  auto newBefore = newLoop.getBeforeBody();
4162  auto newAfter = newLoop.getAfterBody();
4163 
4164  SmallVector<Value> newResults(beforeArgs.size());
4165  SmallVector<Value> newAfterArgs(beforeArgs.size());
4166  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4167  newResults[i] = newLoop.getResult(j);
4168  newAfterArgs[i] = newAfter->getArgument(j);
4169  }
4170 
4171  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4172  newBefore->getArguments());
4173  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4174  newAfterArgs);
4175 
4176  rewriter.replaceOp(loop, newResults);
4177  return success();
4178  }
4179 };
4180 } // namespace
4181 
4182 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4183  MLIRContext *context) {
4184  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4185  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4186  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4187  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4188 }
4189 
4190 //===----------------------------------------------------------------------===//
4191 // IndexSwitchOp
4192 //===----------------------------------------------------------------------===//
4193 
4194 /// Parse the case regions and values.
4195 static ParseResult
4197  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4198  SmallVector<int64_t> caseValues;
4199  while (succeeded(p.parseOptionalKeyword("case"))) {
4200  int64_t value;
4201  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4202  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4203  return failure();
4204  caseValues.push_back(value);
4205  }
4206  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4207  return success();
4208 }
4209 
4210 /// Print the case regions and values.
4212  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4213  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4214  p.printNewline();
4215  p << "case " << value << ' ';
4216  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4217  }
4218 }
4219 
4220 LogicalResult scf::IndexSwitchOp::verify() {
4221  if (getCases().size() != getCaseRegions().size()) {
4222  return emitOpError("has ")
4223  << getCaseRegions().size() << " case regions but "
4224  << getCases().size() << " case values";
4225  }
4226 
4227  DenseSet<int64_t> valueSet;
4228  for (int64_t value : getCases())
4229  if (!valueSet.insert(value).second)
4230  return emitOpError("has duplicate case value: ") << value;
4231  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4232  auto yield = dyn_cast<YieldOp>(region.front().back());
4233  if (!yield)
4234  return emitOpError("expected region to end with scf.yield, but got ")
4235  << region.front().back().getName();
4236 
4237  if (yield.getNumOperands() != getNumResults()) {
4238  return (emitOpError("expected each region to return ")
4239  << getNumResults() << " values, but " << name << " returns "
4240  << yield.getNumOperands())
4241  .attachNote(yield.getLoc())
4242  << "see yield operation here";
4243  }
4244  for (auto [idx, result, operand] :
4245  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4246  yield.getOperandTypes())) {
4247  if (result == operand)
4248  continue;
4249  return (emitOpError("expected result #")
4250  << idx << " of each region to be " << result)
4251  .attachNote(yield.getLoc())
4252  << name << " returns " << operand << " here";
4253  }
4254  return success();
4255  };
4256 
4257  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4258  return failure();
4259  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4260  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4261  return failure();
4262 
4263  return success();
4264 }
4265 
4266 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4267 
4268 Block &scf::IndexSwitchOp::getDefaultBlock() {
4269  return getDefaultRegion().front();
4270 }
4271 
4272 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4273  assert(idx < getNumCases() && "case index out-of-bounds");
4274  return getCaseRegions()[idx].front();
4275 }
4276 
4277 void IndexSwitchOp::getSuccessorRegions(
4279  // All regions branch back to the parent op.
4280  if (!point.isParent()) {
4281  successors.emplace_back(getResults());
4282  return;
4283  }
4284 
4285  llvm::copy(getRegions(), std::back_inserter(successors));
4286 }
4287 
4288 void IndexSwitchOp::getEntrySuccessorRegions(
4289  ArrayRef<Attribute> operands,
4290  SmallVectorImpl<RegionSuccessor> &successors) {
4291  FoldAdaptor adaptor(operands, *this);
4292 
4293  // If a constant was not provided, all regions are possible successors.
4294  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4295  if (!arg) {
4296  llvm::copy(getRegions(), std::back_inserter(successors));
4297  return;
4298  }
4299 
4300  // Otherwise, try to find a case with a matching value. If not, the
4301  // default region is the only successor.
4302  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4303  if (caseValue == arg.getInt()) {
4304  successors.emplace_back(&caseRegion);
4305  return;
4306  }
4307  }
4308  successors.emplace_back(&getDefaultRegion());
4309 }
4310 
4311 void IndexSwitchOp::getRegionInvocationBounds(
4313  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4314  if (!operandValue) {
4315  // All regions are invoked at most once.
4316  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4317  return;
4318  }
4319 
4320  unsigned liveIndex = getNumRegions() - 1;
4321  const auto *it = llvm::find(getCases(), operandValue.getInt());
4322  if (it != getCases().end())
4323  liveIndex = std::distance(getCases().begin(), it);
4324  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4325  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4326 }
4327 
4328 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4330 
4331  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4332  PatternRewriter &rewriter) const override {
4333  // If `op.getArg()` is a constant, select the region that matches with
4334  // the constant value. Use the default region if no matche is found.
4335  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4336  if (!maybeCst.has_value())
4337  return failure();
4338  int64_t cst = *maybeCst;
4339  int64_t caseIdx, e = op.getNumCases();
4340  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4341  if (cst == op.getCases()[caseIdx])
4342  break;
4343  }
4344 
4345  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4346  : op.getDefaultRegion();
4347  Block &source = r.front();
4348  Operation *terminator = source.getTerminator();
4349  SmallVector<Value> results = terminator->getOperands();
4350 
4351  rewriter.inlineBlockBefore(&source, op);
4352  rewriter.eraseOp(terminator);
4353  // Replace the operation with a potentially empty list of results.
4354  // Fold mechanism doesn't support the case where the result list is empty.
4355  rewriter.replaceOp(op, results);
4356 
4357  return success();
4358  }
4359 };
4360 
4361 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4362  MLIRContext *context) {
4363  results.add<FoldConstantCase>(context);
4364 }
4365 
4366 //===----------------------------------------------------------------------===//
4367 // TableGen'd op method definitions
4368 //===----------------------------------------------------------------------===//
4369 
4370 #define GET_OP_CLASSES
4371 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:736
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:728
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1286
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4196
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3444
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4211
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:571
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:582
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:632
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3057
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:692
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1984
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:649
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:602
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1458
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:781
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4331
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.