MLIR  18.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 using namespace mlir;
29 using namespace mlir::scf;
30 
31 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // SCFDialect Dialect Interfaces
35 //===----------------------------------------------------------------------===//
36 
37 namespace {
38 struct SCFInlinerInterface : public DialectInlinerInterface {
40  // We don't have any special restrictions on what can be inlined into
41  // destination regions (e.g. while/conditional bodies). Always allow it.
42  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
43  IRMapping &valueMapping) const final {
44  return true;
45  }
46  // Operations in scf dialect are always legal to inline since they are
47  // pure.
48  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
49  return true;
50  }
51  // Handle the given inlined terminator by replacing it with a new operation
52  // as necessary. Required when the region has only one block.
53  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
54  auto retValOp = dyn_cast<scf::YieldOp>(op);
55  if (!retValOp)
56  return;
57 
58  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
59  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
60  }
61  }
62 };
63 } // namespace
64 
65 //===----------------------------------------------------------------------===//
66 // SCFDialect
67 //===----------------------------------------------------------------------===//
68 
69 void SCFDialect::initialize() {
70  addOperations<
71 #define GET_OP_LIST
72 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
73  >();
74  addInterfaces<SCFInlinerInterface>();
75 }
76 
77 /// Default callback for IfOp builders. Inserts a yield without arguments.
79  builder.create<scf::YieldOp>(loc);
80 }
81 
82 /// Verifies that the first block of the given `region` is terminated by a
83 /// TerminatorTy. Reports errors on the given operation if it is not the case.
84 template <typename TerminatorTy>
85 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
86  StringRef errorMessage) {
87  Operation *terminatorOperation = nullptr;
88  if (!region.empty() && !region.front().empty()) {
89  terminatorOperation = &region.front().back();
90  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
91  return yield;
92  }
93  auto diag = op->emitOpError(errorMessage);
94  if (terminatorOperation)
95  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
96  return nullptr;
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // ExecuteRegionOp
101 //===----------------------------------------------------------------------===//
102 
103 /// Replaces the given op with the contents of the given single-block region,
104 /// using the operands of the block terminator to replace operation results.
105 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
106  Region &region, ValueRange blockArgs = {}) {
107  assert(llvm::hasSingleElement(region) && "expected single-region block");
108  Block *block = &region.front();
109  Operation *terminator = block->getTerminator();
110  ValueRange results = terminator->getOperands();
111  rewriter.inlineBlockBefore(block, op, blockArgs);
112  rewriter.replaceOp(op, results);
113  rewriter.eraseOp(terminator);
114 }
115 
116 ///
117 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
118 /// block+
119 /// `}`
120 ///
121 /// Example:
122 /// scf.execute_region -> i32 {
123 /// %idx = load %rI[%i] : memref<128xi32>
124 /// return %idx : i32
125 /// }
126 ///
128  OperationState &result) {
129  if (parser.parseOptionalArrowTypeList(result.types))
130  return failure();
131 
132  // Introduce the body region and parse it.
133  Region *body = result.addRegion();
134  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
135  parser.parseOptionalAttrDict(result.attributes))
136  return failure();
137 
138  return success();
139 }
140 
142  p.printOptionalArrowTypeList(getResultTypes());
143 
144  p << ' ';
145  p.printRegion(getRegion(),
146  /*printEntryBlockArgs=*/false,
147  /*printBlockTerminators=*/true);
148 
149  p.printOptionalAttrDict((*this)->getAttrs());
150 }
151 
153  if (getRegion().empty())
154  return emitOpError("region needs to have at least one block");
155  if (getRegion().front().getNumArguments() > 0)
156  return emitOpError("region cannot have any arguments");
157  return success();
158 }
159 
160 // Inline an ExecuteRegionOp if it only contains one block.
161 // "test.foo"() : () -> ()
162 // %v = scf.execute_region -> i64 {
163 // %x = "test.val"() : () -> i64
164 // scf.yield %x : i64
165 // }
166 // "test.bar"(%v) : (i64) -> ()
167 //
168 // becomes
169 //
170 // "test.foo"() : () -> ()
171 // %x = "test.val"() : () -> i64
172 // "test.bar"(%x) : (i64) -> ()
173 //
174 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
176 
177  LogicalResult matchAndRewrite(ExecuteRegionOp op,
178  PatternRewriter &rewriter) const override {
179  if (!llvm::hasSingleElement(op.getRegion()))
180  return failure();
181  replaceOpWithRegion(rewriter, op, op.getRegion());
182  return success();
183  }
184 };
185 
186 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
187 // TODO generalize the conditions for operations which can be inlined into.
188 // func @func_execute_region_elim() {
189 // "test.foo"() : () -> ()
190 // %v = scf.execute_region -> i64 {
191 // %c = "test.cmp"() : () -> i1
192 // cf.cond_br %c, ^bb2, ^bb3
193 // ^bb2:
194 // %x = "test.val1"() : () -> i64
195 // cf.br ^bb4(%x : i64)
196 // ^bb3:
197 // %y = "test.val2"() : () -> i64
198 // cf.br ^bb4(%y : i64)
199 // ^bb4(%z : i64):
200 // scf.yield %z : i64
201 // }
202 // "test.bar"(%v) : (i64) -> ()
203 // return
204 // }
205 //
206 // becomes
207 //
208 // func @func_execute_region_elim() {
209 // "test.foo"() : () -> ()
210 // %c = "test.cmp"() : () -> i1
211 // cf.cond_br %c, ^bb1, ^bb2
212 // ^bb1: // pred: ^bb0
213 // %x = "test.val1"() : () -> i64
214 // cf.br ^bb3(%x : i64)
215 // ^bb2: // pred: ^bb0
216 // %y = "test.val2"() : () -> i64
217 // cf.br ^bb3(%y : i64)
218 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
219 // "test.bar"(%z) : (i64) -> ()
220 // return
221 // }
222 //
223 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
225 
226  LogicalResult matchAndRewrite(ExecuteRegionOp op,
227  PatternRewriter &rewriter) const override {
228  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
229  return failure();
230 
231  Block *prevBlock = op->getBlock();
232  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
233  rewriter.setInsertionPointToEnd(prevBlock);
234 
235  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
236 
237  for (Block &blk : op.getRegion()) {
238  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
239  rewriter.setInsertionPoint(yieldOp);
240  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
241  yieldOp.getResults());
242  rewriter.eraseOp(yieldOp);
243  }
244  }
245 
246  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
247  SmallVector<Value> blockArgs;
248 
249  for (auto res : op.getResults())
250  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
251 
252  rewriter.replaceOp(op, blockArgs);
253  return success();
254  }
255 };
256 
257 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
258  MLIRContext *context) {
260 }
261 
262 void ExecuteRegionOp::getSuccessorRegions(
264  // If the predecessor is the ExecuteRegionOp, branch into the body.
265  if (point.isParent()) {
266  regions.push_back(RegionSuccessor(&getRegion()));
267  return;
268  }
269 
270  // Otherwise, the region branches back to the parent operation.
271  regions.push_back(RegionSuccessor(getResults()));
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // ConditionOp
276 //===----------------------------------------------------------------------===//
277 
280  assert((point.isParent() || point == getParentOp().getAfter()) &&
281  "condition op can only exit the loop or branch to the after"
282  "region");
283  // Pass all operands except the condition to the successor region.
284  return getArgsMutable();
285 }
286 
287 void ConditionOp::getSuccessorRegions(
289  FoldAdaptor adaptor(operands, *this);
290 
291  WhileOp whileOp = getParentOp();
292 
293  // Condition can either lead to the after region or back to the parent op
294  // depending on whether the condition is true or not.
295  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
296  if (!boolAttr || boolAttr.getValue())
297  regions.emplace_back(&whileOp.getAfter(),
298  whileOp.getAfter().getArguments());
299  if (!boolAttr || !boolAttr.getValue())
300  regions.emplace_back(whileOp.getResults());
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // ForOp
305 //===----------------------------------------------------------------------===//
306 
307 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
308  Value ub, Value step, ValueRange iterArgs,
309  BodyBuilderFn bodyBuilder) {
310  result.addOperands({lb, ub, step});
311  result.addOperands(iterArgs);
312  for (Value v : iterArgs)
313  result.addTypes(v.getType());
314  Type t = lb.getType();
315  Region *bodyRegion = result.addRegion();
316  bodyRegion->push_back(new Block);
317  Block &bodyBlock = bodyRegion->front();
318  bodyBlock.addArgument(t, result.location);
319  for (Value v : iterArgs)
320  bodyBlock.addArgument(v.getType(), v.getLoc());
321 
322  // Create the default terminator if the builder is not provided and if the
323  // iteration arguments are not provided. Otherwise, leave this to the caller
324  // because we don't know which values to return from the loop.
325  if (iterArgs.empty() && !bodyBuilder) {
326  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
327  } else if (bodyBuilder) {
328  OpBuilder::InsertionGuard guard(builder);
329  builder.setInsertionPointToStart(&bodyBlock);
330  bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
331  bodyBlock.getArguments().drop_front());
332  }
333 }
334 
336  IntegerAttr step;
337  if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
338  return emitOpError("constant step operand must be positive");
339 
340  // Check that the number of init args and op results is the same.
341  if (getInitArgs().size() != getNumResults())
342  return emitOpError(
343  "mismatch in number of loop-carried values and defined values");
344 
345  return success();
346 }
347 
348 LogicalResult ForOp::verifyRegions() {
349  // Check that the body defines as single block argument for the induction
350  // variable.
351  if (getInductionVar().getType() != getLowerBound().getType())
352  return emitOpError(
353  "expected induction variable to be same type as bounds and step");
354 
355  if (getNumRegionIterArgs() != getNumResults())
356  return emitOpError(
357  "mismatch in number of basic block args and defined values");
358 
359  auto initArgs = getInitArgs();
360  auto iterArgs = getRegionIterArgs();
361  auto opResults = getResults();
362  unsigned i = 0;
363  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
364  if (std::get<0>(e).getType() != std::get<2>(e).getType())
365  return emitOpError() << "types mismatch between " << i
366  << "th iter operand and defined value";
367  if (std::get<1>(e).getType() != std::get<2>(e).getType())
368  return emitOpError() << "types mismatch between " << i
369  << "th iter region arg and defined value";
370 
371  ++i;
372  }
373  return success();
374 }
375 
376 std::optional<Value> ForOp::getSingleInductionVar() {
377  return getInductionVar();
378 }
379 
380 std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
381  return OpFoldResult(getLowerBound());
382 }
383 
384 std::optional<OpFoldResult> ForOp::getSingleStep() {
385  return OpFoldResult(getStep());
386 }
387 
388 std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
389  return OpFoldResult(getUpperBound());
390 }
391 
392 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
393 
394 /// Promotes the loop body of a forOp to its containing block if the forOp
395 /// it can be determined that the loop has a single iteration.
397  std::optional<int64_t> tripCount =
399  if (!tripCount.has_value() || tripCount != 1)
400  return failure();
401 
402  // Replace all results with the yielded values.
403  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
404  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
405 
406  // Replace block arguments with lower bound (replacement for IV) and
407  // iter_args.
408  SmallVector<Value> bbArgReplacements;
409  bbArgReplacements.push_back(getLowerBound());
410  llvm::append_range(bbArgReplacements, getInitArgs());
411 
412  // Move the loop body operations to the loop's containing block.
413  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
414  getOperation()->getIterator(), bbArgReplacements);
415 
416  // Erase the old terminator and the loop.
417  rewriter.eraseOp(yieldOp);
418  rewriter.eraseOp(*this);
419 
420  return success();
421 }
422 
423 /// Prints the initialization list in the form of
424 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
425 /// where 'inner' values are assumed to be region arguments and 'outer' values
426 /// are regular SSA values.
428  Block::BlockArgListType blocksArgs,
429  ValueRange initializers,
430  StringRef prefix = "") {
431  assert(blocksArgs.size() == initializers.size() &&
432  "expected same length of arguments and initializers");
433  if (initializers.empty())
434  return;
435 
436  p << prefix << '(';
437  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
438  p << std::get<0>(it) << " = " << std::get<1>(it);
439  });
440  p << ")";
441 }
442 
443 void ForOp::print(OpAsmPrinter &p) {
444  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
445  << getUpperBound() << " step " << getStep();
446 
447  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
448  if (!getInitArgs().empty())
449  p << " -> (" << getInitArgs().getTypes() << ')';
450  p << ' ';
451  if (Type t = getInductionVar().getType(); !t.isIndex())
452  p << " : " << t << ' ';
453  p.printRegion(getRegion(),
454  /*printEntryBlockArgs=*/false,
455  /*printBlockTerminators=*/!getInitArgs().empty());
456  p.printOptionalAttrDict((*this)->getAttrs());
457 }
458 
460  auto &builder = parser.getBuilder();
461  Type type;
462 
463  OpAsmParser::Argument inductionVariable;
464  OpAsmParser::UnresolvedOperand lb, ub, step;
465 
466  // Parse the induction variable followed by '='.
467  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
468  // Parse loop bounds.
469  parser.parseOperand(lb) || parser.parseKeyword("to") ||
470  parser.parseOperand(ub) || parser.parseKeyword("step") ||
471  parser.parseOperand(step))
472  return failure();
473 
474  // Parse the optional initial iteration arguments.
477  regionArgs.push_back(inductionVariable);
478 
479  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
480  if (hasIterArgs) {
481  // Parse assignment list and results type list.
482  if (parser.parseAssignmentList(regionArgs, operands) ||
483  parser.parseArrowTypeList(result.types))
484  return failure();
485  }
486 
487  if (regionArgs.size() != result.types.size() + 1)
488  return parser.emitError(
489  parser.getNameLoc(),
490  "mismatch in number of loop-carried values and defined values");
491 
492  // Parse optional type, else assume Index.
493  if (parser.parseOptionalColon())
494  type = builder.getIndexType();
495  else if (parser.parseType(type))
496  return failure();
497 
498  // Resolve input operands.
499  regionArgs.front().type = type;
500  if (parser.resolveOperand(lb, type, result.operands) ||
501  parser.resolveOperand(ub, type, result.operands) ||
502  parser.resolveOperand(step, type, result.operands))
503  return failure();
504  if (hasIterArgs) {
505  for (auto argOperandType :
506  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
507  Type type = std::get<2>(argOperandType);
508  std::get<0>(argOperandType).type = type;
509  if (parser.resolveOperand(std::get<1>(argOperandType), type,
510  result.operands))
511  return failure();
512  }
513  }
514 
515  // Parse the body region.
516  Region *body = result.addRegion();
517  if (parser.parseRegion(*body, regionArgs))
518  return failure();
519 
520  ForOp::ensureTerminator(*body, builder, result.location);
521 
522  // Parse the optional attribute list.
523  if (parser.parseOptionalAttrDict(result.attributes))
524  return failure();
525 
526  return success();
527 }
528 
529 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
530 
531 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
532  return getInitArgsMutable();
533 }
534 
536 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
537  ValueRange newInitOperands,
538  bool replaceInitOperandUsesInLoop,
539  const NewYieldValuesFn &newYieldValuesFn) {
540  // Create a new loop before the existing one, with the extra operands.
541  OpBuilder::InsertionGuard g(rewriter);
542  rewriter.setInsertionPoint(getOperation());
543  auto inits = llvm::to_vector(getInitArgs());
544  inits.append(newInitOperands.begin(), newInitOperands.end());
545  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
546  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
547  [](OpBuilder &, Location, Value, ValueRange) {});
548 
549  // Generate the new yield values and append them to the scf.yield operation.
550  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
551  ArrayRef<BlockArgument> newIterArgs =
552  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
553  {
554  OpBuilder::InsertionGuard g(rewriter);
555  rewriter.setInsertionPoint(yieldOp);
556  SmallVector<Value> newYieldedValues =
557  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
558  assert(newInitOperands.size() == newYieldedValues.size() &&
559  "expected as many new yield values as new iter operands");
560  rewriter.updateRootInPlace(yieldOp, [&]() {
561  yieldOp.getResultsMutable().append(newYieldedValues);
562  });
563  }
564 
565  // Move the loop body to the new op.
566  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
567  newLoop.getBody()->getArguments().take_front(
568  getBody()->getNumArguments()));
569 
570  if (replaceInitOperandUsesInLoop) {
571  // Replace all uses of `newInitOperands` with the corresponding basic block
572  // arguments.
573  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
574  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
575  [&](OpOperand &use) {
576  Operation *user = use.getOwner();
577  return newLoop->isProperAncestor(user);
578  });
579  }
580  }
581 
582  // Replace the old loop.
583  rewriter.replaceOp(getOperation(),
584  newLoop->getResults().take_front(getNumResults()));
585  return cast<LoopLikeOpInterface>(newLoop.getOperation());
586 }
587 
589  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
590  if (!ivArg)
591  return ForOp();
592  assert(ivArg.getOwner() && "unlinked block argument");
593  auto *containingOp = ivArg.getOwner()->getParentOp();
594  return dyn_cast_or_null<ForOp>(containingOp);
595 }
596 
597 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
598  return getInitArgs();
599 }
600 
601 void ForOp::getSuccessorRegions(RegionBranchPoint point,
603  // Both the operation itself and the region may be branching into the body or
604  // back into the operation itself. It is possible for loop not to enter the
605  // body.
606  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
607  regions.push_back(RegionSuccessor(getResults()));
608 }
609 
610 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
611 
612 /// Promotes the loop body of a forallOp to its containing block if it can be
613 /// determined that the loop has a single iteration.
615  for (auto [lb, ub, step] :
616  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
617  auto tripCount = constantTripCount(lb, ub, step);
618  if (!tripCount.has_value() || *tripCount != 1)
619  return failure();
620  }
621 
622  promote(rewriter, *this);
623  return success();
624 }
625 
626 /// Promotes the loop body of a scf::ForallOp to its containing block.
627 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
628  OpBuilder::InsertionGuard g(rewriter);
629  scf::InParallelOp terminator = forallOp.getTerminator();
630 
631  // Replace block arguments with lower bounds (replacements for IVs) and
632  // outputs.
633  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
634  bbArgReplacements.append(forallOp.getOutputs().begin(),
635  forallOp.getOutputs().end());
636 
637  // Move the loop body operations to the loop's containing block.
638  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
639  forallOp->getIterator(), bbArgReplacements);
640 
641  // Replace the terminator with tensor.insert_slice ops.
642  rewriter.setInsertionPointAfter(forallOp);
643  SmallVector<Value> results;
644  results.reserve(forallOp.getResults().size());
645  for (auto &yieldingOp : terminator.getYieldingOps()) {
646  auto parallelInsertSliceOp =
647  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
648 
649  Value dst = parallelInsertSliceOp.getDest();
650  Value src = parallelInsertSliceOp.getSource();
651  if (llvm::isa<TensorType>(src.getType())) {
652  results.push_back(rewriter.create<tensor::InsertSliceOp>(
653  forallOp.getLoc(), dst.getType(), src, dst,
654  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
655  parallelInsertSliceOp.getStrides(),
656  parallelInsertSliceOp.getStaticOffsets(),
657  parallelInsertSliceOp.getStaticSizes(),
658  parallelInsertSliceOp.getStaticStrides()));
659  } else {
660  llvm_unreachable("unsupported terminator");
661  }
662  }
663  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
664 
665  // Erase the old terminator and the loop.
666  rewriter.eraseOp(terminator);
667  rewriter.eraseOp(forallOp);
668 }
669 
671  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
672  ValueRange steps, ValueRange iterArgs,
674  bodyBuilder) {
675  assert(lbs.size() == ubs.size() &&
676  "expected the same number of lower and upper bounds");
677  assert(lbs.size() == steps.size() &&
678  "expected the same number of lower bounds and steps");
679 
680  // If there are no bounds, call the body-building function and return early.
681  if (lbs.empty()) {
682  ValueVector results =
683  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
684  : ValueVector();
685  assert(results.size() == iterArgs.size() &&
686  "loop nest body must return as many values as loop has iteration "
687  "arguments");
688  return LoopNest{{}, std::move(results)};
689  }
690 
691  // First, create the loop structure iteratively using the body-builder
692  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
693  OpBuilder::InsertionGuard guard(builder);
696  loops.reserve(lbs.size());
697  ivs.reserve(lbs.size());
698  ValueRange currentIterArgs = iterArgs;
699  Location currentLoc = loc;
700  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
701  auto loop = builder.create<scf::ForOp>(
702  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
703  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
704  ValueRange args) {
705  ivs.push_back(iv);
706  // It is safe to store ValueRange args because it points to block
707  // arguments of a loop operation that we also own.
708  currentIterArgs = args;
709  currentLoc = nestedLoc;
710  });
711  // Set the builder to point to the body of the newly created loop. We don't
712  // do this in the callback because the builder is reset when the callback
713  // returns.
714  builder.setInsertionPointToStart(loop.getBody());
715  loops.push_back(loop);
716  }
717 
718  // For all loops but the innermost, yield the results of the nested loop.
719  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
720  builder.setInsertionPointToEnd(loops[i].getBody());
721  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
722  }
723 
724  // In the body of the innermost loop, call the body building function if any
725  // and yield its results.
726  builder.setInsertionPointToStart(loops.back().getBody());
727  ValueVector results = bodyBuilder
728  ? bodyBuilder(builder, currentLoc, ivs,
729  loops.back().getRegionIterArgs())
730  : ValueVector();
731  assert(results.size() == iterArgs.size() &&
732  "loop nest body must return as many values as loop has iteration "
733  "arguments");
734  builder.setInsertionPointToEnd(loops.back().getBody());
735  builder.create<scf::YieldOp>(loc, results);
736 
737  // Return the loops.
738  ValueVector nestResults;
739  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
740  return LoopNest{std::move(loops), std::move(nestResults)};
741 }
742 
744  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
745  ValueRange steps,
746  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
747  // Delegate to the main function by wrapping the body builder.
748  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
749  [&bodyBuilder](OpBuilder &nestedBuilder,
750  Location nestedLoc, ValueRange ivs,
751  ValueRange) -> ValueVector {
752  if (bodyBuilder)
753  bodyBuilder(nestedBuilder, nestedLoc, ivs);
754  return {};
755  });
756 }
757 
758 namespace {
759 // Fold away ForOp iter arguments when:
760 // 1) The op yields the iter arguments.
761 // 2) The iter arguments have no use and the corresponding outer region
762 // iterators (inputs) are yielded.
763 // 3) The iter arguments have no use and the corresponding (operation) results
764 // have no use.
765 //
766 // These arguments must be defined outside of
767 // the ForOp region and can just be forwarded after simplifying the op inits,
768 // yields and returns.
769 //
770 // The implementation uses `inlineBlockBefore` to steal the content of the
771 // original ForOp and avoid cloning.
772 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
774 
775  LogicalResult matchAndRewrite(scf::ForOp forOp,
776  PatternRewriter &rewriter) const final {
777  bool canonicalize = false;
778 
779  // An internal flat vector of block transfer
780  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
781  // transformed block argument mappings. This plays the role of a
782  // IRMapping for the particular use case of calling into
783  // `inlineBlockBefore`.
784  int64_t numResults = forOp.getNumResults();
785  SmallVector<bool, 4> keepMask;
786  keepMask.reserve(numResults);
787  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
788  newResultValues;
789  newBlockTransferArgs.reserve(1 + numResults);
790  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
791  newIterArgs.reserve(forOp.getInitArgs().size());
792  newYieldValues.reserve(numResults);
793  newResultValues.reserve(numResults);
794  for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
795  forOp.getRegionIterArgs(), // iter inside region
796  forOp.getResults(), // op results
797  forOp.getYieldedValues() // iter yield
798  )) {
799  // Forwarded is `true` when:
800  // 1) The region `iter` argument is yielded.
801  // 2) The region `iter` argument has no use, and the corresponding iter
802  // operand (input) is yielded.
803  // 3) The region `iter` argument has no use, and the corresponding op
804  // result has no use.
805  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
806  (std::get<1>(it).use_empty() &&
807  (std::get<0>(it) == std::get<3>(it) ||
808  std::get<2>(it).use_empty())));
809  keepMask.push_back(!forwarded);
810  canonicalize |= forwarded;
811  if (forwarded) {
812  newBlockTransferArgs.push_back(std::get<0>(it));
813  newResultValues.push_back(std::get<0>(it));
814  continue;
815  }
816  newIterArgs.push_back(std::get<0>(it));
817  newYieldValues.push_back(std::get<3>(it));
818  newBlockTransferArgs.push_back(Value()); // placeholder with null value
819  newResultValues.push_back(Value()); // placeholder with null value
820  }
821 
822  if (!canonicalize)
823  return failure();
824 
825  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
826  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
827  forOp.getStep(), newIterArgs);
828  newForOp->setAttrs(forOp->getAttrs());
829  Block &newBlock = newForOp.getRegion().front();
830 
831  // Replace the null placeholders with newly constructed values.
832  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
833  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
834  idx != e; ++idx) {
835  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
836  Value &newResultVal = newResultValues[idx];
837  assert((blockTransferArg && newResultVal) ||
838  (!blockTransferArg && !newResultVal));
839  if (!blockTransferArg) {
840  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
841  newResultVal = newForOp.getResult(collapsedIdx++);
842  }
843  }
844 
845  Block &oldBlock = forOp.getRegion().front();
846  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
847  "unexpected argument size mismatch");
848 
849  // No results case: the scf::ForOp builder already created a zero
850  // result terminator. Merge before this terminator and just get rid of the
851  // original terminator that has been merged in.
852  if (newIterArgs.empty()) {
853  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
854  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
855  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
856  rewriter.replaceOp(forOp, newResultValues);
857  return success();
858  }
859 
860  // No terminator case: merge and rewrite the merged terminator.
861  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
862  OpBuilder::InsertionGuard g(rewriter);
863  rewriter.setInsertionPoint(mergedTerminator);
864  SmallVector<Value, 4> filteredOperands;
865  filteredOperands.reserve(newResultValues.size());
866  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
867  if (keepMask[idx])
868  filteredOperands.push_back(mergedTerminator.getOperand(idx));
869  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
870  filteredOperands);
871  };
872 
873  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
874  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
875  cloneFilteredTerminator(mergedYieldOp);
876  rewriter.eraseOp(mergedYieldOp);
877  rewriter.replaceOp(forOp, newResultValues);
878  return success();
879  }
880 };
881 
882 /// Util function that tries to compute a constant diff between u and l.
883 /// Returns std::nullopt when the difference between two AffineValueMap is
884 /// dynamic.
885 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
886  IntegerAttr clb, cub;
887  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
888  llvm::APInt lbValue = clb.getValue();
889  llvm::APInt ubValue = cub.getValue();
890  return (ubValue - lbValue).getSExtValue();
891  }
892 
893  // Else a simple pattern match for x + c or c + x
894  llvm::APInt diff;
895  if (matchPattern(
896  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
897  matchPattern(
898  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
899  return diff.getSExtValue();
900  return std::nullopt;
901 }
902 
903 /// Rewriting pattern that erases loops that are known not to iterate, replaces
904 /// single-iteration loops with their bodies, and removes empty loops that
905 /// iterate at least once and only return values defined outside of the loop.
906 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
908 
909  LogicalResult matchAndRewrite(ForOp op,
910  PatternRewriter &rewriter) const override {
911  // If the upper bound is the same as the lower bound, the loop does not
912  // iterate, just remove it.
913  if (op.getLowerBound() == op.getUpperBound()) {
914  rewriter.replaceOp(op, op.getInitArgs());
915  return success();
916  }
917 
918  std::optional<int64_t> diff =
919  computeConstDiff(op.getLowerBound(), op.getUpperBound());
920  if (!diff)
921  return failure();
922 
923  // If the loop is known to have 0 iterations, remove it.
924  if (*diff <= 0) {
925  rewriter.replaceOp(op, op.getInitArgs());
926  return success();
927  }
928 
929  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
930  if (!maybeStepValue)
931  return failure();
932 
933  // If the loop is known to have 1 iteration, inline its body and remove the
934  // loop.
935  llvm::APInt stepValue = *maybeStepValue;
936  if (stepValue.sge(*diff)) {
937  SmallVector<Value, 4> blockArgs;
938  blockArgs.reserve(op.getInitArgs().size() + 1);
939  blockArgs.push_back(op.getLowerBound());
940  llvm::append_range(blockArgs, op.getInitArgs());
941  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
942  return success();
943  }
944 
945  // Now we are left with loops that have more than 1 iterations.
946  Block &block = op.getRegion().front();
947  if (!llvm::hasSingleElement(block))
948  return failure();
949  // If the loop is empty, iterates at least once, and only returns values
950  // defined outside of the loop, remove it and replace it with yield values.
951  if (llvm::any_of(op.getYieldedValues(),
952  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
953  return failure();
954  rewriter.replaceOp(op, op.getYieldedValues());
955  return success();
956  }
957 };
958 
959 /// Perform a replacement of one iter OpOperand of an scf.for to the
960 /// `replacement` value which is expected to be the source of a tensor.cast.
961 /// tensor.cast ops are inserted inside the block to account for the type cast.
962 static SmallVector<Value>
963 replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
964  Value replacement) {
965  Type oldType = operand.get().getType(), newType = replacement.getType();
966  assert(llvm::isa<RankedTensorType>(oldType) &&
967  llvm::isa<RankedTensorType>(newType) &&
968  "expected ranked tensor types");
969 
970  // 1. Create new iter operands, exactly 1 is replaced.
971  ForOp forOp = cast<ForOp>(operand.getOwner());
972  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
973  "expected an iter OpOperand");
974  assert(operand.get().getType() != replacement.getType() &&
975  "Expected a different type");
976  SmallVector<Value> newIterOperands;
977  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
978  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
979  newIterOperands.push_back(replacement);
980  continue;
981  }
982  newIterOperands.push_back(opOperand.get());
983  }
984 
985  // 2. Create the new forOp shell.
986  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
987  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
988  forOp.getStep(), newIterOperands);
989  newForOp->setAttrs(forOp->getAttrs());
990  Block &newBlock = newForOp.getRegion().front();
991  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
992  newBlock.getArguments().end());
993 
994  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
995  // corresponding to the `replacement` value.
996  OpBuilder::InsertionGuard g(rewriter);
997  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
998  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
999  &newForOp->getOpOperand(operand.getOperandNumber()));
1000  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
1001  newRegionIterArg);
1002  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
1003 
1004  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1005  Block &oldBlock = forOp.getRegion().front();
1006  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1007 
1008  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
1009  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1010  rewriter.setInsertionPoint(clonedYieldOp);
1011  unsigned yieldIdx =
1012  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
1013  Value castOut = rewriter.create<tensor::CastOp>(
1014  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1015  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
1016  newYieldOperands[yieldIdx] = castOut;
1017  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1018  rewriter.eraseOp(clonedYieldOp);
1019 
1020  // 6. Inject an outgoing cast op after the forOp.
1021  rewriter.setInsertionPointAfter(newForOp);
1022  SmallVector<Value> newResults = newForOp.getResults();
1023  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
1024  newForOp.getLoc(), oldType, newResults[yieldIdx]);
1025 
1026  return newResults;
1027 }
1028 
1029 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1030 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1031 ///
1032 /// ```
1033 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1034 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1035 /// -> (tensor<?x?xf32>) {
1036 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1037 /// scf.yield %2 : tensor<?x?xf32>
1038 /// }
1039 /// use_of(%1)
1040 /// ```
1041 ///
1042 /// folds into:
1043 ///
1044 /// ```
1045 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1046 /// -> (tensor<32x1024xf32>) {
1047 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1048 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1049 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1050 /// scf.yield %4 : tensor<32x1024xf32>
1051 /// }
1052 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1053 /// use_of(%1)
1054 /// ```
1055 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1057 
1058  LogicalResult matchAndRewrite(ForOp op,
1059  PatternRewriter &rewriter) const override {
1060  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1061  OpOperand &iterOpOperand = std::get<0>(it);
1062  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1063  if (!incomingCast ||
1064  incomingCast.getSource().getType() == incomingCast.getType())
1065  continue;
1066  // If the dest type of the cast does not preserve static information in
1067  // the source type.
1069  incomingCast.getDest().getType(),
1070  incomingCast.getSource().getType()))
1071  continue;
1072  if (!std::get<1>(it).hasOneUse())
1073  continue;
1074 
1075  // Create a new ForOp with that iter operand replaced.
1076  rewriter.replaceOp(
1077  op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1078  incomingCast.getSource()));
1079  return success();
1080  }
1081  return failure();
1082  }
1083 };
1084 
1085 /// Canonicalize the iter_args of an scf::ForOp that involve a
1086 /// `bufferization.to_tensor` and for which only the last loop iteration is
1087 /// actually visible outside of the loop. The canonicalization looks for a
1088 /// pattern such as:
1089 /// ```
1090 /// %t0 = ... : tensor_type
1091 /// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
1092 /// ...
1093 /// // %m is either buffer_cast(%bb00) or defined above the loop
1094 /// %m... : memref_type
1095 /// ... // uses of %m with potential inplace updates
1096 /// %new_tensor = bufferization.to_tensor %m : memref_type
1097 /// ...
1098 /// scf.yield %new_tensor : tensor_type
1099 /// }
1100 /// ```
1101 ///
1102 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
1103 /// `%m = buffer_cast %bb0` op that feeds into the yielded
1104 /// `bufferization.to_tensor` op.
1105 ///
1106 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
1107 /// occurs between `bufferization.to_tensor and yield then the value %0
1108 /// visible outside of the loop is the last `bufferization.to_tensor`
1109 /// produced in the loop.
1110 ///
1111 /// For now, we approximate the absence of aliasing by only supporting the case
1112 /// when the bufferization.to_tensor is the operation immediately preceding
1113 /// the yield.
1114 //
1115 /// The canonicalization rewrites the pattern as:
1116 /// ```
1117 /// // %m is either a buffer_cast or defined above
1118 /// %m... : memref_type
1119 /// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
1120 /// ... // uses of %m with potential inplace updates
1121 /// scf.yield %bb0: tensor_type
1122 /// }
1123 /// %0 = bufferization.to_tensor %m : memref_type
1124 /// ```
1125 ///
1126 /// A later bbArg canonicalization will further rewrite as:
1127 /// ```
1128 /// // %m is either a buffer_cast or defined above
1129 /// %m... : memref_type
1130 /// scf.for ... { // no iter_args
1131 /// ... // uses of %m with potential inplace updates
1132 /// }
1133 /// %0 = bufferization.to_tensor %m : memref_type
1134 /// ```
1135 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
1137 
1138  LogicalResult matchAndRewrite(ForOp forOp,
1139  PatternRewriter &rewriter) const override {
1140  assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
1141  "unexpected multiple blocks");
1142 
1143  Location loc = forOp.getLoc();
1144  DenseMap<Value, Value> replacements;
1145  for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
1146  unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
1147  auto yieldOp =
1148  cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
1149  Value yieldVal = yieldOp->getOperand(idx);
1150  auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
1151  bool isTensor = llvm::isa<TensorType>(bbArg.getType());
1152 
1153  bufferization::ToMemrefOp tensorToMemref;
1154  // Either bbArg has no use or it has a single buffer_cast use.
1155  if (bbArg.hasOneUse())
1156  tensorToMemref =
1157  dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
1158  if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
1159  continue;
1160  // If tensorToMemref is present, it must feed into the `ToTensorOp`.
1161  if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
1162  continue;
1163  // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
1164  // must be before `ToTensorOp` in the block so that the lastWrite
1165  // property is not subject to additional side-effects.
1166  // For now, we only support the case when ToTensorOp appears
1167  // immediately before the terminator.
1168  if (tensorLoadOp->getNextNode() != yieldOp)
1169  continue;
1170 
1171  // Clone the optional tensorToMemref before forOp.
1172  if (tensorToMemref) {
1173  rewriter.setInsertionPoint(forOp);
1174  rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1175  tensorToMemref, tensorToMemref.getMemref().getType(),
1176  tensorToMemref.getTensor());
1177  }
1178 
1179  // Clone the tensorLoad after forOp.
1180  rewriter.setInsertionPointAfter(forOp);
1181  Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1182  loc, tensorLoadOp.getMemref());
1183  Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1184  replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1185 
1186  // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1187  // old bbArg (that is now directly yielded) will canonicalize away.
1188  rewriter.startRootUpdate(yieldOp);
1189  yieldOp.setOperand(idx, bbArg);
1190  rewriter.finalizeRootUpdate(yieldOp);
1191  }
1192  if (replacements.empty())
1193  return failure();
1194 
1195  // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1196  // replaces the whole op and erase it unconditionally. This is wrong for
1197  // `forOp` as it generally contains ops with side effects.
1198  // Instead, use `rewriter.replaceOpWithIf`.
1199  SmallVector<Value> newResults;
1200  newResults.reserve(forOp.getNumResults());
1201  for (Value v : forOp.getResults()) {
1202  auto it = replacements.find(v);
1203  newResults.push_back((it != replacements.end()) ? it->second : v);
1204  }
1205  unsigned idx = 0;
1206  rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1207  return op.get() != newResults[idx++];
1208  });
1209  return success();
1210  }
1211 };
1212 } // namespace
1213 
1214 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215  MLIRContext *context) {
1216  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1217  LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1218 }
1219 
1220 std::optional<APInt> ForOp::getConstantStep() {
1221  IntegerAttr step;
1222  if (matchPattern(getStep(), m_Constant(&step)))
1223  return step.getValue();
1224  return {};
1225 }
1226 
1227 MutableArrayRef<OpOperand> ForOp::getYieldedValuesMutable() {
1228  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1229 }
1230 
1231 Speculation::Speculatability ForOp::getSpeculatability() {
1232  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1233  // and End.
1234  if (auto constantStep = getConstantStep())
1235  if (*constantStep == 1)
1237 
1238  // For Step != 1, the loop may not terminate. We can add more smarts here if
1239  // needed.
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // ForallOp
1245 //===----------------------------------------------------------------------===//
1246 
1248  unsigned numLoops = getRank();
1249  // Check number of outputs.
1250  if (getNumResults() != getOutputs().size())
1251  return emitOpError("produces ")
1252  << getNumResults() << " results, but has only "
1253  << getOutputs().size() << " outputs";
1254 
1255  // Check that the body defines block arguments for thread indices and outputs.
1256  auto *body = getBody();
1257  if (body->getNumArguments() != numLoops + getOutputs().size())
1258  return emitOpError("region expects ") << numLoops << " arguments";
1259  for (int64_t i = 0; i < numLoops; ++i)
1260  if (!body->getArgument(i).getType().isIndex())
1261  return emitOpError("expects ")
1262  << i << "-th block argument to be an index";
1263  for (unsigned i = 0; i < getOutputs().size(); ++i)
1264  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1265  return emitOpError("type mismatch between ")
1266  << i << "-th output and corresponding block argument";
1267  if (getMapping().has_value() && !getMapping()->empty()) {
1268  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1269  return emitOpError() << "mapping attribute size must match op rank";
1270  for (auto map : getMapping()->getValue()) {
1271  if (!isa<DeviceMappingAttrInterface>(map))
1272  return emitOpError()
1273  << getMappingAttrName() << " is not device mapping attribute";
1274  }
1275  }
1276 
1277  // Verify mixed static/dynamic control variables.
1278  Operation *op = getOperation();
1279  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1280  getStaticLowerBound(),
1281  getDynamicLowerBound())))
1282  return failure();
1283  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1284  getStaticUpperBound(),
1285  getDynamicUpperBound())))
1286  return failure();
1287  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1288  getStaticStep(), getDynamicStep())))
1289  return failure();
1290 
1291  return success();
1292 }
1293 
1294 void ForallOp::print(OpAsmPrinter &p) {
1295  Operation *op = getOperation();
1296  p << " (" << getInductionVars();
1297  if (isNormalized()) {
1298  p << ") in ";
1299  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1300  /*valueTypes=*/{}, /*scalables=*/{},
1302  } else {
1303  p << ") = ";
1304  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1305  /*valueTypes=*/{}, /*scalables=*/{},
1307  p << " to ";
1308  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1309  /*valueTypes=*/{}, /*scalables=*/{},
1311  p << " step ";
1312  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1313  /*valueTypes=*/{}, /*scalable=*/{},
1315  }
1316  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1317  p << " ";
1318  if (!getRegionOutArgs().empty())
1319  p << "-> (" << getResultTypes() << ") ";
1320  p.printRegion(getRegion(),
1321  /*printEntryBlockArgs=*/false,
1322  /*printBlockTerminators=*/getNumResults() > 0);
1323  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1324  getStaticLowerBoundAttrName(),
1325  getStaticUpperBoundAttrName(),
1326  getStaticStepAttrName()});
1327 }
1328 
1330  OpBuilder b(parser.getContext());
1331  auto indexType = b.getIndexType();
1332 
1333  // Parse an opening `(` followed by thread index variables followed by `)`
1334  // TODO: when we can refer to such "induction variable"-like handles from the
1335  // declarative assembly format, we can implement the parser as a custom hook.
1338  return failure();
1339 
1340  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1341  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1342  dynamicSteps;
1343  if (succeeded(parser.parseOptionalKeyword("in"))) {
1344  // Parse upper bounds.
1345  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1346  /*valueTypes=*/nullptr,
1348  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1349  return failure();
1350 
1351  unsigned numLoops = ivs.size();
1352  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1353  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1354  } else {
1355  // Parse lower bounds.
1356  if (parser.parseEqual() ||
1357  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1358  /*valueTypes=*/nullptr,
1360 
1361  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1362  return failure();
1363 
1364  // Parse upper bounds.
1365  if (parser.parseKeyword("to") ||
1366  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1367  /*valueTypes=*/nullptr,
1369  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1370  return failure();
1371 
1372  // Parse step values.
1373  if (parser.parseKeyword("step") ||
1374  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1375  /*valueTypes=*/nullptr,
1377  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1378  return failure();
1379  }
1380 
1381  // Parse out operands and results.
1384  SMLoc outOperandsLoc = parser.getCurrentLocation();
1385  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1386  if (outOperands.size() != result.types.size())
1387  return parser.emitError(outOperandsLoc,
1388  "mismatch between out operands and types");
1389  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1390  parser.parseOptionalArrowTypeList(result.types) ||
1391  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1392  result.operands))
1393  return failure();
1394  }
1395 
1396  // Parse region.
1398  std::unique_ptr<Region> region = std::make_unique<Region>();
1399  for (auto &iv : ivs) {
1400  iv.type = b.getIndexType();
1401  regionArgs.push_back(iv);
1402  }
1403  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1404  auto &out = it.value();
1405  out.type = result.types[it.index()];
1406  regionArgs.push_back(out);
1407  }
1408  if (parser.parseRegion(*region, regionArgs))
1409  return failure();
1410 
1411  // Ensure terminator and move region.
1412  ForallOp::ensureTerminator(*region, b, result.location);
1413  result.addRegion(std::move(region));
1414 
1415  // Parse the optional attribute list.
1416  if (parser.parseOptionalAttrDict(result.attributes))
1417  return failure();
1418 
1419  result.addAttribute("staticLowerBound", staticLbs);
1420  result.addAttribute("staticUpperBound", staticUbs);
1421  result.addAttribute("staticStep", staticSteps);
1422  result.addAttribute("operandSegmentSizes",
1424  {static_cast<int32_t>(dynamicLbs.size()),
1425  static_cast<int32_t>(dynamicUbs.size()),
1426  static_cast<int32_t>(dynamicSteps.size()),
1427  static_cast<int32_t>(outOperands.size())}));
1428  return success();
1429 }
1430 
1431 // Builder that takes loop bounds.
1432 void ForallOp::build(
1435  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1436  std::optional<ArrayAttr> mapping,
1437  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1438  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1439  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1440  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1441  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1442  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1443 
1444  result.addOperands(dynamicLbs);
1445  result.addOperands(dynamicUbs);
1446  result.addOperands(dynamicSteps);
1447  result.addOperands(outputs);
1448  result.addTypes(TypeRange(outputs));
1449 
1450  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1451  b.getDenseI64ArrayAttr(staticLbs));
1452  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1453  b.getDenseI64ArrayAttr(staticUbs));
1454  result.addAttribute(getStaticStepAttrName(result.name),
1455  b.getDenseI64ArrayAttr(staticSteps));
1456  result.addAttribute(
1457  "operandSegmentSizes",
1458  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1459  static_cast<int32_t>(dynamicUbs.size()),
1460  static_cast<int32_t>(dynamicSteps.size()),
1461  static_cast<int32_t>(outputs.size())}));
1462  if (mapping.has_value()) {
1464  mapping.value());
1465  }
1466 
1467  Region *bodyRegion = result.addRegion();
1469  b.createBlock(bodyRegion);
1470  Block &bodyBlock = bodyRegion->front();
1471 
1472  // Add block arguments for indices and outputs.
1473  bodyBlock.addArguments(
1474  SmallVector<Type>(lbs.size(), b.getIndexType()),
1475  SmallVector<Location>(staticLbs.size(), result.location));
1476  bodyBlock.addArguments(
1477  TypeRange(outputs),
1478  SmallVector<Location>(outputs.size(), result.location));
1479 
1480  b.setInsertionPointToStart(&bodyBlock);
1481  if (!bodyBuilderFn) {
1482  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1483  return;
1484  }
1485  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1486 #ifndef NDEBUG
1487  auto terminator = llvm::dyn_cast<InParallelOp>(bodyBlock.getTerminator());
1488  assert(terminator &&
1489  "expected bodyBuilderFn to create InParallelOp terminator");
1490 #endif // NDEBUG
1491 }
1492 
1493 // Builder that takes loop bounds.
1494 void ForallOp::build(
1496  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1497  std::optional<ArrayAttr> mapping,
1498  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1499  unsigned numLoops = ubs.size();
1500  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1501  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1502  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1503 }
1504 
1505 // Checks if the lbs are zeros and steps are ones.
1506 bool ForallOp::isNormalized() {
1507  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1508  return llvm::all_of(results, [&](OpFoldResult ofr) {
1509  auto intValue = getConstantIntValue(ofr);
1510  return intValue.has_value() && intValue == val;
1511  });
1512  };
1513  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1514 }
1515 
1516 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1517 // unaware of the fact that our terminator also needs a region to be
1518 // well-formed. We override it here to ensure that we do the right thing.
1519 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1520  Location loc) {
1522  ForallOp>::ensureTerminator(region, builder, loc);
1523  auto terminator =
1524  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1525  if (terminator.getRegion().empty())
1526  builder.createBlock(&terminator.getRegion());
1527 }
1528 
1529 InParallelOp ForallOp::getTerminator() {
1530  return cast<InParallelOp>(getBody()->getTerminator());
1531 }
1532 
1533 std::optional<Value> ForallOp::getSingleInductionVar() {
1534  if (getRank() != 1)
1535  return std::nullopt;
1536  return getInductionVar(0);
1537 }
1538 
1539 std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1540  if (getRank() != 1)
1541  return std::nullopt;
1542  return getMixedLowerBound()[0];
1543 }
1544 
1545 std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1546  if (getRank() != 1)
1547  return std::nullopt;
1548  return getMixedUpperBound()[0];
1549 }
1550 
1551 std::optional<OpFoldResult> ForallOp::getSingleStep() {
1552  if (getRank() != 1)
1553  return std::nullopt;
1554  return getMixedStep()[0];
1555 }
1556 
1558  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1559  if (!tidxArg)
1560  return ForallOp();
1561  assert(tidxArg.getOwner() && "unlinked block argument");
1562  auto *containingOp = tidxArg.getOwner()->getParentOp();
1563  return dyn_cast<ForallOp>(containingOp);
1564 }
1565 
1566 namespace {
1567 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1568 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1570 
1571  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1572  PatternRewriter &rewriter) const final {
1573  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1574  if (!forallOp)
1575  return failure();
1576  Value sharedOut =
1577  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1578  ->get();
1579  rewriter.updateRootInPlace(
1580  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1581  return success();
1582  }
1583 };
1584 
1585 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1586 public:
1588 
1589  LogicalResult matchAndRewrite(ForallOp op,
1590  PatternRewriter &rewriter) const override {
1591  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1592  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1593  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1594  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1595  failed(foldDynamicIndexList(mixedUpperBound)) &&
1596  failed(foldDynamicIndexList(mixedStep)))
1597  return failure();
1598 
1599  rewriter.updateRootInPlace(op, [&]() {
1600  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1601  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1602  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1603  staticLowerBound);
1604  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1605  op.setStaticLowerBound(staticLowerBound);
1606 
1607  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1608  staticUpperBound);
1609  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1610  op.setStaticUpperBound(staticUpperBound);
1611 
1612  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1613  op.getDynamicStepMutable().assign(dynamicStep);
1614  op.setStaticStep(staticStep);
1615 
1616  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1617  rewriter.getDenseI32ArrayAttr(
1618  {static_cast<int32_t>(dynamicLowerBound.size()),
1619  static_cast<int32_t>(dynamicUpperBound.size()),
1620  static_cast<int32_t>(dynamicStep.size()),
1621  static_cast<int32_t>(op.getNumResults())}));
1622  });
1623  return success();
1624  }
1625 };
1626 
1627 struct ForallOpSingleOrZeroIterationDimsFolder
1628  : public OpRewritePattern<ForallOp> {
1630 
1631  LogicalResult matchAndRewrite(ForallOp op,
1632  PatternRewriter &rewriter) const override {
1633  // Do not fold dimensions if they are mapped to processing units.
1634  if (op.getMapping().has_value())
1635  return failure();
1636  Location loc = op.getLoc();
1637 
1638  // Compute new loop bounds that omit all single-iteration loop dimensions.
1639  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1640  newMixedSteps;
1641  IRMapping mapping;
1642  for (auto [lb, ub, step, iv] :
1643  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1644  op.getMixedStep(), op.getInductionVars())) {
1645  auto numIterations = constantTripCount(lb, ub, step);
1646  if (numIterations.has_value()) {
1647  // Remove the loop if it performs zero iterations.
1648  if (*numIterations == 0) {
1649  rewriter.replaceOp(op, op.getOutputs());
1650  return success();
1651  }
1652  // Replace the loop induction variable by the lower bound if the loop
1653  // performs a single iteration. Otherwise, copy the loop bounds.
1654  if (*numIterations == 1) {
1655  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1656  continue;
1657  }
1658  }
1659  newMixedLowerBounds.push_back(lb);
1660  newMixedUpperBounds.push_back(ub);
1661  newMixedSteps.push_back(step);
1662  }
1663  // Exit if none of the loop dimensions perform a single iteration.
1664  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1665  return rewriter.notifyMatchFailure(
1666  op, "no dimensions have 0 or 1 iterations");
1667  }
1668 
1669  // All of the loop dimensions perform a single iteration. Inline loop body.
1670  if (newMixedLowerBounds.empty()) {
1671  promote(rewriter, op);
1672  return success();
1673  }
1674 
1675  // Replace the loop by a lower-dimensional loop.
1676  ForallOp newOp;
1677  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1678  newMixedUpperBounds, newMixedSteps,
1679  op.getOutputs(), std::nullopt, nullptr);
1680  newOp.getBodyRegion().getBlocks().clear();
1681  // The new loop needs to keep all attributes from the old one, except for
1682  // "operandSegmentSizes" and static loop bound attributes which capture
1683  // the outdated information of the old iteration domain.
1684  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1685  newOp.getStaticLowerBoundAttrName(),
1686  newOp.getStaticUpperBoundAttrName(),
1687  newOp.getStaticStepAttrName()};
1688  for (const auto &namedAttr : op->getAttrs()) {
1689  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1690  continue;
1691  rewriter.updateRootInPlace(newOp, [&]() {
1692  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1693  });
1694  }
1695  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1696  newOp.getRegion().begin(), mapping);
1697  rewriter.replaceOp(op, newOp.getResults());
1698  return success();
1699  }
1700 };
1701 
1702 struct FoldTensorCastOfOutputIntoForallOp
1703  : public OpRewritePattern<scf::ForallOp> {
1705 
1706  struct TypeCast {
1707  Type srcType;
1708  Type dstType;
1709  };
1710 
1711  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1712  PatternRewriter &rewriter) const final {
1713  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1714  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1715  for (auto en : llvm::enumerate(newOutputTensors)) {
1716  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1717  if (!castOp)
1718  continue;
1719 
1720  // Only casts that that preserve static information, i.e. will make the
1721  // loop result type "more" static than before, will be folded.
1722  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1723  castOp.getSource().getType())) {
1724  continue;
1725  }
1726 
1727  tensorCastProducers[en.index()] =
1728  TypeCast{castOp.getSource().getType(), castOp.getType()};
1729  newOutputTensors[en.index()] = castOp.getSource();
1730  }
1731 
1732  if (tensorCastProducers.empty())
1733  return failure();
1734 
1735  // Create new loop.
1736  Location loc = forallOp.getLoc();
1737  auto newForallOp = rewriter.create<ForallOp>(
1738  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1739  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1740  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1741  auto castBlockArgs =
1742  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1743  for (auto [index, cast] : tensorCastProducers) {
1744  Value &oldTypeBBArg = castBlockArgs[index];
1745  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1746  nestedLoc, cast.dstType, oldTypeBBArg);
1747  }
1748 
1749  // Move old body into new parallel loop.
1750  SmallVector<Value> ivsBlockArgs =
1751  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1752  ivsBlockArgs.append(castBlockArgs);
1753  rewriter.mergeBlocks(forallOp.getBody(),
1754  bbArgs.front().getParentBlock(), ivsBlockArgs);
1755  });
1756 
1757  // After `mergeBlocks` happened, the destinations in the terminator were
1758  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1759  // destination have to be updated to point to the output bbArgs directly.
1760  auto terminator = newForallOp.getTerminator();
1761  for (auto [yieldingOp, outputBlockArg] :
1762  llvm::zip(terminator.getYieldingOps(),
1763  newForallOp.getOutputBlockArguments())) {
1764  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1765  insertSliceOp.getDestMutable().assign(outputBlockArg);
1766  }
1767 
1768  // Cast results back to the original types.
1769  rewriter.setInsertionPointAfter(newForallOp);
1770  SmallVector<Value> castResults = newForallOp.getResults();
1771  for (auto &item : tensorCastProducers) {
1772  Value &oldTypeResult = castResults[item.first];
1773  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1774  oldTypeResult);
1775  }
1776  rewriter.replaceOp(forallOp, castResults);
1777  return success();
1778  }
1779 };
1780 
1781 } // namespace
1782 
1783 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1784  MLIRContext *context) {
1785  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1786  ForallOpControlOperandsFolder,
1787  ForallOpSingleOrZeroIterationDimsFolder>(context);
1788 }
1789 
1790 /// Given the region at `index`, or the parent operation if `index` is None,
1791 /// return the successor regions. These are the regions that may be selected
1792 /// during the flow of control. `operands` is a set of optional attributes that
1793 /// correspond to a constant value for each operand, or null if that operand is
1794 /// not a constant.
1795 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1797  // Both the operation itself and the region may be branching into the body or
1798  // back into the operation itself. It is possible for loop not to enter the
1799  // body.
1800  regions.push_back(RegionSuccessor(&getRegion()));
1801  regions.push_back(RegionSuccessor());
1802 }
1803 
1804 //===----------------------------------------------------------------------===//
1805 // InParallelOp
1806 //===----------------------------------------------------------------------===//
1807 
1808 // Build a InParallelOp with mixed static and dynamic entries.
1809 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1811  Region *bodyRegion = result.addRegion();
1812  b.createBlock(bodyRegion);
1813 }
1814 
1816  scf::ForallOp forallOp =
1817  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1818  if (!forallOp)
1819  return this->emitOpError("expected forall op parent");
1820 
1821  // TODO: InParallelOpInterface.
1822  for (Operation &op : getRegion().front().getOperations()) {
1823  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1824  return this->emitOpError("expected only ")
1825  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1826  }
1827 
1828  // Verify that inserts are into out block arguments.
1829  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1830  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1831  if (!llvm::is_contained(regionOutArgs, dest))
1832  return op.emitOpError("may only insert into an output block argument");
1833  }
1834  return success();
1835 }
1836 
1838  p << " ";
1839  p.printRegion(getRegion(),
1840  /*printEntryBlockArgs=*/false,
1841  /*printBlockTerminators=*/false);
1842  p.printOptionalAttrDict(getOperation()->getAttrs());
1843 }
1844 
1846  auto &builder = parser.getBuilder();
1847 
1849  std::unique_ptr<Region> region = std::make_unique<Region>();
1850  if (parser.parseRegion(*region, regionOperands))
1851  return failure();
1852 
1853  if (region->empty())
1854  OpBuilder(builder.getContext()).createBlock(region.get());
1855  result.addRegion(std::move(region));
1856 
1857  // Parse the optional attribute list.
1858  if (parser.parseOptionalAttrDict(result.attributes))
1859  return failure();
1860  return success();
1861 }
1862 
1863 OpResult InParallelOp::getParentResult(int64_t idx) {
1864  return getOperation()->getParentOp()->getResult(idx);
1865 }
1866 
1867 SmallVector<BlockArgument> InParallelOp::getDests() {
1868  return llvm::to_vector<4>(
1869  llvm::map_range(getYieldingOps(), [](Operation &op) {
1870  // Add new ops here as needed.
1871  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1872  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1873  }));
1874 }
1875 
1876 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1877  return getRegion().front().getOperations();
1878 }
1879 
1880 //===----------------------------------------------------------------------===//
1881 // IfOp
1882 //===----------------------------------------------------------------------===//
1883 
1885  assert(a && "expected non-empty operation");
1886  assert(b && "expected non-empty operation");
1887 
1888  IfOp ifOp = a->getParentOfType<IfOp>();
1889  while (ifOp) {
1890  // Check if b is inside ifOp. (We already know that a is.)
1891  if (ifOp->isProperAncestor(b))
1892  // b is contained in ifOp. a and b are in mutually exclusive branches if
1893  // they are in different blocks of ifOp.
1894  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1895  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1896  // Check next enclosing IfOp.
1897  ifOp = ifOp->getParentOfType<IfOp>();
1898  }
1899 
1900  // Could not find a common IfOp among a's and b's ancestors.
1901  return false;
1902 }
1903 
1905 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1906  IfOp::Adaptor adaptor,
1907  SmallVectorImpl<Type> &inferredReturnTypes) {
1908  if (adaptor.getRegions().empty())
1909  return failure();
1910  Region *r = &adaptor.getThenRegion();
1911  if (r->empty())
1912  return failure();
1913  Block &b = r->front();
1914  if (b.empty())
1915  return failure();
1916  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1917  if (!yieldOp)
1918  return failure();
1919  TypeRange types = yieldOp.getOperandTypes();
1920  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1921  types.end());
1922  return success();
1923 }
1924 
1925 void IfOp::build(OpBuilder &builder, OperationState &result,
1926  TypeRange resultTypes, Value cond) {
1927  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1928  /*addElseBlock=*/false);
1929 }
1930 
1931 void IfOp::build(OpBuilder &builder, OperationState &result,
1932  TypeRange resultTypes, Value cond, bool addThenBlock,
1933  bool addElseBlock) {
1934  assert((!addElseBlock || addThenBlock) &&
1935  "must not create else block w/o then block");
1936  result.addTypes(resultTypes);
1937  result.addOperands(cond);
1938 
1939  // Add regions and blocks.
1940  OpBuilder::InsertionGuard guard(builder);
1941  Region *thenRegion = result.addRegion();
1942  if (addThenBlock)
1943  builder.createBlock(thenRegion);
1944  Region *elseRegion = result.addRegion();
1945  if (addElseBlock)
1946  builder.createBlock(elseRegion);
1947 }
1948 
1949 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1950  bool withElseRegion) {
1951  build(builder, result, TypeRange{}, cond, withElseRegion);
1952 }
1953 
1954 void IfOp::build(OpBuilder &builder, OperationState &result,
1955  TypeRange resultTypes, Value cond, bool withElseRegion) {
1956  result.addTypes(resultTypes);
1957  result.addOperands(cond);
1958 
1959  // Build then region.
1960  OpBuilder::InsertionGuard guard(builder);
1961  Region *thenRegion = result.addRegion();
1962  builder.createBlock(thenRegion);
1963  if (resultTypes.empty())
1964  IfOp::ensureTerminator(*thenRegion, builder, result.location);
1965 
1966  // Build else region.
1967  Region *elseRegion = result.addRegion();
1968  if (withElseRegion) {
1969  builder.createBlock(elseRegion);
1970  if (resultTypes.empty())
1971  IfOp::ensureTerminator(*elseRegion, builder, result.location);
1972  }
1973 }
1974 
1975 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1976  function_ref<void(OpBuilder &, Location)> thenBuilder,
1977  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1978  assert(thenBuilder && "the builder callback for 'then' must be present");
1979  result.addOperands(cond);
1980 
1981  // Build then region.
1982  OpBuilder::InsertionGuard guard(builder);
1983  Region *thenRegion = result.addRegion();
1984  builder.createBlock(thenRegion);
1985  thenBuilder(builder, result.location);
1986 
1987  // Build else region.
1988  Region *elseRegion = result.addRegion();
1989  if (elseBuilder) {
1990  builder.createBlock(elseRegion);
1991  elseBuilder(builder, result.location);
1992  }
1993 
1994  // Infer result types.
1995  SmallVector<Type> inferredReturnTypes;
1996  MLIRContext *ctx = builder.getContext();
1997  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
1998  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
1999  /*properties=*/nullptr, result.regions,
2000  inferredReturnTypes))) {
2001  result.addTypes(inferredReturnTypes);
2002  }
2003 }
2004 
2006  if (getNumResults() != 0 && getElseRegion().empty())
2007  return emitOpError("must have an else block if defining values");
2008  return success();
2009 }
2010 
2012  // Create the regions for 'then'.
2013  result.regions.reserve(2);
2014  Region *thenRegion = result.addRegion();
2015  Region *elseRegion = result.addRegion();
2016 
2017  auto &builder = parser.getBuilder();
2019  Type i1Type = builder.getIntegerType(1);
2020  if (parser.parseOperand(cond) ||
2021  parser.resolveOperand(cond, i1Type, result.operands))
2022  return failure();
2023  // Parse optional results type list.
2024  if (parser.parseOptionalArrowTypeList(result.types))
2025  return failure();
2026  // Parse the 'then' region.
2027  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2028  return failure();
2029  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2030 
2031  // If we find an 'else' keyword then parse the 'else' region.
2032  if (!parser.parseOptionalKeyword("else")) {
2033  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2034  return failure();
2035  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2036  }
2037 
2038  // Parse the optional attribute list.
2039  if (parser.parseOptionalAttrDict(result.attributes))
2040  return failure();
2041  return success();
2042 }
2043 
2044 void IfOp::print(OpAsmPrinter &p) {
2045  bool printBlockTerminators = false;
2046 
2047  p << " " << getCondition();
2048  if (!getResults().empty()) {
2049  p << " -> (" << getResultTypes() << ")";
2050  // Print yield explicitly if the op defines values.
2051  printBlockTerminators = true;
2052  }
2053  p << ' ';
2054  p.printRegion(getThenRegion(),
2055  /*printEntryBlockArgs=*/false,
2056  /*printBlockTerminators=*/printBlockTerminators);
2057 
2058  // Print the 'else' regions if it exists and has a block.
2059  auto &elseRegion = getElseRegion();
2060  if (!elseRegion.empty()) {
2061  p << " else ";
2062  p.printRegion(elseRegion,
2063  /*printEntryBlockArgs=*/false,
2064  /*printBlockTerminators=*/printBlockTerminators);
2065  }
2066 
2067  p.printOptionalAttrDict((*this)->getAttrs());
2068 }
2069 
2070 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2072  // The `then` and the `else` region branch back to the parent operation.
2073  if (!point.isParent()) {
2074  regions.push_back(RegionSuccessor(getResults()));
2075  return;
2076  }
2077 
2078  regions.push_back(RegionSuccessor(&getThenRegion()));
2079 
2080  // Don't consider the else region if it is empty.
2081  Region *elseRegion = &this->getElseRegion();
2082  if (elseRegion->empty())
2083  regions.push_back(RegionSuccessor());
2084  else
2085  regions.push_back(RegionSuccessor(elseRegion));
2086 }
2087 
2088 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2090  FoldAdaptor adaptor(operands, *this);
2091  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2092  if (!boolAttr || boolAttr.getValue())
2093  regions.emplace_back(&getThenRegion());
2094 
2095  // If the else region is empty, execution continues after the parent op.
2096  if (!boolAttr || !boolAttr.getValue()) {
2097  if (!getElseRegion().empty())
2098  regions.emplace_back(&getElseRegion());
2099  else
2100  regions.emplace_back(getResults());
2101  }
2102 }
2103 
2104 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2105  SmallVectorImpl<OpFoldResult> &results) {
2106  // if (!c) then A() else B() -> if c then B() else A()
2107  if (getElseRegion().empty())
2108  return failure();
2109 
2110  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2111  if (!xorStmt)
2112  return failure();
2113 
2114  if (!matchPattern(xorStmt.getRhs(), m_One()))
2115  return failure();
2116 
2117  getConditionMutable().assign(xorStmt.getLhs());
2118  Block *thenBlock = &getThenRegion().front();
2119  // It would be nicer to use iplist::swap, but that has no implemented
2120  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2121  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2122  getElseRegion().getBlocks());
2123  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2124  getThenRegion().getBlocks(), thenBlock);
2125  return success();
2126 }
2127 
2128 void IfOp::getRegionInvocationBounds(
2129  ArrayRef<Attribute> operands,
2130  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2131  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2132  // If the condition is known, then one region is known to be executed once
2133  // and the other zero times.
2134  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2135  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2136  } else {
2137  // Non-constant condition. Each region may be executed 0 or 1 times.
2138  invocationBounds.assign(2, {0, 1});
2139  }
2140 }
2141 
2142 namespace {
2143 // Pattern to remove unused IfOp results.
2144 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2146 
2147  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2148  PatternRewriter &rewriter) const {
2149  // Move all operations to the destination block.
2150  rewriter.mergeBlocks(source, dest);
2151  // Replace the yield op by one that returns only the used values.
2152  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2153  SmallVector<Value, 4> usedOperands;
2154  llvm::transform(usedResults, std::back_inserter(usedOperands),
2155  [&](OpResult result) {
2156  return yieldOp.getOperand(result.getResultNumber());
2157  });
2158  rewriter.updateRootInPlace(yieldOp,
2159  [&]() { yieldOp->setOperands(usedOperands); });
2160  }
2161 
2162  LogicalResult matchAndRewrite(IfOp op,
2163  PatternRewriter &rewriter) const override {
2164  // Compute the list of used results.
2165  SmallVector<OpResult, 4> usedResults;
2166  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2167  [](OpResult result) { return !result.use_empty(); });
2168 
2169  // Replace the operation if only a subset of its results have uses.
2170  if (usedResults.size() == op.getNumResults())
2171  return failure();
2172 
2173  // Compute the result types of the replacement operation.
2174  SmallVector<Type, 4> newTypes;
2175  llvm::transform(usedResults, std::back_inserter(newTypes),
2176  [](OpResult result) { return result.getType(); });
2177 
2178  // Create a replacement operation with empty then and else regions.
2179  auto newOp =
2180  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2181  rewriter.createBlock(&newOp.getThenRegion());
2182  rewriter.createBlock(&newOp.getElseRegion());
2183 
2184  // Move the bodies and replace the terminators (note there is a then and
2185  // an else region since the operation returns results).
2186  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2187  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2188 
2189  // Replace the operation by the new one.
2190  SmallVector<Value, 4> repResults(op.getNumResults());
2191  for (const auto &en : llvm::enumerate(usedResults))
2192  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2193  rewriter.replaceOp(op, repResults);
2194  return success();
2195  }
2196 };
2197 
2198 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2200 
2201  LogicalResult matchAndRewrite(IfOp op,
2202  PatternRewriter &rewriter) const override {
2203  BoolAttr condition;
2204  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2205  return failure();
2206 
2207  if (condition.getValue())
2208  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2209  else if (!op.getElseRegion().empty())
2210  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2211  else
2212  rewriter.eraseOp(op);
2213 
2214  return success();
2215  }
2216 };
2217 
2218 /// Hoist any yielded results whose operands are defined outside
2219 /// the if, to a select instruction.
2220 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2222 
2223  LogicalResult matchAndRewrite(IfOp op,
2224  PatternRewriter &rewriter) const override {
2225  if (op->getNumResults() == 0)
2226  return failure();
2227 
2228  auto cond = op.getCondition();
2229  auto thenYieldArgs = op.thenYield().getOperands();
2230  auto elseYieldArgs = op.elseYield().getOperands();
2231 
2232  SmallVector<Type> nonHoistable;
2233  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2234  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2235  &op.getElseRegion() == falseVal.getParentRegion())
2236  nonHoistable.push_back(trueVal.getType());
2237  }
2238  // Early exit if there aren't any yielded values we can
2239  // hoist outside the if.
2240  if (nonHoistable.size() == op->getNumResults())
2241  return failure();
2242 
2243  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2244  /*withElseRegion=*/false);
2245  if (replacement.thenBlock())
2246  rewriter.eraseBlock(replacement.thenBlock());
2247  replacement.getThenRegion().takeBody(op.getThenRegion());
2248  replacement.getElseRegion().takeBody(op.getElseRegion());
2249 
2250  SmallVector<Value> results(op->getNumResults());
2251  assert(thenYieldArgs.size() == results.size());
2252  assert(elseYieldArgs.size() == results.size());
2253 
2254  SmallVector<Value> trueYields;
2255  SmallVector<Value> falseYields;
2256  rewriter.setInsertionPoint(replacement);
2257  for (const auto &it :
2258  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2259  Value trueVal = std::get<0>(it.value());
2260  Value falseVal = std::get<1>(it.value());
2261  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2262  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2263  results[it.index()] = replacement.getResult(trueYields.size());
2264  trueYields.push_back(trueVal);
2265  falseYields.push_back(falseVal);
2266  } else if (trueVal == falseVal)
2267  results[it.index()] = trueVal;
2268  else
2269  results[it.index()] = rewriter.create<arith::SelectOp>(
2270  op.getLoc(), cond, trueVal, falseVal);
2271  }
2272 
2273  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2274  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2275 
2276  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2277  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2278 
2279  rewriter.replaceOp(op, results);
2280  return success();
2281  }
2282 };
2283 
2284 /// Allow the true region of an if to assume the condition is true
2285 /// and vice versa. For example:
2286 ///
2287 /// scf.if %cmp {
2288 /// print(%cmp)
2289 /// }
2290 ///
2291 /// becomes
2292 ///
2293 /// scf.if %cmp {
2294 /// print(true)
2295 /// }
2296 ///
2297 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2299 
2300  LogicalResult matchAndRewrite(IfOp op,
2301  PatternRewriter &rewriter) const override {
2302  // Early exit if the condition is constant since replacing a constant
2303  // in the body with another constant isn't a simplification.
2304  if (matchPattern(op.getCondition(), m_Constant()))
2305  return failure();
2306 
2307  bool changed = false;
2308  mlir::Type i1Ty = rewriter.getI1Type();
2309 
2310  // These variables serve to prevent creating duplicate constants
2311  // and hold constant true or false values.
2312  Value constantTrue = nullptr;
2313  Value constantFalse = nullptr;
2314 
2315  for (OpOperand &use :
2316  llvm::make_early_inc_range(op.getCondition().getUses())) {
2317  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2318  changed = true;
2319 
2320  if (!constantTrue)
2321  constantTrue = rewriter.create<arith::ConstantOp>(
2322  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2323 
2324  rewriter.updateRootInPlace(use.getOwner(),
2325  [&]() { use.set(constantTrue); });
2326  } else if (op.getElseRegion().isAncestor(
2327  use.getOwner()->getParentRegion())) {
2328  changed = true;
2329 
2330  if (!constantFalse)
2331  constantFalse = rewriter.create<arith::ConstantOp>(
2332  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2333 
2334  rewriter.updateRootInPlace(use.getOwner(),
2335  [&]() { use.set(constantFalse); });
2336  }
2337  }
2338 
2339  return success(changed);
2340  }
2341 };
2342 
2343 /// Remove any statements from an if that are equivalent to the condition
2344 /// or its negation. For example:
2345 ///
2346 /// %res:2 = scf.if %cmp {
2347 /// yield something(), true
2348 /// } else {
2349 /// yield something2(), false
2350 /// }
2351 /// print(%res#1)
2352 ///
2353 /// becomes
2354 /// %res = scf.if %cmp {
2355 /// yield something()
2356 /// } else {
2357 /// yield something2()
2358 /// }
2359 /// print(%cmp)
2360 ///
2361 /// Additionally if both branches yield the same value, replace all uses
2362 /// of the result with the yielded value.
2363 ///
2364 /// %res:2 = scf.if %cmp {
2365 /// yield something(), %arg1
2366 /// } else {
2367 /// yield something2(), %arg1
2368 /// }
2369 /// print(%res#1)
2370 ///
2371 /// becomes
2372 /// %res = scf.if %cmp {
2373 /// yield something()
2374 /// } else {
2375 /// yield something2()
2376 /// }
2377 /// print(%arg1)
2378 ///
2379 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2381 
2382  LogicalResult matchAndRewrite(IfOp op,
2383  PatternRewriter &rewriter) const override {
2384  // Early exit if there are no results that could be replaced.
2385  if (op.getNumResults() == 0)
2386  return failure();
2387 
2388  auto trueYield =
2389  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2390  auto falseYield =
2391  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2392 
2393  rewriter.setInsertionPoint(op->getBlock(),
2394  op.getOperation()->getIterator());
2395  bool changed = false;
2396  Type i1Ty = rewriter.getI1Type();
2397  for (auto [trueResult, falseResult, opResult] :
2398  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2399  op.getResults())) {
2400  if (trueResult == falseResult) {
2401  if (!opResult.use_empty()) {
2402  opResult.replaceAllUsesWith(trueResult);
2403  changed = true;
2404  }
2405  continue;
2406  }
2407 
2408  BoolAttr trueYield, falseYield;
2409  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2410  !matchPattern(falseResult, m_Constant(&falseYield)))
2411  continue;
2412 
2413  bool trueVal = trueYield.getValue();
2414  bool falseVal = falseYield.getValue();
2415  if (!trueVal && falseVal) {
2416  if (!opResult.use_empty()) {
2417  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2418  Value notCond = rewriter.create<arith::XOrIOp>(
2419  op.getLoc(), op.getCondition(),
2420  constDialect
2421  ->materializeConstant(rewriter,
2422  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2423  op.getLoc())
2424  ->getResult(0));
2425  opResult.replaceAllUsesWith(notCond);
2426  changed = true;
2427  }
2428  }
2429  if (trueVal && !falseVal) {
2430  if (!opResult.use_empty()) {
2431  opResult.replaceAllUsesWith(op.getCondition());
2432  changed = true;
2433  }
2434  }
2435  }
2436  return success(changed);
2437  }
2438 };
2439 
2440 /// Merge any consecutive scf.if's with the same condition.
2441 ///
2442 /// scf.if %cond {
2443 /// firstCodeTrue();...
2444 /// } else {
2445 /// firstCodeFalse();...
2446 /// }
2447 /// %res = scf.if %cond {
2448 /// secondCodeTrue();...
2449 /// } else {
2450 /// secondCodeFalse();...
2451 /// }
2452 ///
2453 /// becomes
2454 /// %res = scf.if %cmp {
2455 /// firstCodeTrue();...
2456 /// secondCodeTrue();...
2457 /// } else {
2458 /// firstCodeFalse();...
2459 /// secondCodeFalse();...
2460 /// }
2461 struct CombineIfs : public OpRewritePattern<IfOp> {
2463 
2464  LogicalResult matchAndRewrite(IfOp nextIf,
2465  PatternRewriter &rewriter) const override {
2466  Block *parent = nextIf->getBlock();
2467  if (nextIf == &parent->front())
2468  return failure();
2469 
2470  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2471  if (!prevIf)
2472  return failure();
2473 
2474  // Determine the logical then/else blocks when prevIf's
2475  // condition is used. Null means the block does not exist
2476  // in that case (e.g. empty else). If neither of these
2477  // are set, the two conditions cannot be compared.
2478  Block *nextThen = nullptr;
2479  Block *nextElse = nullptr;
2480  if (nextIf.getCondition() == prevIf.getCondition()) {
2481  nextThen = nextIf.thenBlock();
2482  if (!nextIf.getElseRegion().empty())
2483  nextElse = nextIf.elseBlock();
2484  }
2485  if (arith::XOrIOp notv =
2486  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2487  if (notv.getLhs() == prevIf.getCondition() &&
2488  matchPattern(notv.getRhs(), m_One())) {
2489  nextElse = nextIf.thenBlock();
2490  if (!nextIf.getElseRegion().empty())
2491  nextThen = nextIf.elseBlock();
2492  }
2493  }
2494  if (arith::XOrIOp notv =
2495  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2496  if (notv.getLhs() == nextIf.getCondition() &&
2497  matchPattern(notv.getRhs(), m_One())) {
2498  nextElse = nextIf.thenBlock();
2499  if (!nextIf.getElseRegion().empty())
2500  nextThen = nextIf.elseBlock();
2501  }
2502  }
2503 
2504  if (!nextThen && !nextElse)
2505  return failure();
2506 
2507  SmallVector<Value> prevElseYielded;
2508  if (!prevIf.getElseRegion().empty())
2509  prevElseYielded = prevIf.elseYield().getOperands();
2510  // Replace all uses of return values of op within nextIf with the
2511  // corresponding yields
2512  for (auto it : llvm::zip(prevIf.getResults(),
2513  prevIf.thenYield().getOperands(), prevElseYielded))
2514  for (OpOperand &use :
2515  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2516  if (nextThen && nextThen->getParent()->isAncestor(
2517  use.getOwner()->getParentRegion())) {
2518  rewriter.startRootUpdate(use.getOwner());
2519  use.set(std::get<1>(it));
2520  rewriter.finalizeRootUpdate(use.getOwner());
2521  } else if (nextElse && nextElse->getParent()->isAncestor(
2522  use.getOwner()->getParentRegion())) {
2523  rewriter.startRootUpdate(use.getOwner());
2524  use.set(std::get<2>(it));
2525  rewriter.finalizeRootUpdate(use.getOwner());
2526  }
2527  }
2528 
2529  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2530  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2531 
2532  IfOp combinedIf = rewriter.create<IfOp>(
2533  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2534  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2535 
2536  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2537  combinedIf.getThenRegion(),
2538  combinedIf.getThenRegion().begin());
2539 
2540  if (nextThen) {
2541  YieldOp thenYield = combinedIf.thenYield();
2542  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2543  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2544  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2545 
2546  SmallVector<Value> mergedYields(thenYield.getOperands());
2547  llvm::append_range(mergedYields, thenYield2.getOperands());
2548  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2549  rewriter.eraseOp(thenYield);
2550  rewriter.eraseOp(thenYield2);
2551  }
2552 
2553  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2554  combinedIf.getElseRegion(),
2555  combinedIf.getElseRegion().begin());
2556 
2557  if (nextElse) {
2558  if (combinedIf.getElseRegion().empty()) {
2559  rewriter.inlineRegionBefore(*nextElse->getParent(),
2560  combinedIf.getElseRegion(),
2561  combinedIf.getElseRegion().begin());
2562  } else {
2563  YieldOp elseYield = combinedIf.elseYield();
2564  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2565  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2566 
2567  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2568 
2569  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2570  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2571 
2572  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2573  rewriter.eraseOp(elseYield);
2574  rewriter.eraseOp(elseYield2);
2575  }
2576  }
2577 
2578  SmallVector<Value> prevValues;
2579  SmallVector<Value> nextValues;
2580  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2581  if (pair.index() < prevIf.getNumResults())
2582  prevValues.push_back(pair.value());
2583  else
2584  nextValues.push_back(pair.value());
2585  }
2586  rewriter.replaceOp(prevIf, prevValues);
2587  rewriter.replaceOp(nextIf, nextValues);
2588  return success();
2589  }
2590 };
2591 
2592 /// Pattern to remove an empty else branch.
2593 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2595 
2596  LogicalResult matchAndRewrite(IfOp ifOp,
2597  PatternRewriter &rewriter) const override {
2598  // Cannot remove else region when there are operation results.
2599  if (ifOp.getNumResults())
2600  return failure();
2601  Block *elseBlock = ifOp.elseBlock();
2602  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2603  return failure();
2604  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2605  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2606  newIfOp.getThenRegion().begin());
2607  rewriter.eraseOp(ifOp);
2608  return success();
2609  }
2610 };
2611 
2612 /// Convert nested `if`s into `arith.andi` + single `if`.
2613 ///
2614 /// scf.if %arg0 {
2615 /// scf.if %arg1 {
2616 /// ...
2617 /// scf.yield
2618 /// }
2619 /// scf.yield
2620 /// }
2621 /// becomes
2622 ///
2623 /// %0 = arith.andi %arg0, %arg1
2624 /// scf.if %0 {
2625 /// ...
2626 /// scf.yield
2627 /// }
2628 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2630 
2631  LogicalResult matchAndRewrite(IfOp op,
2632  PatternRewriter &rewriter) const override {
2633  auto nestedOps = op.thenBlock()->without_terminator();
2634  // Nested `if` must be the only op in block.
2635  if (!llvm::hasSingleElement(nestedOps))
2636  return failure();
2637 
2638  // If there is an else block, it can only yield
2639  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2640  return failure();
2641 
2642  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2643  if (!nestedIf)
2644  return failure();
2645 
2646  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2647  return failure();
2648 
2649  SmallVector<Value> thenYield(op.thenYield().getOperands());
2650  SmallVector<Value> elseYield;
2651  if (op.elseBlock())
2652  llvm::append_range(elseYield, op.elseYield().getOperands());
2653 
2654  // A list of indices for which we should upgrade the value yielded
2655  // in the else to a select.
2656  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2657 
2658  // If the outer scf.if yields a value produced by the inner scf.if,
2659  // only permit combining if the value yielded when the condition
2660  // is false in the outer scf.if is the same value yielded when the
2661  // inner scf.if condition is false.
2662  // Note that the array access to elseYield will not go out of bounds
2663  // since it must have the same length as thenYield, since they both
2664  // come from the same scf.if.
2665  for (const auto &tup : llvm::enumerate(thenYield)) {
2666  if (tup.value().getDefiningOp() == nestedIf) {
2667  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2668  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2669  elseYield[tup.index()]) {
2670  return failure();
2671  }
2672  // If the correctness test passes, we will yield
2673  // corresponding value from the inner scf.if
2674  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2675  continue;
2676  }
2677 
2678  // Otherwise, we need to ensure the else block of the combined
2679  // condition still returns the same value when the outer condition is
2680  // true and the inner condition is false. This can be accomplished if
2681  // the then value is defined outside the outer scf.if and we replace the
2682  // value with a select that considers just the outer condition. Since
2683  // the else region contains just the yield, its yielded value is
2684  // defined outside the scf.if, by definition.
2685 
2686  // If the then value is defined within the scf.if, bail.
2687  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2688  return failure();
2689  }
2690  elseYieldsToUpgradeToSelect.push_back(tup.index());
2691  }
2692 
2693  Location loc = op.getLoc();
2694  Value newCondition = rewriter.create<arith::AndIOp>(
2695  loc, op.getCondition(), nestedIf.getCondition());
2696  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2697  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2698 
2699  SmallVector<Value> results;
2700  llvm::append_range(results, newIf.getResults());
2701  rewriter.setInsertionPoint(newIf);
2702 
2703  for (auto idx : elseYieldsToUpgradeToSelect)
2704  results[idx] = rewriter.create<arith::SelectOp>(
2705  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2706 
2707  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2708  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2709  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2710  if (!elseYield.empty()) {
2711  rewriter.createBlock(&newIf.getElseRegion());
2712  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2713  rewriter.create<YieldOp>(loc, elseYield);
2714  }
2715  rewriter.replaceOp(op, results);
2716  return success();
2717  }
2718 };
2719 
2720 } // namespace
2721 
2722 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2723  MLIRContext *context) {
2724  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2725  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2726  RemoveStaticCondition, RemoveUnusedResults,
2727  ReplaceIfYieldWithConditionOrValue>(context);
2728 }
2729 
2730 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2731 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2732 Block *IfOp::elseBlock() {
2733  Region &r = getElseRegion();
2734  if (r.empty())
2735  return nullptr;
2736  return &r.back();
2737 }
2738 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2739 
2740 //===----------------------------------------------------------------------===//
2741 // ParallelOp
2742 //===----------------------------------------------------------------------===//
2743 
2744 void ParallelOp::build(
2745  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2746  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2748  bodyBuilderFn) {
2749  result.addOperands(lowerBounds);
2750  result.addOperands(upperBounds);
2751  result.addOperands(steps);
2752  result.addOperands(initVals);
2753  result.addAttribute(
2754  ParallelOp::getOperandSegmentSizeAttr(),
2755  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2756  static_cast<int32_t>(upperBounds.size()),
2757  static_cast<int32_t>(steps.size()),
2758  static_cast<int32_t>(initVals.size())}));
2759  result.addTypes(initVals.getTypes());
2760 
2761  OpBuilder::InsertionGuard guard(builder);
2762  unsigned numIVs = steps.size();
2763  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2764  SmallVector<Location, 8> argLocs(numIVs, result.location);
2765  Region *bodyRegion = result.addRegion();
2766  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2767 
2768  if (bodyBuilderFn) {
2769  builder.setInsertionPointToStart(bodyBlock);
2770  bodyBuilderFn(builder, result.location,
2771  bodyBlock->getArguments().take_front(numIVs),
2772  bodyBlock->getArguments().drop_front(numIVs));
2773  }
2774  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2775 }
2776 
2777 void ParallelOp::build(
2778  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2779  ValueRange upperBounds, ValueRange steps,
2780  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2781  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2782  // we don't capture a reference to a temporary by constructing the lambda at
2783  // function level.
2784  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2785  Location nestedLoc, ValueRange ivs,
2786  ValueRange) {
2787  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2788  };
2789  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2790  if (bodyBuilderFn)
2791  wrapper = wrappedBuilderFn;
2792 
2793  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2794  wrapper);
2795 }
2796 
2798  // Check that there is at least one value in lowerBound, upperBound and step.
2799  // It is sufficient to test only step, because it is ensured already that the
2800  // number of elements in lowerBound, upperBound and step are the same.
2801  Operation::operand_range stepValues = getStep();
2802  if (stepValues.empty())
2803  return emitOpError(
2804  "needs at least one tuple element for lowerBound, upperBound and step");
2805 
2806  // Check whether all constant step values are positive.
2807  for (Value stepValue : stepValues)
2808  if (auto cst = getConstantIntValue(stepValue))
2809  if (*cst <= 0)
2810  return emitOpError("constant step operand must be positive");
2811 
2812  // Check that the body defines the same number of block arguments as the
2813  // number of tuple elements in step.
2814  Block *body = getBody();
2815  if (body->getNumArguments() != stepValues.size())
2816  return emitOpError() << "expects the same number of induction variables: "
2817  << body->getNumArguments()
2818  << " as bound and step values: " << stepValues.size();
2819  for (auto arg : body->getArguments())
2820  if (!arg.getType().isIndex())
2821  return emitOpError(
2822  "expects arguments for the induction variable to be of index type");
2823 
2824  // Check that the yield has no results
2825  auto yield = verifyAndGetTerminator<scf::YieldOp>(
2826  *this, getRegion(), "expects body to terminate with 'scf.yield'");
2827  if (!yield)
2828  return failure();
2829  if (yield->getNumOperands() != 0)
2830  return yield.emitOpError() << "not allowed to have operands inside '"
2831  << ParallelOp::getOperationName() << "'";
2832 
2833  // Check that the number of results is the same as the number of ReduceOps.
2834  SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2835  auto resultsSize = getResults().size();
2836  auto reductionsSize = reductions.size();
2837  auto initValsSize = getInitVals().size();
2838  if (resultsSize != reductionsSize)
2839  return emitOpError() << "expects number of results: " << resultsSize
2840  << " to be the same as number of reductions: "
2841  << reductionsSize;
2842  if (resultsSize != initValsSize)
2843  return emitOpError() << "expects number of results: " << resultsSize
2844  << " to be the same as number of initial values: "
2845  << initValsSize;
2846 
2847  // Check that the types of the results and reductions are the same.
2848  for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2849  auto resultType = std::get<0>(resultAndReduce).getType();
2850  auto reduceOp = std::get<1>(resultAndReduce);
2851  auto reduceType = reduceOp.getOperand().getType();
2852  if (resultType != reduceType)
2853  return reduceOp.emitOpError()
2854  << "expects type of reduce: " << reduceType
2855  << " to be the same as result type: " << resultType;
2856  }
2857  return success();
2858 }
2859 
2861  auto &builder = parser.getBuilder();
2862  // Parse an opening `(` followed by induction variables followed by `)`
2865  return failure();
2866 
2867  // Parse loop bounds.
2869  if (parser.parseEqual() ||
2870  parser.parseOperandList(lower, ivs.size(),
2872  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2873  return failure();
2874 
2876  if (parser.parseKeyword("to") ||
2877  parser.parseOperandList(upper, ivs.size(),
2879  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2880  return failure();
2881 
2882  // Parse step values.
2884  if (parser.parseKeyword("step") ||
2885  parser.parseOperandList(steps, ivs.size(),
2887  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2888  return failure();
2889 
2890  // Parse init values.
2892  if (succeeded(parser.parseOptionalKeyword("init"))) {
2893  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2894  return failure();
2895  }
2896 
2897  // Parse optional results in case there is a reduce.
2898  if (parser.parseOptionalArrowTypeList(result.types))
2899  return failure();
2900 
2901  // Now parse the body.
2902  Region *body = result.addRegion();
2903  for (auto &iv : ivs)
2904  iv.type = builder.getIndexType();
2905  if (parser.parseRegion(*body, ivs))
2906  return failure();
2907 
2908  // Set `operandSegmentSizes` attribute.
2909  result.addAttribute(
2910  ParallelOp::getOperandSegmentSizeAttr(),
2911  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2912  static_cast<int32_t>(upper.size()),
2913  static_cast<int32_t>(steps.size()),
2914  static_cast<int32_t>(initVals.size())}));
2915 
2916  // Parse attributes.
2917  if (parser.parseOptionalAttrDict(result.attributes) ||
2918  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2919  result.operands))
2920  return failure();
2921 
2922  // Add a terminator if none was parsed.
2923  ForOp::ensureTerminator(*body, builder, result.location);
2924  return success();
2925 }
2926 
2928  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2929  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2930  if (!getInitVals().empty())
2931  p << " init (" << getInitVals() << ")";
2932  p.printOptionalArrowTypeList(getResultTypes());
2933  p << ' ';
2934  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2936  (*this)->getAttrs(),
2937  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2938 }
2939 
2940 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
2941 
2942 std::optional<Value> ParallelOp::getSingleInductionVar() {
2943  if (getNumLoops() != 1)
2944  return std::nullopt;
2945  return getBody()->getArgument(0);
2946 }
2947 
2948 std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
2949  if (getNumLoops() != 1)
2950  return std::nullopt;
2951  return getLowerBound()[0];
2952 }
2953 
2954 std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
2955  if (getNumLoops() != 1)
2956  return std::nullopt;
2957  return getUpperBound()[0];
2958 }
2959 
2960 std::optional<OpFoldResult> ParallelOp::getSingleStep() {
2961  if (getNumLoops() != 1)
2962  return std::nullopt;
2963  return getStep()[0];
2964 }
2965 
2967  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2968  if (!ivArg)
2969  return ParallelOp();
2970  assert(ivArg.getOwner() && "unlinked block argument");
2971  auto *containingOp = ivArg.getOwner()->getParentOp();
2972  return dyn_cast<ParallelOp>(containingOp);
2973 }
2974 
2975 namespace {
2976 // Collapse loop dimensions that perform a single iteration.
2977 struct ParallelOpSingleOrZeroIterationDimsFolder
2978  : public OpRewritePattern<ParallelOp> {
2980 
2981  LogicalResult matchAndRewrite(ParallelOp op,
2982  PatternRewriter &rewriter) const override {
2983  Location loc = op.getLoc();
2984 
2985  // Compute new loop bounds that omit all single-iteration loop dimensions.
2986  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
2987  IRMapping mapping;
2988  for (auto [lb, ub, step, iv] :
2989  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2990  op.getInductionVars())) {
2991  auto numIterations = constantTripCount(lb, ub, step);
2992  if (numIterations.has_value()) {
2993  // Remove the loop if it performs zero iterations.
2994  if (*numIterations == 0) {
2995  rewriter.replaceOp(op, op.getInitVals());
2996  return success();
2997  }
2998  // Replace the loop induction variable by the lower bound if the loop
2999  // performs a single iteration. Otherwise, copy the loop bounds.
3000  if (*numIterations == 1) {
3001  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3002  continue;
3003  }
3004  }
3005  newLowerBounds.push_back(lb);
3006  newUpperBounds.push_back(ub);
3007  newSteps.push_back(step);
3008  }
3009  // Exit if none of the loop dimensions perform a single iteration.
3010  if (newLowerBounds.size() == op.getLowerBound().size())
3011  return failure();
3012 
3013  if (newLowerBounds.empty()) {
3014  // All of the loop dimensions perform a single iteration. Inline
3015  // loop body and nested ReduceOp's
3016  SmallVector<Value> results;
3017  results.reserve(op.getInitVals().size());
3018  for (auto &bodyOp : op.getBody()->without_terminator()) {
3019  auto reduce = dyn_cast<ReduceOp>(bodyOp);
3020  if (!reduce) {
3021  rewriter.clone(bodyOp, mapping);
3022  continue;
3023  }
3024  Block &reduceBlock = reduce.getReductionOperator().front();
3025  auto initValIndex = results.size();
3026  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3027  mapping.map(reduceBlock.getArgument(1),
3028  mapping.lookupOrDefault(reduce.getOperand()));
3029  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3030  rewriter.clone(reduceBodyOp, mapping);
3031 
3032  auto result = mapping.lookupOrDefault(
3033  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3034  results.push_back(result);
3035  }
3036  rewriter.replaceOp(op, results);
3037  return success();
3038  }
3039  // Replace the parallel loop by lower-dimensional parallel loop.
3040  auto newOp =
3041  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3042  newSteps, op.getInitVals(), nullptr);
3043  // Clone the loop body and remap the block arguments of the collapsed loops
3044  // (inlining does not support a cancellable block argument mapping).
3045  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3046  newOp.getRegion().begin(), mapping);
3047  rewriter.replaceOp(op, newOp.getResults());
3048  return success();
3049  }
3050 };
3051 
3052 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3054 
3055  LogicalResult matchAndRewrite(ParallelOp op,
3056  PatternRewriter &rewriter) const override {
3057  Block &outerBody = *op.getBody();
3058  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3059  return failure();
3060 
3061  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3062  if (!innerOp)
3063  return failure();
3064 
3065  for (auto val : outerBody.getArguments())
3066  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3067  llvm::is_contained(innerOp.getUpperBound(), val) ||
3068  llvm::is_contained(innerOp.getStep(), val))
3069  return failure();
3070 
3071  // Reductions are not supported yet.
3072  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3073  return failure();
3074 
3075  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3076  ValueRange iterVals, ValueRange) {
3077  Block &innerBody = *innerOp.getBody();
3078  assert(iterVals.size() ==
3079  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3080  IRMapping mapping;
3081  mapping.map(outerBody.getArguments(),
3082  iterVals.take_front(outerBody.getNumArguments()));
3083  mapping.map(innerBody.getArguments(),
3084  iterVals.take_back(innerBody.getNumArguments()));
3085  for (Operation &op : innerBody.without_terminator())
3086  builder.clone(op, mapping);
3087  };
3088 
3089  auto concatValues = [](const auto &first, const auto &second) {
3090  SmallVector<Value> ret;
3091  ret.reserve(first.size() + second.size());
3092  ret.assign(first.begin(), first.end());
3093  ret.append(second.begin(), second.end());
3094  return ret;
3095  };
3096 
3097  auto newLowerBounds =
3098  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3099  auto newUpperBounds =
3100  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3101  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3102 
3103  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3104  newSteps, std::nullopt,
3105  bodyBuilder);
3106  return success();
3107  }
3108 };
3109 
3110 } // namespace
3111 
3112 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3113  MLIRContext *context) {
3114  results
3115  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3116  context);
3117 }
3118 
3119 /// Given the region at `index`, or the parent operation if `index` is None,
3120 /// return the successor regions. These are the regions that may be selected
3121 /// during the flow of control. `operands` is a set of optional attributes that
3122 /// correspond to a constant value for each operand, or null if that operand is
3123 /// not a constant.
3124 void ParallelOp::getSuccessorRegions(
3126  // Both the operation itself and the region may be branching into the body or
3127  // back into the operation itself. It is possible for loop not to enter the
3128  // body.
3129  regions.push_back(RegionSuccessor(&getRegion()));
3130  regions.push_back(RegionSuccessor());
3131 }
3132 
3133 //===----------------------------------------------------------------------===//
3134 // ReduceOp
3135 //===----------------------------------------------------------------------===//
3136 
3137 void ReduceOp::build(
3138  OpBuilder &builder, OperationState &result, Value operand,
3139  function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
3140  auto type = operand.getType();
3141  result.addOperands(operand);
3142 
3143  OpBuilder::InsertionGuard guard(builder);
3144  Region *bodyRegion = result.addRegion();
3145  Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
3146  {result.location, result.location});
3147  if (bodyBuilderFn)
3148  bodyBuilderFn(builder, result.location, body->getArgument(0),
3149  body->getArgument(1));
3150 }
3151 
3152 LogicalResult ReduceOp::verifyRegions() {
3153  // The region of a ReduceOp has two arguments of the same type as its operand.
3154  auto type = getOperand().getType();
3155  Block &block = getReductionOperator().front();
3156  if (block.empty())
3157  return emitOpError("the block inside reduce should not be empty");
3158  if (block.getNumArguments() != 2 ||
3159  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3160  return arg.getType() != type;
3161  }))
3162  return emitOpError() << "expects two arguments to reduce block of type "
3163  << type;
3164 
3165  // Check that the block is terminated by a ReduceReturnOp.
3166  if (!isa<ReduceReturnOp>(block.getTerminator()))
3167  return emitOpError("the block inside reduce should be terminated with a "
3168  "'scf.reduce.return' op");
3169 
3170  return success();
3171 }
3172 
3174  // Parse an opening `(` followed by the reduced value followed by `)`
3176  if (parser.parseLParen() || parser.parseOperand(operand) ||
3177  parser.parseRParen())
3178  return failure();
3179 
3180  Type resultType;
3181  // Parse the type of the operand (and also what reduce computes on).
3182  if (parser.parseColonType(resultType) ||
3183  parser.resolveOperand(operand, resultType, result.operands))
3184  return failure();
3185 
3186  // Now parse the body.
3187  Region *body = result.addRegion();
3188  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
3189  return failure();
3190 
3191  return success();
3192 }
3193 
3194 void ReduceOp::print(OpAsmPrinter &p) {
3195  p << "(" << getOperand() << ") ";
3196  p << " : " << getOperand().getType() << ' ';
3197  p.printRegion(getReductionOperator());
3198 }
3199 
3200 //===----------------------------------------------------------------------===//
3201 // ReduceReturnOp
3202 //===----------------------------------------------------------------------===//
3203 
3205  // The type of the return value should be the same type as the type of the
3206  // operand of the enclosing ReduceOp.
3207  auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
3208  Type reduceType = reduceOp.getOperand().getType();
3209  if (reduceType != getResult().getType())
3210  return emitOpError() << "needs to have type " << reduceType
3211  << " (the type of the enclosing ReduceOp)";
3212  return success();
3213 }
3214 
3215 //===----------------------------------------------------------------------===//
3216 // WhileOp
3217 //===----------------------------------------------------------------------===//
3218 
3219 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3220  ::mlir::OperationState &odsState, TypeRange resultTypes,
3221  ValueRange operands, BodyBuilderFn beforeBuilder,
3222  BodyBuilderFn afterBuilder) {
3223  odsState.addOperands(operands);
3224  odsState.addTypes(resultTypes);
3225 
3226  OpBuilder::InsertionGuard guard(odsBuilder);
3227 
3228  // Build before region.
3229  SmallVector<Location, 4> beforeArgLocs;
3230  beforeArgLocs.reserve(operands.size());
3231  for (Value operand : operands) {
3232  beforeArgLocs.push_back(operand.getLoc());
3233  }
3234 
3235  Region *beforeRegion = odsState.addRegion();
3236  Block *beforeBlock = odsBuilder.createBlock(
3237  beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
3238  if (beforeBuilder)
3239  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3240 
3241  // Build after region.
3242  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3243 
3244  Region *afterRegion = odsState.addRegion();
3245  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3246  resultTypes, afterArgLocs);
3247 
3248  if (afterBuilder)
3249  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3250 }
3251 
3252 ConditionOp WhileOp::getConditionOp() {
3253  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3254 }
3255 
3256 YieldOp WhileOp::getYieldOp() {
3257  return cast<YieldOp>(getAfterBody()->getTerminator());
3258 }
3259 
3260 MutableArrayRef<OpOperand> WhileOp::getYieldedValuesMutable() {
3261  return getYieldOp().getResultsMutable();
3262 }
3263 
3264 Block::BlockArgListType WhileOp::getBeforeArguments() {
3265  return getBeforeBody()->getArguments();
3266 }
3267 
3268 Block::BlockArgListType WhileOp::getAfterArguments() {
3269  return getAfterBody()->getArguments();
3270 }
3271 
3272 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3273  return getBeforeArguments();
3274 }
3275 
3276 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3277  assert(point == getBefore() &&
3278  "WhileOp is expected to branch only to the first region");
3279  return getInits();
3280 }
3281 
3282 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3284  // The parent op always branches to the condition region.
3285  if (point.isParent()) {
3286  regions.emplace_back(&getBefore(), getBefore().getArguments());
3287  return;
3288  }
3289 
3290  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3291  "there are only two regions in a WhileOp");
3292  // The body region always branches back to the condition region.
3293  if (point == getAfter()) {
3294  regions.emplace_back(&getBefore(), getBefore().getArguments());
3295  return;
3296  }
3297 
3298  regions.emplace_back(getResults());
3299  regions.emplace_back(&getAfter(), getAfter().getArguments());
3300 }
3301 
3302 SmallVector<Region *> WhileOp::getLoopRegions() {
3303  return {&getBefore(), &getAfter()};
3304 }
3305 
3306 /// Parses a `while` op.
3307 ///
3308 /// op ::= `scf.while` assignments `:` function-type region `do` region
3309 /// `attributes` attribute-dict
3310 /// initializer ::= /* empty */ | `(` assignment-list `)`
3311 /// assignment-list ::= assignment | assignment `,` assignment-list
3312 /// assignment ::= ssa-value `=` ssa-value
3316  Region *before = result.addRegion();
3317  Region *after = result.addRegion();
3318 
3319  OptionalParseResult listResult =
3320  parser.parseOptionalAssignmentList(regionArgs, operands);
3321  if (listResult.has_value() && failed(listResult.value()))
3322  return failure();
3323 
3324  FunctionType functionType;
3325  SMLoc typeLoc = parser.getCurrentLocation();
3326  if (failed(parser.parseColonType(functionType)))
3327  return failure();
3328 
3329  result.addTypes(functionType.getResults());
3330 
3331  if (functionType.getNumInputs() != operands.size()) {
3332  return parser.emitError(typeLoc)
3333  << "expected as many input types as operands "
3334  << "(expected " << operands.size() << " got "
3335  << functionType.getNumInputs() << ")";
3336  }
3337 
3338  // Resolve input operands.
3339  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3340  parser.getCurrentLocation(),
3341  result.operands)))
3342  return failure();
3343 
3344  // Propagate the types into the region arguments.
3345  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3346  regionArgs[i].type = functionType.getInput(i);
3347 
3348  return failure(parser.parseRegion(*before, regionArgs) ||
3349  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3351 }
3352 
3353 /// Prints a `while` op.
3355  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3356  p << " : ";
3357  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3358  p << ' ';
3359  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3360  p << " do ";
3361  p.printRegion(getAfter());
3362  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3363 }
3364 
3365 /// Verifies that two ranges of types match, i.e. have the same number of
3366 /// entries and that types are pairwise equals. Reports errors on the given
3367 /// operation in case of mismatch.
3368 template <typename OpTy>
3370  TypeRange right, StringRef message) {
3371  if (left.size() != right.size())
3372  return op.emitOpError("expects the same number of ") << message;
3373 
3374  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3375  if (left[i] != right[i]) {
3376  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3377  << message;
3378  diag.attachNote() << "for argument " << i << ", found " << left[i]
3379  << " and " << right[i];
3380  return diag;
3381  }
3382  }
3383 
3384  return success();
3385 }
3386 
3388  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3389  *this, getBefore(),
3390  "expects the 'before' region to terminate with 'scf.condition'");
3391  if (!beforeTerminator)
3392  return failure();
3393 
3394  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3395  *this, getAfter(),
3396  "expects the 'after' region to terminate with 'scf.yield'");
3397  return success(afterTerminator != nullptr);
3398 }
3399 
3400 namespace {
3401 /// Replace uses of the condition within the do block with true, since otherwise
3402 /// the block would not be evaluated.
3403 ///
3404 /// scf.while (..) : (i1, ...) -> ... {
3405 /// %condition = call @evaluate_condition() : () -> i1
3406 /// scf.condition(%condition) %condition : i1, ...
3407 /// } do {
3408 /// ^bb0(%arg0: i1, ...):
3409 /// use(%arg0)
3410 /// ...
3411 ///
3412 /// becomes
3413 /// scf.while (..) : (i1, ...) -> ... {
3414 /// %condition = call @evaluate_condition() : () -> i1
3415 /// scf.condition(%condition) %condition : i1, ...
3416 /// } do {
3417 /// ^bb0(%arg0: i1, ...):
3418 /// use(%true)
3419 /// ...
3420 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3422 
3423  LogicalResult matchAndRewrite(WhileOp op,
3424  PatternRewriter &rewriter) const override {
3425  auto term = op.getConditionOp();
3426 
3427  // These variables serve to prevent creating duplicate constants
3428  // and hold constant true or false values.
3429  Value constantTrue = nullptr;
3430 
3431  bool replaced = false;
3432  for (auto yieldedAndBlockArgs :
3433  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3434  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3435  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3436  if (!constantTrue)
3437  constantTrue = rewriter.create<arith::ConstantOp>(
3438  op.getLoc(), term.getCondition().getType(),
3439  rewriter.getBoolAttr(true));
3440 
3441  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3442  constantTrue);
3443  replaced = true;
3444  }
3445  }
3446  }
3447  return success(replaced);
3448  }
3449 };
3450 
3451 /// Remove loop invariant arguments from `before` block of scf.while.
3452 /// A before block argument is considered loop invariant if :-
3453 /// 1. i-th yield operand is equal to the i-th while operand.
3454 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3455 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3456 /// iter argument/while operand.
3457 /// For the arguments which are removed, their uses inside scf.while
3458 /// are replaced with their corresponding initial value.
3459 ///
3460 /// Eg:
3461 /// INPUT :-
3462 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3463 /// ..., %argN_before = %N)
3464 /// {
3465 /// ...
3466 /// scf.condition(%cond) %arg1_before, %arg0_before,
3467 /// %arg2_before, %arg0_before, ...
3468 /// } do {
3469 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3470 /// ..., %argK_after):
3471 /// ...
3472 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3473 /// }
3474 ///
3475 /// OUTPUT :-
3476 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3477 /// %N)
3478 /// {
3479 /// ...
3480 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3481 /// } do {
3482 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3483 /// ..., %argK_after):
3484 /// ...
3485 /// scf.yield %arg1_after, ..., %argN
3486 /// }
3487 ///
3488 /// EXPLANATION:
3489 /// We iterate over each yield operand.
3490 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3491 /// %arg0_before, which in turn is the 0-th iter argument. So we
3492 /// remove 0-th before block argument and yield operand, and replace
3493 /// all uses of the 0-th before block argument with its initial value
3494 /// %a.
3495 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3496 /// value. So we remove this operand and the corresponding before
3497 /// block argument and replace all uses of 1-th before block argument
3498 /// with %b.
3499 struct RemoveLoopInvariantArgsFromBeforeBlock
3500  : public OpRewritePattern<WhileOp> {
3502 
3503  LogicalResult matchAndRewrite(WhileOp op,
3504  PatternRewriter &rewriter) const override {
3505  Block &afterBlock = *op.getAfterBody();
3506  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3507  ConditionOp condOp = op.getConditionOp();
3508  OperandRange condOpArgs = condOp.getArgs();
3509  Operation *yieldOp = afterBlock.getTerminator();
3510  ValueRange yieldOpArgs = yieldOp->getOperands();
3511 
3512  bool canSimplify = false;
3513  for (const auto &it :
3514  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3515  auto index = static_cast<unsigned>(it.index());
3516  auto [initVal, yieldOpArg] = it.value();
3517  // If i-th yield operand is equal to the i-th operand of the scf.while,
3518  // the i-th before block argument is a loop invariant.
3519  if (yieldOpArg == initVal) {
3520  canSimplify = true;
3521  break;
3522  }
3523  // If the i-th yield operand is k-th after block argument, then we check
3524  // if the (k+1)-th condition op operand is equal to either the i-th before
3525  // block argument or the initial value of i-th before block argument. If
3526  // the comparison results `true`, i-th before block argument is a loop
3527  // invariant.
3528  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3529  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3530  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3531  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3532  canSimplify = true;
3533  break;
3534  }
3535  }
3536  }
3537 
3538  if (!canSimplify)
3539  return failure();
3540 
3541  SmallVector<Value> newInitArgs, newYieldOpArgs;
3542  DenseMap<unsigned, Value> beforeBlockInitValMap;
3543  SmallVector<Location> newBeforeBlockArgLocs;
3544  for (const auto &it :
3545  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3546  auto index = static_cast<unsigned>(it.index());
3547  auto [initVal, yieldOpArg] = it.value();
3548 
3549  // If i-th yield operand is equal to the i-th operand of the scf.while,
3550  // the i-th before block argument is a loop invariant.
3551  if (yieldOpArg == initVal) {
3552  beforeBlockInitValMap.insert({index, initVal});
3553  continue;
3554  } else {
3555  // If the i-th yield operand is k-th after block argument, then we check
3556  // if the (k+1)-th condition op operand is equal to either the i-th
3557  // before block argument or the initial value of i-th before block
3558  // argument. If the comparison results `true`, i-th before block
3559  // argument is a loop invariant.
3560  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3561  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3562  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3563  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3564  beforeBlockInitValMap.insert({index, initVal});
3565  continue;
3566  }
3567  }
3568  }
3569  newInitArgs.emplace_back(initVal);
3570  newYieldOpArgs.emplace_back(yieldOpArg);
3571  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3572  }
3573 
3574  {
3575  OpBuilder::InsertionGuard g(rewriter);
3576  rewriter.setInsertionPoint(yieldOp);
3577  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3578  }
3579 
3580  auto newWhile =
3581  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3582 
3583  Block &newBeforeBlock = *rewriter.createBlock(
3584  &newWhile.getBefore(), /*insertPt*/ {},
3585  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3586 
3587  Block &beforeBlock = *op.getBeforeBody();
3588  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3589  // For each i-th before block argument we find it's replacement value as :-
3590  // 1. If i-th before block argument is a loop invariant, we fetch it's
3591  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3592  // 2. Else we fetch j-th new before block argument as the replacement
3593  // value of i-th before block argument.
3594  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3595  // If the index 'i' argument was a loop invariant we fetch it's initial
3596  // value from `beforeBlockInitValMap`.
3597  if (beforeBlockInitValMap.count(i) != 0)
3598  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3599  else
3600  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3601  }
3602 
3603  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3604  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3605  newWhile.getAfter().begin());
3606 
3607  rewriter.replaceOp(op, newWhile.getResults());
3608  return success();
3609  }
3610 };
3611 
3612 /// Remove loop invariant value from result (condition op) of scf.while.
3613 /// A value is considered loop invariant if the final value yielded by
3614 /// scf.condition is defined outside of the `before` block. We remove the
3615 /// corresponding argument in `after` block and replace the use with the value.
3616 /// We also replace the use of the corresponding result of scf.while with the
3617 /// value.
3618 ///
3619 /// Eg:
3620 /// INPUT :-
3621 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3622 /// %argN_before = %N) {
3623 /// ...
3624 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3625 /// } do {
3626 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3627 /// ...
3628 /// some_func(%arg1_after)
3629 /// ...
3630 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3631 /// }
3632 ///
3633 /// OUTPUT :-
3634 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3635 /// ...
3636 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3637 /// } do {
3638 /// ^bb0(%arg0, %arg3, ..., %argM):
3639 /// ...
3640 /// some_func(%a)
3641 /// ...
3642 /// scf.yield %arg0, %b, ..., %argN
3643 /// }
3644 ///
3645 /// EXPLANATION:
3646 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3647 /// before block of scf.while, so they get removed.
3648 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3649 /// replaced by %b.
3650 /// 3. The corresponding after block argument %arg1_after's uses are
3651 /// replaced by %a and %arg2_after's uses are replaced by %b.
3652 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3654 
3655  LogicalResult matchAndRewrite(WhileOp op,
3656  PatternRewriter &rewriter) const override {
3657  Block &beforeBlock = *op.getBeforeBody();
3658  ConditionOp condOp = op.getConditionOp();
3659  OperandRange condOpArgs = condOp.getArgs();
3660 
3661  bool canSimplify = false;
3662  for (Value condOpArg : condOpArgs) {
3663  // Those values not defined within `before` block will be considered as
3664  // loop invariant values. We map the corresponding `index` with their
3665  // value.
3666  if (condOpArg.getParentBlock() != &beforeBlock) {
3667  canSimplify = true;
3668  break;
3669  }
3670  }
3671 
3672  if (!canSimplify)
3673  return failure();
3674 
3675  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3676 
3677  SmallVector<Value> newCondOpArgs;
3678  SmallVector<Type> newAfterBlockType;
3679  DenseMap<unsigned, Value> condOpInitValMap;
3680  SmallVector<Location> newAfterBlockArgLocs;
3681  for (const auto &it : llvm::enumerate(condOpArgs)) {
3682  auto index = static_cast<unsigned>(it.index());
3683  Value condOpArg = it.value();
3684  // Those values not defined within `before` block will be considered as
3685  // loop invariant values. We map the corresponding `index` with their
3686  // value.
3687  if (condOpArg.getParentBlock() != &beforeBlock) {
3688  condOpInitValMap.insert({index, condOpArg});
3689  } else {
3690  newCondOpArgs.emplace_back(condOpArg);
3691  newAfterBlockType.emplace_back(condOpArg.getType());
3692  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3693  }
3694  }
3695 
3696  {
3697  OpBuilder::InsertionGuard g(rewriter);
3698  rewriter.setInsertionPoint(condOp);
3699  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3700  newCondOpArgs);
3701  }
3702 
3703  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3704  op.getOperands());
3705 
3706  Block &newAfterBlock =
3707  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3708  newAfterBlockType, newAfterBlockArgLocs);
3709 
3710  Block &afterBlock = *op.getAfterBody();
3711  // Since a new scf.condition op was created, we need to fetch the new
3712  // `after` block arguments which will be used while replacing operations of
3713  // previous scf.while's `after` blocks. We'd also be fetching new result
3714  // values too.
3715  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3716  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3717  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3718  Value afterBlockArg, result;
3719  // If index 'i' argument was loop invariant we fetch it's value from the
3720  // `condOpInitMap` map.
3721  if (condOpInitValMap.count(i) != 0) {
3722  afterBlockArg = condOpInitValMap[i];
3723  result = afterBlockArg;
3724  } else {
3725  afterBlockArg = newAfterBlock.getArgument(j);
3726  result = newWhile.getResult(j);
3727  j++;
3728  }
3729  newAfterBlockArgs[i] = afterBlockArg;
3730  newWhileResults[i] = result;
3731  }
3732 
3733  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3734  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3735  newWhile.getBefore().begin());
3736 
3737  rewriter.replaceOp(op, newWhileResults);
3738  return success();
3739  }
3740 };
3741 
3742 /// Remove WhileOp results that are also unused in 'after' block.
3743 ///
3744 /// %0:2 = scf.while () : () -> (i32, i64) {
3745 /// %condition = "test.condition"() : () -> i1
3746 /// %v1 = "test.get_some_value"() : () -> i32
3747 /// %v2 = "test.get_some_value"() : () -> i64
3748 /// scf.condition(%condition) %v1, %v2 : i32, i64
3749 /// } do {
3750 /// ^bb0(%arg0: i32, %arg1: i64):
3751 /// "test.use"(%arg0) : (i32) -> ()
3752 /// scf.yield
3753 /// }
3754 /// return %0#0 : i32
3755 ///
3756 /// becomes
3757 /// %0 = scf.while () : () -> (i32) {
3758 /// %condition = "test.condition"() : () -> i1
3759 /// %v1 = "test.get_some_value"() : () -> i32
3760 /// %v2 = "test.get_some_value"() : () -> i64
3761 /// scf.condition(%condition) %v1 : i32
3762 /// } do {
3763 /// ^bb0(%arg0: i32):
3764 /// "test.use"(%arg0) : (i32) -> ()
3765 /// scf.yield
3766 /// }
3767 /// return %0 : i32
3768 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3770 
3771  LogicalResult matchAndRewrite(WhileOp op,
3772  PatternRewriter &rewriter) const override {
3773  auto term = op.getConditionOp();
3774  auto afterArgs = op.getAfterArguments();
3775  auto termArgs = term.getArgs();
3776 
3777  // Collect results mapping, new terminator args and new result types.
3778  SmallVector<unsigned> newResultsIndices;
3779  SmallVector<Type> newResultTypes;
3780  SmallVector<Value> newTermArgs;
3781  SmallVector<Location> newArgLocs;
3782  bool needUpdate = false;
3783  for (const auto &it :
3784  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3785  auto i = static_cast<unsigned>(it.index());
3786  Value result = std::get<0>(it.value());
3787  Value afterArg = std::get<1>(it.value());
3788  Value termArg = std::get<2>(it.value());
3789  if (result.use_empty() && afterArg.use_empty()) {
3790  needUpdate = true;
3791  } else {
3792  newResultsIndices.emplace_back(i);
3793  newTermArgs.emplace_back(termArg);
3794  newResultTypes.emplace_back(result.getType());
3795  newArgLocs.emplace_back(result.getLoc());
3796  }
3797  }
3798 
3799  if (!needUpdate)
3800  return failure();
3801 
3802  {
3803  OpBuilder::InsertionGuard g(rewriter);
3804  rewriter.setInsertionPoint(term);
3805  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3806  newTermArgs);
3807  }
3808 
3809  auto newWhile =
3810  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3811 
3812  Block &newAfterBlock = *rewriter.createBlock(
3813  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3814 
3815  // Build new results list and new after block args (unused entries will be
3816  // null).
3817  SmallVector<Value> newResults(op.getNumResults());
3818  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3819  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3820  newResults[it.value()] = newWhile.getResult(it.index());
3821  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3822  }
3823 
3824  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3825  newWhile.getBefore().begin());
3826 
3827  Block &afterBlock = *op.getAfterBody();
3828  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3829 
3830  rewriter.replaceOp(op, newResults);
3831  return success();
3832  }
3833 };
3834 
3835 /// Replace operations equivalent to the condition in the do block with true,
3836 /// since otherwise the block would not be evaluated.
3837 ///
3838 /// scf.while (..) : (i32, ...) -> ... {
3839 /// %z = ... : i32
3840 /// %condition = cmpi pred %z, %a
3841 /// scf.condition(%condition) %z : i32, ...
3842 /// } do {
3843 /// ^bb0(%arg0: i32, ...):
3844 /// %condition2 = cmpi pred %arg0, %a
3845 /// use(%condition2)
3846 /// ...
3847 ///
3848 /// becomes
3849 /// scf.while (..) : (i32, ...) -> ... {
3850 /// %z = ... : i32
3851 /// %condition = cmpi pred %z, %a
3852 /// scf.condition(%condition) %z : i32, ...
3853 /// } do {
3854 /// ^bb0(%arg0: i32, ...):
3855 /// use(%true)
3856 /// ...
3857 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3859 
3860  LogicalResult matchAndRewrite(scf::WhileOp op,
3861  PatternRewriter &rewriter) const override {
3862  using namespace scf;
3863  auto cond = op.getConditionOp();
3864  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3865  if (!cmp)
3866  return failure();
3867  bool changed = false;
3868  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3869  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3870  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3871  continue;
3872  for (OpOperand &u :
3873  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3874  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3875  if (!cmp2)
3876  continue;
3877  // For a binary operator 1-opIdx gets the other side.
3878  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3879  continue;
3880  bool samePredicate;
3881  if (cmp2.getPredicate() == cmp.getPredicate())
3882  samePredicate = true;
3883  else if (cmp2.getPredicate() ==
3884  arith::invertPredicate(cmp.getPredicate()))
3885  samePredicate = false;
3886  else
3887  continue;
3888 
3889  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3890  1);
3891  changed = true;
3892  }
3893  }
3894  }
3895  return success(changed);
3896  }
3897 };
3898 
3899 /// Remove unused init/yield args.
3900 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3902 
3903  LogicalResult matchAndRewrite(WhileOp op,
3904  PatternRewriter &rewriter) const override {
3905 
3906  if (!llvm::any_of(op.getBeforeArguments(),
3907  [](Value arg) { return arg.use_empty(); }))
3908  return rewriter.notifyMatchFailure(op, "No args to remove");
3909 
3910  YieldOp yield = op.getYieldOp();
3911 
3912  // Collect results mapping, new terminator args and new result types.
3913  SmallVector<Value> newYields;
3914  SmallVector<Value> newInits;
3915  llvm::BitVector argsToErase;
3916 
3917  size_t argsCount = op.getBeforeArguments().size();
3918  newYields.reserve(argsCount);
3919  newInits.reserve(argsCount);
3920  argsToErase.reserve(argsCount);
3921  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3922  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3923  if (beforeArg.use_empty()) {
3924  argsToErase.push_back(true);
3925  } else {
3926  argsToErase.push_back(false);
3927  newYields.emplace_back(yieldValue);
3928  newInits.emplace_back(initValue);
3929  }
3930  }
3931 
3932  Block &beforeBlock = *op.getBeforeBody();
3933  Block &afterBlock = *op.getAfterBody();
3934 
3935  beforeBlock.eraseArguments(argsToErase);
3936 
3937  Location loc = op.getLoc();
3938  auto newWhileOp =
3939  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
3940  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
3941  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3942  Block &newAfterBlock = *newWhileOp.getAfterBody();
3943 
3944  OpBuilder::InsertionGuard g(rewriter);
3945  rewriter.setInsertionPoint(yield);
3946  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3947 
3948  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
3949  newBeforeBlock.getArguments());
3950  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
3951  newAfterBlock.getArguments());
3952 
3953  rewriter.replaceOp(op, newWhileOp.getResults());
3954  return success();
3955  }
3956 };
3957 
3958 /// Remove duplicated ConditionOp args.
3959 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
3961 
3962  LogicalResult matchAndRewrite(WhileOp op,
3963  PatternRewriter &rewriter) const override {
3964  ConditionOp condOp = op.getConditionOp();
3965  ValueRange condOpArgs = condOp.getArgs();
3966 
3968  for (Value arg : condOpArgs)
3969  argsSet.insert(arg);
3970 
3971  if (argsSet.size() == condOpArgs.size())
3972  return rewriter.notifyMatchFailure(op, "No results to remove");
3973 
3974  llvm::SmallDenseMap<Value, unsigned> argsMap;
3975  SmallVector<Value> newArgs;
3976  argsMap.reserve(condOpArgs.size());
3977  newArgs.reserve(condOpArgs.size());
3978  for (Value arg : condOpArgs) {
3979  if (!argsMap.count(arg)) {
3980  auto pos = static_cast<unsigned>(argsMap.size());
3981  argsMap.insert({arg, pos});
3982  newArgs.emplace_back(arg);
3983  }
3984  }
3985 
3986  ValueRange argsRange(newArgs);
3987 
3988  Location loc = op.getLoc();
3989  auto newWhileOp = rewriter.create<scf::WhileOp>(
3990  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
3991  /*afterBody*/ nullptr);
3992  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3993  Block &newAfterBlock = *newWhileOp.getAfterBody();
3994 
3995  SmallVector<Value> afterArgsMapping;
3996  SmallVector<Value> resultsMapping;
3997  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
3998  auto it = argsMap.find(arg);
3999  assert(it != argsMap.end());
4000  auto pos = it->second;
4001  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4002  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4003  }
4004 
4005  OpBuilder::InsertionGuard g(rewriter);
4006  rewriter.setInsertionPoint(condOp);
4007  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4008  argsRange);
4009 
4010  Block &beforeBlock = *op.getBeforeBody();
4011  Block &afterBlock = *op.getAfterBody();
4012 
4013  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4014  newBeforeBlock.getArguments());
4015  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4016  rewriter.replaceOp(op, resultsMapping);
4017  return success();
4018  }
4019 };
4020 } // namespace
4021 
4022 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4023  MLIRContext *context) {
4024  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4025  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4026  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4027  WhileRemoveUnusedArgs>(context);
4028 }
4029 
4030 //===----------------------------------------------------------------------===//
4031 // IndexSwitchOp
4032 //===----------------------------------------------------------------------===//
4033 
4034 /// Parse the case regions and values.
4035 static ParseResult
4037  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4038  SmallVector<int64_t> caseValues;
4039  while (succeeded(p.parseOptionalKeyword("case"))) {
4040  int64_t value;
4041  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4042  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4043  return failure();
4044  caseValues.push_back(value);
4045  }
4046  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4047  return success();
4048 }
4049 
4050 /// Print the case regions and values.
4052  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4053  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4054  p.printNewline();
4055  p << "case " << value << ' ';
4056  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4057  }
4058 }
4059 
4061  if (getCases().size() != getCaseRegions().size()) {
4062  return emitOpError("has ")
4063  << getCaseRegions().size() << " case regions but "
4064  << getCases().size() << " case values";
4065  }
4066 
4067  DenseSet<int64_t> valueSet;
4068  for (int64_t value : getCases())
4069  if (!valueSet.insert(value).second)
4070  return emitOpError("has duplicate case value: ") << value;
4071  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4072  auto yield = dyn_cast<YieldOp>(region.front().back());
4073  if (!yield)
4074  return emitOpError("expected region to end with scf.yield, but got ")
4075  << region.front().back().getName();
4076 
4077  if (yield.getNumOperands() != getNumResults()) {
4078  return (emitOpError("expected each region to return ")
4079  << getNumResults() << " values, but " << name << " returns "
4080  << yield.getNumOperands())
4081  .attachNote(yield.getLoc())
4082  << "see yield operation here";
4083  }
4084  for (auto [idx, result, operand] :
4085  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4086  yield.getOperandTypes())) {
4087  if (result == operand)
4088  continue;
4089  return (emitOpError("expected result #")
4090  << idx << " of each region to be " << result)
4091  .attachNote(yield.getLoc())
4092  << name << " returns " << operand << " here";
4093  }
4094  return success();
4095  };
4096 
4097  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4098  return failure();
4099  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4100  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4101  return failure();
4102 
4103  return success();
4104 }
4105 
4106 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4107 
4108 Block &scf::IndexSwitchOp::getDefaultBlock() {
4109  return getDefaultRegion().front();
4110 }
4111 
4112 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4113  assert(idx < getNumCases() && "case index out-of-bounds");
4114  return getCaseRegions()[idx].front();
4115 }
4116 
4117 void IndexSwitchOp::getSuccessorRegions(
4119  // All regions branch back to the parent op.
4120  if (!point.isParent()) {
4121  successors.emplace_back(getResults());
4122  return;
4123  }
4124 
4125  llvm::copy(getRegions(), std::back_inserter(successors));
4126 }
4127 
4128 void IndexSwitchOp::getEntrySuccessorRegions(
4129  ArrayRef<Attribute> operands,
4130  SmallVectorImpl<RegionSuccessor> &successors) {
4131  FoldAdaptor adaptor(operands, *this);
4132 
4133  // If a constant was not provided, all regions are possible successors.
4134  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4135  if (!arg) {
4136  llvm::copy(getRegions(), std::back_inserter(successors));
4137  return;
4138  }
4139 
4140  // Otherwise, try to find a case with a matching value. If not, the
4141  // default region is the only successor.
4142  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4143  if (caseValue == arg.getInt()) {
4144  successors.emplace_back(&caseRegion);
4145  return;
4146  }
4147  }
4148  successors.emplace_back(&getDefaultRegion());
4149 }
4150 
4151 void IndexSwitchOp::getRegionInvocationBounds(
4153  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4154  if (!operandValue) {
4155  // All regions are invoked at most once.
4156  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4157  return;
4158  }
4159 
4160  unsigned liveIndex = getNumRegions() - 1;
4161  const auto *it = llvm::find(getCases(), operandValue.getInt());
4162  if (it != getCases().end())
4163  liveIndex = std::distance(getCases().begin(), it);
4164  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4165  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4166 }
4167 
4168 LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
4169  SmallVectorImpl<OpFoldResult> &results) {
4170  std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
4171  if (!maybeCst.has_value())
4172  return failure();
4173  int64_t cst = *maybeCst;
4174  int64_t caseIdx, e = getNumCases();
4175  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4176  if (cst == getCases()[caseIdx])
4177  break;
4178  }
4179 
4180  Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4181  : getDefaultRegion();
4182  Block &source = r.front();
4183  results.assign(source.getTerminator()->getOperands().begin(),
4184  source.getTerminator()->getOperands().end());
4185 
4186  Block *pDestination = (*this)->getBlock();
4187  if (!pDestination)
4188  return failure();
4189  Block::iterator insertionPoint = (*this)->getIterator();
4190  pDestination->getOperations().splice(insertionPoint, source.getOperations(),
4191  source.getOperations().begin(),
4192  std::prev(source.getOperations().end()));
4193 
4194  return success();
4195 }
4196 
4197 //===----------------------------------------------------------------------===//
4198 // TableGen'd op method definitions
4199 //===----------------------------------------------------------------------===//
4200 
4201 #define GET_OP_CLASSES
4202 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:702
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:694
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:105
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4036
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3369
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:427
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:85
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4051
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2471
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:315
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:133
bool empty()
Definition: Block.h:141
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation & back()
Definition: Block.h:145
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:154
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:195
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
iterator begin()
Definition: Block.h:136
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:186
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:202
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:91
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:561
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:304
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:825
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:334
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:591
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
Type getType() const
Return the type of this value.
Definition: Value.h:125
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:131
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:65
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:2966
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:78
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:670
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1884
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:627
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:588
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1557
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:228
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false)
Returns "success" when any of the elements in ofrs is a constant value.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:226
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:177
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.