MLIR  22.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78  InParallelOp, ReduceReturnOp>();
79  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81  ForallOp, InParallelOp, WhileOp, YieldOp>();
82  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
83 }
84 
85 /// Default callback for IfOp builders. Inserts a yield without arguments.
87  builder.create<scf::YieldOp>(loc);
88 }
89 
90 /// Verifies that the first block of the given `region` is terminated by a
91 /// TerminatorTy. Reports errors on the given operation if it is not the case.
92 template <typename TerminatorTy>
93 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
94  StringRef errorMessage) {
95  Operation *terminatorOperation = nullptr;
96  if (!region.empty() && !region.front().empty()) {
97  terminatorOperation = &region.front().back();
98  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
99  return yield;
100  }
101  auto diag = op->emitOpError(errorMessage);
102  if (terminatorOperation)
103  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
104  return nullptr;
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // ExecuteRegionOp
109 //===----------------------------------------------------------------------===//
110 
111 /// Replaces the given op with the contents of the given single-block region,
112 /// using the operands of the block terminator to replace operation results.
113 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
114  Region &region, ValueRange blockArgs = {}) {
115  assert(llvm::hasSingleElement(region) && "expected single-region block");
116  Block *block = &region.front();
117  Operation *terminator = block->getTerminator();
118  ValueRange results = terminator->getOperands();
119  rewriter.inlineBlockBefore(block, op, blockArgs);
120  rewriter.replaceOp(op, results);
121  rewriter.eraseOp(terminator);
122 }
123 
124 ///
125 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
126 /// block+
127 /// `}`
128 ///
129 /// Example:
130 /// scf.execute_region -> i32 {
131 /// %idx = load %rI[%i] : memref<128xi32>
132 /// return %idx : i32
133 /// }
134 ///
135 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
136  OperationState &result) {
137  if (parser.parseOptionalArrowTypeList(result.types))
138  return failure();
139 
140  // Introduce the body region and parse it.
141  Region *body = result.addRegion();
142  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
143  parser.parseOptionalAttrDict(result.attributes))
144  return failure();
145 
146  return success();
147 }
148 
150  p.printOptionalArrowTypeList(getResultTypes());
151 
152  p << ' ';
153  p.printRegion(getRegion(),
154  /*printEntryBlockArgs=*/false,
155  /*printBlockTerminators=*/true);
156 
157  p.printOptionalAttrDict((*this)->getAttrs());
158 }
159 
160 LogicalResult ExecuteRegionOp::verify() {
161  if (getRegion().empty())
162  return emitOpError("region needs to have at least one block");
163  if (getRegion().front().getNumArguments() > 0)
164  return emitOpError("region cannot have any arguments");
165  return success();
166 }
167 
168 // Inline an ExecuteRegionOp if it only contains one block.
169 // "test.foo"() : () -> ()
170 // %v = scf.execute_region -> i64 {
171 // %x = "test.val"() : () -> i64
172 // scf.yield %x : i64
173 // }
174 // "test.bar"(%v) : (i64) -> ()
175 //
176 // becomes
177 //
178 // "test.foo"() : () -> ()
179 // %x = "test.val"() : () -> i64
180 // "test.bar"(%x) : (i64) -> ()
181 //
182 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
184 
185  LogicalResult matchAndRewrite(ExecuteRegionOp op,
186  PatternRewriter &rewriter) const override {
187  if (!llvm::hasSingleElement(op.getRegion()))
188  return failure();
189  replaceOpWithRegion(rewriter, op, op.getRegion());
190  return success();
191  }
192 };
193 
194 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
195 // TODO generalize the conditions for operations which can be inlined into.
196 // func @func_execute_region_elim() {
197 // "test.foo"() : () -> ()
198 // %v = scf.execute_region -> i64 {
199 // %c = "test.cmp"() : () -> i1
200 // cf.cond_br %c, ^bb2, ^bb3
201 // ^bb2:
202 // %x = "test.val1"() : () -> i64
203 // cf.br ^bb4(%x : i64)
204 // ^bb3:
205 // %y = "test.val2"() : () -> i64
206 // cf.br ^bb4(%y : i64)
207 // ^bb4(%z : i64):
208 // scf.yield %z : i64
209 // }
210 // "test.bar"(%v) : (i64) -> ()
211 // return
212 // }
213 //
214 // becomes
215 //
216 // func @func_execute_region_elim() {
217 // "test.foo"() : () -> ()
218 // %c = "test.cmp"() : () -> i1
219 // cf.cond_br %c, ^bb1, ^bb2
220 // ^bb1: // pred: ^bb0
221 // %x = "test.val1"() : () -> i64
222 // cf.br ^bb3(%x : i64)
223 // ^bb2: // pred: ^bb0
224 // %y = "test.val2"() : () -> i64
225 // cf.br ^bb3(%y : i64)
226 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
227 // "test.bar"(%z) : (i64) -> ()
228 // return
229 // }
230 //
231 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
233 
234  LogicalResult matchAndRewrite(ExecuteRegionOp op,
235  PatternRewriter &rewriter) const override {
236  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
237  return failure();
238 
239  Block *prevBlock = op->getBlock();
240  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
241  rewriter.setInsertionPointToEnd(prevBlock);
242 
243  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 
245  for (Block &blk : op.getRegion()) {
246  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247  rewriter.setInsertionPoint(yieldOp);
248  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
249  yieldOp.getResults());
250  rewriter.eraseOp(yieldOp);
251  }
252  }
253 
254  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
255  SmallVector<Value> blockArgs;
256 
257  for (auto res : op.getResults())
258  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
259 
260  rewriter.replaceOp(op, blockArgs);
261  return success();
262  }
263 };
264 
265 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
266  MLIRContext *context) {
268 }
269 
270 void ExecuteRegionOp::getSuccessorRegions(
272  // If the predecessor is the ExecuteRegionOp, branch into the body.
273  if (point.isParent()) {
274  regions.push_back(RegionSuccessor(&getRegion()));
275  return;
276  }
277 
278  // Otherwise, the region branches back to the parent operation.
279  regions.push_back(RegionSuccessor(getResults()));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // ConditionOp
284 //===----------------------------------------------------------------------===//
285 
288  assert((point.isParent() || point == getParentOp().getAfter()) &&
289  "condition op can only exit the loop or branch to the after"
290  "region");
291  // Pass all operands except the condition to the successor region.
292  return getArgsMutable();
293 }
294 
295 void ConditionOp::getSuccessorRegions(
297  FoldAdaptor adaptor(operands, *this);
298 
299  WhileOp whileOp = getParentOp();
300 
301  // Condition can either lead to the after region or back to the parent op
302  // depending on whether the condition is true or not.
303  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304  if (!boolAttr || boolAttr.getValue())
305  regions.emplace_back(&whileOp.getAfter(),
306  whileOp.getAfter().getArguments());
307  if (!boolAttr || !boolAttr.getValue())
308  regions.emplace_back(whileOp.getResults());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // ForOp
313 //===----------------------------------------------------------------------===//
314 
315 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
316  Value ub, Value step, ValueRange initArgs,
317  BodyBuilderFn bodyBuilder) {
318  OpBuilder::InsertionGuard guard(builder);
319 
320  result.addOperands({lb, ub, step});
321  result.addOperands(initArgs);
322  for (Value v : initArgs)
323  result.addTypes(v.getType());
324  Type t = lb.getType();
325  Region *bodyRegion = result.addRegion();
326  Block *bodyBlock = builder.createBlock(bodyRegion);
327  bodyBlock->addArgument(t, result.location);
328  for (Value v : initArgs)
329  bodyBlock->addArgument(v.getType(), v.getLoc());
330 
331  // Create the default terminator if the builder is not provided and if the
332  // iteration arguments are not provided. Otherwise, leave this to the caller
333  // because we don't know which values to return from the loop.
334  if (initArgs.empty() && !bodyBuilder) {
335  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
336  } else if (bodyBuilder) {
337  OpBuilder::InsertionGuard guard(builder);
338  builder.setInsertionPointToStart(bodyBlock);
339  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
340  bodyBlock->getArguments().drop_front());
341  }
342 }
343 
344 LogicalResult ForOp::verify() {
345  // Check that the number of init args and op results is the same.
346  if (getInitArgs().size() != getNumResults())
347  return emitOpError(
348  "mismatch in number of loop-carried values and defined values");
349 
350  return success();
351 }
352 
353 LogicalResult ForOp::verifyRegions() {
354  // Check that the body defines as single block argument for the induction
355  // variable.
356  if (getInductionVar().getType() != getLowerBound().getType())
357  return emitOpError(
358  "expected induction variable to be same type as bounds and step");
359 
360  if (getNumRegionIterArgs() != getNumResults())
361  return emitOpError(
362  "mismatch in number of basic block args and defined values");
363 
364  auto initArgs = getInitArgs();
365  auto iterArgs = getRegionIterArgs();
366  auto opResults = getResults();
367  unsigned i = 0;
368  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369  if (std::get<0>(e).getType() != std::get<2>(e).getType())
370  return emitOpError() << "types mismatch between " << i
371  << "th iter operand and defined value";
372  if (std::get<1>(e).getType() != std::get<2>(e).getType())
373  return emitOpError() << "types mismatch between " << i
374  << "th iter region arg and defined value";
375 
376  ++i;
377  }
378  return success();
379 }
380 
381 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
382  return SmallVector<Value>{getInductionVar()};
383 }
384 
385 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
387 }
388 
389 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
390  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
391 }
392 
393 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
395 }
396 
397 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
398 
399 /// Promotes the loop body of a forOp to its containing block if the forOp
400 /// it can be determined that the loop has a single iteration.
401 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
402  std::optional<int64_t> tripCount =
404  if (!tripCount.has_value() || tripCount != 1)
405  return failure();
406 
407  // Replace all results with the yielded values.
408  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
409  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
410 
411  // Replace block arguments with lower bound (replacement for IV) and
412  // iter_args.
413  SmallVector<Value> bbArgReplacements;
414  bbArgReplacements.push_back(getLowerBound());
415  llvm::append_range(bbArgReplacements, getInitArgs());
416 
417  // Move the loop body operations to the loop's containing block.
418  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
419  getOperation()->getIterator(), bbArgReplacements);
420 
421  // Erase the old terminator and the loop.
422  rewriter.eraseOp(yieldOp);
423  rewriter.eraseOp(*this);
424 
425  return success();
426 }
427 
428 /// Prints the initialization list in the form of
429 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
430 /// where 'inner' values are assumed to be region arguments and 'outer' values
431 /// are regular SSA values.
433  Block::BlockArgListType blocksArgs,
434  ValueRange initializers,
435  StringRef prefix = "") {
436  assert(blocksArgs.size() == initializers.size() &&
437  "expected same length of arguments and initializers");
438  if (initializers.empty())
439  return;
440 
441  p << prefix << '(';
442  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
443  p << std::get<0>(it) << " = " << std::get<1>(it);
444  });
445  p << ")";
446 }
447 
448 void ForOp::print(OpAsmPrinter &p) {
449  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
450  << getUpperBound() << " step " << getStep();
451 
452  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
453  if (!getInitArgs().empty())
454  p << " -> (" << getInitArgs().getTypes() << ')';
455  p << ' ';
456  if (Type t = getInductionVar().getType(); !t.isIndex())
457  p << " : " << t << ' ';
458  p.printRegion(getRegion(),
459  /*printEntryBlockArgs=*/false,
460  /*printBlockTerminators=*/!getInitArgs().empty());
461  p.printOptionalAttrDict((*this)->getAttrs());
462 }
463 
464 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
465  auto &builder = parser.getBuilder();
466  Type type;
467 
468  OpAsmParser::Argument inductionVariable;
469  OpAsmParser::UnresolvedOperand lb, ub, step;
470 
471  // Parse the induction variable followed by '='.
472  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
473  // Parse loop bounds.
474  parser.parseOperand(lb) || parser.parseKeyword("to") ||
475  parser.parseOperand(ub) || parser.parseKeyword("step") ||
476  parser.parseOperand(step))
477  return failure();
478 
479  // Parse the optional initial iteration arguments.
482  regionArgs.push_back(inductionVariable);
483 
484  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
485  if (hasIterArgs) {
486  // Parse assignment list and results type list.
487  if (parser.parseAssignmentList(regionArgs, operands) ||
488  parser.parseArrowTypeList(result.types))
489  return failure();
490  }
491 
492  if (regionArgs.size() != result.types.size() + 1)
493  return parser.emitError(
494  parser.getNameLoc(),
495  "mismatch in number of loop-carried values and defined values");
496 
497  // Parse optional type, else assume Index.
498  if (parser.parseOptionalColon())
499  type = builder.getIndexType();
500  else if (parser.parseType(type))
501  return failure();
502 
503  // Set block argument types, so that they are known when parsing the region.
504  regionArgs.front().type = type;
505  for (auto [iterArg, type] :
506  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
507  iterArg.type = type;
508 
509  // Parse the body region.
510  Region *body = result.addRegion();
511  if (parser.parseRegion(*body, regionArgs))
512  return failure();
513  ForOp::ensureTerminator(*body, builder, result.location);
514 
515  // Resolve input operands. This should be done after parsing the region to
516  // catch invalid IR where operands were defined inside of the region.
517  if (parser.resolveOperand(lb, type, result.operands) ||
518  parser.resolveOperand(ub, type, result.operands) ||
519  parser.resolveOperand(step, type, result.operands))
520  return failure();
521  if (hasIterArgs) {
522  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
523  operands, result.types)) {
524  Type type = std::get<2>(argOperandType);
525  std::get<0>(argOperandType).type = type;
526  if (parser.resolveOperand(std::get<1>(argOperandType), type,
527  result.operands))
528  return failure();
529  }
530  }
531 
532  // Parse the optional attribute list.
533  if (parser.parseOptionalAttrDict(result.attributes))
534  return failure();
535 
536  return success();
537 }
538 
539 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
540 
541 Block::BlockArgListType ForOp::getRegionIterArgs() {
542  return getBody()->getArguments().drop_front(getNumInductionVars());
543 }
544 
545 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
546  return getInitArgsMutable();
547 }
548 
549 FailureOr<LoopLikeOpInterface>
550 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
551  ValueRange newInitOperands,
552  bool replaceInitOperandUsesInLoop,
553  const NewYieldValuesFn &newYieldValuesFn) {
554  // Create a new loop before the existing one, with the extra operands.
555  OpBuilder::InsertionGuard g(rewriter);
556  rewriter.setInsertionPoint(getOperation());
557  auto inits = llvm::to_vector(getInitArgs());
558  inits.append(newInitOperands.begin(), newInitOperands.end());
559  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
560  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
561  [](OpBuilder &, Location, Value, ValueRange) {});
562  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
563 
564  // Generate the new yield values and append them to the scf.yield operation.
565  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
566  ArrayRef<BlockArgument> newIterArgs =
567  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
568  {
569  OpBuilder::InsertionGuard g(rewriter);
570  rewriter.setInsertionPoint(yieldOp);
571  SmallVector<Value> newYieldedValues =
572  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
573  assert(newInitOperands.size() == newYieldedValues.size() &&
574  "expected as many new yield values as new iter operands");
575  rewriter.modifyOpInPlace(yieldOp, [&]() {
576  yieldOp.getResultsMutable().append(newYieldedValues);
577  });
578  }
579 
580  // Move the loop body to the new op.
581  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
582  newLoop.getBody()->getArguments().take_front(
583  getBody()->getNumArguments()));
584 
585  if (replaceInitOperandUsesInLoop) {
586  // Replace all uses of `newInitOperands` with the corresponding basic block
587  // arguments.
588  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
589  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
590  [&](OpOperand &use) {
591  Operation *user = use.getOwner();
592  return newLoop->isProperAncestor(user);
593  });
594  }
595  }
596 
597  // Replace the old loop.
598  rewriter.replaceOp(getOperation(),
599  newLoop->getResults().take_front(getNumResults()));
600  return cast<LoopLikeOpInterface>(newLoop.getOperation());
601 }
602 
604  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
605  if (!ivArg)
606  return ForOp();
607  assert(ivArg.getOwner() && "unlinked block argument");
608  auto *containingOp = ivArg.getOwner()->getParentOp();
609  return dyn_cast_or_null<ForOp>(containingOp);
610 }
611 
612 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
613  return getInitArgs();
614 }
615 
616 void ForOp::getSuccessorRegions(RegionBranchPoint point,
618  // Both the operation itself and the region may be branching into the body or
619  // back into the operation itself. It is possible for loop not to enter the
620  // body.
621  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
622  regions.push_back(RegionSuccessor(getResults()));
623 }
624 
625 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
626 
627 /// Promotes the loop body of a forallOp to its containing block if it can be
628 /// determined that the loop has a single iteration.
629 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
630  for (auto [lb, ub, step] :
631  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
632  auto tripCount = constantTripCount(lb, ub, step);
633  if (!tripCount.has_value() || *tripCount != 1)
634  return failure();
635  }
636 
637  promote(rewriter, *this);
638  return success();
639 }
640 
641 Block::BlockArgListType ForallOp::getRegionIterArgs() {
642  return getBody()->getArguments().drop_front(getRank());
643 }
644 
645 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
646  return getOutputsMutable();
647 }
648 
649 /// Promotes the loop body of a scf::ForallOp to its containing block.
650 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
651  OpBuilder::InsertionGuard g(rewriter);
652  scf::InParallelOp terminator = forallOp.getTerminator();
653 
654  // Replace block arguments with lower bounds (replacements for IVs) and
655  // outputs.
656  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
657  bbArgReplacements.append(forallOp.getOutputs().begin(),
658  forallOp.getOutputs().end());
659 
660  // Move the loop body operations to the loop's containing block.
661  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
662  forallOp->getIterator(), bbArgReplacements);
663 
664  // Replace the terminator with tensor.insert_slice ops.
665  rewriter.setInsertionPointAfter(forallOp);
666  SmallVector<Value> results;
667  results.reserve(forallOp.getResults().size());
668  for (auto &yieldingOp : terminator.getYieldingOps()) {
669  auto parallelInsertSliceOp =
670  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
671 
672  Value dst = parallelInsertSliceOp.getDest();
673  Value src = parallelInsertSliceOp.getSource();
674  if (llvm::isa<TensorType>(src.getType())) {
675  results.push_back(rewriter.create<tensor::InsertSliceOp>(
676  forallOp.getLoc(), dst.getType(), src, dst,
677  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
678  parallelInsertSliceOp.getStrides(),
679  parallelInsertSliceOp.getStaticOffsets(),
680  parallelInsertSliceOp.getStaticSizes(),
681  parallelInsertSliceOp.getStaticStrides()));
682  } else {
683  llvm_unreachable("unsupported terminator");
684  }
685  }
686  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
687 
688  // Erase the old terminator and the loop.
689  rewriter.eraseOp(terminator);
690  rewriter.eraseOp(forallOp);
691 }
692 
694  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
695  ValueRange steps, ValueRange iterArgs,
697  bodyBuilder) {
698  assert(lbs.size() == ubs.size() &&
699  "expected the same number of lower and upper bounds");
700  assert(lbs.size() == steps.size() &&
701  "expected the same number of lower bounds and steps");
702 
703  // If there are no bounds, call the body-building function and return early.
704  if (lbs.empty()) {
705  ValueVector results =
706  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
707  : ValueVector();
708  assert(results.size() == iterArgs.size() &&
709  "loop nest body must return as many values as loop has iteration "
710  "arguments");
711  return LoopNest{{}, std::move(results)};
712  }
713 
714  // First, create the loop structure iteratively using the body-builder
715  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
716  OpBuilder::InsertionGuard guard(builder);
719  loops.reserve(lbs.size());
720  ivs.reserve(lbs.size());
721  ValueRange currentIterArgs = iterArgs;
722  Location currentLoc = loc;
723  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
724  auto loop = builder.create<scf::ForOp>(
725  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
726  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
727  ValueRange args) {
728  ivs.push_back(iv);
729  // It is safe to store ValueRange args because it points to block
730  // arguments of a loop operation that we also own.
731  currentIterArgs = args;
732  currentLoc = nestedLoc;
733  });
734  // Set the builder to point to the body of the newly created loop. We don't
735  // do this in the callback because the builder is reset when the callback
736  // returns.
737  builder.setInsertionPointToStart(loop.getBody());
738  loops.push_back(loop);
739  }
740 
741  // For all loops but the innermost, yield the results of the nested loop.
742  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
743  builder.setInsertionPointToEnd(loops[i].getBody());
744  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
745  }
746 
747  // In the body of the innermost loop, call the body building function if any
748  // and yield its results.
749  builder.setInsertionPointToStart(loops.back().getBody());
750  ValueVector results = bodyBuilder
751  ? bodyBuilder(builder, currentLoc, ivs,
752  loops.back().getRegionIterArgs())
753  : ValueVector();
754  assert(results.size() == iterArgs.size() &&
755  "loop nest body must return as many values as loop has iteration "
756  "arguments");
757  builder.setInsertionPointToEnd(loops.back().getBody());
758  builder.create<scf::YieldOp>(loc, results);
759 
760  // Return the loops.
761  ValueVector nestResults;
762  llvm::append_range(nestResults, loops.front().getResults());
763  return LoopNest{std::move(loops), std::move(nestResults)};
764 }
765 
767  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
768  ValueRange steps,
769  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
770  // Delegate to the main function by wrapping the body builder.
771  return buildLoopNest(builder, loc, lbs, ubs, steps, {},
772  [&bodyBuilder](OpBuilder &nestedBuilder,
773  Location nestedLoc, ValueRange ivs,
774  ValueRange) -> ValueVector {
775  if (bodyBuilder)
776  bodyBuilder(nestedBuilder, nestedLoc, ivs);
777  return {};
778  });
779 }
780 
783  OpOperand &operand, Value replacement,
784  const ValueTypeCastFnTy &castFn) {
785  assert(operand.getOwner() == forOp);
786  Type oldType = operand.get().getType(), newType = replacement.getType();
787 
788  // 1. Create new iter operands, exactly 1 is replaced.
789  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
790  "expected an iter OpOperand");
791  assert(operand.get().getType() != replacement.getType() &&
792  "Expected a different type");
793  SmallVector<Value> newIterOperands;
794  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
795  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
796  newIterOperands.push_back(replacement);
797  continue;
798  }
799  newIterOperands.push_back(opOperand.get());
800  }
801 
802  // 2. Create the new forOp shell.
803  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
804  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
805  forOp.getStep(), newIterOperands);
806  newForOp->setAttrs(forOp->getAttrs());
807  Block &newBlock = newForOp.getRegion().front();
808  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
809  newBlock.getArguments().end());
810 
811  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
812  // corresponding to the `replacement` value.
813  OpBuilder::InsertionGuard g(rewriter);
814  rewriter.setInsertionPointToStart(&newBlock);
815  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
816  &newForOp->getOpOperand(operand.getOperandNumber()));
817  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
818  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
819 
820  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
821  Block &oldBlock = forOp.getRegion().front();
822  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
823 
824  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
825  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
826  rewriter.setInsertionPoint(clonedYieldOp);
827  unsigned yieldIdx =
828  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
829  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
830  clonedYieldOp.getOperand(yieldIdx));
831  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
832  newYieldOperands[yieldIdx] = castOut;
833  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
834  rewriter.eraseOp(clonedYieldOp);
835 
836  // 6. Inject an outgoing cast op after the forOp.
837  rewriter.setInsertionPointAfter(newForOp);
838  SmallVector<Value> newResults = newForOp.getResults();
839  newResults[yieldIdx] =
840  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
841 
842  return newResults;
843 }
844 
845 namespace {
846 // Fold away ForOp iter arguments when:
847 // 1) The op yields the iter arguments.
848 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
849 // 3) The iter arguments have no use and the corresponding (operation) results
850 // have no use.
851 //
852 // These arguments must be defined outside of the ForOp region and can just be
853 // forwarded after simplifying the op inits, yields and returns.
854 //
855 // The implementation uses `inlineBlockBefore` to steal the content of the
856 // original ForOp and avoid cloning.
857 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
859 
860  LogicalResult matchAndRewrite(scf::ForOp forOp,
861  PatternRewriter &rewriter) const final {
862  bool canonicalize = false;
863 
864  // An internal flat vector of block transfer
865  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
866  // transformed block argument mappings. This plays the role of a
867  // IRMapping for the particular use case of calling into
868  // `inlineBlockBefore`.
869  int64_t numResults = forOp.getNumResults();
870  SmallVector<bool, 4> keepMask;
871  keepMask.reserve(numResults);
872  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
873  newResultValues;
874  newBlockTransferArgs.reserve(1 + numResults);
875  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
876  newIterArgs.reserve(forOp.getInitArgs().size());
877  newYieldValues.reserve(numResults);
878  newResultValues.reserve(numResults);
879  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
880  for (auto [init, arg, result, yielded] :
881  llvm::zip(forOp.getInitArgs(), // iter from outside
882  forOp.getRegionIterArgs(), // iter inside region
883  forOp.getResults(), // op results
884  forOp.getYieldedValues() // iter yield
885  )) {
886  // Forwarded is `true` when:
887  // 1) The region `iter` argument is yielded.
888  // 2) The region `iter` argument the corresponding input is yielded.
889  // 3) The region `iter` argument has no use, and the corresponding op
890  // result has no use.
891  bool forwarded = (arg == yielded) || (init == yielded) ||
892  (arg.use_empty() && result.use_empty());
893  if (forwarded) {
894  canonicalize = true;
895  keepMask.push_back(false);
896  newBlockTransferArgs.push_back(init);
897  newResultValues.push_back(init);
898  continue;
899  }
900 
901  // Check if a previous kept argument always has the same values for init
902  // and yielded values.
903  if (auto it = initYieldToArg.find({init, yielded});
904  it != initYieldToArg.end()) {
905  canonicalize = true;
906  keepMask.push_back(false);
907  auto [sameArg, sameResult] = it->second;
908  rewriter.replaceAllUsesWith(arg, sameArg);
909  rewriter.replaceAllUsesWith(result, sameResult);
910  // The replacement value doesn't matter because there are no uses.
911  newBlockTransferArgs.push_back(init);
912  newResultValues.push_back(init);
913  continue;
914  }
915 
916  // This value is kept.
917  initYieldToArg.insert({{init, yielded}, {arg, result}});
918  keepMask.push_back(true);
919  newIterArgs.push_back(init);
920  newYieldValues.push_back(yielded);
921  newBlockTransferArgs.push_back(Value()); // placeholder with null value
922  newResultValues.push_back(Value()); // placeholder with null value
923  }
924 
925  if (!canonicalize)
926  return failure();
927 
928  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
929  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
930  forOp.getStep(), newIterArgs);
931  newForOp->setAttrs(forOp->getAttrs());
932  Block &newBlock = newForOp.getRegion().front();
933 
934  // Replace the null placeholders with newly constructed values.
935  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
936  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
937  idx != e; ++idx) {
938  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
939  Value &newResultVal = newResultValues[idx];
940  assert((blockTransferArg && newResultVal) ||
941  (!blockTransferArg && !newResultVal));
942  if (!blockTransferArg) {
943  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
944  newResultVal = newForOp.getResult(collapsedIdx++);
945  }
946  }
947 
948  Block &oldBlock = forOp.getRegion().front();
949  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
950  "unexpected argument size mismatch");
951 
952  // No results case: the scf::ForOp builder already created a zero
953  // result terminator. Merge before this terminator and just get rid of the
954  // original terminator that has been merged in.
955  if (newIterArgs.empty()) {
956  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
957  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
958  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
959  rewriter.replaceOp(forOp, newResultValues);
960  return success();
961  }
962 
963  // No terminator case: merge and rewrite the merged terminator.
964  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
965  OpBuilder::InsertionGuard g(rewriter);
966  rewriter.setInsertionPoint(mergedTerminator);
967  SmallVector<Value, 4> filteredOperands;
968  filteredOperands.reserve(newResultValues.size());
969  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
970  if (keepMask[idx])
971  filteredOperands.push_back(mergedTerminator.getOperand(idx));
972  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
973  filteredOperands);
974  };
975 
976  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
977  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
978  cloneFilteredTerminator(mergedYieldOp);
979  rewriter.eraseOp(mergedYieldOp);
980  rewriter.replaceOp(forOp, newResultValues);
981  return success();
982  }
983 };
984 
985 /// Util function that tries to compute a constant diff between u and l.
986 /// Returns std::nullopt when the difference between two AffineValueMap is
987 /// dynamic.
988 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
989  IntegerAttr clb, cub;
990  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
991  llvm::APInt lbValue = clb.getValue();
992  llvm::APInt ubValue = cub.getValue();
993  return (ubValue - lbValue).getSExtValue();
994  }
995 
996  // Else a simple pattern match for x + c or c + x
997  llvm::APInt diff;
998  if (matchPattern(
999  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
1000  matchPattern(
1001  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1002  return diff.getSExtValue();
1003  return std::nullopt;
1004 }
1005 
1006 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1007 /// single-iteration loops with their bodies, and removes empty loops that
1008 /// iterate at least once and only return values defined outside of the loop.
1009 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1011 
1012  LogicalResult matchAndRewrite(ForOp op,
1013  PatternRewriter &rewriter) const override {
1014  // If the upper bound is the same as the lower bound, the loop does not
1015  // iterate, just remove it.
1016  if (op.getLowerBound() == op.getUpperBound()) {
1017  rewriter.replaceOp(op, op.getInitArgs());
1018  return success();
1019  }
1020 
1021  std::optional<int64_t> diff =
1022  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1023  if (!diff)
1024  return failure();
1025 
1026  // If the loop is known to have 0 iterations, remove it.
1027  if (*diff <= 0) {
1028  rewriter.replaceOp(op, op.getInitArgs());
1029  return success();
1030  }
1031 
1032  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1033  if (!maybeStepValue)
1034  return failure();
1035 
1036  // If the loop is known to have 1 iteration, inline its body and remove the
1037  // loop.
1038  llvm::APInt stepValue = *maybeStepValue;
1039  if (stepValue.sge(*diff)) {
1040  SmallVector<Value, 4> blockArgs;
1041  blockArgs.reserve(op.getInitArgs().size() + 1);
1042  blockArgs.push_back(op.getLowerBound());
1043  llvm::append_range(blockArgs, op.getInitArgs());
1044  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1045  return success();
1046  }
1047 
1048  // Now we are left with loops that have more than 1 iterations.
1049  Block &block = op.getRegion().front();
1050  if (!llvm::hasSingleElement(block))
1051  return failure();
1052  // If the loop is empty, iterates at least once, and only returns values
1053  // defined outside of the loop, remove it and replace it with yield values.
1054  if (llvm::any_of(op.getYieldedValues(),
1055  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1056  return failure();
1057  rewriter.replaceOp(op, op.getYieldedValues());
1058  return success();
1059  }
1060 };
1061 
1062 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1063 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1064 ///
1065 /// ```
1066 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1067 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1068 /// -> (tensor<?x?xf32>) {
1069 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1070 /// scf.yield %2 : tensor<?x?xf32>
1071 /// }
1072 /// use_of(%1)
1073 /// ```
1074 ///
1075 /// folds into:
1076 ///
1077 /// ```
1078 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1079 /// -> (tensor<32x1024xf32>) {
1080 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1081 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1082 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1083 /// scf.yield %4 : tensor<32x1024xf32>
1084 /// }
1085 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1086 /// use_of(%1)
1087 /// ```
1088 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1090 
1091  LogicalResult matchAndRewrite(ForOp op,
1092  PatternRewriter &rewriter) const override {
1093  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1094  OpOperand &iterOpOperand = std::get<0>(it);
1095  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1096  if (!incomingCast ||
1097  incomingCast.getSource().getType() == incomingCast.getType())
1098  continue;
1099  // If the dest type of the cast does not preserve static information in
1100  // the source type.
1102  incomingCast.getDest().getType(),
1103  incomingCast.getSource().getType()))
1104  continue;
1105  if (!std::get<1>(it).hasOneUse())
1106  continue;
1107 
1108  // Create a new ForOp with that iter operand replaced.
1109  rewriter.replaceOp(
1111  rewriter, op, iterOpOperand, incomingCast.getSource(),
1112  [](OpBuilder &b, Location loc, Type type, Value source) {
1113  return b.create<tensor::CastOp>(loc, type, source);
1114  }));
1115  return success();
1116  }
1117  return failure();
1118  }
1119 };
1120 
1121 } // namespace
1122 
1123 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1124  MLIRContext *context) {
1125  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1126  context);
1127 }
1128 
1129 std::optional<APInt> ForOp::getConstantStep() {
1130  IntegerAttr step;
1131  if (matchPattern(getStep(), m_Constant(&step)))
1132  return step.getValue();
1133  return {};
1134 }
1135 
1136 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1137  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1138 }
1139 
1140 Speculation::Speculatability ForOp::getSpeculatability() {
1141  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1142  // and End.
1143  if (auto constantStep = getConstantStep())
1144  if (*constantStep == 1)
1146 
1147  // For Step != 1, the loop may not terminate. We can add more smarts here if
1148  // needed.
1150 }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // ForallOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 LogicalResult ForallOp::verify() {
1157  unsigned numLoops = getRank();
1158  // Check number of outputs.
1159  if (getNumResults() != getOutputs().size())
1160  return emitOpError("produces ")
1161  << getNumResults() << " results, but has only "
1162  << getOutputs().size() << " outputs";
1163 
1164  // Check that the body defines block arguments for thread indices and outputs.
1165  auto *body = getBody();
1166  if (body->getNumArguments() != numLoops + getOutputs().size())
1167  return emitOpError("region expects ") << numLoops << " arguments";
1168  for (int64_t i = 0; i < numLoops; ++i)
1169  if (!body->getArgument(i).getType().isIndex())
1170  return emitOpError("expects ")
1171  << i << "-th block argument to be an index";
1172  for (unsigned i = 0; i < getOutputs().size(); ++i)
1173  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1174  return emitOpError("type mismatch between ")
1175  << i << "-th output and corresponding block argument";
1176  if (getMapping().has_value() && !getMapping()->empty()) {
1177  if (getDeviceMappingAttrs().size() != numLoops)
1178  return emitOpError() << "mapping attribute size must match op rank";
1179  if (failed(getDeviceMaskingAttr()))
1180  return emitOpError() << getMappingAttrName()
1181  << " supports at most one device masking attribute";
1182  }
1183 
1184  // Verify mixed static/dynamic control variables.
1185  Operation *op = getOperation();
1186  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1187  getStaticLowerBound(),
1188  getDynamicLowerBound())))
1189  return failure();
1190  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1191  getStaticUpperBound(),
1192  getDynamicUpperBound())))
1193  return failure();
1194  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1195  getStaticStep(), getDynamicStep())))
1196  return failure();
1197 
1198  return success();
1199 }
1200 
1201 void ForallOp::print(OpAsmPrinter &p) {
1202  Operation *op = getOperation();
1203  p << " (" << getInductionVars();
1204  if (isNormalized()) {
1205  p << ") in ";
1206  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1207  /*valueTypes=*/{}, /*scalables=*/{},
1209  } else {
1210  p << ") = ";
1211  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1212  /*valueTypes=*/{}, /*scalables=*/{},
1214  p << " to ";
1215  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1216  /*valueTypes=*/{}, /*scalables=*/{},
1218  p << " step ";
1219  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1220  /*valueTypes=*/{}, /*scalables=*/{},
1222  }
1223  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1224  p << " ";
1225  if (!getRegionOutArgs().empty())
1226  p << "-> (" << getResultTypes() << ") ";
1227  p.printRegion(getRegion(),
1228  /*printEntryBlockArgs=*/false,
1229  /*printBlockTerminators=*/getNumResults() > 0);
1230  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1231  getStaticLowerBoundAttrName(),
1232  getStaticUpperBoundAttrName(),
1233  getStaticStepAttrName()});
1234 }
1235 
1236 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1237  OpBuilder b(parser.getContext());
1238  auto indexType = b.getIndexType();
1239 
1240  // Parse an opening `(` followed by thread index variables followed by `)`
1241  // TODO: when we can refer to such "induction variable"-like handles from the
1242  // declarative assembly format, we can implement the parser as a custom hook.
1245  return failure();
1246 
1247  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1248  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1249  dynamicSteps;
1250  if (succeeded(parser.parseOptionalKeyword("in"))) {
1251  // Parse upper bounds.
1252  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1253  /*valueTypes=*/nullptr,
1255  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1256  return failure();
1257 
1258  unsigned numLoops = ivs.size();
1259  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1260  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1261  } else {
1262  // Parse lower bounds.
1263  if (parser.parseEqual() ||
1264  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1265  /*valueTypes=*/nullptr,
1267 
1268  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1269  return failure();
1270 
1271  // Parse upper bounds.
1272  if (parser.parseKeyword("to") ||
1273  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1274  /*valueTypes=*/nullptr,
1276  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1277  return failure();
1278 
1279  // Parse step values.
1280  if (parser.parseKeyword("step") ||
1281  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1282  /*valueTypes=*/nullptr,
1284  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1285  return failure();
1286  }
1287 
1288  // Parse out operands and results.
1291  SMLoc outOperandsLoc = parser.getCurrentLocation();
1292  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1293  if (outOperands.size() != result.types.size())
1294  return parser.emitError(outOperandsLoc,
1295  "mismatch between out operands and types");
1296  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1297  parser.parseOptionalArrowTypeList(result.types) ||
1298  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1299  result.operands))
1300  return failure();
1301  }
1302 
1303  // Parse region.
1305  std::unique_ptr<Region> region = std::make_unique<Region>();
1306  for (auto &iv : ivs) {
1307  iv.type = b.getIndexType();
1308  regionArgs.push_back(iv);
1309  }
1310  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1311  auto &out = it.value();
1312  out.type = result.types[it.index()];
1313  regionArgs.push_back(out);
1314  }
1315  if (parser.parseRegion(*region, regionArgs))
1316  return failure();
1317 
1318  // Ensure terminator and move region.
1319  ForallOp::ensureTerminator(*region, b, result.location);
1320  result.addRegion(std::move(region));
1321 
1322  // Parse the optional attribute list.
1323  if (parser.parseOptionalAttrDict(result.attributes))
1324  return failure();
1325 
1326  result.addAttribute("staticLowerBound", staticLbs);
1327  result.addAttribute("staticUpperBound", staticUbs);
1328  result.addAttribute("staticStep", staticSteps);
1329  result.addAttribute("operandSegmentSizes",
1331  {static_cast<int32_t>(dynamicLbs.size()),
1332  static_cast<int32_t>(dynamicUbs.size()),
1333  static_cast<int32_t>(dynamicSteps.size()),
1334  static_cast<int32_t>(outOperands.size())}));
1335  return success();
1336 }
1337 
1338 // Builder that takes loop bounds.
1339 void ForallOp::build(
1342  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1343  std::optional<ArrayAttr> mapping,
1344  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1345  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1346  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1347  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1348  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1349  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1350 
1351  result.addOperands(dynamicLbs);
1352  result.addOperands(dynamicUbs);
1353  result.addOperands(dynamicSteps);
1354  result.addOperands(outputs);
1355  result.addTypes(TypeRange(outputs));
1356 
1357  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1358  b.getDenseI64ArrayAttr(staticLbs));
1359  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1360  b.getDenseI64ArrayAttr(staticUbs));
1361  result.addAttribute(getStaticStepAttrName(result.name),
1362  b.getDenseI64ArrayAttr(staticSteps));
1363  result.addAttribute(
1364  "operandSegmentSizes",
1365  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1366  static_cast<int32_t>(dynamicUbs.size()),
1367  static_cast<int32_t>(dynamicSteps.size()),
1368  static_cast<int32_t>(outputs.size())}));
1369  if (mapping.has_value()) {
1371  mapping.value());
1372  }
1373 
1374  Region *bodyRegion = result.addRegion();
1376  b.createBlock(bodyRegion);
1377  Block &bodyBlock = bodyRegion->front();
1378 
1379  // Add block arguments for indices and outputs.
1380  bodyBlock.addArguments(
1381  SmallVector<Type>(lbs.size(), b.getIndexType()),
1382  SmallVector<Location>(staticLbs.size(), result.location));
1383  bodyBlock.addArguments(
1384  TypeRange(outputs),
1385  SmallVector<Location>(outputs.size(), result.location));
1386 
1387  b.setInsertionPointToStart(&bodyBlock);
1388  if (!bodyBuilderFn) {
1389  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1390  return;
1391  }
1392  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1393 }
1394 
1395 // Builder that takes loop bounds.
1396 void ForallOp::build(
1398  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1399  std::optional<ArrayAttr> mapping,
1400  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1401  unsigned numLoops = ubs.size();
1402  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1403  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1404  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1405 }
1406 
1407 // Checks if the lbs are zeros and steps are ones.
1408 bool ForallOp::isNormalized() {
1409  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1410  return llvm::all_of(results, [&](OpFoldResult ofr) {
1411  auto intValue = getConstantIntValue(ofr);
1412  return intValue.has_value() && intValue == val;
1413  });
1414  };
1415  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1416 }
1417 
1418 InParallelOp ForallOp::getTerminator() {
1419  return cast<InParallelOp>(getBody()->getTerminator());
1420 }
1421 
1422 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1423  SmallVector<Operation *> storeOps;
1424  InParallelOp inParallelOp = getTerminator();
1425  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1426  if (auto parallelInsertSliceOp =
1427  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1428  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1429  storeOps.push_back(parallelInsertSliceOp);
1430  }
1431  }
1432  return storeOps;
1433 }
1434 
1435 SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1437  if (!getMapping())
1438  return res;
1439  for (auto attr : getMapping()->getValue()) {
1440  auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1441  if (m)
1442  res.push_back(m);
1443  }
1444  return res;
1445 }
1446 
1447 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1448  DeviceMaskingAttrInterface res;
1449  if (!getMapping())
1450  return res;
1451  for (auto attr : getMapping()->getValue()) {
1452  auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1453  if (m && res)
1454  return failure();
1455  if (m)
1456  res = m;
1457  }
1458  return res;
1459 }
1460 
1461 bool ForallOp::usesLinearMapping() {
1462  SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1463  if (ifaces.empty())
1464  return false;
1465  return ifaces.front().isLinearMapping();
1466 }
1467 
1468 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1469  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1470 }
1471 
1472 // Get lower bounds as OpFoldResult.
1473 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1474  Builder b(getOperation()->getContext());
1475  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1476 }
1477 
1478 // Get upper bounds as OpFoldResult.
1479 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1480  Builder b(getOperation()->getContext());
1481  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1482 }
1483 
1484 // Get steps as OpFoldResult.
1485 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1486  Builder b(getOperation()->getContext());
1487  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1488 }
1489 
1491  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1492  if (!tidxArg)
1493  return ForallOp();
1494  assert(tidxArg.getOwner() && "unlinked block argument");
1495  auto *containingOp = tidxArg.getOwner()->getParentOp();
1496  return dyn_cast<ForallOp>(containingOp);
1497 }
1498 
1499 namespace {
1500 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1501 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1503 
1504  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1505  PatternRewriter &rewriter) const final {
1506  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1507  if (!forallOp)
1508  return failure();
1509  Value sharedOut =
1510  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1511  ->get();
1512  rewriter.modifyOpInPlace(
1513  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1514  return success();
1515  }
1516 };
1517 
1518 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1519 public:
1521 
1522  LogicalResult matchAndRewrite(ForallOp op,
1523  PatternRewriter &rewriter) const override {
1524  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1525  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1526  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1527  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1528  failed(foldDynamicIndexList(mixedUpperBound)) &&
1529  failed(foldDynamicIndexList(mixedStep)))
1530  return failure();
1531 
1532  rewriter.modifyOpInPlace(op, [&]() {
1533  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1534  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1535  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1536  staticLowerBound);
1537  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1538  op.setStaticLowerBound(staticLowerBound);
1539 
1540  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1541  staticUpperBound);
1542  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1543  op.setStaticUpperBound(staticUpperBound);
1544 
1545  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1546  op.getDynamicStepMutable().assign(dynamicStep);
1547  op.setStaticStep(staticStep);
1548 
1549  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1550  rewriter.getDenseI32ArrayAttr(
1551  {static_cast<int32_t>(dynamicLowerBound.size()),
1552  static_cast<int32_t>(dynamicUpperBound.size()),
1553  static_cast<int32_t>(dynamicStep.size()),
1554  static_cast<int32_t>(op.getNumResults())}));
1555  });
1556  return success();
1557  }
1558 };
1559 
1560 /// The following canonicalization pattern folds the iter arguments of
1561 /// scf.forall op if :-
1562 /// 1. The corresponding result has zero uses.
1563 /// 2. The iter argument is NOT being modified within the loop body.
1564 /// uses.
1565 ///
1566 /// Example of first case :-
1567 /// INPUT:
1568 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1569 /// {
1570 /// ...
1571 /// <SOME USE OF %arg0>
1572 /// <SOME USE OF %arg1>
1573 /// <SOME USE OF %arg2>
1574 /// ...
1575 /// scf.forall.in_parallel {
1576 /// <STORE OP WITH DESTINATION %arg1>
1577 /// <STORE OP WITH DESTINATION %arg0>
1578 /// <STORE OP WITH DESTINATION %arg2>
1579 /// }
1580 /// }
1581 /// return %res#1
1582 ///
1583 /// OUTPUT:
1584 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1585 /// {
1586 /// ...
1587 /// <SOME USE OF %a>
1588 /// <SOME USE OF %new_arg0>
1589 /// <SOME USE OF %c>
1590 /// ...
1591 /// scf.forall.in_parallel {
1592 /// <STORE OP WITH DESTINATION %new_arg0>
1593 /// }
1594 /// }
1595 /// return %res
1596 ///
1597 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1598 /// scf.forall is replaced by their corresponding operands.
1599 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1600 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1601 /// this canonicalization remains valid. For more details, please refer
1602 /// to :
1603 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1604 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1605 /// handles tensor.parallel_insert_slice ops only.
1606 ///
1607 /// Example of second case :-
1608 /// INPUT:
1609 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1610 /// {
1611 /// ...
1612 /// <SOME USE OF %arg0>
1613 /// <SOME USE OF %arg1>
1614 /// ...
1615 /// scf.forall.in_parallel {
1616 /// <STORE OP WITH DESTINATION %arg1>
1617 /// }
1618 /// }
1619 /// return %res#0, %res#1
1620 ///
1621 /// OUTPUT:
1622 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1623 /// {
1624 /// ...
1625 /// <SOME USE OF %a>
1626 /// <SOME USE OF %new_arg0>
1627 /// ...
1628 /// scf.forall.in_parallel {
1629 /// <STORE OP WITH DESTINATION %new_arg0>
1630 /// }
1631 /// }
1632 /// return %a, %res
1633 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1635 
1636  LogicalResult matchAndRewrite(ForallOp forallOp,
1637  PatternRewriter &rewriter) const final {
1638  // Step 1: For a given i-th result of scf.forall, check the following :-
1639  // a. If it has any use.
1640  // b. If the corresponding iter argument is being modified within
1641  // the loop, i.e. has at least one store op with the iter arg as
1642  // its destination operand. For this we use
1643  // ForallOp::getCombiningOps(iter_arg).
1644  //
1645  // Based on the check we maintain the following :-
1646  // a. `resultToDelete` - i-th result of scf.forall that'll be
1647  // deleted.
1648  // b. `resultToReplace` - i-th result of the old scf.forall
1649  // whose uses will be replaced by the new scf.forall.
1650  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1651  // corresponding to the i-th result with at least one use.
1652  SetVector<OpResult> resultToDelete;
1653  SmallVector<Value> resultToReplace;
1654  SmallVector<Value> newOuts;
1655  for (OpResult result : forallOp.getResults()) {
1656  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1657  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1658  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1659  resultToDelete.insert(result);
1660  } else {
1661  resultToReplace.push_back(result);
1662  newOuts.push_back(opOperand->get());
1663  }
1664  }
1665 
1666  // Return early if all results of scf.forall have at least one use and being
1667  // modified within the loop.
1668  if (resultToDelete.empty())
1669  return failure();
1670 
1671  // Step 2: For the the i-th result, do the following :-
1672  // a. Fetch the corresponding BlockArgument.
1673  // b. Look for store ops (currently tensor.parallel_insert_slice)
1674  // with the BlockArgument as its destination operand.
1675  // c. Remove the operations fetched in b.
1676  for (OpResult result : resultToDelete) {
1677  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1678  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1679  SmallVector<Operation *> combiningOps =
1680  forallOp.getCombiningOps(blockArg);
1681  for (Operation *combiningOp : combiningOps)
1682  rewriter.eraseOp(combiningOp);
1683  }
1684 
1685  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1686  // fetched earlier
1687  auto newForallOp = rewriter.create<scf::ForallOp>(
1688  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1689  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1690  forallOp.getMapping(),
1691  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1692 
1693  // Step 4. Merge the block of the old scf.forall into the newly created
1694  // scf.forall using the new set of arguments.
1695  Block *loopBody = forallOp.getBody();
1696  Block *newLoopBody = newForallOp.getBody();
1697  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1698  // Form initial new bbArg list with just the control operands of the new
1699  // scf.forall op.
1700  SmallVector<Value> newBlockArgs =
1701  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1702  [](BlockArgument b) -> Value { return b; });
1703  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1704  unsigned index = 0;
1705  // Take the new corresponding bbArg if the old bbArg was used as a
1706  // destination in the in_parallel op. For all other bbArgs, use the
1707  // corresponding init_arg from the old scf.forall op.
1708  for (OpResult result : forallOp.getResults()) {
1709  if (resultToDelete.count(result)) {
1710  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1711  } else {
1712  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1713  }
1714  }
1715  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1716 
1717  // Step 5. Replace the uses of result of old scf.forall with that of the new
1718  // scf.forall.
1719  for (auto &&[oldResult, newResult] :
1720  llvm::zip(resultToReplace, newForallOp->getResults()))
1721  rewriter.replaceAllUsesWith(oldResult, newResult);
1722 
1723  // Step 6. Replace the uses of those values that either has no use or are
1724  // not being modified within the loop with the corresponding
1725  // OpOperand.
1726  for (OpResult oldResult : resultToDelete)
1727  rewriter.replaceAllUsesWith(oldResult,
1728  forallOp.getTiedOpOperand(oldResult)->get());
1729  return success();
1730  }
1731 };
1732 
1733 struct ForallOpSingleOrZeroIterationDimsFolder
1734  : public OpRewritePattern<ForallOp> {
1736 
1737  LogicalResult matchAndRewrite(ForallOp op,
1738  PatternRewriter &rewriter) const override {
1739  // Do not fold dimensions if they are mapped to processing units.
1740  if (op.getMapping().has_value() && !op.getMapping()->empty())
1741  return failure();
1742  Location loc = op.getLoc();
1743 
1744  // Compute new loop bounds that omit all single-iteration loop dimensions.
1745  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1746  newMixedSteps;
1747  IRMapping mapping;
1748  for (auto [lb, ub, step, iv] :
1749  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1750  op.getMixedStep(), op.getInductionVars())) {
1751  auto numIterations = constantTripCount(lb, ub, step);
1752  if (numIterations.has_value()) {
1753  // Remove the loop if it performs zero iterations.
1754  if (*numIterations == 0) {
1755  rewriter.replaceOp(op, op.getOutputs());
1756  return success();
1757  }
1758  // Replace the loop induction variable by the lower bound if the loop
1759  // performs a single iteration. Otherwise, copy the loop bounds.
1760  if (*numIterations == 1) {
1761  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1762  continue;
1763  }
1764  }
1765  newMixedLowerBounds.push_back(lb);
1766  newMixedUpperBounds.push_back(ub);
1767  newMixedSteps.push_back(step);
1768  }
1769 
1770  // All of the loop dimensions perform a single iteration. Inline loop body.
1771  if (newMixedLowerBounds.empty()) {
1772  promote(rewriter, op);
1773  return success();
1774  }
1775 
1776  // Exit if none of the loop dimensions perform a single iteration.
1777  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1778  return rewriter.notifyMatchFailure(
1779  op, "no dimensions have 0 or 1 iterations");
1780  }
1781 
1782  // Replace the loop by a lower-dimensional loop.
1783  ForallOp newOp;
1784  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1785  newMixedUpperBounds, newMixedSteps,
1786  op.getOutputs(), std::nullopt, nullptr);
1787  newOp.getBodyRegion().getBlocks().clear();
1788  // The new loop needs to keep all attributes from the old one, except for
1789  // "operandSegmentSizes" and static loop bound attributes which capture
1790  // the outdated information of the old iteration domain.
1791  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1792  newOp.getStaticLowerBoundAttrName(),
1793  newOp.getStaticUpperBoundAttrName(),
1794  newOp.getStaticStepAttrName()};
1795  for (const auto &namedAttr : op->getAttrs()) {
1796  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1797  continue;
1798  rewriter.modifyOpInPlace(newOp, [&]() {
1799  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1800  });
1801  }
1802  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1803  newOp.getRegion().begin(), mapping);
1804  rewriter.replaceOp(op, newOp.getResults());
1805  return success();
1806  }
1807 };
1808 
1809 /// Replace all induction vars with a single trip count with their lower bound.
1810 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1812 
1813  LogicalResult matchAndRewrite(ForallOp op,
1814  PatternRewriter &rewriter) const override {
1815  Location loc = op.getLoc();
1816  bool changed = false;
1817  for (auto [lb, ub, step, iv] :
1818  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1819  op.getMixedStep(), op.getInductionVars())) {
1820  if (iv.hasNUses(0))
1821  continue;
1822  auto numIterations = constantTripCount(lb, ub, step);
1823  if (!numIterations.has_value() || numIterations.value() != 1) {
1824  continue;
1825  }
1826  rewriter.replaceAllUsesWith(
1827  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1828  changed = true;
1829  }
1830  return success(changed);
1831  }
1832 };
1833 
1834 struct FoldTensorCastOfOutputIntoForallOp
1835  : public OpRewritePattern<scf::ForallOp> {
1837 
1838  struct TypeCast {
1839  Type srcType;
1840  Type dstType;
1841  };
1842 
1843  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1844  PatternRewriter &rewriter) const final {
1845  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1846  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1847  for (auto en : llvm::enumerate(newOutputTensors)) {
1848  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1849  if (!castOp)
1850  continue;
1851 
1852  // Only casts that that preserve static information, i.e. will make the
1853  // loop result type "more" static than before, will be folded.
1854  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1855  castOp.getSource().getType())) {
1856  continue;
1857  }
1858 
1859  tensorCastProducers[en.index()] =
1860  TypeCast{castOp.getSource().getType(), castOp.getType()};
1861  newOutputTensors[en.index()] = castOp.getSource();
1862  }
1863 
1864  if (tensorCastProducers.empty())
1865  return failure();
1866 
1867  // Create new loop.
1868  Location loc = forallOp.getLoc();
1869  auto newForallOp = rewriter.create<ForallOp>(
1870  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1871  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1872  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1873  auto castBlockArgs =
1874  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1875  for (auto [index, cast] : tensorCastProducers) {
1876  Value &oldTypeBBArg = castBlockArgs[index];
1877  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1878  nestedLoc, cast.dstType, oldTypeBBArg);
1879  }
1880 
1881  // Move old body into new parallel loop.
1882  SmallVector<Value> ivsBlockArgs =
1883  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1884  ivsBlockArgs.append(castBlockArgs);
1885  rewriter.mergeBlocks(forallOp.getBody(),
1886  bbArgs.front().getParentBlock(), ivsBlockArgs);
1887  });
1888 
1889  // After `mergeBlocks` happened, the destinations in the terminator were
1890  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1891  // destination have to be updated to point to the output bbArgs directly.
1892  auto terminator = newForallOp.getTerminator();
1893  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1894  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1895  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1896  insertSliceOp.getDestMutable().assign(outputBlockArg);
1897  }
1898 
1899  // Cast results back to the original types.
1900  rewriter.setInsertionPointAfter(newForallOp);
1901  SmallVector<Value> castResults = newForallOp.getResults();
1902  for (auto &item : tensorCastProducers) {
1903  Value &oldTypeResult = castResults[item.first];
1904  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1905  oldTypeResult);
1906  }
1907  rewriter.replaceOp(forallOp, castResults);
1908  return success();
1909  }
1910 };
1911 
1912 } // namespace
1913 
1914 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1915  MLIRContext *context) {
1916  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1917  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1918  ForallOpSingleOrZeroIterationDimsFolder,
1919  ForallOpReplaceConstantInductionVar>(context);
1920 }
1921 
1922 /// Given the region at `index`, or the parent operation if `index` is None,
1923 /// return the successor regions. These are the regions that may be selected
1924 /// during the flow of control. `operands` is a set of optional attributes that
1925 /// correspond to a constant value for each operand, or null if that operand is
1926 /// not a constant.
1927 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1929  // Both the operation itself and the region may be branching into the body or
1930  // back into the operation itself. It is possible for loop not to enter the
1931  // body.
1932  regions.push_back(RegionSuccessor(&getRegion()));
1933  regions.push_back(RegionSuccessor());
1934 }
1935 
1936 //===----------------------------------------------------------------------===//
1937 // InParallelOp
1938 //===----------------------------------------------------------------------===//
1939 
1940 // Build a InParallelOp with mixed static and dynamic entries.
1941 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1943  Region *bodyRegion = result.addRegion();
1944  b.createBlock(bodyRegion);
1945 }
1946 
1947 LogicalResult InParallelOp::verify() {
1948  scf::ForallOp forallOp =
1949  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1950  if (!forallOp)
1951  return this->emitOpError("expected forall op parent");
1952 
1953  // TODO: InParallelOpInterface.
1954  for (Operation &op : getRegion().front().getOperations()) {
1955  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1956  return this->emitOpError("expected only ")
1957  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1958  }
1959 
1960  // Verify that inserts are into out block arguments.
1961  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1962  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1963  if (!llvm::is_contained(regionOutArgs, dest))
1964  return op.emitOpError("may only insert into an output block argument");
1965  }
1966  return success();
1967 }
1968 
1970  p << " ";
1971  p.printRegion(getRegion(),
1972  /*printEntryBlockArgs=*/false,
1973  /*printBlockTerminators=*/false);
1974  p.printOptionalAttrDict(getOperation()->getAttrs());
1975 }
1976 
1977 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1978  auto &builder = parser.getBuilder();
1979 
1981  std::unique_ptr<Region> region = std::make_unique<Region>();
1982  if (parser.parseRegion(*region, regionOperands))
1983  return failure();
1984 
1985  if (region->empty())
1986  OpBuilder(builder.getContext()).createBlock(region.get());
1987  result.addRegion(std::move(region));
1988 
1989  // Parse the optional attribute list.
1990  if (parser.parseOptionalAttrDict(result.attributes))
1991  return failure();
1992  return success();
1993 }
1994 
1995 OpResult InParallelOp::getParentResult(int64_t idx) {
1996  return getOperation()->getParentOp()->getResult(idx);
1997 }
1998 
1999 SmallVector<BlockArgument> InParallelOp::getDests() {
2000  return llvm::to_vector<4>(
2001  llvm::map_range(getYieldingOps(), [](Operation &op) {
2002  // Add new ops here as needed.
2003  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2004  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2005  }));
2006 }
2007 
2008 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2009  return getRegion().front().getOperations();
2010 }
2011 
2012 //===----------------------------------------------------------------------===//
2013 // IfOp
2014 //===----------------------------------------------------------------------===//
2015 
2017  assert(a && "expected non-empty operation");
2018  assert(b && "expected non-empty operation");
2019 
2020  IfOp ifOp = a->getParentOfType<IfOp>();
2021  while (ifOp) {
2022  // Check if b is inside ifOp. (We already know that a is.)
2023  if (ifOp->isProperAncestor(b))
2024  // b is contained in ifOp. a and b are in mutually exclusive branches if
2025  // they are in different blocks of ifOp.
2026  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2027  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2028  // Check next enclosing IfOp.
2029  ifOp = ifOp->getParentOfType<IfOp>();
2030  }
2031 
2032  // Could not find a common IfOp among a's and b's ancestors.
2033  return false;
2034 }
2035 
2036 LogicalResult
2037 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2038  IfOp::Adaptor adaptor,
2039  SmallVectorImpl<Type> &inferredReturnTypes) {
2040  if (adaptor.getRegions().empty())
2041  return failure();
2042  Region *r = &adaptor.getThenRegion();
2043  if (r->empty())
2044  return failure();
2045  Block &b = r->front();
2046  if (b.empty())
2047  return failure();
2048  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2049  if (!yieldOp)
2050  return failure();
2051  TypeRange types = yieldOp.getOperandTypes();
2052  llvm::append_range(inferredReturnTypes, types);
2053  return success();
2054 }
2055 
2056 void IfOp::build(OpBuilder &builder, OperationState &result,
2057  TypeRange resultTypes, Value cond) {
2058  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2059  /*addElseBlock=*/false);
2060 }
2061 
2062 void IfOp::build(OpBuilder &builder, OperationState &result,
2063  TypeRange resultTypes, Value cond, bool addThenBlock,
2064  bool addElseBlock) {
2065  assert((!addElseBlock || addThenBlock) &&
2066  "must not create else block w/o then block");
2067  result.addTypes(resultTypes);
2068  result.addOperands(cond);
2069 
2070  // Add regions and blocks.
2071  OpBuilder::InsertionGuard guard(builder);
2072  Region *thenRegion = result.addRegion();
2073  if (addThenBlock)
2074  builder.createBlock(thenRegion);
2075  Region *elseRegion = result.addRegion();
2076  if (addElseBlock)
2077  builder.createBlock(elseRegion);
2078 }
2079 
2080 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2081  bool withElseRegion) {
2082  build(builder, result, TypeRange{}, cond, withElseRegion);
2083 }
2084 
2085 void IfOp::build(OpBuilder &builder, OperationState &result,
2086  TypeRange resultTypes, Value cond, bool withElseRegion) {
2087  result.addTypes(resultTypes);
2088  result.addOperands(cond);
2089 
2090  // Build then region.
2091  OpBuilder::InsertionGuard guard(builder);
2092  Region *thenRegion = result.addRegion();
2093  builder.createBlock(thenRegion);
2094  if (resultTypes.empty())
2095  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2096 
2097  // Build else region.
2098  Region *elseRegion = result.addRegion();
2099  if (withElseRegion) {
2100  builder.createBlock(elseRegion);
2101  if (resultTypes.empty())
2102  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2103  }
2104 }
2105 
2106 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2107  function_ref<void(OpBuilder &, Location)> thenBuilder,
2108  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2109  assert(thenBuilder && "the builder callback for 'then' must be present");
2110  result.addOperands(cond);
2111 
2112  // Build then region.
2113  OpBuilder::InsertionGuard guard(builder);
2114  Region *thenRegion = result.addRegion();
2115  builder.createBlock(thenRegion);
2116  thenBuilder(builder, result.location);
2117 
2118  // Build else region.
2119  Region *elseRegion = result.addRegion();
2120  if (elseBuilder) {
2121  builder.createBlock(elseRegion);
2122  elseBuilder(builder, result.location);
2123  }
2124 
2125  // Infer result types.
2126  SmallVector<Type> inferredReturnTypes;
2127  MLIRContext *ctx = builder.getContext();
2128  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2129  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2130  /*properties=*/nullptr, result.regions,
2131  inferredReturnTypes))) {
2132  result.addTypes(inferredReturnTypes);
2133  }
2134 }
2135 
2136 LogicalResult IfOp::verify() {
2137  if (getNumResults() != 0 && getElseRegion().empty())
2138  return emitOpError("must have an else block if defining values");
2139  return success();
2140 }
2141 
2142 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2143  // Create the regions for 'then'.
2144  result.regions.reserve(2);
2145  Region *thenRegion = result.addRegion();
2146  Region *elseRegion = result.addRegion();
2147 
2148  auto &builder = parser.getBuilder();
2150  Type i1Type = builder.getIntegerType(1);
2151  if (parser.parseOperand(cond) ||
2152  parser.resolveOperand(cond, i1Type, result.operands))
2153  return failure();
2154  // Parse optional results type list.
2155  if (parser.parseOptionalArrowTypeList(result.types))
2156  return failure();
2157  // Parse the 'then' region.
2158  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2159  return failure();
2160  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2161 
2162  // If we find an 'else' keyword then parse the 'else' region.
2163  if (!parser.parseOptionalKeyword("else")) {
2164  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2165  return failure();
2166  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2167  }
2168 
2169  // Parse the optional attribute list.
2170  if (parser.parseOptionalAttrDict(result.attributes))
2171  return failure();
2172  return success();
2173 }
2174 
2175 void IfOp::print(OpAsmPrinter &p) {
2176  bool printBlockTerminators = false;
2177 
2178  p << " " << getCondition();
2179  if (!getResults().empty()) {
2180  p << " -> (" << getResultTypes() << ")";
2181  // Print yield explicitly if the op defines values.
2182  printBlockTerminators = true;
2183  }
2184  p << ' ';
2185  p.printRegion(getThenRegion(),
2186  /*printEntryBlockArgs=*/false,
2187  /*printBlockTerminators=*/printBlockTerminators);
2188 
2189  // Print the 'else' regions if it exists and has a block.
2190  auto &elseRegion = getElseRegion();
2191  if (!elseRegion.empty()) {
2192  p << " else ";
2193  p.printRegion(elseRegion,
2194  /*printEntryBlockArgs=*/false,
2195  /*printBlockTerminators=*/printBlockTerminators);
2196  }
2197 
2198  p.printOptionalAttrDict((*this)->getAttrs());
2199 }
2200 
2201 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2203  // The `then` and the `else` region branch back to the parent operation.
2204  if (!point.isParent()) {
2205  regions.push_back(RegionSuccessor(getResults()));
2206  return;
2207  }
2208 
2209  regions.push_back(RegionSuccessor(&getThenRegion()));
2210 
2211  // Don't consider the else region if it is empty.
2212  Region *elseRegion = &this->getElseRegion();
2213  if (elseRegion->empty())
2214  regions.push_back(RegionSuccessor());
2215  else
2216  regions.push_back(RegionSuccessor(elseRegion));
2217 }
2218 
2219 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2221  FoldAdaptor adaptor(operands, *this);
2222  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2223  if (!boolAttr || boolAttr.getValue())
2224  regions.emplace_back(&getThenRegion());
2225 
2226  // If the else region is empty, execution continues after the parent op.
2227  if (!boolAttr || !boolAttr.getValue()) {
2228  if (!getElseRegion().empty())
2229  regions.emplace_back(&getElseRegion());
2230  else
2231  regions.emplace_back(getResults());
2232  }
2233 }
2234 
2235 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2236  SmallVectorImpl<OpFoldResult> &results) {
2237  // if (!c) then A() else B() -> if c then B() else A()
2238  if (getElseRegion().empty())
2239  return failure();
2240 
2241  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2242  if (!xorStmt)
2243  return failure();
2244 
2245  if (!matchPattern(xorStmt.getRhs(), m_One()))
2246  return failure();
2247 
2248  getConditionMutable().assign(xorStmt.getLhs());
2249  Block *thenBlock = &getThenRegion().front();
2250  // It would be nicer to use iplist::swap, but that has no implemented
2251  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2252  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2253  getElseRegion().getBlocks());
2254  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2255  getThenRegion().getBlocks(), thenBlock);
2256  return success();
2257 }
2258 
2259 void IfOp::getRegionInvocationBounds(
2260  ArrayRef<Attribute> operands,
2261  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2262  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2263  // If the condition is known, then one region is known to be executed once
2264  // and the other zero times.
2265  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2266  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2267  } else {
2268  // Non-constant condition. Each region may be executed 0 or 1 times.
2269  invocationBounds.assign(2, {0, 1});
2270  }
2271 }
2272 
2273 namespace {
2274 // Pattern to remove unused IfOp results.
2275 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2277 
2278  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2279  PatternRewriter &rewriter) const {
2280  // Move all operations to the destination block.
2281  rewriter.mergeBlocks(source, dest);
2282  // Replace the yield op by one that returns only the used values.
2283  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2284  SmallVector<Value, 4> usedOperands;
2285  llvm::transform(usedResults, std::back_inserter(usedOperands),
2286  [&](OpResult result) {
2287  return yieldOp.getOperand(result.getResultNumber());
2288  });
2289  rewriter.modifyOpInPlace(yieldOp,
2290  [&]() { yieldOp->setOperands(usedOperands); });
2291  }
2292 
2293  LogicalResult matchAndRewrite(IfOp op,
2294  PatternRewriter &rewriter) const override {
2295  // Compute the list of used results.
2296  SmallVector<OpResult, 4> usedResults;
2297  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2298  [](OpResult result) { return !result.use_empty(); });
2299 
2300  // Replace the operation if only a subset of its results have uses.
2301  if (usedResults.size() == op.getNumResults())
2302  return failure();
2303 
2304  // Compute the result types of the replacement operation.
2305  SmallVector<Type, 4> newTypes;
2306  llvm::transform(usedResults, std::back_inserter(newTypes),
2307  [](OpResult result) { return result.getType(); });
2308 
2309  // Create a replacement operation with empty then and else regions.
2310  auto newOp =
2311  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2312  rewriter.createBlock(&newOp.getThenRegion());
2313  rewriter.createBlock(&newOp.getElseRegion());
2314 
2315  // Move the bodies and replace the terminators (note there is a then and
2316  // an else region since the operation returns results).
2317  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2318  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2319 
2320  // Replace the operation by the new one.
2321  SmallVector<Value, 4> repResults(op.getNumResults());
2322  for (const auto &en : llvm::enumerate(usedResults))
2323  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2324  rewriter.replaceOp(op, repResults);
2325  return success();
2326  }
2327 };
2328 
2329 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2331 
2332  LogicalResult matchAndRewrite(IfOp op,
2333  PatternRewriter &rewriter) const override {
2334  BoolAttr condition;
2335  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2336  return failure();
2337 
2338  if (condition.getValue())
2339  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2340  else if (!op.getElseRegion().empty())
2341  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2342  else
2343  rewriter.eraseOp(op);
2344 
2345  return success();
2346  }
2347 };
2348 
2349 /// Hoist any yielded results whose operands are defined outside
2350 /// the if, to a select instruction.
2351 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2353 
2354  LogicalResult matchAndRewrite(IfOp op,
2355  PatternRewriter &rewriter) const override {
2356  if (op->getNumResults() == 0)
2357  return failure();
2358 
2359  auto cond = op.getCondition();
2360  auto thenYieldArgs = op.thenYield().getOperands();
2361  auto elseYieldArgs = op.elseYield().getOperands();
2362 
2363  SmallVector<Type> nonHoistable;
2364  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2365  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2366  &op.getElseRegion() == falseVal.getParentRegion())
2367  nonHoistable.push_back(trueVal.getType());
2368  }
2369  // Early exit if there aren't any yielded values we can
2370  // hoist outside the if.
2371  if (nonHoistable.size() == op->getNumResults())
2372  return failure();
2373 
2374  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2375  /*withElseRegion=*/false);
2376  if (replacement.thenBlock())
2377  rewriter.eraseBlock(replacement.thenBlock());
2378  replacement.getThenRegion().takeBody(op.getThenRegion());
2379  replacement.getElseRegion().takeBody(op.getElseRegion());
2380 
2381  SmallVector<Value> results(op->getNumResults());
2382  assert(thenYieldArgs.size() == results.size());
2383  assert(elseYieldArgs.size() == results.size());
2384 
2385  SmallVector<Value> trueYields;
2386  SmallVector<Value> falseYields;
2387  rewriter.setInsertionPoint(replacement);
2388  for (const auto &it :
2389  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2390  Value trueVal = std::get<0>(it.value());
2391  Value falseVal = std::get<1>(it.value());
2392  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2393  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2394  results[it.index()] = replacement.getResult(trueYields.size());
2395  trueYields.push_back(trueVal);
2396  falseYields.push_back(falseVal);
2397  } else if (trueVal == falseVal)
2398  results[it.index()] = trueVal;
2399  else
2400  results[it.index()] = rewriter.create<arith::SelectOp>(
2401  op.getLoc(), cond, trueVal, falseVal);
2402  }
2403 
2404  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2405  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2406 
2407  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2408  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2409 
2410  rewriter.replaceOp(op, results);
2411  return success();
2412  }
2413 };
2414 
2415 /// Allow the true region of an if to assume the condition is true
2416 /// and vice versa. For example:
2417 ///
2418 /// scf.if %cmp {
2419 /// print(%cmp)
2420 /// }
2421 ///
2422 /// becomes
2423 ///
2424 /// scf.if %cmp {
2425 /// print(true)
2426 /// }
2427 ///
2428 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2430 
2431  LogicalResult matchAndRewrite(IfOp op,
2432  PatternRewriter &rewriter) const override {
2433  // Early exit if the condition is constant since replacing a constant
2434  // in the body with another constant isn't a simplification.
2435  if (matchPattern(op.getCondition(), m_Constant()))
2436  return failure();
2437 
2438  bool changed = false;
2439  mlir::Type i1Ty = rewriter.getI1Type();
2440 
2441  // These variables serve to prevent creating duplicate constants
2442  // and hold constant true or false values.
2443  Value constantTrue = nullptr;
2444  Value constantFalse = nullptr;
2445 
2446  for (OpOperand &use :
2447  llvm::make_early_inc_range(op.getCondition().getUses())) {
2448  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2449  changed = true;
2450 
2451  if (!constantTrue)
2452  constantTrue = rewriter.create<arith::ConstantOp>(
2453  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2454 
2455  rewriter.modifyOpInPlace(use.getOwner(),
2456  [&]() { use.set(constantTrue); });
2457  } else if (op.getElseRegion().isAncestor(
2458  use.getOwner()->getParentRegion())) {
2459  changed = true;
2460 
2461  if (!constantFalse)
2462  constantFalse = rewriter.create<arith::ConstantOp>(
2463  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2464 
2465  rewriter.modifyOpInPlace(use.getOwner(),
2466  [&]() { use.set(constantFalse); });
2467  }
2468  }
2469 
2470  return success(changed);
2471  }
2472 };
2473 
2474 /// Remove any statements from an if that are equivalent to the condition
2475 /// or its negation. For example:
2476 ///
2477 /// %res:2 = scf.if %cmp {
2478 /// yield something(), true
2479 /// } else {
2480 /// yield something2(), false
2481 /// }
2482 /// print(%res#1)
2483 ///
2484 /// becomes
2485 /// %res = scf.if %cmp {
2486 /// yield something()
2487 /// } else {
2488 /// yield something2()
2489 /// }
2490 /// print(%cmp)
2491 ///
2492 /// Additionally if both branches yield the same value, replace all uses
2493 /// of the result with the yielded value.
2494 ///
2495 /// %res:2 = scf.if %cmp {
2496 /// yield something(), %arg1
2497 /// } else {
2498 /// yield something2(), %arg1
2499 /// }
2500 /// print(%res#1)
2501 ///
2502 /// becomes
2503 /// %res = scf.if %cmp {
2504 /// yield something()
2505 /// } else {
2506 /// yield something2()
2507 /// }
2508 /// print(%arg1)
2509 ///
2510 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2512 
2513  LogicalResult matchAndRewrite(IfOp op,
2514  PatternRewriter &rewriter) const override {
2515  // Early exit if there are no results that could be replaced.
2516  if (op.getNumResults() == 0)
2517  return failure();
2518 
2519  auto trueYield =
2520  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2521  auto falseYield =
2522  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2523 
2524  rewriter.setInsertionPoint(op->getBlock(),
2525  op.getOperation()->getIterator());
2526  bool changed = false;
2527  Type i1Ty = rewriter.getI1Type();
2528  for (auto [trueResult, falseResult, opResult] :
2529  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2530  op.getResults())) {
2531  if (trueResult == falseResult) {
2532  if (!opResult.use_empty()) {
2533  opResult.replaceAllUsesWith(trueResult);
2534  changed = true;
2535  }
2536  continue;
2537  }
2538 
2539  BoolAttr trueYield, falseYield;
2540  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2541  !matchPattern(falseResult, m_Constant(&falseYield)))
2542  continue;
2543 
2544  bool trueVal = trueYield.getValue();
2545  bool falseVal = falseYield.getValue();
2546  if (!trueVal && falseVal) {
2547  if (!opResult.use_empty()) {
2548  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2549  Value notCond = rewriter.create<arith::XOrIOp>(
2550  op.getLoc(), op.getCondition(),
2551  constDialect
2552  ->materializeConstant(rewriter,
2553  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2554  op.getLoc())
2555  ->getResult(0));
2556  opResult.replaceAllUsesWith(notCond);
2557  changed = true;
2558  }
2559  }
2560  if (trueVal && !falseVal) {
2561  if (!opResult.use_empty()) {
2562  opResult.replaceAllUsesWith(op.getCondition());
2563  changed = true;
2564  }
2565  }
2566  }
2567  return success(changed);
2568  }
2569 };
2570 
2571 /// Merge any consecutive scf.if's with the same condition.
2572 ///
2573 /// scf.if %cond {
2574 /// firstCodeTrue();...
2575 /// } else {
2576 /// firstCodeFalse();...
2577 /// }
2578 /// %res = scf.if %cond {
2579 /// secondCodeTrue();...
2580 /// } else {
2581 /// secondCodeFalse();...
2582 /// }
2583 ///
2584 /// becomes
2585 /// %res = scf.if %cmp {
2586 /// firstCodeTrue();...
2587 /// secondCodeTrue();...
2588 /// } else {
2589 /// firstCodeFalse();...
2590 /// secondCodeFalse();...
2591 /// }
2592 struct CombineIfs : public OpRewritePattern<IfOp> {
2594 
2595  LogicalResult matchAndRewrite(IfOp nextIf,
2596  PatternRewriter &rewriter) const override {
2597  Block *parent = nextIf->getBlock();
2598  if (nextIf == &parent->front())
2599  return failure();
2600 
2601  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2602  if (!prevIf)
2603  return failure();
2604 
2605  // Determine the logical then/else blocks when prevIf's
2606  // condition is used. Null means the block does not exist
2607  // in that case (e.g. empty else). If neither of these
2608  // are set, the two conditions cannot be compared.
2609  Block *nextThen = nullptr;
2610  Block *nextElse = nullptr;
2611  if (nextIf.getCondition() == prevIf.getCondition()) {
2612  nextThen = nextIf.thenBlock();
2613  if (!nextIf.getElseRegion().empty())
2614  nextElse = nextIf.elseBlock();
2615  }
2616  if (arith::XOrIOp notv =
2617  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2618  if (notv.getLhs() == prevIf.getCondition() &&
2619  matchPattern(notv.getRhs(), m_One())) {
2620  nextElse = nextIf.thenBlock();
2621  if (!nextIf.getElseRegion().empty())
2622  nextThen = nextIf.elseBlock();
2623  }
2624  }
2625  if (arith::XOrIOp notv =
2626  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2627  if (notv.getLhs() == nextIf.getCondition() &&
2628  matchPattern(notv.getRhs(), m_One())) {
2629  nextElse = nextIf.thenBlock();
2630  if (!nextIf.getElseRegion().empty())
2631  nextThen = nextIf.elseBlock();
2632  }
2633  }
2634 
2635  if (!nextThen && !nextElse)
2636  return failure();
2637 
2638  SmallVector<Value> prevElseYielded;
2639  if (!prevIf.getElseRegion().empty())
2640  prevElseYielded = prevIf.elseYield().getOperands();
2641  // Replace all uses of return values of op within nextIf with the
2642  // corresponding yields
2643  for (auto it : llvm::zip(prevIf.getResults(),
2644  prevIf.thenYield().getOperands(), prevElseYielded))
2645  for (OpOperand &use :
2646  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2647  if (nextThen && nextThen->getParent()->isAncestor(
2648  use.getOwner()->getParentRegion())) {
2649  rewriter.startOpModification(use.getOwner());
2650  use.set(std::get<1>(it));
2651  rewriter.finalizeOpModification(use.getOwner());
2652  } else if (nextElse && nextElse->getParent()->isAncestor(
2653  use.getOwner()->getParentRegion())) {
2654  rewriter.startOpModification(use.getOwner());
2655  use.set(std::get<2>(it));
2656  rewriter.finalizeOpModification(use.getOwner());
2657  }
2658  }
2659 
2660  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2661  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2662 
2663  IfOp combinedIf = rewriter.create<IfOp>(
2664  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2665  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2666 
2667  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2668  combinedIf.getThenRegion(),
2669  combinedIf.getThenRegion().begin());
2670 
2671  if (nextThen) {
2672  YieldOp thenYield = combinedIf.thenYield();
2673  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2674  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2675  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2676 
2677  SmallVector<Value> mergedYields(thenYield.getOperands());
2678  llvm::append_range(mergedYields, thenYield2.getOperands());
2679  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2680  rewriter.eraseOp(thenYield);
2681  rewriter.eraseOp(thenYield2);
2682  }
2683 
2684  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2685  combinedIf.getElseRegion(),
2686  combinedIf.getElseRegion().begin());
2687 
2688  if (nextElse) {
2689  if (combinedIf.getElseRegion().empty()) {
2690  rewriter.inlineRegionBefore(*nextElse->getParent(),
2691  combinedIf.getElseRegion(),
2692  combinedIf.getElseRegion().begin());
2693  } else {
2694  YieldOp elseYield = combinedIf.elseYield();
2695  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2696  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2697 
2698  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2699 
2700  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2701  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2702 
2703  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2704  rewriter.eraseOp(elseYield);
2705  rewriter.eraseOp(elseYield2);
2706  }
2707  }
2708 
2709  SmallVector<Value> prevValues;
2710  SmallVector<Value> nextValues;
2711  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2712  if (pair.index() < prevIf.getNumResults())
2713  prevValues.push_back(pair.value());
2714  else
2715  nextValues.push_back(pair.value());
2716  }
2717  rewriter.replaceOp(prevIf, prevValues);
2718  rewriter.replaceOp(nextIf, nextValues);
2719  return success();
2720  }
2721 };
2722 
2723 /// Pattern to remove an empty else branch.
2724 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2726 
2727  LogicalResult matchAndRewrite(IfOp ifOp,
2728  PatternRewriter &rewriter) const override {
2729  // Cannot remove else region when there are operation results.
2730  if (ifOp.getNumResults())
2731  return failure();
2732  Block *elseBlock = ifOp.elseBlock();
2733  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2734  return failure();
2735  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2736  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2737  newIfOp.getThenRegion().begin());
2738  rewriter.eraseOp(ifOp);
2739  return success();
2740  }
2741 };
2742 
2743 /// Convert nested `if`s into `arith.andi` + single `if`.
2744 ///
2745 /// scf.if %arg0 {
2746 /// scf.if %arg1 {
2747 /// ...
2748 /// scf.yield
2749 /// }
2750 /// scf.yield
2751 /// }
2752 /// becomes
2753 ///
2754 /// %0 = arith.andi %arg0, %arg1
2755 /// scf.if %0 {
2756 /// ...
2757 /// scf.yield
2758 /// }
2759 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2761 
2762  LogicalResult matchAndRewrite(IfOp op,
2763  PatternRewriter &rewriter) const override {
2764  auto nestedOps = op.thenBlock()->without_terminator();
2765  // Nested `if` must be the only op in block.
2766  if (!llvm::hasSingleElement(nestedOps))
2767  return failure();
2768 
2769  // If there is an else block, it can only yield
2770  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2771  return failure();
2772 
2773  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2774  if (!nestedIf)
2775  return failure();
2776 
2777  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2778  return failure();
2779 
2780  SmallVector<Value> thenYield(op.thenYield().getOperands());
2781  SmallVector<Value> elseYield;
2782  if (op.elseBlock())
2783  llvm::append_range(elseYield, op.elseYield().getOperands());
2784 
2785  // A list of indices for which we should upgrade the value yielded
2786  // in the else to a select.
2787  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2788 
2789  // If the outer scf.if yields a value produced by the inner scf.if,
2790  // only permit combining if the value yielded when the condition
2791  // is false in the outer scf.if is the same value yielded when the
2792  // inner scf.if condition is false.
2793  // Note that the array access to elseYield will not go out of bounds
2794  // since it must have the same length as thenYield, since they both
2795  // come from the same scf.if.
2796  for (const auto &tup : llvm::enumerate(thenYield)) {
2797  if (tup.value().getDefiningOp() == nestedIf) {
2798  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2799  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2800  elseYield[tup.index()]) {
2801  return failure();
2802  }
2803  // If the correctness test passes, we will yield
2804  // corresponding value from the inner scf.if
2805  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2806  continue;
2807  }
2808 
2809  // Otherwise, we need to ensure the else block of the combined
2810  // condition still returns the same value when the outer condition is
2811  // true and the inner condition is false. This can be accomplished if
2812  // the then value is defined outside the outer scf.if and we replace the
2813  // value with a select that considers just the outer condition. Since
2814  // the else region contains just the yield, its yielded value is
2815  // defined outside the scf.if, by definition.
2816 
2817  // If the then value is defined within the scf.if, bail.
2818  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2819  return failure();
2820  }
2821  elseYieldsToUpgradeToSelect.push_back(tup.index());
2822  }
2823 
2824  Location loc = op.getLoc();
2825  Value newCondition = rewriter.create<arith::AndIOp>(
2826  loc, op.getCondition(), nestedIf.getCondition());
2827  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2828  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2829 
2830  SmallVector<Value> results;
2831  llvm::append_range(results, newIf.getResults());
2832  rewriter.setInsertionPoint(newIf);
2833 
2834  for (auto idx : elseYieldsToUpgradeToSelect)
2835  results[idx] = rewriter.create<arith::SelectOp>(
2836  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2837 
2838  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2839  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2840  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2841  if (!elseYield.empty()) {
2842  rewriter.createBlock(&newIf.getElseRegion());
2843  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2844  rewriter.create<YieldOp>(loc, elseYield);
2845  }
2846  rewriter.replaceOp(op, results);
2847  return success();
2848  }
2849 };
2850 
2851 } // namespace
2852 
2853 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2854  MLIRContext *context) {
2855  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2856  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2857  RemoveStaticCondition, RemoveUnusedResults,
2858  ReplaceIfYieldWithConditionOrValue>(context);
2859 }
2860 
2861 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2862 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2863 Block *IfOp::elseBlock() {
2864  Region &r = getElseRegion();
2865  if (r.empty())
2866  return nullptr;
2867  return &r.back();
2868 }
2869 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2870 
2871 //===----------------------------------------------------------------------===//
2872 // ParallelOp
2873 //===----------------------------------------------------------------------===//
2874 
2875 void ParallelOp::build(
2876  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2877  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2879  bodyBuilderFn) {
2880  result.addOperands(lowerBounds);
2881  result.addOperands(upperBounds);
2882  result.addOperands(steps);
2883  result.addOperands(initVals);
2884  result.addAttribute(
2885  ParallelOp::getOperandSegmentSizeAttr(),
2886  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2887  static_cast<int32_t>(upperBounds.size()),
2888  static_cast<int32_t>(steps.size()),
2889  static_cast<int32_t>(initVals.size())}));
2890  result.addTypes(initVals.getTypes());
2891 
2892  OpBuilder::InsertionGuard guard(builder);
2893  unsigned numIVs = steps.size();
2894  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2895  SmallVector<Location, 8> argLocs(numIVs, result.location);
2896  Region *bodyRegion = result.addRegion();
2897  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2898 
2899  if (bodyBuilderFn) {
2900  builder.setInsertionPointToStart(bodyBlock);
2901  bodyBuilderFn(builder, result.location,
2902  bodyBlock->getArguments().take_front(numIVs),
2903  bodyBlock->getArguments().drop_front(numIVs));
2904  }
2905  // Add terminator only if there are no reductions.
2906  if (initVals.empty())
2907  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2908 }
2909 
2910 void ParallelOp::build(
2911  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2912  ValueRange upperBounds, ValueRange steps,
2913  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2914  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2915  // we don't capture a reference to a temporary by constructing the lambda at
2916  // function level.
2917  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2918  Location nestedLoc, ValueRange ivs,
2919  ValueRange) {
2920  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2921  };
2922  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2923  if (bodyBuilderFn)
2924  wrapper = wrappedBuilderFn;
2925 
2926  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2927  wrapper);
2928 }
2929 
2930 LogicalResult ParallelOp::verify() {
2931  // Check that there is at least one value in lowerBound, upperBound and step.
2932  // It is sufficient to test only step, because it is ensured already that the
2933  // number of elements in lowerBound, upperBound and step are the same.
2934  Operation::operand_range stepValues = getStep();
2935  if (stepValues.empty())
2936  return emitOpError(
2937  "needs at least one tuple element for lowerBound, upperBound and step");
2938 
2939  // Check whether all constant step values are positive.
2940  for (Value stepValue : stepValues)
2941  if (auto cst = getConstantIntValue(stepValue))
2942  if (*cst <= 0)
2943  return emitOpError("constant step operand must be positive");
2944 
2945  // Check that the body defines the same number of block arguments as the
2946  // number of tuple elements in step.
2947  Block *body = getBody();
2948  if (body->getNumArguments() != stepValues.size())
2949  return emitOpError() << "expects the same number of induction variables: "
2950  << body->getNumArguments()
2951  << " as bound and step values: " << stepValues.size();
2952  for (auto arg : body->getArguments())
2953  if (!arg.getType().isIndex())
2954  return emitOpError(
2955  "expects arguments for the induction variable to be of index type");
2956 
2957  // Check that the terminator is an scf.reduce op.
2958  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2959  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2960  if (!reduceOp)
2961  return failure();
2962 
2963  // Check that the number of results is the same as the number of reductions.
2964  auto resultsSize = getResults().size();
2965  auto reductionsSize = reduceOp.getReductions().size();
2966  auto initValsSize = getInitVals().size();
2967  if (resultsSize != reductionsSize)
2968  return emitOpError() << "expects number of results: " << resultsSize
2969  << " to be the same as number of reductions: "
2970  << reductionsSize;
2971  if (resultsSize != initValsSize)
2972  return emitOpError() << "expects number of results: " << resultsSize
2973  << " to be the same as number of initial values: "
2974  << initValsSize;
2975 
2976  // Check that the types of the results and reductions are the same.
2977  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2978  auto resultType = getOperation()->getResult(i).getType();
2979  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2980  if (resultType != reductionOperandType)
2981  return reduceOp.emitOpError()
2982  << "expects type of " << i
2983  << "-th reduction operand: " << reductionOperandType
2984  << " to be the same as the " << i
2985  << "-th result type: " << resultType;
2986  }
2987  return success();
2988 }
2989 
2990 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2991  auto &builder = parser.getBuilder();
2992  // Parse an opening `(` followed by induction variables followed by `)`
2995  return failure();
2996 
2997  // Parse loop bounds.
2999  if (parser.parseEqual() ||
3000  parser.parseOperandList(lower, ivs.size(),
3002  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
3003  return failure();
3004 
3006  if (parser.parseKeyword("to") ||
3007  parser.parseOperandList(upper, ivs.size(),
3009  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
3010  return failure();
3011 
3012  // Parse step values.
3014  if (parser.parseKeyword("step") ||
3015  parser.parseOperandList(steps, ivs.size(),
3017  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
3018  return failure();
3019 
3020  // Parse init values.
3022  if (succeeded(parser.parseOptionalKeyword("init"))) {
3023  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
3024  return failure();
3025  }
3026 
3027  // Parse optional results in case there is a reduce.
3028  if (parser.parseOptionalArrowTypeList(result.types))
3029  return failure();
3030 
3031  // Now parse the body.
3032  Region *body = result.addRegion();
3033  for (auto &iv : ivs)
3034  iv.type = builder.getIndexType();
3035  if (parser.parseRegion(*body, ivs))
3036  return failure();
3037 
3038  // Set `operandSegmentSizes` attribute.
3039  result.addAttribute(
3040  ParallelOp::getOperandSegmentSizeAttr(),
3041  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3042  static_cast<int32_t>(upper.size()),
3043  static_cast<int32_t>(steps.size()),
3044  static_cast<int32_t>(initVals.size())}));
3045 
3046  // Parse attributes.
3047  if (parser.parseOptionalAttrDict(result.attributes) ||
3048  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3049  result.operands))
3050  return failure();
3051 
3052  // Add a terminator if none was parsed.
3053  ParallelOp::ensureTerminator(*body, builder, result.location);
3054  return success();
3055 }
3056 
3058  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3059  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3060  if (!getInitVals().empty())
3061  p << " init (" << getInitVals() << ")";
3062  p.printOptionalArrowTypeList(getResultTypes());
3063  p << ' ';
3064  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3066  (*this)->getAttrs(),
3067  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3068 }
3069 
3070 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3071 
3072 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3073  return SmallVector<Value>{getBody()->getArguments()};
3074 }
3075 
3076 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3077  return getLowerBound();
3078 }
3079 
3080 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3081  return getUpperBound();
3082 }
3083 
3084 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3085  return getStep();
3086 }
3087 
3089  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3090  if (!ivArg)
3091  return ParallelOp();
3092  assert(ivArg.getOwner() && "unlinked block argument");
3093  auto *containingOp = ivArg.getOwner()->getParentOp();
3094  return dyn_cast<ParallelOp>(containingOp);
3095 }
3096 
3097 namespace {
3098 // Collapse loop dimensions that perform a single iteration.
3099 struct ParallelOpSingleOrZeroIterationDimsFolder
3100  : public OpRewritePattern<ParallelOp> {
3102 
3103  LogicalResult matchAndRewrite(ParallelOp op,
3104  PatternRewriter &rewriter) const override {
3105  Location loc = op.getLoc();
3106 
3107  // Compute new loop bounds that omit all single-iteration loop dimensions.
3108  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3109  IRMapping mapping;
3110  for (auto [lb, ub, step, iv] :
3111  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3112  op.getInductionVars())) {
3113  auto numIterations = constantTripCount(lb, ub, step);
3114  if (numIterations.has_value()) {
3115  // Remove the loop if it performs zero iterations.
3116  if (*numIterations == 0) {
3117  rewriter.replaceOp(op, op.getInitVals());
3118  return success();
3119  }
3120  // Replace the loop induction variable by the lower bound if the loop
3121  // performs a single iteration. Otherwise, copy the loop bounds.
3122  if (*numIterations == 1) {
3123  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3124  continue;
3125  }
3126  }
3127  newLowerBounds.push_back(lb);
3128  newUpperBounds.push_back(ub);
3129  newSteps.push_back(step);
3130  }
3131  // Exit if none of the loop dimensions perform a single iteration.
3132  if (newLowerBounds.size() == op.getLowerBound().size())
3133  return failure();
3134 
3135  if (newLowerBounds.empty()) {
3136  // All of the loop dimensions perform a single iteration. Inline
3137  // loop body and nested ReduceOp's
3138  SmallVector<Value> results;
3139  results.reserve(op.getInitVals().size());
3140  for (auto &bodyOp : op.getBody()->without_terminator())
3141  rewriter.clone(bodyOp, mapping);
3142  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3143  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3144  Block &reduceBlock = reduceOp.getReductions()[i].front();
3145  auto initValIndex = results.size();
3146  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3147  mapping.map(reduceBlock.getArgument(1),
3148  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3149  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3150  rewriter.clone(reduceBodyOp, mapping);
3151 
3152  auto result = mapping.lookupOrDefault(
3153  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3154  results.push_back(result);
3155  }
3156 
3157  rewriter.replaceOp(op, results);
3158  return success();
3159  }
3160  // Replace the parallel loop by lower-dimensional parallel loop.
3161  auto newOp =
3162  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3163  newSteps, op.getInitVals(), nullptr);
3164  // Erase the empty block that was inserted by the builder.
3165  rewriter.eraseBlock(newOp.getBody());
3166  // Clone the loop body and remap the block arguments of the collapsed loops
3167  // (inlining does not support a cancellable block argument mapping).
3168  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3169  newOp.getRegion().begin(), mapping);
3170  rewriter.replaceOp(op, newOp.getResults());
3171  return success();
3172  }
3173 };
3174 
3175 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3177 
3178  LogicalResult matchAndRewrite(ParallelOp op,
3179  PatternRewriter &rewriter) const override {
3180  Block &outerBody = *op.getBody();
3181  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3182  return failure();
3183 
3184  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3185  if (!innerOp)
3186  return failure();
3187 
3188  for (auto val : outerBody.getArguments())
3189  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3190  llvm::is_contained(innerOp.getUpperBound(), val) ||
3191  llvm::is_contained(innerOp.getStep(), val))
3192  return failure();
3193 
3194  // Reductions are not supported yet.
3195  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3196  return failure();
3197 
3198  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3199  ValueRange iterVals, ValueRange) {
3200  Block &innerBody = *innerOp.getBody();
3201  assert(iterVals.size() ==
3202  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3203  IRMapping mapping;
3204  mapping.map(outerBody.getArguments(),
3205  iterVals.take_front(outerBody.getNumArguments()));
3206  mapping.map(innerBody.getArguments(),
3207  iterVals.take_back(innerBody.getNumArguments()));
3208  for (Operation &op : innerBody.without_terminator())
3209  builder.clone(op, mapping);
3210  };
3211 
3212  auto concatValues = [](const auto &first, const auto &second) {
3213  SmallVector<Value> ret;
3214  ret.reserve(first.size() + second.size());
3215  ret.assign(first.begin(), first.end());
3216  ret.append(second.begin(), second.end());
3217  return ret;
3218  };
3219 
3220  auto newLowerBounds =
3221  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3222  auto newUpperBounds =
3223  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3224  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3225 
3226  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3227  newSteps, ValueRange(),
3228  bodyBuilder);
3229  return success();
3230  }
3231 };
3232 
3233 } // namespace
3234 
3235 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3236  MLIRContext *context) {
3237  results
3238  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3239  context);
3240 }
3241 
3242 /// Given the region at `index`, or the parent operation if `index` is None,
3243 /// return the successor regions. These are the regions that may be selected
3244 /// during the flow of control. `operands` is a set of optional attributes that
3245 /// correspond to a constant value for each operand, or null if that operand is
3246 /// not a constant.
3247 void ParallelOp::getSuccessorRegions(
3249  // Both the operation itself and the region may be branching into the body or
3250  // back into the operation itself. It is possible for loop not to enter the
3251  // body.
3252  regions.push_back(RegionSuccessor(&getRegion()));
3253  regions.push_back(RegionSuccessor());
3254 }
3255 
3256 //===----------------------------------------------------------------------===//
3257 // ReduceOp
3258 //===----------------------------------------------------------------------===//
3259 
3260 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3261 
3262 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3263  ValueRange operands) {
3264  result.addOperands(operands);
3265  for (Value v : operands) {
3266  OpBuilder::InsertionGuard guard(builder);
3267  Region *bodyRegion = result.addRegion();
3268  builder.createBlock(bodyRegion, {},
3269  ArrayRef<Type>{v.getType(), v.getType()},
3270  {result.location, result.location});
3271  }
3272 }
3273 
3274 LogicalResult ReduceOp::verifyRegions() {
3275  // The region of a ReduceOp has two arguments of the same type as its
3276  // corresponding operand.
3277  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3278  auto type = getOperands()[i].getType();
3279  Block &block = getReductions()[i].front();
3280  if (block.empty())
3281  return emitOpError() << i << "-th reduction has an empty body";
3282  if (block.getNumArguments() != 2 ||
3283  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3284  return arg.getType() != type;
3285  }))
3286  return emitOpError() << "expected two block arguments with type " << type
3287  << " in the " << i << "-th reduction region";
3288 
3289  // Check that the block is terminated by a ReduceReturnOp.
3290  if (!isa<ReduceReturnOp>(block.getTerminator()))
3291  return emitOpError("reduction bodies must be terminated with an "
3292  "'scf.reduce.return' op");
3293  }
3294 
3295  return success();
3296 }
3297 
3300  // No operands are forwarded to the next iteration.
3301  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3302 }
3303 
3304 //===----------------------------------------------------------------------===//
3305 // ReduceReturnOp
3306 //===----------------------------------------------------------------------===//
3307 
3308 LogicalResult ReduceReturnOp::verify() {
3309  // The type of the return value should be the same type as the types of the
3310  // block arguments of the reduction body.
3311  Block *reductionBody = getOperation()->getBlock();
3312  // Should already be verified by an op trait.
3313  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3314  Type expectedResultType = reductionBody->getArgument(0).getType();
3315  if (expectedResultType != getResult().getType())
3316  return emitOpError() << "must have type " << expectedResultType
3317  << " (the type of the reduction inputs)";
3318  return success();
3319 }
3320 
3321 //===----------------------------------------------------------------------===//
3322 // WhileOp
3323 //===----------------------------------------------------------------------===//
3324 
3325 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3326  ::mlir::OperationState &odsState, TypeRange resultTypes,
3327  ValueRange inits, BodyBuilderFn beforeBuilder,
3328  BodyBuilderFn afterBuilder) {
3329  odsState.addOperands(inits);
3330  odsState.addTypes(resultTypes);
3331 
3332  OpBuilder::InsertionGuard guard(odsBuilder);
3333 
3334  // Build before region.
3335  SmallVector<Location, 4> beforeArgLocs;
3336  beforeArgLocs.reserve(inits.size());
3337  for (Value operand : inits) {
3338  beforeArgLocs.push_back(operand.getLoc());
3339  }
3340 
3341  Region *beforeRegion = odsState.addRegion();
3342  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3343  inits.getTypes(), beforeArgLocs);
3344  if (beforeBuilder)
3345  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3346 
3347  // Build after region.
3348  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3349 
3350  Region *afterRegion = odsState.addRegion();
3351  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3352  resultTypes, afterArgLocs);
3353 
3354  if (afterBuilder)
3355  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3356 }
3357 
3358 ConditionOp WhileOp::getConditionOp() {
3359  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3360 }
3361 
3362 YieldOp WhileOp::getYieldOp() {
3363  return cast<YieldOp>(getAfterBody()->getTerminator());
3364 }
3365 
3366 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3367  return getYieldOp().getResultsMutable();
3368 }
3369 
3370 Block::BlockArgListType WhileOp::getBeforeArguments() {
3371  return getBeforeBody()->getArguments();
3372 }
3373 
3374 Block::BlockArgListType WhileOp::getAfterArguments() {
3375  return getAfterBody()->getArguments();
3376 }
3377 
3378 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3379  return getBeforeArguments();
3380 }
3381 
3382 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3383  assert(point == getBefore() &&
3384  "WhileOp is expected to branch only to the first region");
3385  return getInits();
3386 }
3387 
3388 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3390  // The parent op always branches to the condition region.
3391  if (point.isParent()) {
3392  regions.emplace_back(&getBefore(), getBefore().getArguments());
3393  return;
3394  }
3395 
3396  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3397  "there are only two regions in a WhileOp");
3398  // The body region always branches back to the condition region.
3399  if (point == getAfter()) {
3400  regions.emplace_back(&getBefore(), getBefore().getArguments());
3401  return;
3402  }
3403 
3404  regions.emplace_back(getResults());
3405  regions.emplace_back(&getAfter(), getAfter().getArguments());
3406 }
3407 
3408 SmallVector<Region *> WhileOp::getLoopRegions() {
3409  return {&getBefore(), &getAfter()};
3410 }
3411 
3412 /// Parses a `while` op.
3413 ///
3414 /// op ::= `scf.while` assignments `:` function-type region `do` region
3415 /// `attributes` attribute-dict
3416 /// initializer ::= /* empty */ | `(` assignment-list `)`
3417 /// assignment-list ::= assignment | assignment `,` assignment-list
3418 /// assignment ::= ssa-value `=` ssa-value
3419 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3422  Region *before = result.addRegion();
3423  Region *after = result.addRegion();
3424 
3425  OptionalParseResult listResult =
3426  parser.parseOptionalAssignmentList(regionArgs, operands);
3427  if (listResult.has_value() && failed(listResult.value()))
3428  return failure();
3429 
3430  FunctionType functionType;
3431  SMLoc typeLoc = parser.getCurrentLocation();
3432  if (failed(parser.parseColonType(functionType)))
3433  return failure();
3434 
3435  result.addTypes(functionType.getResults());
3436 
3437  if (functionType.getNumInputs() != operands.size()) {
3438  return parser.emitError(typeLoc)
3439  << "expected as many input types as operands "
3440  << "(expected " << operands.size() << " got "
3441  << functionType.getNumInputs() << ")";
3442  }
3443 
3444  // Resolve input operands.
3445  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3446  parser.getCurrentLocation(),
3447  result.operands)))
3448  return failure();
3449 
3450  // Propagate the types into the region arguments.
3451  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3452  regionArgs[i].type = functionType.getInput(i);
3453 
3454  return failure(parser.parseRegion(*before, regionArgs) ||
3455  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3457 }
3458 
3459 /// Prints a `while` op.
3461  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3462  p << " : ";
3463  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3464  p << ' ';
3465  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3466  p << " do ";
3467  p.printRegion(getAfter());
3468  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3469 }
3470 
3471 /// Verifies that two ranges of types match, i.e. have the same number of
3472 /// entries and that types are pairwise equals. Reports errors on the given
3473 /// operation in case of mismatch.
3474 template <typename OpTy>
3475 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3476  TypeRange right, StringRef message) {
3477  if (left.size() != right.size())
3478  return op.emitOpError("expects the same number of ") << message;
3479 
3480  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3481  if (left[i] != right[i]) {
3482  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3483  << message;
3484  diag.attachNote() << "for argument " << i << ", found " << left[i]
3485  << " and " << right[i];
3486  return diag;
3487  }
3488  }
3489 
3490  return success();
3491 }
3492 
3493 LogicalResult scf::WhileOp::verify() {
3494  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3495  *this, getBefore(),
3496  "expects the 'before' region to terminate with 'scf.condition'");
3497  if (!beforeTerminator)
3498  return failure();
3499 
3500  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3501  *this, getAfter(),
3502  "expects the 'after' region to terminate with 'scf.yield'");
3503  return success(afterTerminator != nullptr);
3504 }
3505 
3506 namespace {
3507 /// Replace uses of the condition within the do block with true, since otherwise
3508 /// the block would not be evaluated.
3509 ///
3510 /// scf.while (..) : (i1, ...) -> ... {
3511 /// %condition = call @evaluate_condition() : () -> i1
3512 /// scf.condition(%condition) %condition : i1, ...
3513 /// } do {
3514 /// ^bb0(%arg0: i1, ...):
3515 /// use(%arg0)
3516 /// ...
3517 ///
3518 /// becomes
3519 /// scf.while (..) : (i1, ...) -> ... {
3520 /// %condition = call @evaluate_condition() : () -> i1
3521 /// scf.condition(%condition) %condition : i1, ...
3522 /// } do {
3523 /// ^bb0(%arg0: i1, ...):
3524 /// use(%true)
3525 /// ...
3526 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3528 
3529  LogicalResult matchAndRewrite(WhileOp op,
3530  PatternRewriter &rewriter) const override {
3531  auto term = op.getConditionOp();
3532 
3533  // These variables serve to prevent creating duplicate constants
3534  // and hold constant true or false values.
3535  Value constantTrue = nullptr;
3536 
3537  bool replaced = false;
3538  for (auto yieldedAndBlockArgs :
3539  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3540  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3541  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3542  if (!constantTrue)
3543  constantTrue = rewriter.create<arith::ConstantOp>(
3544  op.getLoc(), term.getCondition().getType(),
3545  rewriter.getBoolAttr(true));
3546 
3547  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3548  constantTrue);
3549  replaced = true;
3550  }
3551  }
3552  }
3553  return success(replaced);
3554  }
3555 };
3556 
3557 /// Remove loop invariant arguments from `before` block of scf.while.
3558 /// A before block argument is considered loop invariant if :-
3559 /// 1. i-th yield operand is equal to the i-th while operand.
3560 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3561 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3562 /// iter argument/while operand.
3563 /// For the arguments which are removed, their uses inside scf.while
3564 /// are replaced with their corresponding initial value.
3565 ///
3566 /// Eg:
3567 /// INPUT :-
3568 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3569 /// ..., %argN_before = %N)
3570 /// {
3571 /// ...
3572 /// scf.condition(%cond) %arg1_before, %arg0_before,
3573 /// %arg2_before, %arg0_before, ...
3574 /// } do {
3575 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3576 /// ..., %argK_after):
3577 /// ...
3578 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3579 /// }
3580 ///
3581 /// OUTPUT :-
3582 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3583 /// %N)
3584 /// {
3585 /// ...
3586 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3587 /// } do {
3588 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3589 /// ..., %argK_after):
3590 /// ...
3591 /// scf.yield %arg1_after, ..., %argN
3592 /// }
3593 ///
3594 /// EXPLANATION:
3595 /// We iterate over each yield operand.
3596 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3597 /// %arg0_before, which in turn is the 0-th iter argument. So we
3598 /// remove 0-th before block argument and yield operand, and replace
3599 /// all uses of the 0-th before block argument with its initial value
3600 /// %a.
3601 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3602 /// value. So we remove this operand and the corresponding before
3603 /// block argument and replace all uses of 1-th before block argument
3604 /// with %b.
3605 struct RemoveLoopInvariantArgsFromBeforeBlock
3606  : public OpRewritePattern<WhileOp> {
3608 
3609  LogicalResult matchAndRewrite(WhileOp op,
3610  PatternRewriter &rewriter) const override {
3611  Block &afterBlock = *op.getAfterBody();
3612  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3613  ConditionOp condOp = op.getConditionOp();
3614  OperandRange condOpArgs = condOp.getArgs();
3615  Operation *yieldOp = afterBlock.getTerminator();
3616  ValueRange yieldOpArgs = yieldOp->getOperands();
3617 
3618  bool canSimplify = false;
3619  for (const auto &it :
3620  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3621  auto index = static_cast<unsigned>(it.index());
3622  auto [initVal, yieldOpArg] = it.value();
3623  // If i-th yield operand is equal to the i-th operand of the scf.while,
3624  // the i-th before block argument is a loop invariant.
3625  if (yieldOpArg == initVal) {
3626  canSimplify = true;
3627  break;
3628  }
3629  // If the i-th yield operand is k-th after block argument, then we check
3630  // if the (k+1)-th condition op operand is equal to either the i-th before
3631  // block argument or the initial value of i-th before block argument. If
3632  // the comparison results `true`, i-th before block argument is a loop
3633  // invariant.
3634  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3635  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3636  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3637  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3638  canSimplify = true;
3639  break;
3640  }
3641  }
3642  }
3643 
3644  if (!canSimplify)
3645  return failure();
3646 
3647  SmallVector<Value> newInitArgs, newYieldOpArgs;
3648  DenseMap<unsigned, Value> beforeBlockInitValMap;
3649  SmallVector<Location> newBeforeBlockArgLocs;
3650  for (const auto &it :
3651  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3652  auto index = static_cast<unsigned>(it.index());
3653  auto [initVal, yieldOpArg] = it.value();
3654 
3655  // If i-th yield operand is equal to the i-th operand of the scf.while,
3656  // the i-th before block argument is a loop invariant.
3657  if (yieldOpArg == initVal) {
3658  beforeBlockInitValMap.insert({index, initVal});
3659  continue;
3660  } else {
3661  // If the i-th yield operand is k-th after block argument, then we check
3662  // if the (k+1)-th condition op operand is equal to either the i-th
3663  // before block argument or the initial value of i-th before block
3664  // argument. If the comparison results `true`, i-th before block
3665  // argument is a loop invariant.
3666  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3667  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3668  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3669  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3670  beforeBlockInitValMap.insert({index, initVal});
3671  continue;
3672  }
3673  }
3674  }
3675  newInitArgs.emplace_back(initVal);
3676  newYieldOpArgs.emplace_back(yieldOpArg);
3677  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3678  }
3679 
3680  {
3681  OpBuilder::InsertionGuard g(rewriter);
3682  rewriter.setInsertionPoint(yieldOp);
3683  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3684  }
3685 
3686  auto newWhile =
3687  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3688 
3689  Block &newBeforeBlock = *rewriter.createBlock(
3690  &newWhile.getBefore(), /*insertPt*/ {},
3691  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3692 
3693  Block &beforeBlock = *op.getBeforeBody();
3694  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3695  // For each i-th before block argument we find it's replacement value as :-
3696  // 1. If i-th before block argument is a loop invariant, we fetch it's
3697  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3698  // 2. Else we fetch j-th new before block argument as the replacement
3699  // value of i-th before block argument.
3700  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3701  // If the index 'i' argument was a loop invariant we fetch it's initial
3702  // value from `beforeBlockInitValMap`.
3703  if (beforeBlockInitValMap.count(i) != 0)
3704  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3705  else
3706  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3707  }
3708 
3709  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3710  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3711  newWhile.getAfter().begin());
3712 
3713  rewriter.replaceOp(op, newWhile.getResults());
3714  return success();
3715  }
3716 };
3717 
3718 /// Remove loop invariant value from result (condition op) of scf.while.
3719 /// A value is considered loop invariant if the final value yielded by
3720 /// scf.condition is defined outside of the `before` block. We remove the
3721 /// corresponding argument in `after` block and replace the use with the value.
3722 /// We also replace the use of the corresponding result of scf.while with the
3723 /// value.
3724 ///
3725 /// Eg:
3726 /// INPUT :-
3727 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3728 /// %argN_before = %N) {
3729 /// ...
3730 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3731 /// } do {
3732 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3733 /// ...
3734 /// some_func(%arg1_after)
3735 /// ...
3736 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3737 /// }
3738 ///
3739 /// OUTPUT :-
3740 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3741 /// ...
3742 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3743 /// } do {
3744 /// ^bb0(%arg0, %arg3, ..., %argM):
3745 /// ...
3746 /// some_func(%a)
3747 /// ...
3748 /// scf.yield %arg0, %b, ..., %argN
3749 /// }
3750 ///
3751 /// EXPLANATION:
3752 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3753 /// before block of scf.while, so they get removed.
3754 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3755 /// replaced by %b.
3756 /// 3. The corresponding after block argument %arg1_after's uses are
3757 /// replaced by %a and %arg2_after's uses are replaced by %b.
3758 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3760 
3761  LogicalResult matchAndRewrite(WhileOp op,
3762  PatternRewriter &rewriter) const override {
3763  Block &beforeBlock = *op.getBeforeBody();
3764  ConditionOp condOp = op.getConditionOp();
3765  OperandRange condOpArgs = condOp.getArgs();
3766 
3767  bool canSimplify = false;
3768  for (Value condOpArg : condOpArgs) {
3769  // Those values not defined within `before` block will be considered as
3770  // loop invariant values. We map the corresponding `index` with their
3771  // value.
3772  if (condOpArg.getParentBlock() != &beforeBlock) {
3773  canSimplify = true;
3774  break;
3775  }
3776  }
3777 
3778  if (!canSimplify)
3779  return failure();
3780 
3781  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3782 
3783  SmallVector<Value> newCondOpArgs;
3784  SmallVector<Type> newAfterBlockType;
3785  DenseMap<unsigned, Value> condOpInitValMap;
3786  SmallVector<Location> newAfterBlockArgLocs;
3787  for (const auto &it : llvm::enumerate(condOpArgs)) {
3788  auto index = static_cast<unsigned>(it.index());
3789  Value condOpArg = it.value();
3790  // Those values not defined within `before` block will be considered as
3791  // loop invariant values. We map the corresponding `index` with their
3792  // value.
3793  if (condOpArg.getParentBlock() != &beforeBlock) {
3794  condOpInitValMap.insert({index, condOpArg});
3795  } else {
3796  newCondOpArgs.emplace_back(condOpArg);
3797  newAfterBlockType.emplace_back(condOpArg.getType());
3798  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3799  }
3800  }
3801 
3802  {
3803  OpBuilder::InsertionGuard g(rewriter);
3804  rewriter.setInsertionPoint(condOp);
3805  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3806  newCondOpArgs);
3807  }
3808 
3809  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3810  op.getOperands());
3811 
3812  Block &newAfterBlock =
3813  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3814  newAfterBlockType, newAfterBlockArgLocs);
3815 
3816  Block &afterBlock = *op.getAfterBody();
3817  // Since a new scf.condition op was created, we need to fetch the new
3818  // `after` block arguments which will be used while replacing operations of
3819  // previous scf.while's `after` blocks. We'd also be fetching new result
3820  // values too.
3821  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3822  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3823  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3824  Value afterBlockArg, result;
3825  // If index 'i' argument was loop invariant we fetch it's value from the
3826  // `condOpInitMap` map.
3827  if (condOpInitValMap.count(i) != 0) {
3828  afterBlockArg = condOpInitValMap[i];
3829  result = afterBlockArg;
3830  } else {
3831  afterBlockArg = newAfterBlock.getArgument(j);
3832  result = newWhile.getResult(j);
3833  j++;
3834  }
3835  newAfterBlockArgs[i] = afterBlockArg;
3836  newWhileResults[i] = result;
3837  }
3838 
3839  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3840  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3841  newWhile.getBefore().begin());
3842 
3843  rewriter.replaceOp(op, newWhileResults);
3844  return success();
3845  }
3846 };
3847 
3848 /// Remove WhileOp results that are also unused in 'after' block.
3849 ///
3850 /// %0:2 = scf.while () : () -> (i32, i64) {
3851 /// %condition = "test.condition"() : () -> i1
3852 /// %v1 = "test.get_some_value"() : () -> i32
3853 /// %v2 = "test.get_some_value"() : () -> i64
3854 /// scf.condition(%condition) %v1, %v2 : i32, i64
3855 /// } do {
3856 /// ^bb0(%arg0: i32, %arg1: i64):
3857 /// "test.use"(%arg0) : (i32) -> ()
3858 /// scf.yield
3859 /// }
3860 /// return %0#0 : i32
3861 ///
3862 /// becomes
3863 /// %0 = scf.while () : () -> (i32) {
3864 /// %condition = "test.condition"() : () -> i1
3865 /// %v1 = "test.get_some_value"() : () -> i32
3866 /// %v2 = "test.get_some_value"() : () -> i64
3867 /// scf.condition(%condition) %v1 : i32
3868 /// } do {
3869 /// ^bb0(%arg0: i32):
3870 /// "test.use"(%arg0) : (i32) -> ()
3871 /// scf.yield
3872 /// }
3873 /// return %0 : i32
3874 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3876 
3877  LogicalResult matchAndRewrite(WhileOp op,
3878  PatternRewriter &rewriter) const override {
3879  auto term = op.getConditionOp();
3880  auto afterArgs = op.getAfterArguments();
3881  auto termArgs = term.getArgs();
3882 
3883  // Collect results mapping, new terminator args and new result types.
3884  SmallVector<unsigned> newResultsIndices;
3885  SmallVector<Type> newResultTypes;
3886  SmallVector<Value> newTermArgs;
3887  SmallVector<Location> newArgLocs;
3888  bool needUpdate = false;
3889  for (const auto &it :
3890  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3891  auto i = static_cast<unsigned>(it.index());
3892  Value result = std::get<0>(it.value());
3893  Value afterArg = std::get<1>(it.value());
3894  Value termArg = std::get<2>(it.value());
3895  if (result.use_empty() && afterArg.use_empty()) {
3896  needUpdate = true;
3897  } else {
3898  newResultsIndices.emplace_back(i);
3899  newTermArgs.emplace_back(termArg);
3900  newResultTypes.emplace_back(result.getType());
3901  newArgLocs.emplace_back(result.getLoc());
3902  }
3903  }
3904 
3905  if (!needUpdate)
3906  return failure();
3907 
3908  {
3909  OpBuilder::InsertionGuard g(rewriter);
3910  rewriter.setInsertionPoint(term);
3911  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3912  newTermArgs);
3913  }
3914 
3915  auto newWhile =
3916  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3917 
3918  Block &newAfterBlock = *rewriter.createBlock(
3919  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3920 
3921  // Build new results list and new after block args (unused entries will be
3922  // null).
3923  SmallVector<Value> newResults(op.getNumResults());
3924  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3925  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3926  newResults[it.value()] = newWhile.getResult(it.index());
3927  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3928  }
3929 
3930  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3931  newWhile.getBefore().begin());
3932 
3933  Block &afterBlock = *op.getAfterBody();
3934  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3935 
3936  rewriter.replaceOp(op, newResults);
3937  return success();
3938  }
3939 };
3940 
3941 /// Replace operations equivalent to the condition in the do block with true,
3942 /// since otherwise the block would not be evaluated.
3943 ///
3944 /// scf.while (..) : (i32, ...) -> ... {
3945 /// %z = ... : i32
3946 /// %condition = cmpi pred %z, %a
3947 /// scf.condition(%condition) %z : i32, ...
3948 /// } do {
3949 /// ^bb0(%arg0: i32, ...):
3950 /// %condition2 = cmpi pred %arg0, %a
3951 /// use(%condition2)
3952 /// ...
3953 ///
3954 /// becomes
3955 /// scf.while (..) : (i32, ...) -> ... {
3956 /// %z = ... : i32
3957 /// %condition = cmpi pred %z, %a
3958 /// scf.condition(%condition) %z : i32, ...
3959 /// } do {
3960 /// ^bb0(%arg0: i32, ...):
3961 /// use(%true)
3962 /// ...
3963 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3965 
3966  LogicalResult matchAndRewrite(scf::WhileOp op,
3967  PatternRewriter &rewriter) const override {
3968  using namespace scf;
3969  auto cond = op.getConditionOp();
3970  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3971  if (!cmp)
3972  return failure();
3973  bool changed = false;
3974  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3975  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3976  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3977  continue;
3978  for (OpOperand &u :
3979  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3980  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3981  if (!cmp2)
3982  continue;
3983  // For a binary operator 1-opIdx gets the other side.
3984  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3985  continue;
3986  bool samePredicate;
3987  if (cmp2.getPredicate() == cmp.getPredicate())
3988  samePredicate = true;
3989  else if (cmp2.getPredicate() ==
3990  arith::invertPredicate(cmp.getPredicate()))
3991  samePredicate = false;
3992  else
3993  continue;
3994 
3995  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3996  1);
3997  changed = true;
3998  }
3999  }
4000  }
4001  return success(changed);
4002  }
4003 };
4004 
4005 /// Remove unused init/yield args.
4006 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
4008 
4009  LogicalResult matchAndRewrite(WhileOp op,
4010  PatternRewriter &rewriter) const override {
4011 
4012  if (!llvm::any_of(op.getBeforeArguments(),
4013  [](Value arg) { return arg.use_empty(); }))
4014  return rewriter.notifyMatchFailure(op, "No args to remove");
4015 
4016  YieldOp yield = op.getYieldOp();
4017 
4018  // Collect results mapping, new terminator args and new result types.
4019  SmallVector<Value> newYields;
4020  SmallVector<Value> newInits;
4021  llvm::BitVector argsToErase;
4022 
4023  size_t argsCount = op.getBeforeArguments().size();
4024  newYields.reserve(argsCount);
4025  newInits.reserve(argsCount);
4026  argsToErase.reserve(argsCount);
4027  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4028  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4029  if (beforeArg.use_empty()) {
4030  argsToErase.push_back(true);
4031  } else {
4032  argsToErase.push_back(false);
4033  newYields.emplace_back(yieldValue);
4034  newInits.emplace_back(initValue);
4035  }
4036  }
4037 
4038  Block &beforeBlock = *op.getBeforeBody();
4039  Block &afterBlock = *op.getAfterBody();
4040 
4041  beforeBlock.eraseArguments(argsToErase);
4042 
4043  Location loc = op.getLoc();
4044  auto newWhileOp =
4045  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4046  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4047  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4048  Block &newAfterBlock = *newWhileOp.getAfterBody();
4049 
4050  OpBuilder::InsertionGuard g(rewriter);
4051  rewriter.setInsertionPoint(yield);
4052  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4053 
4054  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4055  newBeforeBlock.getArguments());
4056  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4057  newAfterBlock.getArguments());
4058 
4059  rewriter.replaceOp(op, newWhileOp.getResults());
4060  return success();
4061  }
4062 };
4063 
4064 /// Remove duplicated ConditionOp args.
4065 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4067 
4068  LogicalResult matchAndRewrite(WhileOp op,
4069  PatternRewriter &rewriter) const override {
4070  ConditionOp condOp = op.getConditionOp();
4071  ValueRange condOpArgs = condOp.getArgs();
4072 
4073  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4074 
4075  if (argsSet.size() == condOpArgs.size())
4076  return rewriter.notifyMatchFailure(op, "No results to remove");
4077 
4078  llvm::SmallDenseMap<Value, unsigned> argsMap;
4079  SmallVector<Value> newArgs;
4080  argsMap.reserve(condOpArgs.size());
4081  newArgs.reserve(condOpArgs.size());
4082  for (Value arg : condOpArgs) {
4083  if (!argsMap.count(arg)) {
4084  auto pos = static_cast<unsigned>(argsMap.size());
4085  argsMap.insert({arg, pos});
4086  newArgs.emplace_back(arg);
4087  }
4088  }
4089 
4090  ValueRange argsRange(newArgs);
4091 
4092  Location loc = op.getLoc();
4093  auto newWhileOp = rewriter.create<scf::WhileOp>(
4094  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4095  /*afterBody*/ nullptr);
4096  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4097  Block &newAfterBlock = *newWhileOp.getAfterBody();
4098 
4099  SmallVector<Value> afterArgsMapping;
4100  SmallVector<Value> resultsMapping;
4101  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4102  auto it = argsMap.find(arg);
4103  assert(it != argsMap.end());
4104  auto pos = it->second;
4105  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4106  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4107  }
4108 
4109  OpBuilder::InsertionGuard g(rewriter);
4110  rewriter.setInsertionPoint(condOp);
4111  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4112  argsRange);
4113 
4114  Block &beforeBlock = *op.getBeforeBody();
4115  Block &afterBlock = *op.getAfterBody();
4116 
4117  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4118  newBeforeBlock.getArguments());
4119  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4120  rewriter.replaceOp(op, resultsMapping);
4121  return success();
4122  }
4123 };
4124 
4125 /// If both ranges contain same values return mappping indices from args2 to
4126 /// args1. Otherwise return std::nullopt.
4127 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4128  ValueRange args2) {
4129  if (args1.size() != args2.size())
4130  return std::nullopt;
4131 
4132  SmallVector<unsigned> ret(args1.size());
4133  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4134  auto it = llvm::find(args2, arg1);
4135  if (it == args2.end())
4136  return std::nullopt;
4137 
4138  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4139  }
4140 
4141  return ret;
4142 }
4143 
4144 static bool hasDuplicates(ValueRange args) {
4145  llvm::SmallDenseSet<Value> set;
4146  for (Value arg : args) {
4147  if (!set.insert(arg).second)
4148  return true;
4149  }
4150  return false;
4151 }
4152 
4153 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4154 /// `scf.condition` args into same order as block args. Update `after` block
4155 /// args and op result values accordingly.
4156 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4157 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4159 
4160  LogicalResult matchAndRewrite(WhileOp loop,
4161  PatternRewriter &rewriter) const override {
4162  auto oldBefore = loop.getBeforeBody();
4163  ConditionOp oldTerm = loop.getConditionOp();
4164  ValueRange beforeArgs = oldBefore->getArguments();
4165  ValueRange termArgs = oldTerm.getArgs();
4166  if (beforeArgs == termArgs)
4167  return failure();
4168 
4169  if (hasDuplicates(termArgs))
4170  return failure();
4171 
4172  auto mapping = getArgsMapping(beforeArgs, termArgs);
4173  if (!mapping)
4174  return failure();
4175 
4176  {
4177  OpBuilder::InsertionGuard g(rewriter);
4178  rewriter.setInsertionPoint(oldTerm);
4179  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4180  beforeArgs);
4181  }
4182 
4183  auto oldAfter = loop.getAfterBody();
4184 
4185  SmallVector<Type> newResultTypes(beforeArgs.size());
4186  for (auto &&[i, j] : llvm::enumerate(*mapping))
4187  newResultTypes[j] = loop.getResult(i).getType();
4188 
4189  auto newLoop = rewriter.create<WhileOp>(
4190  loop.getLoc(), newResultTypes, loop.getInits(),
4191  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4192  auto newBefore = newLoop.getBeforeBody();
4193  auto newAfter = newLoop.getAfterBody();
4194 
4195  SmallVector<Value> newResults(beforeArgs.size());
4196  SmallVector<Value> newAfterArgs(beforeArgs.size());
4197  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4198  newResults[i] = newLoop.getResult(j);
4199  newAfterArgs[i] = newAfter->getArgument(j);
4200  }
4201 
4202  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4203  newBefore->getArguments());
4204  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4205  newAfterArgs);
4206 
4207  rewriter.replaceOp(loop, newResults);
4208  return success();
4209  }
4210 };
4211 } // namespace
4212 
4213 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4214  MLIRContext *context) {
4215  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4216  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4217  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4218  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4219 }
4220 
4221 //===----------------------------------------------------------------------===//
4222 // IndexSwitchOp
4223 //===----------------------------------------------------------------------===//
4224 
4225 /// Parse the case regions and values.
4226 static ParseResult
4228  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4229  SmallVector<int64_t> caseValues;
4230  while (succeeded(p.parseOptionalKeyword("case"))) {
4231  int64_t value;
4232  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4233  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4234  return failure();
4235  caseValues.push_back(value);
4236  }
4237  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4238  return success();
4239 }
4240 
4241 /// Print the case regions and values.
4243  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4244  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4245  p.printNewline();
4246  p << "case " << value << ' ';
4247  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4248  }
4249 }
4250 
4251 LogicalResult scf::IndexSwitchOp::verify() {
4252  if (getCases().size() != getCaseRegions().size()) {
4253  return emitOpError("has ")
4254  << getCaseRegions().size() << " case regions but "
4255  << getCases().size() << " case values";
4256  }
4257 
4258  DenseSet<int64_t> valueSet;
4259  for (int64_t value : getCases())
4260  if (!valueSet.insert(value).second)
4261  return emitOpError("has duplicate case value: ") << value;
4262  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4263  auto yield = dyn_cast<YieldOp>(region.front().back());
4264  if (!yield)
4265  return emitOpError("expected region to end with scf.yield, but got ")
4266  << region.front().back().getName();
4267 
4268  if (yield.getNumOperands() != getNumResults()) {
4269  return (emitOpError("expected each region to return ")
4270  << getNumResults() << " values, but " << name << " returns "
4271  << yield.getNumOperands())
4272  .attachNote(yield.getLoc())
4273  << "see yield operation here";
4274  }
4275  for (auto [idx, result, operand] :
4276  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4277  yield.getOperandTypes())) {
4278  if (result == operand)
4279  continue;
4280  return (emitOpError("expected result #")
4281  << idx << " of each region to be " << result)
4282  .attachNote(yield.getLoc())
4283  << name << " returns " << operand << " here";
4284  }
4285  return success();
4286  };
4287 
4288  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4289  return failure();
4290  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4291  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4292  return failure();
4293 
4294  return success();
4295 }
4296 
4297 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4298 
4299 Block &scf::IndexSwitchOp::getDefaultBlock() {
4300  return getDefaultRegion().front();
4301 }
4302 
4303 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4304  assert(idx < getNumCases() && "case index out-of-bounds");
4305  return getCaseRegions()[idx].front();
4306 }
4307 
4308 void IndexSwitchOp::getSuccessorRegions(
4310  // All regions branch back to the parent op.
4311  if (!point.isParent()) {
4312  successors.emplace_back(getResults());
4313  return;
4314  }
4315 
4316  llvm::append_range(successors, getRegions());
4317 }
4318 
4319 void IndexSwitchOp::getEntrySuccessorRegions(
4320  ArrayRef<Attribute> operands,
4321  SmallVectorImpl<RegionSuccessor> &successors) {
4322  FoldAdaptor adaptor(operands, *this);
4323 
4324  // If a constant was not provided, all regions are possible successors.
4325  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4326  if (!arg) {
4327  llvm::append_range(successors, getRegions());
4328  return;
4329  }
4330 
4331  // Otherwise, try to find a case with a matching value. If not, the
4332  // default region is the only successor.
4333  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4334  if (caseValue == arg.getInt()) {
4335  successors.emplace_back(&caseRegion);
4336  return;
4337  }
4338  }
4339  successors.emplace_back(&getDefaultRegion());
4340 }
4341 
4342 void IndexSwitchOp::getRegionInvocationBounds(
4344  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4345  if (!operandValue) {
4346  // All regions are invoked at most once.
4347  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4348  return;
4349  }
4350 
4351  unsigned liveIndex = getNumRegions() - 1;
4352  const auto *it = llvm::find(getCases(), operandValue.getInt());
4353  if (it != getCases().end())
4354  liveIndex = std::distance(getCases().begin(), it);
4355  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4356  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4357 }
4358 
4359 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4361 
4362  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4363  PatternRewriter &rewriter) const override {
4364  // If `op.getArg()` is a constant, select the region that matches with
4365  // the constant value. Use the default region if no matche is found.
4366  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4367  if (!maybeCst.has_value())
4368  return failure();
4369  int64_t cst = *maybeCst;
4370  int64_t caseIdx, e = op.getNumCases();
4371  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4372  if (cst == op.getCases()[caseIdx])
4373  break;
4374  }
4375 
4376  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4377  : op.getDefaultRegion();
4378  Block &source = r.front();
4379  Operation *terminator = source.getTerminator();
4380  SmallVector<Value> results = terminator->getOperands();
4381 
4382  rewriter.inlineBlockBefore(&source, op);
4383  rewriter.eraseOp(terminator);
4384  // Replace the operation with a potentially empty list of results.
4385  // Fold mechanism doesn't support the case where the result list is empty.
4386  rewriter.replaceOp(op, results);
4387 
4388  return success();
4389  }
4390 };
4391 
4392 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4393  MLIRContext *context) {
4394  results.add<FoldConstantCase>(context);
4395 }
4396 
4397 //===----------------------------------------------------------------------===//
4398 // TableGen'd op method definitions
4399 //===----------------------------------------------------------------------===//
4400 
4401 #define GET_OP_CLASSES
4402 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:756
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:748
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:131
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1290
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4227
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3475
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:432
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:93
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4242
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:575
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:585
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:829
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:620
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:612
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:596
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3088
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:86
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:693
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:2016
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:650
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:603
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1490
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:782
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4362
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:234
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:185
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.