MLIR  20.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange initArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(initArgs);
321  for (Value v : initArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : initArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (initArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Resolve input operands.
503  regionArgs.front().type = type;
504  if (parser.resolveOperand(lb, type, result.operands) ||
505  parser.resolveOperand(ub, type, result.operands) ||
506  parser.resolveOperand(step, type, result.operands))
507  return failure();
508  if (hasIterArgs) {
509  for (auto argOperandType :
510  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
511  Type type = std::get<2>(argOperandType);
512  std::get<0>(argOperandType).type = type;
513  if (parser.resolveOperand(std::get<1>(argOperandType), type,
514  result.operands))
515  return failure();
516  }
517  }
518 
519  // Parse the body region.
520  Region *body = result.addRegion();
521  if (parser.parseRegion(*body, regionArgs))
522  return failure();
523 
524  ForOp::ensureTerminator(*body, builder, result.location);
525 
526  // Parse the optional attribute list.
527  if (parser.parseOptionalAttrDict(result.attributes))
528  return failure();
529 
530  return success();
531 }
532 
533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534 
535 Block::BlockArgListType ForOp::getRegionIterArgs() {
536  return getBody()->getArguments().drop_front(getNumInductionVars());
537 }
538 
539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
540  return getInitArgsMutable();
541 }
542 
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
545  ValueRange newInitOperands,
546  bool replaceInitOperandUsesInLoop,
547  const NewYieldValuesFn &newYieldValuesFn) {
548  // Create a new loop before the existing one, with the extra operands.
549  OpBuilder::InsertionGuard g(rewriter);
550  rewriter.setInsertionPoint(getOperation());
551  auto inits = llvm::to_vector(getInitArgs());
552  inits.append(newInitOperands.begin(), newInitOperands.end());
553  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
554  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
555  [](OpBuilder &, Location, Value, ValueRange) {});
556  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
777  OpOperand &operand, Value replacement,
778  const ValueTypeCastFnTy &castFn) {
779  assert(operand.getOwner() == forOp);
780  Type oldType = operand.get().getType(), newType = replacement.getType();
781 
782  // 1. Create new iter operands, exactly 1 is replaced.
783  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
784  "expected an iter OpOperand");
785  assert(operand.get().getType() != replacement.getType() &&
786  "Expected a different type");
787  SmallVector<Value> newIterOperands;
788  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
789  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
790  newIterOperands.push_back(replacement);
791  continue;
792  }
793  newIterOperands.push_back(opOperand.get());
794  }
795 
796  // 2. Create the new forOp shell.
797  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
798  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799  forOp.getStep(), newIterOperands);
800  newForOp->setAttrs(forOp->getAttrs());
801  Block &newBlock = newForOp.getRegion().front();
802  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
803  newBlock.getArguments().end());
804 
805  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806  // corresponding to the `replacement` value.
807  OpBuilder::InsertionGuard g(rewriter);
808  rewriter.setInsertionPointToStart(&newBlock);
809  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
810  &newForOp->getOpOperand(operand.getOperandNumber()));
811  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
813 
814  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815  Block &oldBlock = forOp.getRegion().front();
816  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
817 
818  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
820  rewriter.setInsertionPoint(clonedYieldOp);
821  unsigned yieldIdx =
822  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
823  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824  clonedYieldOp.getOperand(yieldIdx));
825  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
826  newYieldOperands[yieldIdx] = castOut;
827  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828  rewriter.eraseOp(clonedYieldOp);
829 
830  // 6. Inject an outgoing cast op after the forOp.
831  rewriter.setInsertionPointAfter(newForOp);
832  SmallVector<Value> newResults = newForOp.getResults();
833  newResults[yieldIdx] =
834  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
835 
836  return newResults;
837 }
838 
839 namespace {
840 // Fold away ForOp iter arguments when:
841 // 1) The op yields the iter arguments.
842 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
843 // 3) The iter arguments have no use and the corresponding (operation) results
844 // have no use.
845 //
846 // These arguments must be defined outside of
847 // the ForOp region and can just be forwarded after simplifying the op inits,
848 // yields and returns.
849 //
850 // The implementation uses `inlineBlockBefore` to steal the content of the
851 // original ForOp and avoid cloning.
852 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
854 
855  LogicalResult matchAndRewrite(scf::ForOp forOp,
856  PatternRewriter &rewriter) const final {
857  bool canonicalize = false;
858 
859  // An internal flat vector of block transfer
860  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
861  // transformed block argument mappings. This plays the role of a
862  // IRMapping for the particular use case of calling into
863  // `inlineBlockBefore`.
864  int64_t numResults = forOp.getNumResults();
865  SmallVector<bool, 4> keepMask;
866  keepMask.reserve(numResults);
867  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
868  newResultValues;
869  newBlockTransferArgs.reserve(1 + numResults);
870  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
871  newIterArgs.reserve(forOp.getInitArgs().size());
872  newYieldValues.reserve(numResults);
873  newResultValues.reserve(numResults);
874  for (auto [init, arg, result, yielded] :
875  llvm::zip(forOp.getInitArgs(), // iter from outside
876  forOp.getRegionIterArgs(), // iter inside region
877  forOp.getResults(), // op results
878  forOp.getYieldedValues() // iter yield
879  )) {
880  // Forwarded is `true` when:
881  // 1) The region `iter` argument is yielded.
882  // 2) The region `iter` argument the corresponding input is yielded.
883  // 3) The region `iter` argument has no use, and the corresponding op
884  // result has no use.
885  bool forwarded = (arg == yielded) || (init == yielded) ||
886  (arg.use_empty() && result.use_empty());
887  keepMask.push_back(!forwarded);
888  canonicalize |= forwarded;
889  if (forwarded) {
890  newBlockTransferArgs.push_back(init);
891  newResultValues.push_back(init);
892  continue;
893  }
894  newIterArgs.push_back(init);
895  newYieldValues.push_back(yielded);
896  newBlockTransferArgs.push_back(Value()); // placeholder with null value
897  newResultValues.push_back(Value()); // placeholder with null value
898  }
899 
900  if (!canonicalize)
901  return failure();
902 
903  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
904  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
905  forOp.getStep(), newIterArgs);
906  newForOp->setAttrs(forOp->getAttrs());
907  Block &newBlock = newForOp.getRegion().front();
908 
909  // Replace the null placeholders with newly constructed values.
910  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
911  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
912  idx != e; ++idx) {
913  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
914  Value &newResultVal = newResultValues[idx];
915  assert((blockTransferArg && newResultVal) ||
916  (!blockTransferArg && !newResultVal));
917  if (!blockTransferArg) {
918  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
919  newResultVal = newForOp.getResult(collapsedIdx++);
920  }
921  }
922 
923  Block &oldBlock = forOp.getRegion().front();
924  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
925  "unexpected argument size mismatch");
926 
927  // No results case: the scf::ForOp builder already created a zero
928  // result terminator. Merge before this terminator and just get rid of the
929  // original terminator that has been merged in.
930  if (newIterArgs.empty()) {
931  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
932  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
933  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
934  rewriter.replaceOp(forOp, newResultValues);
935  return success();
936  }
937 
938  // No terminator case: merge and rewrite the merged terminator.
939  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
940  OpBuilder::InsertionGuard g(rewriter);
941  rewriter.setInsertionPoint(mergedTerminator);
942  SmallVector<Value, 4> filteredOperands;
943  filteredOperands.reserve(newResultValues.size());
944  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
945  if (keepMask[idx])
946  filteredOperands.push_back(mergedTerminator.getOperand(idx));
947  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
948  filteredOperands);
949  };
950 
951  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
952  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
953  cloneFilteredTerminator(mergedYieldOp);
954  rewriter.eraseOp(mergedYieldOp);
955  rewriter.replaceOp(forOp, newResultValues);
956  return success();
957  }
958 };
959 
960 /// Util function that tries to compute a constant diff between u and l.
961 /// Returns std::nullopt when the difference between two AffineValueMap is
962 /// dynamic.
963 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
964  IntegerAttr clb, cub;
965  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
966  llvm::APInt lbValue = clb.getValue();
967  llvm::APInt ubValue = cub.getValue();
968  return (ubValue - lbValue).getSExtValue();
969  }
970 
971  // Else a simple pattern match for x + c or c + x
972  llvm::APInt diff;
973  if (matchPattern(
974  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
975  matchPattern(
976  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
977  return diff.getSExtValue();
978  return std::nullopt;
979 }
980 
981 /// Rewriting pattern that erases loops that are known not to iterate, replaces
982 /// single-iteration loops with their bodies, and removes empty loops that
983 /// iterate at least once and only return values defined outside of the loop.
984 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
986 
987  LogicalResult matchAndRewrite(ForOp op,
988  PatternRewriter &rewriter) const override {
989  // If the upper bound is the same as the lower bound, the loop does not
990  // iterate, just remove it.
991  if (op.getLowerBound() == op.getUpperBound()) {
992  rewriter.replaceOp(op, op.getInitArgs());
993  return success();
994  }
995 
996  std::optional<int64_t> diff =
997  computeConstDiff(op.getLowerBound(), op.getUpperBound());
998  if (!diff)
999  return failure();
1000 
1001  // If the loop is known to have 0 iterations, remove it.
1002  if (*diff <= 0) {
1003  rewriter.replaceOp(op, op.getInitArgs());
1004  return success();
1005  }
1006 
1007  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1008  if (!maybeStepValue)
1009  return failure();
1010 
1011  // If the loop is known to have 1 iteration, inline its body and remove the
1012  // loop.
1013  llvm::APInt stepValue = *maybeStepValue;
1014  if (stepValue.sge(*diff)) {
1015  SmallVector<Value, 4> blockArgs;
1016  blockArgs.reserve(op.getInitArgs().size() + 1);
1017  blockArgs.push_back(op.getLowerBound());
1018  llvm::append_range(blockArgs, op.getInitArgs());
1019  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1020  return success();
1021  }
1022 
1023  // Now we are left with loops that have more than 1 iterations.
1024  Block &block = op.getRegion().front();
1025  if (!llvm::hasSingleElement(block))
1026  return failure();
1027  // If the loop is empty, iterates at least once, and only returns values
1028  // defined outside of the loop, remove it and replace it with yield values.
1029  if (llvm::any_of(op.getYieldedValues(),
1030  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1031  return failure();
1032  rewriter.replaceOp(op, op.getYieldedValues());
1033  return success();
1034  }
1035 };
1036 
1037 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1038 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1039 ///
1040 /// ```
1041 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1042 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1043 /// -> (tensor<?x?xf32>) {
1044 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1045 /// scf.yield %2 : tensor<?x?xf32>
1046 /// }
1047 /// use_of(%1)
1048 /// ```
1049 ///
1050 /// folds into:
1051 ///
1052 /// ```
1053 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1054 /// -> (tensor<32x1024xf32>) {
1055 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1056 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1057 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1058 /// scf.yield %4 : tensor<32x1024xf32>
1059 /// }
1060 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1061 /// use_of(%1)
1062 /// ```
1063 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1065 
1066  LogicalResult matchAndRewrite(ForOp op,
1067  PatternRewriter &rewriter) const override {
1068  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1069  OpOperand &iterOpOperand = std::get<0>(it);
1070  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1071  if (!incomingCast ||
1072  incomingCast.getSource().getType() == incomingCast.getType())
1073  continue;
1074  // If the dest type of the cast does not preserve static information in
1075  // the source type.
1077  incomingCast.getDest().getType(),
1078  incomingCast.getSource().getType()))
1079  continue;
1080  if (!std::get<1>(it).hasOneUse())
1081  continue;
1082 
1083  // Create a new ForOp with that iter operand replaced.
1084  rewriter.replaceOp(
1086  rewriter, op, iterOpOperand, incomingCast.getSource(),
1087  [](OpBuilder &b, Location loc, Type type, Value source) {
1088  return b.create<tensor::CastOp>(loc, type, source);
1089  }));
1090  return success();
1091  }
1092  return failure();
1093  }
1094 };
1095 
1096 } // namespace
1097 
1098 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099  MLIRContext *context) {
1100  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1101  context);
1102 }
1103 
1104 std::optional<APInt> ForOp::getConstantStep() {
1105  IntegerAttr step;
1106  if (matchPattern(getStep(), m_Constant(&step)))
1107  return step.getValue();
1108  return {};
1109 }
1110 
1111 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1112  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1113 }
1114 
1115 Speculation::Speculatability ForOp::getSpeculatability() {
1116  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1117  // and End.
1118  if (auto constantStep = getConstantStep())
1119  if (*constantStep == 1)
1121 
1122  // For Step != 1, the loop may not terminate. We can add more smarts here if
1123  // needed.
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // ForallOp
1129 //===----------------------------------------------------------------------===//
1130 
1131 LogicalResult ForallOp::verify() {
1132  unsigned numLoops = getRank();
1133  // Check number of outputs.
1134  if (getNumResults() != getOutputs().size())
1135  return emitOpError("produces ")
1136  << getNumResults() << " results, but has only "
1137  << getOutputs().size() << " outputs";
1138 
1139  // Check that the body defines block arguments for thread indices and outputs.
1140  auto *body = getBody();
1141  if (body->getNumArguments() != numLoops + getOutputs().size())
1142  return emitOpError("region expects ") << numLoops << " arguments";
1143  for (int64_t i = 0; i < numLoops; ++i)
1144  if (!body->getArgument(i).getType().isIndex())
1145  return emitOpError("expects ")
1146  << i << "-th block argument to be an index";
1147  for (unsigned i = 0; i < getOutputs().size(); ++i)
1148  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1149  return emitOpError("type mismatch between ")
1150  << i << "-th output and corresponding block argument";
1151  if (getMapping().has_value() && !getMapping()->empty()) {
1152  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1153  return emitOpError() << "mapping attribute size must match op rank";
1154  for (auto map : getMapping()->getValue()) {
1155  if (!isa<DeviceMappingAttrInterface>(map))
1156  return emitOpError()
1157  << getMappingAttrName() << " is not device mapping attribute";
1158  }
1159  }
1160 
1161  // Verify mixed static/dynamic control variables.
1162  Operation *op = getOperation();
1163  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1164  getStaticLowerBound(),
1165  getDynamicLowerBound())))
1166  return failure();
1167  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1168  getStaticUpperBound(),
1169  getDynamicUpperBound())))
1170  return failure();
1171  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1172  getStaticStep(), getDynamicStep())))
1173  return failure();
1174 
1175  return success();
1176 }
1177 
1178 void ForallOp::print(OpAsmPrinter &p) {
1179  Operation *op = getOperation();
1180  p << " (" << getInductionVars();
1181  if (isNormalized()) {
1182  p << ") in ";
1183  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1184  /*valueTypes=*/{}, /*scalables=*/{},
1186  } else {
1187  p << ") = ";
1188  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1189  /*valueTypes=*/{}, /*scalables=*/{},
1191  p << " to ";
1192  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1193  /*valueTypes=*/{}, /*scalables=*/{},
1195  p << " step ";
1196  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1197  /*valueTypes=*/{}, /*scalables=*/{},
1199  }
1200  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1201  p << " ";
1202  if (!getRegionOutArgs().empty())
1203  p << "-> (" << getResultTypes() << ") ";
1204  p.printRegion(getRegion(),
1205  /*printEntryBlockArgs=*/false,
1206  /*printBlockTerminators=*/getNumResults() > 0);
1207  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1208  getStaticLowerBoundAttrName(),
1209  getStaticUpperBoundAttrName(),
1210  getStaticStepAttrName()});
1211 }
1212 
1213 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1214  OpBuilder b(parser.getContext());
1215  auto indexType = b.getIndexType();
1216 
1217  // Parse an opening `(` followed by thread index variables followed by `)`
1218  // TODO: when we can refer to such "induction variable"-like handles from the
1219  // declarative assembly format, we can implement the parser as a custom hook.
1222  return failure();
1223 
1224  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1225  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1226  dynamicSteps;
1227  if (succeeded(parser.parseOptionalKeyword("in"))) {
1228  // Parse upper bounds.
1229  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1230  /*valueTypes=*/nullptr,
1232  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1233  return failure();
1234 
1235  unsigned numLoops = ivs.size();
1236  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1237  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1238  } else {
1239  // Parse lower bounds.
1240  if (parser.parseEqual() ||
1241  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1242  /*valueTypes=*/nullptr,
1244 
1245  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1246  return failure();
1247 
1248  // Parse upper bounds.
1249  if (parser.parseKeyword("to") ||
1250  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1251  /*valueTypes=*/nullptr,
1253  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1254  return failure();
1255 
1256  // Parse step values.
1257  if (parser.parseKeyword("step") ||
1258  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1259  /*valueTypes=*/nullptr,
1261  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1262  return failure();
1263  }
1264 
1265  // Parse out operands and results.
1268  SMLoc outOperandsLoc = parser.getCurrentLocation();
1269  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1270  if (outOperands.size() != result.types.size())
1271  return parser.emitError(outOperandsLoc,
1272  "mismatch between out operands and types");
1273  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1274  parser.parseOptionalArrowTypeList(result.types) ||
1275  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1276  result.operands))
1277  return failure();
1278  }
1279 
1280  // Parse region.
1282  std::unique_ptr<Region> region = std::make_unique<Region>();
1283  for (auto &iv : ivs) {
1284  iv.type = b.getIndexType();
1285  regionArgs.push_back(iv);
1286  }
1287  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1288  auto &out = it.value();
1289  out.type = result.types[it.index()];
1290  regionArgs.push_back(out);
1291  }
1292  if (parser.parseRegion(*region, regionArgs))
1293  return failure();
1294 
1295  // Ensure terminator and move region.
1296  ForallOp::ensureTerminator(*region, b, result.location);
1297  result.addRegion(std::move(region));
1298 
1299  // Parse the optional attribute list.
1300  if (parser.parseOptionalAttrDict(result.attributes))
1301  return failure();
1302 
1303  result.addAttribute("staticLowerBound", staticLbs);
1304  result.addAttribute("staticUpperBound", staticUbs);
1305  result.addAttribute("staticStep", staticSteps);
1306  result.addAttribute("operandSegmentSizes",
1308  {static_cast<int32_t>(dynamicLbs.size()),
1309  static_cast<int32_t>(dynamicUbs.size()),
1310  static_cast<int32_t>(dynamicSteps.size()),
1311  static_cast<int32_t>(outOperands.size())}));
1312  return success();
1313 }
1314 
1315 // Builder that takes loop bounds.
1316 void ForallOp::build(
1319  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1320  std::optional<ArrayAttr> mapping,
1321  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1322  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1323  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1324  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1325  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1326  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1327 
1328  result.addOperands(dynamicLbs);
1329  result.addOperands(dynamicUbs);
1330  result.addOperands(dynamicSteps);
1331  result.addOperands(outputs);
1332  result.addTypes(TypeRange(outputs));
1333 
1334  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1335  b.getDenseI64ArrayAttr(staticLbs));
1336  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1337  b.getDenseI64ArrayAttr(staticUbs));
1338  result.addAttribute(getStaticStepAttrName(result.name),
1339  b.getDenseI64ArrayAttr(staticSteps));
1340  result.addAttribute(
1341  "operandSegmentSizes",
1342  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1343  static_cast<int32_t>(dynamicUbs.size()),
1344  static_cast<int32_t>(dynamicSteps.size()),
1345  static_cast<int32_t>(outputs.size())}));
1346  if (mapping.has_value()) {
1348  mapping.value());
1349  }
1350 
1351  Region *bodyRegion = result.addRegion();
1353  b.createBlock(bodyRegion);
1354  Block &bodyBlock = bodyRegion->front();
1355 
1356  // Add block arguments for indices and outputs.
1357  bodyBlock.addArguments(
1358  SmallVector<Type>(lbs.size(), b.getIndexType()),
1359  SmallVector<Location>(staticLbs.size(), result.location));
1360  bodyBlock.addArguments(
1361  TypeRange(outputs),
1362  SmallVector<Location>(outputs.size(), result.location));
1363 
1364  b.setInsertionPointToStart(&bodyBlock);
1365  if (!bodyBuilderFn) {
1366  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1367  return;
1368  }
1369  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1370 }
1371 
1372 // Builder that takes loop bounds.
1373 void ForallOp::build(
1375  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1376  std::optional<ArrayAttr> mapping,
1377  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1378  unsigned numLoops = ubs.size();
1379  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1380  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1381  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1382 }
1383 
1384 // Checks if the lbs are zeros and steps are ones.
1385 bool ForallOp::isNormalized() {
1386  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1387  return llvm::all_of(results, [&](OpFoldResult ofr) {
1388  auto intValue = getConstantIntValue(ofr);
1389  return intValue.has_value() && intValue == val;
1390  });
1391  };
1392  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1393 }
1394 
1395 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1396 // unaware of the fact that our terminator also needs a region to be
1397 // well-formed. We override it here to ensure that we do the right thing.
1398 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1399  Location loc) {
1401  ForallOp>::ensureTerminator(region, builder, loc);
1402  auto terminator =
1403  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1404  if (terminator.getRegion().empty())
1405  builder.createBlock(&terminator.getRegion());
1406 }
1407 
1408 InParallelOp ForallOp::getTerminator() {
1409  return cast<InParallelOp>(getBody()->getTerminator());
1410 }
1411 
1412 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1413  SmallVector<Operation *> storeOps;
1414  InParallelOp inParallelOp = getTerminator();
1415  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1416  if (auto parallelInsertSliceOp =
1417  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1418  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1419  storeOps.push_back(parallelInsertSliceOp);
1420  }
1421  }
1422  return storeOps;
1423 }
1424 
1425 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1426  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1427 }
1428 
1429 // Get lower bounds as OpFoldResult.
1430 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1431  Builder b(getOperation()->getContext());
1432  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1433 }
1434 
1435 // Get upper bounds as OpFoldResult.
1436 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1437  Builder b(getOperation()->getContext());
1438  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1439 }
1440 
1441 // Get steps as OpFoldResult.
1442 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1443  Builder b(getOperation()->getContext());
1444  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1445 }
1446 
1448  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1449  if (!tidxArg)
1450  return ForallOp();
1451  assert(tidxArg.getOwner() && "unlinked block argument");
1452  auto *containingOp = tidxArg.getOwner()->getParentOp();
1453  return dyn_cast<ForallOp>(containingOp);
1454 }
1455 
1456 namespace {
1457 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1458 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1460 
1461  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1462  PatternRewriter &rewriter) const final {
1463  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1464  if (!forallOp)
1465  return failure();
1466  Value sharedOut =
1467  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1468  ->get();
1469  rewriter.modifyOpInPlace(
1470  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1471  return success();
1472  }
1473 };
1474 
1475 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1476 public:
1478 
1479  LogicalResult matchAndRewrite(ForallOp op,
1480  PatternRewriter &rewriter) const override {
1481  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1482  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1483  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1484  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1485  failed(foldDynamicIndexList(mixedUpperBound)) &&
1486  failed(foldDynamicIndexList(mixedStep)))
1487  return failure();
1488 
1489  rewriter.modifyOpInPlace(op, [&]() {
1490  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1491  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1492  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1493  staticLowerBound);
1494  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1495  op.setStaticLowerBound(staticLowerBound);
1496 
1497  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1498  staticUpperBound);
1499  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1500  op.setStaticUpperBound(staticUpperBound);
1501 
1502  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1503  op.getDynamicStepMutable().assign(dynamicStep);
1504  op.setStaticStep(staticStep);
1505 
1506  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1507  rewriter.getDenseI32ArrayAttr(
1508  {static_cast<int32_t>(dynamicLowerBound.size()),
1509  static_cast<int32_t>(dynamicUpperBound.size()),
1510  static_cast<int32_t>(dynamicStep.size()),
1511  static_cast<int32_t>(op.getNumResults())}));
1512  });
1513  return success();
1514  }
1515 };
1516 
1517 /// The following canonicalization pattern folds the iter arguments of
1518 /// scf.forall op if :-
1519 /// 1. The corresponding result has zero uses.
1520 /// 2. The iter argument is NOT being modified within the loop body.
1521 /// uses.
1522 ///
1523 /// Example of first case :-
1524 /// INPUT:
1525 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1526 /// {
1527 /// ...
1528 /// <SOME USE OF %arg0>
1529 /// <SOME USE OF %arg1>
1530 /// <SOME USE OF %arg2>
1531 /// ...
1532 /// scf.forall.in_parallel {
1533 /// <STORE OP WITH DESTINATION %arg1>
1534 /// <STORE OP WITH DESTINATION %arg0>
1535 /// <STORE OP WITH DESTINATION %arg2>
1536 /// }
1537 /// }
1538 /// return %res#1
1539 ///
1540 /// OUTPUT:
1541 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1542 /// {
1543 /// ...
1544 /// <SOME USE OF %a>
1545 /// <SOME USE OF %new_arg0>
1546 /// <SOME USE OF %c>
1547 /// ...
1548 /// scf.forall.in_parallel {
1549 /// <STORE OP WITH DESTINATION %new_arg0>
1550 /// }
1551 /// }
1552 /// return %res
1553 ///
1554 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1555 /// scf.forall is replaced by their corresponding operands.
1556 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1557 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1558 /// this canonicalization remains valid. For more details, please refer
1559 /// to :
1560 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1561 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1562 /// handles tensor.parallel_insert_slice ops only.
1563 ///
1564 /// Example of second case :-
1565 /// INPUT:
1566 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1567 /// {
1568 /// ...
1569 /// <SOME USE OF %arg0>
1570 /// <SOME USE OF %arg1>
1571 /// ...
1572 /// scf.forall.in_parallel {
1573 /// <STORE OP WITH DESTINATION %arg1>
1574 /// }
1575 /// }
1576 /// return %res#0, %res#1
1577 ///
1578 /// OUTPUT:
1579 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1580 /// {
1581 /// ...
1582 /// <SOME USE OF %a>
1583 /// <SOME USE OF %new_arg0>
1584 /// ...
1585 /// scf.forall.in_parallel {
1586 /// <STORE OP WITH DESTINATION %new_arg0>
1587 /// }
1588 /// }
1589 /// return %a, %res
1590 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1592 
1593  LogicalResult matchAndRewrite(ForallOp forallOp,
1594  PatternRewriter &rewriter) const final {
1595  // Step 1: For a given i-th result of scf.forall, check the following :-
1596  // a. If it has any use.
1597  // b. If the corresponding iter argument is being modified within
1598  // the loop, i.e. has at least one store op with the iter arg as
1599  // its destination operand. For this we use
1600  // ForallOp::getCombiningOps(iter_arg).
1601  //
1602  // Based on the check we maintain the following :-
1603  // a. `resultToDelete` - i-th result of scf.forall that'll be
1604  // deleted.
1605  // b. `resultToReplace` - i-th result of the old scf.forall
1606  // whose uses will be replaced by the new scf.forall.
1607  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1608  // corresponding to the i-th result with at least one use.
1609  SetVector<OpResult> resultToDelete;
1610  SmallVector<Value> resultToReplace;
1611  SmallVector<Value> newOuts;
1612  for (OpResult result : forallOp.getResults()) {
1613  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1614  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1615  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1616  resultToDelete.insert(result);
1617  } else {
1618  resultToReplace.push_back(result);
1619  newOuts.push_back(opOperand->get());
1620  }
1621  }
1622 
1623  // Return early if all results of scf.forall have at least one use and being
1624  // modified within the loop.
1625  if (resultToDelete.empty())
1626  return failure();
1627 
1628  // Step 2: For the the i-th result, do the following :-
1629  // a. Fetch the corresponding BlockArgument.
1630  // b. Look for store ops (currently tensor.parallel_insert_slice)
1631  // with the BlockArgument as its destination operand.
1632  // c. Remove the operations fetched in b.
1633  for (OpResult result : resultToDelete) {
1634  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1635  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1636  SmallVector<Operation *> combiningOps =
1637  forallOp.getCombiningOps(blockArg);
1638  for (Operation *combiningOp : combiningOps)
1639  rewriter.eraseOp(combiningOp);
1640  }
1641 
1642  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1643  // fetched earlier
1644  auto newForallOp = rewriter.create<scf::ForallOp>(
1645  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1646  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1647  forallOp.getMapping(),
1648  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1649 
1650  // Step 4. Merge the block of the old scf.forall into the newly created
1651  // scf.forall using the new set of arguments.
1652  Block *loopBody = forallOp.getBody();
1653  Block *newLoopBody = newForallOp.getBody();
1654  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1655  // Form initial new bbArg list with just the control operands of the new
1656  // scf.forall op.
1657  SmallVector<Value> newBlockArgs =
1658  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1659  [](BlockArgument b) -> Value { return b; });
1660  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1661  unsigned index = 0;
1662  // Take the new corresponding bbArg if the old bbArg was used as a
1663  // destination in the in_parallel op. For all other bbArgs, use the
1664  // corresponding init_arg from the old scf.forall op.
1665  for (OpResult result : forallOp.getResults()) {
1666  if (resultToDelete.count(result)) {
1667  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1668  } else {
1669  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1670  }
1671  }
1672  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1673 
1674  // Step 5. Replace the uses of result of old scf.forall with that of the new
1675  // scf.forall.
1676  for (auto &&[oldResult, newResult] :
1677  llvm::zip(resultToReplace, newForallOp->getResults()))
1678  rewriter.replaceAllUsesWith(oldResult, newResult);
1679 
1680  // Step 6. Replace the uses of those values that either has no use or are
1681  // not being modified within the loop with the corresponding
1682  // OpOperand.
1683  for (OpResult oldResult : resultToDelete)
1684  rewriter.replaceAllUsesWith(oldResult,
1685  forallOp.getTiedOpOperand(oldResult)->get());
1686  return success();
1687  }
1688 };
1689 
1690 struct ForallOpSingleOrZeroIterationDimsFolder
1691  : public OpRewritePattern<ForallOp> {
1693 
1694  LogicalResult matchAndRewrite(ForallOp op,
1695  PatternRewriter &rewriter) const override {
1696  // Do not fold dimensions if they are mapped to processing units.
1697  if (op.getMapping().has_value() && !op.getMapping()->empty())
1698  return failure();
1699  Location loc = op.getLoc();
1700 
1701  // Compute new loop bounds that omit all single-iteration loop dimensions.
1702  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1703  newMixedSteps;
1704  IRMapping mapping;
1705  for (auto [lb, ub, step, iv] :
1706  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1707  op.getMixedStep(), op.getInductionVars())) {
1708  auto numIterations = constantTripCount(lb, ub, step);
1709  if (numIterations.has_value()) {
1710  // Remove the loop if it performs zero iterations.
1711  if (*numIterations == 0) {
1712  rewriter.replaceOp(op, op.getOutputs());
1713  return success();
1714  }
1715  // Replace the loop induction variable by the lower bound if the loop
1716  // performs a single iteration. Otherwise, copy the loop bounds.
1717  if (*numIterations == 1) {
1718  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1719  continue;
1720  }
1721  }
1722  newMixedLowerBounds.push_back(lb);
1723  newMixedUpperBounds.push_back(ub);
1724  newMixedSteps.push_back(step);
1725  }
1726 
1727  // All of the loop dimensions perform a single iteration. Inline loop body.
1728  if (newMixedLowerBounds.empty()) {
1729  promote(rewriter, op);
1730  return success();
1731  }
1732 
1733  // Exit if none of the loop dimensions perform a single iteration.
1734  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1735  return rewriter.notifyMatchFailure(
1736  op, "no dimensions have 0 or 1 iterations");
1737  }
1738 
1739  // Replace the loop by a lower-dimensional loop.
1740  ForallOp newOp;
1741  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1742  newMixedUpperBounds, newMixedSteps,
1743  op.getOutputs(), std::nullopt, nullptr);
1744  newOp.getBodyRegion().getBlocks().clear();
1745  // The new loop needs to keep all attributes from the old one, except for
1746  // "operandSegmentSizes" and static loop bound attributes which capture
1747  // the outdated information of the old iteration domain.
1748  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1749  newOp.getStaticLowerBoundAttrName(),
1750  newOp.getStaticUpperBoundAttrName(),
1751  newOp.getStaticStepAttrName()};
1752  for (const auto &namedAttr : op->getAttrs()) {
1753  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1754  continue;
1755  rewriter.modifyOpInPlace(newOp, [&]() {
1756  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1757  });
1758  }
1759  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1760  newOp.getRegion().begin(), mapping);
1761  rewriter.replaceOp(op, newOp.getResults());
1762  return success();
1763  }
1764 };
1765 
1766 /// Replace all induction vars with a single trip count with their lower bound.
1767 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1769 
1770  LogicalResult matchAndRewrite(ForallOp op,
1771  PatternRewriter &rewriter) const override {
1772  Location loc = op.getLoc();
1773  bool changed = false;
1774  for (auto [lb, ub, step, iv] :
1775  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1776  op.getMixedStep(), op.getInductionVars())) {
1777  if (iv.getUses().begin() == iv.getUses().end())
1778  continue;
1779  auto numIterations = constantTripCount(lb, ub, step);
1780  if (!numIterations.has_value() || numIterations.value() != 1) {
1781  continue;
1782  }
1783  rewriter.replaceAllUsesWith(
1784  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1785  changed = true;
1786  }
1787  return success(changed);
1788  }
1789 };
1790 
1791 struct FoldTensorCastOfOutputIntoForallOp
1792  : public OpRewritePattern<scf::ForallOp> {
1794 
1795  struct TypeCast {
1796  Type srcType;
1797  Type dstType;
1798  };
1799 
1800  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1801  PatternRewriter &rewriter) const final {
1802  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1803  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1804  for (auto en : llvm::enumerate(newOutputTensors)) {
1805  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1806  if (!castOp)
1807  continue;
1808 
1809  // Only casts that that preserve static information, i.e. will make the
1810  // loop result type "more" static than before, will be folded.
1811  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1812  castOp.getSource().getType())) {
1813  continue;
1814  }
1815 
1816  tensorCastProducers[en.index()] =
1817  TypeCast{castOp.getSource().getType(), castOp.getType()};
1818  newOutputTensors[en.index()] = castOp.getSource();
1819  }
1820 
1821  if (tensorCastProducers.empty())
1822  return failure();
1823 
1824  // Create new loop.
1825  Location loc = forallOp.getLoc();
1826  auto newForallOp = rewriter.create<ForallOp>(
1827  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1828  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1829  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1830  auto castBlockArgs =
1831  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1832  for (auto [index, cast] : tensorCastProducers) {
1833  Value &oldTypeBBArg = castBlockArgs[index];
1834  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1835  nestedLoc, cast.dstType, oldTypeBBArg);
1836  }
1837 
1838  // Move old body into new parallel loop.
1839  SmallVector<Value> ivsBlockArgs =
1840  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1841  ivsBlockArgs.append(castBlockArgs);
1842  rewriter.mergeBlocks(forallOp.getBody(),
1843  bbArgs.front().getParentBlock(), ivsBlockArgs);
1844  });
1845 
1846  // After `mergeBlocks` happened, the destinations in the terminator were
1847  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1848  // destination have to be updated to point to the output bbArgs directly.
1849  auto terminator = newForallOp.getTerminator();
1850  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1851  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1852  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1853  insertSliceOp.getDestMutable().assign(outputBlockArg);
1854  }
1855 
1856  // Cast results back to the original types.
1857  rewriter.setInsertionPointAfter(newForallOp);
1858  SmallVector<Value> castResults = newForallOp.getResults();
1859  for (auto &item : tensorCastProducers) {
1860  Value &oldTypeResult = castResults[item.first];
1861  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1862  oldTypeResult);
1863  }
1864  rewriter.replaceOp(forallOp, castResults);
1865  return success();
1866  }
1867 };
1868 
1869 } // namespace
1870 
1871 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1872  MLIRContext *context) {
1873  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1874  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1875  ForallOpSingleOrZeroIterationDimsFolder,
1876  ForallOpReplaceConstantInductionVar>(context);
1877 }
1878 
1879 /// Given the region at `index`, or the parent operation if `index` is None,
1880 /// return the successor regions. These are the regions that may be selected
1881 /// during the flow of control. `operands` is a set of optional attributes that
1882 /// correspond to a constant value for each operand, or null if that operand is
1883 /// not a constant.
1884 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1886  // Both the operation itself and the region may be branching into the body or
1887  // back into the operation itself. It is possible for loop not to enter the
1888  // body.
1889  regions.push_back(RegionSuccessor(&getRegion()));
1890  regions.push_back(RegionSuccessor());
1891 }
1892 
1893 //===----------------------------------------------------------------------===//
1894 // InParallelOp
1895 //===----------------------------------------------------------------------===//
1896 
1897 // Build a InParallelOp with mixed static and dynamic entries.
1898 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1900  Region *bodyRegion = result.addRegion();
1901  b.createBlock(bodyRegion);
1902 }
1903 
1904 LogicalResult InParallelOp::verify() {
1905  scf::ForallOp forallOp =
1906  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1907  if (!forallOp)
1908  return this->emitOpError("expected forall op parent");
1909 
1910  // TODO: InParallelOpInterface.
1911  for (Operation &op : getRegion().front().getOperations()) {
1912  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1913  return this->emitOpError("expected only ")
1914  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1915  }
1916 
1917  // Verify that inserts are into out block arguments.
1918  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1919  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1920  if (!llvm::is_contained(regionOutArgs, dest))
1921  return op.emitOpError("may only insert into an output block argument");
1922  }
1923  return success();
1924 }
1925 
1927  p << " ";
1928  p.printRegion(getRegion(),
1929  /*printEntryBlockArgs=*/false,
1930  /*printBlockTerminators=*/false);
1931  p.printOptionalAttrDict(getOperation()->getAttrs());
1932 }
1933 
1934 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1935  auto &builder = parser.getBuilder();
1936 
1938  std::unique_ptr<Region> region = std::make_unique<Region>();
1939  if (parser.parseRegion(*region, regionOperands))
1940  return failure();
1941 
1942  if (region->empty())
1943  OpBuilder(builder.getContext()).createBlock(region.get());
1944  result.addRegion(std::move(region));
1945 
1946  // Parse the optional attribute list.
1947  if (parser.parseOptionalAttrDict(result.attributes))
1948  return failure();
1949  return success();
1950 }
1951 
1952 OpResult InParallelOp::getParentResult(int64_t idx) {
1953  return getOperation()->getParentOp()->getResult(idx);
1954 }
1955 
1956 SmallVector<BlockArgument> InParallelOp::getDests() {
1957  return llvm::to_vector<4>(
1958  llvm::map_range(getYieldingOps(), [](Operation &op) {
1959  // Add new ops here as needed.
1960  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1961  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1962  }));
1963 }
1964 
1965 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1966  return getRegion().front().getOperations();
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // IfOp
1971 //===----------------------------------------------------------------------===//
1972 
1974  assert(a && "expected non-empty operation");
1975  assert(b && "expected non-empty operation");
1976 
1977  IfOp ifOp = a->getParentOfType<IfOp>();
1978  while (ifOp) {
1979  // Check if b is inside ifOp. (We already know that a is.)
1980  if (ifOp->isProperAncestor(b))
1981  // b is contained in ifOp. a and b are in mutually exclusive branches if
1982  // they are in different blocks of ifOp.
1983  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1984  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1985  // Check next enclosing IfOp.
1986  ifOp = ifOp->getParentOfType<IfOp>();
1987  }
1988 
1989  // Could not find a common IfOp among a's and b's ancestors.
1990  return false;
1991 }
1992 
1993 LogicalResult
1994 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1995  IfOp::Adaptor adaptor,
1996  SmallVectorImpl<Type> &inferredReturnTypes) {
1997  if (adaptor.getRegions().empty())
1998  return failure();
1999  Region *r = &adaptor.getThenRegion();
2000  if (r->empty())
2001  return failure();
2002  Block &b = r->front();
2003  if (b.empty())
2004  return failure();
2005  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2006  if (!yieldOp)
2007  return failure();
2008  TypeRange types = yieldOp.getOperandTypes();
2009  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2010  types.end());
2011  return success();
2012 }
2013 
2014 void IfOp::build(OpBuilder &builder, OperationState &result,
2015  TypeRange resultTypes, Value cond) {
2016  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2017  /*addElseBlock=*/false);
2018 }
2019 
2020 void IfOp::build(OpBuilder &builder, OperationState &result,
2021  TypeRange resultTypes, Value cond, bool addThenBlock,
2022  bool addElseBlock) {
2023  assert((!addElseBlock || addThenBlock) &&
2024  "must not create else block w/o then block");
2025  result.addTypes(resultTypes);
2026  result.addOperands(cond);
2027 
2028  // Add regions and blocks.
2029  OpBuilder::InsertionGuard guard(builder);
2030  Region *thenRegion = result.addRegion();
2031  if (addThenBlock)
2032  builder.createBlock(thenRegion);
2033  Region *elseRegion = result.addRegion();
2034  if (addElseBlock)
2035  builder.createBlock(elseRegion);
2036 }
2037 
2038 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2039  bool withElseRegion) {
2040  build(builder, result, TypeRange{}, cond, withElseRegion);
2041 }
2042 
2043 void IfOp::build(OpBuilder &builder, OperationState &result,
2044  TypeRange resultTypes, Value cond, bool withElseRegion) {
2045  result.addTypes(resultTypes);
2046  result.addOperands(cond);
2047 
2048  // Build then region.
2049  OpBuilder::InsertionGuard guard(builder);
2050  Region *thenRegion = result.addRegion();
2051  builder.createBlock(thenRegion);
2052  if (resultTypes.empty())
2053  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2054 
2055  // Build else region.
2056  Region *elseRegion = result.addRegion();
2057  if (withElseRegion) {
2058  builder.createBlock(elseRegion);
2059  if (resultTypes.empty())
2060  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2061  }
2062 }
2063 
2064 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2065  function_ref<void(OpBuilder &, Location)> thenBuilder,
2066  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2067  assert(thenBuilder && "the builder callback for 'then' must be present");
2068  result.addOperands(cond);
2069 
2070  // Build then region.
2071  OpBuilder::InsertionGuard guard(builder);
2072  Region *thenRegion = result.addRegion();
2073  builder.createBlock(thenRegion);
2074  thenBuilder(builder, result.location);
2075 
2076  // Build else region.
2077  Region *elseRegion = result.addRegion();
2078  if (elseBuilder) {
2079  builder.createBlock(elseRegion);
2080  elseBuilder(builder, result.location);
2081  }
2082 
2083  // Infer result types.
2084  SmallVector<Type> inferredReturnTypes;
2085  MLIRContext *ctx = builder.getContext();
2086  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2087  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2088  /*properties=*/nullptr, result.regions,
2089  inferredReturnTypes))) {
2090  result.addTypes(inferredReturnTypes);
2091  }
2092 }
2093 
2094 LogicalResult IfOp::verify() {
2095  if (getNumResults() != 0 && getElseRegion().empty())
2096  return emitOpError("must have an else block if defining values");
2097  return success();
2098 }
2099 
2100 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2101  // Create the regions for 'then'.
2102  result.regions.reserve(2);
2103  Region *thenRegion = result.addRegion();
2104  Region *elseRegion = result.addRegion();
2105 
2106  auto &builder = parser.getBuilder();
2108  Type i1Type = builder.getIntegerType(1);
2109  if (parser.parseOperand(cond) ||
2110  parser.resolveOperand(cond, i1Type, result.operands))
2111  return failure();
2112  // Parse optional results type list.
2113  if (parser.parseOptionalArrowTypeList(result.types))
2114  return failure();
2115  // Parse the 'then' region.
2116  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2117  return failure();
2118  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2119 
2120  // If we find an 'else' keyword then parse the 'else' region.
2121  if (!parser.parseOptionalKeyword("else")) {
2122  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2123  return failure();
2124  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2125  }
2126 
2127  // Parse the optional attribute list.
2128  if (parser.parseOptionalAttrDict(result.attributes))
2129  return failure();
2130  return success();
2131 }
2132 
2133 void IfOp::print(OpAsmPrinter &p) {
2134  bool printBlockTerminators = false;
2135 
2136  p << " " << getCondition();
2137  if (!getResults().empty()) {
2138  p << " -> (" << getResultTypes() << ")";
2139  // Print yield explicitly if the op defines values.
2140  printBlockTerminators = true;
2141  }
2142  p << ' ';
2143  p.printRegion(getThenRegion(),
2144  /*printEntryBlockArgs=*/false,
2145  /*printBlockTerminators=*/printBlockTerminators);
2146 
2147  // Print the 'else' regions if it exists and has a block.
2148  auto &elseRegion = getElseRegion();
2149  if (!elseRegion.empty()) {
2150  p << " else ";
2151  p.printRegion(elseRegion,
2152  /*printEntryBlockArgs=*/false,
2153  /*printBlockTerminators=*/printBlockTerminators);
2154  }
2155 
2156  p.printOptionalAttrDict((*this)->getAttrs());
2157 }
2158 
2159 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2161  // The `then` and the `else` region branch back to the parent operation.
2162  if (!point.isParent()) {
2163  regions.push_back(RegionSuccessor(getResults()));
2164  return;
2165  }
2166 
2167  regions.push_back(RegionSuccessor(&getThenRegion()));
2168 
2169  // Don't consider the else region if it is empty.
2170  Region *elseRegion = &this->getElseRegion();
2171  if (elseRegion->empty())
2172  regions.push_back(RegionSuccessor());
2173  else
2174  regions.push_back(RegionSuccessor(elseRegion));
2175 }
2176 
2177 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2179  FoldAdaptor adaptor(operands, *this);
2180  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2181  if (!boolAttr || boolAttr.getValue())
2182  regions.emplace_back(&getThenRegion());
2183 
2184  // If the else region is empty, execution continues after the parent op.
2185  if (!boolAttr || !boolAttr.getValue()) {
2186  if (!getElseRegion().empty())
2187  regions.emplace_back(&getElseRegion());
2188  else
2189  regions.emplace_back(getResults());
2190  }
2191 }
2192 
2193 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2194  SmallVectorImpl<OpFoldResult> &results) {
2195  // if (!c) then A() else B() -> if c then B() else A()
2196  if (getElseRegion().empty())
2197  return failure();
2198 
2199  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2200  if (!xorStmt)
2201  return failure();
2202 
2203  if (!matchPattern(xorStmt.getRhs(), m_One()))
2204  return failure();
2205 
2206  getConditionMutable().assign(xorStmt.getLhs());
2207  Block *thenBlock = &getThenRegion().front();
2208  // It would be nicer to use iplist::swap, but that has no implemented
2209  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2210  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2211  getElseRegion().getBlocks());
2212  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2213  getThenRegion().getBlocks(), thenBlock);
2214  return success();
2215 }
2216 
2217 void IfOp::getRegionInvocationBounds(
2218  ArrayRef<Attribute> operands,
2219  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2220  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2221  // If the condition is known, then one region is known to be executed once
2222  // and the other zero times.
2223  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2224  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2225  } else {
2226  // Non-constant condition. Each region may be executed 0 or 1 times.
2227  invocationBounds.assign(2, {0, 1});
2228  }
2229 }
2230 
2231 namespace {
2232 // Pattern to remove unused IfOp results.
2233 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2235 
2236  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2237  PatternRewriter &rewriter) const {
2238  // Move all operations to the destination block.
2239  rewriter.mergeBlocks(source, dest);
2240  // Replace the yield op by one that returns only the used values.
2241  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2242  SmallVector<Value, 4> usedOperands;
2243  llvm::transform(usedResults, std::back_inserter(usedOperands),
2244  [&](OpResult result) {
2245  return yieldOp.getOperand(result.getResultNumber());
2246  });
2247  rewriter.modifyOpInPlace(yieldOp,
2248  [&]() { yieldOp->setOperands(usedOperands); });
2249  }
2250 
2251  LogicalResult matchAndRewrite(IfOp op,
2252  PatternRewriter &rewriter) const override {
2253  // Compute the list of used results.
2254  SmallVector<OpResult, 4> usedResults;
2255  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2256  [](OpResult result) { return !result.use_empty(); });
2257 
2258  // Replace the operation if only a subset of its results have uses.
2259  if (usedResults.size() == op.getNumResults())
2260  return failure();
2261 
2262  // Compute the result types of the replacement operation.
2263  SmallVector<Type, 4> newTypes;
2264  llvm::transform(usedResults, std::back_inserter(newTypes),
2265  [](OpResult result) { return result.getType(); });
2266 
2267  // Create a replacement operation with empty then and else regions.
2268  auto newOp =
2269  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2270  rewriter.createBlock(&newOp.getThenRegion());
2271  rewriter.createBlock(&newOp.getElseRegion());
2272 
2273  // Move the bodies and replace the terminators (note there is a then and
2274  // an else region since the operation returns results).
2275  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2276  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2277 
2278  // Replace the operation by the new one.
2279  SmallVector<Value, 4> repResults(op.getNumResults());
2280  for (const auto &en : llvm::enumerate(usedResults))
2281  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2282  rewriter.replaceOp(op, repResults);
2283  return success();
2284  }
2285 };
2286 
2287 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2289 
2290  LogicalResult matchAndRewrite(IfOp op,
2291  PatternRewriter &rewriter) const override {
2292  BoolAttr condition;
2293  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2294  return failure();
2295 
2296  if (condition.getValue())
2297  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2298  else if (!op.getElseRegion().empty())
2299  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2300  else
2301  rewriter.eraseOp(op);
2302 
2303  return success();
2304  }
2305 };
2306 
2307 /// Hoist any yielded results whose operands are defined outside
2308 /// the if, to a select instruction.
2309 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2311 
2312  LogicalResult matchAndRewrite(IfOp op,
2313  PatternRewriter &rewriter) const override {
2314  if (op->getNumResults() == 0)
2315  return failure();
2316 
2317  auto cond = op.getCondition();
2318  auto thenYieldArgs = op.thenYield().getOperands();
2319  auto elseYieldArgs = op.elseYield().getOperands();
2320 
2321  SmallVector<Type> nonHoistable;
2322  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2323  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2324  &op.getElseRegion() == falseVal.getParentRegion())
2325  nonHoistable.push_back(trueVal.getType());
2326  }
2327  // Early exit if there aren't any yielded values we can
2328  // hoist outside the if.
2329  if (nonHoistable.size() == op->getNumResults())
2330  return failure();
2331 
2332  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2333  /*withElseRegion=*/false);
2334  if (replacement.thenBlock())
2335  rewriter.eraseBlock(replacement.thenBlock());
2336  replacement.getThenRegion().takeBody(op.getThenRegion());
2337  replacement.getElseRegion().takeBody(op.getElseRegion());
2338 
2339  SmallVector<Value> results(op->getNumResults());
2340  assert(thenYieldArgs.size() == results.size());
2341  assert(elseYieldArgs.size() == results.size());
2342 
2343  SmallVector<Value> trueYields;
2344  SmallVector<Value> falseYields;
2345  rewriter.setInsertionPoint(replacement);
2346  for (const auto &it :
2347  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2348  Value trueVal = std::get<0>(it.value());
2349  Value falseVal = std::get<1>(it.value());
2350  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2351  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2352  results[it.index()] = replacement.getResult(trueYields.size());
2353  trueYields.push_back(trueVal);
2354  falseYields.push_back(falseVal);
2355  } else if (trueVal == falseVal)
2356  results[it.index()] = trueVal;
2357  else
2358  results[it.index()] = rewriter.create<arith::SelectOp>(
2359  op.getLoc(), cond, trueVal, falseVal);
2360  }
2361 
2362  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2363  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2364 
2365  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2366  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2367 
2368  rewriter.replaceOp(op, results);
2369  return success();
2370  }
2371 };
2372 
2373 /// Allow the true region of an if to assume the condition is true
2374 /// and vice versa. For example:
2375 ///
2376 /// scf.if %cmp {
2377 /// print(%cmp)
2378 /// }
2379 ///
2380 /// becomes
2381 ///
2382 /// scf.if %cmp {
2383 /// print(true)
2384 /// }
2385 ///
2386 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2388 
2389  LogicalResult matchAndRewrite(IfOp op,
2390  PatternRewriter &rewriter) const override {
2391  // Early exit if the condition is constant since replacing a constant
2392  // in the body with another constant isn't a simplification.
2393  if (matchPattern(op.getCondition(), m_Constant()))
2394  return failure();
2395 
2396  bool changed = false;
2397  mlir::Type i1Ty = rewriter.getI1Type();
2398 
2399  // These variables serve to prevent creating duplicate constants
2400  // and hold constant true or false values.
2401  Value constantTrue = nullptr;
2402  Value constantFalse = nullptr;
2403 
2404  for (OpOperand &use :
2405  llvm::make_early_inc_range(op.getCondition().getUses())) {
2406  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2407  changed = true;
2408 
2409  if (!constantTrue)
2410  constantTrue = rewriter.create<arith::ConstantOp>(
2411  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2412 
2413  rewriter.modifyOpInPlace(use.getOwner(),
2414  [&]() { use.set(constantTrue); });
2415  } else if (op.getElseRegion().isAncestor(
2416  use.getOwner()->getParentRegion())) {
2417  changed = true;
2418 
2419  if (!constantFalse)
2420  constantFalse = rewriter.create<arith::ConstantOp>(
2421  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2422 
2423  rewriter.modifyOpInPlace(use.getOwner(),
2424  [&]() { use.set(constantFalse); });
2425  }
2426  }
2427 
2428  return success(changed);
2429  }
2430 };
2431 
2432 /// Remove any statements from an if that are equivalent to the condition
2433 /// or its negation. For example:
2434 ///
2435 /// %res:2 = scf.if %cmp {
2436 /// yield something(), true
2437 /// } else {
2438 /// yield something2(), false
2439 /// }
2440 /// print(%res#1)
2441 ///
2442 /// becomes
2443 /// %res = scf.if %cmp {
2444 /// yield something()
2445 /// } else {
2446 /// yield something2()
2447 /// }
2448 /// print(%cmp)
2449 ///
2450 /// Additionally if both branches yield the same value, replace all uses
2451 /// of the result with the yielded value.
2452 ///
2453 /// %res:2 = scf.if %cmp {
2454 /// yield something(), %arg1
2455 /// } else {
2456 /// yield something2(), %arg1
2457 /// }
2458 /// print(%res#1)
2459 ///
2460 /// becomes
2461 /// %res = scf.if %cmp {
2462 /// yield something()
2463 /// } else {
2464 /// yield something2()
2465 /// }
2466 /// print(%arg1)
2467 ///
2468 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2470 
2471  LogicalResult matchAndRewrite(IfOp op,
2472  PatternRewriter &rewriter) const override {
2473  // Early exit if there are no results that could be replaced.
2474  if (op.getNumResults() == 0)
2475  return failure();
2476 
2477  auto trueYield =
2478  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2479  auto falseYield =
2480  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2481 
2482  rewriter.setInsertionPoint(op->getBlock(),
2483  op.getOperation()->getIterator());
2484  bool changed = false;
2485  Type i1Ty = rewriter.getI1Type();
2486  for (auto [trueResult, falseResult, opResult] :
2487  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2488  op.getResults())) {
2489  if (trueResult == falseResult) {
2490  if (!opResult.use_empty()) {
2491  opResult.replaceAllUsesWith(trueResult);
2492  changed = true;
2493  }
2494  continue;
2495  }
2496 
2497  BoolAttr trueYield, falseYield;
2498  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2499  !matchPattern(falseResult, m_Constant(&falseYield)))
2500  continue;
2501 
2502  bool trueVal = trueYield.getValue();
2503  bool falseVal = falseYield.getValue();
2504  if (!trueVal && falseVal) {
2505  if (!opResult.use_empty()) {
2506  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2507  Value notCond = rewriter.create<arith::XOrIOp>(
2508  op.getLoc(), op.getCondition(),
2509  constDialect
2510  ->materializeConstant(rewriter,
2511  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2512  op.getLoc())
2513  ->getResult(0));
2514  opResult.replaceAllUsesWith(notCond);
2515  changed = true;
2516  }
2517  }
2518  if (trueVal && !falseVal) {
2519  if (!opResult.use_empty()) {
2520  opResult.replaceAllUsesWith(op.getCondition());
2521  changed = true;
2522  }
2523  }
2524  }
2525  return success(changed);
2526  }
2527 };
2528 
2529 /// Merge any consecutive scf.if's with the same condition.
2530 ///
2531 /// scf.if %cond {
2532 /// firstCodeTrue();...
2533 /// } else {
2534 /// firstCodeFalse();...
2535 /// }
2536 /// %res = scf.if %cond {
2537 /// secondCodeTrue();...
2538 /// } else {
2539 /// secondCodeFalse();...
2540 /// }
2541 ///
2542 /// becomes
2543 /// %res = scf.if %cmp {
2544 /// firstCodeTrue();...
2545 /// secondCodeTrue();...
2546 /// } else {
2547 /// firstCodeFalse();...
2548 /// secondCodeFalse();...
2549 /// }
2550 struct CombineIfs : public OpRewritePattern<IfOp> {
2552 
2553  LogicalResult matchAndRewrite(IfOp nextIf,
2554  PatternRewriter &rewriter) const override {
2555  Block *parent = nextIf->getBlock();
2556  if (nextIf == &parent->front())
2557  return failure();
2558 
2559  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2560  if (!prevIf)
2561  return failure();
2562 
2563  // Determine the logical then/else blocks when prevIf's
2564  // condition is used. Null means the block does not exist
2565  // in that case (e.g. empty else). If neither of these
2566  // are set, the two conditions cannot be compared.
2567  Block *nextThen = nullptr;
2568  Block *nextElse = nullptr;
2569  if (nextIf.getCondition() == prevIf.getCondition()) {
2570  nextThen = nextIf.thenBlock();
2571  if (!nextIf.getElseRegion().empty())
2572  nextElse = nextIf.elseBlock();
2573  }
2574  if (arith::XOrIOp notv =
2575  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2576  if (notv.getLhs() == prevIf.getCondition() &&
2577  matchPattern(notv.getRhs(), m_One())) {
2578  nextElse = nextIf.thenBlock();
2579  if (!nextIf.getElseRegion().empty())
2580  nextThen = nextIf.elseBlock();
2581  }
2582  }
2583  if (arith::XOrIOp notv =
2584  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2585  if (notv.getLhs() == nextIf.getCondition() &&
2586  matchPattern(notv.getRhs(), m_One())) {
2587  nextElse = nextIf.thenBlock();
2588  if (!nextIf.getElseRegion().empty())
2589  nextThen = nextIf.elseBlock();
2590  }
2591  }
2592 
2593  if (!nextThen && !nextElse)
2594  return failure();
2595 
2596  SmallVector<Value> prevElseYielded;
2597  if (!prevIf.getElseRegion().empty())
2598  prevElseYielded = prevIf.elseYield().getOperands();
2599  // Replace all uses of return values of op within nextIf with the
2600  // corresponding yields
2601  for (auto it : llvm::zip(prevIf.getResults(),
2602  prevIf.thenYield().getOperands(), prevElseYielded))
2603  for (OpOperand &use :
2604  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2605  if (nextThen && nextThen->getParent()->isAncestor(
2606  use.getOwner()->getParentRegion())) {
2607  rewriter.startOpModification(use.getOwner());
2608  use.set(std::get<1>(it));
2609  rewriter.finalizeOpModification(use.getOwner());
2610  } else if (nextElse && nextElse->getParent()->isAncestor(
2611  use.getOwner()->getParentRegion())) {
2612  rewriter.startOpModification(use.getOwner());
2613  use.set(std::get<2>(it));
2614  rewriter.finalizeOpModification(use.getOwner());
2615  }
2616  }
2617 
2618  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2619  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2620 
2621  IfOp combinedIf = rewriter.create<IfOp>(
2622  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2623  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2624 
2625  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2626  combinedIf.getThenRegion(),
2627  combinedIf.getThenRegion().begin());
2628 
2629  if (nextThen) {
2630  YieldOp thenYield = combinedIf.thenYield();
2631  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2632  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2633  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2634 
2635  SmallVector<Value> mergedYields(thenYield.getOperands());
2636  llvm::append_range(mergedYields, thenYield2.getOperands());
2637  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2638  rewriter.eraseOp(thenYield);
2639  rewriter.eraseOp(thenYield2);
2640  }
2641 
2642  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2643  combinedIf.getElseRegion(),
2644  combinedIf.getElseRegion().begin());
2645 
2646  if (nextElse) {
2647  if (combinedIf.getElseRegion().empty()) {
2648  rewriter.inlineRegionBefore(*nextElse->getParent(),
2649  combinedIf.getElseRegion(),
2650  combinedIf.getElseRegion().begin());
2651  } else {
2652  YieldOp elseYield = combinedIf.elseYield();
2653  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2654  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2655 
2656  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2657 
2658  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2659  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2660 
2661  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2662  rewriter.eraseOp(elseYield);
2663  rewriter.eraseOp(elseYield2);
2664  }
2665  }
2666 
2667  SmallVector<Value> prevValues;
2668  SmallVector<Value> nextValues;
2669  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2670  if (pair.index() < prevIf.getNumResults())
2671  prevValues.push_back(pair.value());
2672  else
2673  nextValues.push_back(pair.value());
2674  }
2675  rewriter.replaceOp(prevIf, prevValues);
2676  rewriter.replaceOp(nextIf, nextValues);
2677  return success();
2678  }
2679 };
2680 
2681 /// Pattern to remove an empty else branch.
2682 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2684 
2685  LogicalResult matchAndRewrite(IfOp ifOp,
2686  PatternRewriter &rewriter) const override {
2687  // Cannot remove else region when there are operation results.
2688  if (ifOp.getNumResults())
2689  return failure();
2690  Block *elseBlock = ifOp.elseBlock();
2691  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2692  return failure();
2693  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2694  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2695  newIfOp.getThenRegion().begin());
2696  rewriter.eraseOp(ifOp);
2697  return success();
2698  }
2699 };
2700 
2701 /// Convert nested `if`s into `arith.andi` + single `if`.
2702 ///
2703 /// scf.if %arg0 {
2704 /// scf.if %arg1 {
2705 /// ...
2706 /// scf.yield
2707 /// }
2708 /// scf.yield
2709 /// }
2710 /// becomes
2711 ///
2712 /// %0 = arith.andi %arg0, %arg1
2713 /// scf.if %0 {
2714 /// ...
2715 /// scf.yield
2716 /// }
2717 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2719 
2720  LogicalResult matchAndRewrite(IfOp op,
2721  PatternRewriter &rewriter) const override {
2722  auto nestedOps = op.thenBlock()->without_terminator();
2723  // Nested `if` must be the only op in block.
2724  if (!llvm::hasSingleElement(nestedOps))
2725  return failure();
2726 
2727  // If there is an else block, it can only yield
2728  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2729  return failure();
2730 
2731  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2732  if (!nestedIf)
2733  return failure();
2734 
2735  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2736  return failure();
2737 
2738  SmallVector<Value> thenYield(op.thenYield().getOperands());
2739  SmallVector<Value> elseYield;
2740  if (op.elseBlock())
2741  llvm::append_range(elseYield, op.elseYield().getOperands());
2742 
2743  // A list of indices for which we should upgrade the value yielded
2744  // in the else to a select.
2745  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2746 
2747  // If the outer scf.if yields a value produced by the inner scf.if,
2748  // only permit combining if the value yielded when the condition
2749  // is false in the outer scf.if is the same value yielded when the
2750  // inner scf.if condition is false.
2751  // Note that the array access to elseYield will not go out of bounds
2752  // since it must have the same length as thenYield, since they both
2753  // come from the same scf.if.
2754  for (const auto &tup : llvm::enumerate(thenYield)) {
2755  if (tup.value().getDefiningOp() == nestedIf) {
2756  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2757  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2758  elseYield[tup.index()]) {
2759  return failure();
2760  }
2761  // If the correctness test passes, we will yield
2762  // corresponding value from the inner scf.if
2763  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2764  continue;
2765  }
2766 
2767  // Otherwise, we need to ensure the else block of the combined
2768  // condition still returns the same value when the outer condition is
2769  // true and the inner condition is false. This can be accomplished if
2770  // the then value is defined outside the outer scf.if and we replace the
2771  // value with a select that considers just the outer condition. Since
2772  // the else region contains just the yield, its yielded value is
2773  // defined outside the scf.if, by definition.
2774 
2775  // If the then value is defined within the scf.if, bail.
2776  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2777  return failure();
2778  }
2779  elseYieldsToUpgradeToSelect.push_back(tup.index());
2780  }
2781 
2782  Location loc = op.getLoc();
2783  Value newCondition = rewriter.create<arith::AndIOp>(
2784  loc, op.getCondition(), nestedIf.getCondition());
2785  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2786  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2787 
2788  SmallVector<Value> results;
2789  llvm::append_range(results, newIf.getResults());
2790  rewriter.setInsertionPoint(newIf);
2791 
2792  for (auto idx : elseYieldsToUpgradeToSelect)
2793  results[idx] = rewriter.create<arith::SelectOp>(
2794  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2795 
2796  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2797  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2798  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2799  if (!elseYield.empty()) {
2800  rewriter.createBlock(&newIf.getElseRegion());
2801  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2802  rewriter.create<YieldOp>(loc, elseYield);
2803  }
2804  rewriter.replaceOp(op, results);
2805  return success();
2806  }
2807 };
2808 
2809 } // namespace
2810 
2811 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2812  MLIRContext *context) {
2813  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2814  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2815  RemoveStaticCondition, RemoveUnusedResults,
2816  ReplaceIfYieldWithConditionOrValue>(context);
2817 }
2818 
2819 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2820 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2821 Block *IfOp::elseBlock() {
2822  Region &r = getElseRegion();
2823  if (r.empty())
2824  return nullptr;
2825  return &r.back();
2826 }
2827 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2828 
2829 //===----------------------------------------------------------------------===//
2830 // ParallelOp
2831 //===----------------------------------------------------------------------===//
2832 
2833 void ParallelOp::build(
2834  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2835  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2837  bodyBuilderFn) {
2838  result.addOperands(lowerBounds);
2839  result.addOperands(upperBounds);
2840  result.addOperands(steps);
2841  result.addOperands(initVals);
2842  result.addAttribute(
2843  ParallelOp::getOperandSegmentSizeAttr(),
2844  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2845  static_cast<int32_t>(upperBounds.size()),
2846  static_cast<int32_t>(steps.size()),
2847  static_cast<int32_t>(initVals.size())}));
2848  result.addTypes(initVals.getTypes());
2849 
2850  OpBuilder::InsertionGuard guard(builder);
2851  unsigned numIVs = steps.size();
2852  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2853  SmallVector<Location, 8> argLocs(numIVs, result.location);
2854  Region *bodyRegion = result.addRegion();
2855  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2856 
2857  if (bodyBuilderFn) {
2858  builder.setInsertionPointToStart(bodyBlock);
2859  bodyBuilderFn(builder, result.location,
2860  bodyBlock->getArguments().take_front(numIVs),
2861  bodyBlock->getArguments().drop_front(numIVs));
2862  }
2863  // Add terminator only if there are no reductions.
2864  if (initVals.empty())
2865  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2866 }
2867 
2868 void ParallelOp::build(
2869  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2870  ValueRange upperBounds, ValueRange steps,
2871  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2872  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2873  // we don't capture a reference to a temporary by constructing the lambda at
2874  // function level.
2875  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2876  Location nestedLoc, ValueRange ivs,
2877  ValueRange) {
2878  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2879  };
2880  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2881  if (bodyBuilderFn)
2882  wrapper = wrappedBuilderFn;
2883 
2884  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2885  wrapper);
2886 }
2887 
2888 LogicalResult ParallelOp::verify() {
2889  // Check that there is at least one value in lowerBound, upperBound and step.
2890  // It is sufficient to test only step, because it is ensured already that the
2891  // number of elements in lowerBound, upperBound and step are the same.
2892  Operation::operand_range stepValues = getStep();
2893  if (stepValues.empty())
2894  return emitOpError(
2895  "needs at least one tuple element for lowerBound, upperBound and step");
2896 
2897  // Check whether all constant step values are positive.
2898  for (Value stepValue : stepValues)
2899  if (auto cst = getConstantIntValue(stepValue))
2900  if (*cst <= 0)
2901  return emitOpError("constant step operand must be positive");
2902 
2903  // Check that the body defines the same number of block arguments as the
2904  // number of tuple elements in step.
2905  Block *body = getBody();
2906  if (body->getNumArguments() != stepValues.size())
2907  return emitOpError() << "expects the same number of induction variables: "
2908  << body->getNumArguments()
2909  << " as bound and step values: " << stepValues.size();
2910  for (auto arg : body->getArguments())
2911  if (!arg.getType().isIndex())
2912  return emitOpError(
2913  "expects arguments for the induction variable to be of index type");
2914 
2915  // Check that the terminator is an scf.reduce op.
2916  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2917  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2918  if (!reduceOp)
2919  return failure();
2920 
2921  // Check that the number of results is the same as the number of reductions.
2922  auto resultsSize = getResults().size();
2923  auto reductionsSize = reduceOp.getReductions().size();
2924  auto initValsSize = getInitVals().size();
2925  if (resultsSize != reductionsSize)
2926  return emitOpError() << "expects number of results: " << resultsSize
2927  << " to be the same as number of reductions: "
2928  << reductionsSize;
2929  if (resultsSize != initValsSize)
2930  return emitOpError() << "expects number of results: " << resultsSize
2931  << " to be the same as number of initial values: "
2932  << initValsSize;
2933 
2934  // Check that the types of the results and reductions are the same.
2935  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2936  auto resultType = getOperation()->getResult(i).getType();
2937  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2938  if (resultType != reductionOperandType)
2939  return reduceOp.emitOpError()
2940  << "expects type of " << i
2941  << "-th reduction operand: " << reductionOperandType
2942  << " to be the same as the " << i
2943  << "-th result type: " << resultType;
2944  }
2945  return success();
2946 }
2947 
2948 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2949  auto &builder = parser.getBuilder();
2950  // Parse an opening `(` followed by induction variables followed by `)`
2953  return failure();
2954 
2955  // Parse loop bounds.
2957  if (parser.parseEqual() ||
2958  parser.parseOperandList(lower, ivs.size(),
2960  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2961  return failure();
2962 
2964  if (parser.parseKeyword("to") ||
2965  parser.parseOperandList(upper, ivs.size(),
2967  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2968  return failure();
2969 
2970  // Parse step values.
2972  if (parser.parseKeyword("step") ||
2973  parser.parseOperandList(steps, ivs.size(),
2975  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2976  return failure();
2977 
2978  // Parse init values.
2980  if (succeeded(parser.parseOptionalKeyword("init"))) {
2981  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2982  return failure();
2983  }
2984 
2985  // Parse optional results in case there is a reduce.
2986  if (parser.parseOptionalArrowTypeList(result.types))
2987  return failure();
2988 
2989  // Now parse the body.
2990  Region *body = result.addRegion();
2991  for (auto &iv : ivs)
2992  iv.type = builder.getIndexType();
2993  if (parser.parseRegion(*body, ivs))
2994  return failure();
2995 
2996  // Set `operandSegmentSizes` attribute.
2997  result.addAttribute(
2998  ParallelOp::getOperandSegmentSizeAttr(),
2999  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3000  static_cast<int32_t>(upper.size()),
3001  static_cast<int32_t>(steps.size()),
3002  static_cast<int32_t>(initVals.size())}));
3003 
3004  // Parse attributes.
3005  if (parser.parseOptionalAttrDict(result.attributes) ||
3006  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3007  result.operands))
3008  return failure();
3009 
3010  // Add a terminator if none was parsed.
3011  ParallelOp::ensureTerminator(*body, builder, result.location);
3012  return success();
3013 }
3014 
3016  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3017  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3018  if (!getInitVals().empty())
3019  p << " init (" << getInitVals() << ")";
3020  p.printOptionalArrowTypeList(getResultTypes());
3021  p << ' ';
3022  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3024  (*this)->getAttrs(),
3025  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3026 }
3027 
3028 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3029 
3030 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3031  return SmallVector<Value>{getBody()->getArguments()};
3032 }
3033 
3034 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3035  return getLowerBound();
3036 }
3037 
3038 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3039  return getUpperBound();
3040 }
3041 
3042 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3043  return getStep();
3044 }
3045 
3047  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3048  if (!ivArg)
3049  return ParallelOp();
3050  assert(ivArg.getOwner() && "unlinked block argument");
3051  auto *containingOp = ivArg.getOwner()->getParentOp();
3052  return dyn_cast<ParallelOp>(containingOp);
3053 }
3054 
3055 namespace {
3056 // Collapse loop dimensions that perform a single iteration.
3057 struct ParallelOpSingleOrZeroIterationDimsFolder
3058  : public OpRewritePattern<ParallelOp> {
3060 
3061  LogicalResult matchAndRewrite(ParallelOp op,
3062  PatternRewriter &rewriter) const override {
3063  Location loc = op.getLoc();
3064 
3065  // Compute new loop bounds that omit all single-iteration loop dimensions.
3066  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3067  IRMapping mapping;
3068  for (auto [lb, ub, step, iv] :
3069  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3070  op.getInductionVars())) {
3071  auto numIterations = constantTripCount(lb, ub, step);
3072  if (numIterations.has_value()) {
3073  // Remove the loop if it performs zero iterations.
3074  if (*numIterations == 0) {
3075  rewriter.replaceOp(op, op.getInitVals());
3076  return success();
3077  }
3078  // Replace the loop induction variable by the lower bound if the loop
3079  // performs a single iteration. Otherwise, copy the loop bounds.
3080  if (*numIterations == 1) {
3081  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3082  continue;
3083  }
3084  }
3085  newLowerBounds.push_back(lb);
3086  newUpperBounds.push_back(ub);
3087  newSteps.push_back(step);
3088  }
3089  // Exit if none of the loop dimensions perform a single iteration.
3090  if (newLowerBounds.size() == op.getLowerBound().size())
3091  return failure();
3092 
3093  if (newLowerBounds.empty()) {
3094  // All of the loop dimensions perform a single iteration. Inline
3095  // loop body and nested ReduceOp's
3096  SmallVector<Value> results;
3097  results.reserve(op.getInitVals().size());
3098  for (auto &bodyOp : op.getBody()->without_terminator())
3099  rewriter.clone(bodyOp, mapping);
3100  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3101  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3102  Block &reduceBlock = reduceOp.getReductions()[i].front();
3103  auto initValIndex = results.size();
3104  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3105  mapping.map(reduceBlock.getArgument(1),
3106  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3107  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3108  rewriter.clone(reduceBodyOp, mapping);
3109 
3110  auto result = mapping.lookupOrDefault(
3111  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3112  results.push_back(result);
3113  }
3114 
3115  rewriter.replaceOp(op, results);
3116  return success();
3117  }
3118  // Replace the parallel loop by lower-dimensional parallel loop.
3119  auto newOp =
3120  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3121  newSteps, op.getInitVals(), nullptr);
3122  // Erase the empty block that was inserted by the builder.
3123  rewriter.eraseBlock(newOp.getBody());
3124  // Clone the loop body and remap the block arguments of the collapsed loops
3125  // (inlining does not support a cancellable block argument mapping).
3126  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3127  newOp.getRegion().begin(), mapping);
3128  rewriter.replaceOp(op, newOp.getResults());
3129  return success();
3130  }
3131 };
3132 
3133 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3135 
3136  LogicalResult matchAndRewrite(ParallelOp op,
3137  PatternRewriter &rewriter) const override {
3138  Block &outerBody = *op.getBody();
3139  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3140  return failure();
3141 
3142  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3143  if (!innerOp)
3144  return failure();
3145 
3146  for (auto val : outerBody.getArguments())
3147  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3148  llvm::is_contained(innerOp.getUpperBound(), val) ||
3149  llvm::is_contained(innerOp.getStep(), val))
3150  return failure();
3151 
3152  // Reductions are not supported yet.
3153  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3154  return failure();
3155 
3156  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3157  ValueRange iterVals, ValueRange) {
3158  Block &innerBody = *innerOp.getBody();
3159  assert(iterVals.size() ==
3160  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3161  IRMapping mapping;
3162  mapping.map(outerBody.getArguments(),
3163  iterVals.take_front(outerBody.getNumArguments()));
3164  mapping.map(innerBody.getArguments(),
3165  iterVals.take_back(innerBody.getNumArguments()));
3166  for (Operation &op : innerBody.without_terminator())
3167  builder.clone(op, mapping);
3168  };
3169 
3170  auto concatValues = [](const auto &first, const auto &second) {
3171  SmallVector<Value> ret;
3172  ret.reserve(first.size() + second.size());
3173  ret.assign(first.begin(), first.end());
3174  ret.append(second.begin(), second.end());
3175  return ret;
3176  };
3177 
3178  auto newLowerBounds =
3179  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3180  auto newUpperBounds =
3181  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3182  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3183 
3184  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3185  newSteps, std::nullopt,
3186  bodyBuilder);
3187  return success();
3188  }
3189 };
3190 
3191 } // namespace
3192 
3193 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3194  MLIRContext *context) {
3195  results
3196  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3197  context);
3198 }
3199 
3200 /// Given the region at `index`, or the parent operation if `index` is None,
3201 /// return the successor regions. These are the regions that may be selected
3202 /// during the flow of control. `operands` is a set of optional attributes that
3203 /// correspond to a constant value for each operand, or null if that operand is
3204 /// not a constant.
3205 void ParallelOp::getSuccessorRegions(
3207  // Both the operation itself and the region may be branching into the body or
3208  // back into the operation itself. It is possible for loop not to enter the
3209  // body.
3210  regions.push_back(RegionSuccessor(&getRegion()));
3211  regions.push_back(RegionSuccessor());
3212 }
3213 
3214 //===----------------------------------------------------------------------===//
3215 // ReduceOp
3216 //===----------------------------------------------------------------------===//
3217 
3218 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3219 
3220 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3221  ValueRange operands) {
3222  result.addOperands(operands);
3223  for (Value v : operands) {
3224  OpBuilder::InsertionGuard guard(builder);
3225  Region *bodyRegion = result.addRegion();
3226  builder.createBlock(bodyRegion, {},
3227  ArrayRef<Type>{v.getType(), v.getType()},
3228  {result.location, result.location});
3229  }
3230 }
3231 
3232 LogicalResult ReduceOp::verifyRegions() {
3233  // The region of a ReduceOp has two arguments of the same type as its
3234  // corresponding operand.
3235  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3236  auto type = getOperands()[i].getType();
3237  Block &block = getReductions()[i].front();
3238  if (block.empty())
3239  return emitOpError() << i << "-th reduction has an empty body";
3240  if (block.getNumArguments() != 2 ||
3241  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3242  return arg.getType() != type;
3243  }))
3244  return emitOpError() << "expected two block arguments with type " << type
3245  << " in the " << i << "-th reduction region";
3246 
3247  // Check that the block is terminated by a ReduceReturnOp.
3248  if (!isa<ReduceReturnOp>(block.getTerminator()))
3249  return emitOpError("reduction bodies must be terminated with an "
3250  "'scf.reduce.return' op");
3251  }
3252 
3253  return success();
3254 }
3255 
3258  // No operands are forwarded to the next iteration.
3259  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3260 }
3261 
3262 //===----------------------------------------------------------------------===//
3263 // ReduceReturnOp
3264 //===----------------------------------------------------------------------===//
3265 
3266 LogicalResult ReduceReturnOp::verify() {
3267  // The type of the return value should be the same type as the types of the
3268  // block arguments of the reduction body.
3269  Block *reductionBody = getOperation()->getBlock();
3270  // Should already be verified by an op trait.
3271  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3272  Type expectedResultType = reductionBody->getArgument(0).getType();
3273  if (expectedResultType != getResult().getType())
3274  return emitOpError() << "must have type " << expectedResultType
3275  << " (the type of the reduction inputs)";
3276  return success();
3277 }
3278 
3279 //===----------------------------------------------------------------------===//
3280 // WhileOp
3281 //===----------------------------------------------------------------------===//
3282 
3283 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3284  ::mlir::OperationState &odsState, TypeRange resultTypes,
3285  ValueRange inits, BodyBuilderFn beforeBuilder,
3286  BodyBuilderFn afterBuilder) {
3287  odsState.addOperands(inits);
3288  odsState.addTypes(resultTypes);
3289 
3290  OpBuilder::InsertionGuard guard(odsBuilder);
3291 
3292  // Build before region.
3293  SmallVector<Location, 4> beforeArgLocs;
3294  beforeArgLocs.reserve(inits.size());
3295  for (Value operand : inits) {
3296  beforeArgLocs.push_back(operand.getLoc());
3297  }
3298 
3299  Region *beforeRegion = odsState.addRegion();
3300  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3301  inits.getTypes(), beforeArgLocs);
3302  if (beforeBuilder)
3303  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3304 
3305  // Build after region.
3306  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3307 
3308  Region *afterRegion = odsState.addRegion();
3309  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3310  resultTypes, afterArgLocs);
3311 
3312  if (afterBuilder)
3313  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3314 }
3315 
3316 ConditionOp WhileOp::getConditionOp() {
3317  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3318 }
3319 
3320 YieldOp WhileOp::getYieldOp() {
3321  return cast<YieldOp>(getAfterBody()->getTerminator());
3322 }
3323 
3324 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3325  return getYieldOp().getResultsMutable();
3326 }
3327 
3328 Block::BlockArgListType WhileOp::getBeforeArguments() {
3329  return getBeforeBody()->getArguments();
3330 }
3331 
3332 Block::BlockArgListType WhileOp::getAfterArguments() {
3333  return getAfterBody()->getArguments();
3334 }
3335 
3336 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3337  return getBeforeArguments();
3338 }
3339 
3340 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3341  assert(point == getBefore() &&
3342  "WhileOp is expected to branch only to the first region");
3343  return getInits();
3344 }
3345 
3346 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3348  // The parent op always branches to the condition region.
3349  if (point.isParent()) {
3350  regions.emplace_back(&getBefore(), getBefore().getArguments());
3351  return;
3352  }
3353 
3354  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3355  "there are only two regions in a WhileOp");
3356  // The body region always branches back to the condition region.
3357  if (point == getAfter()) {
3358  regions.emplace_back(&getBefore(), getBefore().getArguments());
3359  return;
3360  }
3361 
3362  regions.emplace_back(getResults());
3363  regions.emplace_back(&getAfter(), getAfter().getArguments());
3364 }
3365 
3366 SmallVector<Region *> WhileOp::getLoopRegions() {
3367  return {&getBefore(), &getAfter()};
3368 }
3369 
3370 /// Parses a `while` op.
3371 ///
3372 /// op ::= `scf.while` assignments `:` function-type region `do` region
3373 /// `attributes` attribute-dict
3374 /// initializer ::= /* empty */ | `(` assignment-list `)`
3375 /// assignment-list ::= assignment | assignment `,` assignment-list
3376 /// assignment ::= ssa-value `=` ssa-value
3377 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3380  Region *before = result.addRegion();
3381  Region *after = result.addRegion();
3382 
3383  OptionalParseResult listResult =
3384  parser.parseOptionalAssignmentList(regionArgs, operands);
3385  if (listResult.has_value() && failed(listResult.value()))
3386  return failure();
3387 
3388  FunctionType functionType;
3389  SMLoc typeLoc = parser.getCurrentLocation();
3390  if (failed(parser.parseColonType(functionType)))
3391  return failure();
3392 
3393  result.addTypes(functionType.getResults());
3394 
3395  if (functionType.getNumInputs() != operands.size()) {
3396  return parser.emitError(typeLoc)
3397  << "expected as many input types as operands "
3398  << "(expected " << operands.size() << " got "
3399  << functionType.getNumInputs() << ")";
3400  }
3401 
3402  // Resolve input operands.
3403  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3404  parser.getCurrentLocation(),
3405  result.operands)))
3406  return failure();
3407 
3408  // Propagate the types into the region arguments.
3409  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3410  regionArgs[i].type = functionType.getInput(i);
3411 
3412  return failure(parser.parseRegion(*before, regionArgs) ||
3413  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3415 }
3416 
3417 /// Prints a `while` op.
3419  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3420  p << " : ";
3421  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3422  p << ' ';
3423  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3424  p << " do ";
3425  p.printRegion(getAfter());
3426  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3427 }
3428 
3429 /// Verifies that two ranges of types match, i.e. have the same number of
3430 /// entries and that types are pairwise equals. Reports errors on the given
3431 /// operation in case of mismatch.
3432 template <typename OpTy>
3433 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3434  TypeRange right, StringRef message) {
3435  if (left.size() != right.size())
3436  return op.emitOpError("expects the same number of ") << message;
3437 
3438  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3439  if (left[i] != right[i]) {
3440  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3441  << message;
3442  diag.attachNote() << "for argument " << i << ", found " << left[i]
3443  << " and " << right[i];
3444  return diag;
3445  }
3446  }
3447 
3448  return success();
3449 }
3450 
3451 LogicalResult scf::WhileOp::verify() {
3452  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3453  *this, getBefore(),
3454  "expects the 'before' region to terminate with 'scf.condition'");
3455  if (!beforeTerminator)
3456  return failure();
3457 
3458  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3459  *this, getAfter(),
3460  "expects the 'after' region to terminate with 'scf.yield'");
3461  return success(afterTerminator != nullptr);
3462 }
3463 
3464 namespace {
3465 /// Replace uses of the condition within the do block with true, since otherwise
3466 /// the block would not be evaluated.
3467 ///
3468 /// scf.while (..) : (i1, ...) -> ... {
3469 /// %condition = call @evaluate_condition() : () -> i1
3470 /// scf.condition(%condition) %condition : i1, ...
3471 /// } do {
3472 /// ^bb0(%arg0: i1, ...):
3473 /// use(%arg0)
3474 /// ...
3475 ///
3476 /// becomes
3477 /// scf.while (..) : (i1, ...) -> ... {
3478 /// %condition = call @evaluate_condition() : () -> i1
3479 /// scf.condition(%condition) %condition : i1, ...
3480 /// } do {
3481 /// ^bb0(%arg0: i1, ...):
3482 /// use(%true)
3483 /// ...
3484 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3486 
3487  LogicalResult matchAndRewrite(WhileOp op,
3488  PatternRewriter &rewriter) const override {
3489  auto term = op.getConditionOp();
3490 
3491  // These variables serve to prevent creating duplicate constants
3492  // and hold constant true or false values.
3493  Value constantTrue = nullptr;
3494 
3495  bool replaced = false;
3496  for (auto yieldedAndBlockArgs :
3497  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3498  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3499  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3500  if (!constantTrue)
3501  constantTrue = rewriter.create<arith::ConstantOp>(
3502  op.getLoc(), term.getCondition().getType(),
3503  rewriter.getBoolAttr(true));
3504 
3505  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3506  constantTrue);
3507  replaced = true;
3508  }
3509  }
3510  }
3511  return success(replaced);
3512  }
3513 };
3514 
3515 /// Remove loop invariant arguments from `before` block of scf.while.
3516 /// A before block argument is considered loop invariant if :-
3517 /// 1. i-th yield operand is equal to the i-th while operand.
3518 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3519 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3520 /// iter argument/while operand.
3521 /// For the arguments which are removed, their uses inside scf.while
3522 /// are replaced with their corresponding initial value.
3523 ///
3524 /// Eg:
3525 /// INPUT :-
3526 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3527 /// ..., %argN_before = %N)
3528 /// {
3529 /// ...
3530 /// scf.condition(%cond) %arg1_before, %arg0_before,
3531 /// %arg2_before, %arg0_before, ...
3532 /// } do {
3533 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3534 /// ..., %argK_after):
3535 /// ...
3536 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3537 /// }
3538 ///
3539 /// OUTPUT :-
3540 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3541 /// %N)
3542 /// {
3543 /// ...
3544 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3545 /// } do {
3546 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3547 /// ..., %argK_after):
3548 /// ...
3549 /// scf.yield %arg1_after, ..., %argN
3550 /// }
3551 ///
3552 /// EXPLANATION:
3553 /// We iterate over each yield operand.
3554 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3555 /// %arg0_before, which in turn is the 0-th iter argument. So we
3556 /// remove 0-th before block argument and yield operand, and replace
3557 /// all uses of the 0-th before block argument with its initial value
3558 /// %a.
3559 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3560 /// value. So we remove this operand and the corresponding before
3561 /// block argument and replace all uses of 1-th before block argument
3562 /// with %b.
3563 struct RemoveLoopInvariantArgsFromBeforeBlock
3564  : public OpRewritePattern<WhileOp> {
3566 
3567  LogicalResult matchAndRewrite(WhileOp op,
3568  PatternRewriter &rewriter) const override {
3569  Block &afterBlock = *op.getAfterBody();
3570  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3571  ConditionOp condOp = op.getConditionOp();
3572  OperandRange condOpArgs = condOp.getArgs();
3573  Operation *yieldOp = afterBlock.getTerminator();
3574  ValueRange yieldOpArgs = yieldOp->getOperands();
3575 
3576  bool canSimplify = false;
3577  for (const auto &it :
3578  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3579  auto index = static_cast<unsigned>(it.index());
3580  auto [initVal, yieldOpArg] = it.value();
3581  // If i-th yield operand is equal to the i-th operand of the scf.while,
3582  // the i-th before block argument is a loop invariant.
3583  if (yieldOpArg == initVal) {
3584  canSimplify = true;
3585  break;
3586  }
3587  // If the i-th yield operand is k-th after block argument, then we check
3588  // if the (k+1)-th condition op operand is equal to either the i-th before
3589  // block argument or the initial value of i-th before block argument. If
3590  // the comparison results `true`, i-th before block argument is a loop
3591  // invariant.
3592  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3593  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3594  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3595  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3596  canSimplify = true;
3597  break;
3598  }
3599  }
3600  }
3601 
3602  if (!canSimplify)
3603  return failure();
3604 
3605  SmallVector<Value> newInitArgs, newYieldOpArgs;
3606  DenseMap<unsigned, Value> beforeBlockInitValMap;
3607  SmallVector<Location> newBeforeBlockArgLocs;
3608  for (const auto &it :
3609  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3610  auto index = static_cast<unsigned>(it.index());
3611  auto [initVal, yieldOpArg] = it.value();
3612 
3613  // If i-th yield operand is equal to the i-th operand of the scf.while,
3614  // the i-th before block argument is a loop invariant.
3615  if (yieldOpArg == initVal) {
3616  beforeBlockInitValMap.insert({index, initVal});
3617  continue;
3618  } else {
3619  // If the i-th yield operand is k-th after block argument, then we check
3620  // if the (k+1)-th condition op operand is equal to either the i-th
3621  // before block argument or the initial value of i-th before block
3622  // argument. If the comparison results `true`, i-th before block
3623  // argument is a loop invariant.
3624  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3625  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3626  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3627  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3628  beforeBlockInitValMap.insert({index, initVal});
3629  continue;
3630  }
3631  }
3632  }
3633  newInitArgs.emplace_back(initVal);
3634  newYieldOpArgs.emplace_back(yieldOpArg);
3635  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3636  }
3637 
3638  {
3639  OpBuilder::InsertionGuard g(rewriter);
3640  rewriter.setInsertionPoint(yieldOp);
3641  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3642  }
3643 
3644  auto newWhile =
3645  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3646 
3647  Block &newBeforeBlock = *rewriter.createBlock(
3648  &newWhile.getBefore(), /*insertPt*/ {},
3649  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3650 
3651  Block &beforeBlock = *op.getBeforeBody();
3652  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3653  // For each i-th before block argument we find it's replacement value as :-
3654  // 1. If i-th before block argument is a loop invariant, we fetch it's
3655  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3656  // 2. Else we fetch j-th new before block argument as the replacement
3657  // value of i-th before block argument.
3658  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3659  // If the index 'i' argument was a loop invariant we fetch it's initial
3660  // value from `beforeBlockInitValMap`.
3661  if (beforeBlockInitValMap.count(i) != 0)
3662  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3663  else
3664  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3665  }
3666 
3667  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3668  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3669  newWhile.getAfter().begin());
3670 
3671  rewriter.replaceOp(op, newWhile.getResults());
3672  return success();
3673  }
3674 };
3675 
3676 /// Remove loop invariant value from result (condition op) of scf.while.
3677 /// A value is considered loop invariant if the final value yielded by
3678 /// scf.condition is defined outside of the `before` block. We remove the
3679 /// corresponding argument in `after` block and replace the use with the value.
3680 /// We also replace the use of the corresponding result of scf.while with the
3681 /// value.
3682 ///
3683 /// Eg:
3684 /// INPUT :-
3685 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3686 /// %argN_before = %N) {
3687 /// ...
3688 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3689 /// } do {
3690 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3691 /// ...
3692 /// some_func(%arg1_after)
3693 /// ...
3694 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3695 /// }
3696 ///
3697 /// OUTPUT :-
3698 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3699 /// ...
3700 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3701 /// } do {
3702 /// ^bb0(%arg0, %arg3, ..., %argM):
3703 /// ...
3704 /// some_func(%a)
3705 /// ...
3706 /// scf.yield %arg0, %b, ..., %argN
3707 /// }
3708 ///
3709 /// EXPLANATION:
3710 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3711 /// before block of scf.while, so they get removed.
3712 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3713 /// replaced by %b.
3714 /// 3. The corresponding after block argument %arg1_after's uses are
3715 /// replaced by %a and %arg2_after's uses are replaced by %b.
3716 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3718 
3719  LogicalResult matchAndRewrite(WhileOp op,
3720  PatternRewriter &rewriter) const override {
3721  Block &beforeBlock = *op.getBeforeBody();
3722  ConditionOp condOp = op.getConditionOp();
3723  OperandRange condOpArgs = condOp.getArgs();
3724 
3725  bool canSimplify = false;
3726  for (Value condOpArg : condOpArgs) {
3727  // Those values not defined within `before` block will be considered as
3728  // loop invariant values. We map the corresponding `index` with their
3729  // value.
3730  if (condOpArg.getParentBlock() != &beforeBlock) {
3731  canSimplify = true;
3732  break;
3733  }
3734  }
3735 
3736  if (!canSimplify)
3737  return failure();
3738 
3739  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3740 
3741  SmallVector<Value> newCondOpArgs;
3742  SmallVector<Type> newAfterBlockType;
3743  DenseMap<unsigned, Value> condOpInitValMap;
3744  SmallVector<Location> newAfterBlockArgLocs;
3745  for (const auto &it : llvm::enumerate(condOpArgs)) {
3746  auto index = static_cast<unsigned>(it.index());
3747  Value condOpArg = it.value();
3748  // Those values not defined within `before` block will be considered as
3749  // loop invariant values. We map the corresponding `index` with their
3750  // value.
3751  if (condOpArg.getParentBlock() != &beforeBlock) {
3752  condOpInitValMap.insert({index, condOpArg});
3753  } else {
3754  newCondOpArgs.emplace_back(condOpArg);
3755  newAfterBlockType.emplace_back(condOpArg.getType());
3756  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3757  }
3758  }
3759 
3760  {
3761  OpBuilder::InsertionGuard g(rewriter);
3762  rewriter.setInsertionPoint(condOp);
3763  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3764  newCondOpArgs);
3765  }
3766 
3767  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3768  op.getOperands());
3769 
3770  Block &newAfterBlock =
3771  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3772  newAfterBlockType, newAfterBlockArgLocs);
3773 
3774  Block &afterBlock = *op.getAfterBody();
3775  // Since a new scf.condition op was created, we need to fetch the new
3776  // `after` block arguments which will be used while replacing operations of
3777  // previous scf.while's `after` blocks. We'd also be fetching new result
3778  // values too.
3779  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3780  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3781  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3782  Value afterBlockArg, result;
3783  // If index 'i' argument was loop invariant we fetch it's value from the
3784  // `condOpInitMap` map.
3785  if (condOpInitValMap.count(i) != 0) {
3786  afterBlockArg = condOpInitValMap[i];
3787  result = afterBlockArg;
3788  } else {
3789  afterBlockArg = newAfterBlock.getArgument(j);
3790  result = newWhile.getResult(j);
3791  j++;
3792  }
3793  newAfterBlockArgs[i] = afterBlockArg;
3794  newWhileResults[i] = result;
3795  }
3796 
3797  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3798  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3799  newWhile.getBefore().begin());
3800 
3801  rewriter.replaceOp(op, newWhileResults);
3802  return success();
3803  }
3804 };
3805 
3806 /// Remove WhileOp results that are also unused in 'after' block.
3807 ///
3808 /// %0:2 = scf.while () : () -> (i32, i64) {
3809 /// %condition = "test.condition"() : () -> i1
3810 /// %v1 = "test.get_some_value"() : () -> i32
3811 /// %v2 = "test.get_some_value"() : () -> i64
3812 /// scf.condition(%condition) %v1, %v2 : i32, i64
3813 /// } do {
3814 /// ^bb0(%arg0: i32, %arg1: i64):
3815 /// "test.use"(%arg0) : (i32) -> ()
3816 /// scf.yield
3817 /// }
3818 /// return %0#0 : i32
3819 ///
3820 /// becomes
3821 /// %0 = scf.while () : () -> (i32) {
3822 /// %condition = "test.condition"() : () -> i1
3823 /// %v1 = "test.get_some_value"() : () -> i32
3824 /// %v2 = "test.get_some_value"() : () -> i64
3825 /// scf.condition(%condition) %v1 : i32
3826 /// } do {
3827 /// ^bb0(%arg0: i32):
3828 /// "test.use"(%arg0) : (i32) -> ()
3829 /// scf.yield
3830 /// }
3831 /// return %0 : i32
3832 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3834 
3835  LogicalResult matchAndRewrite(WhileOp op,
3836  PatternRewriter &rewriter) const override {
3837  auto term = op.getConditionOp();
3838  auto afterArgs = op.getAfterArguments();
3839  auto termArgs = term.getArgs();
3840 
3841  // Collect results mapping, new terminator args and new result types.
3842  SmallVector<unsigned> newResultsIndices;
3843  SmallVector<Type> newResultTypes;
3844  SmallVector<Value> newTermArgs;
3845  SmallVector<Location> newArgLocs;
3846  bool needUpdate = false;
3847  for (const auto &it :
3848  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3849  auto i = static_cast<unsigned>(it.index());
3850  Value result = std::get<0>(it.value());
3851  Value afterArg = std::get<1>(it.value());
3852  Value termArg = std::get<2>(it.value());
3853  if (result.use_empty() && afterArg.use_empty()) {
3854  needUpdate = true;
3855  } else {
3856  newResultsIndices.emplace_back(i);
3857  newTermArgs.emplace_back(termArg);
3858  newResultTypes.emplace_back(result.getType());
3859  newArgLocs.emplace_back(result.getLoc());
3860  }
3861  }
3862 
3863  if (!needUpdate)
3864  return failure();
3865 
3866  {
3867  OpBuilder::InsertionGuard g(rewriter);
3868  rewriter.setInsertionPoint(term);
3869  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3870  newTermArgs);
3871  }
3872 
3873  auto newWhile =
3874  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3875 
3876  Block &newAfterBlock = *rewriter.createBlock(
3877  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3878 
3879  // Build new results list and new after block args (unused entries will be
3880  // null).
3881  SmallVector<Value> newResults(op.getNumResults());
3882  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3883  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3884  newResults[it.value()] = newWhile.getResult(it.index());
3885  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3886  }
3887 
3888  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3889  newWhile.getBefore().begin());
3890 
3891  Block &afterBlock = *op.getAfterBody();
3892  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3893 
3894  rewriter.replaceOp(op, newResults);
3895  return success();
3896  }
3897 };
3898 
3899 /// Replace operations equivalent to the condition in the do block with true,
3900 /// since otherwise the block would not be evaluated.
3901 ///
3902 /// scf.while (..) : (i32, ...) -> ... {
3903 /// %z = ... : i32
3904 /// %condition = cmpi pred %z, %a
3905 /// scf.condition(%condition) %z : i32, ...
3906 /// } do {
3907 /// ^bb0(%arg0: i32, ...):
3908 /// %condition2 = cmpi pred %arg0, %a
3909 /// use(%condition2)
3910 /// ...
3911 ///
3912 /// becomes
3913 /// scf.while (..) : (i32, ...) -> ... {
3914 /// %z = ... : i32
3915 /// %condition = cmpi pred %z, %a
3916 /// scf.condition(%condition) %z : i32, ...
3917 /// } do {
3918 /// ^bb0(%arg0: i32, ...):
3919 /// use(%true)
3920 /// ...
3921 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3923 
3924  LogicalResult matchAndRewrite(scf::WhileOp op,
3925  PatternRewriter &rewriter) const override {
3926  using namespace scf;
3927  auto cond = op.getConditionOp();
3928  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3929  if (!cmp)
3930  return failure();
3931  bool changed = false;
3932  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3933  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3934  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3935  continue;
3936  for (OpOperand &u :
3937  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3938  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3939  if (!cmp2)
3940  continue;
3941  // For a binary operator 1-opIdx gets the other side.
3942  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3943  continue;
3944  bool samePredicate;
3945  if (cmp2.getPredicate() == cmp.getPredicate())
3946  samePredicate = true;
3947  else if (cmp2.getPredicate() ==
3948  arith::invertPredicate(cmp.getPredicate()))
3949  samePredicate = false;
3950  else
3951  continue;
3952 
3953  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3954  1);
3955  changed = true;
3956  }
3957  }
3958  }
3959  return success(changed);
3960  }
3961 };
3962 
3963 /// Remove unused init/yield args.
3964 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3966 
3967  LogicalResult matchAndRewrite(WhileOp op,
3968  PatternRewriter &rewriter) const override {
3969 
3970  if (!llvm::any_of(op.getBeforeArguments(),
3971  [](Value arg) { return arg.use_empty(); }))
3972  return rewriter.notifyMatchFailure(op, "No args to remove");
3973 
3974  YieldOp yield = op.getYieldOp();
3975 
3976  // Collect results mapping, new terminator args and new result types.
3977  SmallVector<Value> newYields;
3978  SmallVector<Value> newInits;
3979  llvm::BitVector argsToErase;
3980 
3981  size_t argsCount = op.getBeforeArguments().size();
3982  newYields.reserve(argsCount);
3983  newInits.reserve(argsCount);
3984  argsToErase.reserve(argsCount);
3985  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3986  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3987  if (beforeArg.use_empty()) {
3988  argsToErase.push_back(true);
3989  } else {
3990  argsToErase.push_back(false);
3991  newYields.emplace_back(yieldValue);
3992  newInits.emplace_back(initValue);
3993  }
3994  }
3995 
3996  Block &beforeBlock = *op.getBeforeBody();
3997  Block &afterBlock = *op.getAfterBody();
3998 
3999  beforeBlock.eraseArguments(argsToErase);
4000 
4001  Location loc = op.getLoc();
4002  auto newWhileOp =
4003  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4004  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4005  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4006  Block &newAfterBlock = *newWhileOp.getAfterBody();
4007 
4008  OpBuilder::InsertionGuard g(rewriter);
4009  rewriter.setInsertionPoint(yield);
4010  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4011 
4012  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4013  newBeforeBlock.getArguments());
4014  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4015  newAfterBlock.getArguments());
4016 
4017  rewriter.replaceOp(op, newWhileOp.getResults());
4018  return success();
4019  }
4020 };
4021 
4022 /// Remove duplicated ConditionOp args.
4023 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4025 
4026  LogicalResult matchAndRewrite(WhileOp op,
4027  PatternRewriter &rewriter) const override {
4028  ConditionOp condOp = op.getConditionOp();
4029  ValueRange condOpArgs = condOp.getArgs();
4030 
4032  for (Value arg : condOpArgs)
4033  argsSet.insert(arg);
4034 
4035  if (argsSet.size() == condOpArgs.size())
4036  return rewriter.notifyMatchFailure(op, "No results to remove");
4037 
4038  llvm::SmallDenseMap<Value, unsigned> argsMap;
4039  SmallVector<Value> newArgs;
4040  argsMap.reserve(condOpArgs.size());
4041  newArgs.reserve(condOpArgs.size());
4042  for (Value arg : condOpArgs) {
4043  if (!argsMap.count(arg)) {
4044  auto pos = static_cast<unsigned>(argsMap.size());
4045  argsMap.insert({arg, pos});
4046  newArgs.emplace_back(arg);
4047  }
4048  }
4049 
4050  ValueRange argsRange(newArgs);
4051 
4052  Location loc = op.getLoc();
4053  auto newWhileOp = rewriter.create<scf::WhileOp>(
4054  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4055  /*afterBody*/ nullptr);
4056  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4057  Block &newAfterBlock = *newWhileOp.getAfterBody();
4058 
4059  SmallVector<Value> afterArgsMapping;
4060  SmallVector<Value> resultsMapping;
4061  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4062  auto it = argsMap.find(arg);
4063  assert(it != argsMap.end());
4064  auto pos = it->second;
4065  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4066  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4067  }
4068 
4069  OpBuilder::InsertionGuard g(rewriter);
4070  rewriter.setInsertionPoint(condOp);
4071  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4072  argsRange);
4073 
4074  Block &beforeBlock = *op.getBeforeBody();
4075  Block &afterBlock = *op.getAfterBody();
4076 
4077  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4078  newBeforeBlock.getArguments());
4079  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4080  rewriter.replaceOp(op, resultsMapping);
4081  return success();
4082  }
4083 };
4084 
4085 /// If both ranges contain same values return mappping indices from args2 to
4086 /// args1. Otherwise return std::nullopt.
4087 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4088  ValueRange args2) {
4089  if (args1.size() != args2.size())
4090  return std::nullopt;
4091 
4092  SmallVector<unsigned> ret(args1.size());
4093  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4094  auto it = llvm::find(args2, arg1);
4095  if (it == args2.end())
4096  return std::nullopt;
4097 
4098  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4099  }
4100 
4101  return ret;
4102 }
4103 
4104 static bool hasDuplicates(ValueRange args) {
4105  llvm::SmallDenseSet<Value> set;
4106  for (Value arg : args) {
4107  if (!set.insert(arg).second)
4108  return true;
4109  }
4110  return false;
4111 }
4112 
4113 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4114 /// `scf.condition` args into same order as block args. Update `after` block
4115 /// args and op result values accordingly.
4116 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4117 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4119 
4120  LogicalResult matchAndRewrite(WhileOp loop,
4121  PatternRewriter &rewriter) const override {
4122  auto oldBefore = loop.getBeforeBody();
4123  ConditionOp oldTerm = loop.getConditionOp();
4124  ValueRange beforeArgs = oldBefore->getArguments();
4125  ValueRange termArgs = oldTerm.getArgs();
4126  if (beforeArgs == termArgs)
4127  return failure();
4128 
4129  if (hasDuplicates(termArgs))
4130  return failure();
4131 
4132  auto mapping = getArgsMapping(beforeArgs, termArgs);
4133  if (!mapping)
4134  return failure();
4135 
4136  {
4137  OpBuilder::InsertionGuard g(rewriter);
4138  rewriter.setInsertionPoint(oldTerm);
4139  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4140  beforeArgs);
4141  }
4142 
4143  auto oldAfter = loop.getAfterBody();
4144 
4145  SmallVector<Type> newResultTypes(beforeArgs.size());
4146  for (auto &&[i, j] : llvm::enumerate(*mapping))
4147  newResultTypes[j] = loop.getResult(i).getType();
4148 
4149  auto newLoop = rewriter.create<WhileOp>(
4150  loop.getLoc(), newResultTypes, loop.getInits(),
4151  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4152  auto newBefore = newLoop.getBeforeBody();
4153  auto newAfter = newLoop.getAfterBody();
4154 
4155  SmallVector<Value> newResults(beforeArgs.size());
4156  SmallVector<Value> newAfterArgs(beforeArgs.size());
4157  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4158  newResults[i] = newLoop.getResult(j);
4159  newAfterArgs[i] = newAfter->getArgument(j);
4160  }
4161 
4162  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4163  newBefore->getArguments());
4164  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4165  newAfterArgs);
4166 
4167  rewriter.replaceOp(loop, newResults);
4168  return success();
4169  }
4170 };
4171 } // namespace
4172 
4173 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4174  MLIRContext *context) {
4175  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4176  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4177  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4178  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4179 }
4180 
4181 //===----------------------------------------------------------------------===//
4182 // IndexSwitchOp
4183 //===----------------------------------------------------------------------===//
4184 
4185 /// Parse the case regions and values.
4186 static ParseResult
4188  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4189  SmallVector<int64_t> caseValues;
4190  while (succeeded(p.parseOptionalKeyword("case"))) {
4191  int64_t value;
4192  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4193  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4194  return failure();
4195  caseValues.push_back(value);
4196  }
4197  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4198  return success();
4199 }
4200 
4201 /// Print the case regions and values.
4203  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4204  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4205  p.printNewline();
4206  p << "case " << value << ' ';
4207  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4208  }
4209 }
4210 
4211 LogicalResult scf::IndexSwitchOp::verify() {
4212  if (getCases().size() != getCaseRegions().size()) {
4213  return emitOpError("has ")
4214  << getCaseRegions().size() << " case regions but "
4215  << getCases().size() << " case values";
4216  }
4217 
4218  DenseSet<int64_t> valueSet;
4219  for (int64_t value : getCases())
4220  if (!valueSet.insert(value).second)
4221  return emitOpError("has duplicate case value: ") << value;
4222  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4223  auto yield = dyn_cast<YieldOp>(region.front().back());
4224  if (!yield)
4225  return emitOpError("expected region to end with scf.yield, but got ")
4226  << region.front().back().getName();
4227 
4228  if (yield.getNumOperands() != getNumResults()) {
4229  return (emitOpError("expected each region to return ")
4230  << getNumResults() << " values, but " << name << " returns "
4231  << yield.getNumOperands())
4232  .attachNote(yield.getLoc())
4233  << "see yield operation here";
4234  }
4235  for (auto [idx, result, operand] :
4236  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4237  yield.getOperandTypes())) {
4238  if (result == operand)
4239  continue;
4240  return (emitOpError("expected result #")
4241  << idx << " of each region to be " << result)
4242  .attachNote(yield.getLoc())
4243  << name << " returns " << operand << " here";
4244  }
4245  return success();
4246  };
4247 
4248  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4249  return failure();
4250  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4251  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4252  return failure();
4253 
4254  return success();
4255 }
4256 
4257 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4258 
4259 Block &scf::IndexSwitchOp::getDefaultBlock() {
4260  return getDefaultRegion().front();
4261 }
4262 
4263 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4264  assert(idx < getNumCases() && "case index out-of-bounds");
4265  return getCaseRegions()[idx].front();
4266 }
4267 
4268 void IndexSwitchOp::getSuccessorRegions(
4270  // All regions branch back to the parent op.
4271  if (!point.isParent()) {
4272  successors.emplace_back(getResults());
4273  return;
4274  }
4275 
4276  llvm::copy(getRegions(), std::back_inserter(successors));
4277 }
4278 
4279 void IndexSwitchOp::getEntrySuccessorRegions(
4280  ArrayRef<Attribute> operands,
4281  SmallVectorImpl<RegionSuccessor> &successors) {
4282  FoldAdaptor adaptor(operands, *this);
4283 
4284  // If a constant was not provided, all regions are possible successors.
4285  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4286  if (!arg) {
4287  llvm::copy(getRegions(), std::back_inserter(successors));
4288  return;
4289  }
4290 
4291  // Otherwise, try to find a case with a matching value. If not, the
4292  // default region is the only successor.
4293  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4294  if (caseValue == arg.getInt()) {
4295  successors.emplace_back(&caseRegion);
4296  return;
4297  }
4298  }
4299  successors.emplace_back(&getDefaultRegion());
4300 }
4301 
4302 void IndexSwitchOp::getRegionInvocationBounds(
4304  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4305  if (!operandValue) {
4306  // All regions are invoked at most once.
4307  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4308  return;
4309  }
4310 
4311  unsigned liveIndex = getNumRegions() - 1;
4312  const auto *it = llvm::find(getCases(), operandValue.getInt());
4313  if (it != getCases().end())
4314  liveIndex = std::distance(getCases().begin(), it);
4315  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4316  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4317 }
4318 
4319 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4321 
4322  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4323  PatternRewriter &rewriter) const override {
4324  // If `op.getArg()` is a constant, select the region that matches with
4325  // the constant value. Use the default region if no matche is found.
4326  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4327  if (!maybeCst.has_value())
4328  return failure();
4329  int64_t cst = *maybeCst;
4330  int64_t caseIdx, e = op.getNumCases();
4331  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4332  if (cst == op.getCases()[caseIdx])
4333  break;
4334  }
4335 
4336  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4337  : op.getDefaultRegion();
4338  Block &source = r.front();
4339  Operation *terminator = source.getTerminator();
4340  SmallVector<Value> results = terminator->getOperands();
4341 
4342  rewriter.inlineBlockBefore(&source, op);
4343  rewriter.eraseOp(terminator);
4344  // Replace the operation with a potentially empty list of results.
4345  // Fold mechanism doesn't support the case where the result list is empty.
4346  rewriter.replaceOp(op, results);
4347 
4348  return success();
4349  }
4350 };
4351 
4352 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4353  MLIRContext *context) {
4354  results.add<FoldConstantCase>(context);
4355 }
4356 
4357 //===----------------------------------------------------------------------===//
4358 // TableGen'd op method definitions
4359 //===----------------------------------------------------------------------===//
4360 
4361 #define GET_OP_CLASSES
4362 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:722
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:714
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4187
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3433
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4202
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:142
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:197
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:262
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:201
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:105
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:134
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:91
IndexType getIndexType()
Definition: Builders.cpp:89
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:582
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:609
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:464
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:491
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:593
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3046
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1973
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1447
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:776
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:266
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4322
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.