MLIR  20.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange initArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(initArgs);
321  for (Value v : initArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : initArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (initArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Resolve input operands.
503  regionArgs.front().type = type;
504  if (parser.resolveOperand(lb, type, result.operands) ||
505  parser.resolveOperand(ub, type, result.operands) ||
506  parser.resolveOperand(step, type, result.operands))
507  return failure();
508  if (hasIterArgs) {
509  for (auto argOperandType :
510  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
511  Type type = std::get<2>(argOperandType);
512  std::get<0>(argOperandType).type = type;
513  if (parser.resolveOperand(std::get<1>(argOperandType), type,
514  result.operands))
515  return failure();
516  }
517  }
518 
519  // Parse the body region.
520  Region *body = result.addRegion();
521  if (parser.parseRegion(*body, regionArgs))
522  return failure();
523 
524  ForOp::ensureTerminator(*body, builder, result.location);
525 
526  // Parse the optional attribute list.
527  if (parser.parseOptionalAttrDict(result.attributes))
528  return failure();
529 
530  return success();
531 }
532 
533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534 
535 Block::BlockArgListType ForOp::getRegionIterArgs() {
536  return getBody()->getArguments().drop_front(getNumInductionVars());
537 }
538 
539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
540  return getInitArgsMutable();
541 }
542 
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
545  ValueRange newInitOperands,
546  bool replaceInitOperandUsesInLoop,
547  const NewYieldValuesFn &newYieldValuesFn) {
548  // Create a new loop before the existing one, with the extra operands.
549  OpBuilder::InsertionGuard g(rewriter);
550  rewriter.setInsertionPoint(getOperation());
551  auto inits = llvm::to_vector(getInitArgs());
552  inits.append(newInitOperands.begin(), newInitOperands.end());
553  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
554  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
555  [](OpBuilder &, Location, Value, ValueRange) {});
556  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
777  OpOperand &operand, Value replacement,
778  const ValueTypeCastFnTy &castFn) {
779  assert(operand.getOwner() == forOp);
780  Type oldType = operand.get().getType(), newType = replacement.getType();
781 
782  // 1. Create new iter operands, exactly 1 is replaced.
783  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
784  "expected an iter OpOperand");
785  assert(operand.get().getType() != replacement.getType() &&
786  "Expected a different type");
787  SmallVector<Value> newIterOperands;
788  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
789  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
790  newIterOperands.push_back(replacement);
791  continue;
792  }
793  newIterOperands.push_back(opOperand.get());
794  }
795 
796  // 2. Create the new forOp shell.
797  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
798  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799  forOp.getStep(), newIterOperands);
800  newForOp->setAttrs(forOp->getAttrs());
801  Block &newBlock = newForOp.getRegion().front();
802  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
803  newBlock.getArguments().end());
804 
805  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806  // corresponding to the `replacement` value.
807  OpBuilder::InsertionGuard g(rewriter);
808  rewriter.setInsertionPointToStart(&newBlock);
809  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
810  &newForOp->getOpOperand(operand.getOperandNumber()));
811  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
813 
814  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815  Block &oldBlock = forOp.getRegion().front();
816  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
817 
818  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
820  rewriter.setInsertionPoint(clonedYieldOp);
821  unsigned yieldIdx =
822  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
823  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824  clonedYieldOp.getOperand(yieldIdx));
825  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
826  newYieldOperands[yieldIdx] = castOut;
827  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828  rewriter.eraseOp(clonedYieldOp);
829 
830  // 6. Inject an outgoing cast op after the forOp.
831  rewriter.setInsertionPointAfter(newForOp);
832  SmallVector<Value> newResults = newForOp.getResults();
833  newResults[yieldIdx] =
834  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
835 
836  return newResults;
837 }
838 
839 namespace {
840 // Fold away ForOp iter arguments when:
841 // 1) The op yields the iter arguments.
842 // 2) The iter arguments have no use and the corresponding outer region
843 // iterators (inputs) are yielded.
844 // 3) The iter arguments have no use and the corresponding (operation) results
845 // have no use.
846 //
847 // These arguments must be defined outside of
848 // the ForOp region and can just be forwarded after simplifying the op inits,
849 // yields and returns.
850 //
851 // The implementation uses `inlineBlockBefore` to steal the content of the
852 // original ForOp and avoid cloning.
853 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
855 
856  LogicalResult matchAndRewrite(scf::ForOp forOp,
857  PatternRewriter &rewriter) const final {
858  bool canonicalize = false;
859 
860  // An internal flat vector of block transfer
861  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
862  // transformed block argument mappings. This plays the role of a
863  // IRMapping for the particular use case of calling into
864  // `inlineBlockBefore`.
865  int64_t numResults = forOp.getNumResults();
866  SmallVector<bool, 4> keepMask;
867  keepMask.reserve(numResults);
868  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
869  newResultValues;
870  newBlockTransferArgs.reserve(1 + numResults);
871  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
872  newIterArgs.reserve(forOp.getInitArgs().size());
873  newYieldValues.reserve(numResults);
874  newResultValues.reserve(numResults);
875  for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
876  forOp.getRegionIterArgs(), // iter inside region
877  forOp.getResults(), // op results
878  forOp.getYieldedValues() // iter yield
879  )) {
880  // Forwarded is `true` when:
881  // 1) The region `iter` argument is yielded.
882  // 2) The region `iter` argument has no use, and the corresponding iter
883  // operand (input) is yielded.
884  // 3) The region `iter` argument has no use, and the corresponding op
885  // result has no use.
886  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
887  (std::get<1>(it).use_empty() &&
888  (std::get<0>(it) == std::get<3>(it) ||
889  std::get<2>(it).use_empty())));
890  keepMask.push_back(!forwarded);
891  canonicalize |= forwarded;
892  if (forwarded) {
893  newBlockTransferArgs.push_back(std::get<0>(it));
894  newResultValues.push_back(std::get<0>(it));
895  continue;
896  }
897  newIterArgs.push_back(std::get<0>(it));
898  newYieldValues.push_back(std::get<3>(it));
899  newBlockTransferArgs.push_back(Value()); // placeholder with null value
900  newResultValues.push_back(Value()); // placeholder with null value
901  }
902 
903  if (!canonicalize)
904  return failure();
905 
906  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
907  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
908  forOp.getStep(), newIterArgs);
909  newForOp->setAttrs(forOp->getAttrs());
910  Block &newBlock = newForOp.getRegion().front();
911 
912  // Replace the null placeholders with newly constructed values.
913  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
914  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
915  idx != e; ++idx) {
916  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
917  Value &newResultVal = newResultValues[idx];
918  assert((blockTransferArg && newResultVal) ||
919  (!blockTransferArg && !newResultVal));
920  if (!blockTransferArg) {
921  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
922  newResultVal = newForOp.getResult(collapsedIdx++);
923  }
924  }
925 
926  Block &oldBlock = forOp.getRegion().front();
927  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
928  "unexpected argument size mismatch");
929 
930  // No results case: the scf::ForOp builder already created a zero
931  // result terminator. Merge before this terminator and just get rid of the
932  // original terminator that has been merged in.
933  if (newIterArgs.empty()) {
934  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
935  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
936  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
937  rewriter.replaceOp(forOp, newResultValues);
938  return success();
939  }
940 
941  // No terminator case: merge and rewrite the merged terminator.
942  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
943  OpBuilder::InsertionGuard g(rewriter);
944  rewriter.setInsertionPoint(mergedTerminator);
945  SmallVector<Value, 4> filteredOperands;
946  filteredOperands.reserve(newResultValues.size());
947  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
948  if (keepMask[idx])
949  filteredOperands.push_back(mergedTerminator.getOperand(idx));
950  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
951  filteredOperands);
952  };
953 
954  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
955  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
956  cloneFilteredTerminator(mergedYieldOp);
957  rewriter.eraseOp(mergedYieldOp);
958  rewriter.replaceOp(forOp, newResultValues);
959  return success();
960  }
961 };
962 
963 /// Util function that tries to compute a constant diff between u and l.
964 /// Returns std::nullopt when the difference between two AffineValueMap is
965 /// dynamic.
966 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
967  IntegerAttr clb, cub;
968  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
969  llvm::APInt lbValue = clb.getValue();
970  llvm::APInt ubValue = cub.getValue();
971  return (ubValue - lbValue).getSExtValue();
972  }
973 
974  // Else a simple pattern match for x + c or c + x
975  llvm::APInt diff;
976  if (matchPattern(
977  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
978  matchPattern(
979  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
980  return diff.getSExtValue();
981  return std::nullopt;
982 }
983 
984 /// Rewriting pattern that erases loops that are known not to iterate, replaces
985 /// single-iteration loops with their bodies, and removes empty loops that
986 /// iterate at least once and only return values defined outside of the loop.
987 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
989 
990  LogicalResult matchAndRewrite(ForOp op,
991  PatternRewriter &rewriter) const override {
992  // If the upper bound is the same as the lower bound, the loop does not
993  // iterate, just remove it.
994  if (op.getLowerBound() == op.getUpperBound()) {
995  rewriter.replaceOp(op, op.getInitArgs());
996  return success();
997  }
998 
999  std::optional<int64_t> diff =
1000  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1001  if (!diff)
1002  return failure();
1003 
1004  // If the loop is known to have 0 iterations, remove it.
1005  if (*diff <= 0) {
1006  rewriter.replaceOp(op, op.getInitArgs());
1007  return success();
1008  }
1009 
1010  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1011  if (!maybeStepValue)
1012  return failure();
1013 
1014  // If the loop is known to have 1 iteration, inline its body and remove the
1015  // loop.
1016  llvm::APInt stepValue = *maybeStepValue;
1017  if (stepValue.sge(*diff)) {
1018  SmallVector<Value, 4> blockArgs;
1019  blockArgs.reserve(op.getInitArgs().size() + 1);
1020  blockArgs.push_back(op.getLowerBound());
1021  llvm::append_range(blockArgs, op.getInitArgs());
1022  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1023  return success();
1024  }
1025 
1026  // Now we are left with loops that have more than 1 iterations.
1027  Block &block = op.getRegion().front();
1028  if (!llvm::hasSingleElement(block))
1029  return failure();
1030  // If the loop is empty, iterates at least once, and only returns values
1031  // defined outside of the loop, remove it and replace it with yield values.
1032  if (llvm::any_of(op.getYieldedValues(),
1033  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1034  return failure();
1035  rewriter.replaceOp(op, op.getYieldedValues());
1036  return success();
1037  }
1038 };
1039 
1040 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1041 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1042 ///
1043 /// ```
1044 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1045 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1046 /// -> (tensor<?x?xf32>) {
1047 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1048 /// scf.yield %2 : tensor<?x?xf32>
1049 /// }
1050 /// use_of(%1)
1051 /// ```
1052 ///
1053 /// folds into:
1054 ///
1055 /// ```
1056 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1057 /// -> (tensor<32x1024xf32>) {
1058 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1059 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1060 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1061 /// scf.yield %4 : tensor<32x1024xf32>
1062 /// }
1063 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1064 /// use_of(%1)
1065 /// ```
1066 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1068 
1069  LogicalResult matchAndRewrite(ForOp op,
1070  PatternRewriter &rewriter) const override {
1071  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1072  OpOperand &iterOpOperand = std::get<0>(it);
1073  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1074  if (!incomingCast ||
1075  incomingCast.getSource().getType() == incomingCast.getType())
1076  continue;
1077  // If the dest type of the cast does not preserve static information in
1078  // the source type.
1080  incomingCast.getDest().getType(),
1081  incomingCast.getSource().getType()))
1082  continue;
1083  if (!std::get<1>(it).hasOneUse())
1084  continue;
1085 
1086  // Create a new ForOp with that iter operand replaced.
1087  rewriter.replaceOp(
1089  rewriter, op, iterOpOperand, incomingCast.getSource(),
1090  [](OpBuilder &b, Location loc, Type type, Value source) {
1091  return b.create<tensor::CastOp>(loc, type, source);
1092  }));
1093  return success();
1094  }
1095  return failure();
1096  }
1097 };
1098 
1099 } // namespace
1100 
1101 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1102  MLIRContext *context) {
1103  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1104  context);
1105 }
1106 
1107 std::optional<APInt> ForOp::getConstantStep() {
1108  IntegerAttr step;
1109  if (matchPattern(getStep(), m_Constant(&step)))
1110  return step.getValue();
1111  return {};
1112 }
1113 
1114 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1115  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1116 }
1117 
1118 Speculation::Speculatability ForOp::getSpeculatability() {
1119  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1120  // and End.
1121  if (auto constantStep = getConstantStep())
1122  if (*constantStep == 1)
1124 
1125  // For Step != 1, the loop may not terminate. We can add more smarts here if
1126  // needed.
1128 }
1129 
1130 //===----------------------------------------------------------------------===//
1131 // ForallOp
1132 //===----------------------------------------------------------------------===//
1133 
1134 LogicalResult ForallOp::verify() {
1135  unsigned numLoops = getRank();
1136  // Check number of outputs.
1137  if (getNumResults() != getOutputs().size())
1138  return emitOpError("produces ")
1139  << getNumResults() << " results, but has only "
1140  << getOutputs().size() << " outputs";
1141 
1142  // Check that the body defines block arguments for thread indices and outputs.
1143  auto *body = getBody();
1144  if (body->getNumArguments() != numLoops + getOutputs().size())
1145  return emitOpError("region expects ") << numLoops << " arguments";
1146  for (int64_t i = 0; i < numLoops; ++i)
1147  if (!body->getArgument(i).getType().isIndex())
1148  return emitOpError("expects ")
1149  << i << "-th block argument to be an index";
1150  for (unsigned i = 0; i < getOutputs().size(); ++i)
1151  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1152  return emitOpError("type mismatch between ")
1153  << i << "-th output and corresponding block argument";
1154  if (getMapping().has_value() && !getMapping()->empty()) {
1155  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1156  return emitOpError() << "mapping attribute size must match op rank";
1157  for (auto map : getMapping()->getValue()) {
1158  if (!isa<DeviceMappingAttrInterface>(map))
1159  return emitOpError()
1160  << getMappingAttrName() << " is not device mapping attribute";
1161  }
1162  }
1163 
1164  // Verify mixed static/dynamic control variables.
1165  Operation *op = getOperation();
1166  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1167  getStaticLowerBound(),
1168  getDynamicLowerBound())))
1169  return failure();
1170  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1171  getStaticUpperBound(),
1172  getDynamicUpperBound())))
1173  return failure();
1174  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1175  getStaticStep(), getDynamicStep())))
1176  return failure();
1177 
1178  return success();
1179 }
1180 
1181 void ForallOp::print(OpAsmPrinter &p) {
1182  Operation *op = getOperation();
1183  p << " (" << getInductionVars();
1184  if (isNormalized()) {
1185  p << ") in ";
1186  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1187  /*valueTypes=*/{}, /*scalables=*/{},
1189  } else {
1190  p << ") = ";
1191  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1192  /*valueTypes=*/{}, /*scalables=*/{},
1194  p << " to ";
1195  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1196  /*valueTypes=*/{}, /*scalables=*/{},
1198  p << " step ";
1199  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1200  /*valueTypes=*/{}, /*scalables=*/{},
1202  }
1203  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1204  p << " ";
1205  if (!getRegionOutArgs().empty())
1206  p << "-> (" << getResultTypes() << ") ";
1207  p.printRegion(getRegion(),
1208  /*printEntryBlockArgs=*/false,
1209  /*printBlockTerminators=*/getNumResults() > 0);
1210  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1211  getStaticLowerBoundAttrName(),
1212  getStaticUpperBoundAttrName(),
1213  getStaticStepAttrName()});
1214 }
1215 
1216 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1217  OpBuilder b(parser.getContext());
1218  auto indexType = b.getIndexType();
1219 
1220  // Parse an opening `(` followed by thread index variables followed by `)`
1221  // TODO: when we can refer to such "induction variable"-like handles from the
1222  // declarative assembly format, we can implement the parser as a custom hook.
1225  return failure();
1226 
1227  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1228  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1229  dynamicSteps;
1230  if (succeeded(parser.parseOptionalKeyword("in"))) {
1231  // Parse upper bounds.
1232  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1233  /*valueTypes=*/nullptr,
1235  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1236  return failure();
1237 
1238  unsigned numLoops = ivs.size();
1239  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1240  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1241  } else {
1242  // Parse lower bounds.
1243  if (parser.parseEqual() ||
1244  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1245  /*valueTypes=*/nullptr,
1247 
1248  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1249  return failure();
1250 
1251  // Parse upper bounds.
1252  if (parser.parseKeyword("to") ||
1253  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1254  /*valueTypes=*/nullptr,
1256  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1257  return failure();
1258 
1259  // Parse step values.
1260  if (parser.parseKeyword("step") ||
1261  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1262  /*valueTypes=*/nullptr,
1264  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1265  return failure();
1266  }
1267 
1268  // Parse out operands and results.
1271  SMLoc outOperandsLoc = parser.getCurrentLocation();
1272  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1273  if (outOperands.size() != result.types.size())
1274  return parser.emitError(outOperandsLoc,
1275  "mismatch between out operands and types");
1276  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1277  parser.parseOptionalArrowTypeList(result.types) ||
1278  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1279  result.operands))
1280  return failure();
1281  }
1282 
1283  // Parse region.
1285  std::unique_ptr<Region> region = std::make_unique<Region>();
1286  for (auto &iv : ivs) {
1287  iv.type = b.getIndexType();
1288  regionArgs.push_back(iv);
1289  }
1290  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1291  auto &out = it.value();
1292  out.type = result.types[it.index()];
1293  regionArgs.push_back(out);
1294  }
1295  if (parser.parseRegion(*region, regionArgs))
1296  return failure();
1297 
1298  // Ensure terminator and move region.
1299  ForallOp::ensureTerminator(*region, b, result.location);
1300  result.addRegion(std::move(region));
1301 
1302  // Parse the optional attribute list.
1303  if (parser.parseOptionalAttrDict(result.attributes))
1304  return failure();
1305 
1306  result.addAttribute("staticLowerBound", staticLbs);
1307  result.addAttribute("staticUpperBound", staticUbs);
1308  result.addAttribute("staticStep", staticSteps);
1309  result.addAttribute("operandSegmentSizes",
1311  {static_cast<int32_t>(dynamicLbs.size()),
1312  static_cast<int32_t>(dynamicUbs.size()),
1313  static_cast<int32_t>(dynamicSteps.size()),
1314  static_cast<int32_t>(outOperands.size())}));
1315  return success();
1316 }
1317 
1318 // Builder that takes loop bounds.
1319 void ForallOp::build(
1322  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1323  std::optional<ArrayAttr> mapping,
1324  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1325  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1326  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1327  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1328  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1329  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1330 
1331  result.addOperands(dynamicLbs);
1332  result.addOperands(dynamicUbs);
1333  result.addOperands(dynamicSteps);
1334  result.addOperands(outputs);
1335  result.addTypes(TypeRange(outputs));
1336 
1337  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1338  b.getDenseI64ArrayAttr(staticLbs));
1339  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1340  b.getDenseI64ArrayAttr(staticUbs));
1341  result.addAttribute(getStaticStepAttrName(result.name),
1342  b.getDenseI64ArrayAttr(staticSteps));
1343  result.addAttribute(
1344  "operandSegmentSizes",
1345  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1346  static_cast<int32_t>(dynamicUbs.size()),
1347  static_cast<int32_t>(dynamicSteps.size()),
1348  static_cast<int32_t>(outputs.size())}));
1349  if (mapping.has_value()) {
1351  mapping.value());
1352  }
1353 
1354  Region *bodyRegion = result.addRegion();
1356  b.createBlock(bodyRegion);
1357  Block &bodyBlock = bodyRegion->front();
1358 
1359  // Add block arguments for indices and outputs.
1360  bodyBlock.addArguments(
1361  SmallVector<Type>(lbs.size(), b.getIndexType()),
1362  SmallVector<Location>(staticLbs.size(), result.location));
1363  bodyBlock.addArguments(
1364  TypeRange(outputs),
1365  SmallVector<Location>(outputs.size(), result.location));
1366 
1367  b.setInsertionPointToStart(&bodyBlock);
1368  if (!bodyBuilderFn) {
1369  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1370  return;
1371  }
1372  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1373 }
1374 
1375 // Builder that takes loop bounds.
1376 void ForallOp::build(
1378  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1379  std::optional<ArrayAttr> mapping,
1380  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1381  unsigned numLoops = ubs.size();
1382  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1383  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1384  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1385 }
1386 
1387 // Checks if the lbs are zeros and steps are ones.
1388 bool ForallOp::isNormalized() {
1389  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1390  return llvm::all_of(results, [&](OpFoldResult ofr) {
1391  auto intValue = getConstantIntValue(ofr);
1392  return intValue.has_value() && intValue == val;
1393  });
1394  };
1395  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1396 }
1397 
1398 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1399 // unaware of the fact that our terminator also needs a region to be
1400 // well-formed. We override it here to ensure that we do the right thing.
1401 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1402  Location loc) {
1404  ForallOp>::ensureTerminator(region, builder, loc);
1405  auto terminator =
1406  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1407  if (terminator.getRegion().empty())
1408  builder.createBlock(&terminator.getRegion());
1409 }
1410 
1411 InParallelOp ForallOp::getTerminator() {
1412  return cast<InParallelOp>(getBody()->getTerminator());
1413 }
1414 
1415 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1416  SmallVector<Operation *> storeOps;
1417  InParallelOp inParallelOp = getTerminator();
1418  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1419  if (auto parallelInsertSliceOp =
1420  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1421  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1422  storeOps.push_back(parallelInsertSliceOp);
1423  }
1424  }
1425  return storeOps;
1426 }
1427 
1428 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1429  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1430 }
1431 
1432 // Get lower bounds as OpFoldResult.
1433 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1434  Builder b(getOperation()->getContext());
1435  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1436 }
1437 
1438 // Get upper bounds as OpFoldResult.
1439 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1440  Builder b(getOperation()->getContext());
1441  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1442 }
1443 
1444 // Get steps as OpFoldResult.
1445 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1446  Builder b(getOperation()->getContext());
1447  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1448 }
1449 
1451  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1452  if (!tidxArg)
1453  return ForallOp();
1454  assert(tidxArg.getOwner() && "unlinked block argument");
1455  auto *containingOp = tidxArg.getOwner()->getParentOp();
1456  return dyn_cast<ForallOp>(containingOp);
1457 }
1458 
1459 namespace {
1460 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1461 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1463 
1464  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1465  PatternRewriter &rewriter) const final {
1466  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1467  if (!forallOp)
1468  return failure();
1469  Value sharedOut =
1470  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1471  ->get();
1472  rewriter.modifyOpInPlace(
1473  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1474  return success();
1475  }
1476 };
1477 
1478 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1479 public:
1481 
1482  LogicalResult matchAndRewrite(ForallOp op,
1483  PatternRewriter &rewriter) const override {
1484  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1485  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1486  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1487  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1488  failed(foldDynamicIndexList(mixedUpperBound)) &&
1489  failed(foldDynamicIndexList(mixedStep)))
1490  return failure();
1491 
1492  rewriter.modifyOpInPlace(op, [&]() {
1493  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1494  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1495  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1496  staticLowerBound);
1497  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1498  op.setStaticLowerBound(staticLowerBound);
1499 
1500  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1501  staticUpperBound);
1502  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1503  op.setStaticUpperBound(staticUpperBound);
1504 
1505  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1506  op.getDynamicStepMutable().assign(dynamicStep);
1507  op.setStaticStep(staticStep);
1508 
1509  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1510  rewriter.getDenseI32ArrayAttr(
1511  {static_cast<int32_t>(dynamicLowerBound.size()),
1512  static_cast<int32_t>(dynamicUpperBound.size()),
1513  static_cast<int32_t>(dynamicStep.size()),
1514  static_cast<int32_t>(op.getNumResults())}));
1515  });
1516  return success();
1517  }
1518 };
1519 
1520 /// The following canonicalization pattern folds the iter arguments of
1521 /// scf.forall op if :-
1522 /// 1. The corresponding result has zero uses.
1523 /// 2. The iter argument is NOT being modified within the loop body.
1524 /// uses.
1525 ///
1526 /// Example of first case :-
1527 /// INPUT:
1528 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1529 /// {
1530 /// ...
1531 /// <SOME USE OF %arg0>
1532 /// <SOME USE OF %arg1>
1533 /// <SOME USE OF %arg2>
1534 /// ...
1535 /// scf.forall.in_parallel {
1536 /// <STORE OP WITH DESTINATION %arg1>
1537 /// <STORE OP WITH DESTINATION %arg0>
1538 /// <STORE OP WITH DESTINATION %arg2>
1539 /// }
1540 /// }
1541 /// return %res#1
1542 ///
1543 /// OUTPUT:
1544 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1545 /// {
1546 /// ...
1547 /// <SOME USE OF %a>
1548 /// <SOME USE OF %new_arg0>
1549 /// <SOME USE OF %c>
1550 /// ...
1551 /// scf.forall.in_parallel {
1552 /// <STORE OP WITH DESTINATION %new_arg0>
1553 /// }
1554 /// }
1555 /// return %res
1556 ///
1557 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1558 /// scf.forall is replaced by their corresponding operands.
1559 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1560 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1561 /// this canonicalization remains valid. For more details, please refer
1562 /// to :
1563 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1564 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1565 /// handles tensor.parallel_insert_slice ops only.
1566 ///
1567 /// Example of second case :-
1568 /// INPUT:
1569 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1570 /// {
1571 /// ...
1572 /// <SOME USE OF %arg0>
1573 /// <SOME USE OF %arg1>
1574 /// ...
1575 /// scf.forall.in_parallel {
1576 /// <STORE OP WITH DESTINATION %arg1>
1577 /// }
1578 /// }
1579 /// return %res#0, %res#1
1580 ///
1581 /// OUTPUT:
1582 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1583 /// {
1584 /// ...
1585 /// <SOME USE OF %a>
1586 /// <SOME USE OF %new_arg0>
1587 /// ...
1588 /// scf.forall.in_parallel {
1589 /// <STORE OP WITH DESTINATION %new_arg0>
1590 /// }
1591 /// }
1592 /// return %a, %res
1593 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1595 
1596  LogicalResult matchAndRewrite(ForallOp forallOp,
1597  PatternRewriter &rewriter) const final {
1598  // Step 1: For a given i-th result of scf.forall, check the following :-
1599  // a. If it has any use.
1600  // b. If the corresponding iter argument is being modified within
1601  // the loop, i.e. has at least one store op with the iter arg as
1602  // its destination operand. For this we use
1603  // ForallOp::getCombiningOps(iter_arg).
1604  //
1605  // Based on the check we maintain the following :-
1606  // a. `resultToDelete` - i-th result of scf.forall that'll be
1607  // deleted.
1608  // b. `resultToReplace` - i-th result of the old scf.forall
1609  // whose uses will be replaced by the new scf.forall.
1610  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1611  // corresponding to the i-th result with at least one use.
1612  SetVector<OpResult> resultToDelete;
1613  SmallVector<Value> resultToReplace;
1614  SmallVector<Value> newOuts;
1615  for (OpResult result : forallOp.getResults()) {
1616  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1617  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1618  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1619  resultToDelete.insert(result);
1620  } else {
1621  resultToReplace.push_back(result);
1622  newOuts.push_back(opOperand->get());
1623  }
1624  }
1625 
1626  // Return early if all results of scf.forall have at least one use and being
1627  // modified within the loop.
1628  if (resultToDelete.empty())
1629  return failure();
1630 
1631  // Step 2: For the the i-th result, do the following :-
1632  // a. Fetch the corresponding BlockArgument.
1633  // b. Look for store ops (currently tensor.parallel_insert_slice)
1634  // with the BlockArgument as its destination operand.
1635  // c. Remove the operations fetched in b.
1636  for (OpResult result : resultToDelete) {
1637  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1638  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1639  SmallVector<Operation *> combiningOps =
1640  forallOp.getCombiningOps(blockArg);
1641  for (Operation *combiningOp : combiningOps)
1642  rewriter.eraseOp(combiningOp);
1643  }
1644 
1645  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1646  // fetched earlier
1647  auto newForallOp = rewriter.create<scf::ForallOp>(
1648  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1649  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1650  forallOp.getMapping(),
1651  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1652 
1653  // Step 4. Merge the block of the old scf.forall into the newly created
1654  // scf.forall using the new set of arguments.
1655  Block *loopBody = forallOp.getBody();
1656  Block *newLoopBody = newForallOp.getBody();
1657  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1658  // Form initial new bbArg list with just the control operands of the new
1659  // scf.forall op.
1660  SmallVector<Value> newBlockArgs =
1661  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1662  [](BlockArgument b) -> Value { return b; });
1663  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1664  unsigned index = 0;
1665  // Take the new corresponding bbArg if the old bbArg was used as a
1666  // destination in the in_parallel op. For all other bbArgs, use the
1667  // corresponding init_arg from the old scf.forall op.
1668  for (OpResult result : forallOp.getResults()) {
1669  if (resultToDelete.count(result)) {
1670  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1671  } else {
1672  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1673  }
1674  }
1675  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1676 
1677  // Step 5. Replace the uses of result of old scf.forall with that of the new
1678  // scf.forall.
1679  for (auto &&[oldResult, newResult] :
1680  llvm::zip(resultToReplace, newForallOp->getResults()))
1681  rewriter.replaceAllUsesWith(oldResult, newResult);
1682 
1683  // Step 6. Replace the uses of those values that either has no use or are
1684  // not being modified within the loop with the corresponding
1685  // OpOperand.
1686  for (OpResult oldResult : resultToDelete)
1687  rewriter.replaceAllUsesWith(oldResult,
1688  forallOp.getTiedOpOperand(oldResult)->get());
1689  return success();
1690  }
1691 };
1692 
1693 struct ForallOpSingleOrZeroIterationDimsFolder
1694  : public OpRewritePattern<ForallOp> {
1696 
1697  LogicalResult matchAndRewrite(ForallOp op,
1698  PatternRewriter &rewriter) const override {
1699  // Do not fold dimensions if they are mapped to processing units.
1700  if (op.getMapping().has_value() && !op.getMapping()->empty())
1701  return failure();
1702  Location loc = op.getLoc();
1703 
1704  // Compute new loop bounds that omit all single-iteration loop dimensions.
1705  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1706  newMixedSteps;
1707  IRMapping mapping;
1708  for (auto [lb, ub, step, iv] :
1709  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1710  op.getMixedStep(), op.getInductionVars())) {
1711  auto numIterations = constantTripCount(lb, ub, step);
1712  if (numIterations.has_value()) {
1713  // Remove the loop if it performs zero iterations.
1714  if (*numIterations == 0) {
1715  rewriter.replaceOp(op, op.getOutputs());
1716  return success();
1717  }
1718  // Replace the loop induction variable by the lower bound if the loop
1719  // performs a single iteration. Otherwise, copy the loop bounds.
1720  if (*numIterations == 1) {
1721  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1722  continue;
1723  }
1724  }
1725  newMixedLowerBounds.push_back(lb);
1726  newMixedUpperBounds.push_back(ub);
1727  newMixedSteps.push_back(step);
1728  }
1729 
1730  // All of the loop dimensions perform a single iteration. Inline loop body.
1731  if (newMixedLowerBounds.empty()) {
1732  promote(rewriter, op);
1733  return success();
1734  }
1735 
1736  // Exit if none of the loop dimensions perform a single iteration.
1737  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1738  return rewriter.notifyMatchFailure(
1739  op, "no dimensions have 0 or 1 iterations");
1740  }
1741 
1742  // Replace the loop by a lower-dimensional loop.
1743  ForallOp newOp;
1744  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1745  newMixedUpperBounds, newMixedSteps,
1746  op.getOutputs(), std::nullopt, nullptr);
1747  newOp.getBodyRegion().getBlocks().clear();
1748  // The new loop needs to keep all attributes from the old one, except for
1749  // "operandSegmentSizes" and static loop bound attributes which capture
1750  // the outdated information of the old iteration domain.
1751  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1752  newOp.getStaticLowerBoundAttrName(),
1753  newOp.getStaticUpperBoundAttrName(),
1754  newOp.getStaticStepAttrName()};
1755  for (const auto &namedAttr : op->getAttrs()) {
1756  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1757  continue;
1758  rewriter.modifyOpInPlace(newOp, [&]() {
1759  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1760  });
1761  }
1762  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1763  newOp.getRegion().begin(), mapping);
1764  rewriter.replaceOp(op, newOp.getResults());
1765  return success();
1766  }
1767 };
1768 
1769 /// Replace all induction vars with a single trip count with their lower bound.
1770 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1772 
1773  LogicalResult matchAndRewrite(ForallOp op,
1774  PatternRewriter &rewriter) const override {
1775  Location loc = op.getLoc();
1776  bool changed = false;
1777  for (auto [lb, ub, step, iv] :
1778  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1779  op.getMixedStep(), op.getInductionVars())) {
1780  if (iv.getUses().begin() == iv.getUses().end())
1781  continue;
1782  auto numIterations = constantTripCount(lb, ub, step);
1783  if (!numIterations.has_value() || numIterations.value() != 1) {
1784  continue;
1785  }
1786  rewriter.replaceAllUsesWith(
1787  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1788  changed = true;
1789  }
1790  return success(changed);
1791  }
1792 };
1793 
1794 struct FoldTensorCastOfOutputIntoForallOp
1795  : public OpRewritePattern<scf::ForallOp> {
1797 
1798  struct TypeCast {
1799  Type srcType;
1800  Type dstType;
1801  };
1802 
1803  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1804  PatternRewriter &rewriter) const final {
1805  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1806  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1807  for (auto en : llvm::enumerate(newOutputTensors)) {
1808  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1809  if (!castOp)
1810  continue;
1811 
1812  // Only casts that that preserve static information, i.e. will make the
1813  // loop result type "more" static than before, will be folded.
1814  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1815  castOp.getSource().getType())) {
1816  continue;
1817  }
1818 
1819  tensorCastProducers[en.index()] =
1820  TypeCast{castOp.getSource().getType(), castOp.getType()};
1821  newOutputTensors[en.index()] = castOp.getSource();
1822  }
1823 
1824  if (tensorCastProducers.empty())
1825  return failure();
1826 
1827  // Create new loop.
1828  Location loc = forallOp.getLoc();
1829  auto newForallOp = rewriter.create<ForallOp>(
1830  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1831  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1832  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1833  auto castBlockArgs =
1834  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1835  for (auto [index, cast] : tensorCastProducers) {
1836  Value &oldTypeBBArg = castBlockArgs[index];
1837  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1838  nestedLoc, cast.dstType, oldTypeBBArg);
1839  }
1840 
1841  // Move old body into new parallel loop.
1842  SmallVector<Value> ivsBlockArgs =
1843  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1844  ivsBlockArgs.append(castBlockArgs);
1845  rewriter.mergeBlocks(forallOp.getBody(),
1846  bbArgs.front().getParentBlock(), ivsBlockArgs);
1847  });
1848 
1849  // After `mergeBlocks` happened, the destinations in the terminator were
1850  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1851  // destination have to be updated to point to the output bbArgs directly.
1852  auto terminator = newForallOp.getTerminator();
1853  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1854  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1855  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1856  insertSliceOp.getDestMutable().assign(outputBlockArg);
1857  }
1858 
1859  // Cast results back to the original types.
1860  rewriter.setInsertionPointAfter(newForallOp);
1861  SmallVector<Value> castResults = newForallOp.getResults();
1862  for (auto &item : tensorCastProducers) {
1863  Value &oldTypeResult = castResults[item.first];
1864  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1865  oldTypeResult);
1866  }
1867  rewriter.replaceOp(forallOp, castResults);
1868  return success();
1869  }
1870 };
1871 
1872 } // namespace
1873 
1874 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1875  MLIRContext *context) {
1876  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1877  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1878  ForallOpSingleOrZeroIterationDimsFolder,
1879  ForallOpReplaceConstantInductionVar>(context);
1880 }
1881 
1882 /// Given the region at `index`, or the parent operation if `index` is None,
1883 /// return the successor regions. These are the regions that may be selected
1884 /// during the flow of control. `operands` is a set of optional attributes that
1885 /// correspond to a constant value for each operand, or null if that operand is
1886 /// not a constant.
1887 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1889  // Both the operation itself and the region may be branching into the body or
1890  // back into the operation itself. It is possible for loop not to enter the
1891  // body.
1892  regions.push_back(RegionSuccessor(&getRegion()));
1893  regions.push_back(RegionSuccessor());
1894 }
1895 
1896 //===----------------------------------------------------------------------===//
1897 // InParallelOp
1898 //===----------------------------------------------------------------------===//
1899 
1900 // Build a InParallelOp with mixed static and dynamic entries.
1901 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1903  Region *bodyRegion = result.addRegion();
1904  b.createBlock(bodyRegion);
1905 }
1906 
1907 LogicalResult InParallelOp::verify() {
1908  scf::ForallOp forallOp =
1909  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1910  if (!forallOp)
1911  return this->emitOpError("expected forall op parent");
1912 
1913  // TODO: InParallelOpInterface.
1914  for (Operation &op : getRegion().front().getOperations()) {
1915  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1916  return this->emitOpError("expected only ")
1917  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1918  }
1919 
1920  // Verify that inserts are into out block arguments.
1921  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1922  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1923  if (!llvm::is_contained(regionOutArgs, dest))
1924  return op.emitOpError("may only insert into an output block argument");
1925  }
1926  return success();
1927 }
1928 
1930  p << " ";
1931  p.printRegion(getRegion(),
1932  /*printEntryBlockArgs=*/false,
1933  /*printBlockTerminators=*/false);
1934  p.printOptionalAttrDict(getOperation()->getAttrs());
1935 }
1936 
1937 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1938  auto &builder = parser.getBuilder();
1939 
1941  std::unique_ptr<Region> region = std::make_unique<Region>();
1942  if (parser.parseRegion(*region, regionOperands))
1943  return failure();
1944 
1945  if (region->empty())
1946  OpBuilder(builder.getContext()).createBlock(region.get());
1947  result.addRegion(std::move(region));
1948 
1949  // Parse the optional attribute list.
1950  if (parser.parseOptionalAttrDict(result.attributes))
1951  return failure();
1952  return success();
1953 }
1954 
1955 OpResult InParallelOp::getParentResult(int64_t idx) {
1956  return getOperation()->getParentOp()->getResult(idx);
1957 }
1958 
1959 SmallVector<BlockArgument> InParallelOp::getDests() {
1960  return llvm::to_vector<4>(
1961  llvm::map_range(getYieldingOps(), [](Operation &op) {
1962  // Add new ops here as needed.
1963  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1964  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1965  }));
1966 }
1967 
1968 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1969  return getRegion().front().getOperations();
1970 }
1971 
1972 //===----------------------------------------------------------------------===//
1973 // IfOp
1974 //===----------------------------------------------------------------------===//
1975 
1977  assert(a && "expected non-empty operation");
1978  assert(b && "expected non-empty operation");
1979 
1980  IfOp ifOp = a->getParentOfType<IfOp>();
1981  while (ifOp) {
1982  // Check if b is inside ifOp. (We already know that a is.)
1983  if (ifOp->isProperAncestor(b))
1984  // b is contained in ifOp. a and b are in mutually exclusive branches if
1985  // they are in different blocks of ifOp.
1986  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1987  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1988  // Check next enclosing IfOp.
1989  ifOp = ifOp->getParentOfType<IfOp>();
1990  }
1991 
1992  // Could not find a common IfOp among a's and b's ancestors.
1993  return false;
1994 }
1995 
1996 LogicalResult
1997 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1998  IfOp::Adaptor adaptor,
1999  SmallVectorImpl<Type> &inferredReturnTypes) {
2000  if (adaptor.getRegions().empty())
2001  return failure();
2002  Region *r = &adaptor.getThenRegion();
2003  if (r->empty())
2004  return failure();
2005  Block &b = r->front();
2006  if (b.empty())
2007  return failure();
2008  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2009  if (!yieldOp)
2010  return failure();
2011  TypeRange types = yieldOp.getOperandTypes();
2012  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2013  types.end());
2014  return success();
2015 }
2016 
2017 void IfOp::build(OpBuilder &builder, OperationState &result,
2018  TypeRange resultTypes, Value cond) {
2019  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2020  /*addElseBlock=*/false);
2021 }
2022 
2023 void IfOp::build(OpBuilder &builder, OperationState &result,
2024  TypeRange resultTypes, Value cond, bool addThenBlock,
2025  bool addElseBlock) {
2026  assert((!addElseBlock || addThenBlock) &&
2027  "must not create else block w/o then block");
2028  result.addTypes(resultTypes);
2029  result.addOperands(cond);
2030 
2031  // Add regions and blocks.
2032  OpBuilder::InsertionGuard guard(builder);
2033  Region *thenRegion = result.addRegion();
2034  if (addThenBlock)
2035  builder.createBlock(thenRegion);
2036  Region *elseRegion = result.addRegion();
2037  if (addElseBlock)
2038  builder.createBlock(elseRegion);
2039 }
2040 
2041 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2042  bool withElseRegion) {
2043  build(builder, result, TypeRange{}, cond, withElseRegion);
2044 }
2045 
2046 void IfOp::build(OpBuilder &builder, OperationState &result,
2047  TypeRange resultTypes, Value cond, bool withElseRegion) {
2048  result.addTypes(resultTypes);
2049  result.addOperands(cond);
2050 
2051  // Build then region.
2052  OpBuilder::InsertionGuard guard(builder);
2053  Region *thenRegion = result.addRegion();
2054  builder.createBlock(thenRegion);
2055  if (resultTypes.empty())
2056  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2057 
2058  // Build else region.
2059  Region *elseRegion = result.addRegion();
2060  if (withElseRegion) {
2061  builder.createBlock(elseRegion);
2062  if (resultTypes.empty())
2063  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2064  }
2065 }
2066 
2067 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2068  function_ref<void(OpBuilder &, Location)> thenBuilder,
2069  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2070  assert(thenBuilder && "the builder callback for 'then' must be present");
2071  result.addOperands(cond);
2072 
2073  // Build then region.
2074  OpBuilder::InsertionGuard guard(builder);
2075  Region *thenRegion = result.addRegion();
2076  builder.createBlock(thenRegion);
2077  thenBuilder(builder, result.location);
2078 
2079  // Build else region.
2080  Region *elseRegion = result.addRegion();
2081  if (elseBuilder) {
2082  builder.createBlock(elseRegion);
2083  elseBuilder(builder, result.location);
2084  }
2085 
2086  // Infer result types.
2087  SmallVector<Type> inferredReturnTypes;
2088  MLIRContext *ctx = builder.getContext();
2089  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2090  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2091  /*properties=*/nullptr, result.regions,
2092  inferredReturnTypes))) {
2093  result.addTypes(inferredReturnTypes);
2094  }
2095 }
2096 
2097 LogicalResult IfOp::verify() {
2098  if (getNumResults() != 0 && getElseRegion().empty())
2099  return emitOpError("must have an else block if defining values");
2100  return success();
2101 }
2102 
2103 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2104  // Create the regions for 'then'.
2105  result.regions.reserve(2);
2106  Region *thenRegion = result.addRegion();
2107  Region *elseRegion = result.addRegion();
2108 
2109  auto &builder = parser.getBuilder();
2111  Type i1Type = builder.getIntegerType(1);
2112  if (parser.parseOperand(cond) ||
2113  parser.resolveOperand(cond, i1Type, result.operands))
2114  return failure();
2115  // Parse optional results type list.
2116  if (parser.parseOptionalArrowTypeList(result.types))
2117  return failure();
2118  // Parse the 'then' region.
2119  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2120  return failure();
2121  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2122 
2123  // If we find an 'else' keyword then parse the 'else' region.
2124  if (!parser.parseOptionalKeyword("else")) {
2125  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2126  return failure();
2127  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2128  }
2129 
2130  // Parse the optional attribute list.
2131  if (parser.parseOptionalAttrDict(result.attributes))
2132  return failure();
2133  return success();
2134 }
2135 
2136 void IfOp::print(OpAsmPrinter &p) {
2137  bool printBlockTerminators = false;
2138 
2139  p << " " << getCondition();
2140  if (!getResults().empty()) {
2141  p << " -> (" << getResultTypes() << ")";
2142  // Print yield explicitly if the op defines values.
2143  printBlockTerminators = true;
2144  }
2145  p << ' ';
2146  p.printRegion(getThenRegion(),
2147  /*printEntryBlockArgs=*/false,
2148  /*printBlockTerminators=*/printBlockTerminators);
2149 
2150  // Print the 'else' regions if it exists and has a block.
2151  auto &elseRegion = getElseRegion();
2152  if (!elseRegion.empty()) {
2153  p << " else ";
2154  p.printRegion(elseRegion,
2155  /*printEntryBlockArgs=*/false,
2156  /*printBlockTerminators=*/printBlockTerminators);
2157  }
2158 
2159  p.printOptionalAttrDict((*this)->getAttrs());
2160 }
2161 
2162 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2164  // The `then` and the `else` region branch back to the parent operation.
2165  if (!point.isParent()) {
2166  regions.push_back(RegionSuccessor(getResults()));
2167  return;
2168  }
2169 
2170  regions.push_back(RegionSuccessor(&getThenRegion()));
2171 
2172  // Don't consider the else region if it is empty.
2173  Region *elseRegion = &this->getElseRegion();
2174  if (elseRegion->empty())
2175  regions.push_back(RegionSuccessor());
2176  else
2177  regions.push_back(RegionSuccessor(elseRegion));
2178 }
2179 
2180 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2182  FoldAdaptor adaptor(operands, *this);
2183  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2184  if (!boolAttr || boolAttr.getValue())
2185  regions.emplace_back(&getThenRegion());
2186 
2187  // If the else region is empty, execution continues after the parent op.
2188  if (!boolAttr || !boolAttr.getValue()) {
2189  if (!getElseRegion().empty())
2190  regions.emplace_back(&getElseRegion());
2191  else
2192  regions.emplace_back(getResults());
2193  }
2194 }
2195 
2196 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2197  SmallVectorImpl<OpFoldResult> &results) {
2198  // if (!c) then A() else B() -> if c then B() else A()
2199  if (getElseRegion().empty())
2200  return failure();
2201 
2202  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2203  if (!xorStmt)
2204  return failure();
2205 
2206  if (!matchPattern(xorStmt.getRhs(), m_One()))
2207  return failure();
2208 
2209  getConditionMutable().assign(xorStmt.getLhs());
2210  Block *thenBlock = &getThenRegion().front();
2211  // It would be nicer to use iplist::swap, but that has no implemented
2212  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2213  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2214  getElseRegion().getBlocks());
2215  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2216  getThenRegion().getBlocks(), thenBlock);
2217  return success();
2218 }
2219 
2220 void IfOp::getRegionInvocationBounds(
2221  ArrayRef<Attribute> operands,
2222  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2223  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2224  // If the condition is known, then one region is known to be executed once
2225  // and the other zero times.
2226  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2227  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2228  } else {
2229  // Non-constant condition. Each region may be executed 0 or 1 times.
2230  invocationBounds.assign(2, {0, 1});
2231  }
2232 }
2233 
2234 namespace {
2235 // Pattern to remove unused IfOp results.
2236 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2238 
2239  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2240  PatternRewriter &rewriter) const {
2241  // Move all operations to the destination block.
2242  rewriter.mergeBlocks(source, dest);
2243  // Replace the yield op by one that returns only the used values.
2244  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2245  SmallVector<Value, 4> usedOperands;
2246  llvm::transform(usedResults, std::back_inserter(usedOperands),
2247  [&](OpResult result) {
2248  return yieldOp.getOperand(result.getResultNumber());
2249  });
2250  rewriter.modifyOpInPlace(yieldOp,
2251  [&]() { yieldOp->setOperands(usedOperands); });
2252  }
2253 
2254  LogicalResult matchAndRewrite(IfOp op,
2255  PatternRewriter &rewriter) const override {
2256  // Compute the list of used results.
2257  SmallVector<OpResult, 4> usedResults;
2258  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2259  [](OpResult result) { return !result.use_empty(); });
2260 
2261  // Replace the operation if only a subset of its results have uses.
2262  if (usedResults.size() == op.getNumResults())
2263  return failure();
2264 
2265  // Compute the result types of the replacement operation.
2266  SmallVector<Type, 4> newTypes;
2267  llvm::transform(usedResults, std::back_inserter(newTypes),
2268  [](OpResult result) { return result.getType(); });
2269 
2270  // Create a replacement operation with empty then and else regions.
2271  auto newOp =
2272  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2273  rewriter.createBlock(&newOp.getThenRegion());
2274  rewriter.createBlock(&newOp.getElseRegion());
2275 
2276  // Move the bodies and replace the terminators (note there is a then and
2277  // an else region since the operation returns results).
2278  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2279  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2280 
2281  // Replace the operation by the new one.
2282  SmallVector<Value, 4> repResults(op.getNumResults());
2283  for (const auto &en : llvm::enumerate(usedResults))
2284  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2285  rewriter.replaceOp(op, repResults);
2286  return success();
2287  }
2288 };
2289 
2290 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2292 
2293  LogicalResult matchAndRewrite(IfOp op,
2294  PatternRewriter &rewriter) const override {
2295  BoolAttr condition;
2296  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2297  return failure();
2298 
2299  if (condition.getValue())
2300  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2301  else if (!op.getElseRegion().empty())
2302  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2303  else
2304  rewriter.eraseOp(op);
2305 
2306  return success();
2307  }
2308 };
2309 
2310 /// Hoist any yielded results whose operands are defined outside
2311 /// the if, to a select instruction.
2312 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2314 
2315  LogicalResult matchAndRewrite(IfOp op,
2316  PatternRewriter &rewriter) const override {
2317  if (op->getNumResults() == 0)
2318  return failure();
2319 
2320  auto cond = op.getCondition();
2321  auto thenYieldArgs = op.thenYield().getOperands();
2322  auto elseYieldArgs = op.elseYield().getOperands();
2323 
2324  SmallVector<Type> nonHoistable;
2325  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2326  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2327  &op.getElseRegion() == falseVal.getParentRegion())
2328  nonHoistable.push_back(trueVal.getType());
2329  }
2330  // Early exit if there aren't any yielded values we can
2331  // hoist outside the if.
2332  if (nonHoistable.size() == op->getNumResults())
2333  return failure();
2334 
2335  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2336  /*withElseRegion=*/false);
2337  if (replacement.thenBlock())
2338  rewriter.eraseBlock(replacement.thenBlock());
2339  replacement.getThenRegion().takeBody(op.getThenRegion());
2340  replacement.getElseRegion().takeBody(op.getElseRegion());
2341 
2342  SmallVector<Value> results(op->getNumResults());
2343  assert(thenYieldArgs.size() == results.size());
2344  assert(elseYieldArgs.size() == results.size());
2345 
2346  SmallVector<Value> trueYields;
2347  SmallVector<Value> falseYields;
2348  rewriter.setInsertionPoint(replacement);
2349  for (const auto &it :
2350  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2351  Value trueVal = std::get<0>(it.value());
2352  Value falseVal = std::get<1>(it.value());
2353  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2354  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2355  results[it.index()] = replacement.getResult(trueYields.size());
2356  trueYields.push_back(trueVal);
2357  falseYields.push_back(falseVal);
2358  } else if (trueVal == falseVal)
2359  results[it.index()] = trueVal;
2360  else
2361  results[it.index()] = rewriter.create<arith::SelectOp>(
2362  op.getLoc(), cond, trueVal, falseVal);
2363  }
2364 
2365  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2366  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2367 
2368  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2369  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2370 
2371  rewriter.replaceOp(op, results);
2372  return success();
2373  }
2374 };
2375 
2376 /// Allow the true region of an if to assume the condition is true
2377 /// and vice versa. For example:
2378 ///
2379 /// scf.if %cmp {
2380 /// print(%cmp)
2381 /// }
2382 ///
2383 /// becomes
2384 ///
2385 /// scf.if %cmp {
2386 /// print(true)
2387 /// }
2388 ///
2389 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2391 
2392  LogicalResult matchAndRewrite(IfOp op,
2393  PatternRewriter &rewriter) const override {
2394  // Early exit if the condition is constant since replacing a constant
2395  // in the body with another constant isn't a simplification.
2396  if (matchPattern(op.getCondition(), m_Constant()))
2397  return failure();
2398 
2399  bool changed = false;
2400  mlir::Type i1Ty = rewriter.getI1Type();
2401 
2402  // These variables serve to prevent creating duplicate constants
2403  // and hold constant true or false values.
2404  Value constantTrue = nullptr;
2405  Value constantFalse = nullptr;
2406 
2407  for (OpOperand &use :
2408  llvm::make_early_inc_range(op.getCondition().getUses())) {
2409  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2410  changed = true;
2411 
2412  if (!constantTrue)
2413  constantTrue = rewriter.create<arith::ConstantOp>(
2414  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2415 
2416  rewriter.modifyOpInPlace(use.getOwner(),
2417  [&]() { use.set(constantTrue); });
2418  } else if (op.getElseRegion().isAncestor(
2419  use.getOwner()->getParentRegion())) {
2420  changed = true;
2421 
2422  if (!constantFalse)
2423  constantFalse = rewriter.create<arith::ConstantOp>(
2424  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2425 
2426  rewriter.modifyOpInPlace(use.getOwner(),
2427  [&]() { use.set(constantFalse); });
2428  }
2429  }
2430 
2431  return success(changed);
2432  }
2433 };
2434 
2435 /// Remove any statements from an if that are equivalent to the condition
2436 /// or its negation. For example:
2437 ///
2438 /// %res:2 = scf.if %cmp {
2439 /// yield something(), true
2440 /// } else {
2441 /// yield something2(), false
2442 /// }
2443 /// print(%res#1)
2444 ///
2445 /// becomes
2446 /// %res = scf.if %cmp {
2447 /// yield something()
2448 /// } else {
2449 /// yield something2()
2450 /// }
2451 /// print(%cmp)
2452 ///
2453 /// Additionally if both branches yield the same value, replace all uses
2454 /// of the result with the yielded value.
2455 ///
2456 /// %res:2 = scf.if %cmp {
2457 /// yield something(), %arg1
2458 /// } else {
2459 /// yield something2(), %arg1
2460 /// }
2461 /// print(%res#1)
2462 ///
2463 /// becomes
2464 /// %res = scf.if %cmp {
2465 /// yield something()
2466 /// } else {
2467 /// yield something2()
2468 /// }
2469 /// print(%arg1)
2470 ///
2471 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2473 
2474  LogicalResult matchAndRewrite(IfOp op,
2475  PatternRewriter &rewriter) const override {
2476  // Early exit if there are no results that could be replaced.
2477  if (op.getNumResults() == 0)
2478  return failure();
2479 
2480  auto trueYield =
2481  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2482  auto falseYield =
2483  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2484 
2485  rewriter.setInsertionPoint(op->getBlock(),
2486  op.getOperation()->getIterator());
2487  bool changed = false;
2488  Type i1Ty = rewriter.getI1Type();
2489  for (auto [trueResult, falseResult, opResult] :
2490  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2491  op.getResults())) {
2492  if (trueResult == falseResult) {
2493  if (!opResult.use_empty()) {
2494  opResult.replaceAllUsesWith(trueResult);
2495  changed = true;
2496  }
2497  continue;
2498  }
2499 
2500  BoolAttr trueYield, falseYield;
2501  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2502  !matchPattern(falseResult, m_Constant(&falseYield)))
2503  continue;
2504 
2505  bool trueVal = trueYield.getValue();
2506  bool falseVal = falseYield.getValue();
2507  if (!trueVal && falseVal) {
2508  if (!opResult.use_empty()) {
2509  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2510  Value notCond = rewriter.create<arith::XOrIOp>(
2511  op.getLoc(), op.getCondition(),
2512  constDialect
2513  ->materializeConstant(rewriter,
2514  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2515  op.getLoc())
2516  ->getResult(0));
2517  opResult.replaceAllUsesWith(notCond);
2518  changed = true;
2519  }
2520  }
2521  if (trueVal && !falseVal) {
2522  if (!opResult.use_empty()) {
2523  opResult.replaceAllUsesWith(op.getCondition());
2524  changed = true;
2525  }
2526  }
2527  }
2528  return success(changed);
2529  }
2530 };
2531 
2532 /// Merge any consecutive scf.if's with the same condition.
2533 ///
2534 /// scf.if %cond {
2535 /// firstCodeTrue();...
2536 /// } else {
2537 /// firstCodeFalse();...
2538 /// }
2539 /// %res = scf.if %cond {
2540 /// secondCodeTrue();...
2541 /// } else {
2542 /// secondCodeFalse();...
2543 /// }
2544 ///
2545 /// becomes
2546 /// %res = scf.if %cmp {
2547 /// firstCodeTrue();...
2548 /// secondCodeTrue();...
2549 /// } else {
2550 /// firstCodeFalse();...
2551 /// secondCodeFalse();...
2552 /// }
2553 struct CombineIfs : public OpRewritePattern<IfOp> {
2555 
2556  LogicalResult matchAndRewrite(IfOp nextIf,
2557  PatternRewriter &rewriter) const override {
2558  Block *parent = nextIf->getBlock();
2559  if (nextIf == &parent->front())
2560  return failure();
2561 
2562  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2563  if (!prevIf)
2564  return failure();
2565 
2566  // Determine the logical then/else blocks when prevIf's
2567  // condition is used. Null means the block does not exist
2568  // in that case (e.g. empty else). If neither of these
2569  // are set, the two conditions cannot be compared.
2570  Block *nextThen = nullptr;
2571  Block *nextElse = nullptr;
2572  if (nextIf.getCondition() == prevIf.getCondition()) {
2573  nextThen = nextIf.thenBlock();
2574  if (!nextIf.getElseRegion().empty())
2575  nextElse = nextIf.elseBlock();
2576  }
2577  if (arith::XOrIOp notv =
2578  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2579  if (notv.getLhs() == prevIf.getCondition() &&
2580  matchPattern(notv.getRhs(), m_One())) {
2581  nextElse = nextIf.thenBlock();
2582  if (!nextIf.getElseRegion().empty())
2583  nextThen = nextIf.elseBlock();
2584  }
2585  }
2586  if (arith::XOrIOp notv =
2587  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2588  if (notv.getLhs() == nextIf.getCondition() &&
2589  matchPattern(notv.getRhs(), m_One())) {
2590  nextElse = nextIf.thenBlock();
2591  if (!nextIf.getElseRegion().empty())
2592  nextThen = nextIf.elseBlock();
2593  }
2594  }
2595 
2596  if (!nextThen && !nextElse)
2597  return failure();
2598 
2599  SmallVector<Value> prevElseYielded;
2600  if (!prevIf.getElseRegion().empty())
2601  prevElseYielded = prevIf.elseYield().getOperands();
2602  // Replace all uses of return values of op within nextIf with the
2603  // corresponding yields
2604  for (auto it : llvm::zip(prevIf.getResults(),
2605  prevIf.thenYield().getOperands(), prevElseYielded))
2606  for (OpOperand &use :
2607  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2608  if (nextThen && nextThen->getParent()->isAncestor(
2609  use.getOwner()->getParentRegion())) {
2610  rewriter.startOpModification(use.getOwner());
2611  use.set(std::get<1>(it));
2612  rewriter.finalizeOpModification(use.getOwner());
2613  } else if (nextElse && nextElse->getParent()->isAncestor(
2614  use.getOwner()->getParentRegion())) {
2615  rewriter.startOpModification(use.getOwner());
2616  use.set(std::get<2>(it));
2617  rewriter.finalizeOpModification(use.getOwner());
2618  }
2619  }
2620 
2621  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2622  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2623 
2624  IfOp combinedIf = rewriter.create<IfOp>(
2625  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2626  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2627 
2628  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2629  combinedIf.getThenRegion(),
2630  combinedIf.getThenRegion().begin());
2631 
2632  if (nextThen) {
2633  YieldOp thenYield = combinedIf.thenYield();
2634  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2635  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2636  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2637 
2638  SmallVector<Value> mergedYields(thenYield.getOperands());
2639  llvm::append_range(mergedYields, thenYield2.getOperands());
2640  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2641  rewriter.eraseOp(thenYield);
2642  rewriter.eraseOp(thenYield2);
2643  }
2644 
2645  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2646  combinedIf.getElseRegion(),
2647  combinedIf.getElseRegion().begin());
2648 
2649  if (nextElse) {
2650  if (combinedIf.getElseRegion().empty()) {
2651  rewriter.inlineRegionBefore(*nextElse->getParent(),
2652  combinedIf.getElseRegion(),
2653  combinedIf.getElseRegion().begin());
2654  } else {
2655  YieldOp elseYield = combinedIf.elseYield();
2656  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2657  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2658 
2659  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2660 
2661  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2662  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2663 
2664  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2665  rewriter.eraseOp(elseYield);
2666  rewriter.eraseOp(elseYield2);
2667  }
2668  }
2669 
2670  SmallVector<Value> prevValues;
2671  SmallVector<Value> nextValues;
2672  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2673  if (pair.index() < prevIf.getNumResults())
2674  prevValues.push_back(pair.value());
2675  else
2676  nextValues.push_back(pair.value());
2677  }
2678  rewriter.replaceOp(prevIf, prevValues);
2679  rewriter.replaceOp(nextIf, nextValues);
2680  return success();
2681  }
2682 };
2683 
2684 /// Pattern to remove an empty else branch.
2685 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2687 
2688  LogicalResult matchAndRewrite(IfOp ifOp,
2689  PatternRewriter &rewriter) const override {
2690  // Cannot remove else region when there are operation results.
2691  if (ifOp.getNumResults())
2692  return failure();
2693  Block *elseBlock = ifOp.elseBlock();
2694  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2695  return failure();
2696  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2697  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2698  newIfOp.getThenRegion().begin());
2699  rewriter.eraseOp(ifOp);
2700  return success();
2701  }
2702 };
2703 
2704 /// Convert nested `if`s into `arith.andi` + single `if`.
2705 ///
2706 /// scf.if %arg0 {
2707 /// scf.if %arg1 {
2708 /// ...
2709 /// scf.yield
2710 /// }
2711 /// scf.yield
2712 /// }
2713 /// becomes
2714 ///
2715 /// %0 = arith.andi %arg0, %arg1
2716 /// scf.if %0 {
2717 /// ...
2718 /// scf.yield
2719 /// }
2720 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2722 
2723  LogicalResult matchAndRewrite(IfOp op,
2724  PatternRewriter &rewriter) const override {
2725  auto nestedOps = op.thenBlock()->without_terminator();
2726  // Nested `if` must be the only op in block.
2727  if (!llvm::hasSingleElement(nestedOps))
2728  return failure();
2729 
2730  // If there is an else block, it can only yield
2731  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2732  return failure();
2733 
2734  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2735  if (!nestedIf)
2736  return failure();
2737 
2738  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2739  return failure();
2740 
2741  SmallVector<Value> thenYield(op.thenYield().getOperands());
2742  SmallVector<Value> elseYield;
2743  if (op.elseBlock())
2744  llvm::append_range(elseYield, op.elseYield().getOperands());
2745 
2746  // A list of indices for which we should upgrade the value yielded
2747  // in the else to a select.
2748  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2749 
2750  // If the outer scf.if yields a value produced by the inner scf.if,
2751  // only permit combining if the value yielded when the condition
2752  // is false in the outer scf.if is the same value yielded when the
2753  // inner scf.if condition is false.
2754  // Note that the array access to elseYield will not go out of bounds
2755  // since it must have the same length as thenYield, since they both
2756  // come from the same scf.if.
2757  for (const auto &tup : llvm::enumerate(thenYield)) {
2758  if (tup.value().getDefiningOp() == nestedIf) {
2759  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2760  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2761  elseYield[tup.index()]) {
2762  return failure();
2763  }
2764  // If the correctness test passes, we will yield
2765  // corresponding value from the inner scf.if
2766  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2767  continue;
2768  }
2769 
2770  // Otherwise, we need to ensure the else block of the combined
2771  // condition still returns the same value when the outer condition is
2772  // true and the inner condition is false. This can be accomplished if
2773  // the then value is defined outside the outer scf.if and we replace the
2774  // value with a select that considers just the outer condition. Since
2775  // the else region contains just the yield, its yielded value is
2776  // defined outside the scf.if, by definition.
2777 
2778  // If the then value is defined within the scf.if, bail.
2779  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2780  return failure();
2781  }
2782  elseYieldsToUpgradeToSelect.push_back(tup.index());
2783  }
2784 
2785  Location loc = op.getLoc();
2786  Value newCondition = rewriter.create<arith::AndIOp>(
2787  loc, op.getCondition(), nestedIf.getCondition());
2788  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2789  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2790 
2791  SmallVector<Value> results;
2792  llvm::append_range(results, newIf.getResults());
2793  rewriter.setInsertionPoint(newIf);
2794 
2795  for (auto idx : elseYieldsToUpgradeToSelect)
2796  results[idx] = rewriter.create<arith::SelectOp>(
2797  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2798 
2799  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2800  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2801  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2802  if (!elseYield.empty()) {
2803  rewriter.createBlock(&newIf.getElseRegion());
2804  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2805  rewriter.create<YieldOp>(loc, elseYield);
2806  }
2807  rewriter.replaceOp(op, results);
2808  return success();
2809  }
2810 };
2811 
2812 } // namespace
2813 
2814 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2815  MLIRContext *context) {
2816  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2817  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2818  RemoveStaticCondition, RemoveUnusedResults,
2819  ReplaceIfYieldWithConditionOrValue>(context);
2820 }
2821 
2822 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2823 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2824 Block *IfOp::elseBlock() {
2825  Region &r = getElseRegion();
2826  if (r.empty())
2827  return nullptr;
2828  return &r.back();
2829 }
2830 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2831 
2832 //===----------------------------------------------------------------------===//
2833 // ParallelOp
2834 //===----------------------------------------------------------------------===//
2835 
2836 void ParallelOp::build(
2837  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2838  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2840  bodyBuilderFn) {
2841  result.addOperands(lowerBounds);
2842  result.addOperands(upperBounds);
2843  result.addOperands(steps);
2844  result.addOperands(initVals);
2845  result.addAttribute(
2846  ParallelOp::getOperandSegmentSizeAttr(),
2847  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2848  static_cast<int32_t>(upperBounds.size()),
2849  static_cast<int32_t>(steps.size()),
2850  static_cast<int32_t>(initVals.size())}));
2851  result.addTypes(initVals.getTypes());
2852 
2853  OpBuilder::InsertionGuard guard(builder);
2854  unsigned numIVs = steps.size();
2855  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2856  SmallVector<Location, 8> argLocs(numIVs, result.location);
2857  Region *bodyRegion = result.addRegion();
2858  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2859 
2860  if (bodyBuilderFn) {
2861  builder.setInsertionPointToStart(bodyBlock);
2862  bodyBuilderFn(builder, result.location,
2863  bodyBlock->getArguments().take_front(numIVs),
2864  bodyBlock->getArguments().drop_front(numIVs));
2865  }
2866  // Add terminator only if there are no reductions.
2867  if (initVals.empty())
2868  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2869 }
2870 
2871 void ParallelOp::build(
2872  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2873  ValueRange upperBounds, ValueRange steps,
2874  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2875  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2876  // we don't capture a reference to a temporary by constructing the lambda at
2877  // function level.
2878  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2879  Location nestedLoc, ValueRange ivs,
2880  ValueRange) {
2881  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2882  };
2883  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2884  if (bodyBuilderFn)
2885  wrapper = wrappedBuilderFn;
2886 
2887  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2888  wrapper);
2889 }
2890 
2891 LogicalResult ParallelOp::verify() {
2892  // Check that there is at least one value in lowerBound, upperBound and step.
2893  // It is sufficient to test only step, because it is ensured already that the
2894  // number of elements in lowerBound, upperBound and step are the same.
2895  Operation::operand_range stepValues = getStep();
2896  if (stepValues.empty())
2897  return emitOpError(
2898  "needs at least one tuple element for lowerBound, upperBound and step");
2899 
2900  // Check whether all constant step values are positive.
2901  for (Value stepValue : stepValues)
2902  if (auto cst = getConstantIntValue(stepValue))
2903  if (*cst <= 0)
2904  return emitOpError("constant step operand must be positive");
2905 
2906  // Check that the body defines the same number of block arguments as the
2907  // number of tuple elements in step.
2908  Block *body = getBody();
2909  if (body->getNumArguments() != stepValues.size())
2910  return emitOpError() << "expects the same number of induction variables: "
2911  << body->getNumArguments()
2912  << " as bound and step values: " << stepValues.size();
2913  for (auto arg : body->getArguments())
2914  if (!arg.getType().isIndex())
2915  return emitOpError(
2916  "expects arguments for the induction variable to be of index type");
2917 
2918  // Check that the terminator is an scf.reduce op.
2919  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2920  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2921  if (!reduceOp)
2922  return failure();
2923 
2924  // Check that the number of results is the same as the number of reductions.
2925  auto resultsSize = getResults().size();
2926  auto reductionsSize = reduceOp.getReductions().size();
2927  auto initValsSize = getInitVals().size();
2928  if (resultsSize != reductionsSize)
2929  return emitOpError() << "expects number of results: " << resultsSize
2930  << " to be the same as number of reductions: "
2931  << reductionsSize;
2932  if (resultsSize != initValsSize)
2933  return emitOpError() << "expects number of results: " << resultsSize
2934  << " to be the same as number of initial values: "
2935  << initValsSize;
2936 
2937  // Check that the types of the results and reductions are the same.
2938  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2939  auto resultType = getOperation()->getResult(i).getType();
2940  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2941  if (resultType != reductionOperandType)
2942  return reduceOp.emitOpError()
2943  << "expects type of " << i
2944  << "-th reduction operand: " << reductionOperandType
2945  << " to be the same as the " << i
2946  << "-th result type: " << resultType;
2947  }
2948  return success();
2949 }
2950 
2951 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2952  auto &builder = parser.getBuilder();
2953  // Parse an opening `(` followed by induction variables followed by `)`
2956  return failure();
2957 
2958  // Parse loop bounds.
2960  if (parser.parseEqual() ||
2961  parser.parseOperandList(lower, ivs.size(),
2963  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2964  return failure();
2965 
2967  if (parser.parseKeyword("to") ||
2968  parser.parseOperandList(upper, ivs.size(),
2970  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2971  return failure();
2972 
2973  // Parse step values.
2975  if (parser.parseKeyword("step") ||
2976  parser.parseOperandList(steps, ivs.size(),
2978  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2979  return failure();
2980 
2981  // Parse init values.
2983  if (succeeded(parser.parseOptionalKeyword("init"))) {
2984  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2985  return failure();
2986  }
2987 
2988  // Parse optional results in case there is a reduce.
2989  if (parser.parseOptionalArrowTypeList(result.types))
2990  return failure();
2991 
2992  // Now parse the body.
2993  Region *body = result.addRegion();
2994  for (auto &iv : ivs)
2995  iv.type = builder.getIndexType();
2996  if (parser.parseRegion(*body, ivs))
2997  return failure();
2998 
2999  // Set `operandSegmentSizes` attribute.
3000  result.addAttribute(
3001  ParallelOp::getOperandSegmentSizeAttr(),
3002  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3003  static_cast<int32_t>(upper.size()),
3004  static_cast<int32_t>(steps.size()),
3005  static_cast<int32_t>(initVals.size())}));
3006 
3007  // Parse attributes.
3008  if (parser.parseOptionalAttrDict(result.attributes) ||
3009  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3010  result.operands))
3011  return failure();
3012 
3013  // Add a terminator if none was parsed.
3014  ParallelOp::ensureTerminator(*body, builder, result.location);
3015  return success();
3016 }
3017 
3019  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3020  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3021  if (!getInitVals().empty())
3022  p << " init (" << getInitVals() << ")";
3023  p.printOptionalArrowTypeList(getResultTypes());
3024  p << ' ';
3025  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3027  (*this)->getAttrs(),
3028  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3029 }
3030 
3031 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3032 
3033 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3034  return SmallVector<Value>{getBody()->getArguments()};
3035 }
3036 
3037 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3038  return getLowerBound();
3039 }
3040 
3041 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3042  return getUpperBound();
3043 }
3044 
3045 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3046  return getStep();
3047 }
3048 
3050  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3051  if (!ivArg)
3052  return ParallelOp();
3053  assert(ivArg.getOwner() && "unlinked block argument");
3054  auto *containingOp = ivArg.getOwner()->getParentOp();
3055  return dyn_cast<ParallelOp>(containingOp);
3056 }
3057 
3058 namespace {
3059 // Collapse loop dimensions that perform a single iteration.
3060 struct ParallelOpSingleOrZeroIterationDimsFolder
3061  : public OpRewritePattern<ParallelOp> {
3063 
3064  LogicalResult matchAndRewrite(ParallelOp op,
3065  PatternRewriter &rewriter) const override {
3066  Location loc = op.getLoc();
3067 
3068  // Compute new loop bounds that omit all single-iteration loop dimensions.
3069  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3070  IRMapping mapping;
3071  for (auto [lb, ub, step, iv] :
3072  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3073  op.getInductionVars())) {
3074  auto numIterations = constantTripCount(lb, ub, step);
3075  if (numIterations.has_value()) {
3076  // Remove the loop if it performs zero iterations.
3077  if (*numIterations == 0) {
3078  rewriter.replaceOp(op, op.getInitVals());
3079  return success();
3080  }
3081  // Replace the loop induction variable by the lower bound if the loop
3082  // performs a single iteration. Otherwise, copy the loop bounds.
3083  if (*numIterations == 1) {
3084  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3085  continue;
3086  }
3087  }
3088  newLowerBounds.push_back(lb);
3089  newUpperBounds.push_back(ub);
3090  newSteps.push_back(step);
3091  }
3092  // Exit if none of the loop dimensions perform a single iteration.
3093  if (newLowerBounds.size() == op.getLowerBound().size())
3094  return failure();
3095 
3096  if (newLowerBounds.empty()) {
3097  // All of the loop dimensions perform a single iteration. Inline
3098  // loop body and nested ReduceOp's
3099  SmallVector<Value> results;
3100  results.reserve(op.getInitVals().size());
3101  for (auto &bodyOp : op.getBody()->without_terminator())
3102  rewriter.clone(bodyOp, mapping);
3103  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3104  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3105  Block &reduceBlock = reduceOp.getReductions()[i].front();
3106  auto initValIndex = results.size();
3107  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3108  mapping.map(reduceBlock.getArgument(1),
3109  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3110  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3111  rewriter.clone(reduceBodyOp, mapping);
3112 
3113  auto result = mapping.lookupOrDefault(
3114  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3115  results.push_back(result);
3116  }
3117 
3118  rewriter.replaceOp(op, results);
3119  return success();
3120  }
3121  // Replace the parallel loop by lower-dimensional parallel loop.
3122  auto newOp =
3123  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3124  newSteps, op.getInitVals(), nullptr);
3125  // Erase the empty block that was inserted by the builder.
3126  rewriter.eraseBlock(newOp.getBody());
3127  // Clone the loop body and remap the block arguments of the collapsed loops
3128  // (inlining does not support a cancellable block argument mapping).
3129  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3130  newOp.getRegion().begin(), mapping);
3131  rewriter.replaceOp(op, newOp.getResults());
3132  return success();
3133  }
3134 };
3135 
3136 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3138 
3139  LogicalResult matchAndRewrite(ParallelOp op,
3140  PatternRewriter &rewriter) const override {
3141  Block &outerBody = *op.getBody();
3142  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3143  return failure();
3144 
3145  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3146  if (!innerOp)
3147  return failure();
3148 
3149  for (auto val : outerBody.getArguments())
3150  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3151  llvm::is_contained(innerOp.getUpperBound(), val) ||
3152  llvm::is_contained(innerOp.getStep(), val))
3153  return failure();
3154 
3155  // Reductions are not supported yet.
3156  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3157  return failure();
3158 
3159  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3160  ValueRange iterVals, ValueRange) {
3161  Block &innerBody = *innerOp.getBody();
3162  assert(iterVals.size() ==
3163  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3164  IRMapping mapping;
3165  mapping.map(outerBody.getArguments(),
3166  iterVals.take_front(outerBody.getNumArguments()));
3167  mapping.map(innerBody.getArguments(),
3168  iterVals.take_back(innerBody.getNumArguments()));
3169  for (Operation &op : innerBody.without_terminator())
3170  builder.clone(op, mapping);
3171  };
3172 
3173  auto concatValues = [](const auto &first, const auto &second) {
3174  SmallVector<Value> ret;
3175  ret.reserve(first.size() + second.size());
3176  ret.assign(first.begin(), first.end());
3177  ret.append(second.begin(), second.end());
3178  return ret;
3179  };
3180 
3181  auto newLowerBounds =
3182  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3183  auto newUpperBounds =
3184  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3185  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3186 
3187  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3188  newSteps, std::nullopt,
3189  bodyBuilder);
3190  return success();
3191  }
3192 };
3193 
3194 } // namespace
3195 
3196 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3197  MLIRContext *context) {
3198  results
3199  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3200  context);
3201 }
3202 
3203 /// Given the region at `index`, or the parent operation if `index` is None,
3204 /// return the successor regions. These are the regions that may be selected
3205 /// during the flow of control. `operands` is a set of optional attributes that
3206 /// correspond to a constant value for each operand, or null if that operand is
3207 /// not a constant.
3208 void ParallelOp::getSuccessorRegions(
3210  // Both the operation itself and the region may be branching into the body or
3211  // back into the operation itself. It is possible for loop not to enter the
3212  // body.
3213  regions.push_back(RegionSuccessor(&getRegion()));
3214  regions.push_back(RegionSuccessor());
3215 }
3216 
3217 //===----------------------------------------------------------------------===//
3218 // ReduceOp
3219 //===----------------------------------------------------------------------===//
3220 
3221 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3222 
3223 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3224  ValueRange operands) {
3225  result.addOperands(operands);
3226  for (Value v : operands) {
3227  OpBuilder::InsertionGuard guard(builder);
3228  Region *bodyRegion = result.addRegion();
3229  builder.createBlock(bodyRegion, {},
3230  ArrayRef<Type>{v.getType(), v.getType()},
3231  {result.location, result.location});
3232  }
3233 }
3234 
3235 LogicalResult ReduceOp::verifyRegions() {
3236  // The region of a ReduceOp has two arguments of the same type as its
3237  // corresponding operand.
3238  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3239  auto type = getOperands()[i].getType();
3240  Block &block = getReductions()[i].front();
3241  if (block.empty())
3242  return emitOpError() << i << "-th reduction has an empty body";
3243  if (block.getNumArguments() != 2 ||
3244  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3245  return arg.getType() != type;
3246  }))
3247  return emitOpError() << "expected two block arguments with type " << type
3248  << " in the " << i << "-th reduction region";
3249 
3250  // Check that the block is terminated by a ReduceReturnOp.
3251  if (!isa<ReduceReturnOp>(block.getTerminator()))
3252  return emitOpError("reduction bodies must be terminated with an "
3253  "'scf.reduce.return' op");
3254  }
3255 
3256  return success();
3257 }
3258 
3261  // No operands are forwarded to the next iteration.
3262  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3263 }
3264 
3265 //===----------------------------------------------------------------------===//
3266 // ReduceReturnOp
3267 //===----------------------------------------------------------------------===//
3268 
3269 LogicalResult ReduceReturnOp::verify() {
3270  // The type of the return value should be the same type as the types of the
3271  // block arguments of the reduction body.
3272  Block *reductionBody = getOperation()->getBlock();
3273  // Should already be verified by an op trait.
3274  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3275  Type expectedResultType = reductionBody->getArgument(0).getType();
3276  if (expectedResultType != getResult().getType())
3277  return emitOpError() << "must have type " << expectedResultType
3278  << " (the type of the reduction inputs)";
3279  return success();
3280 }
3281 
3282 //===----------------------------------------------------------------------===//
3283 // WhileOp
3284 //===----------------------------------------------------------------------===//
3285 
3286 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3287  ::mlir::OperationState &odsState, TypeRange resultTypes,
3288  ValueRange inits, BodyBuilderFn beforeBuilder,
3289  BodyBuilderFn afterBuilder) {
3290  odsState.addOperands(inits);
3291  odsState.addTypes(resultTypes);
3292 
3293  OpBuilder::InsertionGuard guard(odsBuilder);
3294 
3295  // Build before region.
3296  SmallVector<Location, 4> beforeArgLocs;
3297  beforeArgLocs.reserve(inits.size());
3298  for (Value operand : inits) {
3299  beforeArgLocs.push_back(operand.getLoc());
3300  }
3301 
3302  Region *beforeRegion = odsState.addRegion();
3303  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3304  inits.getTypes(), beforeArgLocs);
3305  if (beforeBuilder)
3306  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3307 
3308  // Build after region.
3309  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3310 
3311  Region *afterRegion = odsState.addRegion();
3312  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3313  resultTypes, afterArgLocs);
3314 
3315  if (afterBuilder)
3316  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3317 }
3318 
3319 ConditionOp WhileOp::getConditionOp() {
3320  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3321 }
3322 
3323 YieldOp WhileOp::getYieldOp() {
3324  return cast<YieldOp>(getAfterBody()->getTerminator());
3325 }
3326 
3327 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3328  return getYieldOp().getResultsMutable();
3329 }
3330 
3331 Block::BlockArgListType WhileOp::getBeforeArguments() {
3332  return getBeforeBody()->getArguments();
3333 }
3334 
3335 Block::BlockArgListType WhileOp::getAfterArguments() {
3336  return getAfterBody()->getArguments();
3337 }
3338 
3339 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3340  return getBeforeArguments();
3341 }
3342 
3343 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3344  assert(point == getBefore() &&
3345  "WhileOp is expected to branch only to the first region");
3346  return getInits();
3347 }
3348 
3349 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3351  // The parent op always branches to the condition region.
3352  if (point.isParent()) {
3353  regions.emplace_back(&getBefore(), getBefore().getArguments());
3354  return;
3355  }
3356 
3357  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3358  "there are only two regions in a WhileOp");
3359  // The body region always branches back to the condition region.
3360  if (point == getAfter()) {
3361  regions.emplace_back(&getBefore(), getBefore().getArguments());
3362  return;
3363  }
3364 
3365  regions.emplace_back(getResults());
3366  regions.emplace_back(&getAfter(), getAfter().getArguments());
3367 }
3368 
3369 SmallVector<Region *> WhileOp::getLoopRegions() {
3370  return {&getBefore(), &getAfter()};
3371 }
3372 
3373 /// Parses a `while` op.
3374 ///
3375 /// op ::= `scf.while` assignments `:` function-type region `do` region
3376 /// `attributes` attribute-dict
3377 /// initializer ::= /* empty */ | `(` assignment-list `)`
3378 /// assignment-list ::= assignment | assignment `,` assignment-list
3379 /// assignment ::= ssa-value `=` ssa-value
3380 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3383  Region *before = result.addRegion();
3384  Region *after = result.addRegion();
3385 
3386  OptionalParseResult listResult =
3387  parser.parseOptionalAssignmentList(regionArgs, operands);
3388  if (listResult.has_value() && failed(listResult.value()))
3389  return failure();
3390 
3391  FunctionType functionType;
3392  SMLoc typeLoc = parser.getCurrentLocation();
3393  if (failed(parser.parseColonType(functionType)))
3394  return failure();
3395 
3396  result.addTypes(functionType.getResults());
3397 
3398  if (functionType.getNumInputs() != operands.size()) {
3399  return parser.emitError(typeLoc)
3400  << "expected as many input types as operands "
3401  << "(expected " << operands.size() << " got "
3402  << functionType.getNumInputs() << ")";
3403  }
3404 
3405  // Resolve input operands.
3406  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3407  parser.getCurrentLocation(),
3408  result.operands)))
3409  return failure();
3410 
3411  // Propagate the types into the region arguments.
3412  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3413  regionArgs[i].type = functionType.getInput(i);
3414 
3415  return failure(parser.parseRegion(*before, regionArgs) ||
3416  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3418 }
3419 
3420 /// Prints a `while` op.
3422  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3423  p << " : ";
3424  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3425  p << ' ';
3426  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3427  p << " do ";
3428  p.printRegion(getAfter());
3429  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3430 }
3431 
3432 /// Verifies that two ranges of types match, i.e. have the same number of
3433 /// entries and that types are pairwise equals. Reports errors on the given
3434 /// operation in case of mismatch.
3435 template <typename OpTy>
3436 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3437  TypeRange right, StringRef message) {
3438  if (left.size() != right.size())
3439  return op.emitOpError("expects the same number of ") << message;
3440 
3441  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3442  if (left[i] != right[i]) {
3443  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3444  << message;
3445  diag.attachNote() << "for argument " << i << ", found " << left[i]
3446  << " and " << right[i];
3447  return diag;
3448  }
3449  }
3450 
3451  return success();
3452 }
3453 
3454 LogicalResult scf::WhileOp::verify() {
3455  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3456  *this, getBefore(),
3457  "expects the 'before' region to terminate with 'scf.condition'");
3458  if (!beforeTerminator)
3459  return failure();
3460 
3461  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3462  *this, getAfter(),
3463  "expects the 'after' region to terminate with 'scf.yield'");
3464  return success(afterTerminator != nullptr);
3465 }
3466 
3467 namespace {
3468 /// Replace uses of the condition within the do block with true, since otherwise
3469 /// the block would not be evaluated.
3470 ///
3471 /// scf.while (..) : (i1, ...) -> ... {
3472 /// %condition = call @evaluate_condition() : () -> i1
3473 /// scf.condition(%condition) %condition : i1, ...
3474 /// } do {
3475 /// ^bb0(%arg0: i1, ...):
3476 /// use(%arg0)
3477 /// ...
3478 ///
3479 /// becomes
3480 /// scf.while (..) : (i1, ...) -> ... {
3481 /// %condition = call @evaluate_condition() : () -> i1
3482 /// scf.condition(%condition) %condition : i1, ...
3483 /// } do {
3484 /// ^bb0(%arg0: i1, ...):
3485 /// use(%true)
3486 /// ...
3487 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3489 
3490  LogicalResult matchAndRewrite(WhileOp op,
3491  PatternRewriter &rewriter) const override {
3492  auto term = op.getConditionOp();
3493 
3494  // These variables serve to prevent creating duplicate constants
3495  // and hold constant true or false values.
3496  Value constantTrue = nullptr;
3497 
3498  bool replaced = false;
3499  for (auto yieldedAndBlockArgs :
3500  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3501  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3502  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3503  if (!constantTrue)
3504  constantTrue = rewriter.create<arith::ConstantOp>(
3505  op.getLoc(), term.getCondition().getType(),
3506  rewriter.getBoolAttr(true));
3507 
3508  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3509  constantTrue);
3510  replaced = true;
3511  }
3512  }
3513  }
3514  return success(replaced);
3515  }
3516 };
3517 
3518 /// Remove loop invariant arguments from `before` block of scf.while.
3519 /// A before block argument is considered loop invariant if :-
3520 /// 1. i-th yield operand is equal to the i-th while operand.
3521 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3522 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3523 /// iter argument/while operand.
3524 /// For the arguments which are removed, their uses inside scf.while
3525 /// are replaced with their corresponding initial value.
3526 ///
3527 /// Eg:
3528 /// INPUT :-
3529 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3530 /// ..., %argN_before = %N)
3531 /// {
3532 /// ...
3533 /// scf.condition(%cond) %arg1_before, %arg0_before,
3534 /// %arg2_before, %arg0_before, ...
3535 /// } do {
3536 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3537 /// ..., %argK_after):
3538 /// ...
3539 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3540 /// }
3541 ///
3542 /// OUTPUT :-
3543 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3544 /// %N)
3545 /// {
3546 /// ...
3547 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3548 /// } do {
3549 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3550 /// ..., %argK_after):
3551 /// ...
3552 /// scf.yield %arg1_after, ..., %argN
3553 /// }
3554 ///
3555 /// EXPLANATION:
3556 /// We iterate over each yield operand.
3557 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3558 /// %arg0_before, which in turn is the 0-th iter argument. So we
3559 /// remove 0-th before block argument and yield operand, and replace
3560 /// all uses of the 0-th before block argument with its initial value
3561 /// %a.
3562 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3563 /// value. So we remove this operand and the corresponding before
3564 /// block argument and replace all uses of 1-th before block argument
3565 /// with %b.
3566 struct RemoveLoopInvariantArgsFromBeforeBlock
3567  : public OpRewritePattern<WhileOp> {
3569 
3570  LogicalResult matchAndRewrite(WhileOp op,
3571  PatternRewriter &rewriter) const override {
3572  Block &afterBlock = *op.getAfterBody();
3573  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3574  ConditionOp condOp = op.getConditionOp();
3575  OperandRange condOpArgs = condOp.getArgs();
3576  Operation *yieldOp = afterBlock.getTerminator();
3577  ValueRange yieldOpArgs = yieldOp->getOperands();
3578 
3579  bool canSimplify = false;
3580  for (const auto &it :
3581  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3582  auto index = static_cast<unsigned>(it.index());
3583  auto [initVal, yieldOpArg] = it.value();
3584  // If i-th yield operand is equal to the i-th operand of the scf.while,
3585  // the i-th before block argument is a loop invariant.
3586  if (yieldOpArg == initVal) {
3587  canSimplify = true;
3588  break;
3589  }
3590  // If the i-th yield operand is k-th after block argument, then we check
3591  // if the (k+1)-th condition op operand is equal to either the i-th before
3592  // block argument or the initial value of i-th before block argument. If
3593  // the comparison results `true`, i-th before block argument is a loop
3594  // invariant.
3595  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3596  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3597  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3598  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3599  canSimplify = true;
3600  break;
3601  }
3602  }
3603  }
3604 
3605  if (!canSimplify)
3606  return failure();
3607 
3608  SmallVector<Value> newInitArgs, newYieldOpArgs;
3609  DenseMap<unsigned, Value> beforeBlockInitValMap;
3610  SmallVector<Location> newBeforeBlockArgLocs;
3611  for (const auto &it :
3612  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3613  auto index = static_cast<unsigned>(it.index());
3614  auto [initVal, yieldOpArg] = it.value();
3615 
3616  // If i-th yield operand is equal to the i-th operand of the scf.while,
3617  // the i-th before block argument is a loop invariant.
3618  if (yieldOpArg == initVal) {
3619  beforeBlockInitValMap.insert({index, initVal});
3620  continue;
3621  } else {
3622  // If the i-th yield operand is k-th after block argument, then we check
3623  // if the (k+1)-th condition op operand is equal to either the i-th
3624  // before block argument or the initial value of i-th before block
3625  // argument. If the comparison results `true`, i-th before block
3626  // argument is a loop invariant.
3627  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3628  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3629  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3630  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3631  beforeBlockInitValMap.insert({index, initVal});
3632  continue;
3633  }
3634  }
3635  }
3636  newInitArgs.emplace_back(initVal);
3637  newYieldOpArgs.emplace_back(yieldOpArg);
3638  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3639  }
3640 
3641  {
3642  OpBuilder::InsertionGuard g(rewriter);
3643  rewriter.setInsertionPoint(yieldOp);
3644  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3645  }
3646 
3647  auto newWhile =
3648  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3649 
3650  Block &newBeforeBlock = *rewriter.createBlock(
3651  &newWhile.getBefore(), /*insertPt*/ {},
3652  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3653 
3654  Block &beforeBlock = *op.getBeforeBody();
3655  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3656  // For each i-th before block argument we find it's replacement value as :-
3657  // 1. If i-th before block argument is a loop invariant, we fetch it's
3658  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3659  // 2. Else we fetch j-th new before block argument as the replacement
3660  // value of i-th before block argument.
3661  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3662  // If the index 'i' argument was a loop invariant we fetch it's initial
3663  // value from `beforeBlockInitValMap`.
3664  if (beforeBlockInitValMap.count(i) != 0)
3665  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3666  else
3667  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3668  }
3669 
3670  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3671  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3672  newWhile.getAfter().begin());
3673 
3674  rewriter.replaceOp(op, newWhile.getResults());
3675  return success();
3676  }
3677 };
3678 
3679 /// Remove loop invariant value from result (condition op) of scf.while.
3680 /// A value is considered loop invariant if the final value yielded by
3681 /// scf.condition is defined outside of the `before` block. We remove the
3682 /// corresponding argument in `after` block and replace the use with the value.
3683 /// We also replace the use of the corresponding result of scf.while with the
3684 /// value.
3685 ///
3686 /// Eg:
3687 /// INPUT :-
3688 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3689 /// %argN_before = %N) {
3690 /// ...
3691 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3692 /// } do {
3693 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3694 /// ...
3695 /// some_func(%arg1_after)
3696 /// ...
3697 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3698 /// }
3699 ///
3700 /// OUTPUT :-
3701 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3702 /// ...
3703 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3704 /// } do {
3705 /// ^bb0(%arg0, %arg3, ..., %argM):
3706 /// ...
3707 /// some_func(%a)
3708 /// ...
3709 /// scf.yield %arg0, %b, ..., %argN
3710 /// }
3711 ///
3712 /// EXPLANATION:
3713 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3714 /// before block of scf.while, so they get removed.
3715 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3716 /// replaced by %b.
3717 /// 3. The corresponding after block argument %arg1_after's uses are
3718 /// replaced by %a and %arg2_after's uses are replaced by %b.
3719 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3721 
3722  LogicalResult matchAndRewrite(WhileOp op,
3723  PatternRewriter &rewriter) const override {
3724  Block &beforeBlock = *op.getBeforeBody();
3725  ConditionOp condOp = op.getConditionOp();
3726  OperandRange condOpArgs = condOp.getArgs();
3727 
3728  bool canSimplify = false;
3729  for (Value condOpArg : condOpArgs) {
3730  // Those values not defined within `before` block will be considered as
3731  // loop invariant values. We map the corresponding `index` with their
3732  // value.
3733  if (condOpArg.getParentBlock() != &beforeBlock) {
3734  canSimplify = true;
3735  break;
3736  }
3737  }
3738 
3739  if (!canSimplify)
3740  return failure();
3741 
3742  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3743 
3744  SmallVector<Value> newCondOpArgs;
3745  SmallVector<Type> newAfterBlockType;
3746  DenseMap<unsigned, Value> condOpInitValMap;
3747  SmallVector<Location> newAfterBlockArgLocs;
3748  for (const auto &it : llvm::enumerate(condOpArgs)) {
3749  auto index = static_cast<unsigned>(it.index());
3750  Value condOpArg = it.value();
3751  // Those values not defined within `before` block will be considered as
3752  // loop invariant values. We map the corresponding `index` with their
3753  // value.
3754  if (condOpArg.getParentBlock() != &beforeBlock) {
3755  condOpInitValMap.insert({index, condOpArg});
3756  } else {
3757  newCondOpArgs.emplace_back(condOpArg);
3758  newAfterBlockType.emplace_back(condOpArg.getType());
3759  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3760  }
3761  }
3762 
3763  {
3764  OpBuilder::InsertionGuard g(rewriter);
3765  rewriter.setInsertionPoint(condOp);
3766  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3767  newCondOpArgs);
3768  }
3769 
3770  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3771  op.getOperands());
3772 
3773  Block &newAfterBlock =
3774  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3775  newAfterBlockType, newAfterBlockArgLocs);
3776 
3777  Block &afterBlock = *op.getAfterBody();
3778  // Since a new scf.condition op was created, we need to fetch the new
3779  // `after` block arguments which will be used while replacing operations of
3780  // previous scf.while's `after` blocks. We'd also be fetching new result
3781  // values too.
3782  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3783  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3784  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3785  Value afterBlockArg, result;
3786  // If index 'i' argument was loop invariant we fetch it's value from the
3787  // `condOpInitMap` map.
3788  if (condOpInitValMap.count(i) != 0) {
3789  afterBlockArg = condOpInitValMap[i];
3790  result = afterBlockArg;
3791  } else {
3792  afterBlockArg = newAfterBlock.getArgument(j);
3793  result = newWhile.getResult(j);
3794  j++;
3795  }
3796  newAfterBlockArgs[i] = afterBlockArg;
3797  newWhileResults[i] = result;
3798  }
3799 
3800  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3801  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3802  newWhile.getBefore().begin());
3803 
3804  rewriter.replaceOp(op, newWhileResults);
3805  return success();
3806  }
3807 };
3808 
3809 /// Remove WhileOp results that are also unused in 'after' block.
3810 ///
3811 /// %0:2 = scf.while () : () -> (i32, i64) {
3812 /// %condition = "test.condition"() : () -> i1
3813 /// %v1 = "test.get_some_value"() : () -> i32
3814 /// %v2 = "test.get_some_value"() : () -> i64
3815 /// scf.condition(%condition) %v1, %v2 : i32, i64
3816 /// } do {
3817 /// ^bb0(%arg0: i32, %arg1: i64):
3818 /// "test.use"(%arg0) : (i32) -> ()
3819 /// scf.yield
3820 /// }
3821 /// return %0#0 : i32
3822 ///
3823 /// becomes
3824 /// %0 = scf.while () : () -> (i32) {
3825 /// %condition = "test.condition"() : () -> i1
3826 /// %v1 = "test.get_some_value"() : () -> i32
3827 /// %v2 = "test.get_some_value"() : () -> i64
3828 /// scf.condition(%condition) %v1 : i32
3829 /// } do {
3830 /// ^bb0(%arg0: i32):
3831 /// "test.use"(%arg0) : (i32) -> ()
3832 /// scf.yield
3833 /// }
3834 /// return %0 : i32
3835 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3837 
3838  LogicalResult matchAndRewrite(WhileOp op,
3839  PatternRewriter &rewriter) const override {
3840  auto term = op.getConditionOp();
3841  auto afterArgs = op.getAfterArguments();
3842  auto termArgs = term.getArgs();
3843 
3844  // Collect results mapping, new terminator args and new result types.
3845  SmallVector<unsigned> newResultsIndices;
3846  SmallVector<Type> newResultTypes;
3847  SmallVector<Value> newTermArgs;
3848  SmallVector<Location> newArgLocs;
3849  bool needUpdate = false;
3850  for (const auto &it :
3851  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3852  auto i = static_cast<unsigned>(it.index());
3853  Value result = std::get<0>(it.value());
3854  Value afterArg = std::get<1>(it.value());
3855  Value termArg = std::get<2>(it.value());
3856  if (result.use_empty() && afterArg.use_empty()) {
3857  needUpdate = true;
3858  } else {
3859  newResultsIndices.emplace_back(i);
3860  newTermArgs.emplace_back(termArg);
3861  newResultTypes.emplace_back(result.getType());
3862  newArgLocs.emplace_back(result.getLoc());
3863  }
3864  }
3865 
3866  if (!needUpdate)
3867  return failure();
3868 
3869  {
3870  OpBuilder::InsertionGuard g(rewriter);
3871  rewriter.setInsertionPoint(term);
3872  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3873  newTermArgs);
3874  }
3875 
3876  auto newWhile =
3877  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3878 
3879  Block &newAfterBlock = *rewriter.createBlock(
3880  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3881 
3882  // Build new results list and new after block args (unused entries will be
3883  // null).
3884  SmallVector<Value> newResults(op.getNumResults());
3885  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3886  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3887  newResults[it.value()] = newWhile.getResult(it.index());
3888  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3889  }
3890 
3891  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3892  newWhile.getBefore().begin());
3893 
3894  Block &afterBlock = *op.getAfterBody();
3895  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3896 
3897  rewriter.replaceOp(op, newResults);
3898  return success();
3899  }
3900 };
3901 
3902 /// Replace operations equivalent to the condition in the do block with true,
3903 /// since otherwise the block would not be evaluated.
3904 ///
3905 /// scf.while (..) : (i32, ...) -> ... {
3906 /// %z = ... : i32
3907 /// %condition = cmpi pred %z, %a
3908 /// scf.condition(%condition) %z : i32, ...
3909 /// } do {
3910 /// ^bb0(%arg0: i32, ...):
3911 /// %condition2 = cmpi pred %arg0, %a
3912 /// use(%condition2)
3913 /// ...
3914 ///
3915 /// becomes
3916 /// scf.while (..) : (i32, ...) -> ... {
3917 /// %z = ... : i32
3918 /// %condition = cmpi pred %z, %a
3919 /// scf.condition(%condition) %z : i32, ...
3920 /// } do {
3921 /// ^bb0(%arg0: i32, ...):
3922 /// use(%true)
3923 /// ...
3924 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3926 
3927  LogicalResult matchAndRewrite(scf::WhileOp op,
3928  PatternRewriter &rewriter) const override {
3929  using namespace scf;
3930  auto cond = op.getConditionOp();
3931  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3932  if (!cmp)
3933  return failure();
3934  bool changed = false;
3935  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3936  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3937  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3938  continue;
3939  for (OpOperand &u :
3940  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3941  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3942  if (!cmp2)
3943  continue;
3944  // For a binary operator 1-opIdx gets the other side.
3945  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3946  continue;
3947  bool samePredicate;
3948  if (cmp2.getPredicate() == cmp.getPredicate())
3949  samePredicate = true;
3950  else if (cmp2.getPredicate() ==
3951  arith::invertPredicate(cmp.getPredicate()))
3952  samePredicate = false;
3953  else
3954  continue;
3955 
3956  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3957  1);
3958  changed = true;
3959  }
3960  }
3961  }
3962  return success(changed);
3963  }
3964 };
3965 
3966 /// Remove unused init/yield args.
3967 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3969 
3970  LogicalResult matchAndRewrite(WhileOp op,
3971  PatternRewriter &rewriter) const override {
3972 
3973  if (!llvm::any_of(op.getBeforeArguments(),
3974  [](Value arg) { return arg.use_empty(); }))
3975  return rewriter.notifyMatchFailure(op, "No args to remove");
3976 
3977  YieldOp yield = op.getYieldOp();
3978 
3979  // Collect results mapping, new terminator args and new result types.
3980  SmallVector<Value> newYields;
3981  SmallVector<Value> newInits;
3982  llvm::BitVector argsToErase;
3983 
3984  size_t argsCount = op.getBeforeArguments().size();
3985  newYields.reserve(argsCount);
3986  newInits.reserve(argsCount);
3987  argsToErase.reserve(argsCount);
3988  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3989  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3990  if (beforeArg.use_empty()) {
3991  argsToErase.push_back(true);
3992  } else {
3993  argsToErase.push_back(false);
3994  newYields.emplace_back(yieldValue);
3995  newInits.emplace_back(initValue);
3996  }
3997  }
3998 
3999  Block &beforeBlock = *op.getBeforeBody();
4000  Block &afterBlock = *op.getAfterBody();
4001 
4002  beforeBlock.eraseArguments(argsToErase);
4003 
4004  Location loc = op.getLoc();
4005  auto newWhileOp =
4006  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4007  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4008  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4009  Block &newAfterBlock = *newWhileOp.getAfterBody();
4010 
4011  OpBuilder::InsertionGuard g(rewriter);
4012  rewriter.setInsertionPoint(yield);
4013  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4014 
4015  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4016  newBeforeBlock.getArguments());
4017  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4018  newAfterBlock.getArguments());
4019 
4020  rewriter.replaceOp(op, newWhileOp.getResults());
4021  return success();
4022  }
4023 };
4024 
4025 /// Remove duplicated ConditionOp args.
4026 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4028 
4029  LogicalResult matchAndRewrite(WhileOp op,
4030  PatternRewriter &rewriter) const override {
4031  ConditionOp condOp = op.getConditionOp();
4032  ValueRange condOpArgs = condOp.getArgs();
4033 
4035  for (Value arg : condOpArgs)
4036  argsSet.insert(arg);
4037 
4038  if (argsSet.size() == condOpArgs.size())
4039  return rewriter.notifyMatchFailure(op, "No results to remove");
4040 
4041  llvm::SmallDenseMap<Value, unsigned> argsMap;
4042  SmallVector<Value> newArgs;
4043  argsMap.reserve(condOpArgs.size());
4044  newArgs.reserve(condOpArgs.size());
4045  for (Value arg : condOpArgs) {
4046  if (!argsMap.count(arg)) {
4047  auto pos = static_cast<unsigned>(argsMap.size());
4048  argsMap.insert({arg, pos});
4049  newArgs.emplace_back(arg);
4050  }
4051  }
4052 
4053  ValueRange argsRange(newArgs);
4054 
4055  Location loc = op.getLoc();
4056  auto newWhileOp = rewriter.create<scf::WhileOp>(
4057  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4058  /*afterBody*/ nullptr);
4059  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4060  Block &newAfterBlock = *newWhileOp.getAfterBody();
4061 
4062  SmallVector<Value> afterArgsMapping;
4063  SmallVector<Value> resultsMapping;
4064  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4065  auto it = argsMap.find(arg);
4066  assert(it != argsMap.end());
4067  auto pos = it->second;
4068  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4069  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4070  }
4071 
4072  OpBuilder::InsertionGuard g(rewriter);
4073  rewriter.setInsertionPoint(condOp);
4074  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4075  argsRange);
4076 
4077  Block &beforeBlock = *op.getBeforeBody();
4078  Block &afterBlock = *op.getAfterBody();
4079 
4080  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4081  newBeforeBlock.getArguments());
4082  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4083  rewriter.replaceOp(op, resultsMapping);
4084  return success();
4085  }
4086 };
4087 
4088 /// If both ranges contain same values return mappping indices from args2 to
4089 /// args1. Otherwise return std::nullopt.
4090 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4091  ValueRange args2) {
4092  if (args1.size() != args2.size())
4093  return std::nullopt;
4094 
4095  SmallVector<unsigned> ret(args1.size());
4096  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4097  auto it = llvm::find(args2, arg1);
4098  if (it == args2.end())
4099  return std::nullopt;
4100 
4101  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4102  }
4103 
4104  return ret;
4105 }
4106 
4107 static bool hasDuplicates(ValueRange args) {
4108  llvm::SmallDenseSet<Value> set;
4109  for (Value arg : args) {
4110  if (!set.insert(arg).second)
4111  return true;
4112  }
4113  return false;
4114 }
4115 
4116 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4117 /// `scf.condition` args into same order as block args. Update `after` block
4118 /// args and op result values accordingly.
4119 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4120 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4122 
4123  LogicalResult matchAndRewrite(WhileOp loop,
4124  PatternRewriter &rewriter) const override {
4125  auto oldBefore = loop.getBeforeBody();
4126  ConditionOp oldTerm = loop.getConditionOp();
4127  ValueRange beforeArgs = oldBefore->getArguments();
4128  ValueRange termArgs = oldTerm.getArgs();
4129  if (beforeArgs == termArgs)
4130  return failure();
4131 
4132  if (hasDuplicates(termArgs))
4133  return failure();
4134 
4135  auto mapping = getArgsMapping(beforeArgs, termArgs);
4136  if (!mapping)
4137  return failure();
4138 
4139  {
4140  OpBuilder::InsertionGuard g(rewriter);
4141  rewriter.setInsertionPoint(oldTerm);
4142  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4143  beforeArgs);
4144  }
4145 
4146  auto oldAfter = loop.getAfterBody();
4147 
4148  SmallVector<Type> newResultTypes(beforeArgs.size());
4149  for (auto &&[i, j] : llvm::enumerate(*mapping))
4150  newResultTypes[j] = loop.getResult(i).getType();
4151 
4152  auto newLoop = rewriter.create<WhileOp>(
4153  loop.getLoc(), newResultTypes, loop.getInits(),
4154  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4155  auto newBefore = newLoop.getBeforeBody();
4156  auto newAfter = newLoop.getAfterBody();
4157 
4158  SmallVector<Value> newResults(beforeArgs.size());
4159  SmallVector<Value> newAfterArgs(beforeArgs.size());
4160  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4161  newResults[i] = newLoop.getResult(j);
4162  newAfterArgs[i] = newAfter->getArgument(j);
4163  }
4164 
4165  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4166  newBefore->getArguments());
4167  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4168  newAfterArgs);
4169 
4170  rewriter.replaceOp(loop, newResults);
4171  return success();
4172  }
4173 };
4174 } // namespace
4175 
4176 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4177  MLIRContext *context) {
4178  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4179  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4180  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4181  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4182 }
4183 
4184 //===----------------------------------------------------------------------===//
4185 // IndexSwitchOp
4186 //===----------------------------------------------------------------------===//
4187 
4188 /// Parse the case regions and values.
4189 static ParseResult
4191  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4192  SmallVector<int64_t> caseValues;
4193  while (succeeded(p.parseOptionalKeyword("case"))) {
4194  int64_t value;
4195  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4196  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4197  return failure();
4198  caseValues.push_back(value);
4199  }
4200  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4201  return success();
4202 }
4203 
4204 /// Print the case regions and values.
4206  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4207  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4208  p.printNewline();
4209  p << "case " << value << ' ';
4210  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4211  }
4212 }
4213 
4214 LogicalResult scf::IndexSwitchOp::verify() {
4215  if (getCases().size() != getCaseRegions().size()) {
4216  return emitOpError("has ")
4217  << getCaseRegions().size() << " case regions but "
4218  << getCases().size() << " case values";
4219  }
4220 
4221  DenseSet<int64_t> valueSet;
4222  for (int64_t value : getCases())
4223  if (!valueSet.insert(value).second)
4224  return emitOpError("has duplicate case value: ") << value;
4225  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4226  auto yield = dyn_cast<YieldOp>(region.front().back());
4227  if (!yield)
4228  return emitOpError("expected region to end with scf.yield, but got ")
4229  << region.front().back().getName();
4230 
4231  if (yield.getNumOperands() != getNumResults()) {
4232  return (emitOpError("expected each region to return ")
4233  << getNumResults() << " values, but " << name << " returns "
4234  << yield.getNumOperands())
4235  .attachNote(yield.getLoc())
4236  << "see yield operation here";
4237  }
4238  for (auto [idx, result, operand] :
4239  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4240  yield.getOperandTypes())) {
4241  if (result == operand)
4242  continue;
4243  return (emitOpError("expected result #")
4244  << idx << " of each region to be " << result)
4245  .attachNote(yield.getLoc())
4246  << name << " returns " << operand << " here";
4247  }
4248  return success();
4249  };
4250 
4251  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4252  return failure();
4253  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4254  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4255  return failure();
4256 
4257  return success();
4258 }
4259 
4260 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4261 
4262 Block &scf::IndexSwitchOp::getDefaultBlock() {
4263  return getDefaultRegion().front();
4264 }
4265 
4266 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4267  assert(idx < getNumCases() && "case index out-of-bounds");
4268  return getCaseRegions()[idx].front();
4269 }
4270 
4271 void IndexSwitchOp::getSuccessorRegions(
4273  // All regions branch back to the parent op.
4274  if (!point.isParent()) {
4275  successors.emplace_back(getResults());
4276  return;
4277  }
4278 
4279  llvm::copy(getRegions(), std::back_inserter(successors));
4280 }
4281 
4282 void IndexSwitchOp::getEntrySuccessorRegions(
4283  ArrayRef<Attribute> operands,
4284  SmallVectorImpl<RegionSuccessor> &successors) {
4285  FoldAdaptor adaptor(operands, *this);
4286 
4287  // If a constant was not provided, all regions are possible successors.
4288  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4289  if (!arg) {
4290  llvm::copy(getRegions(), std::back_inserter(successors));
4291  return;
4292  }
4293 
4294  // Otherwise, try to find a case with a matching value. If not, the
4295  // default region is the only successor.
4296  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4297  if (caseValue == arg.getInt()) {
4298  successors.emplace_back(&caseRegion);
4299  return;
4300  }
4301  }
4302  successors.emplace_back(&getDefaultRegion());
4303 }
4304 
4305 void IndexSwitchOp::getRegionInvocationBounds(
4307  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4308  if (!operandValue) {
4309  // All regions are invoked at most once.
4310  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4311  return;
4312  }
4313 
4314  unsigned liveIndex = getNumRegions() - 1;
4315  const auto *it = llvm::find(getCases(), operandValue.getInt());
4316  if (it != getCases().end())
4317  liveIndex = std::distance(getCases().begin(), it);
4318  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4319  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4320 }
4321 
4322 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4324 
4325  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4326  PatternRewriter &rewriter) const override {
4327  // If `op.getArg()` is a constant, select the region that matches with
4328  // the constant value. Use the default region if no matche is found.
4329  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4330  if (!maybeCst.has_value())
4331  return failure();
4332  int64_t cst = *maybeCst;
4333  int64_t caseIdx, e = op.getNumCases();
4334  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4335  if (cst == op.getCases()[caseIdx])
4336  break;
4337  }
4338 
4339  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4340  : op.getDefaultRegion();
4341  Block &source = r.front();
4342  Operation *terminator = source.getTerminator();
4343  SmallVector<Value> results = terminator->getOperands();
4344 
4345  rewriter.inlineBlockBefore(&source, op);
4346  rewriter.eraseOp(terminator);
4347  // Replace the operation with a potentially empty list of results.
4348  // Fold mechanism doesn't support the case where the result list is empty.
4349  rewriter.replaceOp(op, results);
4350 
4351  return success();
4352  }
4353 };
4354 
4355 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4356  MLIRContext *context) {
4357  results.add<FoldConstantCase>(context);
4358 }
4359 
4360 //===----------------------------------------------------------------------===//
4361 // TableGen'd op method definitions
4362 //===----------------------------------------------------------------------===//
4363 
4364 #define GET_OP_CLASSES
4365 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:722
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:714
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4190
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3436
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4205
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:615
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:593
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3049
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1976
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1450
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:776
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:266
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4325
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.