MLIR  21.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange initArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(initArgs);
321  for (Value v : initArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : initArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (initArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Set block argument types, so that they are known when parsing the region.
503  regionArgs.front().type = type;
504  for (auto [iterArg, type] :
505  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
506  iterArg.type = type;
507 
508  // Parse the body region.
509  Region *body = result.addRegion();
510  if (parser.parseRegion(*body, regionArgs))
511  return failure();
512  ForOp::ensureTerminator(*body, builder, result.location);
513 
514  // Resolve input operands. This should be done after parsing the region to
515  // catch invalid IR where operands were defined inside of the region.
516  if (parser.resolveOperand(lb, type, result.operands) ||
517  parser.resolveOperand(ub, type, result.operands) ||
518  parser.resolveOperand(step, type, result.operands))
519  return failure();
520  if (hasIterArgs) {
521  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
522  operands, result.types)) {
523  Type type = std::get<2>(argOperandType);
524  std::get<0>(argOperandType).type = type;
525  if (parser.resolveOperand(std::get<1>(argOperandType), type,
526  result.operands))
527  return failure();
528  }
529  }
530 
531  // Parse the optional attribute list.
532  if (parser.parseOptionalAttrDict(result.attributes))
533  return failure();
534 
535  return success();
536 }
537 
538 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
539 
540 Block::BlockArgListType ForOp::getRegionIterArgs() {
541  return getBody()->getArguments().drop_front(getNumInductionVars());
542 }
543 
544 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
545  return getInitArgsMutable();
546 }
547 
548 FailureOr<LoopLikeOpInterface>
549 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
550  ValueRange newInitOperands,
551  bool replaceInitOperandUsesInLoop,
552  const NewYieldValuesFn &newYieldValuesFn) {
553  // Create a new loop before the existing one, with the extra operands.
554  OpBuilder::InsertionGuard g(rewriter);
555  rewriter.setInsertionPoint(getOperation());
556  auto inits = llvm::to_vector(getInitArgs());
557  inits.append(newInitOperands.begin(), newInitOperands.end());
558  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
559  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
560  [](OpBuilder &, Location, Value, ValueRange) {});
561  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
562 
563  // Generate the new yield values and append them to the scf.yield operation.
564  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
565  ArrayRef<BlockArgument> newIterArgs =
566  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
567  {
568  OpBuilder::InsertionGuard g(rewriter);
569  rewriter.setInsertionPoint(yieldOp);
570  SmallVector<Value> newYieldedValues =
571  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
572  assert(newInitOperands.size() == newYieldedValues.size() &&
573  "expected as many new yield values as new iter operands");
574  rewriter.modifyOpInPlace(yieldOp, [&]() {
575  yieldOp.getResultsMutable().append(newYieldedValues);
576  });
577  }
578 
579  // Move the loop body to the new op.
580  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
581  newLoop.getBody()->getArguments().take_front(
582  getBody()->getNumArguments()));
583 
584  if (replaceInitOperandUsesInLoop) {
585  // Replace all uses of `newInitOperands` with the corresponding basic block
586  // arguments.
587  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
588  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
589  [&](OpOperand &use) {
590  Operation *user = use.getOwner();
591  return newLoop->isProperAncestor(user);
592  });
593  }
594  }
595 
596  // Replace the old loop.
597  rewriter.replaceOp(getOperation(),
598  newLoop->getResults().take_front(getNumResults()));
599  return cast<LoopLikeOpInterface>(newLoop.getOperation());
600 }
601 
603  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
604  if (!ivArg)
605  return ForOp();
606  assert(ivArg.getOwner() && "unlinked block argument");
607  auto *containingOp = ivArg.getOwner()->getParentOp();
608  return dyn_cast_or_null<ForOp>(containingOp);
609 }
610 
611 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
612  return getInitArgs();
613 }
614 
615 void ForOp::getSuccessorRegions(RegionBranchPoint point,
617  // Both the operation itself and the region may be branching into the body or
618  // back into the operation itself. It is possible for loop not to enter the
619  // body.
620  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
621  regions.push_back(RegionSuccessor(getResults()));
622 }
623 
624 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
625 
626 /// Promotes the loop body of a forallOp to its containing block if it can be
627 /// determined that the loop has a single iteration.
628 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
629  for (auto [lb, ub, step] :
630  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
631  auto tripCount = constantTripCount(lb, ub, step);
632  if (!tripCount.has_value() || *tripCount != 1)
633  return failure();
634  }
635 
636  promote(rewriter, *this);
637  return success();
638 }
639 
640 Block::BlockArgListType ForallOp::getRegionIterArgs() {
641  return getBody()->getArguments().drop_front(getRank());
642 }
643 
644 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
645  return getOutputsMutable();
646 }
647 
648 /// Promotes the loop body of a scf::ForallOp to its containing block.
649 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
650  OpBuilder::InsertionGuard g(rewriter);
651  scf::InParallelOp terminator = forallOp.getTerminator();
652 
653  // Replace block arguments with lower bounds (replacements for IVs) and
654  // outputs.
655  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
656  bbArgReplacements.append(forallOp.getOutputs().begin(),
657  forallOp.getOutputs().end());
658 
659  // Move the loop body operations to the loop's containing block.
660  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
661  forallOp->getIterator(), bbArgReplacements);
662 
663  // Replace the terminator with tensor.insert_slice ops.
664  rewriter.setInsertionPointAfter(forallOp);
665  SmallVector<Value> results;
666  results.reserve(forallOp.getResults().size());
667  for (auto &yieldingOp : terminator.getYieldingOps()) {
668  auto parallelInsertSliceOp =
669  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
670 
671  Value dst = parallelInsertSliceOp.getDest();
672  Value src = parallelInsertSliceOp.getSource();
673  if (llvm::isa<TensorType>(src.getType())) {
674  results.push_back(rewriter.create<tensor::InsertSliceOp>(
675  forallOp.getLoc(), dst.getType(), src, dst,
676  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
677  parallelInsertSliceOp.getStrides(),
678  parallelInsertSliceOp.getStaticOffsets(),
679  parallelInsertSliceOp.getStaticSizes(),
680  parallelInsertSliceOp.getStaticStrides()));
681  } else {
682  llvm_unreachable("unsupported terminator");
683  }
684  }
685  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
686 
687  // Erase the old terminator and the loop.
688  rewriter.eraseOp(terminator);
689  rewriter.eraseOp(forallOp);
690 }
691 
693  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
694  ValueRange steps, ValueRange iterArgs,
696  bodyBuilder) {
697  assert(lbs.size() == ubs.size() &&
698  "expected the same number of lower and upper bounds");
699  assert(lbs.size() == steps.size() &&
700  "expected the same number of lower bounds and steps");
701 
702  // If there are no bounds, call the body-building function and return early.
703  if (lbs.empty()) {
704  ValueVector results =
705  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
706  : ValueVector();
707  assert(results.size() == iterArgs.size() &&
708  "loop nest body must return as many values as loop has iteration "
709  "arguments");
710  return LoopNest{{}, std::move(results)};
711  }
712 
713  // First, create the loop structure iteratively using the body-builder
714  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
715  OpBuilder::InsertionGuard guard(builder);
718  loops.reserve(lbs.size());
719  ivs.reserve(lbs.size());
720  ValueRange currentIterArgs = iterArgs;
721  Location currentLoc = loc;
722  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
723  auto loop = builder.create<scf::ForOp>(
724  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
725  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
726  ValueRange args) {
727  ivs.push_back(iv);
728  // It is safe to store ValueRange args because it points to block
729  // arguments of a loop operation that we also own.
730  currentIterArgs = args;
731  currentLoc = nestedLoc;
732  });
733  // Set the builder to point to the body of the newly created loop. We don't
734  // do this in the callback because the builder is reset when the callback
735  // returns.
736  builder.setInsertionPointToStart(loop.getBody());
737  loops.push_back(loop);
738  }
739 
740  // For all loops but the innermost, yield the results of the nested loop.
741  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
742  builder.setInsertionPointToEnd(loops[i].getBody());
743  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
744  }
745 
746  // In the body of the innermost loop, call the body building function if any
747  // and yield its results.
748  builder.setInsertionPointToStart(loops.back().getBody());
749  ValueVector results = bodyBuilder
750  ? bodyBuilder(builder, currentLoc, ivs,
751  loops.back().getRegionIterArgs())
752  : ValueVector();
753  assert(results.size() == iterArgs.size() &&
754  "loop nest body must return as many values as loop has iteration "
755  "arguments");
756  builder.setInsertionPointToEnd(loops.back().getBody());
757  builder.create<scf::YieldOp>(loc, results);
758 
759  // Return the loops.
760  ValueVector nestResults;
761  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
762  return LoopNest{std::move(loops), std::move(nestResults)};
763 }
764 
766  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
767  ValueRange steps,
768  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
769  // Delegate to the main function by wrapping the body builder.
770  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
771  [&bodyBuilder](OpBuilder &nestedBuilder,
772  Location nestedLoc, ValueRange ivs,
773  ValueRange) -> ValueVector {
774  if (bodyBuilder)
775  bodyBuilder(nestedBuilder, nestedLoc, ivs);
776  return {};
777  });
778 }
779 
782  OpOperand &operand, Value replacement,
783  const ValueTypeCastFnTy &castFn) {
784  assert(operand.getOwner() == forOp);
785  Type oldType = operand.get().getType(), newType = replacement.getType();
786 
787  // 1. Create new iter operands, exactly 1 is replaced.
788  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
789  "expected an iter OpOperand");
790  assert(operand.get().getType() != replacement.getType() &&
791  "Expected a different type");
792  SmallVector<Value> newIterOperands;
793  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
794  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
795  newIterOperands.push_back(replacement);
796  continue;
797  }
798  newIterOperands.push_back(opOperand.get());
799  }
800 
801  // 2. Create the new forOp shell.
802  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
803  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
804  forOp.getStep(), newIterOperands);
805  newForOp->setAttrs(forOp->getAttrs());
806  Block &newBlock = newForOp.getRegion().front();
807  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
808  newBlock.getArguments().end());
809 
810  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
811  // corresponding to the `replacement` value.
812  OpBuilder::InsertionGuard g(rewriter);
813  rewriter.setInsertionPointToStart(&newBlock);
814  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
815  &newForOp->getOpOperand(operand.getOperandNumber()));
816  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
817  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
818 
819  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
820  Block &oldBlock = forOp.getRegion().front();
821  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
822 
823  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
824  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
825  rewriter.setInsertionPoint(clonedYieldOp);
826  unsigned yieldIdx =
827  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
828  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
829  clonedYieldOp.getOperand(yieldIdx));
830  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
831  newYieldOperands[yieldIdx] = castOut;
832  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
833  rewriter.eraseOp(clonedYieldOp);
834 
835  // 6. Inject an outgoing cast op after the forOp.
836  rewriter.setInsertionPointAfter(newForOp);
837  SmallVector<Value> newResults = newForOp.getResults();
838  newResults[yieldIdx] =
839  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
840 
841  return newResults;
842 }
843 
844 namespace {
845 // Fold away ForOp iter arguments when:
846 // 1) The op yields the iter arguments.
847 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
848 // 3) The iter arguments have no use and the corresponding (operation) results
849 // have no use.
850 //
851 // These arguments must be defined outside of the ForOp region and can just be
852 // forwarded after simplifying the op inits, yields and returns.
853 //
854 // The implementation uses `inlineBlockBefore` to steal the content of the
855 // original ForOp and avoid cloning.
856 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
858 
859  LogicalResult matchAndRewrite(scf::ForOp forOp,
860  PatternRewriter &rewriter) const final {
861  bool canonicalize = false;
862 
863  // An internal flat vector of block transfer
864  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
865  // transformed block argument mappings. This plays the role of a
866  // IRMapping for the particular use case of calling into
867  // `inlineBlockBefore`.
868  int64_t numResults = forOp.getNumResults();
869  SmallVector<bool, 4> keepMask;
870  keepMask.reserve(numResults);
871  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
872  newResultValues;
873  newBlockTransferArgs.reserve(1 + numResults);
874  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
875  newIterArgs.reserve(forOp.getInitArgs().size());
876  newYieldValues.reserve(numResults);
877  newResultValues.reserve(numResults);
878  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
879  for (auto [init, arg, result, yielded] :
880  llvm::zip(forOp.getInitArgs(), // iter from outside
881  forOp.getRegionIterArgs(), // iter inside region
882  forOp.getResults(), // op results
883  forOp.getYieldedValues() // iter yield
884  )) {
885  // Forwarded is `true` when:
886  // 1) The region `iter` argument is yielded.
887  // 2) The region `iter` argument the corresponding input is yielded.
888  // 3) The region `iter` argument has no use, and the corresponding op
889  // result has no use.
890  bool forwarded = (arg == yielded) || (init == yielded) ||
891  (arg.use_empty() && result.use_empty());
892  if (forwarded) {
893  canonicalize = true;
894  keepMask.push_back(false);
895  newBlockTransferArgs.push_back(init);
896  newResultValues.push_back(init);
897  continue;
898  }
899 
900  // Check if a previous kept argument always has the same values for init
901  // and yielded values.
902  if (auto it = initYieldToArg.find({init, yielded});
903  it != initYieldToArg.end()) {
904  canonicalize = true;
905  keepMask.push_back(false);
906  auto [sameArg, sameResult] = it->second;
907  rewriter.replaceAllUsesWith(arg, sameArg);
908  rewriter.replaceAllUsesWith(result, sameResult);
909  // The replacement value doesn't matter because there are no uses.
910  newBlockTransferArgs.push_back(init);
911  newResultValues.push_back(init);
912  continue;
913  }
914 
915  // This value is kept.
916  initYieldToArg.insert({{init, yielded}, {arg, result}});
917  keepMask.push_back(true);
918  newIterArgs.push_back(init);
919  newYieldValues.push_back(yielded);
920  newBlockTransferArgs.push_back(Value()); // placeholder with null value
921  newResultValues.push_back(Value()); // placeholder with null value
922  }
923 
924  if (!canonicalize)
925  return failure();
926 
927  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
928  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
929  forOp.getStep(), newIterArgs);
930  newForOp->setAttrs(forOp->getAttrs());
931  Block &newBlock = newForOp.getRegion().front();
932 
933  // Replace the null placeholders with newly constructed values.
934  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
935  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
936  idx != e; ++idx) {
937  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
938  Value &newResultVal = newResultValues[idx];
939  assert((blockTransferArg && newResultVal) ||
940  (!blockTransferArg && !newResultVal));
941  if (!blockTransferArg) {
942  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
943  newResultVal = newForOp.getResult(collapsedIdx++);
944  }
945  }
946 
947  Block &oldBlock = forOp.getRegion().front();
948  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
949  "unexpected argument size mismatch");
950 
951  // No results case: the scf::ForOp builder already created a zero
952  // result terminator. Merge before this terminator and just get rid of the
953  // original terminator that has been merged in.
954  if (newIterArgs.empty()) {
955  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
956  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
957  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
958  rewriter.replaceOp(forOp, newResultValues);
959  return success();
960  }
961 
962  // No terminator case: merge and rewrite the merged terminator.
963  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
964  OpBuilder::InsertionGuard g(rewriter);
965  rewriter.setInsertionPoint(mergedTerminator);
966  SmallVector<Value, 4> filteredOperands;
967  filteredOperands.reserve(newResultValues.size());
968  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
969  if (keepMask[idx])
970  filteredOperands.push_back(mergedTerminator.getOperand(idx));
971  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
972  filteredOperands);
973  };
974 
975  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
976  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
977  cloneFilteredTerminator(mergedYieldOp);
978  rewriter.eraseOp(mergedYieldOp);
979  rewriter.replaceOp(forOp, newResultValues);
980  return success();
981  }
982 };
983 
984 /// Util function that tries to compute a constant diff between u and l.
985 /// Returns std::nullopt when the difference between two AffineValueMap is
986 /// dynamic.
987 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
988  IntegerAttr clb, cub;
989  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
990  llvm::APInt lbValue = clb.getValue();
991  llvm::APInt ubValue = cub.getValue();
992  return (ubValue - lbValue).getSExtValue();
993  }
994 
995  // Else a simple pattern match for x + c or c + x
996  llvm::APInt diff;
997  if (matchPattern(
998  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
999  matchPattern(
1000  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1001  return diff.getSExtValue();
1002  return std::nullopt;
1003 }
1004 
1005 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1006 /// single-iteration loops with their bodies, and removes empty loops that
1007 /// iterate at least once and only return values defined outside of the loop.
1008 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1010 
1011  LogicalResult matchAndRewrite(ForOp op,
1012  PatternRewriter &rewriter) const override {
1013  // If the upper bound is the same as the lower bound, the loop does not
1014  // iterate, just remove it.
1015  if (op.getLowerBound() == op.getUpperBound()) {
1016  rewriter.replaceOp(op, op.getInitArgs());
1017  return success();
1018  }
1019 
1020  std::optional<int64_t> diff =
1021  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1022  if (!diff)
1023  return failure();
1024 
1025  // If the loop is known to have 0 iterations, remove it.
1026  if (*diff <= 0) {
1027  rewriter.replaceOp(op, op.getInitArgs());
1028  return success();
1029  }
1030 
1031  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1032  if (!maybeStepValue)
1033  return failure();
1034 
1035  // If the loop is known to have 1 iteration, inline its body and remove the
1036  // loop.
1037  llvm::APInt stepValue = *maybeStepValue;
1038  if (stepValue.sge(*diff)) {
1039  SmallVector<Value, 4> blockArgs;
1040  blockArgs.reserve(op.getInitArgs().size() + 1);
1041  blockArgs.push_back(op.getLowerBound());
1042  llvm::append_range(blockArgs, op.getInitArgs());
1043  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1044  return success();
1045  }
1046 
1047  // Now we are left with loops that have more than 1 iterations.
1048  Block &block = op.getRegion().front();
1049  if (!llvm::hasSingleElement(block))
1050  return failure();
1051  // If the loop is empty, iterates at least once, and only returns values
1052  // defined outside of the loop, remove it and replace it with yield values.
1053  if (llvm::any_of(op.getYieldedValues(),
1054  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1055  return failure();
1056  rewriter.replaceOp(op, op.getYieldedValues());
1057  return success();
1058  }
1059 };
1060 
1061 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1062 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1063 ///
1064 /// ```
1065 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1066 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1067 /// -> (tensor<?x?xf32>) {
1068 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1069 /// scf.yield %2 : tensor<?x?xf32>
1070 /// }
1071 /// use_of(%1)
1072 /// ```
1073 ///
1074 /// folds into:
1075 ///
1076 /// ```
1077 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1078 /// -> (tensor<32x1024xf32>) {
1079 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1080 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1081 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1082 /// scf.yield %4 : tensor<32x1024xf32>
1083 /// }
1084 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1085 /// use_of(%1)
1086 /// ```
1087 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1089 
1090  LogicalResult matchAndRewrite(ForOp op,
1091  PatternRewriter &rewriter) const override {
1092  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1093  OpOperand &iterOpOperand = std::get<0>(it);
1094  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1095  if (!incomingCast ||
1096  incomingCast.getSource().getType() == incomingCast.getType())
1097  continue;
1098  // If the dest type of the cast does not preserve static information in
1099  // the source type.
1101  incomingCast.getDest().getType(),
1102  incomingCast.getSource().getType()))
1103  continue;
1104  if (!std::get<1>(it).hasOneUse())
1105  continue;
1106 
1107  // Create a new ForOp with that iter operand replaced.
1108  rewriter.replaceOp(
1110  rewriter, op, iterOpOperand, incomingCast.getSource(),
1111  [](OpBuilder &b, Location loc, Type type, Value source) {
1112  return b.create<tensor::CastOp>(loc, type, source);
1113  }));
1114  return success();
1115  }
1116  return failure();
1117  }
1118 };
1119 
1120 } // namespace
1121 
1122 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123  MLIRContext *context) {
1124  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1125  context);
1126 }
1127 
1128 std::optional<APInt> ForOp::getConstantStep() {
1129  IntegerAttr step;
1130  if (matchPattern(getStep(), m_Constant(&step)))
1131  return step.getValue();
1132  return {};
1133 }
1134 
1135 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1136  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1137 }
1138 
1139 Speculation::Speculatability ForOp::getSpeculatability() {
1140  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1141  // and End.
1142  if (auto constantStep = getConstantStep())
1143  if (*constantStep == 1)
1145 
1146  // For Step != 1, the loop may not terminate. We can add more smarts here if
1147  // needed.
1149 }
1150 
1151 //===----------------------------------------------------------------------===//
1152 // ForallOp
1153 //===----------------------------------------------------------------------===//
1154 
1155 LogicalResult ForallOp::verify() {
1156  unsigned numLoops = getRank();
1157  // Check number of outputs.
1158  if (getNumResults() != getOutputs().size())
1159  return emitOpError("produces ")
1160  << getNumResults() << " results, but has only "
1161  << getOutputs().size() << " outputs";
1162 
1163  // Check that the body defines block arguments for thread indices and outputs.
1164  auto *body = getBody();
1165  if (body->getNumArguments() != numLoops + getOutputs().size())
1166  return emitOpError("region expects ") << numLoops << " arguments";
1167  for (int64_t i = 0; i < numLoops; ++i)
1168  if (!body->getArgument(i).getType().isIndex())
1169  return emitOpError("expects ")
1170  << i << "-th block argument to be an index";
1171  for (unsigned i = 0; i < getOutputs().size(); ++i)
1172  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1173  return emitOpError("type mismatch between ")
1174  << i << "-th output and corresponding block argument";
1175  if (getMapping().has_value() && !getMapping()->empty()) {
1176  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1177  return emitOpError() << "mapping attribute size must match op rank";
1178  for (auto map : getMapping()->getValue()) {
1179  if (!isa<DeviceMappingAttrInterface>(map))
1180  return emitOpError()
1181  << getMappingAttrName() << " is not device mapping attribute";
1182  }
1183  }
1184 
1185  // Verify mixed static/dynamic control variables.
1186  Operation *op = getOperation();
1187  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1188  getStaticLowerBound(),
1189  getDynamicLowerBound())))
1190  return failure();
1191  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1192  getStaticUpperBound(),
1193  getDynamicUpperBound())))
1194  return failure();
1195  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1196  getStaticStep(), getDynamicStep())))
1197  return failure();
1198 
1199  return success();
1200 }
1201 
1202 void ForallOp::print(OpAsmPrinter &p) {
1203  Operation *op = getOperation();
1204  p << " (" << getInductionVars();
1205  if (isNormalized()) {
1206  p << ") in ";
1207  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1208  /*valueTypes=*/{}, /*scalables=*/{},
1210  } else {
1211  p << ") = ";
1212  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1213  /*valueTypes=*/{}, /*scalables=*/{},
1215  p << " to ";
1216  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1217  /*valueTypes=*/{}, /*scalables=*/{},
1219  p << " step ";
1220  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1221  /*valueTypes=*/{}, /*scalables=*/{},
1223  }
1224  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1225  p << " ";
1226  if (!getRegionOutArgs().empty())
1227  p << "-> (" << getResultTypes() << ") ";
1228  p.printRegion(getRegion(),
1229  /*printEntryBlockArgs=*/false,
1230  /*printBlockTerminators=*/getNumResults() > 0);
1231  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1232  getStaticLowerBoundAttrName(),
1233  getStaticUpperBoundAttrName(),
1234  getStaticStepAttrName()});
1235 }
1236 
1237 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1238  OpBuilder b(parser.getContext());
1239  auto indexType = b.getIndexType();
1240 
1241  // Parse an opening `(` followed by thread index variables followed by `)`
1242  // TODO: when we can refer to such "induction variable"-like handles from the
1243  // declarative assembly format, we can implement the parser as a custom hook.
1246  return failure();
1247 
1248  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1249  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1250  dynamicSteps;
1251  if (succeeded(parser.parseOptionalKeyword("in"))) {
1252  // Parse upper bounds.
1253  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1254  /*valueTypes=*/nullptr,
1256  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1257  return failure();
1258 
1259  unsigned numLoops = ivs.size();
1260  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1261  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1262  } else {
1263  // Parse lower bounds.
1264  if (parser.parseEqual() ||
1265  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1266  /*valueTypes=*/nullptr,
1268 
1269  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1270  return failure();
1271 
1272  // Parse upper bounds.
1273  if (parser.parseKeyword("to") ||
1274  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1275  /*valueTypes=*/nullptr,
1277  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1278  return failure();
1279 
1280  // Parse step values.
1281  if (parser.parseKeyword("step") ||
1282  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1283  /*valueTypes=*/nullptr,
1285  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1286  return failure();
1287  }
1288 
1289  // Parse out operands and results.
1292  SMLoc outOperandsLoc = parser.getCurrentLocation();
1293  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1294  if (outOperands.size() != result.types.size())
1295  return parser.emitError(outOperandsLoc,
1296  "mismatch between out operands and types");
1297  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1298  parser.parseOptionalArrowTypeList(result.types) ||
1299  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1300  result.operands))
1301  return failure();
1302  }
1303 
1304  // Parse region.
1306  std::unique_ptr<Region> region = std::make_unique<Region>();
1307  for (auto &iv : ivs) {
1308  iv.type = b.getIndexType();
1309  regionArgs.push_back(iv);
1310  }
1311  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1312  auto &out = it.value();
1313  out.type = result.types[it.index()];
1314  regionArgs.push_back(out);
1315  }
1316  if (parser.parseRegion(*region, regionArgs))
1317  return failure();
1318 
1319  // Ensure terminator and move region.
1320  ForallOp::ensureTerminator(*region, b, result.location);
1321  result.addRegion(std::move(region));
1322 
1323  // Parse the optional attribute list.
1324  if (parser.parseOptionalAttrDict(result.attributes))
1325  return failure();
1326 
1327  result.addAttribute("staticLowerBound", staticLbs);
1328  result.addAttribute("staticUpperBound", staticUbs);
1329  result.addAttribute("staticStep", staticSteps);
1330  result.addAttribute("operandSegmentSizes",
1332  {static_cast<int32_t>(dynamicLbs.size()),
1333  static_cast<int32_t>(dynamicUbs.size()),
1334  static_cast<int32_t>(dynamicSteps.size()),
1335  static_cast<int32_t>(outOperands.size())}));
1336  return success();
1337 }
1338 
1339 // Builder that takes loop bounds.
1340 void ForallOp::build(
1343  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1344  std::optional<ArrayAttr> mapping,
1345  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1346  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1347  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1348  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1349  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1350  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1351 
1352  result.addOperands(dynamicLbs);
1353  result.addOperands(dynamicUbs);
1354  result.addOperands(dynamicSteps);
1355  result.addOperands(outputs);
1356  result.addTypes(TypeRange(outputs));
1357 
1358  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1359  b.getDenseI64ArrayAttr(staticLbs));
1360  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1361  b.getDenseI64ArrayAttr(staticUbs));
1362  result.addAttribute(getStaticStepAttrName(result.name),
1363  b.getDenseI64ArrayAttr(staticSteps));
1364  result.addAttribute(
1365  "operandSegmentSizes",
1366  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1367  static_cast<int32_t>(dynamicUbs.size()),
1368  static_cast<int32_t>(dynamicSteps.size()),
1369  static_cast<int32_t>(outputs.size())}));
1370  if (mapping.has_value()) {
1372  mapping.value());
1373  }
1374 
1375  Region *bodyRegion = result.addRegion();
1377  b.createBlock(bodyRegion);
1378  Block &bodyBlock = bodyRegion->front();
1379 
1380  // Add block arguments for indices and outputs.
1381  bodyBlock.addArguments(
1382  SmallVector<Type>(lbs.size(), b.getIndexType()),
1383  SmallVector<Location>(staticLbs.size(), result.location));
1384  bodyBlock.addArguments(
1385  TypeRange(outputs),
1386  SmallVector<Location>(outputs.size(), result.location));
1387 
1388  b.setInsertionPointToStart(&bodyBlock);
1389  if (!bodyBuilderFn) {
1390  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1391  return;
1392  }
1393  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1394 }
1395 
1396 // Builder that takes loop bounds.
1397 void ForallOp::build(
1399  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1400  std::optional<ArrayAttr> mapping,
1401  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1402  unsigned numLoops = ubs.size();
1403  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1404  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1405  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1406 }
1407 
1408 // Checks if the lbs are zeros and steps are ones.
1409 bool ForallOp::isNormalized() {
1410  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1411  return llvm::all_of(results, [&](OpFoldResult ofr) {
1412  auto intValue = getConstantIntValue(ofr);
1413  return intValue.has_value() && intValue == val;
1414  });
1415  };
1416  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1417 }
1418 
1419 InParallelOp ForallOp::getTerminator() {
1420  return cast<InParallelOp>(getBody()->getTerminator());
1421 }
1422 
1423 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1424  SmallVector<Operation *> storeOps;
1425  InParallelOp inParallelOp = getTerminator();
1426  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1427  if (auto parallelInsertSliceOp =
1428  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1429  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1430  storeOps.push_back(parallelInsertSliceOp);
1431  }
1432  }
1433  return storeOps;
1434 }
1435 
1436 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1437  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1438 }
1439 
1440 // Get lower bounds as OpFoldResult.
1441 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1442  Builder b(getOperation()->getContext());
1443  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1444 }
1445 
1446 // Get upper bounds as OpFoldResult.
1447 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1448  Builder b(getOperation()->getContext());
1449  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1450 }
1451 
1452 // Get steps as OpFoldResult.
1453 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1454  Builder b(getOperation()->getContext());
1455  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1456 }
1457 
1459  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1460  if (!tidxArg)
1461  return ForallOp();
1462  assert(tidxArg.getOwner() && "unlinked block argument");
1463  auto *containingOp = tidxArg.getOwner()->getParentOp();
1464  return dyn_cast<ForallOp>(containingOp);
1465 }
1466 
1467 namespace {
1468 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1469 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1471 
1472  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1473  PatternRewriter &rewriter) const final {
1474  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1475  if (!forallOp)
1476  return failure();
1477  Value sharedOut =
1478  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1479  ->get();
1480  rewriter.modifyOpInPlace(
1481  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1482  return success();
1483  }
1484 };
1485 
1486 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1487 public:
1489 
1490  LogicalResult matchAndRewrite(ForallOp op,
1491  PatternRewriter &rewriter) const override {
1492  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1493  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1494  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1495  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1496  failed(foldDynamicIndexList(mixedUpperBound)) &&
1497  failed(foldDynamicIndexList(mixedStep)))
1498  return failure();
1499 
1500  rewriter.modifyOpInPlace(op, [&]() {
1501  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1502  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1503  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1504  staticLowerBound);
1505  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1506  op.setStaticLowerBound(staticLowerBound);
1507 
1508  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1509  staticUpperBound);
1510  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1511  op.setStaticUpperBound(staticUpperBound);
1512 
1513  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1514  op.getDynamicStepMutable().assign(dynamicStep);
1515  op.setStaticStep(staticStep);
1516 
1517  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1518  rewriter.getDenseI32ArrayAttr(
1519  {static_cast<int32_t>(dynamicLowerBound.size()),
1520  static_cast<int32_t>(dynamicUpperBound.size()),
1521  static_cast<int32_t>(dynamicStep.size()),
1522  static_cast<int32_t>(op.getNumResults())}));
1523  });
1524  return success();
1525  }
1526 };
1527 
1528 /// The following canonicalization pattern folds the iter arguments of
1529 /// scf.forall op if :-
1530 /// 1. The corresponding result has zero uses.
1531 /// 2. The iter argument is NOT being modified within the loop body.
1532 /// uses.
1533 ///
1534 /// Example of first case :-
1535 /// INPUT:
1536 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1537 /// {
1538 /// ...
1539 /// <SOME USE OF %arg0>
1540 /// <SOME USE OF %arg1>
1541 /// <SOME USE OF %arg2>
1542 /// ...
1543 /// scf.forall.in_parallel {
1544 /// <STORE OP WITH DESTINATION %arg1>
1545 /// <STORE OP WITH DESTINATION %arg0>
1546 /// <STORE OP WITH DESTINATION %arg2>
1547 /// }
1548 /// }
1549 /// return %res#1
1550 ///
1551 /// OUTPUT:
1552 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1553 /// {
1554 /// ...
1555 /// <SOME USE OF %a>
1556 /// <SOME USE OF %new_arg0>
1557 /// <SOME USE OF %c>
1558 /// ...
1559 /// scf.forall.in_parallel {
1560 /// <STORE OP WITH DESTINATION %new_arg0>
1561 /// }
1562 /// }
1563 /// return %res
1564 ///
1565 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1566 /// scf.forall is replaced by their corresponding operands.
1567 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1568 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1569 /// this canonicalization remains valid. For more details, please refer
1570 /// to :
1571 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1572 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1573 /// handles tensor.parallel_insert_slice ops only.
1574 ///
1575 /// Example of second case :-
1576 /// INPUT:
1577 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1578 /// {
1579 /// ...
1580 /// <SOME USE OF %arg0>
1581 /// <SOME USE OF %arg1>
1582 /// ...
1583 /// scf.forall.in_parallel {
1584 /// <STORE OP WITH DESTINATION %arg1>
1585 /// }
1586 /// }
1587 /// return %res#0, %res#1
1588 ///
1589 /// OUTPUT:
1590 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1591 /// {
1592 /// ...
1593 /// <SOME USE OF %a>
1594 /// <SOME USE OF %new_arg0>
1595 /// ...
1596 /// scf.forall.in_parallel {
1597 /// <STORE OP WITH DESTINATION %new_arg0>
1598 /// }
1599 /// }
1600 /// return %a, %res
1601 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1603 
1604  LogicalResult matchAndRewrite(ForallOp forallOp,
1605  PatternRewriter &rewriter) const final {
1606  // Step 1: For a given i-th result of scf.forall, check the following :-
1607  // a. If it has any use.
1608  // b. If the corresponding iter argument is being modified within
1609  // the loop, i.e. has at least one store op with the iter arg as
1610  // its destination operand. For this we use
1611  // ForallOp::getCombiningOps(iter_arg).
1612  //
1613  // Based on the check we maintain the following :-
1614  // a. `resultToDelete` - i-th result of scf.forall that'll be
1615  // deleted.
1616  // b. `resultToReplace` - i-th result of the old scf.forall
1617  // whose uses will be replaced by the new scf.forall.
1618  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1619  // corresponding to the i-th result with at least one use.
1620  SetVector<OpResult> resultToDelete;
1621  SmallVector<Value> resultToReplace;
1622  SmallVector<Value> newOuts;
1623  for (OpResult result : forallOp.getResults()) {
1624  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1625  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1626  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1627  resultToDelete.insert(result);
1628  } else {
1629  resultToReplace.push_back(result);
1630  newOuts.push_back(opOperand->get());
1631  }
1632  }
1633 
1634  // Return early if all results of scf.forall have at least one use and being
1635  // modified within the loop.
1636  if (resultToDelete.empty())
1637  return failure();
1638 
1639  // Step 2: For the the i-th result, do the following :-
1640  // a. Fetch the corresponding BlockArgument.
1641  // b. Look for store ops (currently tensor.parallel_insert_slice)
1642  // with the BlockArgument as its destination operand.
1643  // c. Remove the operations fetched in b.
1644  for (OpResult result : resultToDelete) {
1645  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1646  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1647  SmallVector<Operation *> combiningOps =
1648  forallOp.getCombiningOps(blockArg);
1649  for (Operation *combiningOp : combiningOps)
1650  rewriter.eraseOp(combiningOp);
1651  }
1652 
1653  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1654  // fetched earlier
1655  auto newForallOp = rewriter.create<scf::ForallOp>(
1656  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1657  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1658  forallOp.getMapping(),
1659  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1660 
1661  // Step 4. Merge the block of the old scf.forall into the newly created
1662  // scf.forall using the new set of arguments.
1663  Block *loopBody = forallOp.getBody();
1664  Block *newLoopBody = newForallOp.getBody();
1665  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1666  // Form initial new bbArg list with just the control operands of the new
1667  // scf.forall op.
1668  SmallVector<Value> newBlockArgs =
1669  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1670  [](BlockArgument b) -> Value { return b; });
1671  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1672  unsigned index = 0;
1673  // Take the new corresponding bbArg if the old bbArg was used as a
1674  // destination in the in_parallel op. For all other bbArgs, use the
1675  // corresponding init_arg from the old scf.forall op.
1676  for (OpResult result : forallOp.getResults()) {
1677  if (resultToDelete.count(result)) {
1678  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1679  } else {
1680  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1681  }
1682  }
1683  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1684 
1685  // Step 5. Replace the uses of result of old scf.forall with that of the new
1686  // scf.forall.
1687  for (auto &&[oldResult, newResult] :
1688  llvm::zip(resultToReplace, newForallOp->getResults()))
1689  rewriter.replaceAllUsesWith(oldResult, newResult);
1690 
1691  // Step 6. Replace the uses of those values that either has no use or are
1692  // not being modified within the loop with the corresponding
1693  // OpOperand.
1694  for (OpResult oldResult : resultToDelete)
1695  rewriter.replaceAllUsesWith(oldResult,
1696  forallOp.getTiedOpOperand(oldResult)->get());
1697  return success();
1698  }
1699 };
1700 
1701 struct ForallOpSingleOrZeroIterationDimsFolder
1702  : public OpRewritePattern<ForallOp> {
1704 
1705  LogicalResult matchAndRewrite(ForallOp op,
1706  PatternRewriter &rewriter) const override {
1707  // Do not fold dimensions if they are mapped to processing units.
1708  if (op.getMapping().has_value() && !op.getMapping()->empty())
1709  return failure();
1710  Location loc = op.getLoc();
1711 
1712  // Compute new loop bounds that omit all single-iteration loop dimensions.
1713  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1714  newMixedSteps;
1715  IRMapping mapping;
1716  for (auto [lb, ub, step, iv] :
1717  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1718  op.getMixedStep(), op.getInductionVars())) {
1719  auto numIterations = constantTripCount(lb, ub, step);
1720  if (numIterations.has_value()) {
1721  // Remove the loop if it performs zero iterations.
1722  if (*numIterations == 0) {
1723  rewriter.replaceOp(op, op.getOutputs());
1724  return success();
1725  }
1726  // Replace the loop induction variable by the lower bound if the loop
1727  // performs a single iteration. Otherwise, copy the loop bounds.
1728  if (*numIterations == 1) {
1729  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1730  continue;
1731  }
1732  }
1733  newMixedLowerBounds.push_back(lb);
1734  newMixedUpperBounds.push_back(ub);
1735  newMixedSteps.push_back(step);
1736  }
1737 
1738  // All of the loop dimensions perform a single iteration. Inline loop body.
1739  if (newMixedLowerBounds.empty()) {
1740  promote(rewriter, op);
1741  return success();
1742  }
1743 
1744  // Exit if none of the loop dimensions perform a single iteration.
1745  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1746  return rewriter.notifyMatchFailure(
1747  op, "no dimensions have 0 or 1 iterations");
1748  }
1749 
1750  // Replace the loop by a lower-dimensional loop.
1751  ForallOp newOp;
1752  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1753  newMixedUpperBounds, newMixedSteps,
1754  op.getOutputs(), std::nullopt, nullptr);
1755  newOp.getBodyRegion().getBlocks().clear();
1756  // The new loop needs to keep all attributes from the old one, except for
1757  // "operandSegmentSizes" and static loop bound attributes which capture
1758  // the outdated information of the old iteration domain.
1759  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1760  newOp.getStaticLowerBoundAttrName(),
1761  newOp.getStaticUpperBoundAttrName(),
1762  newOp.getStaticStepAttrName()};
1763  for (const auto &namedAttr : op->getAttrs()) {
1764  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1765  continue;
1766  rewriter.modifyOpInPlace(newOp, [&]() {
1767  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1768  });
1769  }
1770  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1771  newOp.getRegion().begin(), mapping);
1772  rewriter.replaceOp(op, newOp.getResults());
1773  return success();
1774  }
1775 };
1776 
1777 /// Replace all induction vars with a single trip count with their lower bound.
1778 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1780 
1781  LogicalResult matchAndRewrite(ForallOp op,
1782  PatternRewriter &rewriter) const override {
1783  Location loc = op.getLoc();
1784  bool changed = false;
1785  for (auto [lb, ub, step, iv] :
1786  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1787  op.getMixedStep(), op.getInductionVars())) {
1788  if (iv.getUses().begin() == iv.getUses().end())
1789  continue;
1790  auto numIterations = constantTripCount(lb, ub, step);
1791  if (!numIterations.has_value() || numIterations.value() != 1) {
1792  continue;
1793  }
1794  rewriter.replaceAllUsesWith(
1795  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1796  changed = true;
1797  }
1798  return success(changed);
1799  }
1800 };
1801 
1802 struct FoldTensorCastOfOutputIntoForallOp
1803  : public OpRewritePattern<scf::ForallOp> {
1805 
1806  struct TypeCast {
1807  Type srcType;
1808  Type dstType;
1809  };
1810 
1811  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1812  PatternRewriter &rewriter) const final {
1813  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1814  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1815  for (auto en : llvm::enumerate(newOutputTensors)) {
1816  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1817  if (!castOp)
1818  continue;
1819 
1820  // Only casts that that preserve static information, i.e. will make the
1821  // loop result type "more" static than before, will be folded.
1822  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1823  castOp.getSource().getType())) {
1824  continue;
1825  }
1826 
1827  tensorCastProducers[en.index()] =
1828  TypeCast{castOp.getSource().getType(), castOp.getType()};
1829  newOutputTensors[en.index()] = castOp.getSource();
1830  }
1831 
1832  if (tensorCastProducers.empty())
1833  return failure();
1834 
1835  // Create new loop.
1836  Location loc = forallOp.getLoc();
1837  auto newForallOp = rewriter.create<ForallOp>(
1838  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1839  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1840  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1841  auto castBlockArgs =
1842  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1843  for (auto [index, cast] : tensorCastProducers) {
1844  Value &oldTypeBBArg = castBlockArgs[index];
1845  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1846  nestedLoc, cast.dstType, oldTypeBBArg);
1847  }
1848 
1849  // Move old body into new parallel loop.
1850  SmallVector<Value> ivsBlockArgs =
1851  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1852  ivsBlockArgs.append(castBlockArgs);
1853  rewriter.mergeBlocks(forallOp.getBody(),
1854  bbArgs.front().getParentBlock(), ivsBlockArgs);
1855  });
1856 
1857  // After `mergeBlocks` happened, the destinations in the terminator were
1858  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1859  // destination have to be updated to point to the output bbArgs directly.
1860  auto terminator = newForallOp.getTerminator();
1861  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1862  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1863  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1864  insertSliceOp.getDestMutable().assign(outputBlockArg);
1865  }
1866 
1867  // Cast results back to the original types.
1868  rewriter.setInsertionPointAfter(newForallOp);
1869  SmallVector<Value> castResults = newForallOp.getResults();
1870  for (auto &item : tensorCastProducers) {
1871  Value &oldTypeResult = castResults[item.first];
1872  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1873  oldTypeResult);
1874  }
1875  rewriter.replaceOp(forallOp, castResults);
1876  return success();
1877  }
1878 };
1879 
1880 } // namespace
1881 
1882 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1883  MLIRContext *context) {
1884  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1885  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1886  ForallOpSingleOrZeroIterationDimsFolder,
1887  ForallOpReplaceConstantInductionVar>(context);
1888 }
1889 
1890 /// Given the region at `index`, or the parent operation if `index` is None,
1891 /// return the successor regions. These are the regions that may be selected
1892 /// during the flow of control. `operands` is a set of optional attributes that
1893 /// correspond to a constant value for each operand, or null if that operand is
1894 /// not a constant.
1895 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1897  // Both the operation itself and the region may be branching into the body or
1898  // back into the operation itself. It is possible for loop not to enter the
1899  // body.
1900  regions.push_back(RegionSuccessor(&getRegion()));
1901  regions.push_back(RegionSuccessor());
1902 }
1903 
1904 //===----------------------------------------------------------------------===//
1905 // InParallelOp
1906 //===----------------------------------------------------------------------===//
1907 
1908 // Build a InParallelOp with mixed static and dynamic entries.
1909 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1911  Region *bodyRegion = result.addRegion();
1912  b.createBlock(bodyRegion);
1913 }
1914 
1915 LogicalResult InParallelOp::verify() {
1916  scf::ForallOp forallOp =
1917  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1918  if (!forallOp)
1919  return this->emitOpError("expected forall op parent");
1920 
1921  // TODO: InParallelOpInterface.
1922  for (Operation &op : getRegion().front().getOperations()) {
1923  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1924  return this->emitOpError("expected only ")
1925  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1926  }
1927 
1928  // Verify that inserts are into out block arguments.
1929  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1930  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1931  if (!llvm::is_contained(regionOutArgs, dest))
1932  return op.emitOpError("may only insert into an output block argument");
1933  }
1934  return success();
1935 }
1936 
1938  p << " ";
1939  p.printRegion(getRegion(),
1940  /*printEntryBlockArgs=*/false,
1941  /*printBlockTerminators=*/false);
1942  p.printOptionalAttrDict(getOperation()->getAttrs());
1943 }
1944 
1945 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1946  auto &builder = parser.getBuilder();
1947 
1949  std::unique_ptr<Region> region = std::make_unique<Region>();
1950  if (parser.parseRegion(*region, regionOperands))
1951  return failure();
1952 
1953  if (region->empty())
1954  OpBuilder(builder.getContext()).createBlock(region.get());
1955  result.addRegion(std::move(region));
1956 
1957  // Parse the optional attribute list.
1958  if (parser.parseOptionalAttrDict(result.attributes))
1959  return failure();
1960  return success();
1961 }
1962 
1963 OpResult InParallelOp::getParentResult(int64_t idx) {
1964  return getOperation()->getParentOp()->getResult(idx);
1965 }
1966 
1967 SmallVector<BlockArgument> InParallelOp::getDests() {
1968  return llvm::to_vector<4>(
1969  llvm::map_range(getYieldingOps(), [](Operation &op) {
1970  // Add new ops here as needed.
1971  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1972  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1973  }));
1974 }
1975 
1976 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1977  return getRegion().front().getOperations();
1978 }
1979 
1980 //===----------------------------------------------------------------------===//
1981 // IfOp
1982 //===----------------------------------------------------------------------===//
1983 
1985  assert(a && "expected non-empty operation");
1986  assert(b && "expected non-empty operation");
1987 
1988  IfOp ifOp = a->getParentOfType<IfOp>();
1989  while (ifOp) {
1990  // Check if b is inside ifOp. (We already know that a is.)
1991  if (ifOp->isProperAncestor(b))
1992  // b is contained in ifOp. a and b are in mutually exclusive branches if
1993  // they are in different blocks of ifOp.
1994  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1995  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1996  // Check next enclosing IfOp.
1997  ifOp = ifOp->getParentOfType<IfOp>();
1998  }
1999 
2000  // Could not find a common IfOp among a's and b's ancestors.
2001  return false;
2002 }
2003 
2004 LogicalResult
2005 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2006  IfOp::Adaptor adaptor,
2007  SmallVectorImpl<Type> &inferredReturnTypes) {
2008  if (adaptor.getRegions().empty())
2009  return failure();
2010  Region *r = &adaptor.getThenRegion();
2011  if (r->empty())
2012  return failure();
2013  Block &b = r->front();
2014  if (b.empty())
2015  return failure();
2016  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2017  if (!yieldOp)
2018  return failure();
2019  TypeRange types = yieldOp.getOperandTypes();
2020  llvm::append_range(inferredReturnTypes, types);
2021  return success();
2022 }
2023 
2024 void IfOp::build(OpBuilder &builder, OperationState &result,
2025  TypeRange resultTypes, Value cond) {
2026  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2027  /*addElseBlock=*/false);
2028 }
2029 
2030 void IfOp::build(OpBuilder &builder, OperationState &result,
2031  TypeRange resultTypes, Value cond, bool addThenBlock,
2032  bool addElseBlock) {
2033  assert((!addElseBlock || addThenBlock) &&
2034  "must not create else block w/o then block");
2035  result.addTypes(resultTypes);
2036  result.addOperands(cond);
2037 
2038  // Add regions and blocks.
2039  OpBuilder::InsertionGuard guard(builder);
2040  Region *thenRegion = result.addRegion();
2041  if (addThenBlock)
2042  builder.createBlock(thenRegion);
2043  Region *elseRegion = result.addRegion();
2044  if (addElseBlock)
2045  builder.createBlock(elseRegion);
2046 }
2047 
2048 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2049  bool withElseRegion) {
2050  build(builder, result, TypeRange{}, cond, withElseRegion);
2051 }
2052 
2053 void IfOp::build(OpBuilder &builder, OperationState &result,
2054  TypeRange resultTypes, Value cond, bool withElseRegion) {
2055  result.addTypes(resultTypes);
2056  result.addOperands(cond);
2057 
2058  // Build then region.
2059  OpBuilder::InsertionGuard guard(builder);
2060  Region *thenRegion = result.addRegion();
2061  builder.createBlock(thenRegion);
2062  if (resultTypes.empty())
2063  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2064 
2065  // Build else region.
2066  Region *elseRegion = result.addRegion();
2067  if (withElseRegion) {
2068  builder.createBlock(elseRegion);
2069  if (resultTypes.empty())
2070  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2071  }
2072 }
2073 
2074 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2075  function_ref<void(OpBuilder &, Location)> thenBuilder,
2076  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2077  assert(thenBuilder && "the builder callback for 'then' must be present");
2078  result.addOperands(cond);
2079 
2080  // Build then region.
2081  OpBuilder::InsertionGuard guard(builder);
2082  Region *thenRegion = result.addRegion();
2083  builder.createBlock(thenRegion);
2084  thenBuilder(builder, result.location);
2085 
2086  // Build else region.
2087  Region *elseRegion = result.addRegion();
2088  if (elseBuilder) {
2089  builder.createBlock(elseRegion);
2090  elseBuilder(builder, result.location);
2091  }
2092 
2093  // Infer result types.
2094  SmallVector<Type> inferredReturnTypes;
2095  MLIRContext *ctx = builder.getContext();
2096  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2097  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2098  /*properties=*/nullptr, result.regions,
2099  inferredReturnTypes))) {
2100  result.addTypes(inferredReturnTypes);
2101  }
2102 }
2103 
2104 LogicalResult IfOp::verify() {
2105  if (getNumResults() != 0 && getElseRegion().empty())
2106  return emitOpError("must have an else block if defining values");
2107  return success();
2108 }
2109 
2110 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2111  // Create the regions for 'then'.
2112  result.regions.reserve(2);
2113  Region *thenRegion = result.addRegion();
2114  Region *elseRegion = result.addRegion();
2115 
2116  auto &builder = parser.getBuilder();
2118  Type i1Type = builder.getIntegerType(1);
2119  if (parser.parseOperand(cond) ||
2120  parser.resolveOperand(cond, i1Type, result.operands))
2121  return failure();
2122  // Parse optional results type list.
2123  if (parser.parseOptionalArrowTypeList(result.types))
2124  return failure();
2125  // Parse the 'then' region.
2126  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2127  return failure();
2128  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2129 
2130  // If we find an 'else' keyword then parse the 'else' region.
2131  if (!parser.parseOptionalKeyword("else")) {
2132  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2133  return failure();
2134  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2135  }
2136 
2137  // Parse the optional attribute list.
2138  if (parser.parseOptionalAttrDict(result.attributes))
2139  return failure();
2140  return success();
2141 }
2142 
2143 void IfOp::print(OpAsmPrinter &p) {
2144  bool printBlockTerminators = false;
2145 
2146  p << " " << getCondition();
2147  if (!getResults().empty()) {
2148  p << " -> (" << getResultTypes() << ")";
2149  // Print yield explicitly if the op defines values.
2150  printBlockTerminators = true;
2151  }
2152  p << ' ';
2153  p.printRegion(getThenRegion(),
2154  /*printEntryBlockArgs=*/false,
2155  /*printBlockTerminators=*/printBlockTerminators);
2156 
2157  // Print the 'else' regions if it exists and has a block.
2158  auto &elseRegion = getElseRegion();
2159  if (!elseRegion.empty()) {
2160  p << " else ";
2161  p.printRegion(elseRegion,
2162  /*printEntryBlockArgs=*/false,
2163  /*printBlockTerminators=*/printBlockTerminators);
2164  }
2165 
2166  p.printOptionalAttrDict((*this)->getAttrs());
2167 }
2168 
2169 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2171  // The `then` and the `else` region branch back to the parent operation.
2172  if (!point.isParent()) {
2173  regions.push_back(RegionSuccessor(getResults()));
2174  return;
2175  }
2176 
2177  regions.push_back(RegionSuccessor(&getThenRegion()));
2178 
2179  // Don't consider the else region if it is empty.
2180  Region *elseRegion = &this->getElseRegion();
2181  if (elseRegion->empty())
2182  regions.push_back(RegionSuccessor());
2183  else
2184  regions.push_back(RegionSuccessor(elseRegion));
2185 }
2186 
2187 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2189  FoldAdaptor adaptor(operands, *this);
2190  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2191  if (!boolAttr || boolAttr.getValue())
2192  regions.emplace_back(&getThenRegion());
2193 
2194  // If the else region is empty, execution continues after the parent op.
2195  if (!boolAttr || !boolAttr.getValue()) {
2196  if (!getElseRegion().empty())
2197  regions.emplace_back(&getElseRegion());
2198  else
2199  regions.emplace_back(getResults());
2200  }
2201 }
2202 
2203 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2204  SmallVectorImpl<OpFoldResult> &results) {
2205  // if (!c) then A() else B() -> if c then B() else A()
2206  if (getElseRegion().empty())
2207  return failure();
2208 
2209  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2210  if (!xorStmt)
2211  return failure();
2212 
2213  if (!matchPattern(xorStmt.getRhs(), m_One()))
2214  return failure();
2215 
2216  getConditionMutable().assign(xorStmt.getLhs());
2217  Block *thenBlock = &getThenRegion().front();
2218  // It would be nicer to use iplist::swap, but that has no implemented
2219  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2220  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2221  getElseRegion().getBlocks());
2222  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2223  getThenRegion().getBlocks(), thenBlock);
2224  return success();
2225 }
2226 
2227 void IfOp::getRegionInvocationBounds(
2228  ArrayRef<Attribute> operands,
2229  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2230  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2231  // If the condition is known, then one region is known to be executed once
2232  // and the other zero times.
2233  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2234  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2235  } else {
2236  // Non-constant condition. Each region may be executed 0 or 1 times.
2237  invocationBounds.assign(2, {0, 1});
2238  }
2239 }
2240 
2241 namespace {
2242 // Pattern to remove unused IfOp results.
2243 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2245 
2246  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2247  PatternRewriter &rewriter) const {
2248  // Move all operations to the destination block.
2249  rewriter.mergeBlocks(source, dest);
2250  // Replace the yield op by one that returns only the used values.
2251  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2252  SmallVector<Value, 4> usedOperands;
2253  llvm::transform(usedResults, std::back_inserter(usedOperands),
2254  [&](OpResult result) {
2255  return yieldOp.getOperand(result.getResultNumber());
2256  });
2257  rewriter.modifyOpInPlace(yieldOp,
2258  [&]() { yieldOp->setOperands(usedOperands); });
2259  }
2260 
2261  LogicalResult matchAndRewrite(IfOp op,
2262  PatternRewriter &rewriter) const override {
2263  // Compute the list of used results.
2264  SmallVector<OpResult, 4> usedResults;
2265  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2266  [](OpResult result) { return !result.use_empty(); });
2267 
2268  // Replace the operation if only a subset of its results have uses.
2269  if (usedResults.size() == op.getNumResults())
2270  return failure();
2271 
2272  // Compute the result types of the replacement operation.
2273  SmallVector<Type, 4> newTypes;
2274  llvm::transform(usedResults, std::back_inserter(newTypes),
2275  [](OpResult result) { return result.getType(); });
2276 
2277  // Create a replacement operation with empty then and else regions.
2278  auto newOp =
2279  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2280  rewriter.createBlock(&newOp.getThenRegion());
2281  rewriter.createBlock(&newOp.getElseRegion());
2282 
2283  // Move the bodies and replace the terminators (note there is a then and
2284  // an else region since the operation returns results).
2285  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2286  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2287 
2288  // Replace the operation by the new one.
2289  SmallVector<Value, 4> repResults(op.getNumResults());
2290  for (const auto &en : llvm::enumerate(usedResults))
2291  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2292  rewriter.replaceOp(op, repResults);
2293  return success();
2294  }
2295 };
2296 
2297 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2299 
2300  LogicalResult matchAndRewrite(IfOp op,
2301  PatternRewriter &rewriter) const override {
2302  BoolAttr condition;
2303  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2304  return failure();
2305 
2306  if (condition.getValue())
2307  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2308  else if (!op.getElseRegion().empty())
2309  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2310  else
2311  rewriter.eraseOp(op);
2312 
2313  return success();
2314  }
2315 };
2316 
2317 /// Hoist any yielded results whose operands are defined outside
2318 /// the if, to a select instruction.
2319 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2321 
2322  LogicalResult matchAndRewrite(IfOp op,
2323  PatternRewriter &rewriter) const override {
2324  if (op->getNumResults() == 0)
2325  return failure();
2326 
2327  auto cond = op.getCondition();
2328  auto thenYieldArgs = op.thenYield().getOperands();
2329  auto elseYieldArgs = op.elseYield().getOperands();
2330 
2331  SmallVector<Type> nonHoistable;
2332  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2333  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2334  &op.getElseRegion() == falseVal.getParentRegion())
2335  nonHoistable.push_back(trueVal.getType());
2336  }
2337  // Early exit if there aren't any yielded values we can
2338  // hoist outside the if.
2339  if (nonHoistable.size() == op->getNumResults())
2340  return failure();
2341 
2342  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2343  /*withElseRegion=*/false);
2344  if (replacement.thenBlock())
2345  rewriter.eraseBlock(replacement.thenBlock());
2346  replacement.getThenRegion().takeBody(op.getThenRegion());
2347  replacement.getElseRegion().takeBody(op.getElseRegion());
2348 
2349  SmallVector<Value> results(op->getNumResults());
2350  assert(thenYieldArgs.size() == results.size());
2351  assert(elseYieldArgs.size() == results.size());
2352 
2353  SmallVector<Value> trueYields;
2354  SmallVector<Value> falseYields;
2355  rewriter.setInsertionPoint(replacement);
2356  for (const auto &it :
2357  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2358  Value trueVal = std::get<0>(it.value());
2359  Value falseVal = std::get<1>(it.value());
2360  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2361  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2362  results[it.index()] = replacement.getResult(trueYields.size());
2363  trueYields.push_back(trueVal);
2364  falseYields.push_back(falseVal);
2365  } else if (trueVal == falseVal)
2366  results[it.index()] = trueVal;
2367  else
2368  results[it.index()] = rewriter.create<arith::SelectOp>(
2369  op.getLoc(), cond, trueVal, falseVal);
2370  }
2371 
2372  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2373  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2374 
2375  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2376  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2377 
2378  rewriter.replaceOp(op, results);
2379  return success();
2380  }
2381 };
2382 
2383 /// Allow the true region of an if to assume the condition is true
2384 /// and vice versa. For example:
2385 ///
2386 /// scf.if %cmp {
2387 /// print(%cmp)
2388 /// }
2389 ///
2390 /// becomes
2391 ///
2392 /// scf.if %cmp {
2393 /// print(true)
2394 /// }
2395 ///
2396 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2398 
2399  LogicalResult matchAndRewrite(IfOp op,
2400  PatternRewriter &rewriter) const override {
2401  // Early exit if the condition is constant since replacing a constant
2402  // in the body with another constant isn't a simplification.
2403  if (matchPattern(op.getCondition(), m_Constant()))
2404  return failure();
2405 
2406  bool changed = false;
2407  mlir::Type i1Ty = rewriter.getI1Type();
2408 
2409  // These variables serve to prevent creating duplicate constants
2410  // and hold constant true or false values.
2411  Value constantTrue = nullptr;
2412  Value constantFalse = nullptr;
2413 
2414  for (OpOperand &use :
2415  llvm::make_early_inc_range(op.getCondition().getUses())) {
2416  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2417  changed = true;
2418 
2419  if (!constantTrue)
2420  constantTrue = rewriter.create<arith::ConstantOp>(
2421  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2422 
2423  rewriter.modifyOpInPlace(use.getOwner(),
2424  [&]() { use.set(constantTrue); });
2425  } else if (op.getElseRegion().isAncestor(
2426  use.getOwner()->getParentRegion())) {
2427  changed = true;
2428 
2429  if (!constantFalse)
2430  constantFalse = rewriter.create<arith::ConstantOp>(
2431  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2432 
2433  rewriter.modifyOpInPlace(use.getOwner(),
2434  [&]() { use.set(constantFalse); });
2435  }
2436  }
2437 
2438  return success(changed);
2439  }
2440 };
2441 
2442 /// Remove any statements from an if that are equivalent to the condition
2443 /// or its negation. For example:
2444 ///
2445 /// %res:2 = scf.if %cmp {
2446 /// yield something(), true
2447 /// } else {
2448 /// yield something2(), false
2449 /// }
2450 /// print(%res#1)
2451 ///
2452 /// becomes
2453 /// %res = scf.if %cmp {
2454 /// yield something()
2455 /// } else {
2456 /// yield something2()
2457 /// }
2458 /// print(%cmp)
2459 ///
2460 /// Additionally if both branches yield the same value, replace all uses
2461 /// of the result with the yielded value.
2462 ///
2463 /// %res:2 = scf.if %cmp {
2464 /// yield something(), %arg1
2465 /// } else {
2466 /// yield something2(), %arg1
2467 /// }
2468 /// print(%res#1)
2469 ///
2470 /// becomes
2471 /// %res = scf.if %cmp {
2472 /// yield something()
2473 /// } else {
2474 /// yield something2()
2475 /// }
2476 /// print(%arg1)
2477 ///
2478 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2480 
2481  LogicalResult matchAndRewrite(IfOp op,
2482  PatternRewriter &rewriter) const override {
2483  // Early exit if there are no results that could be replaced.
2484  if (op.getNumResults() == 0)
2485  return failure();
2486 
2487  auto trueYield =
2488  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2489  auto falseYield =
2490  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2491 
2492  rewriter.setInsertionPoint(op->getBlock(),
2493  op.getOperation()->getIterator());
2494  bool changed = false;
2495  Type i1Ty = rewriter.getI1Type();
2496  for (auto [trueResult, falseResult, opResult] :
2497  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2498  op.getResults())) {
2499  if (trueResult == falseResult) {
2500  if (!opResult.use_empty()) {
2501  opResult.replaceAllUsesWith(trueResult);
2502  changed = true;
2503  }
2504  continue;
2505  }
2506 
2507  BoolAttr trueYield, falseYield;
2508  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2509  !matchPattern(falseResult, m_Constant(&falseYield)))
2510  continue;
2511 
2512  bool trueVal = trueYield.getValue();
2513  bool falseVal = falseYield.getValue();
2514  if (!trueVal && falseVal) {
2515  if (!opResult.use_empty()) {
2516  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2517  Value notCond = rewriter.create<arith::XOrIOp>(
2518  op.getLoc(), op.getCondition(),
2519  constDialect
2520  ->materializeConstant(rewriter,
2521  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2522  op.getLoc())
2523  ->getResult(0));
2524  opResult.replaceAllUsesWith(notCond);
2525  changed = true;
2526  }
2527  }
2528  if (trueVal && !falseVal) {
2529  if (!opResult.use_empty()) {
2530  opResult.replaceAllUsesWith(op.getCondition());
2531  changed = true;
2532  }
2533  }
2534  }
2535  return success(changed);
2536  }
2537 };
2538 
2539 /// Merge any consecutive scf.if's with the same condition.
2540 ///
2541 /// scf.if %cond {
2542 /// firstCodeTrue();...
2543 /// } else {
2544 /// firstCodeFalse();...
2545 /// }
2546 /// %res = scf.if %cond {
2547 /// secondCodeTrue();...
2548 /// } else {
2549 /// secondCodeFalse();...
2550 /// }
2551 ///
2552 /// becomes
2553 /// %res = scf.if %cmp {
2554 /// firstCodeTrue();...
2555 /// secondCodeTrue();...
2556 /// } else {
2557 /// firstCodeFalse();...
2558 /// secondCodeFalse();...
2559 /// }
2560 struct CombineIfs : public OpRewritePattern<IfOp> {
2562 
2563  LogicalResult matchAndRewrite(IfOp nextIf,
2564  PatternRewriter &rewriter) const override {
2565  Block *parent = nextIf->getBlock();
2566  if (nextIf == &parent->front())
2567  return failure();
2568 
2569  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2570  if (!prevIf)
2571  return failure();
2572 
2573  // Determine the logical then/else blocks when prevIf's
2574  // condition is used. Null means the block does not exist
2575  // in that case (e.g. empty else). If neither of these
2576  // are set, the two conditions cannot be compared.
2577  Block *nextThen = nullptr;
2578  Block *nextElse = nullptr;
2579  if (nextIf.getCondition() == prevIf.getCondition()) {
2580  nextThen = nextIf.thenBlock();
2581  if (!nextIf.getElseRegion().empty())
2582  nextElse = nextIf.elseBlock();
2583  }
2584  if (arith::XOrIOp notv =
2585  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2586  if (notv.getLhs() == prevIf.getCondition() &&
2587  matchPattern(notv.getRhs(), m_One())) {
2588  nextElse = nextIf.thenBlock();
2589  if (!nextIf.getElseRegion().empty())
2590  nextThen = nextIf.elseBlock();
2591  }
2592  }
2593  if (arith::XOrIOp notv =
2594  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2595  if (notv.getLhs() == nextIf.getCondition() &&
2596  matchPattern(notv.getRhs(), m_One())) {
2597  nextElse = nextIf.thenBlock();
2598  if (!nextIf.getElseRegion().empty())
2599  nextThen = nextIf.elseBlock();
2600  }
2601  }
2602 
2603  if (!nextThen && !nextElse)
2604  return failure();
2605 
2606  SmallVector<Value> prevElseYielded;
2607  if (!prevIf.getElseRegion().empty())
2608  prevElseYielded = prevIf.elseYield().getOperands();
2609  // Replace all uses of return values of op within nextIf with the
2610  // corresponding yields
2611  for (auto it : llvm::zip(prevIf.getResults(),
2612  prevIf.thenYield().getOperands(), prevElseYielded))
2613  for (OpOperand &use :
2614  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2615  if (nextThen && nextThen->getParent()->isAncestor(
2616  use.getOwner()->getParentRegion())) {
2617  rewriter.startOpModification(use.getOwner());
2618  use.set(std::get<1>(it));
2619  rewriter.finalizeOpModification(use.getOwner());
2620  } else if (nextElse && nextElse->getParent()->isAncestor(
2621  use.getOwner()->getParentRegion())) {
2622  rewriter.startOpModification(use.getOwner());
2623  use.set(std::get<2>(it));
2624  rewriter.finalizeOpModification(use.getOwner());
2625  }
2626  }
2627 
2628  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2629  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2630 
2631  IfOp combinedIf = rewriter.create<IfOp>(
2632  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2633  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2634 
2635  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2636  combinedIf.getThenRegion(),
2637  combinedIf.getThenRegion().begin());
2638 
2639  if (nextThen) {
2640  YieldOp thenYield = combinedIf.thenYield();
2641  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2642  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2643  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2644 
2645  SmallVector<Value> mergedYields(thenYield.getOperands());
2646  llvm::append_range(mergedYields, thenYield2.getOperands());
2647  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2648  rewriter.eraseOp(thenYield);
2649  rewriter.eraseOp(thenYield2);
2650  }
2651 
2652  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2653  combinedIf.getElseRegion(),
2654  combinedIf.getElseRegion().begin());
2655 
2656  if (nextElse) {
2657  if (combinedIf.getElseRegion().empty()) {
2658  rewriter.inlineRegionBefore(*nextElse->getParent(),
2659  combinedIf.getElseRegion(),
2660  combinedIf.getElseRegion().begin());
2661  } else {
2662  YieldOp elseYield = combinedIf.elseYield();
2663  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2664  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2665 
2666  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2667 
2668  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2669  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2670 
2671  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2672  rewriter.eraseOp(elseYield);
2673  rewriter.eraseOp(elseYield2);
2674  }
2675  }
2676 
2677  SmallVector<Value> prevValues;
2678  SmallVector<Value> nextValues;
2679  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2680  if (pair.index() < prevIf.getNumResults())
2681  prevValues.push_back(pair.value());
2682  else
2683  nextValues.push_back(pair.value());
2684  }
2685  rewriter.replaceOp(prevIf, prevValues);
2686  rewriter.replaceOp(nextIf, nextValues);
2687  return success();
2688  }
2689 };
2690 
2691 /// Pattern to remove an empty else branch.
2692 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2694 
2695  LogicalResult matchAndRewrite(IfOp ifOp,
2696  PatternRewriter &rewriter) const override {
2697  // Cannot remove else region when there are operation results.
2698  if (ifOp.getNumResults())
2699  return failure();
2700  Block *elseBlock = ifOp.elseBlock();
2701  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2702  return failure();
2703  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2704  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2705  newIfOp.getThenRegion().begin());
2706  rewriter.eraseOp(ifOp);
2707  return success();
2708  }
2709 };
2710 
2711 /// Convert nested `if`s into `arith.andi` + single `if`.
2712 ///
2713 /// scf.if %arg0 {
2714 /// scf.if %arg1 {
2715 /// ...
2716 /// scf.yield
2717 /// }
2718 /// scf.yield
2719 /// }
2720 /// becomes
2721 ///
2722 /// %0 = arith.andi %arg0, %arg1
2723 /// scf.if %0 {
2724 /// ...
2725 /// scf.yield
2726 /// }
2727 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2729 
2730  LogicalResult matchAndRewrite(IfOp op,
2731  PatternRewriter &rewriter) const override {
2732  auto nestedOps = op.thenBlock()->without_terminator();
2733  // Nested `if` must be the only op in block.
2734  if (!llvm::hasSingleElement(nestedOps))
2735  return failure();
2736 
2737  // If there is an else block, it can only yield
2738  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2739  return failure();
2740 
2741  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2742  if (!nestedIf)
2743  return failure();
2744 
2745  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2746  return failure();
2747 
2748  SmallVector<Value> thenYield(op.thenYield().getOperands());
2749  SmallVector<Value> elseYield;
2750  if (op.elseBlock())
2751  llvm::append_range(elseYield, op.elseYield().getOperands());
2752 
2753  // A list of indices for which we should upgrade the value yielded
2754  // in the else to a select.
2755  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2756 
2757  // If the outer scf.if yields a value produced by the inner scf.if,
2758  // only permit combining if the value yielded when the condition
2759  // is false in the outer scf.if is the same value yielded when the
2760  // inner scf.if condition is false.
2761  // Note that the array access to elseYield will not go out of bounds
2762  // since it must have the same length as thenYield, since they both
2763  // come from the same scf.if.
2764  for (const auto &tup : llvm::enumerate(thenYield)) {
2765  if (tup.value().getDefiningOp() == nestedIf) {
2766  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2767  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2768  elseYield[tup.index()]) {
2769  return failure();
2770  }
2771  // If the correctness test passes, we will yield
2772  // corresponding value from the inner scf.if
2773  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2774  continue;
2775  }
2776 
2777  // Otherwise, we need to ensure the else block of the combined
2778  // condition still returns the same value when the outer condition is
2779  // true and the inner condition is false. This can be accomplished if
2780  // the then value is defined outside the outer scf.if and we replace the
2781  // value with a select that considers just the outer condition. Since
2782  // the else region contains just the yield, its yielded value is
2783  // defined outside the scf.if, by definition.
2784 
2785  // If the then value is defined within the scf.if, bail.
2786  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2787  return failure();
2788  }
2789  elseYieldsToUpgradeToSelect.push_back(tup.index());
2790  }
2791 
2792  Location loc = op.getLoc();
2793  Value newCondition = rewriter.create<arith::AndIOp>(
2794  loc, op.getCondition(), nestedIf.getCondition());
2795  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2796  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2797 
2798  SmallVector<Value> results;
2799  llvm::append_range(results, newIf.getResults());
2800  rewriter.setInsertionPoint(newIf);
2801 
2802  for (auto idx : elseYieldsToUpgradeToSelect)
2803  results[idx] = rewriter.create<arith::SelectOp>(
2804  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2805 
2806  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2807  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2808  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2809  if (!elseYield.empty()) {
2810  rewriter.createBlock(&newIf.getElseRegion());
2811  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2812  rewriter.create<YieldOp>(loc, elseYield);
2813  }
2814  rewriter.replaceOp(op, results);
2815  return success();
2816  }
2817 };
2818 
2819 } // namespace
2820 
2821 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2822  MLIRContext *context) {
2823  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2824  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2825  RemoveStaticCondition, RemoveUnusedResults,
2826  ReplaceIfYieldWithConditionOrValue>(context);
2827 }
2828 
2829 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2830 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2831 Block *IfOp::elseBlock() {
2832  Region &r = getElseRegion();
2833  if (r.empty())
2834  return nullptr;
2835  return &r.back();
2836 }
2837 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2838 
2839 //===----------------------------------------------------------------------===//
2840 // ParallelOp
2841 //===----------------------------------------------------------------------===//
2842 
2843 void ParallelOp::build(
2844  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2845  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2847  bodyBuilderFn) {
2848  result.addOperands(lowerBounds);
2849  result.addOperands(upperBounds);
2850  result.addOperands(steps);
2851  result.addOperands(initVals);
2852  result.addAttribute(
2853  ParallelOp::getOperandSegmentSizeAttr(),
2854  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2855  static_cast<int32_t>(upperBounds.size()),
2856  static_cast<int32_t>(steps.size()),
2857  static_cast<int32_t>(initVals.size())}));
2858  result.addTypes(initVals.getTypes());
2859 
2860  OpBuilder::InsertionGuard guard(builder);
2861  unsigned numIVs = steps.size();
2862  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2863  SmallVector<Location, 8> argLocs(numIVs, result.location);
2864  Region *bodyRegion = result.addRegion();
2865  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2866 
2867  if (bodyBuilderFn) {
2868  builder.setInsertionPointToStart(bodyBlock);
2869  bodyBuilderFn(builder, result.location,
2870  bodyBlock->getArguments().take_front(numIVs),
2871  bodyBlock->getArguments().drop_front(numIVs));
2872  }
2873  // Add terminator only if there are no reductions.
2874  if (initVals.empty())
2875  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2876 }
2877 
2878 void ParallelOp::build(
2879  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2880  ValueRange upperBounds, ValueRange steps,
2881  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2882  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2883  // we don't capture a reference to a temporary by constructing the lambda at
2884  // function level.
2885  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2886  Location nestedLoc, ValueRange ivs,
2887  ValueRange) {
2888  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2889  };
2890  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2891  if (bodyBuilderFn)
2892  wrapper = wrappedBuilderFn;
2893 
2894  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2895  wrapper);
2896 }
2897 
2898 LogicalResult ParallelOp::verify() {
2899  // Check that there is at least one value in lowerBound, upperBound and step.
2900  // It is sufficient to test only step, because it is ensured already that the
2901  // number of elements in lowerBound, upperBound and step are the same.
2902  Operation::operand_range stepValues = getStep();
2903  if (stepValues.empty())
2904  return emitOpError(
2905  "needs at least one tuple element for lowerBound, upperBound and step");
2906 
2907  // Check whether all constant step values are positive.
2908  for (Value stepValue : stepValues)
2909  if (auto cst = getConstantIntValue(stepValue))
2910  if (*cst <= 0)
2911  return emitOpError("constant step operand must be positive");
2912 
2913  // Check that the body defines the same number of block arguments as the
2914  // number of tuple elements in step.
2915  Block *body = getBody();
2916  if (body->getNumArguments() != stepValues.size())
2917  return emitOpError() << "expects the same number of induction variables: "
2918  << body->getNumArguments()
2919  << " as bound and step values: " << stepValues.size();
2920  for (auto arg : body->getArguments())
2921  if (!arg.getType().isIndex())
2922  return emitOpError(
2923  "expects arguments for the induction variable to be of index type");
2924 
2925  // Check that the terminator is an scf.reduce op.
2926  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2927  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2928  if (!reduceOp)
2929  return failure();
2930 
2931  // Check that the number of results is the same as the number of reductions.
2932  auto resultsSize = getResults().size();
2933  auto reductionsSize = reduceOp.getReductions().size();
2934  auto initValsSize = getInitVals().size();
2935  if (resultsSize != reductionsSize)
2936  return emitOpError() << "expects number of results: " << resultsSize
2937  << " to be the same as number of reductions: "
2938  << reductionsSize;
2939  if (resultsSize != initValsSize)
2940  return emitOpError() << "expects number of results: " << resultsSize
2941  << " to be the same as number of initial values: "
2942  << initValsSize;
2943 
2944  // Check that the types of the results and reductions are the same.
2945  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2946  auto resultType = getOperation()->getResult(i).getType();
2947  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2948  if (resultType != reductionOperandType)
2949  return reduceOp.emitOpError()
2950  << "expects type of " << i
2951  << "-th reduction operand: " << reductionOperandType
2952  << " to be the same as the " << i
2953  << "-th result type: " << resultType;
2954  }
2955  return success();
2956 }
2957 
2958 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2959  auto &builder = parser.getBuilder();
2960  // Parse an opening `(` followed by induction variables followed by `)`
2963  return failure();
2964 
2965  // Parse loop bounds.
2967  if (parser.parseEqual() ||
2968  parser.parseOperandList(lower, ivs.size(),
2970  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2971  return failure();
2972 
2974  if (parser.parseKeyword("to") ||
2975  parser.parseOperandList(upper, ivs.size(),
2977  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2978  return failure();
2979 
2980  // Parse step values.
2982  if (parser.parseKeyword("step") ||
2983  parser.parseOperandList(steps, ivs.size(),
2985  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2986  return failure();
2987 
2988  // Parse init values.
2990  if (succeeded(parser.parseOptionalKeyword("init"))) {
2991  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2992  return failure();
2993  }
2994 
2995  // Parse optional results in case there is a reduce.
2996  if (parser.parseOptionalArrowTypeList(result.types))
2997  return failure();
2998 
2999  // Now parse the body.
3000  Region *body = result.addRegion();
3001  for (auto &iv : ivs)
3002  iv.type = builder.getIndexType();
3003  if (parser.parseRegion(*body, ivs))
3004  return failure();
3005 
3006  // Set `operandSegmentSizes` attribute.
3007  result.addAttribute(
3008  ParallelOp::getOperandSegmentSizeAttr(),
3009  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3010  static_cast<int32_t>(upper.size()),
3011  static_cast<int32_t>(steps.size()),
3012  static_cast<int32_t>(initVals.size())}));
3013 
3014  // Parse attributes.
3015  if (parser.parseOptionalAttrDict(result.attributes) ||
3016  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3017  result.operands))
3018  return failure();
3019 
3020  // Add a terminator if none was parsed.
3021  ParallelOp::ensureTerminator(*body, builder, result.location);
3022  return success();
3023 }
3024 
3026  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3027  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3028  if (!getInitVals().empty())
3029  p << " init (" << getInitVals() << ")";
3030  p.printOptionalArrowTypeList(getResultTypes());
3031  p << ' ';
3032  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3034  (*this)->getAttrs(),
3035  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3036 }
3037 
3038 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3039 
3040 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3041  return SmallVector<Value>{getBody()->getArguments()};
3042 }
3043 
3044 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3045  return getLowerBound();
3046 }
3047 
3048 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3049  return getUpperBound();
3050 }
3051 
3052 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3053  return getStep();
3054 }
3055 
3057  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3058  if (!ivArg)
3059  return ParallelOp();
3060  assert(ivArg.getOwner() && "unlinked block argument");
3061  auto *containingOp = ivArg.getOwner()->getParentOp();
3062  return dyn_cast<ParallelOp>(containingOp);
3063 }
3064 
3065 namespace {
3066 // Collapse loop dimensions that perform a single iteration.
3067 struct ParallelOpSingleOrZeroIterationDimsFolder
3068  : public OpRewritePattern<ParallelOp> {
3070 
3071  LogicalResult matchAndRewrite(ParallelOp op,
3072  PatternRewriter &rewriter) const override {
3073  Location loc = op.getLoc();
3074 
3075  // Compute new loop bounds that omit all single-iteration loop dimensions.
3076  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3077  IRMapping mapping;
3078  for (auto [lb, ub, step, iv] :
3079  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3080  op.getInductionVars())) {
3081  auto numIterations = constantTripCount(lb, ub, step);
3082  if (numIterations.has_value()) {
3083  // Remove the loop if it performs zero iterations.
3084  if (*numIterations == 0) {
3085  rewriter.replaceOp(op, op.getInitVals());
3086  return success();
3087  }
3088  // Replace the loop induction variable by the lower bound if the loop
3089  // performs a single iteration. Otherwise, copy the loop bounds.
3090  if (*numIterations == 1) {
3091  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3092  continue;
3093  }
3094  }
3095  newLowerBounds.push_back(lb);
3096  newUpperBounds.push_back(ub);
3097  newSteps.push_back(step);
3098  }
3099  // Exit if none of the loop dimensions perform a single iteration.
3100  if (newLowerBounds.size() == op.getLowerBound().size())
3101  return failure();
3102 
3103  if (newLowerBounds.empty()) {
3104  // All of the loop dimensions perform a single iteration. Inline
3105  // loop body and nested ReduceOp's
3106  SmallVector<Value> results;
3107  results.reserve(op.getInitVals().size());
3108  for (auto &bodyOp : op.getBody()->without_terminator())
3109  rewriter.clone(bodyOp, mapping);
3110  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3111  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3112  Block &reduceBlock = reduceOp.getReductions()[i].front();
3113  auto initValIndex = results.size();
3114  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3115  mapping.map(reduceBlock.getArgument(1),
3116  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3117  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3118  rewriter.clone(reduceBodyOp, mapping);
3119 
3120  auto result = mapping.lookupOrDefault(
3121  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3122  results.push_back(result);
3123  }
3124 
3125  rewriter.replaceOp(op, results);
3126  return success();
3127  }
3128  // Replace the parallel loop by lower-dimensional parallel loop.
3129  auto newOp =
3130  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3131  newSteps, op.getInitVals(), nullptr);
3132  // Erase the empty block that was inserted by the builder.
3133  rewriter.eraseBlock(newOp.getBody());
3134  // Clone the loop body and remap the block arguments of the collapsed loops
3135  // (inlining does not support a cancellable block argument mapping).
3136  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3137  newOp.getRegion().begin(), mapping);
3138  rewriter.replaceOp(op, newOp.getResults());
3139  return success();
3140  }
3141 };
3142 
3143 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3145 
3146  LogicalResult matchAndRewrite(ParallelOp op,
3147  PatternRewriter &rewriter) const override {
3148  Block &outerBody = *op.getBody();
3149  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3150  return failure();
3151 
3152  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3153  if (!innerOp)
3154  return failure();
3155 
3156  for (auto val : outerBody.getArguments())
3157  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3158  llvm::is_contained(innerOp.getUpperBound(), val) ||
3159  llvm::is_contained(innerOp.getStep(), val))
3160  return failure();
3161 
3162  // Reductions are not supported yet.
3163  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3164  return failure();
3165 
3166  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3167  ValueRange iterVals, ValueRange) {
3168  Block &innerBody = *innerOp.getBody();
3169  assert(iterVals.size() ==
3170  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3171  IRMapping mapping;
3172  mapping.map(outerBody.getArguments(),
3173  iterVals.take_front(outerBody.getNumArguments()));
3174  mapping.map(innerBody.getArguments(),
3175  iterVals.take_back(innerBody.getNumArguments()));
3176  for (Operation &op : innerBody.without_terminator())
3177  builder.clone(op, mapping);
3178  };
3179 
3180  auto concatValues = [](const auto &first, const auto &second) {
3181  SmallVector<Value> ret;
3182  ret.reserve(first.size() + second.size());
3183  ret.assign(first.begin(), first.end());
3184  ret.append(second.begin(), second.end());
3185  return ret;
3186  };
3187 
3188  auto newLowerBounds =
3189  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3190  auto newUpperBounds =
3191  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3192  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3193 
3194  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3195  newSteps, std::nullopt,
3196  bodyBuilder);
3197  return success();
3198  }
3199 };
3200 
3201 } // namespace
3202 
3203 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3204  MLIRContext *context) {
3205  results
3206  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3207  context);
3208 }
3209 
3210 /// Given the region at `index`, or the parent operation if `index` is None,
3211 /// return the successor regions. These are the regions that may be selected
3212 /// during the flow of control. `operands` is a set of optional attributes that
3213 /// correspond to a constant value for each operand, or null if that operand is
3214 /// not a constant.
3215 void ParallelOp::getSuccessorRegions(
3217  // Both the operation itself and the region may be branching into the body or
3218  // back into the operation itself. It is possible for loop not to enter the
3219  // body.
3220  regions.push_back(RegionSuccessor(&getRegion()));
3221  regions.push_back(RegionSuccessor());
3222 }
3223 
3224 //===----------------------------------------------------------------------===//
3225 // ReduceOp
3226 //===----------------------------------------------------------------------===//
3227 
3228 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3229 
3230 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3231  ValueRange operands) {
3232  result.addOperands(operands);
3233  for (Value v : operands) {
3234  OpBuilder::InsertionGuard guard(builder);
3235  Region *bodyRegion = result.addRegion();
3236  builder.createBlock(bodyRegion, {},
3237  ArrayRef<Type>{v.getType(), v.getType()},
3238  {result.location, result.location});
3239  }
3240 }
3241 
3242 LogicalResult ReduceOp::verifyRegions() {
3243  // The region of a ReduceOp has two arguments of the same type as its
3244  // corresponding operand.
3245  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3246  auto type = getOperands()[i].getType();
3247  Block &block = getReductions()[i].front();
3248  if (block.empty())
3249  return emitOpError() << i << "-th reduction has an empty body";
3250  if (block.getNumArguments() != 2 ||
3251  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3252  return arg.getType() != type;
3253  }))
3254  return emitOpError() << "expected two block arguments with type " << type
3255  << " in the " << i << "-th reduction region";
3256 
3257  // Check that the block is terminated by a ReduceReturnOp.
3258  if (!isa<ReduceReturnOp>(block.getTerminator()))
3259  return emitOpError("reduction bodies must be terminated with an "
3260  "'scf.reduce.return' op");
3261  }
3262 
3263  return success();
3264 }
3265 
3268  // No operands are forwarded to the next iteration.
3269  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3270 }
3271 
3272 //===----------------------------------------------------------------------===//
3273 // ReduceReturnOp
3274 //===----------------------------------------------------------------------===//
3275 
3276 LogicalResult ReduceReturnOp::verify() {
3277  // The type of the return value should be the same type as the types of the
3278  // block arguments of the reduction body.
3279  Block *reductionBody = getOperation()->getBlock();
3280  // Should already be verified by an op trait.
3281  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3282  Type expectedResultType = reductionBody->getArgument(0).getType();
3283  if (expectedResultType != getResult().getType())
3284  return emitOpError() << "must have type " << expectedResultType
3285  << " (the type of the reduction inputs)";
3286  return success();
3287 }
3288 
3289 //===----------------------------------------------------------------------===//
3290 // WhileOp
3291 //===----------------------------------------------------------------------===//
3292 
3293 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3294  ::mlir::OperationState &odsState, TypeRange resultTypes,
3295  ValueRange inits, BodyBuilderFn beforeBuilder,
3296  BodyBuilderFn afterBuilder) {
3297  odsState.addOperands(inits);
3298  odsState.addTypes(resultTypes);
3299 
3300  OpBuilder::InsertionGuard guard(odsBuilder);
3301 
3302  // Build before region.
3303  SmallVector<Location, 4> beforeArgLocs;
3304  beforeArgLocs.reserve(inits.size());
3305  for (Value operand : inits) {
3306  beforeArgLocs.push_back(operand.getLoc());
3307  }
3308 
3309  Region *beforeRegion = odsState.addRegion();
3310  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3311  inits.getTypes(), beforeArgLocs);
3312  if (beforeBuilder)
3313  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3314 
3315  // Build after region.
3316  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3317 
3318  Region *afterRegion = odsState.addRegion();
3319  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3320  resultTypes, afterArgLocs);
3321 
3322  if (afterBuilder)
3323  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3324 }
3325 
3326 ConditionOp WhileOp::getConditionOp() {
3327  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3328 }
3329 
3330 YieldOp WhileOp::getYieldOp() {
3331  return cast<YieldOp>(getAfterBody()->getTerminator());
3332 }
3333 
3334 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3335  return getYieldOp().getResultsMutable();
3336 }
3337 
3338 Block::BlockArgListType WhileOp::getBeforeArguments() {
3339  return getBeforeBody()->getArguments();
3340 }
3341 
3342 Block::BlockArgListType WhileOp::getAfterArguments() {
3343  return getAfterBody()->getArguments();
3344 }
3345 
3346 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3347  return getBeforeArguments();
3348 }
3349 
3350 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3351  assert(point == getBefore() &&
3352  "WhileOp is expected to branch only to the first region");
3353  return getInits();
3354 }
3355 
3356 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3358  // The parent op always branches to the condition region.
3359  if (point.isParent()) {
3360  regions.emplace_back(&getBefore(), getBefore().getArguments());
3361  return;
3362  }
3363 
3364  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3365  "there are only two regions in a WhileOp");
3366  // The body region always branches back to the condition region.
3367  if (point == getAfter()) {
3368  regions.emplace_back(&getBefore(), getBefore().getArguments());
3369  return;
3370  }
3371 
3372  regions.emplace_back(getResults());
3373  regions.emplace_back(&getAfter(), getAfter().getArguments());
3374 }
3375 
3376 SmallVector<Region *> WhileOp::getLoopRegions() {
3377  return {&getBefore(), &getAfter()};
3378 }
3379 
3380 /// Parses a `while` op.
3381 ///
3382 /// op ::= `scf.while` assignments `:` function-type region `do` region
3383 /// `attributes` attribute-dict
3384 /// initializer ::= /* empty */ | `(` assignment-list `)`
3385 /// assignment-list ::= assignment | assignment `,` assignment-list
3386 /// assignment ::= ssa-value `=` ssa-value
3387 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3390  Region *before = result.addRegion();
3391  Region *after = result.addRegion();
3392 
3393  OptionalParseResult listResult =
3394  parser.parseOptionalAssignmentList(regionArgs, operands);
3395  if (listResult.has_value() && failed(listResult.value()))
3396  return failure();
3397 
3398  FunctionType functionType;
3399  SMLoc typeLoc = parser.getCurrentLocation();
3400  if (failed(parser.parseColonType(functionType)))
3401  return failure();
3402 
3403  result.addTypes(functionType.getResults());
3404 
3405  if (functionType.getNumInputs() != operands.size()) {
3406  return parser.emitError(typeLoc)
3407  << "expected as many input types as operands "
3408  << "(expected " << operands.size() << " got "
3409  << functionType.getNumInputs() << ")";
3410  }
3411 
3412  // Resolve input operands.
3413  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3414  parser.getCurrentLocation(),
3415  result.operands)))
3416  return failure();
3417 
3418  // Propagate the types into the region arguments.
3419  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3420  regionArgs[i].type = functionType.getInput(i);
3421 
3422  return failure(parser.parseRegion(*before, regionArgs) ||
3423  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3425 }
3426 
3427 /// Prints a `while` op.
3429  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3430  p << " : ";
3431  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3432  p << ' ';
3433  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3434  p << " do ";
3435  p.printRegion(getAfter());
3436  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3437 }
3438 
3439 /// Verifies that two ranges of types match, i.e. have the same number of
3440 /// entries and that types are pairwise equals. Reports errors on the given
3441 /// operation in case of mismatch.
3442 template <typename OpTy>
3443 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3444  TypeRange right, StringRef message) {
3445  if (left.size() != right.size())
3446  return op.emitOpError("expects the same number of ") << message;
3447 
3448  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3449  if (left[i] != right[i]) {
3450  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3451  << message;
3452  diag.attachNote() << "for argument " << i << ", found " << left[i]
3453  << " and " << right[i];
3454  return diag;
3455  }
3456  }
3457 
3458  return success();
3459 }
3460 
3461 LogicalResult scf::WhileOp::verify() {
3462  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3463  *this, getBefore(),
3464  "expects the 'before' region to terminate with 'scf.condition'");
3465  if (!beforeTerminator)
3466  return failure();
3467 
3468  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3469  *this, getAfter(),
3470  "expects the 'after' region to terminate with 'scf.yield'");
3471  return success(afterTerminator != nullptr);
3472 }
3473 
3474 namespace {
3475 /// Replace uses of the condition within the do block with true, since otherwise
3476 /// the block would not be evaluated.
3477 ///
3478 /// scf.while (..) : (i1, ...) -> ... {
3479 /// %condition = call @evaluate_condition() : () -> i1
3480 /// scf.condition(%condition) %condition : i1, ...
3481 /// } do {
3482 /// ^bb0(%arg0: i1, ...):
3483 /// use(%arg0)
3484 /// ...
3485 ///
3486 /// becomes
3487 /// scf.while (..) : (i1, ...) -> ... {
3488 /// %condition = call @evaluate_condition() : () -> i1
3489 /// scf.condition(%condition) %condition : i1, ...
3490 /// } do {
3491 /// ^bb0(%arg0: i1, ...):
3492 /// use(%true)
3493 /// ...
3494 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3496 
3497  LogicalResult matchAndRewrite(WhileOp op,
3498  PatternRewriter &rewriter) const override {
3499  auto term = op.getConditionOp();
3500 
3501  // These variables serve to prevent creating duplicate constants
3502  // and hold constant true or false values.
3503  Value constantTrue = nullptr;
3504 
3505  bool replaced = false;
3506  for (auto yieldedAndBlockArgs :
3507  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3508  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3509  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3510  if (!constantTrue)
3511  constantTrue = rewriter.create<arith::ConstantOp>(
3512  op.getLoc(), term.getCondition().getType(),
3513  rewriter.getBoolAttr(true));
3514 
3515  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3516  constantTrue);
3517  replaced = true;
3518  }
3519  }
3520  }
3521  return success(replaced);
3522  }
3523 };
3524 
3525 /// Remove loop invariant arguments from `before` block of scf.while.
3526 /// A before block argument is considered loop invariant if :-
3527 /// 1. i-th yield operand is equal to the i-th while operand.
3528 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3529 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3530 /// iter argument/while operand.
3531 /// For the arguments which are removed, their uses inside scf.while
3532 /// are replaced with their corresponding initial value.
3533 ///
3534 /// Eg:
3535 /// INPUT :-
3536 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3537 /// ..., %argN_before = %N)
3538 /// {
3539 /// ...
3540 /// scf.condition(%cond) %arg1_before, %arg0_before,
3541 /// %arg2_before, %arg0_before, ...
3542 /// } do {
3543 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3544 /// ..., %argK_after):
3545 /// ...
3546 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3547 /// }
3548 ///
3549 /// OUTPUT :-
3550 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3551 /// %N)
3552 /// {
3553 /// ...
3554 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3555 /// } do {
3556 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3557 /// ..., %argK_after):
3558 /// ...
3559 /// scf.yield %arg1_after, ..., %argN
3560 /// }
3561 ///
3562 /// EXPLANATION:
3563 /// We iterate over each yield operand.
3564 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3565 /// %arg0_before, which in turn is the 0-th iter argument. So we
3566 /// remove 0-th before block argument and yield operand, and replace
3567 /// all uses of the 0-th before block argument with its initial value
3568 /// %a.
3569 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3570 /// value. So we remove this operand and the corresponding before
3571 /// block argument and replace all uses of 1-th before block argument
3572 /// with %b.
3573 struct RemoveLoopInvariantArgsFromBeforeBlock
3574  : public OpRewritePattern<WhileOp> {
3576 
3577  LogicalResult matchAndRewrite(WhileOp op,
3578  PatternRewriter &rewriter) const override {
3579  Block &afterBlock = *op.getAfterBody();
3580  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3581  ConditionOp condOp = op.getConditionOp();
3582  OperandRange condOpArgs = condOp.getArgs();
3583  Operation *yieldOp = afterBlock.getTerminator();
3584  ValueRange yieldOpArgs = yieldOp->getOperands();
3585 
3586  bool canSimplify = false;
3587  for (const auto &it :
3588  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3589  auto index = static_cast<unsigned>(it.index());
3590  auto [initVal, yieldOpArg] = it.value();
3591  // If i-th yield operand is equal to the i-th operand of the scf.while,
3592  // the i-th before block argument is a loop invariant.
3593  if (yieldOpArg == initVal) {
3594  canSimplify = true;
3595  break;
3596  }
3597  // If the i-th yield operand is k-th after block argument, then we check
3598  // if the (k+1)-th condition op operand is equal to either the i-th before
3599  // block argument or the initial value of i-th before block argument. If
3600  // the comparison results `true`, i-th before block argument is a loop
3601  // invariant.
3602  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3603  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3604  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3605  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3606  canSimplify = true;
3607  break;
3608  }
3609  }
3610  }
3611 
3612  if (!canSimplify)
3613  return failure();
3614 
3615  SmallVector<Value> newInitArgs, newYieldOpArgs;
3616  DenseMap<unsigned, Value> beforeBlockInitValMap;
3617  SmallVector<Location> newBeforeBlockArgLocs;
3618  for (const auto &it :
3619  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3620  auto index = static_cast<unsigned>(it.index());
3621  auto [initVal, yieldOpArg] = it.value();
3622 
3623  // If i-th yield operand is equal to the i-th operand of the scf.while,
3624  // the i-th before block argument is a loop invariant.
3625  if (yieldOpArg == initVal) {
3626  beforeBlockInitValMap.insert({index, initVal});
3627  continue;
3628  } else {
3629  // If the i-th yield operand is k-th after block argument, then we check
3630  // if the (k+1)-th condition op operand is equal to either the i-th
3631  // before block argument or the initial value of i-th before block
3632  // argument. If the comparison results `true`, i-th before block
3633  // argument is a loop invariant.
3634  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3635  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3636  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3637  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3638  beforeBlockInitValMap.insert({index, initVal});
3639  continue;
3640  }
3641  }
3642  }
3643  newInitArgs.emplace_back(initVal);
3644  newYieldOpArgs.emplace_back(yieldOpArg);
3645  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3646  }
3647 
3648  {
3649  OpBuilder::InsertionGuard g(rewriter);
3650  rewriter.setInsertionPoint(yieldOp);
3651  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3652  }
3653 
3654  auto newWhile =
3655  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3656 
3657  Block &newBeforeBlock = *rewriter.createBlock(
3658  &newWhile.getBefore(), /*insertPt*/ {},
3659  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3660 
3661  Block &beforeBlock = *op.getBeforeBody();
3662  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3663  // For each i-th before block argument we find it's replacement value as :-
3664  // 1. If i-th before block argument is a loop invariant, we fetch it's
3665  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3666  // 2. Else we fetch j-th new before block argument as the replacement
3667  // value of i-th before block argument.
3668  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3669  // If the index 'i' argument was a loop invariant we fetch it's initial
3670  // value from `beforeBlockInitValMap`.
3671  if (beforeBlockInitValMap.count(i) != 0)
3672  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3673  else
3674  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3675  }
3676 
3677  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3678  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3679  newWhile.getAfter().begin());
3680 
3681  rewriter.replaceOp(op, newWhile.getResults());
3682  return success();
3683  }
3684 };
3685 
3686 /// Remove loop invariant value from result (condition op) of scf.while.
3687 /// A value is considered loop invariant if the final value yielded by
3688 /// scf.condition is defined outside of the `before` block. We remove the
3689 /// corresponding argument in `after` block and replace the use with the value.
3690 /// We also replace the use of the corresponding result of scf.while with the
3691 /// value.
3692 ///
3693 /// Eg:
3694 /// INPUT :-
3695 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3696 /// %argN_before = %N) {
3697 /// ...
3698 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3699 /// } do {
3700 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3701 /// ...
3702 /// some_func(%arg1_after)
3703 /// ...
3704 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3705 /// }
3706 ///
3707 /// OUTPUT :-
3708 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3709 /// ...
3710 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3711 /// } do {
3712 /// ^bb0(%arg0, %arg3, ..., %argM):
3713 /// ...
3714 /// some_func(%a)
3715 /// ...
3716 /// scf.yield %arg0, %b, ..., %argN
3717 /// }
3718 ///
3719 /// EXPLANATION:
3720 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3721 /// before block of scf.while, so they get removed.
3722 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3723 /// replaced by %b.
3724 /// 3. The corresponding after block argument %arg1_after's uses are
3725 /// replaced by %a and %arg2_after's uses are replaced by %b.
3726 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3728 
3729  LogicalResult matchAndRewrite(WhileOp op,
3730  PatternRewriter &rewriter) const override {
3731  Block &beforeBlock = *op.getBeforeBody();
3732  ConditionOp condOp = op.getConditionOp();
3733  OperandRange condOpArgs = condOp.getArgs();
3734 
3735  bool canSimplify = false;
3736  for (Value condOpArg : condOpArgs) {
3737  // Those values not defined within `before` block will be considered as
3738  // loop invariant values. We map the corresponding `index` with their
3739  // value.
3740  if (condOpArg.getParentBlock() != &beforeBlock) {
3741  canSimplify = true;
3742  break;
3743  }
3744  }
3745 
3746  if (!canSimplify)
3747  return failure();
3748 
3749  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3750 
3751  SmallVector<Value> newCondOpArgs;
3752  SmallVector<Type> newAfterBlockType;
3753  DenseMap<unsigned, Value> condOpInitValMap;
3754  SmallVector<Location> newAfterBlockArgLocs;
3755  for (const auto &it : llvm::enumerate(condOpArgs)) {
3756  auto index = static_cast<unsigned>(it.index());
3757  Value condOpArg = it.value();
3758  // Those values not defined within `before` block will be considered as
3759  // loop invariant values. We map the corresponding `index` with their
3760  // value.
3761  if (condOpArg.getParentBlock() != &beforeBlock) {
3762  condOpInitValMap.insert({index, condOpArg});
3763  } else {
3764  newCondOpArgs.emplace_back(condOpArg);
3765  newAfterBlockType.emplace_back(condOpArg.getType());
3766  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3767  }
3768  }
3769 
3770  {
3771  OpBuilder::InsertionGuard g(rewriter);
3772  rewriter.setInsertionPoint(condOp);
3773  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3774  newCondOpArgs);
3775  }
3776 
3777  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3778  op.getOperands());
3779 
3780  Block &newAfterBlock =
3781  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3782  newAfterBlockType, newAfterBlockArgLocs);
3783 
3784  Block &afterBlock = *op.getAfterBody();
3785  // Since a new scf.condition op was created, we need to fetch the new
3786  // `after` block arguments which will be used while replacing operations of
3787  // previous scf.while's `after` blocks. We'd also be fetching new result
3788  // values too.
3789  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3790  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3791  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3792  Value afterBlockArg, result;
3793  // If index 'i' argument was loop invariant we fetch it's value from the
3794  // `condOpInitMap` map.
3795  if (condOpInitValMap.count(i) != 0) {
3796  afterBlockArg = condOpInitValMap[i];
3797  result = afterBlockArg;
3798  } else {
3799  afterBlockArg = newAfterBlock.getArgument(j);
3800  result = newWhile.getResult(j);
3801  j++;
3802  }
3803  newAfterBlockArgs[i] = afterBlockArg;
3804  newWhileResults[i] = result;
3805  }
3806 
3807  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3808  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3809  newWhile.getBefore().begin());
3810 
3811  rewriter.replaceOp(op, newWhileResults);
3812  return success();
3813  }
3814 };
3815 
3816 /// Remove WhileOp results that are also unused in 'after' block.
3817 ///
3818 /// %0:2 = scf.while () : () -> (i32, i64) {
3819 /// %condition = "test.condition"() : () -> i1
3820 /// %v1 = "test.get_some_value"() : () -> i32
3821 /// %v2 = "test.get_some_value"() : () -> i64
3822 /// scf.condition(%condition) %v1, %v2 : i32, i64
3823 /// } do {
3824 /// ^bb0(%arg0: i32, %arg1: i64):
3825 /// "test.use"(%arg0) : (i32) -> ()
3826 /// scf.yield
3827 /// }
3828 /// return %0#0 : i32
3829 ///
3830 /// becomes
3831 /// %0 = scf.while () : () -> (i32) {
3832 /// %condition = "test.condition"() : () -> i1
3833 /// %v1 = "test.get_some_value"() : () -> i32
3834 /// %v2 = "test.get_some_value"() : () -> i64
3835 /// scf.condition(%condition) %v1 : i32
3836 /// } do {
3837 /// ^bb0(%arg0: i32):
3838 /// "test.use"(%arg0) : (i32) -> ()
3839 /// scf.yield
3840 /// }
3841 /// return %0 : i32
3842 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3844 
3845  LogicalResult matchAndRewrite(WhileOp op,
3846  PatternRewriter &rewriter) const override {
3847  auto term = op.getConditionOp();
3848  auto afterArgs = op.getAfterArguments();
3849  auto termArgs = term.getArgs();
3850 
3851  // Collect results mapping, new terminator args and new result types.
3852  SmallVector<unsigned> newResultsIndices;
3853  SmallVector<Type> newResultTypes;
3854  SmallVector<Value> newTermArgs;
3855  SmallVector<Location> newArgLocs;
3856  bool needUpdate = false;
3857  for (const auto &it :
3858  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3859  auto i = static_cast<unsigned>(it.index());
3860  Value result = std::get<0>(it.value());
3861  Value afterArg = std::get<1>(it.value());
3862  Value termArg = std::get<2>(it.value());
3863  if (result.use_empty() && afterArg.use_empty()) {
3864  needUpdate = true;
3865  } else {
3866  newResultsIndices.emplace_back(i);
3867  newTermArgs.emplace_back(termArg);
3868  newResultTypes.emplace_back(result.getType());
3869  newArgLocs.emplace_back(result.getLoc());
3870  }
3871  }
3872 
3873  if (!needUpdate)
3874  return failure();
3875 
3876  {
3877  OpBuilder::InsertionGuard g(rewriter);
3878  rewriter.setInsertionPoint(term);
3879  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3880  newTermArgs);
3881  }
3882 
3883  auto newWhile =
3884  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3885 
3886  Block &newAfterBlock = *rewriter.createBlock(
3887  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3888 
3889  // Build new results list and new after block args (unused entries will be
3890  // null).
3891  SmallVector<Value> newResults(op.getNumResults());
3892  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3893  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3894  newResults[it.value()] = newWhile.getResult(it.index());
3895  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3896  }
3897 
3898  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3899  newWhile.getBefore().begin());
3900 
3901  Block &afterBlock = *op.getAfterBody();
3902  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3903 
3904  rewriter.replaceOp(op, newResults);
3905  return success();
3906  }
3907 };
3908 
3909 /// Replace operations equivalent to the condition in the do block with true,
3910 /// since otherwise the block would not be evaluated.
3911 ///
3912 /// scf.while (..) : (i32, ...) -> ... {
3913 /// %z = ... : i32
3914 /// %condition = cmpi pred %z, %a
3915 /// scf.condition(%condition) %z : i32, ...
3916 /// } do {
3917 /// ^bb0(%arg0: i32, ...):
3918 /// %condition2 = cmpi pred %arg0, %a
3919 /// use(%condition2)
3920 /// ...
3921 ///
3922 /// becomes
3923 /// scf.while (..) : (i32, ...) -> ... {
3924 /// %z = ... : i32
3925 /// %condition = cmpi pred %z, %a
3926 /// scf.condition(%condition) %z : i32, ...
3927 /// } do {
3928 /// ^bb0(%arg0: i32, ...):
3929 /// use(%true)
3930 /// ...
3931 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3933 
3934  LogicalResult matchAndRewrite(scf::WhileOp op,
3935  PatternRewriter &rewriter) const override {
3936  using namespace scf;
3937  auto cond = op.getConditionOp();
3938  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3939  if (!cmp)
3940  return failure();
3941  bool changed = false;
3942  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3943  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3944  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3945  continue;
3946  for (OpOperand &u :
3947  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3948  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3949  if (!cmp2)
3950  continue;
3951  // For a binary operator 1-opIdx gets the other side.
3952  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3953  continue;
3954  bool samePredicate;
3955  if (cmp2.getPredicate() == cmp.getPredicate())
3956  samePredicate = true;
3957  else if (cmp2.getPredicate() ==
3958  arith::invertPredicate(cmp.getPredicate()))
3959  samePredicate = false;
3960  else
3961  continue;
3962 
3963  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3964  1);
3965  changed = true;
3966  }
3967  }
3968  }
3969  return success(changed);
3970  }
3971 };
3972 
3973 /// Remove unused init/yield args.
3974 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3976 
3977  LogicalResult matchAndRewrite(WhileOp op,
3978  PatternRewriter &rewriter) const override {
3979 
3980  if (!llvm::any_of(op.getBeforeArguments(),
3981  [](Value arg) { return arg.use_empty(); }))
3982  return rewriter.notifyMatchFailure(op, "No args to remove");
3983 
3984  YieldOp yield = op.getYieldOp();
3985 
3986  // Collect results mapping, new terminator args and new result types.
3987  SmallVector<Value> newYields;
3988  SmallVector<Value> newInits;
3989  llvm::BitVector argsToErase;
3990 
3991  size_t argsCount = op.getBeforeArguments().size();
3992  newYields.reserve(argsCount);
3993  newInits.reserve(argsCount);
3994  argsToErase.reserve(argsCount);
3995  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3996  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3997  if (beforeArg.use_empty()) {
3998  argsToErase.push_back(true);
3999  } else {
4000  argsToErase.push_back(false);
4001  newYields.emplace_back(yieldValue);
4002  newInits.emplace_back(initValue);
4003  }
4004  }
4005 
4006  Block &beforeBlock = *op.getBeforeBody();
4007  Block &afterBlock = *op.getAfterBody();
4008 
4009  beforeBlock.eraseArguments(argsToErase);
4010 
4011  Location loc = op.getLoc();
4012  auto newWhileOp =
4013  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4014  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4015  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4016  Block &newAfterBlock = *newWhileOp.getAfterBody();
4017 
4018  OpBuilder::InsertionGuard g(rewriter);
4019  rewriter.setInsertionPoint(yield);
4020  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4021 
4022  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4023  newBeforeBlock.getArguments());
4024  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4025  newAfterBlock.getArguments());
4026 
4027  rewriter.replaceOp(op, newWhileOp.getResults());
4028  return success();
4029  }
4030 };
4031 
4032 /// Remove duplicated ConditionOp args.
4033 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4035 
4036  LogicalResult matchAndRewrite(WhileOp op,
4037  PatternRewriter &rewriter) const override {
4038  ConditionOp condOp = op.getConditionOp();
4039  ValueRange condOpArgs = condOp.getArgs();
4040 
4041  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4042 
4043  if (argsSet.size() == condOpArgs.size())
4044  return rewriter.notifyMatchFailure(op, "No results to remove");
4045 
4046  llvm::SmallDenseMap<Value, unsigned> argsMap;
4047  SmallVector<Value> newArgs;
4048  argsMap.reserve(condOpArgs.size());
4049  newArgs.reserve(condOpArgs.size());
4050  for (Value arg : condOpArgs) {
4051  if (!argsMap.count(arg)) {
4052  auto pos = static_cast<unsigned>(argsMap.size());
4053  argsMap.insert({arg, pos});
4054  newArgs.emplace_back(arg);
4055  }
4056  }
4057 
4058  ValueRange argsRange(newArgs);
4059 
4060  Location loc = op.getLoc();
4061  auto newWhileOp = rewriter.create<scf::WhileOp>(
4062  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4063  /*afterBody*/ nullptr);
4064  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4065  Block &newAfterBlock = *newWhileOp.getAfterBody();
4066 
4067  SmallVector<Value> afterArgsMapping;
4068  SmallVector<Value> resultsMapping;
4069  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4070  auto it = argsMap.find(arg);
4071  assert(it != argsMap.end());
4072  auto pos = it->second;
4073  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4074  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4075  }
4076 
4077  OpBuilder::InsertionGuard g(rewriter);
4078  rewriter.setInsertionPoint(condOp);
4079  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4080  argsRange);
4081 
4082  Block &beforeBlock = *op.getBeforeBody();
4083  Block &afterBlock = *op.getAfterBody();
4084 
4085  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4086  newBeforeBlock.getArguments());
4087  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4088  rewriter.replaceOp(op, resultsMapping);
4089  return success();
4090  }
4091 };
4092 
4093 /// If both ranges contain same values return mappping indices from args2 to
4094 /// args1. Otherwise return std::nullopt.
4095 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4096  ValueRange args2) {
4097  if (args1.size() != args2.size())
4098  return std::nullopt;
4099 
4100  SmallVector<unsigned> ret(args1.size());
4101  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4102  auto it = llvm::find(args2, arg1);
4103  if (it == args2.end())
4104  return std::nullopt;
4105 
4106  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4107  }
4108 
4109  return ret;
4110 }
4111 
4112 static bool hasDuplicates(ValueRange args) {
4113  llvm::SmallDenseSet<Value> set;
4114  for (Value arg : args) {
4115  if (!set.insert(arg).second)
4116  return true;
4117  }
4118  return false;
4119 }
4120 
4121 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4122 /// `scf.condition` args into same order as block args. Update `after` block
4123 /// args and op result values accordingly.
4124 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4125 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4127 
4128  LogicalResult matchAndRewrite(WhileOp loop,
4129  PatternRewriter &rewriter) const override {
4130  auto oldBefore = loop.getBeforeBody();
4131  ConditionOp oldTerm = loop.getConditionOp();
4132  ValueRange beforeArgs = oldBefore->getArguments();
4133  ValueRange termArgs = oldTerm.getArgs();
4134  if (beforeArgs == termArgs)
4135  return failure();
4136 
4137  if (hasDuplicates(termArgs))
4138  return failure();
4139 
4140  auto mapping = getArgsMapping(beforeArgs, termArgs);
4141  if (!mapping)
4142  return failure();
4143 
4144  {
4145  OpBuilder::InsertionGuard g(rewriter);
4146  rewriter.setInsertionPoint(oldTerm);
4147  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4148  beforeArgs);
4149  }
4150 
4151  auto oldAfter = loop.getAfterBody();
4152 
4153  SmallVector<Type> newResultTypes(beforeArgs.size());
4154  for (auto &&[i, j] : llvm::enumerate(*mapping))
4155  newResultTypes[j] = loop.getResult(i).getType();
4156 
4157  auto newLoop = rewriter.create<WhileOp>(
4158  loop.getLoc(), newResultTypes, loop.getInits(),
4159  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4160  auto newBefore = newLoop.getBeforeBody();
4161  auto newAfter = newLoop.getAfterBody();
4162 
4163  SmallVector<Value> newResults(beforeArgs.size());
4164  SmallVector<Value> newAfterArgs(beforeArgs.size());
4165  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4166  newResults[i] = newLoop.getResult(j);
4167  newAfterArgs[i] = newAfter->getArgument(j);
4168  }
4169 
4170  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4171  newBefore->getArguments());
4172  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4173  newAfterArgs);
4174 
4175  rewriter.replaceOp(loop, newResults);
4176  return success();
4177  }
4178 };
4179 } // namespace
4180 
4181 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4182  MLIRContext *context) {
4183  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4184  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4185  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4186  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4187 }
4188 
4189 //===----------------------------------------------------------------------===//
4190 // IndexSwitchOp
4191 //===----------------------------------------------------------------------===//
4192 
4193 /// Parse the case regions and values.
4194 static ParseResult
4196  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4197  SmallVector<int64_t> caseValues;
4198  while (succeeded(p.parseOptionalKeyword("case"))) {
4199  int64_t value;
4200  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4201  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4202  return failure();
4203  caseValues.push_back(value);
4204  }
4205  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4206  return success();
4207 }
4208 
4209 /// Print the case regions and values.
4211  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4212  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4213  p.printNewline();
4214  p << "case " << value << ' ';
4215  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4216  }
4217 }
4218 
4219 LogicalResult scf::IndexSwitchOp::verify() {
4220  if (getCases().size() != getCaseRegions().size()) {
4221  return emitOpError("has ")
4222  << getCaseRegions().size() << " case regions but "
4223  << getCases().size() << " case values";
4224  }
4225 
4226  DenseSet<int64_t> valueSet;
4227  for (int64_t value : getCases())
4228  if (!valueSet.insert(value).second)
4229  return emitOpError("has duplicate case value: ") << value;
4230  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4231  auto yield = dyn_cast<YieldOp>(region.front().back());
4232  if (!yield)
4233  return emitOpError("expected region to end with scf.yield, but got ")
4234  << region.front().back().getName();
4235 
4236  if (yield.getNumOperands() != getNumResults()) {
4237  return (emitOpError("expected each region to return ")
4238  << getNumResults() << " values, but " << name << " returns "
4239  << yield.getNumOperands())
4240  .attachNote(yield.getLoc())
4241  << "see yield operation here";
4242  }
4243  for (auto [idx, result, operand] :
4244  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4245  yield.getOperandTypes())) {
4246  if (result == operand)
4247  continue;
4248  return (emitOpError("expected result #")
4249  << idx << " of each region to be " << result)
4250  .attachNote(yield.getLoc())
4251  << name << " returns " << operand << " here";
4252  }
4253  return success();
4254  };
4255 
4256  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4257  return failure();
4258  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4259  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4260  return failure();
4261 
4262  return success();
4263 }
4264 
4265 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4266 
4267 Block &scf::IndexSwitchOp::getDefaultBlock() {
4268  return getDefaultRegion().front();
4269 }
4270 
4271 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4272  assert(idx < getNumCases() && "case index out-of-bounds");
4273  return getCaseRegions()[idx].front();
4274 }
4275 
4276 void IndexSwitchOp::getSuccessorRegions(
4278  // All regions branch back to the parent op.
4279  if (!point.isParent()) {
4280  successors.emplace_back(getResults());
4281  return;
4282  }
4283 
4284  llvm::copy(getRegions(), std::back_inserter(successors));
4285 }
4286 
4287 void IndexSwitchOp::getEntrySuccessorRegions(
4288  ArrayRef<Attribute> operands,
4289  SmallVectorImpl<RegionSuccessor> &successors) {
4290  FoldAdaptor adaptor(operands, *this);
4291 
4292  // If a constant was not provided, all regions are possible successors.
4293  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4294  if (!arg) {
4295  llvm::copy(getRegions(), std::back_inserter(successors));
4296  return;
4297  }
4298 
4299  // Otherwise, try to find a case with a matching value. If not, the
4300  // default region is the only successor.
4301  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4302  if (caseValue == arg.getInt()) {
4303  successors.emplace_back(&caseRegion);
4304  return;
4305  }
4306  }
4307  successors.emplace_back(&getDefaultRegion());
4308 }
4309 
4310 void IndexSwitchOp::getRegionInvocationBounds(
4312  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4313  if (!operandValue) {
4314  // All regions are invoked at most once.
4315  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4316  return;
4317  }
4318 
4319  unsigned liveIndex = getNumRegions() - 1;
4320  const auto *it = llvm::find(getCases(), operandValue.getInt());
4321  if (it != getCases().end())
4322  liveIndex = std::distance(getCases().begin(), it);
4323  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4324  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4325 }
4326 
4327 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4329 
4330  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4331  PatternRewriter &rewriter) const override {
4332  // If `op.getArg()` is a constant, select the region that matches with
4333  // the constant value. Use the default region if no matche is found.
4334  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4335  if (!maybeCst.has_value())
4336  return failure();
4337  int64_t cst = *maybeCst;
4338  int64_t caseIdx, e = op.getNumCases();
4339  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4340  if (cst == op.getCases()[caseIdx])
4341  break;
4342  }
4343 
4344  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4345  : op.getDefaultRegion();
4346  Block &source = r.front();
4347  Operation *terminator = source.getTerminator();
4348  SmallVector<Value> results = terminator->getOperands();
4349 
4350  rewriter.inlineBlockBefore(&source, op);
4351  rewriter.eraseOp(terminator);
4352  // Replace the operation with a potentially empty list of results.
4353  // Fold mechanism doesn't support the case where the result list is empty.
4354  rewriter.replaceOp(op, results);
4355 
4356  return success();
4357  }
4358 };
4359 
4360 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4361  MLIRContext *context) {
4362  results.add<FoldConstantCase>(context);
4363 }
4364 
4365 //===----------------------------------------------------------------------===//
4366 // TableGen'd op method definitions
4367 //===----------------------------------------------------------------------===//
4368 
4369 #define GET_OP_CLASSES
4370 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:736
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:728
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1286
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4195
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3443
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4210
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:295
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:307
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:571
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:582
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:815
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:606
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:582
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:194
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3056
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:692
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1984
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:649
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:602
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1458
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:781
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4330
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.