MLIR  19.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // SCFDialect Dialect Interfaces
37 //===----------------------------------------------------------------------===//
38 
39 namespace {
40 struct SCFInlinerInterface : public DialectInlinerInterface {
42  // We don't have any special restrictions on what can be inlined into
43  // destination regions (e.g. while/conditional bodies). Always allow it.
44  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45  IRMapping &valueMapping) const final {
46  return true;
47  }
48  // Operations in scf dialect are always legal to inline since they are
49  // pure.
50  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
51  return true;
52  }
53  // Handle the given inlined terminator by replacing it with a new operation
54  // as necessary. Required when the region has only one block.
55  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
56  auto retValOp = dyn_cast<scf::YieldOp>(op);
57  if (!retValOp)
58  return;
59 
60  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
62  }
63  }
64 };
65 } // namespace
66 
67 //===----------------------------------------------------------------------===//
68 // SCFDialect
69 //===----------------------------------------------------------------------===//
70 
71 void SCFDialect::initialize() {
72  addOperations<
73 #define GET_OP_LIST
74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75  >();
76  addInterfaces<SCFInlinerInterface>();
77  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78  InParallelOp, ReduceReturnOp>();
79  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81  ForallOp, InParallelOp, WhileOp, YieldOp>();
82  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
83 }
84 
85 /// Default callback for IfOp builders. Inserts a yield without arguments.
87  builder.create<scf::YieldOp>(loc);
88 }
89 
90 /// Verifies that the first block of the given `region` is terminated by a
91 /// TerminatorTy. Reports errors on the given operation if it is not the case.
92 template <typename TerminatorTy>
93 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
94  StringRef errorMessage) {
95  Operation *terminatorOperation = nullptr;
96  if (!region.empty() && !region.front().empty()) {
97  terminatorOperation = &region.front().back();
98  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
99  return yield;
100  }
101  auto diag = op->emitOpError(errorMessage);
102  if (terminatorOperation)
103  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
104  return nullptr;
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // ExecuteRegionOp
109 //===----------------------------------------------------------------------===//
110 
111 /// Replaces the given op with the contents of the given single-block region,
112 /// using the operands of the block terminator to replace operation results.
113 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
114  Region &region, ValueRange blockArgs = {}) {
115  assert(llvm::hasSingleElement(region) && "expected single-region block");
116  Block *block = &region.front();
117  Operation *terminator = block->getTerminator();
118  ValueRange results = terminator->getOperands();
119  rewriter.inlineBlockBefore(block, op, blockArgs);
120  rewriter.replaceOp(op, results);
121  rewriter.eraseOp(terminator);
122 }
123 
124 ///
125 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
126 /// block+
127 /// `}`
128 ///
129 /// Example:
130 /// scf.execute_region -> i32 {
131 /// %idx = load %rI[%i] : memref<128xi32>
132 /// return %idx : i32
133 /// }
134 ///
136  OperationState &result) {
137  if (parser.parseOptionalArrowTypeList(result.types))
138  return failure();
139 
140  // Introduce the body region and parse it.
141  Region *body = result.addRegion();
142  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
143  parser.parseOptionalAttrDict(result.attributes))
144  return failure();
145 
146  return success();
147 }
148 
150  p.printOptionalArrowTypeList(getResultTypes());
151 
152  p << ' ';
153  p.printRegion(getRegion(),
154  /*printEntryBlockArgs=*/false,
155  /*printBlockTerminators=*/true);
156 
157  p.printOptionalAttrDict((*this)->getAttrs());
158 }
159 
161  if (getRegion().empty())
162  return emitOpError("region needs to have at least one block");
163  if (getRegion().front().getNumArguments() > 0)
164  return emitOpError("region cannot have any arguments");
165  return success();
166 }
167 
168 // Inline an ExecuteRegionOp if it only contains one block.
169 // "test.foo"() : () -> ()
170 // %v = scf.execute_region -> i64 {
171 // %x = "test.val"() : () -> i64
172 // scf.yield %x : i64
173 // }
174 // "test.bar"(%v) : (i64) -> ()
175 //
176 // becomes
177 //
178 // "test.foo"() : () -> ()
179 // %x = "test.val"() : () -> i64
180 // "test.bar"(%x) : (i64) -> ()
181 //
182 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
184 
185  LogicalResult matchAndRewrite(ExecuteRegionOp op,
186  PatternRewriter &rewriter) const override {
187  if (!llvm::hasSingleElement(op.getRegion()))
188  return failure();
189  replaceOpWithRegion(rewriter, op, op.getRegion());
190  return success();
191  }
192 };
193 
194 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
195 // TODO generalize the conditions for operations which can be inlined into.
196 // func @func_execute_region_elim() {
197 // "test.foo"() : () -> ()
198 // %v = scf.execute_region -> i64 {
199 // %c = "test.cmp"() : () -> i1
200 // cf.cond_br %c, ^bb2, ^bb3
201 // ^bb2:
202 // %x = "test.val1"() : () -> i64
203 // cf.br ^bb4(%x : i64)
204 // ^bb3:
205 // %y = "test.val2"() : () -> i64
206 // cf.br ^bb4(%y : i64)
207 // ^bb4(%z : i64):
208 // scf.yield %z : i64
209 // }
210 // "test.bar"(%v) : (i64) -> ()
211 // return
212 // }
213 //
214 // becomes
215 //
216 // func @func_execute_region_elim() {
217 // "test.foo"() : () -> ()
218 // %c = "test.cmp"() : () -> i1
219 // cf.cond_br %c, ^bb1, ^bb2
220 // ^bb1: // pred: ^bb0
221 // %x = "test.val1"() : () -> i64
222 // cf.br ^bb3(%x : i64)
223 // ^bb2: // pred: ^bb0
224 // %y = "test.val2"() : () -> i64
225 // cf.br ^bb3(%y : i64)
226 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
227 // "test.bar"(%z) : (i64) -> ()
228 // return
229 // }
230 //
231 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
233 
234  LogicalResult matchAndRewrite(ExecuteRegionOp op,
235  PatternRewriter &rewriter) const override {
236  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
237  return failure();
238 
239  Block *prevBlock = op->getBlock();
240  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
241  rewriter.setInsertionPointToEnd(prevBlock);
242 
243  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 
245  for (Block &blk : op.getRegion()) {
246  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247  rewriter.setInsertionPoint(yieldOp);
248  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
249  yieldOp.getResults());
250  rewriter.eraseOp(yieldOp);
251  }
252  }
253 
254  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
255  SmallVector<Value> blockArgs;
256 
257  for (auto res : op.getResults())
258  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
259 
260  rewriter.replaceOp(op, blockArgs);
261  return success();
262  }
263 };
264 
265 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
266  MLIRContext *context) {
268 }
269 
270 void ExecuteRegionOp::getSuccessorRegions(
272  // If the predecessor is the ExecuteRegionOp, branch into the body.
273  if (point.isParent()) {
274  regions.push_back(RegionSuccessor(&getRegion()));
275  return;
276  }
277 
278  // Otherwise, the region branches back to the parent operation.
279  regions.push_back(RegionSuccessor(getResults()));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // ConditionOp
284 //===----------------------------------------------------------------------===//
285 
288  assert((point.isParent() || point == getParentOp().getAfter()) &&
289  "condition op can only exit the loop or branch to the after"
290  "region");
291  // Pass all operands except the condition to the successor region.
292  return getArgsMutable();
293 }
294 
295 void ConditionOp::getSuccessorRegions(
297  FoldAdaptor adaptor(operands, *this);
298 
299  WhileOp whileOp = getParentOp();
300 
301  // Condition can either lead to the after region or back to the parent op
302  // depending on whether the condition is true or not.
303  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304  if (!boolAttr || boolAttr.getValue())
305  regions.emplace_back(&whileOp.getAfter(),
306  whileOp.getAfter().getArguments());
307  if (!boolAttr || !boolAttr.getValue())
308  regions.emplace_back(whileOp.getResults());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // ForOp
313 //===----------------------------------------------------------------------===//
314 
315 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
316  Value ub, Value step, ValueRange iterArgs,
317  BodyBuilderFn bodyBuilder) {
318  OpBuilder::InsertionGuard guard(builder);
319 
320  result.addOperands({lb, ub, step});
321  result.addOperands(iterArgs);
322  for (Value v : iterArgs)
323  result.addTypes(v.getType());
324  Type t = lb.getType();
325  Region *bodyRegion = result.addRegion();
326  Block *bodyBlock = builder.createBlock(bodyRegion);
327  bodyBlock->addArgument(t, result.location);
328  for (Value v : iterArgs)
329  bodyBlock->addArgument(v.getType(), v.getLoc());
330 
331  // Create the default terminator if the builder is not provided and if the
332  // iteration arguments are not provided. Otherwise, leave this to the caller
333  // because we don't know which values to return from the loop.
334  if (iterArgs.empty() && !bodyBuilder) {
335  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
336  } else if (bodyBuilder) {
337  OpBuilder::InsertionGuard guard(builder);
338  builder.setInsertionPointToStart(bodyBlock);
339  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
340  bodyBlock->getArguments().drop_front());
341  }
342 }
343 
345  // Check that the number of init args and op results is the same.
346  if (getInitArgs().size() != getNumResults())
347  return emitOpError(
348  "mismatch in number of loop-carried values and defined values");
349 
350  return success();
351 }
352 
353 LogicalResult ForOp::verifyRegions() {
354  // Check that the body defines as single block argument for the induction
355  // variable.
356  if (getInductionVar().getType() != getLowerBound().getType())
357  return emitOpError(
358  "expected induction variable to be same type as bounds and step");
359 
360  if (getNumRegionIterArgs() != getNumResults())
361  return emitOpError(
362  "mismatch in number of basic block args and defined values");
363 
364  auto initArgs = getInitArgs();
365  auto iterArgs = getRegionIterArgs();
366  auto opResults = getResults();
367  unsigned i = 0;
368  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369  if (std::get<0>(e).getType() != std::get<2>(e).getType())
370  return emitOpError() << "types mismatch between " << i
371  << "th iter operand and defined value";
372  if (std::get<1>(e).getType() != std::get<2>(e).getType())
373  return emitOpError() << "types mismatch between " << i
374  << "th iter region arg and defined value";
375 
376  ++i;
377  }
378  return success();
379 }
380 
381 std::optional<Value> ForOp::getSingleInductionVar() {
382  return getInductionVar();
383 }
384 
385 std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
386  return OpFoldResult(getLowerBound());
387 }
388 
389 std::optional<OpFoldResult> ForOp::getSingleStep() {
390  return OpFoldResult(getStep());
391 }
392 
393 std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
394  return OpFoldResult(getUpperBound());
395 }
396 
397 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
398 
399 /// Promotes the loop body of a forOp to its containing block if the forOp
400 /// it can be determined that the loop has a single iteration.
402  std::optional<int64_t> tripCount =
404  if (!tripCount.has_value() || tripCount != 1)
405  return failure();
406 
407  // Replace all results with the yielded values.
408  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
409  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
410 
411  // Replace block arguments with lower bound (replacement for IV) and
412  // iter_args.
413  SmallVector<Value> bbArgReplacements;
414  bbArgReplacements.push_back(getLowerBound());
415  llvm::append_range(bbArgReplacements, getInitArgs());
416 
417  // Move the loop body operations to the loop's containing block.
418  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
419  getOperation()->getIterator(), bbArgReplacements);
420 
421  // Erase the old terminator and the loop.
422  rewriter.eraseOp(yieldOp);
423  rewriter.eraseOp(*this);
424 
425  return success();
426 }
427 
428 /// Prints the initialization list in the form of
429 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
430 /// where 'inner' values are assumed to be region arguments and 'outer' values
431 /// are regular SSA values.
433  Block::BlockArgListType blocksArgs,
434  ValueRange initializers,
435  StringRef prefix = "") {
436  assert(blocksArgs.size() == initializers.size() &&
437  "expected same length of arguments and initializers");
438  if (initializers.empty())
439  return;
440 
441  p << prefix << '(';
442  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
443  p << std::get<0>(it) << " = " << std::get<1>(it);
444  });
445  p << ")";
446 }
447 
448 void ForOp::print(OpAsmPrinter &p) {
449  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
450  << getUpperBound() << " step " << getStep();
451 
452  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
453  if (!getInitArgs().empty())
454  p << " -> (" << getInitArgs().getTypes() << ')';
455  p << ' ';
456  if (Type t = getInductionVar().getType(); !t.isIndex())
457  p << " : " << t << ' ';
458  p.printRegion(getRegion(),
459  /*printEntryBlockArgs=*/false,
460  /*printBlockTerminators=*/!getInitArgs().empty());
461  p.printOptionalAttrDict((*this)->getAttrs());
462 }
463 
465  auto &builder = parser.getBuilder();
466  Type type;
467 
468  OpAsmParser::Argument inductionVariable;
469  OpAsmParser::UnresolvedOperand lb, ub, step;
470 
471  // Parse the induction variable followed by '='.
472  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
473  // Parse loop bounds.
474  parser.parseOperand(lb) || parser.parseKeyword("to") ||
475  parser.parseOperand(ub) || parser.parseKeyword("step") ||
476  parser.parseOperand(step))
477  return failure();
478 
479  // Parse the optional initial iteration arguments.
482  regionArgs.push_back(inductionVariable);
483 
484  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
485  if (hasIterArgs) {
486  // Parse assignment list and results type list.
487  if (parser.parseAssignmentList(regionArgs, operands) ||
488  parser.parseArrowTypeList(result.types))
489  return failure();
490  }
491 
492  if (regionArgs.size() != result.types.size() + 1)
493  return parser.emitError(
494  parser.getNameLoc(),
495  "mismatch in number of loop-carried values and defined values");
496 
497  // Parse optional type, else assume Index.
498  if (parser.parseOptionalColon())
499  type = builder.getIndexType();
500  else if (parser.parseType(type))
501  return failure();
502 
503  // Resolve input operands.
504  regionArgs.front().type = type;
505  if (parser.resolveOperand(lb, type, result.operands) ||
506  parser.resolveOperand(ub, type, result.operands) ||
507  parser.resolveOperand(step, type, result.operands))
508  return failure();
509  if (hasIterArgs) {
510  for (auto argOperandType :
511  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
512  Type type = std::get<2>(argOperandType);
513  std::get<0>(argOperandType).type = type;
514  if (parser.resolveOperand(std::get<1>(argOperandType), type,
515  result.operands))
516  return failure();
517  }
518  }
519 
520  // Parse the body region.
521  Region *body = result.addRegion();
522  if (parser.parseRegion(*body, regionArgs))
523  return failure();
524 
525  ForOp::ensureTerminator(*body, builder, result.location);
526 
527  // Parse the optional attribute list.
528  if (parser.parseOptionalAttrDict(result.attributes))
529  return failure();
530 
531  return success();
532 }
533 
534 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
535 
536 Block::BlockArgListType ForOp::getRegionIterArgs() {
537  return getBody()->getArguments().drop_front(getNumInductionVars());
538 }
539 
540 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
541  return getInitArgsMutable();
542 }
543 
545 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
546  ValueRange newInitOperands,
547  bool replaceInitOperandUsesInLoop,
548  const NewYieldValuesFn &newYieldValuesFn) {
549  // Create a new loop before the existing one, with the extra operands.
550  OpBuilder::InsertionGuard g(rewriter);
551  rewriter.setInsertionPoint(getOperation());
552  auto inits = llvm::to_vector(getInitArgs());
553  inits.append(newInitOperands.begin(), newInitOperands.end());
554  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
555  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
556  [](OpBuilder &, Location, Value, ValueRange) {});
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
775 namespace {
776 // Fold away ForOp iter arguments when:
777 // 1) The op yields the iter arguments.
778 // 2) The iter arguments have no use and the corresponding outer region
779 // iterators (inputs) are yielded.
780 // 3) The iter arguments have no use and the corresponding (operation) results
781 // have no use.
782 //
783 // These arguments must be defined outside of
784 // the ForOp region and can just be forwarded after simplifying the op inits,
785 // yields and returns.
786 //
787 // The implementation uses `inlineBlockBefore` to steal the content of the
788 // original ForOp and avoid cloning.
789 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
791 
792  LogicalResult matchAndRewrite(scf::ForOp forOp,
793  PatternRewriter &rewriter) const final {
794  bool canonicalize = false;
795 
796  // An internal flat vector of block transfer
797  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
798  // transformed block argument mappings. This plays the role of a
799  // IRMapping for the particular use case of calling into
800  // `inlineBlockBefore`.
801  int64_t numResults = forOp.getNumResults();
802  SmallVector<bool, 4> keepMask;
803  keepMask.reserve(numResults);
804  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
805  newResultValues;
806  newBlockTransferArgs.reserve(1 + numResults);
807  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
808  newIterArgs.reserve(forOp.getInitArgs().size());
809  newYieldValues.reserve(numResults);
810  newResultValues.reserve(numResults);
811  for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
812  forOp.getRegionIterArgs(), // iter inside region
813  forOp.getResults(), // op results
814  forOp.getYieldedValues() // iter yield
815  )) {
816  // Forwarded is `true` when:
817  // 1) The region `iter` argument is yielded.
818  // 2) The region `iter` argument has no use, and the corresponding iter
819  // operand (input) is yielded.
820  // 3) The region `iter` argument has no use, and the corresponding op
821  // result has no use.
822  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
823  (std::get<1>(it).use_empty() &&
824  (std::get<0>(it) == std::get<3>(it) ||
825  std::get<2>(it).use_empty())));
826  keepMask.push_back(!forwarded);
827  canonicalize |= forwarded;
828  if (forwarded) {
829  newBlockTransferArgs.push_back(std::get<0>(it));
830  newResultValues.push_back(std::get<0>(it));
831  continue;
832  }
833  newIterArgs.push_back(std::get<0>(it));
834  newYieldValues.push_back(std::get<3>(it));
835  newBlockTransferArgs.push_back(Value()); // placeholder with null value
836  newResultValues.push_back(Value()); // placeholder with null value
837  }
838 
839  if (!canonicalize)
840  return failure();
841 
842  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
843  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
844  forOp.getStep(), newIterArgs);
845  newForOp->setAttrs(forOp->getAttrs());
846  Block &newBlock = newForOp.getRegion().front();
847 
848  // Replace the null placeholders with newly constructed values.
849  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
850  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
851  idx != e; ++idx) {
852  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
853  Value &newResultVal = newResultValues[idx];
854  assert((blockTransferArg && newResultVal) ||
855  (!blockTransferArg && !newResultVal));
856  if (!blockTransferArg) {
857  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
858  newResultVal = newForOp.getResult(collapsedIdx++);
859  }
860  }
861 
862  Block &oldBlock = forOp.getRegion().front();
863  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
864  "unexpected argument size mismatch");
865 
866  // No results case: the scf::ForOp builder already created a zero
867  // result terminator. Merge before this terminator and just get rid of the
868  // original terminator that has been merged in.
869  if (newIterArgs.empty()) {
870  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
871  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
872  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
873  rewriter.replaceOp(forOp, newResultValues);
874  return success();
875  }
876 
877  // No terminator case: merge and rewrite the merged terminator.
878  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
879  OpBuilder::InsertionGuard g(rewriter);
880  rewriter.setInsertionPoint(mergedTerminator);
881  SmallVector<Value, 4> filteredOperands;
882  filteredOperands.reserve(newResultValues.size());
883  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
884  if (keepMask[idx])
885  filteredOperands.push_back(mergedTerminator.getOperand(idx));
886  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
887  filteredOperands);
888  };
889 
890  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
891  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
892  cloneFilteredTerminator(mergedYieldOp);
893  rewriter.eraseOp(mergedYieldOp);
894  rewriter.replaceOp(forOp, newResultValues);
895  return success();
896  }
897 };
898 
899 /// Util function that tries to compute a constant diff between u and l.
900 /// Returns std::nullopt when the difference between two AffineValueMap is
901 /// dynamic.
902 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
903  IntegerAttr clb, cub;
904  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
905  llvm::APInt lbValue = clb.getValue();
906  llvm::APInt ubValue = cub.getValue();
907  return (ubValue - lbValue).getSExtValue();
908  }
909 
910  // Else a simple pattern match for x + c or c + x
911  llvm::APInt diff;
912  if (matchPattern(
913  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
914  matchPattern(
915  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
916  return diff.getSExtValue();
917  return std::nullopt;
918 }
919 
920 /// Rewriting pattern that erases loops that are known not to iterate, replaces
921 /// single-iteration loops with their bodies, and removes empty loops that
922 /// iterate at least once and only return values defined outside of the loop.
923 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
925 
926  LogicalResult matchAndRewrite(ForOp op,
927  PatternRewriter &rewriter) const override {
928  // If the upper bound is the same as the lower bound, the loop does not
929  // iterate, just remove it.
930  if (op.getLowerBound() == op.getUpperBound()) {
931  rewriter.replaceOp(op, op.getInitArgs());
932  return success();
933  }
934 
935  std::optional<int64_t> diff =
936  computeConstDiff(op.getLowerBound(), op.getUpperBound());
937  if (!diff)
938  return failure();
939 
940  // If the loop is known to have 0 iterations, remove it.
941  if (*diff <= 0) {
942  rewriter.replaceOp(op, op.getInitArgs());
943  return success();
944  }
945 
946  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
947  if (!maybeStepValue)
948  return failure();
949 
950  // If the loop is known to have 1 iteration, inline its body and remove the
951  // loop.
952  llvm::APInt stepValue = *maybeStepValue;
953  if (stepValue.sge(*diff)) {
954  SmallVector<Value, 4> blockArgs;
955  blockArgs.reserve(op.getInitArgs().size() + 1);
956  blockArgs.push_back(op.getLowerBound());
957  llvm::append_range(blockArgs, op.getInitArgs());
958  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
959  return success();
960  }
961 
962  // Now we are left with loops that have more than 1 iterations.
963  Block &block = op.getRegion().front();
964  if (!llvm::hasSingleElement(block))
965  return failure();
966  // If the loop is empty, iterates at least once, and only returns values
967  // defined outside of the loop, remove it and replace it with yield values.
968  if (llvm::any_of(op.getYieldedValues(),
969  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
970  return failure();
971  rewriter.replaceOp(op, op.getYieldedValues());
972  return success();
973  }
974 };
975 
976 /// Perform a replacement of one iter OpOperand of an scf.for to the
977 /// `replacement` value which is expected to be the source of a tensor.cast.
978 /// tensor.cast ops are inserted inside the block to account for the type cast.
979 static SmallVector<Value>
980 replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
981  Value replacement) {
982  Type oldType = operand.get().getType(), newType = replacement.getType();
983  assert(llvm::isa<RankedTensorType>(oldType) &&
984  llvm::isa<RankedTensorType>(newType) &&
985  "expected ranked tensor types");
986 
987  // 1. Create new iter operands, exactly 1 is replaced.
988  ForOp forOp = cast<ForOp>(operand.getOwner());
989  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
990  "expected an iter OpOperand");
991  assert(operand.get().getType() != replacement.getType() &&
992  "Expected a different type");
993  SmallVector<Value> newIterOperands;
994  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
995  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
996  newIterOperands.push_back(replacement);
997  continue;
998  }
999  newIterOperands.push_back(opOperand.get());
1000  }
1001 
1002  // 2. Create the new forOp shell.
1003  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
1004  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1005  forOp.getStep(), newIterOperands);
1006  newForOp->setAttrs(forOp->getAttrs());
1007  Block &newBlock = newForOp.getRegion().front();
1008  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
1009  newBlock.getArguments().end());
1010 
1011  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
1012  // corresponding to the `replacement` value.
1013  OpBuilder::InsertionGuard g(rewriter);
1014  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
1015  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1016  &newForOp->getOpOperand(operand.getOperandNumber()));
1017  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
1018  newRegionIterArg);
1019  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
1020 
1021  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1022  Block &oldBlock = forOp.getRegion().front();
1023  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1024 
1025  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
1026  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1027  rewriter.setInsertionPoint(clonedYieldOp);
1028  unsigned yieldIdx =
1029  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
1030  Value castOut = rewriter.create<tensor::CastOp>(
1031  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1032  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
1033  newYieldOperands[yieldIdx] = castOut;
1034  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1035  rewriter.eraseOp(clonedYieldOp);
1036 
1037  // 6. Inject an outgoing cast op after the forOp.
1038  rewriter.setInsertionPointAfter(newForOp);
1039  SmallVector<Value> newResults = newForOp.getResults();
1040  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
1041  newForOp.getLoc(), oldType, newResults[yieldIdx]);
1042 
1043  return newResults;
1044 }
1045 
1046 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1047 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1048 ///
1049 /// ```
1050 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1051 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1052 /// -> (tensor<?x?xf32>) {
1053 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1054 /// scf.yield %2 : tensor<?x?xf32>
1055 /// }
1056 /// use_of(%1)
1057 /// ```
1058 ///
1059 /// folds into:
1060 ///
1061 /// ```
1062 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1063 /// -> (tensor<32x1024xf32>) {
1064 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1065 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1066 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1067 /// scf.yield %4 : tensor<32x1024xf32>
1068 /// }
1069 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1070 /// use_of(%1)
1071 /// ```
1072 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1074 
1075  LogicalResult matchAndRewrite(ForOp op,
1076  PatternRewriter &rewriter) const override {
1077  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1078  OpOperand &iterOpOperand = std::get<0>(it);
1079  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1080  if (!incomingCast ||
1081  incomingCast.getSource().getType() == incomingCast.getType())
1082  continue;
1083  // If the dest type of the cast does not preserve static information in
1084  // the source type.
1086  incomingCast.getDest().getType(),
1087  incomingCast.getSource().getType()))
1088  continue;
1089  if (!std::get<1>(it).hasOneUse())
1090  continue;
1091 
1092  // Create a new ForOp with that iter operand replaced.
1093  rewriter.replaceOp(
1094  op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1095  incomingCast.getSource()));
1096  return success();
1097  }
1098  return failure();
1099  }
1100 };
1101 
1102 } // namespace
1103 
1104 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1105  MLIRContext *context) {
1106  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1107  context);
1108 }
1109 
1110 std::optional<APInt> ForOp::getConstantStep() {
1111  IntegerAttr step;
1112  if (matchPattern(getStep(), m_Constant(&step)))
1113  return step.getValue();
1114  return {};
1115 }
1116 
1117 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1118  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1119 }
1120 
1121 Speculation::Speculatability ForOp::getSpeculatability() {
1122  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1123  // and End.
1124  if (auto constantStep = getConstantStep())
1125  if (*constantStep == 1)
1127 
1128  // For Step != 1, the loop may not terminate. We can add more smarts here if
1129  // needed.
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // ForallOp
1135 //===----------------------------------------------------------------------===//
1136 
1138  unsigned numLoops = getRank();
1139  // Check number of outputs.
1140  if (getNumResults() != getOutputs().size())
1141  return emitOpError("produces ")
1142  << getNumResults() << " results, but has only "
1143  << getOutputs().size() << " outputs";
1144 
1145  // Check that the body defines block arguments for thread indices and outputs.
1146  auto *body = getBody();
1147  if (body->getNumArguments() != numLoops + getOutputs().size())
1148  return emitOpError("region expects ") << numLoops << " arguments";
1149  for (int64_t i = 0; i < numLoops; ++i)
1150  if (!body->getArgument(i).getType().isIndex())
1151  return emitOpError("expects ")
1152  << i << "-th block argument to be an index";
1153  for (unsigned i = 0; i < getOutputs().size(); ++i)
1154  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1155  return emitOpError("type mismatch between ")
1156  << i << "-th output and corresponding block argument";
1157  if (getMapping().has_value() && !getMapping()->empty()) {
1158  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1159  return emitOpError() << "mapping attribute size must match op rank";
1160  for (auto map : getMapping()->getValue()) {
1161  if (!isa<DeviceMappingAttrInterface>(map))
1162  return emitOpError()
1163  << getMappingAttrName() << " is not device mapping attribute";
1164  }
1165  }
1166 
1167  // Verify mixed static/dynamic control variables.
1168  Operation *op = getOperation();
1169  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1170  getStaticLowerBound(),
1171  getDynamicLowerBound())))
1172  return failure();
1173  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1174  getStaticUpperBound(),
1175  getDynamicUpperBound())))
1176  return failure();
1177  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1178  getStaticStep(), getDynamicStep())))
1179  return failure();
1180 
1181  return success();
1182 }
1183 
1184 void ForallOp::print(OpAsmPrinter &p) {
1185  Operation *op = getOperation();
1186  p << " (" << getInductionVars();
1187  if (isNormalized()) {
1188  p << ") in ";
1189  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1190  /*valueTypes=*/{}, /*scalables=*/{},
1192  } else {
1193  p << ") = ";
1194  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1195  /*valueTypes=*/{}, /*scalables=*/{},
1197  p << " to ";
1198  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1199  /*valueTypes=*/{}, /*scalables=*/{},
1201  p << " step ";
1202  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1203  /*valueTypes=*/{}, /*scalables=*/{},
1205  }
1206  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1207  p << " ";
1208  if (!getRegionOutArgs().empty())
1209  p << "-> (" << getResultTypes() << ") ";
1210  p.printRegion(getRegion(),
1211  /*printEntryBlockArgs=*/false,
1212  /*printBlockTerminators=*/getNumResults() > 0);
1213  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1214  getStaticLowerBoundAttrName(),
1215  getStaticUpperBoundAttrName(),
1216  getStaticStepAttrName()});
1217 }
1218 
1220  OpBuilder b(parser.getContext());
1221  auto indexType = b.getIndexType();
1222 
1223  // Parse an opening `(` followed by thread index variables followed by `)`
1224  // TODO: when we can refer to such "induction variable"-like handles from the
1225  // declarative assembly format, we can implement the parser as a custom hook.
1228  return failure();
1229 
1230  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1231  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1232  dynamicSteps;
1233  if (succeeded(parser.parseOptionalKeyword("in"))) {
1234  // Parse upper bounds.
1235  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1236  /*valueTypes=*/nullptr,
1238  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1239  return failure();
1240 
1241  unsigned numLoops = ivs.size();
1242  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1243  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1244  } else {
1245  // Parse lower bounds.
1246  if (parser.parseEqual() ||
1247  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1248  /*valueTypes=*/nullptr,
1250 
1251  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1252  return failure();
1253 
1254  // Parse upper bounds.
1255  if (parser.parseKeyword("to") ||
1256  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1257  /*valueTypes=*/nullptr,
1259  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1260  return failure();
1261 
1262  // Parse step values.
1263  if (parser.parseKeyword("step") ||
1264  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1265  /*valueTypes=*/nullptr,
1267  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1268  return failure();
1269  }
1270 
1271  // Parse out operands and results.
1274  SMLoc outOperandsLoc = parser.getCurrentLocation();
1275  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1276  if (outOperands.size() != result.types.size())
1277  return parser.emitError(outOperandsLoc,
1278  "mismatch between out operands and types");
1279  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1280  parser.parseOptionalArrowTypeList(result.types) ||
1281  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1282  result.operands))
1283  return failure();
1284  }
1285 
1286  // Parse region.
1288  std::unique_ptr<Region> region = std::make_unique<Region>();
1289  for (auto &iv : ivs) {
1290  iv.type = b.getIndexType();
1291  regionArgs.push_back(iv);
1292  }
1293  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1294  auto &out = it.value();
1295  out.type = result.types[it.index()];
1296  regionArgs.push_back(out);
1297  }
1298  if (parser.parseRegion(*region, regionArgs))
1299  return failure();
1300 
1301  // Ensure terminator and move region.
1302  ForallOp::ensureTerminator(*region, b, result.location);
1303  result.addRegion(std::move(region));
1304 
1305  // Parse the optional attribute list.
1306  if (parser.parseOptionalAttrDict(result.attributes))
1307  return failure();
1308 
1309  result.addAttribute("staticLowerBound", staticLbs);
1310  result.addAttribute("staticUpperBound", staticUbs);
1311  result.addAttribute("staticStep", staticSteps);
1312  result.addAttribute("operandSegmentSizes",
1314  {static_cast<int32_t>(dynamicLbs.size()),
1315  static_cast<int32_t>(dynamicUbs.size()),
1316  static_cast<int32_t>(dynamicSteps.size()),
1317  static_cast<int32_t>(outOperands.size())}));
1318  return success();
1319 }
1320 
1321 // Builder that takes loop bounds.
1322 void ForallOp::build(
1325  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1326  std::optional<ArrayAttr> mapping,
1327  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1328  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1329  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1330  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1331  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1332  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1333 
1334  result.addOperands(dynamicLbs);
1335  result.addOperands(dynamicUbs);
1336  result.addOperands(dynamicSteps);
1337  result.addOperands(outputs);
1338  result.addTypes(TypeRange(outputs));
1339 
1340  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1341  b.getDenseI64ArrayAttr(staticLbs));
1342  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1343  b.getDenseI64ArrayAttr(staticUbs));
1344  result.addAttribute(getStaticStepAttrName(result.name),
1345  b.getDenseI64ArrayAttr(staticSteps));
1346  result.addAttribute(
1347  "operandSegmentSizes",
1348  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1349  static_cast<int32_t>(dynamicUbs.size()),
1350  static_cast<int32_t>(dynamicSteps.size()),
1351  static_cast<int32_t>(outputs.size())}));
1352  if (mapping.has_value()) {
1354  mapping.value());
1355  }
1356 
1357  Region *bodyRegion = result.addRegion();
1359  b.createBlock(bodyRegion);
1360  Block &bodyBlock = bodyRegion->front();
1361 
1362  // Add block arguments for indices and outputs.
1363  bodyBlock.addArguments(
1364  SmallVector<Type>(lbs.size(), b.getIndexType()),
1365  SmallVector<Location>(staticLbs.size(), result.location));
1366  bodyBlock.addArguments(
1367  TypeRange(outputs),
1368  SmallVector<Location>(outputs.size(), result.location));
1369 
1370  b.setInsertionPointToStart(&bodyBlock);
1371  if (!bodyBuilderFn) {
1372  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1373  return;
1374  }
1375  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1376 }
1377 
1378 // Builder that takes loop bounds.
1379 void ForallOp::build(
1381  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1382  std::optional<ArrayAttr> mapping,
1383  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1384  unsigned numLoops = ubs.size();
1385  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1386  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1387  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1388 }
1389 
1390 // Checks if the lbs are zeros and steps are ones.
1391 bool ForallOp::isNormalized() {
1392  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1393  return llvm::all_of(results, [&](OpFoldResult ofr) {
1394  auto intValue = getConstantIntValue(ofr);
1395  return intValue.has_value() && intValue == val;
1396  });
1397  };
1398  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1399 }
1400 
1401 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1402 // unaware of the fact that our terminator also needs a region to be
1403 // well-formed. We override it here to ensure that we do the right thing.
1404 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1405  Location loc) {
1407  ForallOp>::ensureTerminator(region, builder, loc);
1408  auto terminator =
1409  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1410  if (terminator.getRegion().empty())
1411  builder.createBlock(&terminator.getRegion());
1412 }
1413 
1414 InParallelOp ForallOp::getTerminator() {
1415  return cast<InParallelOp>(getBody()->getTerminator());
1416 }
1417 
1418 std::optional<Value> ForallOp::getSingleInductionVar() {
1419  if (getRank() != 1)
1420  return std::nullopt;
1421  return getInductionVar(0);
1422 }
1423 
1424 std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1425  if (getRank() != 1)
1426  return std::nullopt;
1427  return getMixedLowerBound()[0];
1428 }
1429 
1430 std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1431  if (getRank() != 1)
1432  return std::nullopt;
1433  return getMixedUpperBound()[0];
1434 }
1435 
1436 std::optional<OpFoldResult> ForallOp::getSingleStep() {
1437  if (getRank() != 1)
1438  return std::nullopt;
1439  return getMixedStep()[0];
1440 }
1441 
1443  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1444  if (!tidxArg)
1445  return ForallOp();
1446  assert(tidxArg.getOwner() && "unlinked block argument");
1447  auto *containingOp = tidxArg.getOwner()->getParentOp();
1448  return dyn_cast<ForallOp>(containingOp);
1449 }
1450 
1451 namespace {
1452 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1453 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1455 
1456  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1457  PatternRewriter &rewriter) const final {
1458  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1459  if (!forallOp)
1460  return failure();
1461  Value sharedOut =
1462  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1463  ->get();
1464  rewriter.modifyOpInPlace(
1465  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1466  return success();
1467  }
1468 };
1469 
1470 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1471 public:
1473 
1474  LogicalResult matchAndRewrite(ForallOp op,
1475  PatternRewriter &rewriter) const override {
1476  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1477  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1478  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1479  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1480  failed(foldDynamicIndexList(mixedUpperBound)) &&
1481  failed(foldDynamicIndexList(mixedStep)))
1482  return failure();
1483 
1484  rewriter.modifyOpInPlace(op, [&]() {
1485  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1486  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1487  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1488  staticLowerBound);
1489  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1490  op.setStaticLowerBound(staticLowerBound);
1491 
1492  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1493  staticUpperBound);
1494  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1495  op.setStaticUpperBound(staticUpperBound);
1496 
1497  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1498  op.getDynamicStepMutable().assign(dynamicStep);
1499  op.setStaticStep(staticStep);
1500 
1501  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1502  rewriter.getDenseI32ArrayAttr(
1503  {static_cast<int32_t>(dynamicLowerBound.size()),
1504  static_cast<int32_t>(dynamicUpperBound.size()),
1505  static_cast<int32_t>(dynamicStep.size()),
1506  static_cast<int32_t>(op.getNumResults())}));
1507  });
1508  return success();
1509  }
1510 };
1511 
1512 struct ForallOpSingleOrZeroIterationDimsFolder
1513  : public OpRewritePattern<ForallOp> {
1515 
1516  LogicalResult matchAndRewrite(ForallOp op,
1517  PatternRewriter &rewriter) const override {
1518  // Do not fold dimensions if they are mapped to processing units.
1519  if (op.getMapping().has_value())
1520  return failure();
1521  Location loc = op.getLoc();
1522 
1523  // Compute new loop bounds that omit all single-iteration loop dimensions.
1524  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1525  newMixedSteps;
1526  IRMapping mapping;
1527  for (auto [lb, ub, step, iv] :
1528  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1529  op.getMixedStep(), op.getInductionVars())) {
1530  auto numIterations = constantTripCount(lb, ub, step);
1531  if (numIterations.has_value()) {
1532  // Remove the loop if it performs zero iterations.
1533  if (*numIterations == 0) {
1534  rewriter.replaceOp(op, op.getOutputs());
1535  return success();
1536  }
1537  // Replace the loop induction variable by the lower bound if the loop
1538  // performs a single iteration. Otherwise, copy the loop bounds.
1539  if (*numIterations == 1) {
1540  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1541  continue;
1542  }
1543  }
1544  newMixedLowerBounds.push_back(lb);
1545  newMixedUpperBounds.push_back(ub);
1546  newMixedSteps.push_back(step);
1547  }
1548  // Exit if none of the loop dimensions perform a single iteration.
1549  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1550  return rewriter.notifyMatchFailure(
1551  op, "no dimensions have 0 or 1 iterations");
1552  }
1553 
1554  // All of the loop dimensions perform a single iteration. Inline loop body.
1555  if (newMixedLowerBounds.empty()) {
1556  promote(rewriter, op);
1557  return success();
1558  }
1559 
1560  // Replace the loop by a lower-dimensional loop.
1561  ForallOp newOp;
1562  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1563  newMixedUpperBounds, newMixedSteps,
1564  op.getOutputs(), std::nullopt, nullptr);
1565  newOp.getBodyRegion().getBlocks().clear();
1566  // The new loop needs to keep all attributes from the old one, except for
1567  // "operandSegmentSizes" and static loop bound attributes which capture
1568  // the outdated information of the old iteration domain.
1569  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1570  newOp.getStaticLowerBoundAttrName(),
1571  newOp.getStaticUpperBoundAttrName(),
1572  newOp.getStaticStepAttrName()};
1573  for (const auto &namedAttr : op->getAttrs()) {
1574  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1575  continue;
1576  rewriter.modifyOpInPlace(newOp, [&]() {
1577  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1578  });
1579  }
1580  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1581  newOp.getRegion().begin(), mapping);
1582  rewriter.replaceOp(op, newOp.getResults());
1583  return success();
1584  }
1585 };
1586 
1587 struct FoldTensorCastOfOutputIntoForallOp
1588  : public OpRewritePattern<scf::ForallOp> {
1590 
1591  struct TypeCast {
1592  Type srcType;
1593  Type dstType;
1594  };
1595 
1596  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1597  PatternRewriter &rewriter) const final {
1598  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1599  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1600  for (auto en : llvm::enumerate(newOutputTensors)) {
1601  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1602  if (!castOp)
1603  continue;
1604 
1605  // Only casts that that preserve static information, i.e. will make the
1606  // loop result type "more" static than before, will be folded.
1607  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1608  castOp.getSource().getType())) {
1609  continue;
1610  }
1611 
1612  tensorCastProducers[en.index()] =
1613  TypeCast{castOp.getSource().getType(), castOp.getType()};
1614  newOutputTensors[en.index()] = castOp.getSource();
1615  }
1616 
1617  if (tensorCastProducers.empty())
1618  return failure();
1619 
1620  // Create new loop.
1621  Location loc = forallOp.getLoc();
1622  auto newForallOp = rewriter.create<ForallOp>(
1623  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1624  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1625  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1626  auto castBlockArgs =
1627  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1628  for (auto [index, cast] : tensorCastProducers) {
1629  Value &oldTypeBBArg = castBlockArgs[index];
1630  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1631  nestedLoc, cast.dstType, oldTypeBBArg);
1632  }
1633 
1634  // Move old body into new parallel loop.
1635  SmallVector<Value> ivsBlockArgs =
1636  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1637  ivsBlockArgs.append(castBlockArgs);
1638  rewriter.mergeBlocks(forallOp.getBody(),
1639  bbArgs.front().getParentBlock(), ivsBlockArgs);
1640  });
1641 
1642  // After `mergeBlocks` happened, the destinations in the terminator were
1643  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1644  // destination have to be updated to point to the output bbArgs directly.
1645  auto terminator = newForallOp.getTerminator();
1646  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1647  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1648  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1649  insertSliceOp.getDestMutable().assign(outputBlockArg);
1650  }
1651 
1652  // Cast results back to the original types.
1653  rewriter.setInsertionPointAfter(newForallOp);
1654  SmallVector<Value> castResults = newForallOp.getResults();
1655  for (auto &item : tensorCastProducers) {
1656  Value &oldTypeResult = castResults[item.first];
1657  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1658  oldTypeResult);
1659  }
1660  rewriter.replaceOp(forallOp, castResults);
1661  return success();
1662  }
1663 };
1664 
1665 } // namespace
1666 
1667 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1668  MLIRContext *context) {
1669  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1670  ForallOpControlOperandsFolder,
1671  ForallOpSingleOrZeroIterationDimsFolder>(context);
1672 }
1673 
1674 /// Given the region at `index`, or the parent operation if `index` is None,
1675 /// return the successor regions. These are the regions that may be selected
1676 /// during the flow of control. `operands` is a set of optional attributes that
1677 /// correspond to a constant value for each operand, or null if that operand is
1678 /// not a constant.
1679 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1681  // Both the operation itself and the region may be branching into the body or
1682  // back into the operation itself. It is possible for loop not to enter the
1683  // body.
1684  regions.push_back(RegionSuccessor(&getRegion()));
1685  regions.push_back(RegionSuccessor());
1686 }
1687 
1688 //===----------------------------------------------------------------------===//
1689 // InParallelOp
1690 //===----------------------------------------------------------------------===//
1691 
1692 // Build a InParallelOp with mixed static and dynamic entries.
1693 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1695  Region *bodyRegion = result.addRegion();
1696  b.createBlock(bodyRegion);
1697 }
1698 
1700  scf::ForallOp forallOp =
1701  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1702  if (!forallOp)
1703  return this->emitOpError("expected forall op parent");
1704 
1705  // TODO: InParallelOpInterface.
1706  for (Operation &op : getRegion().front().getOperations()) {
1707  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1708  return this->emitOpError("expected only ")
1709  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1710  }
1711 
1712  // Verify that inserts are into out block arguments.
1713  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1714  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1715  if (!llvm::is_contained(regionOutArgs, dest))
1716  return op.emitOpError("may only insert into an output block argument");
1717  }
1718  return success();
1719 }
1720 
1722  p << " ";
1723  p.printRegion(getRegion(),
1724  /*printEntryBlockArgs=*/false,
1725  /*printBlockTerminators=*/false);
1726  p.printOptionalAttrDict(getOperation()->getAttrs());
1727 }
1728 
1730  auto &builder = parser.getBuilder();
1731 
1733  std::unique_ptr<Region> region = std::make_unique<Region>();
1734  if (parser.parseRegion(*region, regionOperands))
1735  return failure();
1736 
1737  if (region->empty())
1738  OpBuilder(builder.getContext()).createBlock(region.get());
1739  result.addRegion(std::move(region));
1740 
1741  // Parse the optional attribute list.
1742  if (parser.parseOptionalAttrDict(result.attributes))
1743  return failure();
1744  return success();
1745 }
1746 
1747 OpResult InParallelOp::getParentResult(int64_t idx) {
1748  return getOperation()->getParentOp()->getResult(idx);
1749 }
1750 
1751 SmallVector<BlockArgument> InParallelOp::getDests() {
1752  return llvm::to_vector<4>(
1753  llvm::map_range(getYieldingOps(), [](Operation &op) {
1754  // Add new ops here as needed.
1755  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1756  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1757  }));
1758 }
1759 
1760 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1761  return getRegion().front().getOperations();
1762 }
1763 
1764 //===----------------------------------------------------------------------===//
1765 // IfOp
1766 //===----------------------------------------------------------------------===//
1767 
1769  assert(a && "expected non-empty operation");
1770  assert(b && "expected non-empty operation");
1771 
1772  IfOp ifOp = a->getParentOfType<IfOp>();
1773  while (ifOp) {
1774  // Check if b is inside ifOp. (We already know that a is.)
1775  if (ifOp->isProperAncestor(b))
1776  // b is contained in ifOp. a and b are in mutually exclusive branches if
1777  // they are in different blocks of ifOp.
1778  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1779  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1780  // Check next enclosing IfOp.
1781  ifOp = ifOp->getParentOfType<IfOp>();
1782  }
1783 
1784  // Could not find a common IfOp among a's and b's ancestors.
1785  return false;
1786 }
1787 
1789 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1790  IfOp::Adaptor adaptor,
1791  SmallVectorImpl<Type> &inferredReturnTypes) {
1792  if (adaptor.getRegions().empty())
1793  return failure();
1794  Region *r = &adaptor.getThenRegion();
1795  if (r->empty())
1796  return failure();
1797  Block &b = r->front();
1798  if (b.empty())
1799  return failure();
1800  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1801  if (!yieldOp)
1802  return failure();
1803  TypeRange types = yieldOp.getOperandTypes();
1804  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1805  types.end());
1806  return success();
1807 }
1808 
1809 void IfOp::build(OpBuilder &builder, OperationState &result,
1810  TypeRange resultTypes, Value cond) {
1811  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1812  /*addElseBlock=*/false);
1813 }
1814 
1815 void IfOp::build(OpBuilder &builder, OperationState &result,
1816  TypeRange resultTypes, Value cond, bool addThenBlock,
1817  bool addElseBlock) {
1818  assert((!addElseBlock || addThenBlock) &&
1819  "must not create else block w/o then block");
1820  result.addTypes(resultTypes);
1821  result.addOperands(cond);
1822 
1823  // Add regions and blocks.
1824  OpBuilder::InsertionGuard guard(builder);
1825  Region *thenRegion = result.addRegion();
1826  if (addThenBlock)
1827  builder.createBlock(thenRegion);
1828  Region *elseRegion = result.addRegion();
1829  if (addElseBlock)
1830  builder.createBlock(elseRegion);
1831 }
1832 
1833 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1834  bool withElseRegion) {
1835  build(builder, result, TypeRange{}, cond, withElseRegion);
1836 }
1837 
1838 void IfOp::build(OpBuilder &builder, OperationState &result,
1839  TypeRange resultTypes, Value cond, bool withElseRegion) {
1840  result.addTypes(resultTypes);
1841  result.addOperands(cond);
1842 
1843  // Build then region.
1844  OpBuilder::InsertionGuard guard(builder);
1845  Region *thenRegion = result.addRegion();
1846  builder.createBlock(thenRegion);
1847  if (resultTypes.empty())
1848  IfOp::ensureTerminator(*thenRegion, builder, result.location);
1849 
1850  // Build else region.
1851  Region *elseRegion = result.addRegion();
1852  if (withElseRegion) {
1853  builder.createBlock(elseRegion);
1854  if (resultTypes.empty())
1855  IfOp::ensureTerminator(*elseRegion, builder, result.location);
1856  }
1857 }
1858 
1859 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1860  function_ref<void(OpBuilder &, Location)> thenBuilder,
1861  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1862  assert(thenBuilder && "the builder callback for 'then' must be present");
1863  result.addOperands(cond);
1864 
1865  // Build then region.
1866  OpBuilder::InsertionGuard guard(builder);
1867  Region *thenRegion = result.addRegion();
1868  builder.createBlock(thenRegion);
1869  thenBuilder(builder, result.location);
1870 
1871  // Build else region.
1872  Region *elseRegion = result.addRegion();
1873  if (elseBuilder) {
1874  builder.createBlock(elseRegion);
1875  elseBuilder(builder, result.location);
1876  }
1877 
1878  // Infer result types.
1879  SmallVector<Type> inferredReturnTypes;
1880  MLIRContext *ctx = builder.getContext();
1881  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
1882  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
1883  /*properties=*/nullptr, result.regions,
1884  inferredReturnTypes))) {
1885  result.addTypes(inferredReturnTypes);
1886  }
1887 }
1888 
1890  if (getNumResults() != 0 && getElseRegion().empty())
1891  return emitOpError("must have an else block if defining values");
1892  return success();
1893 }
1894 
1896  // Create the regions for 'then'.
1897  result.regions.reserve(2);
1898  Region *thenRegion = result.addRegion();
1899  Region *elseRegion = result.addRegion();
1900 
1901  auto &builder = parser.getBuilder();
1903  Type i1Type = builder.getIntegerType(1);
1904  if (parser.parseOperand(cond) ||
1905  parser.resolveOperand(cond, i1Type, result.operands))
1906  return failure();
1907  // Parse optional results type list.
1908  if (parser.parseOptionalArrowTypeList(result.types))
1909  return failure();
1910  // Parse the 'then' region.
1911  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1912  return failure();
1913  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1914 
1915  // If we find an 'else' keyword then parse the 'else' region.
1916  if (!parser.parseOptionalKeyword("else")) {
1917  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1918  return failure();
1919  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1920  }
1921 
1922  // Parse the optional attribute list.
1923  if (parser.parseOptionalAttrDict(result.attributes))
1924  return failure();
1925  return success();
1926 }
1927 
1928 void IfOp::print(OpAsmPrinter &p) {
1929  bool printBlockTerminators = false;
1930 
1931  p << " " << getCondition();
1932  if (!getResults().empty()) {
1933  p << " -> (" << getResultTypes() << ")";
1934  // Print yield explicitly if the op defines values.
1935  printBlockTerminators = true;
1936  }
1937  p << ' ';
1938  p.printRegion(getThenRegion(),
1939  /*printEntryBlockArgs=*/false,
1940  /*printBlockTerminators=*/printBlockTerminators);
1941 
1942  // Print the 'else' regions if it exists and has a block.
1943  auto &elseRegion = getElseRegion();
1944  if (!elseRegion.empty()) {
1945  p << " else ";
1946  p.printRegion(elseRegion,
1947  /*printEntryBlockArgs=*/false,
1948  /*printBlockTerminators=*/printBlockTerminators);
1949  }
1950 
1951  p.printOptionalAttrDict((*this)->getAttrs());
1952 }
1953 
1954 void IfOp::getSuccessorRegions(RegionBranchPoint point,
1956  // The `then` and the `else` region branch back to the parent operation.
1957  if (!point.isParent()) {
1958  regions.push_back(RegionSuccessor(getResults()));
1959  return;
1960  }
1961 
1962  regions.push_back(RegionSuccessor(&getThenRegion()));
1963 
1964  // Don't consider the else region if it is empty.
1965  Region *elseRegion = &this->getElseRegion();
1966  if (elseRegion->empty())
1967  regions.push_back(RegionSuccessor());
1968  else
1969  regions.push_back(RegionSuccessor(elseRegion));
1970 }
1971 
1972 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
1974  FoldAdaptor adaptor(operands, *this);
1975  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
1976  if (!boolAttr || boolAttr.getValue())
1977  regions.emplace_back(&getThenRegion());
1978 
1979  // If the else region is empty, execution continues after the parent op.
1980  if (!boolAttr || !boolAttr.getValue()) {
1981  if (!getElseRegion().empty())
1982  regions.emplace_back(&getElseRegion());
1983  else
1984  regions.emplace_back(getResults());
1985  }
1986 }
1987 
1988 LogicalResult IfOp::fold(FoldAdaptor adaptor,
1989  SmallVectorImpl<OpFoldResult> &results) {
1990  // if (!c) then A() else B() -> if c then B() else A()
1991  if (getElseRegion().empty())
1992  return failure();
1993 
1994  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1995  if (!xorStmt)
1996  return failure();
1997 
1998  if (!matchPattern(xorStmt.getRhs(), m_One()))
1999  return failure();
2000 
2001  getConditionMutable().assign(xorStmt.getLhs());
2002  Block *thenBlock = &getThenRegion().front();
2003  // It would be nicer to use iplist::swap, but that has no implemented
2004  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2005  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2006  getElseRegion().getBlocks());
2007  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2008  getThenRegion().getBlocks(), thenBlock);
2009  return success();
2010 }
2011 
2012 void IfOp::getRegionInvocationBounds(
2013  ArrayRef<Attribute> operands,
2014  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2015  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2016  // If the condition is known, then one region is known to be executed once
2017  // and the other zero times.
2018  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2019  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2020  } else {
2021  // Non-constant condition. Each region may be executed 0 or 1 times.
2022  invocationBounds.assign(2, {0, 1});
2023  }
2024 }
2025 
2026 namespace {
2027 // Pattern to remove unused IfOp results.
2028 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2030 
2031  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2032  PatternRewriter &rewriter) const {
2033  // Move all operations to the destination block.
2034  rewriter.mergeBlocks(source, dest);
2035  // Replace the yield op by one that returns only the used values.
2036  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2037  SmallVector<Value, 4> usedOperands;
2038  llvm::transform(usedResults, std::back_inserter(usedOperands),
2039  [&](OpResult result) {
2040  return yieldOp.getOperand(result.getResultNumber());
2041  });
2042  rewriter.modifyOpInPlace(yieldOp,
2043  [&]() { yieldOp->setOperands(usedOperands); });
2044  }
2045 
2046  LogicalResult matchAndRewrite(IfOp op,
2047  PatternRewriter &rewriter) const override {
2048  // Compute the list of used results.
2049  SmallVector<OpResult, 4> usedResults;
2050  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2051  [](OpResult result) { return !result.use_empty(); });
2052 
2053  // Replace the operation if only a subset of its results have uses.
2054  if (usedResults.size() == op.getNumResults())
2055  return failure();
2056 
2057  // Compute the result types of the replacement operation.
2058  SmallVector<Type, 4> newTypes;
2059  llvm::transform(usedResults, std::back_inserter(newTypes),
2060  [](OpResult result) { return result.getType(); });
2061 
2062  // Create a replacement operation with empty then and else regions.
2063  auto newOp =
2064  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2065  rewriter.createBlock(&newOp.getThenRegion());
2066  rewriter.createBlock(&newOp.getElseRegion());
2067 
2068  // Move the bodies and replace the terminators (note there is a then and
2069  // an else region since the operation returns results).
2070  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2071  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2072 
2073  // Replace the operation by the new one.
2074  SmallVector<Value, 4> repResults(op.getNumResults());
2075  for (const auto &en : llvm::enumerate(usedResults))
2076  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2077  rewriter.replaceOp(op, repResults);
2078  return success();
2079  }
2080 };
2081 
2082 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2084 
2085  LogicalResult matchAndRewrite(IfOp op,
2086  PatternRewriter &rewriter) const override {
2087  BoolAttr condition;
2088  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2089  return failure();
2090 
2091  if (condition.getValue())
2092  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2093  else if (!op.getElseRegion().empty())
2094  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2095  else
2096  rewriter.eraseOp(op);
2097 
2098  return success();
2099  }
2100 };
2101 
2102 /// Hoist any yielded results whose operands are defined outside
2103 /// the if, to a select instruction.
2104 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2106 
2107  LogicalResult matchAndRewrite(IfOp op,
2108  PatternRewriter &rewriter) const override {
2109  if (op->getNumResults() == 0)
2110  return failure();
2111 
2112  auto cond = op.getCondition();
2113  auto thenYieldArgs = op.thenYield().getOperands();
2114  auto elseYieldArgs = op.elseYield().getOperands();
2115 
2116  SmallVector<Type> nonHoistable;
2117  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2118  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2119  &op.getElseRegion() == falseVal.getParentRegion())
2120  nonHoistable.push_back(trueVal.getType());
2121  }
2122  // Early exit if there aren't any yielded values we can
2123  // hoist outside the if.
2124  if (nonHoistable.size() == op->getNumResults())
2125  return failure();
2126 
2127  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2128  /*withElseRegion=*/false);
2129  if (replacement.thenBlock())
2130  rewriter.eraseBlock(replacement.thenBlock());
2131  replacement.getThenRegion().takeBody(op.getThenRegion());
2132  replacement.getElseRegion().takeBody(op.getElseRegion());
2133 
2134  SmallVector<Value> results(op->getNumResults());
2135  assert(thenYieldArgs.size() == results.size());
2136  assert(elseYieldArgs.size() == results.size());
2137 
2138  SmallVector<Value> trueYields;
2139  SmallVector<Value> falseYields;
2140  rewriter.setInsertionPoint(replacement);
2141  for (const auto &it :
2142  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2143  Value trueVal = std::get<0>(it.value());
2144  Value falseVal = std::get<1>(it.value());
2145  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2146  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2147  results[it.index()] = replacement.getResult(trueYields.size());
2148  trueYields.push_back(trueVal);
2149  falseYields.push_back(falseVal);
2150  } else if (trueVal == falseVal)
2151  results[it.index()] = trueVal;
2152  else
2153  results[it.index()] = rewriter.create<arith::SelectOp>(
2154  op.getLoc(), cond, trueVal, falseVal);
2155  }
2156 
2157  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2158  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2159 
2160  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2161  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2162 
2163  rewriter.replaceOp(op, results);
2164  return success();
2165  }
2166 };
2167 
2168 /// Allow the true region of an if to assume the condition is true
2169 /// and vice versa. For example:
2170 ///
2171 /// scf.if %cmp {
2172 /// print(%cmp)
2173 /// }
2174 ///
2175 /// becomes
2176 ///
2177 /// scf.if %cmp {
2178 /// print(true)
2179 /// }
2180 ///
2181 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2183 
2184  LogicalResult matchAndRewrite(IfOp op,
2185  PatternRewriter &rewriter) const override {
2186  // Early exit if the condition is constant since replacing a constant
2187  // in the body with another constant isn't a simplification.
2188  if (matchPattern(op.getCondition(), m_Constant()))
2189  return failure();
2190 
2191  bool changed = false;
2192  mlir::Type i1Ty = rewriter.getI1Type();
2193 
2194  // These variables serve to prevent creating duplicate constants
2195  // and hold constant true or false values.
2196  Value constantTrue = nullptr;
2197  Value constantFalse = nullptr;
2198 
2199  for (OpOperand &use :
2200  llvm::make_early_inc_range(op.getCondition().getUses())) {
2201  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2202  changed = true;
2203 
2204  if (!constantTrue)
2205  constantTrue = rewriter.create<arith::ConstantOp>(
2206  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2207 
2208  rewriter.modifyOpInPlace(use.getOwner(),
2209  [&]() { use.set(constantTrue); });
2210  } else if (op.getElseRegion().isAncestor(
2211  use.getOwner()->getParentRegion())) {
2212  changed = true;
2213 
2214  if (!constantFalse)
2215  constantFalse = rewriter.create<arith::ConstantOp>(
2216  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2217 
2218  rewriter.modifyOpInPlace(use.getOwner(),
2219  [&]() { use.set(constantFalse); });
2220  }
2221  }
2222 
2223  return success(changed);
2224  }
2225 };
2226 
2227 /// Remove any statements from an if that are equivalent to the condition
2228 /// or its negation. For example:
2229 ///
2230 /// %res:2 = scf.if %cmp {
2231 /// yield something(), true
2232 /// } else {
2233 /// yield something2(), false
2234 /// }
2235 /// print(%res#1)
2236 ///
2237 /// becomes
2238 /// %res = scf.if %cmp {
2239 /// yield something()
2240 /// } else {
2241 /// yield something2()
2242 /// }
2243 /// print(%cmp)
2244 ///
2245 /// Additionally if both branches yield the same value, replace all uses
2246 /// of the result with the yielded value.
2247 ///
2248 /// %res:2 = scf.if %cmp {
2249 /// yield something(), %arg1
2250 /// } else {
2251 /// yield something2(), %arg1
2252 /// }
2253 /// print(%res#1)
2254 ///
2255 /// becomes
2256 /// %res = scf.if %cmp {
2257 /// yield something()
2258 /// } else {
2259 /// yield something2()
2260 /// }
2261 /// print(%arg1)
2262 ///
2263 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2265 
2266  LogicalResult matchAndRewrite(IfOp op,
2267  PatternRewriter &rewriter) const override {
2268  // Early exit if there are no results that could be replaced.
2269  if (op.getNumResults() == 0)
2270  return failure();
2271 
2272  auto trueYield =
2273  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2274  auto falseYield =
2275  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2276 
2277  rewriter.setInsertionPoint(op->getBlock(),
2278  op.getOperation()->getIterator());
2279  bool changed = false;
2280  Type i1Ty = rewriter.getI1Type();
2281  for (auto [trueResult, falseResult, opResult] :
2282  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2283  op.getResults())) {
2284  if (trueResult == falseResult) {
2285  if (!opResult.use_empty()) {
2286  opResult.replaceAllUsesWith(trueResult);
2287  changed = true;
2288  }
2289  continue;
2290  }
2291 
2292  BoolAttr trueYield, falseYield;
2293  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2294  !matchPattern(falseResult, m_Constant(&falseYield)))
2295  continue;
2296 
2297  bool trueVal = trueYield.getValue();
2298  bool falseVal = falseYield.getValue();
2299  if (!trueVal && falseVal) {
2300  if (!opResult.use_empty()) {
2301  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2302  Value notCond = rewriter.create<arith::XOrIOp>(
2303  op.getLoc(), op.getCondition(),
2304  constDialect
2305  ->materializeConstant(rewriter,
2306  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2307  op.getLoc())
2308  ->getResult(0));
2309  opResult.replaceAllUsesWith(notCond);
2310  changed = true;
2311  }
2312  }
2313  if (trueVal && !falseVal) {
2314  if (!opResult.use_empty()) {
2315  opResult.replaceAllUsesWith(op.getCondition());
2316  changed = true;
2317  }
2318  }
2319  }
2320  return success(changed);
2321  }
2322 };
2323 
2324 /// Merge any consecutive scf.if's with the same condition.
2325 ///
2326 /// scf.if %cond {
2327 /// firstCodeTrue();...
2328 /// } else {
2329 /// firstCodeFalse();...
2330 /// }
2331 /// %res = scf.if %cond {
2332 /// secondCodeTrue();...
2333 /// } else {
2334 /// secondCodeFalse();...
2335 /// }
2336 ///
2337 /// becomes
2338 /// %res = scf.if %cmp {
2339 /// firstCodeTrue();...
2340 /// secondCodeTrue();...
2341 /// } else {
2342 /// firstCodeFalse();...
2343 /// secondCodeFalse();...
2344 /// }
2345 struct CombineIfs : public OpRewritePattern<IfOp> {
2347 
2348  LogicalResult matchAndRewrite(IfOp nextIf,
2349  PatternRewriter &rewriter) const override {
2350  Block *parent = nextIf->getBlock();
2351  if (nextIf == &parent->front())
2352  return failure();
2353 
2354  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2355  if (!prevIf)
2356  return failure();
2357 
2358  // Determine the logical then/else blocks when prevIf's
2359  // condition is used. Null means the block does not exist
2360  // in that case (e.g. empty else). If neither of these
2361  // are set, the two conditions cannot be compared.
2362  Block *nextThen = nullptr;
2363  Block *nextElse = nullptr;
2364  if (nextIf.getCondition() == prevIf.getCondition()) {
2365  nextThen = nextIf.thenBlock();
2366  if (!nextIf.getElseRegion().empty())
2367  nextElse = nextIf.elseBlock();
2368  }
2369  if (arith::XOrIOp notv =
2370  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2371  if (notv.getLhs() == prevIf.getCondition() &&
2372  matchPattern(notv.getRhs(), m_One())) {
2373  nextElse = nextIf.thenBlock();
2374  if (!nextIf.getElseRegion().empty())
2375  nextThen = nextIf.elseBlock();
2376  }
2377  }
2378  if (arith::XOrIOp notv =
2379  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2380  if (notv.getLhs() == nextIf.getCondition() &&
2381  matchPattern(notv.getRhs(), m_One())) {
2382  nextElse = nextIf.thenBlock();
2383  if (!nextIf.getElseRegion().empty())
2384  nextThen = nextIf.elseBlock();
2385  }
2386  }
2387 
2388  if (!nextThen && !nextElse)
2389  return failure();
2390 
2391  SmallVector<Value> prevElseYielded;
2392  if (!prevIf.getElseRegion().empty())
2393  prevElseYielded = prevIf.elseYield().getOperands();
2394  // Replace all uses of return values of op within nextIf with the
2395  // corresponding yields
2396  for (auto it : llvm::zip(prevIf.getResults(),
2397  prevIf.thenYield().getOperands(), prevElseYielded))
2398  for (OpOperand &use :
2399  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2400  if (nextThen && nextThen->getParent()->isAncestor(
2401  use.getOwner()->getParentRegion())) {
2402  rewriter.startOpModification(use.getOwner());
2403  use.set(std::get<1>(it));
2404  rewriter.finalizeOpModification(use.getOwner());
2405  } else if (nextElse && nextElse->getParent()->isAncestor(
2406  use.getOwner()->getParentRegion())) {
2407  rewriter.startOpModification(use.getOwner());
2408  use.set(std::get<2>(it));
2409  rewriter.finalizeOpModification(use.getOwner());
2410  }
2411  }
2412 
2413  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2414  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2415 
2416  IfOp combinedIf = rewriter.create<IfOp>(
2417  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2418  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2419 
2420  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2421  combinedIf.getThenRegion(),
2422  combinedIf.getThenRegion().begin());
2423 
2424  if (nextThen) {
2425  YieldOp thenYield = combinedIf.thenYield();
2426  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2427  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2428  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2429 
2430  SmallVector<Value> mergedYields(thenYield.getOperands());
2431  llvm::append_range(mergedYields, thenYield2.getOperands());
2432  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2433  rewriter.eraseOp(thenYield);
2434  rewriter.eraseOp(thenYield2);
2435  }
2436 
2437  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2438  combinedIf.getElseRegion(),
2439  combinedIf.getElseRegion().begin());
2440 
2441  if (nextElse) {
2442  if (combinedIf.getElseRegion().empty()) {
2443  rewriter.inlineRegionBefore(*nextElse->getParent(),
2444  combinedIf.getElseRegion(),
2445  combinedIf.getElseRegion().begin());
2446  } else {
2447  YieldOp elseYield = combinedIf.elseYield();
2448  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2449  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2450 
2451  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2452 
2453  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2454  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2455 
2456  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2457  rewriter.eraseOp(elseYield);
2458  rewriter.eraseOp(elseYield2);
2459  }
2460  }
2461 
2462  SmallVector<Value> prevValues;
2463  SmallVector<Value> nextValues;
2464  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2465  if (pair.index() < prevIf.getNumResults())
2466  prevValues.push_back(pair.value());
2467  else
2468  nextValues.push_back(pair.value());
2469  }
2470  rewriter.replaceOp(prevIf, prevValues);
2471  rewriter.replaceOp(nextIf, nextValues);
2472  return success();
2473  }
2474 };
2475 
2476 /// Pattern to remove an empty else branch.
2477 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2479 
2480  LogicalResult matchAndRewrite(IfOp ifOp,
2481  PatternRewriter &rewriter) const override {
2482  // Cannot remove else region when there are operation results.
2483  if (ifOp.getNumResults())
2484  return failure();
2485  Block *elseBlock = ifOp.elseBlock();
2486  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2487  return failure();
2488  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2489  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2490  newIfOp.getThenRegion().begin());
2491  rewriter.eraseOp(ifOp);
2492  return success();
2493  }
2494 };
2495 
2496 /// Convert nested `if`s into `arith.andi` + single `if`.
2497 ///
2498 /// scf.if %arg0 {
2499 /// scf.if %arg1 {
2500 /// ...
2501 /// scf.yield
2502 /// }
2503 /// scf.yield
2504 /// }
2505 /// becomes
2506 ///
2507 /// %0 = arith.andi %arg0, %arg1
2508 /// scf.if %0 {
2509 /// ...
2510 /// scf.yield
2511 /// }
2512 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2514 
2515  LogicalResult matchAndRewrite(IfOp op,
2516  PatternRewriter &rewriter) const override {
2517  auto nestedOps = op.thenBlock()->without_terminator();
2518  // Nested `if` must be the only op in block.
2519  if (!llvm::hasSingleElement(nestedOps))
2520  return failure();
2521 
2522  // If there is an else block, it can only yield
2523  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2524  return failure();
2525 
2526  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2527  if (!nestedIf)
2528  return failure();
2529 
2530  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2531  return failure();
2532 
2533  SmallVector<Value> thenYield(op.thenYield().getOperands());
2534  SmallVector<Value> elseYield;
2535  if (op.elseBlock())
2536  llvm::append_range(elseYield, op.elseYield().getOperands());
2537 
2538  // A list of indices for which we should upgrade the value yielded
2539  // in the else to a select.
2540  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2541 
2542  // If the outer scf.if yields a value produced by the inner scf.if,
2543  // only permit combining if the value yielded when the condition
2544  // is false in the outer scf.if is the same value yielded when the
2545  // inner scf.if condition is false.
2546  // Note that the array access to elseYield will not go out of bounds
2547  // since it must have the same length as thenYield, since they both
2548  // come from the same scf.if.
2549  for (const auto &tup : llvm::enumerate(thenYield)) {
2550  if (tup.value().getDefiningOp() == nestedIf) {
2551  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2552  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2553  elseYield[tup.index()]) {
2554  return failure();
2555  }
2556  // If the correctness test passes, we will yield
2557  // corresponding value from the inner scf.if
2558  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2559  continue;
2560  }
2561 
2562  // Otherwise, we need to ensure the else block of the combined
2563  // condition still returns the same value when the outer condition is
2564  // true and the inner condition is false. This can be accomplished if
2565  // the then value is defined outside the outer scf.if and we replace the
2566  // value with a select that considers just the outer condition. Since
2567  // the else region contains just the yield, its yielded value is
2568  // defined outside the scf.if, by definition.
2569 
2570  // If the then value is defined within the scf.if, bail.
2571  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2572  return failure();
2573  }
2574  elseYieldsToUpgradeToSelect.push_back(tup.index());
2575  }
2576 
2577  Location loc = op.getLoc();
2578  Value newCondition = rewriter.create<arith::AndIOp>(
2579  loc, op.getCondition(), nestedIf.getCondition());
2580  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2581  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2582 
2583  SmallVector<Value> results;
2584  llvm::append_range(results, newIf.getResults());
2585  rewriter.setInsertionPoint(newIf);
2586 
2587  for (auto idx : elseYieldsToUpgradeToSelect)
2588  results[idx] = rewriter.create<arith::SelectOp>(
2589  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2590 
2591  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2592  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2593  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2594  if (!elseYield.empty()) {
2595  rewriter.createBlock(&newIf.getElseRegion());
2596  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2597  rewriter.create<YieldOp>(loc, elseYield);
2598  }
2599  rewriter.replaceOp(op, results);
2600  return success();
2601  }
2602 };
2603 
2604 } // namespace
2605 
2606 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2607  MLIRContext *context) {
2608  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2609  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2610  RemoveStaticCondition, RemoveUnusedResults,
2611  ReplaceIfYieldWithConditionOrValue>(context);
2612 }
2613 
2614 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2615 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2616 Block *IfOp::elseBlock() {
2617  Region &r = getElseRegion();
2618  if (r.empty())
2619  return nullptr;
2620  return &r.back();
2621 }
2622 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2623 
2624 //===----------------------------------------------------------------------===//
2625 // ParallelOp
2626 //===----------------------------------------------------------------------===//
2627 
2628 void ParallelOp::build(
2629  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2630  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2632  bodyBuilderFn) {
2633  result.addOperands(lowerBounds);
2634  result.addOperands(upperBounds);
2635  result.addOperands(steps);
2636  result.addOperands(initVals);
2637  result.addAttribute(
2638  ParallelOp::getOperandSegmentSizeAttr(),
2639  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2640  static_cast<int32_t>(upperBounds.size()),
2641  static_cast<int32_t>(steps.size()),
2642  static_cast<int32_t>(initVals.size())}));
2643  result.addTypes(initVals.getTypes());
2644 
2645  OpBuilder::InsertionGuard guard(builder);
2646  unsigned numIVs = steps.size();
2647  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2648  SmallVector<Location, 8> argLocs(numIVs, result.location);
2649  Region *bodyRegion = result.addRegion();
2650  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2651 
2652  if (bodyBuilderFn) {
2653  builder.setInsertionPointToStart(bodyBlock);
2654  bodyBuilderFn(builder, result.location,
2655  bodyBlock->getArguments().take_front(numIVs),
2656  bodyBlock->getArguments().drop_front(numIVs));
2657  }
2658  // Add terminator only if there are no reductions.
2659  if (initVals.empty())
2660  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2661 }
2662 
2663 void ParallelOp::build(
2664  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2665  ValueRange upperBounds, ValueRange steps,
2666  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2667  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2668  // we don't capture a reference to a temporary by constructing the lambda at
2669  // function level.
2670  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2671  Location nestedLoc, ValueRange ivs,
2672  ValueRange) {
2673  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2674  };
2675  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2676  if (bodyBuilderFn)
2677  wrapper = wrappedBuilderFn;
2678 
2679  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2680  wrapper);
2681 }
2682 
2684  // Check that there is at least one value in lowerBound, upperBound and step.
2685  // It is sufficient to test only step, because it is ensured already that the
2686  // number of elements in lowerBound, upperBound and step are the same.
2687  Operation::operand_range stepValues = getStep();
2688  if (stepValues.empty())
2689  return emitOpError(
2690  "needs at least one tuple element for lowerBound, upperBound and step");
2691 
2692  // Check whether all constant step values are positive.
2693  for (Value stepValue : stepValues)
2694  if (auto cst = getConstantIntValue(stepValue))
2695  if (*cst <= 0)
2696  return emitOpError("constant step operand must be positive");
2697 
2698  // Check that the body defines the same number of block arguments as the
2699  // number of tuple elements in step.
2700  Block *body = getBody();
2701  if (body->getNumArguments() != stepValues.size())
2702  return emitOpError() << "expects the same number of induction variables: "
2703  << body->getNumArguments()
2704  << " as bound and step values: " << stepValues.size();
2705  for (auto arg : body->getArguments())
2706  if (!arg.getType().isIndex())
2707  return emitOpError(
2708  "expects arguments for the induction variable to be of index type");
2709 
2710  // Check that the terminator is an scf.reduce op.
2711  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2712  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2713  if (!reduceOp)
2714  return failure();
2715 
2716  // Check that the number of results is the same as the number of reductions.
2717  auto resultsSize = getResults().size();
2718  auto reductionsSize = reduceOp.getReductions().size();
2719  auto initValsSize = getInitVals().size();
2720  if (resultsSize != reductionsSize)
2721  return emitOpError() << "expects number of results: " << resultsSize
2722  << " to be the same as number of reductions: "
2723  << reductionsSize;
2724  if (resultsSize != initValsSize)
2725  return emitOpError() << "expects number of results: " << resultsSize
2726  << " to be the same as number of initial values: "
2727  << initValsSize;
2728 
2729  // Check that the types of the results and reductions are the same.
2730  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2731  auto resultType = getOperation()->getResult(i).getType();
2732  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2733  if (resultType != reductionOperandType)
2734  return reduceOp.emitOpError()
2735  << "expects type of " << i
2736  << "-th reduction operand: " << reductionOperandType
2737  << " to be the same as the " << i
2738  << "-th result type: " << resultType;
2739  }
2740  return success();
2741 }
2742 
2744  auto &builder = parser.getBuilder();
2745  // Parse an opening `(` followed by induction variables followed by `)`
2748  return failure();
2749 
2750  // Parse loop bounds.
2752  if (parser.parseEqual() ||
2753  parser.parseOperandList(lower, ivs.size(),
2755  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2756  return failure();
2757 
2759  if (parser.parseKeyword("to") ||
2760  parser.parseOperandList(upper, ivs.size(),
2762  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2763  return failure();
2764 
2765  // Parse step values.
2767  if (parser.parseKeyword("step") ||
2768  parser.parseOperandList(steps, ivs.size(),
2770  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2771  return failure();
2772 
2773  // Parse init values.
2775  if (succeeded(parser.parseOptionalKeyword("init"))) {
2776  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2777  return failure();
2778  }
2779 
2780  // Parse optional results in case there is a reduce.
2781  if (parser.parseOptionalArrowTypeList(result.types))
2782  return failure();
2783 
2784  // Now parse the body.
2785  Region *body = result.addRegion();
2786  for (auto &iv : ivs)
2787  iv.type = builder.getIndexType();
2788  if (parser.parseRegion(*body, ivs))
2789  return failure();
2790 
2791  // Set `operandSegmentSizes` attribute.
2792  result.addAttribute(
2793  ParallelOp::getOperandSegmentSizeAttr(),
2794  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2795  static_cast<int32_t>(upper.size()),
2796  static_cast<int32_t>(steps.size()),
2797  static_cast<int32_t>(initVals.size())}));
2798 
2799  // Parse attributes.
2800  if (parser.parseOptionalAttrDict(result.attributes) ||
2801  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2802  result.operands))
2803  return failure();
2804 
2805  // Add a terminator if none was parsed.
2806  ParallelOp::ensureTerminator(*body, builder, result.location);
2807  return success();
2808 }
2809 
2811  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2812  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2813  if (!getInitVals().empty())
2814  p << " init (" << getInitVals() << ")";
2815  p.printOptionalArrowTypeList(getResultTypes());
2816  p << ' ';
2817  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2819  (*this)->getAttrs(),
2820  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2821 }
2822 
2823 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
2824 
2825 std::optional<Value> ParallelOp::getSingleInductionVar() {
2826  if (getNumLoops() != 1)
2827  return std::nullopt;
2828  return getBody()->getArgument(0);
2829 }
2830 
2831 std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
2832  if (getNumLoops() != 1)
2833  return std::nullopt;
2834  return getLowerBound()[0];
2835 }
2836 
2837 std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
2838  if (getNumLoops() != 1)
2839  return std::nullopt;
2840  return getUpperBound()[0];
2841 }
2842 
2843 std::optional<OpFoldResult> ParallelOp::getSingleStep() {
2844  if (getNumLoops() != 1)
2845  return std::nullopt;
2846  return getStep()[0];
2847 }
2848 
2850  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2851  if (!ivArg)
2852  return ParallelOp();
2853  assert(ivArg.getOwner() && "unlinked block argument");
2854  auto *containingOp = ivArg.getOwner()->getParentOp();
2855  return dyn_cast<ParallelOp>(containingOp);
2856 }
2857 
2858 namespace {
2859 // Collapse loop dimensions that perform a single iteration.
2860 struct ParallelOpSingleOrZeroIterationDimsFolder
2861  : public OpRewritePattern<ParallelOp> {
2863 
2864  LogicalResult matchAndRewrite(ParallelOp op,
2865  PatternRewriter &rewriter) const override {
2866  Location loc = op.getLoc();
2867 
2868  // Compute new loop bounds that omit all single-iteration loop dimensions.
2869  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
2870  IRMapping mapping;
2871  for (auto [lb, ub, step, iv] :
2872  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2873  op.getInductionVars())) {
2874  auto numIterations = constantTripCount(lb, ub, step);
2875  if (numIterations.has_value()) {
2876  // Remove the loop if it performs zero iterations.
2877  if (*numIterations == 0) {
2878  rewriter.replaceOp(op, op.getInitVals());
2879  return success();
2880  }
2881  // Replace the loop induction variable by the lower bound if the loop
2882  // performs a single iteration. Otherwise, copy the loop bounds.
2883  if (*numIterations == 1) {
2884  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
2885  continue;
2886  }
2887  }
2888  newLowerBounds.push_back(lb);
2889  newUpperBounds.push_back(ub);
2890  newSteps.push_back(step);
2891  }
2892  // Exit if none of the loop dimensions perform a single iteration.
2893  if (newLowerBounds.size() == op.getLowerBound().size())
2894  return failure();
2895 
2896  if (newLowerBounds.empty()) {
2897  // All of the loop dimensions perform a single iteration. Inline
2898  // loop body and nested ReduceOp's
2899  SmallVector<Value> results;
2900  results.reserve(op.getInitVals().size());
2901  for (auto &bodyOp : op.getBody()->without_terminator())
2902  rewriter.clone(bodyOp, mapping);
2903  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
2904  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
2905  Block &reduceBlock = reduceOp.getReductions()[i].front();
2906  auto initValIndex = results.size();
2907  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2908  mapping.map(reduceBlock.getArgument(1),
2909  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
2910  for (auto &reduceBodyOp : reduceBlock.without_terminator())
2911  rewriter.clone(reduceBodyOp, mapping);
2912 
2913  auto result = mapping.lookupOrDefault(
2914  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2915  results.push_back(result);
2916  }
2917 
2918  rewriter.replaceOp(op, results);
2919  return success();
2920  }
2921  // Replace the parallel loop by lower-dimensional parallel loop.
2922  auto newOp =
2923  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2924  newSteps, op.getInitVals(), nullptr);
2925  // Erase the empty block that was inserted by the builder.
2926  rewriter.eraseBlock(newOp.getBody());
2927  // Clone the loop body and remap the block arguments of the collapsed loops
2928  // (inlining does not support a cancellable block argument mapping).
2929  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2930  newOp.getRegion().begin(), mapping);
2931  rewriter.replaceOp(op, newOp.getResults());
2932  return success();
2933  }
2934 };
2935 
2936 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2938 
2939  LogicalResult matchAndRewrite(ParallelOp op,
2940  PatternRewriter &rewriter) const override {
2941  Block &outerBody = *op.getBody();
2942  if (!llvm::hasSingleElement(outerBody.without_terminator()))
2943  return failure();
2944 
2945  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2946  if (!innerOp)
2947  return failure();
2948 
2949  for (auto val : outerBody.getArguments())
2950  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2951  llvm::is_contained(innerOp.getUpperBound(), val) ||
2952  llvm::is_contained(innerOp.getStep(), val))
2953  return failure();
2954 
2955  // Reductions are not supported yet.
2956  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2957  return failure();
2958 
2959  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2960  ValueRange iterVals, ValueRange) {
2961  Block &innerBody = *innerOp.getBody();
2962  assert(iterVals.size() ==
2963  (outerBody.getNumArguments() + innerBody.getNumArguments()));
2964  IRMapping mapping;
2965  mapping.map(outerBody.getArguments(),
2966  iterVals.take_front(outerBody.getNumArguments()));
2967  mapping.map(innerBody.getArguments(),
2968  iterVals.take_back(innerBody.getNumArguments()));
2969  for (Operation &op : innerBody.without_terminator())
2970  builder.clone(op, mapping);
2971  };
2972 
2973  auto concatValues = [](const auto &first, const auto &second) {
2974  SmallVector<Value> ret;
2975  ret.reserve(first.size() + second.size());
2976  ret.assign(first.begin(), first.end());
2977  ret.append(second.begin(), second.end());
2978  return ret;
2979  };
2980 
2981  auto newLowerBounds =
2982  concatValues(op.getLowerBound(), innerOp.getLowerBound());
2983  auto newUpperBounds =
2984  concatValues(op.getUpperBound(), innerOp.getUpperBound());
2985  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2986 
2987  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2988  newSteps, std::nullopt,
2989  bodyBuilder);
2990  return success();
2991  }
2992 };
2993 
2994 } // namespace
2995 
2996 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2997  MLIRContext *context) {
2998  results
2999  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3000  context);
3001 }
3002 
3003 /// Given the region at `index`, or the parent operation if `index` is None,
3004 /// return the successor regions. These are the regions that may be selected
3005 /// during the flow of control. `operands` is a set of optional attributes that
3006 /// correspond to a constant value for each operand, or null if that operand is
3007 /// not a constant.
3008 void ParallelOp::getSuccessorRegions(
3010  // Both the operation itself and the region may be branching into the body or
3011  // back into the operation itself. It is possible for loop not to enter the
3012  // body.
3013  regions.push_back(RegionSuccessor(&getRegion()));
3014  regions.push_back(RegionSuccessor());
3015 }
3016 
3017 //===----------------------------------------------------------------------===//
3018 // ReduceOp
3019 //===----------------------------------------------------------------------===//
3020 
3021 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3022 
3023 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3024  ValueRange operands) {
3025  result.addOperands(operands);
3026  for (Value v : operands) {
3027  OpBuilder::InsertionGuard guard(builder);
3028  Region *bodyRegion = result.addRegion();
3029  builder.createBlock(bodyRegion, {},
3030  ArrayRef<Type>{v.getType(), v.getType()},
3031  {result.location, result.location});
3032  }
3033 }
3034 
3035 LogicalResult ReduceOp::verifyRegions() {
3036  // The region of a ReduceOp has two arguments of the same type as its
3037  // corresponding operand.
3038  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3039  auto type = getOperands()[i].getType();
3040  Block &block = getReductions()[i].front();
3041  if (block.empty())
3042  return emitOpError() << i << "-th reduction has an empty body";
3043  if (block.getNumArguments() != 2 ||
3044  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3045  return arg.getType() != type;
3046  }))
3047  return emitOpError() << "expected two block arguments with type " << type
3048  << " in the " << i << "-th reduction region";
3049 
3050  // Check that the block is terminated by a ReduceReturnOp.
3051  if (!isa<ReduceReturnOp>(block.getTerminator()))
3052  return emitOpError("reduction bodies must be terminated with an "
3053  "'scf.reduce.return' op");
3054  }
3055 
3056  return success();
3057 }
3058 
3061  // No operands are forwarded to the next iteration.
3062  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3063 }
3064 
3065 //===----------------------------------------------------------------------===//
3066 // ReduceReturnOp
3067 //===----------------------------------------------------------------------===//
3068 
3070  // The type of the return value should be the same type as the types of the
3071  // block arguments of the reduction body.
3072  Block *reductionBody = getOperation()->getBlock();
3073  // Should already be verified by an op trait.
3074  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3075  Type expectedResultType = reductionBody->getArgument(0).getType();
3076  if (expectedResultType != getResult().getType())
3077  return emitOpError() << "must have type " << expectedResultType
3078  << " (the type of the reduction inputs)";
3079  return success();
3080 }
3081 
3082 //===----------------------------------------------------------------------===//
3083 // WhileOp
3084 //===----------------------------------------------------------------------===//
3085 
3086 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3087  ::mlir::OperationState &odsState, TypeRange resultTypes,
3088  ValueRange operands, BodyBuilderFn beforeBuilder,
3089  BodyBuilderFn afterBuilder) {
3090  odsState.addOperands(operands);
3091  odsState.addTypes(resultTypes);
3092 
3093  OpBuilder::InsertionGuard guard(odsBuilder);
3094 
3095  // Build before region.
3096  SmallVector<Location, 4> beforeArgLocs;
3097  beforeArgLocs.reserve(operands.size());
3098  for (Value operand : operands) {
3099  beforeArgLocs.push_back(operand.getLoc());
3100  }
3101 
3102  Region *beforeRegion = odsState.addRegion();
3103  Block *beforeBlock = odsBuilder.createBlock(
3104  beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
3105  if (beforeBuilder)
3106  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3107 
3108  // Build after region.
3109  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3110 
3111  Region *afterRegion = odsState.addRegion();
3112  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3113  resultTypes, afterArgLocs);
3114 
3115  if (afterBuilder)
3116  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3117 }
3118 
3119 ConditionOp WhileOp::getConditionOp() {
3120  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3121 }
3122 
3123 YieldOp WhileOp::getYieldOp() {
3124  return cast<YieldOp>(getAfterBody()->getTerminator());
3125 }
3126 
3127 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3128  return getYieldOp().getResultsMutable();
3129 }
3130 
3131 Block::BlockArgListType WhileOp::getBeforeArguments() {
3132  return getBeforeBody()->getArguments();
3133 }
3134 
3135 Block::BlockArgListType WhileOp::getAfterArguments() {
3136  return getAfterBody()->getArguments();
3137 }
3138 
3139 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3140  return getBeforeArguments();
3141 }
3142 
3143 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3144  assert(point == getBefore() &&
3145  "WhileOp is expected to branch only to the first region");
3146  return getInits();
3147 }
3148 
3149 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3151  // The parent op always branches to the condition region.
3152  if (point.isParent()) {
3153  regions.emplace_back(&getBefore(), getBefore().getArguments());
3154  return;
3155  }
3156 
3157  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3158  "there are only two regions in a WhileOp");
3159  // The body region always branches back to the condition region.
3160  if (point == getAfter()) {
3161  regions.emplace_back(&getBefore(), getBefore().getArguments());
3162  return;
3163  }
3164 
3165  regions.emplace_back(getResults());
3166  regions.emplace_back(&getAfter(), getAfter().getArguments());
3167 }
3168 
3169 SmallVector<Region *> WhileOp::getLoopRegions() {
3170  return {&getBefore(), &getAfter()};
3171 }
3172 
3173 /// Parses a `while` op.
3174 ///
3175 /// op ::= `scf.while` assignments `:` function-type region `do` region
3176 /// `attributes` attribute-dict
3177 /// initializer ::= /* empty */ | `(` assignment-list `)`
3178 /// assignment-list ::= assignment | assignment `,` assignment-list
3179 /// assignment ::= ssa-value `=` ssa-value
3183  Region *before = result.addRegion();
3184  Region *after = result.addRegion();
3185 
3186  OptionalParseResult listResult =
3187  parser.parseOptionalAssignmentList(regionArgs, operands);
3188  if (listResult.has_value() && failed(listResult.value()))
3189  return failure();
3190 
3191  FunctionType functionType;
3192  SMLoc typeLoc = parser.getCurrentLocation();
3193  if (failed(parser.parseColonType(functionType)))
3194  return failure();
3195 
3196  result.addTypes(functionType.getResults());
3197 
3198  if (functionType.getNumInputs() != operands.size()) {
3199  return parser.emitError(typeLoc)
3200  << "expected as many input types as operands "
3201  << "(expected " << operands.size() << " got "
3202  << functionType.getNumInputs() << ")";
3203  }
3204 
3205  // Resolve input operands.
3206  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3207  parser.getCurrentLocation(),
3208  result.operands)))
3209  return failure();
3210 
3211  // Propagate the types into the region arguments.
3212  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3213  regionArgs[i].type = functionType.getInput(i);
3214 
3215  return failure(parser.parseRegion(*before, regionArgs) ||
3216  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3218 }
3219 
3220 /// Prints a `while` op.
3222  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3223  p << " : ";
3224  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3225  p << ' ';
3226  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3227  p << " do ";
3228  p.printRegion(getAfter());
3229  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3230 }
3231 
3232 /// Verifies that two ranges of types match, i.e. have the same number of
3233 /// entries and that types are pairwise equals. Reports errors on the given
3234 /// operation in case of mismatch.
3235 template <typename OpTy>
3237  TypeRange right, StringRef message) {
3238  if (left.size() != right.size())
3239  return op.emitOpError("expects the same number of ") << message;
3240 
3241  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3242  if (left[i] != right[i]) {
3243  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3244  << message;
3245  diag.attachNote() << "for argument " << i << ", found " << left[i]
3246  << " and " << right[i];
3247  return diag;
3248  }
3249  }
3250 
3251  return success();
3252 }
3253 
3255  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3256  *this, getBefore(),
3257  "expects the 'before' region to terminate with 'scf.condition'");
3258  if (!beforeTerminator)
3259  return failure();
3260 
3261  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3262  *this, getAfter(),
3263  "expects the 'after' region to terminate with 'scf.yield'");
3264  return success(afterTerminator != nullptr);
3265 }
3266 
3267 namespace {
3268 /// Replace uses of the condition within the do block with true, since otherwise
3269 /// the block would not be evaluated.
3270 ///
3271 /// scf.while (..) : (i1, ...) -> ... {
3272 /// %condition = call @evaluate_condition() : () -> i1
3273 /// scf.condition(%condition) %condition : i1, ...
3274 /// } do {
3275 /// ^bb0(%arg0: i1, ...):
3276 /// use(%arg0)
3277 /// ...
3278 ///
3279 /// becomes
3280 /// scf.while (..) : (i1, ...) -> ... {
3281 /// %condition = call @evaluate_condition() : () -> i1
3282 /// scf.condition(%condition) %condition : i1, ...
3283 /// } do {
3284 /// ^bb0(%arg0: i1, ...):
3285 /// use(%true)
3286 /// ...
3287 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3289 
3290  LogicalResult matchAndRewrite(WhileOp op,
3291  PatternRewriter &rewriter) const override {
3292  auto term = op.getConditionOp();
3293 
3294  // These variables serve to prevent creating duplicate constants
3295  // and hold constant true or false values.
3296  Value constantTrue = nullptr;
3297 
3298  bool replaced = false;
3299  for (auto yieldedAndBlockArgs :
3300  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3301  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3302  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3303  if (!constantTrue)
3304  constantTrue = rewriter.create<arith::ConstantOp>(
3305  op.getLoc(), term.getCondition().getType(),
3306  rewriter.getBoolAttr(true));
3307 
3308  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3309  constantTrue);
3310  replaced = true;
3311  }
3312  }
3313  }
3314  return success(replaced);
3315  }
3316 };
3317 
3318 /// Remove loop invariant arguments from `before` block of scf.while.
3319 /// A before block argument is considered loop invariant if :-
3320 /// 1. i-th yield operand is equal to the i-th while operand.
3321 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3322 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3323 /// iter argument/while operand.
3324 /// For the arguments which are removed, their uses inside scf.while
3325 /// are replaced with their corresponding initial value.
3326 ///
3327 /// Eg:
3328 /// INPUT :-
3329 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3330 /// ..., %argN_before = %N)
3331 /// {
3332 /// ...
3333 /// scf.condition(%cond) %arg1_before, %arg0_before,
3334 /// %arg2_before, %arg0_before, ...
3335 /// } do {
3336 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3337 /// ..., %argK_after):
3338 /// ...
3339 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3340 /// }
3341 ///
3342 /// OUTPUT :-
3343 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3344 /// %N)
3345 /// {
3346 /// ...
3347 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3348 /// } do {
3349 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3350 /// ..., %argK_after):
3351 /// ...
3352 /// scf.yield %arg1_after, ..., %argN
3353 /// }
3354 ///
3355 /// EXPLANATION:
3356 /// We iterate over each yield operand.
3357 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3358 /// %arg0_before, which in turn is the 0-th iter argument. So we
3359 /// remove 0-th before block argument and yield operand, and replace
3360 /// all uses of the 0-th before block argument with its initial value
3361 /// %a.
3362 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3363 /// value. So we remove this operand and the corresponding before
3364 /// block argument and replace all uses of 1-th before block argument
3365 /// with %b.
3366 struct RemoveLoopInvariantArgsFromBeforeBlock
3367  : public OpRewritePattern<WhileOp> {
3369 
3370  LogicalResult matchAndRewrite(WhileOp op,
3371  PatternRewriter &rewriter) const override {
3372  Block &afterBlock = *op.getAfterBody();
3373  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3374  ConditionOp condOp = op.getConditionOp();
3375  OperandRange condOpArgs = condOp.getArgs();
3376  Operation *yieldOp = afterBlock.getTerminator();
3377  ValueRange yieldOpArgs = yieldOp->getOperands();
3378 
3379  bool canSimplify = false;
3380  for (const auto &it :
3381  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3382  auto index = static_cast<unsigned>(it.index());
3383  auto [initVal, yieldOpArg] = it.value();
3384  // If i-th yield operand is equal to the i-th operand of the scf.while,
3385  // the i-th before block argument is a loop invariant.
3386  if (yieldOpArg == initVal) {
3387  canSimplify = true;
3388  break;
3389  }
3390  // If the i-th yield operand is k-th after block argument, then we check
3391  // if the (k+1)-th condition op operand is equal to either the i-th before
3392  // block argument or the initial value of i-th before block argument. If
3393  // the comparison results `true`, i-th before block argument is a loop
3394  // invariant.
3395  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3396  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3397  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3398  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3399  canSimplify = true;
3400  break;
3401  }
3402  }
3403  }
3404 
3405  if (!canSimplify)
3406  return failure();
3407 
3408  SmallVector<Value> newInitArgs, newYieldOpArgs;
3409  DenseMap<unsigned, Value> beforeBlockInitValMap;
3410  SmallVector<Location> newBeforeBlockArgLocs;
3411  for (const auto &it :
3412  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3413  auto index = static_cast<unsigned>(it.index());
3414  auto [initVal, yieldOpArg] = it.value();
3415 
3416  // If i-th yield operand is equal to the i-th operand of the scf.while,
3417  // the i-th before block argument is a loop invariant.
3418  if (yieldOpArg == initVal) {
3419  beforeBlockInitValMap.insert({index, initVal});
3420  continue;
3421  } else {
3422  // If the i-th yield operand is k-th after block argument, then we check
3423  // if the (k+1)-th condition op operand is equal to either the i-th
3424  // before block argument or the initial value of i-th before block
3425  // argument. If the comparison results `true`, i-th before block
3426  // argument is a loop invariant.
3427  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3428  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3429  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3430  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3431  beforeBlockInitValMap.insert({index, initVal});
3432  continue;
3433  }
3434  }
3435  }
3436  newInitArgs.emplace_back(initVal);
3437  newYieldOpArgs.emplace_back(yieldOpArg);
3438  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3439  }
3440 
3441  {
3442  OpBuilder::InsertionGuard g(rewriter);
3443  rewriter.setInsertionPoint(yieldOp);
3444  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3445  }
3446 
3447  auto newWhile =
3448  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3449 
3450  Block &newBeforeBlock = *rewriter.createBlock(
3451  &newWhile.getBefore(), /*insertPt*/ {},
3452  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3453 
3454  Block &beforeBlock = *op.getBeforeBody();
3455  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3456  // For each i-th before block argument we find it's replacement value as :-
3457  // 1. If i-th before block argument is a loop invariant, we fetch it's
3458  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3459  // 2. Else we fetch j-th new before block argument as the replacement
3460  // value of i-th before block argument.
3461  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3462  // If the index 'i' argument was a loop invariant we fetch it's initial
3463  // value from `beforeBlockInitValMap`.
3464  if (beforeBlockInitValMap.count(i) != 0)
3465  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3466  else
3467  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3468  }
3469 
3470  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3471  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3472  newWhile.getAfter().begin());
3473 
3474  rewriter.replaceOp(op, newWhile.getResults());
3475  return success();
3476  }
3477 };
3478 
3479 /// Remove loop invariant value from result (condition op) of scf.while.
3480 /// A value is considered loop invariant if the final value yielded by
3481 /// scf.condition is defined outside of the `before` block. We remove the
3482 /// corresponding argument in `after` block and replace the use with the value.
3483 /// We also replace the use of the corresponding result of scf.while with the
3484 /// value.
3485 ///
3486 /// Eg:
3487 /// INPUT :-
3488 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3489 /// %argN_before = %N) {
3490 /// ...
3491 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3492 /// } do {
3493 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3494 /// ...
3495 /// some_func(%arg1_after)
3496 /// ...
3497 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3498 /// }
3499 ///
3500 /// OUTPUT :-
3501 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3502 /// ...
3503 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3504 /// } do {
3505 /// ^bb0(%arg0, %arg3, ..., %argM):
3506 /// ...
3507 /// some_func(%a)
3508 /// ...
3509 /// scf.yield %arg0, %b, ..., %argN
3510 /// }
3511 ///
3512 /// EXPLANATION:
3513 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3514 /// before block of scf.while, so they get removed.
3515 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3516 /// replaced by %b.
3517 /// 3. The corresponding after block argument %arg1_after's uses are
3518 /// replaced by %a and %arg2_after's uses are replaced by %b.
3519 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3521 
3522  LogicalResult matchAndRewrite(WhileOp op,
3523  PatternRewriter &rewriter) const override {
3524  Block &beforeBlock = *op.getBeforeBody();
3525  ConditionOp condOp = op.getConditionOp();
3526  OperandRange condOpArgs = condOp.getArgs();
3527 
3528  bool canSimplify = false;
3529  for (Value condOpArg : condOpArgs) {
3530  // Those values not defined within `before` block will be considered as
3531  // loop invariant values. We map the corresponding `index` with their
3532  // value.
3533  if (condOpArg.getParentBlock() != &beforeBlock) {
3534  canSimplify = true;
3535  break;
3536  }
3537  }
3538 
3539  if (!canSimplify)
3540  return failure();
3541 
3542  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3543 
3544  SmallVector<Value> newCondOpArgs;
3545  SmallVector<Type> newAfterBlockType;
3546  DenseMap<unsigned, Value> condOpInitValMap;
3547  SmallVector<Location> newAfterBlockArgLocs;
3548  for (const auto &it : llvm::enumerate(condOpArgs)) {
3549  auto index = static_cast<unsigned>(it.index());
3550  Value condOpArg = it.value();
3551  // Those values not defined within `before` block will be considered as
3552  // loop invariant values. We map the corresponding `index` with their
3553  // value.
3554  if (condOpArg.getParentBlock() != &beforeBlock) {
3555  condOpInitValMap.insert({index, condOpArg});
3556  } else {
3557  newCondOpArgs.emplace_back(condOpArg);
3558  newAfterBlockType.emplace_back(condOpArg.getType());
3559  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3560  }
3561  }
3562 
3563  {
3564  OpBuilder::InsertionGuard g(rewriter);
3565  rewriter.setInsertionPoint(condOp);
3566  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3567  newCondOpArgs);
3568  }
3569 
3570  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3571  op.getOperands());
3572 
3573  Block &newAfterBlock =
3574  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3575  newAfterBlockType, newAfterBlockArgLocs);
3576 
3577  Block &afterBlock = *op.getAfterBody();
3578  // Since a new scf.condition op was created, we need to fetch the new
3579  // `after` block arguments which will be used while replacing operations of
3580  // previous scf.while's `after` blocks. We'd also be fetching new result
3581  // values too.
3582  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3583  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3584  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3585  Value afterBlockArg, result;
3586  // If index 'i' argument was loop invariant we fetch it's value from the
3587  // `condOpInitMap` map.
3588  if (condOpInitValMap.count(i) != 0) {
3589  afterBlockArg = condOpInitValMap[i];
3590  result = afterBlockArg;
3591  } else {
3592  afterBlockArg = newAfterBlock.getArgument(j);
3593  result = newWhile.getResult(j);
3594  j++;
3595  }
3596  newAfterBlockArgs[i] = afterBlockArg;
3597  newWhileResults[i] = result;
3598  }
3599 
3600  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3601  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3602  newWhile.getBefore().begin());
3603 
3604  rewriter.replaceOp(op, newWhileResults);
3605  return success();
3606  }
3607 };
3608 
3609 /// Remove WhileOp results that are also unused in 'after' block.
3610 ///
3611 /// %0:2 = scf.while () : () -> (i32, i64) {
3612 /// %condition = "test.condition"() : () -> i1
3613 /// %v1 = "test.get_some_value"() : () -> i32
3614 /// %v2 = "test.get_some_value"() : () -> i64
3615 /// scf.condition(%condition) %v1, %v2 : i32, i64
3616 /// } do {
3617 /// ^bb0(%arg0: i32, %arg1: i64):
3618 /// "test.use"(%arg0) : (i32) -> ()
3619 /// scf.yield
3620 /// }
3621 /// return %0#0 : i32
3622 ///
3623 /// becomes
3624 /// %0 = scf.while () : () -> (i32) {
3625 /// %condition = "test.condition"() : () -> i1
3626 /// %v1 = "test.get_some_value"() : () -> i32
3627 /// %v2 = "test.get_some_value"() : () -> i64
3628 /// scf.condition(%condition) %v1 : i32
3629 /// } do {
3630 /// ^bb0(%arg0: i32):
3631 /// "test.use"(%arg0) : (i32) -> ()
3632 /// scf.yield
3633 /// }
3634 /// return %0 : i32
3635 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3637 
3638  LogicalResult matchAndRewrite(WhileOp op,
3639  PatternRewriter &rewriter) const override {
3640  auto term = op.getConditionOp();
3641  auto afterArgs = op.getAfterArguments();
3642  auto termArgs = term.getArgs();
3643 
3644  // Collect results mapping, new terminator args and new result types.
3645  SmallVector<unsigned> newResultsIndices;
3646  SmallVector<Type> newResultTypes;
3647  SmallVector<Value> newTermArgs;
3648  SmallVector<Location> newArgLocs;
3649  bool needUpdate = false;
3650  for (const auto &it :
3651  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3652  auto i = static_cast<unsigned>(it.index());
3653  Value result = std::get<0>(it.value());
3654  Value afterArg = std::get<1>(it.value());
3655  Value termArg = std::get<2>(it.value());
3656  if (result.use_empty() && afterArg.use_empty()) {
3657  needUpdate = true;
3658  } else {
3659  newResultsIndices.emplace_back(i);
3660  newTermArgs.emplace_back(termArg);
3661  newResultTypes.emplace_back(result.getType());
3662  newArgLocs.emplace_back(result.getLoc());
3663  }
3664  }
3665 
3666  if (!needUpdate)
3667  return failure();
3668 
3669  {
3670  OpBuilder::InsertionGuard g(rewriter);
3671  rewriter.setInsertionPoint(term);
3672  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3673  newTermArgs);
3674  }
3675 
3676  auto newWhile =
3677  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3678 
3679  Block &newAfterBlock = *rewriter.createBlock(
3680  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3681 
3682  // Build new results list and new after block args (unused entries will be
3683  // null).
3684  SmallVector<Value> newResults(op.getNumResults());
3685  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3686  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3687  newResults[it.value()] = newWhile.getResult(it.index());
3688  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3689  }
3690 
3691  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3692  newWhile.getBefore().begin());
3693 
3694  Block &afterBlock = *op.getAfterBody();
3695  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3696 
3697  rewriter.replaceOp(op, newResults);
3698  return success();
3699  }
3700 };
3701 
3702 /// Replace operations equivalent to the condition in the do block with true,
3703 /// since otherwise the block would not be evaluated.
3704 ///
3705 /// scf.while (..) : (i32, ...) -> ... {
3706 /// %z = ... : i32
3707 /// %condition = cmpi pred %z, %a
3708 /// scf.condition(%condition) %z : i32, ...
3709 /// } do {
3710 /// ^bb0(%arg0: i32, ...):
3711 /// %condition2 = cmpi pred %arg0, %a
3712 /// use(%condition2)
3713 /// ...
3714 ///
3715 /// becomes
3716 /// scf.while (..) : (i32, ...) -> ... {
3717 /// %z = ... : i32
3718 /// %condition = cmpi pred %z, %a
3719 /// scf.condition(%condition) %z : i32, ...
3720 /// } do {
3721 /// ^bb0(%arg0: i32, ...):
3722 /// use(%true)
3723 /// ...
3724 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3726 
3727  LogicalResult matchAndRewrite(scf::WhileOp op,
3728  PatternRewriter &rewriter) const override {
3729  using namespace scf;
3730  auto cond = op.getConditionOp();
3731  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3732  if (!cmp)
3733  return failure();
3734  bool changed = false;
3735  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3736  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3737  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3738  continue;
3739  for (OpOperand &u :
3740  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3741  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3742  if (!cmp2)
3743  continue;
3744  // For a binary operator 1-opIdx gets the other side.
3745  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3746  continue;
3747  bool samePredicate;
3748  if (cmp2.getPredicate() == cmp.getPredicate())
3749  samePredicate = true;
3750  else if (cmp2.getPredicate() ==
3751  arith::invertPredicate(cmp.getPredicate()))
3752  samePredicate = false;
3753  else
3754  continue;
3755 
3756  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3757  1);
3758  changed = true;
3759  }
3760  }
3761  }
3762  return success(changed);
3763  }
3764 };
3765 
3766 /// Remove unused init/yield args.
3767 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3769 
3770  LogicalResult matchAndRewrite(WhileOp op,
3771  PatternRewriter &rewriter) const override {
3772 
3773  if (!llvm::any_of(op.getBeforeArguments(),
3774  [](Value arg) { return arg.use_empty(); }))
3775  return rewriter.notifyMatchFailure(op, "No args to remove");
3776 
3777  YieldOp yield = op.getYieldOp();
3778 
3779  // Collect results mapping, new terminator args and new result types.
3780  SmallVector<Value> newYields;
3781  SmallVector<Value> newInits;
3782  llvm::BitVector argsToErase;
3783 
3784  size_t argsCount = op.getBeforeArguments().size();
3785  newYields.reserve(argsCount);
3786  newInits.reserve(argsCount);
3787  argsToErase.reserve(argsCount);
3788  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3789  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3790  if (beforeArg.use_empty()) {
3791  argsToErase.push_back(true);
3792  } else {
3793  argsToErase.push_back(false);
3794  newYields.emplace_back(yieldValue);
3795  newInits.emplace_back(initValue);
3796  }
3797  }
3798 
3799  Block &beforeBlock = *op.getBeforeBody();
3800  Block &afterBlock = *op.getAfterBody();
3801 
3802  beforeBlock.eraseArguments(argsToErase);
3803 
3804  Location loc = op.getLoc();
3805  auto newWhileOp =
3806  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
3807  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
3808  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3809  Block &newAfterBlock = *newWhileOp.getAfterBody();
3810 
3811  OpBuilder::InsertionGuard g(rewriter);
3812  rewriter.setInsertionPoint(yield);
3813  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3814 
3815  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
3816  newBeforeBlock.getArguments());
3817  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
3818  newAfterBlock.getArguments());
3819 
3820  rewriter.replaceOp(op, newWhileOp.getResults());
3821  return success();
3822  }
3823 };
3824 
3825 /// Remove duplicated ConditionOp args.
3826 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
3828 
3829  LogicalResult matchAndRewrite(WhileOp op,
3830  PatternRewriter &rewriter) const override {
3831  ConditionOp condOp = op.getConditionOp();
3832  ValueRange condOpArgs = condOp.getArgs();
3833 
3835  for (Value arg : condOpArgs)
3836  argsSet.insert(arg);
3837 
3838  if (argsSet.size() == condOpArgs.size())
3839  return rewriter.notifyMatchFailure(op, "No results to remove");
3840 
3841  llvm::SmallDenseMap<Value, unsigned> argsMap;
3842  SmallVector<Value> newArgs;
3843  argsMap.reserve(condOpArgs.size());
3844  newArgs.reserve(condOpArgs.size());
3845  for (Value arg : condOpArgs) {
3846  if (!argsMap.count(arg)) {
3847  auto pos = static_cast<unsigned>(argsMap.size());
3848  argsMap.insert({arg, pos});
3849  newArgs.emplace_back(arg);
3850  }
3851  }
3852 
3853  ValueRange argsRange(newArgs);
3854 
3855  Location loc = op.getLoc();
3856  auto newWhileOp = rewriter.create<scf::WhileOp>(
3857  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
3858  /*afterBody*/ nullptr);
3859  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3860  Block &newAfterBlock = *newWhileOp.getAfterBody();
3861 
3862  SmallVector<Value> afterArgsMapping;
3863  SmallVector<Value> resultsMapping;
3864  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
3865  auto it = argsMap.find(arg);
3866  assert(it != argsMap.end());
3867  auto pos = it->second;
3868  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
3869  resultsMapping.emplace_back(newWhileOp->getResult(pos));
3870  }
3871 
3872  OpBuilder::InsertionGuard g(rewriter);
3873  rewriter.setInsertionPoint(condOp);
3874  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3875  argsRange);
3876 
3877  Block &beforeBlock = *op.getBeforeBody();
3878  Block &afterBlock = *op.getAfterBody();
3879 
3880  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
3881  newBeforeBlock.getArguments());
3882  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
3883  rewriter.replaceOp(op, resultsMapping);
3884  return success();
3885  }
3886 };
3887 
3888 /// If both ranges contain same values return mappping indices from args2 to
3889 /// args1. Otherwise return std::nullopt.
3890 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
3891  ValueRange args2) {
3892  if (args1.size() != args2.size())
3893  return std::nullopt;
3894 
3895  SmallVector<unsigned> ret(args1.size());
3896  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
3897  auto it = llvm::find(args2, arg1);
3898  if (it == args2.end())
3899  return std::nullopt;
3900 
3901  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
3902  }
3903 
3904  return ret;
3905 }
3906 
3907 static bool hasDuplicates(ValueRange args) {
3908  llvm::SmallDenseSet<Value> set;
3909  for (Value arg : args) {
3910  if (set.contains(arg))
3911  return true;
3912 
3913  set.insert(arg);
3914  }
3915  return false;
3916 }
3917 
3918 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
3919 /// `scf.condition` args into same order as block args. Update `after` block
3920 /// args and op result values accordingly.
3921 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
3922 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
3924 
3925  LogicalResult matchAndRewrite(WhileOp loop,
3926  PatternRewriter &rewriter) const override {
3927  auto oldBefore = loop.getBeforeBody();
3928  ConditionOp oldTerm = loop.getConditionOp();
3929  ValueRange beforeArgs = oldBefore->getArguments();
3930  ValueRange termArgs = oldTerm.getArgs();
3931  if (beforeArgs == termArgs)
3932  return failure();
3933 
3934  if (hasDuplicates(termArgs))
3935  return failure();
3936 
3937  auto mapping = getArgsMapping(beforeArgs, termArgs);
3938  if (!mapping)
3939  return failure();
3940 
3941  {
3942  OpBuilder::InsertionGuard g(rewriter);
3943  rewriter.setInsertionPoint(oldTerm);
3944  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
3945  beforeArgs);
3946  }
3947 
3948  auto oldAfter = loop.getAfterBody();
3949 
3950  SmallVector<Type> newResultTypes(beforeArgs.size());
3951  for (auto &&[i, j] : llvm::enumerate(*mapping))
3952  newResultTypes[j] = loop.getResult(i).getType();
3953 
3954  auto newLoop = rewriter.create<WhileOp>(
3955  loop.getLoc(), newResultTypes, loop.getInits(),
3956  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
3957  auto newBefore = newLoop.getBeforeBody();
3958  auto newAfter = newLoop.getAfterBody();
3959 
3960  SmallVector<Value> newResults(beforeArgs.size());
3961  SmallVector<Value> newAfterArgs(beforeArgs.size());
3962  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
3963  newResults[i] = newLoop.getResult(j);
3964  newAfterArgs[i] = newAfter->getArgument(j);
3965  }
3966 
3967  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
3968  newBefore->getArguments());
3969  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
3970  newAfterArgs);
3971 
3972  rewriter.replaceOp(loop, newResults);
3973  return success();
3974  }
3975 };
3976 } // namespace
3977 
3978 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3979  MLIRContext *context) {
3980  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3981  RemoveLoopInvariantValueYielded, WhileConditionTruth,
3982  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3983  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
3984 }
3985 
3986 //===----------------------------------------------------------------------===//
3987 // IndexSwitchOp
3988 //===----------------------------------------------------------------------===//
3989 
3990 /// Parse the case regions and values.
3991 static ParseResult
3993  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
3994  SmallVector<int64_t> caseValues;
3995  while (succeeded(p.parseOptionalKeyword("case"))) {
3996  int64_t value;
3997  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
3998  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
3999  return failure();
4000  caseValues.push_back(value);
4001  }
4002  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4003  return success();
4004 }
4005 
4006 /// Print the case regions and values.
4008  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4009  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4010  p.printNewline();
4011  p << "case " << value << ' ';
4012  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4013  }
4014 }
4015 
4017  if (getCases().size() != getCaseRegions().size()) {
4018  return emitOpError("has ")
4019  << getCaseRegions().size() << " case regions but "
4020  << getCases().size() << " case values";
4021  }
4022 
4023  DenseSet<int64_t> valueSet;
4024  for (int64_t value : getCases())
4025  if (!valueSet.insert(value).second)
4026  return emitOpError("has duplicate case value: ") << value;
4027  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4028  auto yield = dyn_cast<YieldOp>(region.front().back());
4029  if (!yield)
4030  return emitOpError("expected region to end with scf.yield, but got ")
4031  << region.front().back().getName();
4032 
4033  if (yield.getNumOperands() != getNumResults()) {
4034  return (emitOpError("expected each region to return ")
4035  << getNumResults() << " values, but " << name << " returns "
4036  << yield.getNumOperands())
4037  .attachNote(yield.getLoc())
4038  << "see yield operation here";
4039  }
4040  for (auto [idx, result, operand] :
4041  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4042  yield.getOperandTypes())) {
4043  if (result == operand)
4044  continue;
4045  return (emitOpError("expected result #")
4046  << idx << " of each region to be " << result)
4047  .attachNote(yield.getLoc())
4048  << name << " returns " << operand << " here";
4049  }
4050  return success();
4051  };
4052 
4053  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4054  return failure();
4055  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4056  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4057  return failure();
4058 
4059  return success();
4060 }
4061 
4062 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4063 
4064 Block &scf::IndexSwitchOp::getDefaultBlock() {
4065  return getDefaultRegion().front();
4066 }
4067 
4068 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4069  assert(idx < getNumCases() && "case index out-of-bounds");
4070  return getCaseRegions()[idx].front();
4071 }
4072 
4073 void IndexSwitchOp::getSuccessorRegions(
4075  // All regions branch back to the parent op.
4076  if (!point.isParent()) {
4077  successors.emplace_back(getResults());
4078  return;
4079  }
4080 
4081  llvm::copy(getRegions(), std::back_inserter(successors));
4082 }
4083 
4084 void IndexSwitchOp::getEntrySuccessorRegions(
4085  ArrayRef<Attribute> operands,
4086  SmallVectorImpl<RegionSuccessor> &successors) {
4087  FoldAdaptor adaptor(operands, *this);
4088 
4089  // If a constant was not provided, all regions are possible successors.
4090  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4091  if (!arg) {
4092  llvm::copy(getRegions(), std::back_inserter(successors));
4093  return;
4094  }
4095 
4096  // Otherwise, try to find a case with a matching value. If not, the
4097  // default region is the only successor.
4098  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4099  if (caseValue == arg.getInt()) {
4100  successors.emplace_back(&caseRegion);
4101  return;
4102  }
4103  }
4104  successors.emplace_back(&getDefaultRegion());
4105 }
4106 
4107 void IndexSwitchOp::getRegionInvocationBounds(
4109  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4110  if (!operandValue) {
4111  // All regions are invoked at most once.
4112  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4113  return;
4114  }
4115 
4116  unsigned liveIndex = getNumRegions() - 1;
4117  const auto *it = llvm::find(getCases(), operandValue.getInt());
4118  if (it != getCases().end())
4119  liveIndex = std::distance(getCases().begin(), it);
4120  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4121  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4122 }
4123 
4124 LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
4125  SmallVectorImpl<OpFoldResult> &results) {
4126  std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
4127  if (!maybeCst.has_value())
4128  return failure();
4129  int64_t cst = *maybeCst;
4130  int64_t caseIdx, e = getNumCases();
4131  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4132  if (cst == getCases()[caseIdx])
4133  break;
4134  }
4135 
4136  Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4137  : getDefaultRegion();
4138  Block &source = r.front();
4139  results.assign(source.getTerminator()->getOperands().begin(),
4140  source.getTerminator()->getOperands().end());
4141 
4142  Block *pDestination = (*this)->getBlock();
4143  if (!pDestination)
4144  return failure();
4145  Block::iterator insertionPoint = (*this)->getIterator();
4146  pDestination->getOperations().splice(insertionPoint, source.getOperations(),
4147  source.getOperations().begin(),
4148  std::prev(source.getOperations().end()));
4149 
4150  return success();
4151 }
4152 
4153 //===----------------------------------------------------------------------===//
4154 // TableGen'd op method definitions
4155 //===----------------------------------------------------------------------===//
4156 
4157 #define GET_OP_CLASSES
4158 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:716
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:708
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:3992
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3236
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:432
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:93
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4007
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:137
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation & back()
Definition: Block.h:149
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator begin()
Definition: Block.h:140
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:93
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:582
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:586
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:131
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:68
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:2849
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:86
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1768
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1442
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:263
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:234
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:185
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.