MLIR  21.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // SCFDialect Dialect Interfaces
37 //===----------------------------------------------------------------------===//
38 
39 namespace {
40 struct SCFInlinerInterface : public DialectInlinerInterface {
42  // We don't have any special restrictions on what can be inlined into
43  // destination regions (e.g. while/conditional bodies). Always allow it.
44  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45  IRMapping &valueMapping) const final {
46  return true;
47  }
48  // Operations in scf dialect are always legal to inline since they are
49  // pure.
50  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
51  return true;
52  }
53  // Handle the given inlined terminator by replacing it with a new operation
54  // as necessary. Required when the region has only one block.
55  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
56  auto retValOp = dyn_cast<scf::YieldOp>(op);
57  if (!retValOp)
58  return;
59 
60  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
62  }
63  }
64 };
65 } // namespace
66 
67 //===----------------------------------------------------------------------===//
68 // SCFDialect
69 //===----------------------------------------------------------------------===//
70 
71 void SCFDialect::initialize() {
72  addOperations<
73 #define GET_OP_LIST
74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75  >();
76  addInterfaces<SCFInlinerInterface>();
77  declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
78  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
79  InParallelOp, ReduceReturnOp>();
80  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
81  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
82  ForallOp, InParallelOp, WhileOp, YieldOp>();
83  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
84 }
85 
86 /// Default callback for IfOp builders. Inserts a yield without arguments.
88  builder.create<scf::YieldOp>(loc);
89 }
90 
91 /// Verifies that the first block of the given `region` is terminated by a
92 /// TerminatorTy. Reports errors on the given operation if it is not the case.
93 template <typename TerminatorTy>
94 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
95  StringRef errorMessage) {
96  Operation *terminatorOperation = nullptr;
97  if (!region.empty() && !region.front().empty()) {
98  terminatorOperation = &region.front().back();
99  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
100  return yield;
101  }
102  auto diag = op->emitOpError(errorMessage);
103  if (terminatorOperation)
104  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
105  return nullptr;
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // ExecuteRegionOp
110 //===----------------------------------------------------------------------===//
111 
112 /// Replaces the given op with the contents of the given single-block region,
113 /// using the operands of the block terminator to replace operation results.
114 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
115  Region &region, ValueRange blockArgs = {}) {
116  assert(llvm::hasSingleElement(region) && "expected single-region block");
117  Block *block = &region.front();
118  Operation *terminator = block->getTerminator();
119  ValueRange results = terminator->getOperands();
120  rewriter.inlineBlockBefore(block, op, blockArgs);
121  rewriter.replaceOp(op, results);
122  rewriter.eraseOp(terminator);
123 }
124 
125 ///
126 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
127 /// block+
128 /// `}`
129 ///
130 /// Example:
131 /// scf.execute_region -> i32 {
132 /// %idx = load %rI[%i] : memref<128xi32>
133 /// return %idx : i32
134 /// }
135 ///
136 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
137  OperationState &result) {
138  if (parser.parseOptionalArrowTypeList(result.types))
139  return failure();
140 
141  // Introduce the body region and parse it.
142  Region *body = result.addRegion();
143  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
144  parser.parseOptionalAttrDict(result.attributes))
145  return failure();
146 
147  return success();
148 }
149 
151  p.printOptionalArrowTypeList(getResultTypes());
152 
153  p << ' ';
154  p.printRegion(getRegion(),
155  /*printEntryBlockArgs=*/false,
156  /*printBlockTerminators=*/true);
157 
158  p.printOptionalAttrDict((*this)->getAttrs());
159 }
160 
161 LogicalResult ExecuteRegionOp::verify() {
162  if (getRegion().empty())
163  return emitOpError("region needs to have at least one block");
164  if (getRegion().front().getNumArguments() > 0)
165  return emitOpError("region cannot have any arguments");
166  return success();
167 }
168 
169 // Inline an ExecuteRegionOp if it only contains one block.
170 // "test.foo"() : () -> ()
171 // %v = scf.execute_region -> i64 {
172 // %x = "test.val"() : () -> i64
173 // scf.yield %x : i64
174 // }
175 // "test.bar"(%v) : (i64) -> ()
176 //
177 // becomes
178 //
179 // "test.foo"() : () -> ()
180 // %x = "test.val"() : () -> i64
181 // "test.bar"(%x) : (i64) -> ()
182 //
183 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
185 
186  LogicalResult matchAndRewrite(ExecuteRegionOp op,
187  PatternRewriter &rewriter) const override {
188  if (!llvm::hasSingleElement(op.getRegion()))
189  return failure();
190  replaceOpWithRegion(rewriter, op, op.getRegion());
191  return success();
192  }
193 };
194 
195 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
196 // TODO generalize the conditions for operations which can be inlined into.
197 // func @func_execute_region_elim() {
198 // "test.foo"() : () -> ()
199 // %v = scf.execute_region -> i64 {
200 // %c = "test.cmp"() : () -> i1
201 // cf.cond_br %c, ^bb2, ^bb3
202 // ^bb2:
203 // %x = "test.val1"() : () -> i64
204 // cf.br ^bb4(%x : i64)
205 // ^bb3:
206 // %y = "test.val2"() : () -> i64
207 // cf.br ^bb4(%y : i64)
208 // ^bb4(%z : i64):
209 // scf.yield %z : i64
210 // }
211 // "test.bar"(%v) : (i64) -> ()
212 // return
213 // }
214 //
215 // becomes
216 //
217 // func @func_execute_region_elim() {
218 // "test.foo"() : () -> ()
219 // %c = "test.cmp"() : () -> i1
220 // cf.cond_br %c, ^bb1, ^bb2
221 // ^bb1: // pred: ^bb0
222 // %x = "test.val1"() : () -> i64
223 // cf.br ^bb3(%x : i64)
224 // ^bb2: // pred: ^bb0
225 // %y = "test.val2"() : () -> i64
226 // cf.br ^bb3(%y : i64)
227 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
228 // "test.bar"(%z) : (i64) -> ()
229 // return
230 // }
231 //
232 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
234 
235  LogicalResult matchAndRewrite(ExecuteRegionOp op,
236  PatternRewriter &rewriter) const override {
237  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
238  return failure();
239 
240  Block *prevBlock = op->getBlock();
241  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
242  rewriter.setInsertionPointToEnd(prevBlock);
243 
244  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
245 
246  for (Block &blk : op.getRegion()) {
247  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
248  rewriter.setInsertionPoint(yieldOp);
249  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
250  yieldOp.getResults());
251  rewriter.eraseOp(yieldOp);
252  }
253  }
254 
255  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
256  SmallVector<Value> blockArgs;
257 
258  for (auto res : op.getResults())
259  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
260 
261  rewriter.replaceOp(op, blockArgs);
262  return success();
263  }
264 };
265 
266 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
267  MLIRContext *context) {
269 }
270 
271 void ExecuteRegionOp::getSuccessorRegions(
273  // If the predecessor is the ExecuteRegionOp, branch into the body.
274  if (point.isParent()) {
275  regions.push_back(RegionSuccessor(&getRegion()));
276  return;
277  }
278 
279  // Otherwise, the region branches back to the parent operation.
280  regions.push_back(RegionSuccessor(getResults()));
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // ConditionOp
285 //===----------------------------------------------------------------------===//
286 
289  assert((point.isParent() || point == getParentOp().getAfter()) &&
290  "condition op can only exit the loop or branch to the after"
291  "region");
292  // Pass all operands except the condition to the successor region.
293  return getArgsMutable();
294 }
295 
296 void ConditionOp::getSuccessorRegions(
298  FoldAdaptor adaptor(operands, *this);
299 
300  WhileOp whileOp = getParentOp();
301 
302  // Condition can either lead to the after region or back to the parent op
303  // depending on whether the condition is true or not.
304  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
305  if (!boolAttr || boolAttr.getValue())
306  regions.emplace_back(&whileOp.getAfter(),
307  whileOp.getAfter().getArguments());
308  if (!boolAttr || !boolAttr.getValue())
309  regions.emplace_back(whileOp.getResults());
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // ForOp
314 //===----------------------------------------------------------------------===//
315 
316 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
317  Value ub, Value step, ValueRange initArgs,
318  BodyBuilderFn bodyBuilder) {
319  OpBuilder::InsertionGuard guard(builder);
320 
321  result.addOperands({lb, ub, step});
322  result.addOperands(initArgs);
323  for (Value v : initArgs)
324  result.addTypes(v.getType());
325  Type t = lb.getType();
326  Region *bodyRegion = result.addRegion();
327  Block *bodyBlock = builder.createBlock(bodyRegion);
328  bodyBlock->addArgument(t, result.location);
329  for (Value v : initArgs)
330  bodyBlock->addArgument(v.getType(), v.getLoc());
331 
332  // Create the default terminator if the builder is not provided and if the
333  // iteration arguments are not provided. Otherwise, leave this to the caller
334  // because we don't know which values to return from the loop.
335  if (initArgs.empty() && !bodyBuilder) {
336  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
337  } else if (bodyBuilder) {
338  OpBuilder::InsertionGuard guard(builder);
339  builder.setInsertionPointToStart(bodyBlock);
340  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
341  bodyBlock->getArguments().drop_front());
342  }
343 }
344 
345 LogicalResult ForOp::verify() {
346  // Check that the number of init args and op results is the same.
347  if (getInitArgs().size() != getNumResults())
348  return emitOpError(
349  "mismatch in number of loop-carried values and defined values");
350 
351  return success();
352 }
353 
354 LogicalResult ForOp::verifyRegions() {
355  // Check that the body defines as single block argument for the induction
356  // variable.
357  if (getInductionVar().getType() != getLowerBound().getType())
358  return emitOpError(
359  "expected induction variable to be same type as bounds and step");
360 
361  if (getNumRegionIterArgs() != getNumResults())
362  return emitOpError(
363  "mismatch in number of basic block args and defined values");
364 
365  auto initArgs = getInitArgs();
366  auto iterArgs = getRegionIterArgs();
367  auto opResults = getResults();
368  unsigned i = 0;
369  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
370  if (std::get<0>(e).getType() != std::get<2>(e).getType())
371  return emitOpError() << "types mismatch between " << i
372  << "th iter operand and defined value";
373  if (std::get<1>(e).getType() != std::get<2>(e).getType())
374  return emitOpError() << "types mismatch between " << i
375  << "th iter region arg and defined value";
376 
377  ++i;
378  }
379  return success();
380 }
381 
382 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
383  return SmallVector<Value>{getInductionVar()};
384 }
385 
386 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
388 }
389 
390 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
391  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
392 }
393 
394 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
396 }
397 
398 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
399 
400 /// Promotes the loop body of a forOp to its containing block if the forOp
401 /// it can be determined that the loop has a single iteration.
402 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
403  std::optional<int64_t> tripCount =
405  if (!tripCount.has_value() || tripCount != 1)
406  return failure();
407 
408  // Replace all results with the yielded values.
409  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
410  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
411 
412  // Replace block arguments with lower bound (replacement for IV) and
413  // iter_args.
414  SmallVector<Value> bbArgReplacements;
415  bbArgReplacements.push_back(getLowerBound());
416  llvm::append_range(bbArgReplacements, getInitArgs());
417 
418  // Move the loop body operations to the loop's containing block.
419  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
420  getOperation()->getIterator(), bbArgReplacements);
421 
422  // Erase the old terminator and the loop.
423  rewriter.eraseOp(yieldOp);
424  rewriter.eraseOp(*this);
425 
426  return success();
427 }
428 
429 /// Prints the initialization list in the form of
430 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
431 /// where 'inner' values are assumed to be region arguments and 'outer' values
432 /// are regular SSA values.
434  Block::BlockArgListType blocksArgs,
435  ValueRange initializers,
436  StringRef prefix = "") {
437  assert(blocksArgs.size() == initializers.size() &&
438  "expected same length of arguments and initializers");
439  if (initializers.empty())
440  return;
441 
442  p << prefix << '(';
443  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
444  p << std::get<0>(it) << " = " << std::get<1>(it);
445  });
446  p << ")";
447 }
448 
449 void ForOp::print(OpAsmPrinter &p) {
450  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
451  << getUpperBound() << " step " << getStep();
452 
453  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
454  if (!getInitArgs().empty())
455  p << " -> (" << getInitArgs().getTypes() << ')';
456  p << ' ';
457  if (Type t = getInductionVar().getType(); !t.isIndex())
458  p << " : " << t << ' ';
459  p.printRegion(getRegion(),
460  /*printEntryBlockArgs=*/false,
461  /*printBlockTerminators=*/!getInitArgs().empty());
462  p.printOptionalAttrDict((*this)->getAttrs());
463 }
464 
465 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
466  auto &builder = parser.getBuilder();
467  Type type;
468 
469  OpAsmParser::Argument inductionVariable;
470  OpAsmParser::UnresolvedOperand lb, ub, step;
471 
472  // Parse the induction variable followed by '='.
473  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
474  // Parse loop bounds.
475  parser.parseOperand(lb) || parser.parseKeyword("to") ||
476  parser.parseOperand(ub) || parser.parseKeyword("step") ||
477  parser.parseOperand(step))
478  return failure();
479 
480  // Parse the optional initial iteration arguments.
483  regionArgs.push_back(inductionVariable);
484 
485  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
486  if (hasIterArgs) {
487  // Parse assignment list and results type list.
488  if (parser.parseAssignmentList(regionArgs, operands) ||
489  parser.parseArrowTypeList(result.types))
490  return failure();
491  }
492 
493  if (regionArgs.size() != result.types.size() + 1)
494  return parser.emitError(
495  parser.getNameLoc(),
496  "mismatch in number of loop-carried values and defined values");
497 
498  // Parse optional type, else assume Index.
499  if (parser.parseOptionalColon())
500  type = builder.getIndexType();
501  else if (parser.parseType(type))
502  return failure();
503 
504  // Set block argument types, so that they are known when parsing the region.
505  regionArgs.front().type = type;
506  for (auto [iterArg, type] :
507  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
508  iterArg.type = type;
509 
510  // Parse the body region.
511  Region *body = result.addRegion();
512  if (parser.parseRegion(*body, regionArgs))
513  return failure();
514  ForOp::ensureTerminator(*body, builder, result.location);
515 
516  // Resolve input operands. This should be done after parsing the region to
517  // catch invalid IR where operands were defined inside of the region.
518  if (parser.resolveOperand(lb, type, result.operands) ||
519  parser.resolveOperand(ub, type, result.operands) ||
520  parser.resolveOperand(step, type, result.operands))
521  return failure();
522  if (hasIterArgs) {
523  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
524  operands, result.types)) {
525  Type type = std::get<2>(argOperandType);
526  std::get<0>(argOperandType).type = type;
527  if (parser.resolveOperand(std::get<1>(argOperandType), type,
528  result.operands))
529  return failure();
530  }
531  }
532 
533  // Parse the optional attribute list.
534  if (parser.parseOptionalAttrDict(result.attributes))
535  return failure();
536 
537  return success();
538 }
539 
540 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
541 
542 Block::BlockArgListType ForOp::getRegionIterArgs() {
543  return getBody()->getArguments().drop_front(getNumInductionVars());
544 }
545 
546 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
547  return getInitArgsMutable();
548 }
549 
550 FailureOr<LoopLikeOpInterface>
551 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
552  ValueRange newInitOperands,
553  bool replaceInitOperandUsesInLoop,
554  const NewYieldValuesFn &newYieldValuesFn) {
555  // Create a new loop before the existing one, with the extra operands.
556  OpBuilder::InsertionGuard g(rewriter);
557  rewriter.setInsertionPoint(getOperation());
558  auto inits = llvm::to_vector(getInitArgs());
559  inits.append(newInitOperands.begin(), newInitOperands.end());
560  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
561  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
562  [](OpBuilder &, Location, Value, ValueRange) {});
563  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
564 
565  // Generate the new yield values and append them to the scf.yield operation.
566  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
567  ArrayRef<BlockArgument> newIterArgs =
568  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
569  {
570  OpBuilder::InsertionGuard g(rewriter);
571  rewriter.setInsertionPoint(yieldOp);
572  SmallVector<Value> newYieldedValues =
573  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
574  assert(newInitOperands.size() == newYieldedValues.size() &&
575  "expected as many new yield values as new iter operands");
576  rewriter.modifyOpInPlace(yieldOp, [&]() {
577  yieldOp.getResultsMutable().append(newYieldedValues);
578  });
579  }
580 
581  // Move the loop body to the new op.
582  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
583  newLoop.getBody()->getArguments().take_front(
584  getBody()->getNumArguments()));
585 
586  if (replaceInitOperandUsesInLoop) {
587  // Replace all uses of `newInitOperands` with the corresponding basic block
588  // arguments.
589  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
590  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
591  [&](OpOperand &use) {
592  Operation *user = use.getOwner();
593  return newLoop->isProperAncestor(user);
594  });
595  }
596  }
597 
598  // Replace the old loop.
599  rewriter.replaceOp(getOperation(),
600  newLoop->getResults().take_front(getNumResults()));
601  return cast<LoopLikeOpInterface>(newLoop.getOperation());
602 }
603 
605  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
606  if (!ivArg)
607  return ForOp();
608  assert(ivArg.getOwner() && "unlinked block argument");
609  auto *containingOp = ivArg.getOwner()->getParentOp();
610  return dyn_cast_or_null<ForOp>(containingOp);
611 }
612 
613 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
614  return getInitArgs();
615 }
616 
617 void ForOp::getSuccessorRegions(RegionBranchPoint point,
619  // Both the operation itself and the region may be branching into the body or
620  // back into the operation itself. It is possible for loop not to enter the
621  // body.
622  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
623  regions.push_back(RegionSuccessor(getResults()));
624 }
625 
626 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
627 
628 /// Promotes the loop body of a forallOp to its containing block if it can be
629 /// determined that the loop has a single iteration.
630 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
631  for (auto [lb, ub, step] :
632  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
633  auto tripCount = constantTripCount(lb, ub, step);
634  if (!tripCount.has_value() || *tripCount != 1)
635  return failure();
636  }
637 
638  promote(rewriter, *this);
639  return success();
640 }
641 
642 Block::BlockArgListType ForallOp::getRegionIterArgs() {
643  return getBody()->getArguments().drop_front(getRank());
644 }
645 
646 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
647  return getOutputsMutable();
648 }
649 
650 /// Promotes the loop body of a scf::ForallOp to its containing block.
651 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
652  OpBuilder::InsertionGuard g(rewriter);
653  scf::InParallelOp terminator = forallOp.getTerminator();
654 
655  // Replace block arguments with lower bounds (replacements for IVs) and
656  // outputs.
657  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
658  bbArgReplacements.append(forallOp.getOutputs().begin(),
659  forallOp.getOutputs().end());
660 
661  // Move the loop body operations to the loop's containing block.
662  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
663  forallOp->getIterator(), bbArgReplacements);
664 
665  // Replace the terminator with tensor.insert_slice ops.
666  rewriter.setInsertionPointAfter(forallOp);
667  SmallVector<Value> results;
668  results.reserve(forallOp.getResults().size());
669  for (auto &yieldingOp : terminator.getYieldingOps()) {
670  auto parallelInsertSliceOp =
671  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
672 
673  Value dst = parallelInsertSliceOp.getDest();
674  Value src = parallelInsertSliceOp.getSource();
675  if (llvm::isa<TensorType>(src.getType())) {
676  results.push_back(rewriter.create<tensor::InsertSliceOp>(
677  forallOp.getLoc(), dst.getType(), src, dst,
678  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
679  parallelInsertSliceOp.getStrides(),
680  parallelInsertSliceOp.getStaticOffsets(),
681  parallelInsertSliceOp.getStaticSizes(),
682  parallelInsertSliceOp.getStaticStrides()));
683  } else {
684  llvm_unreachable("unsupported terminator");
685  }
686  }
687  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
688 
689  // Erase the old terminator and the loop.
690  rewriter.eraseOp(terminator);
691  rewriter.eraseOp(forallOp);
692 }
693 
695  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
696  ValueRange steps, ValueRange iterArgs,
698  bodyBuilder) {
699  assert(lbs.size() == ubs.size() &&
700  "expected the same number of lower and upper bounds");
701  assert(lbs.size() == steps.size() &&
702  "expected the same number of lower bounds and steps");
703 
704  // If there are no bounds, call the body-building function and return early.
705  if (lbs.empty()) {
706  ValueVector results =
707  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
708  : ValueVector();
709  assert(results.size() == iterArgs.size() &&
710  "loop nest body must return as many values as loop has iteration "
711  "arguments");
712  return LoopNest{{}, std::move(results)};
713  }
714 
715  // First, create the loop structure iteratively using the body-builder
716  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
717  OpBuilder::InsertionGuard guard(builder);
720  loops.reserve(lbs.size());
721  ivs.reserve(lbs.size());
722  ValueRange currentIterArgs = iterArgs;
723  Location currentLoc = loc;
724  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
725  auto loop = builder.create<scf::ForOp>(
726  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
727  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
728  ValueRange args) {
729  ivs.push_back(iv);
730  // It is safe to store ValueRange args because it points to block
731  // arguments of a loop operation that we also own.
732  currentIterArgs = args;
733  currentLoc = nestedLoc;
734  });
735  // Set the builder to point to the body of the newly created loop. We don't
736  // do this in the callback because the builder is reset when the callback
737  // returns.
738  builder.setInsertionPointToStart(loop.getBody());
739  loops.push_back(loop);
740  }
741 
742  // For all loops but the innermost, yield the results of the nested loop.
743  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
744  builder.setInsertionPointToEnd(loops[i].getBody());
745  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
746  }
747 
748  // In the body of the innermost loop, call the body building function if any
749  // and yield its results.
750  builder.setInsertionPointToStart(loops.back().getBody());
751  ValueVector results = bodyBuilder
752  ? bodyBuilder(builder, currentLoc, ivs,
753  loops.back().getRegionIterArgs())
754  : ValueVector();
755  assert(results.size() == iterArgs.size() &&
756  "loop nest body must return as many values as loop has iteration "
757  "arguments");
758  builder.setInsertionPointToEnd(loops.back().getBody());
759  builder.create<scf::YieldOp>(loc, results);
760 
761  // Return the loops.
762  ValueVector nestResults;
763  llvm::append_range(nestResults, loops.front().getResults());
764  return LoopNest{std::move(loops), std::move(nestResults)};
765 }
766 
768  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
769  ValueRange steps,
770  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
771  // Delegate to the main function by wrapping the body builder.
772  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
773  [&bodyBuilder](OpBuilder &nestedBuilder,
774  Location nestedLoc, ValueRange ivs,
775  ValueRange) -> ValueVector {
776  if (bodyBuilder)
777  bodyBuilder(nestedBuilder, nestedLoc, ivs);
778  return {};
779  });
780 }
781 
784  OpOperand &operand, Value replacement,
785  const ValueTypeCastFnTy &castFn) {
786  assert(operand.getOwner() == forOp);
787  Type oldType = operand.get().getType(), newType = replacement.getType();
788 
789  // 1. Create new iter operands, exactly 1 is replaced.
790  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791  "expected an iter OpOperand");
792  assert(operand.get().getType() != replacement.getType() &&
793  "Expected a different type");
794  SmallVector<Value> newIterOperands;
795  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
796  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797  newIterOperands.push_back(replacement);
798  continue;
799  }
800  newIterOperands.push_back(opOperand.get());
801  }
802 
803  // 2. Create the new forOp shell.
804  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806  forOp.getStep(), newIterOperands);
807  newForOp->setAttrs(forOp->getAttrs());
808  Block &newBlock = newForOp.getRegion().front();
809  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810  newBlock.getArguments().end());
811 
812  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813  // corresponding to the `replacement` value.
814  OpBuilder::InsertionGuard g(rewriter);
815  rewriter.setInsertionPointToStart(&newBlock);
816  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
817  &newForOp->getOpOperand(operand.getOperandNumber()));
818  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
819  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
820 
821  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
822  Block &oldBlock = forOp.getRegion().front();
823  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
824 
825  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
826  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
827  rewriter.setInsertionPoint(clonedYieldOp);
828  unsigned yieldIdx =
829  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
830  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
831  clonedYieldOp.getOperand(yieldIdx));
832  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
833  newYieldOperands[yieldIdx] = castOut;
834  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
835  rewriter.eraseOp(clonedYieldOp);
836 
837  // 6. Inject an outgoing cast op after the forOp.
838  rewriter.setInsertionPointAfter(newForOp);
839  SmallVector<Value> newResults = newForOp.getResults();
840  newResults[yieldIdx] =
841  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
842 
843  return newResults;
844 }
845 
846 namespace {
847 // Fold away ForOp iter arguments when:
848 // 1) The op yields the iter arguments.
849 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
850 // 3) The iter arguments have no use and the corresponding (operation) results
851 // have no use.
852 //
853 // These arguments must be defined outside of the ForOp region and can just be
854 // forwarded after simplifying the op inits, yields and returns.
855 //
856 // The implementation uses `inlineBlockBefore` to steal the content of the
857 // original ForOp and avoid cloning.
858 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
860 
861  LogicalResult matchAndRewrite(scf::ForOp forOp,
862  PatternRewriter &rewriter) const final {
863  bool canonicalize = false;
864 
865  // An internal flat vector of block transfer
866  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
867  // transformed block argument mappings. This plays the role of a
868  // IRMapping for the particular use case of calling into
869  // `inlineBlockBefore`.
870  int64_t numResults = forOp.getNumResults();
871  SmallVector<bool, 4> keepMask;
872  keepMask.reserve(numResults);
873  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
874  newResultValues;
875  newBlockTransferArgs.reserve(1 + numResults);
876  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
877  newIterArgs.reserve(forOp.getInitArgs().size());
878  newYieldValues.reserve(numResults);
879  newResultValues.reserve(numResults);
880  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
881  for (auto [init, arg, result, yielded] :
882  llvm::zip(forOp.getInitArgs(), // iter from outside
883  forOp.getRegionIterArgs(), // iter inside region
884  forOp.getResults(), // op results
885  forOp.getYieldedValues() // iter yield
886  )) {
887  // Forwarded is `true` when:
888  // 1) The region `iter` argument is yielded.
889  // 2) The region `iter` argument the corresponding input is yielded.
890  // 3) The region `iter` argument has no use, and the corresponding op
891  // result has no use.
892  bool forwarded = (arg == yielded) || (init == yielded) ||
893  (arg.use_empty() && result.use_empty());
894  if (forwarded) {
895  canonicalize = true;
896  keepMask.push_back(false);
897  newBlockTransferArgs.push_back(init);
898  newResultValues.push_back(init);
899  continue;
900  }
901 
902  // Check if a previous kept argument always has the same values for init
903  // and yielded values.
904  if (auto it = initYieldToArg.find({init, yielded});
905  it != initYieldToArg.end()) {
906  canonicalize = true;
907  keepMask.push_back(false);
908  auto [sameArg, sameResult] = it->second;
909  rewriter.replaceAllUsesWith(arg, sameArg);
910  rewriter.replaceAllUsesWith(result, sameResult);
911  // The replacement value doesn't matter because there are no uses.
912  newBlockTransferArgs.push_back(init);
913  newResultValues.push_back(init);
914  continue;
915  }
916 
917  // This value is kept.
918  initYieldToArg.insert({{init, yielded}, {arg, result}});
919  keepMask.push_back(true);
920  newIterArgs.push_back(init);
921  newYieldValues.push_back(yielded);
922  newBlockTransferArgs.push_back(Value()); // placeholder with null value
923  newResultValues.push_back(Value()); // placeholder with null value
924  }
925 
926  if (!canonicalize)
927  return failure();
928 
929  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
930  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
931  forOp.getStep(), newIterArgs);
932  newForOp->setAttrs(forOp->getAttrs());
933  Block &newBlock = newForOp.getRegion().front();
934 
935  // Replace the null placeholders with newly constructed values.
936  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
937  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
938  idx != e; ++idx) {
939  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
940  Value &newResultVal = newResultValues[idx];
941  assert((blockTransferArg && newResultVal) ||
942  (!blockTransferArg && !newResultVal));
943  if (!blockTransferArg) {
944  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
945  newResultVal = newForOp.getResult(collapsedIdx++);
946  }
947  }
948 
949  Block &oldBlock = forOp.getRegion().front();
950  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
951  "unexpected argument size mismatch");
952 
953  // No results case: the scf::ForOp builder already created a zero
954  // result terminator. Merge before this terminator and just get rid of the
955  // original terminator that has been merged in.
956  if (newIterArgs.empty()) {
957  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
958  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
959  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
960  rewriter.replaceOp(forOp, newResultValues);
961  return success();
962  }
963 
964  // No terminator case: merge and rewrite the merged terminator.
965  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
966  OpBuilder::InsertionGuard g(rewriter);
967  rewriter.setInsertionPoint(mergedTerminator);
968  SmallVector<Value, 4> filteredOperands;
969  filteredOperands.reserve(newResultValues.size());
970  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
971  if (keepMask[idx])
972  filteredOperands.push_back(mergedTerminator.getOperand(idx));
973  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
974  filteredOperands);
975  };
976 
977  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
978  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
979  cloneFilteredTerminator(mergedYieldOp);
980  rewriter.eraseOp(mergedYieldOp);
981  rewriter.replaceOp(forOp, newResultValues);
982  return success();
983  }
984 };
985 
986 /// Util function that tries to compute a constant diff between u and l.
987 /// Returns std::nullopt when the difference between two AffineValueMap is
988 /// dynamic.
989 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
990  IntegerAttr clb, cub;
991  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
992  llvm::APInt lbValue = clb.getValue();
993  llvm::APInt ubValue = cub.getValue();
994  return (ubValue - lbValue).getSExtValue();
995  }
996 
997  // Else a simple pattern match for x + c or c + x
998  llvm::APInt diff;
999  if (matchPattern(
1000  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
1001  matchPattern(
1002  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1003  return diff.getSExtValue();
1004  return std::nullopt;
1005 }
1006 
1007 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1008 /// single-iteration loops with their bodies, and removes empty loops that
1009 /// iterate at least once and only return values defined outside of the loop.
1010 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1012 
1013  LogicalResult matchAndRewrite(ForOp op,
1014  PatternRewriter &rewriter) const override {
1015  // If the upper bound is the same as the lower bound, the loop does not
1016  // iterate, just remove it.
1017  if (op.getLowerBound() == op.getUpperBound()) {
1018  rewriter.replaceOp(op, op.getInitArgs());
1019  return success();
1020  }
1021 
1022  std::optional<int64_t> diff =
1023  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1024  if (!diff)
1025  return failure();
1026 
1027  // If the loop is known to have 0 iterations, remove it.
1028  if (*diff <= 0) {
1029  rewriter.replaceOp(op, op.getInitArgs());
1030  return success();
1031  }
1032 
1033  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1034  if (!maybeStepValue)
1035  return failure();
1036 
1037  // If the loop is known to have 1 iteration, inline its body and remove the
1038  // loop.
1039  llvm::APInt stepValue = *maybeStepValue;
1040  if (stepValue.sge(*diff)) {
1041  SmallVector<Value, 4> blockArgs;
1042  blockArgs.reserve(op.getInitArgs().size() + 1);
1043  blockArgs.push_back(op.getLowerBound());
1044  llvm::append_range(blockArgs, op.getInitArgs());
1045  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1046  return success();
1047  }
1048 
1049  // Now we are left with loops that have more than 1 iterations.
1050  Block &block = op.getRegion().front();
1051  if (!llvm::hasSingleElement(block))
1052  return failure();
1053  // If the loop is empty, iterates at least once, and only returns values
1054  // defined outside of the loop, remove it and replace it with yield values.
1055  if (llvm::any_of(op.getYieldedValues(),
1056  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1057  return failure();
1058  rewriter.replaceOp(op, op.getYieldedValues());
1059  return success();
1060  }
1061 };
1062 
1063 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1064 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1065 ///
1066 /// ```
1067 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1068 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1069 /// -> (tensor<?x?xf32>) {
1070 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1071 /// scf.yield %2 : tensor<?x?xf32>
1072 /// }
1073 /// use_of(%1)
1074 /// ```
1075 ///
1076 /// folds into:
1077 ///
1078 /// ```
1079 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1080 /// -> (tensor<32x1024xf32>) {
1081 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1082 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1083 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1084 /// scf.yield %4 : tensor<32x1024xf32>
1085 /// }
1086 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1087 /// use_of(%1)
1088 /// ```
1089 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1091 
1092  LogicalResult matchAndRewrite(ForOp op,
1093  PatternRewriter &rewriter) const override {
1094  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1095  OpOperand &iterOpOperand = std::get<0>(it);
1096  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1097  if (!incomingCast ||
1098  incomingCast.getSource().getType() == incomingCast.getType())
1099  continue;
1100  // If the dest type of the cast does not preserve static information in
1101  // the source type.
1103  incomingCast.getDest().getType(),
1104  incomingCast.getSource().getType()))
1105  continue;
1106  if (!std::get<1>(it).hasOneUse())
1107  continue;
1108 
1109  // Create a new ForOp with that iter operand replaced.
1110  rewriter.replaceOp(
1112  rewriter, op, iterOpOperand, incomingCast.getSource(),
1113  [](OpBuilder &b, Location loc, Type type, Value source) {
1114  return b.create<tensor::CastOp>(loc, type, source);
1115  }));
1116  return success();
1117  }
1118  return failure();
1119  }
1120 };
1121 
1122 } // namespace
1123 
1124 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1125  MLIRContext *context) {
1126  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1127  context);
1128 }
1129 
1130 std::optional<APInt> ForOp::getConstantStep() {
1131  IntegerAttr step;
1132  if (matchPattern(getStep(), m_Constant(&step)))
1133  return step.getValue();
1134  return {};
1135 }
1136 
1137 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1138  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1139 }
1140 
1141 Speculation::Speculatability ForOp::getSpeculatability() {
1142  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1143  // and End.
1144  if (auto constantStep = getConstantStep())
1145  if (*constantStep == 1)
1147 
1148  // For Step != 1, the loop may not terminate. We can add more smarts here if
1149  // needed.
1151 }
1152 
1153 //===----------------------------------------------------------------------===//
1154 // ForallOp
1155 //===----------------------------------------------------------------------===//
1156 
1157 LogicalResult ForallOp::verify() {
1158  unsigned numLoops = getRank();
1159  // Check number of outputs.
1160  if (getNumResults() != getOutputs().size())
1161  return emitOpError("produces ")
1162  << getNumResults() << " results, but has only "
1163  << getOutputs().size() << " outputs";
1164 
1165  // Check that the body defines block arguments for thread indices and outputs.
1166  auto *body = getBody();
1167  if (body->getNumArguments() != numLoops + getOutputs().size())
1168  return emitOpError("region expects ") << numLoops << " arguments";
1169  for (int64_t i = 0; i < numLoops; ++i)
1170  if (!body->getArgument(i).getType().isIndex())
1171  return emitOpError("expects ")
1172  << i << "-th block argument to be an index";
1173  for (unsigned i = 0; i < getOutputs().size(); ++i)
1174  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1175  return emitOpError("type mismatch between ")
1176  << i << "-th output and corresponding block argument";
1177  if (getMapping().has_value() && !getMapping()->empty()) {
1178  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1179  return emitOpError() << "mapping attribute size must match op rank";
1180  for (auto map : getMapping()->getValue()) {
1181  if (!isa<DeviceMappingAttrInterface>(map))
1182  return emitOpError()
1183  << getMappingAttrName() << " is not device mapping attribute";
1184  }
1185  }
1186 
1187  // Verify mixed static/dynamic control variables.
1188  Operation *op = getOperation();
1189  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1190  getStaticLowerBound(),
1191  getDynamicLowerBound())))
1192  return failure();
1193  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1194  getStaticUpperBound(),
1195  getDynamicUpperBound())))
1196  return failure();
1197  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1198  getStaticStep(), getDynamicStep())))
1199  return failure();
1200 
1201  return success();
1202 }
1203 
1204 void ForallOp::print(OpAsmPrinter &p) {
1205  Operation *op = getOperation();
1206  p << " (" << getInductionVars();
1207  if (isNormalized()) {
1208  p << ") in ";
1209  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1210  /*valueTypes=*/{}, /*scalables=*/{},
1212  } else {
1213  p << ") = ";
1214  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1215  /*valueTypes=*/{}, /*scalables=*/{},
1217  p << " to ";
1218  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1219  /*valueTypes=*/{}, /*scalables=*/{},
1221  p << " step ";
1222  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1223  /*valueTypes=*/{}, /*scalables=*/{},
1225  }
1226  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1227  p << " ";
1228  if (!getRegionOutArgs().empty())
1229  p << "-> (" << getResultTypes() << ") ";
1230  p.printRegion(getRegion(),
1231  /*printEntryBlockArgs=*/false,
1232  /*printBlockTerminators=*/getNumResults() > 0);
1233  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1234  getStaticLowerBoundAttrName(),
1235  getStaticUpperBoundAttrName(),
1236  getStaticStepAttrName()});
1237 }
1238 
1239 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1240  OpBuilder b(parser.getContext());
1241  auto indexType = b.getIndexType();
1242 
1243  // Parse an opening `(` followed by thread index variables followed by `)`
1244  // TODO: when we can refer to such "induction variable"-like handles from the
1245  // declarative assembly format, we can implement the parser as a custom hook.
1248  return failure();
1249 
1250  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1251  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1252  dynamicSteps;
1253  if (succeeded(parser.parseOptionalKeyword("in"))) {
1254  // Parse upper bounds.
1255  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1256  /*valueTypes=*/nullptr,
1258  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1259  return failure();
1260 
1261  unsigned numLoops = ivs.size();
1262  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1263  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1264  } else {
1265  // Parse lower bounds.
1266  if (parser.parseEqual() ||
1267  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1268  /*valueTypes=*/nullptr,
1270 
1271  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1272  return failure();
1273 
1274  // Parse upper bounds.
1275  if (parser.parseKeyword("to") ||
1276  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1277  /*valueTypes=*/nullptr,
1279  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1280  return failure();
1281 
1282  // Parse step values.
1283  if (parser.parseKeyword("step") ||
1284  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1285  /*valueTypes=*/nullptr,
1287  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1288  return failure();
1289  }
1290 
1291  // Parse out operands and results.
1294  SMLoc outOperandsLoc = parser.getCurrentLocation();
1295  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1296  if (outOperands.size() != result.types.size())
1297  return parser.emitError(outOperandsLoc,
1298  "mismatch between out operands and types");
1299  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1300  parser.parseOptionalArrowTypeList(result.types) ||
1301  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1302  result.operands))
1303  return failure();
1304  }
1305 
1306  // Parse region.
1308  std::unique_ptr<Region> region = std::make_unique<Region>();
1309  for (auto &iv : ivs) {
1310  iv.type = b.getIndexType();
1311  regionArgs.push_back(iv);
1312  }
1313  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1314  auto &out = it.value();
1315  out.type = result.types[it.index()];
1316  regionArgs.push_back(out);
1317  }
1318  if (parser.parseRegion(*region, regionArgs))
1319  return failure();
1320 
1321  // Ensure terminator and move region.
1322  ForallOp::ensureTerminator(*region, b, result.location);
1323  result.addRegion(std::move(region));
1324 
1325  // Parse the optional attribute list.
1326  if (parser.parseOptionalAttrDict(result.attributes))
1327  return failure();
1328 
1329  result.addAttribute("staticLowerBound", staticLbs);
1330  result.addAttribute("staticUpperBound", staticUbs);
1331  result.addAttribute("staticStep", staticSteps);
1332  result.addAttribute("operandSegmentSizes",
1334  {static_cast<int32_t>(dynamicLbs.size()),
1335  static_cast<int32_t>(dynamicUbs.size()),
1336  static_cast<int32_t>(dynamicSteps.size()),
1337  static_cast<int32_t>(outOperands.size())}));
1338  return success();
1339 }
1340 
1341 // Builder that takes loop bounds.
1342 void ForallOp::build(
1345  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1346  std::optional<ArrayAttr> mapping,
1347  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1348  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1349  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1350  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1351  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1352  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1353 
1354  result.addOperands(dynamicLbs);
1355  result.addOperands(dynamicUbs);
1356  result.addOperands(dynamicSteps);
1357  result.addOperands(outputs);
1358  result.addTypes(TypeRange(outputs));
1359 
1360  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1361  b.getDenseI64ArrayAttr(staticLbs));
1362  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1363  b.getDenseI64ArrayAttr(staticUbs));
1364  result.addAttribute(getStaticStepAttrName(result.name),
1365  b.getDenseI64ArrayAttr(staticSteps));
1366  result.addAttribute(
1367  "operandSegmentSizes",
1368  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1369  static_cast<int32_t>(dynamicUbs.size()),
1370  static_cast<int32_t>(dynamicSteps.size()),
1371  static_cast<int32_t>(outputs.size())}));
1372  if (mapping.has_value()) {
1374  mapping.value());
1375  }
1376 
1377  Region *bodyRegion = result.addRegion();
1379  b.createBlock(bodyRegion);
1380  Block &bodyBlock = bodyRegion->front();
1381 
1382  // Add block arguments for indices and outputs.
1383  bodyBlock.addArguments(
1384  SmallVector<Type>(lbs.size(), b.getIndexType()),
1385  SmallVector<Location>(staticLbs.size(), result.location));
1386  bodyBlock.addArguments(
1387  TypeRange(outputs),
1388  SmallVector<Location>(outputs.size(), result.location));
1389 
1390  b.setInsertionPointToStart(&bodyBlock);
1391  if (!bodyBuilderFn) {
1392  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1393  return;
1394  }
1395  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1396 }
1397 
1398 // Builder that takes loop bounds.
1399 void ForallOp::build(
1401  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1402  std::optional<ArrayAttr> mapping,
1403  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1404  unsigned numLoops = ubs.size();
1405  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1406  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1407  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1408 }
1409 
1410 // Checks if the lbs are zeros and steps are ones.
1411 bool ForallOp::isNormalized() {
1412  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1413  return llvm::all_of(results, [&](OpFoldResult ofr) {
1414  auto intValue = getConstantIntValue(ofr);
1415  return intValue.has_value() && intValue == val;
1416  });
1417  };
1418  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1419 }
1420 
1421 InParallelOp ForallOp::getTerminator() {
1422  return cast<InParallelOp>(getBody()->getTerminator());
1423 }
1424 
1425 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1426  SmallVector<Operation *> storeOps;
1427  InParallelOp inParallelOp = getTerminator();
1428  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1429  if (auto parallelInsertSliceOp =
1430  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1431  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1432  storeOps.push_back(parallelInsertSliceOp);
1433  }
1434  }
1435  return storeOps;
1436 }
1437 
1438 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1439  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1440 }
1441 
1442 // Get lower bounds as OpFoldResult.
1443 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1444  Builder b(getOperation()->getContext());
1445  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1446 }
1447 
1448 // Get upper bounds as OpFoldResult.
1449 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1450  Builder b(getOperation()->getContext());
1451  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1452 }
1453 
1454 // Get steps as OpFoldResult.
1455 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1456  Builder b(getOperation()->getContext());
1457  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1458 }
1459 
1461  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1462  if (!tidxArg)
1463  return ForallOp();
1464  assert(tidxArg.getOwner() && "unlinked block argument");
1465  auto *containingOp = tidxArg.getOwner()->getParentOp();
1466  return dyn_cast<ForallOp>(containingOp);
1467 }
1468 
1469 namespace {
1470 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1471 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1473 
1474  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1475  PatternRewriter &rewriter) const final {
1476  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1477  if (!forallOp)
1478  return failure();
1479  Value sharedOut =
1480  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1481  ->get();
1482  rewriter.modifyOpInPlace(
1483  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1484  return success();
1485  }
1486 };
1487 
1488 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1489 public:
1491 
1492  LogicalResult matchAndRewrite(ForallOp op,
1493  PatternRewriter &rewriter) const override {
1494  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1495  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1496  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1497  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1498  failed(foldDynamicIndexList(mixedUpperBound)) &&
1499  failed(foldDynamicIndexList(mixedStep)))
1500  return failure();
1501 
1502  rewriter.modifyOpInPlace(op, [&]() {
1503  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1504  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1505  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1506  staticLowerBound);
1507  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1508  op.setStaticLowerBound(staticLowerBound);
1509 
1510  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1511  staticUpperBound);
1512  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1513  op.setStaticUpperBound(staticUpperBound);
1514 
1515  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1516  op.getDynamicStepMutable().assign(dynamicStep);
1517  op.setStaticStep(staticStep);
1518 
1519  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1520  rewriter.getDenseI32ArrayAttr(
1521  {static_cast<int32_t>(dynamicLowerBound.size()),
1522  static_cast<int32_t>(dynamicUpperBound.size()),
1523  static_cast<int32_t>(dynamicStep.size()),
1524  static_cast<int32_t>(op.getNumResults())}));
1525  });
1526  return success();
1527  }
1528 };
1529 
1530 /// The following canonicalization pattern folds the iter arguments of
1531 /// scf.forall op if :-
1532 /// 1. The corresponding result has zero uses.
1533 /// 2. The iter argument is NOT being modified within the loop body.
1534 /// uses.
1535 ///
1536 /// Example of first case :-
1537 /// INPUT:
1538 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1539 /// {
1540 /// ...
1541 /// <SOME USE OF %arg0>
1542 /// <SOME USE OF %arg1>
1543 /// <SOME USE OF %arg2>
1544 /// ...
1545 /// scf.forall.in_parallel {
1546 /// <STORE OP WITH DESTINATION %arg1>
1547 /// <STORE OP WITH DESTINATION %arg0>
1548 /// <STORE OP WITH DESTINATION %arg2>
1549 /// }
1550 /// }
1551 /// return %res#1
1552 ///
1553 /// OUTPUT:
1554 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1555 /// {
1556 /// ...
1557 /// <SOME USE OF %a>
1558 /// <SOME USE OF %new_arg0>
1559 /// <SOME USE OF %c>
1560 /// ...
1561 /// scf.forall.in_parallel {
1562 /// <STORE OP WITH DESTINATION %new_arg0>
1563 /// }
1564 /// }
1565 /// return %res
1566 ///
1567 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1568 /// scf.forall is replaced by their corresponding operands.
1569 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1570 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1571 /// this canonicalization remains valid. For more details, please refer
1572 /// to :
1573 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1574 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1575 /// handles tensor.parallel_insert_slice ops only.
1576 ///
1577 /// Example of second case :-
1578 /// INPUT:
1579 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1580 /// {
1581 /// ...
1582 /// <SOME USE OF %arg0>
1583 /// <SOME USE OF %arg1>
1584 /// ...
1585 /// scf.forall.in_parallel {
1586 /// <STORE OP WITH DESTINATION %arg1>
1587 /// }
1588 /// }
1589 /// return %res#0, %res#1
1590 ///
1591 /// OUTPUT:
1592 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1593 /// {
1594 /// ...
1595 /// <SOME USE OF %a>
1596 /// <SOME USE OF %new_arg0>
1597 /// ...
1598 /// scf.forall.in_parallel {
1599 /// <STORE OP WITH DESTINATION %new_arg0>
1600 /// }
1601 /// }
1602 /// return %a, %res
1603 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1605 
1606  LogicalResult matchAndRewrite(ForallOp forallOp,
1607  PatternRewriter &rewriter) const final {
1608  // Step 1: For a given i-th result of scf.forall, check the following :-
1609  // a. If it has any use.
1610  // b. If the corresponding iter argument is being modified within
1611  // the loop, i.e. has at least one store op with the iter arg as
1612  // its destination operand. For this we use
1613  // ForallOp::getCombiningOps(iter_arg).
1614  //
1615  // Based on the check we maintain the following :-
1616  // a. `resultToDelete` - i-th result of scf.forall that'll be
1617  // deleted.
1618  // b. `resultToReplace` - i-th result of the old scf.forall
1619  // whose uses will be replaced by the new scf.forall.
1620  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1621  // corresponding to the i-th result with at least one use.
1622  SetVector<OpResult> resultToDelete;
1623  SmallVector<Value> resultToReplace;
1624  SmallVector<Value> newOuts;
1625  for (OpResult result : forallOp.getResults()) {
1626  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1627  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1628  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1629  resultToDelete.insert(result);
1630  } else {
1631  resultToReplace.push_back(result);
1632  newOuts.push_back(opOperand->get());
1633  }
1634  }
1635 
1636  // Return early if all results of scf.forall have at least one use and being
1637  // modified within the loop.
1638  if (resultToDelete.empty())
1639  return failure();
1640 
1641  // Step 2: For the the i-th result, do the following :-
1642  // a. Fetch the corresponding BlockArgument.
1643  // b. Look for store ops (currently tensor.parallel_insert_slice)
1644  // with the BlockArgument as its destination operand.
1645  // c. Remove the operations fetched in b.
1646  for (OpResult result : resultToDelete) {
1647  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1648  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1649  SmallVector<Operation *> combiningOps =
1650  forallOp.getCombiningOps(blockArg);
1651  for (Operation *combiningOp : combiningOps)
1652  rewriter.eraseOp(combiningOp);
1653  }
1654 
1655  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1656  // fetched earlier
1657  auto newForallOp = rewriter.create<scf::ForallOp>(
1658  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1659  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1660  forallOp.getMapping(),
1661  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1662 
1663  // Step 4. Merge the block of the old scf.forall into the newly created
1664  // scf.forall using the new set of arguments.
1665  Block *loopBody = forallOp.getBody();
1666  Block *newLoopBody = newForallOp.getBody();
1667  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1668  // Form initial new bbArg list with just the control operands of the new
1669  // scf.forall op.
1670  SmallVector<Value> newBlockArgs =
1671  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1672  [](BlockArgument b) -> Value { return b; });
1673  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1674  unsigned index = 0;
1675  // Take the new corresponding bbArg if the old bbArg was used as a
1676  // destination in the in_parallel op. For all other bbArgs, use the
1677  // corresponding init_arg from the old scf.forall op.
1678  for (OpResult result : forallOp.getResults()) {
1679  if (resultToDelete.count(result)) {
1680  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1681  } else {
1682  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1683  }
1684  }
1685  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1686 
1687  // Step 5. Replace the uses of result of old scf.forall with that of the new
1688  // scf.forall.
1689  for (auto &&[oldResult, newResult] :
1690  llvm::zip(resultToReplace, newForallOp->getResults()))
1691  rewriter.replaceAllUsesWith(oldResult, newResult);
1692 
1693  // Step 6. Replace the uses of those values that either has no use or are
1694  // not being modified within the loop with the corresponding
1695  // OpOperand.
1696  for (OpResult oldResult : resultToDelete)
1697  rewriter.replaceAllUsesWith(oldResult,
1698  forallOp.getTiedOpOperand(oldResult)->get());
1699  return success();
1700  }
1701 };
1702 
1703 struct ForallOpSingleOrZeroIterationDimsFolder
1704  : public OpRewritePattern<ForallOp> {
1706 
1707  LogicalResult matchAndRewrite(ForallOp op,
1708  PatternRewriter &rewriter) const override {
1709  // Do not fold dimensions if they are mapped to processing units.
1710  if (op.getMapping().has_value() && !op.getMapping()->empty())
1711  return failure();
1712  Location loc = op.getLoc();
1713 
1714  // Compute new loop bounds that omit all single-iteration loop dimensions.
1715  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1716  newMixedSteps;
1717  IRMapping mapping;
1718  for (auto [lb, ub, step, iv] :
1719  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1720  op.getMixedStep(), op.getInductionVars())) {
1721  auto numIterations = constantTripCount(lb, ub, step);
1722  if (numIterations.has_value()) {
1723  // Remove the loop if it performs zero iterations.
1724  if (*numIterations == 0) {
1725  rewriter.replaceOp(op, op.getOutputs());
1726  return success();
1727  }
1728  // Replace the loop induction variable by the lower bound if the loop
1729  // performs a single iteration. Otherwise, copy the loop bounds.
1730  if (*numIterations == 1) {
1731  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1732  continue;
1733  }
1734  }
1735  newMixedLowerBounds.push_back(lb);
1736  newMixedUpperBounds.push_back(ub);
1737  newMixedSteps.push_back(step);
1738  }
1739 
1740  // All of the loop dimensions perform a single iteration. Inline loop body.
1741  if (newMixedLowerBounds.empty()) {
1742  promote(rewriter, op);
1743  return success();
1744  }
1745 
1746  // Exit if none of the loop dimensions perform a single iteration.
1747  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1748  return rewriter.notifyMatchFailure(
1749  op, "no dimensions have 0 or 1 iterations");
1750  }
1751 
1752  // Replace the loop by a lower-dimensional loop.
1753  ForallOp newOp;
1754  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1755  newMixedUpperBounds, newMixedSteps,
1756  op.getOutputs(), std::nullopt, nullptr);
1757  newOp.getBodyRegion().getBlocks().clear();
1758  // The new loop needs to keep all attributes from the old one, except for
1759  // "operandSegmentSizes" and static loop bound attributes which capture
1760  // the outdated information of the old iteration domain.
1761  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1762  newOp.getStaticLowerBoundAttrName(),
1763  newOp.getStaticUpperBoundAttrName(),
1764  newOp.getStaticStepAttrName()};
1765  for (const auto &namedAttr : op->getAttrs()) {
1766  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1767  continue;
1768  rewriter.modifyOpInPlace(newOp, [&]() {
1769  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1770  });
1771  }
1772  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1773  newOp.getRegion().begin(), mapping);
1774  rewriter.replaceOp(op, newOp.getResults());
1775  return success();
1776  }
1777 };
1778 
1779 /// Replace all induction vars with a single trip count with their lower bound.
1780 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1782 
1783  LogicalResult matchAndRewrite(ForallOp op,
1784  PatternRewriter &rewriter) const override {
1785  Location loc = op.getLoc();
1786  bool changed = false;
1787  for (auto [lb, ub, step, iv] :
1788  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1789  op.getMixedStep(), op.getInductionVars())) {
1790  if (iv.getUses().begin() == iv.getUses().end())
1791  continue;
1792  auto numIterations = constantTripCount(lb, ub, step);
1793  if (!numIterations.has_value() || numIterations.value() != 1) {
1794  continue;
1795  }
1796  rewriter.replaceAllUsesWith(
1797  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1798  changed = true;
1799  }
1800  return success(changed);
1801  }
1802 };
1803 
1804 struct FoldTensorCastOfOutputIntoForallOp
1805  : public OpRewritePattern<scf::ForallOp> {
1807 
1808  struct TypeCast {
1809  Type srcType;
1810  Type dstType;
1811  };
1812 
1813  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1814  PatternRewriter &rewriter) const final {
1815  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1816  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1817  for (auto en : llvm::enumerate(newOutputTensors)) {
1818  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1819  if (!castOp)
1820  continue;
1821 
1822  // Only casts that that preserve static information, i.e. will make the
1823  // loop result type "more" static than before, will be folded.
1824  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1825  castOp.getSource().getType())) {
1826  continue;
1827  }
1828 
1829  tensorCastProducers[en.index()] =
1830  TypeCast{castOp.getSource().getType(), castOp.getType()};
1831  newOutputTensors[en.index()] = castOp.getSource();
1832  }
1833 
1834  if (tensorCastProducers.empty())
1835  return failure();
1836 
1837  // Create new loop.
1838  Location loc = forallOp.getLoc();
1839  auto newForallOp = rewriter.create<ForallOp>(
1840  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1841  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1842  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1843  auto castBlockArgs =
1844  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1845  for (auto [index, cast] : tensorCastProducers) {
1846  Value &oldTypeBBArg = castBlockArgs[index];
1847  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1848  nestedLoc, cast.dstType, oldTypeBBArg);
1849  }
1850 
1851  // Move old body into new parallel loop.
1852  SmallVector<Value> ivsBlockArgs =
1853  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1854  ivsBlockArgs.append(castBlockArgs);
1855  rewriter.mergeBlocks(forallOp.getBody(),
1856  bbArgs.front().getParentBlock(), ivsBlockArgs);
1857  });
1858 
1859  // After `mergeBlocks` happened, the destinations in the terminator were
1860  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1861  // destination have to be updated to point to the output bbArgs directly.
1862  auto terminator = newForallOp.getTerminator();
1863  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1864  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1865  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1866  insertSliceOp.getDestMutable().assign(outputBlockArg);
1867  }
1868 
1869  // Cast results back to the original types.
1870  rewriter.setInsertionPointAfter(newForallOp);
1871  SmallVector<Value> castResults = newForallOp.getResults();
1872  for (auto &item : tensorCastProducers) {
1873  Value &oldTypeResult = castResults[item.first];
1874  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1875  oldTypeResult);
1876  }
1877  rewriter.replaceOp(forallOp, castResults);
1878  return success();
1879  }
1880 };
1881 
1882 } // namespace
1883 
1884 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1885  MLIRContext *context) {
1886  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1887  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1888  ForallOpSingleOrZeroIterationDimsFolder,
1889  ForallOpReplaceConstantInductionVar>(context);
1890 }
1891 
1892 /// Given the region at `index`, or the parent operation if `index` is None,
1893 /// return the successor regions. These are the regions that may be selected
1894 /// during the flow of control. `operands` is a set of optional attributes that
1895 /// correspond to a constant value for each operand, or null if that operand is
1896 /// not a constant.
1897 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1899  // Both the operation itself and the region may be branching into the body or
1900  // back into the operation itself. It is possible for loop not to enter the
1901  // body.
1902  regions.push_back(RegionSuccessor(&getRegion()));
1903  regions.push_back(RegionSuccessor());
1904 }
1905 
1906 //===----------------------------------------------------------------------===//
1907 // InParallelOp
1908 //===----------------------------------------------------------------------===//
1909 
1910 // Build a InParallelOp with mixed static and dynamic entries.
1911 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1913  Region *bodyRegion = result.addRegion();
1914  b.createBlock(bodyRegion);
1915 }
1916 
1917 LogicalResult InParallelOp::verify() {
1918  scf::ForallOp forallOp =
1919  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1920  if (!forallOp)
1921  return this->emitOpError("expected forall op parent");
1922 
1923  // TODO: InParallelOpInterface.
1924  for (Operation &op : getRegion().front().getOperations()) {
1925  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1926  return this->emitOpError("expected only ")
1927  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1928  }
1929 
1930  // Verify that inserts are into out block arguments.
1931  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1932  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1933  if (!llvm::is_contained(regionOutArgs, dest))
1934  return op.emitOpError("may only insert into an output block argument");
1935  }
1936  return success();
1937 }
1938 
1940  p << " ";
1941  p.printRegion(getRegion(),
1942  /*printEntryBlockArgs=*/false,
1943  /*printBlockTerminators=*/false);
1944  p.printOptionalAttrDict(getOperation()->getAttrs());
1945 }
1946 
1947 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1948  auto &builder = parser.getBuilder();
1949 
1951  std::unique_ptr<Region> region = std::make_unique<Region>();
1952  if (parser.parseRegion(*region, regionOperands))
1953  return failure();
1954 
1955  if (region->empty())
1956  OpBuilder(builder.getContext()).createBlock(region.get());
1957  result.addRegion(std::move(region));
1958 
1959  // Parse the optional attribute list.
1960  if (parser.parseOptionalAttrDict(result.attributes))
1961  return failure();
1962  return success();
1963 }
1964 
1965 OpResult InParallelOp::getParentResult(int64_t idx) {
1966  return getOperation()->getParentOp()->getResult(idx);
1967 }
1968 
1969 SmallVector<BlockArgument> InParallelOp::getDests() {
1970  return llvm::to_vector<4>(
1971  llvm::map_range(getYieldingOps(), [](Operation &op) {
1972  // Add new ops here as needed.
1973  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1974  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1975  }));
1976 }
1977 
1978 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1979  return getRegion().front().getOperations();
1980 }
1981 
1982 //===----------------------------------------------------------------------===//
1983 // IfOp
1984 //===----------------------------------------------------------------------===//
1985 
1987  assert(a && "expected non-empty operation");
1988  assert(b && "expected non-empty operation");
1989 
1990  IfOp ifOp = a->getParentOfType<IfOp>();
1991  while (ifOp) {
1992  // Check if b is inside ifOp. (We already know that a is.)
1993  if (ifOp->isProperAncestor(b))
1994  // b is contained in ifOp. a and b are in mutually exclusive branches if
1995  // they are in different blocks of ifOp.
1996  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1997  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1998  // Check next enclosing IfOp.
1999  ifOp = ifOp->getParentOfType<IfOp>();
2000  }
2001 
2002  // Could not find a common IfOp among a's and b's ancestors.
2003  return false;
2004 }
2005 
2006 LogicalResult
2007 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2008  IfOp::Adaptor adaptor,
2009  SmallVectorImpl<Type> &inferredReturnTypes) {
2010  if (adaptor.getRegions().empty())
2011  return failure();
2012  Region *r = &adaptor.getThenRegion();
2013  if (r->empty())
2014  return failure();
2015  Block &b = r->front();
2016  if (b.empty())
2017  return failure();
2018  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2019  if (!yieldOp)
2020  return failure();
2021  TypeRange types = yieldOp.getOperandTypes();
2022  llvm::append_range(inferredReturnTypes, types);
2023  return success();
2024 }
2025 
2026 void IfOp::build(OpBuilder &builder, OperationState &result,
2027  TypeRange resultTypes, Value cond) {
2028  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2029  /*addElseBlock=*/false);
2030 }
2031 
2032 void IfOp::build(OpBuilder &builder, OperationState &result,
2033  TypeRange resultTypes, Value cond, bool addThenBlock,
2034  bool addElseBlock) {
2035  assert((!addElseBlock || addThenBlock) &&
2036  "must not create else block w/o then block");
2037  result.addTypes(resultTypes);
2038  result.addOperands(cond);
2039 
2040  // Add regions and blocks.
2041  OpBuilder::InsertionGuard guard(builder);
2042  Region *thenRegion = result.addRegion();
2043  if (addThenBlock)
2044  builder.createBlock(thenRegion);
2045  Region *elseRegion = result.addRegion();
2046  if (addElseBlock)
2047  builder.createBlock(elseRegion);
2048 }
2049 
2050 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2051  bool withElseRegion) {
2052  build(builder, result, TypeRange{}, cond, withElseRegion);
2053 }
2054 
2055 void IfOp::build(OpBuilder &builder, OperationState &result,
2056  TypeRange resultTypes, Value cond, bool withElseRegion) {
2057  result.addTypes(resultTypes);
2058  result.addOperands(cond);
2059 
2060  // Build then region.
2061  OpBuilder::InsertionGuard guard(builder);
2062  Region *thenRegion = result.addRegion();
2063  builder.createBlock(thenRegion);
2064  if (resultTypes.empty())
2065  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2066 
2067  // Build else region.
2068  Region *elseRegion = result.addRegion();
2069  if (withElseRegion) {
2070  builder.createBlock(elseRegion);
2071  if (resultTypes.empty())
2072  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2073  }
2074 }
2075 
2076 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2077  function_ref<void(OpBuilder &, Location)> thenBuilder,
2078  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2079  assert(thenBuilder && "the builder callback for 'then' must be present");
2080  result.addOperands(cond);
2081 
2082  // Build then region.
2083  OpBuilder::InsertionGuard guard(builder);
2084  Region *thenRegion = result.addRegion();
2085  builder.createBlock(thenRegion);
2086  thenBuilder(builder, result.location);
2087 
2088  // Build else region.
2089  Region *elseRegion = result.addRegion();
2090  if (elseBuilder) {
2091  builder.createBlock(elseRegion);
2092  elseBuilder(builder, result.location);
2093  }
2094 
2095  // Infer result types.
2096  SmallVector<Type> inferredReturnTypes;
2097  MLIRContext *ctx = builder.getContext();
2098  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2099  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2100  /*properties=*/nullptr, result.regions,
2101  inferredReturnTypes))) {
2102  result.addTypes(inferredReturnTypes);
2103  }
2104 }
2105 
2106 LogicalResult IfOp::verify() {
2107  if (getNumResults() != 0 && getElseRegion().empty())
2108  return emitOpError("must have an else block if defining values");
2109  return success();
2110 }
2111 
2112 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2113  // Create the regions for 'then'.
2114  result.regions.reserve(2);
2115  Region *thenRegion = result.addRegion();
2116  Region *elseRegion = result.addRegion();
2117 
2118  auto &builder = parser.getBuilder();
2120  Type i1Type = builder.getIntegerType(1);
2121  if (parser.parseOperand(cond) ||
2122  parser.resolveOperand(cond, i1Type, result.operands))
2123  return failure();
2124  // Parse optional results type list.
2125  if (parser.parseOptionalArrowTypeList(result.types))
2126  return failure();
2127  // Parse the 'then' region.
2128  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2129  return failure();
2130  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2131 
2132  // If we find an 'else' keyword then parse the 'else' region.
2133  if (!parser.parseOptionalKeyword("else")) {
2134  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2135  return failure();
2136  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2137  }
2138 
2139  // Parse the optional attribute list.
2140  if (parser.parseOptionalAttrDict(result.attributes))
2141  return failure();
2142  return success();
2143 }
2144 
2145 void IfOp::print(OpAsmPrinter &p) {
2146  bool printBlockTerminators = false;
2147 
2148  p << " " << getCondition();
2149  if (!getResults().empty()) {
2150  p << " -> (" << getResultTypes() << ")";
2151  // Print yield explicitly if the op defines values.
2152  printBlockTerminators = true;
2153  }
2154  p << ' ';
2155  p.printRegion(getThenRegion(),
2156  /*printEntryBlockArgs=*/false,
2157  /*printBlockTerminators=*/printBlockTerminators);
2158 
2159  // Print the 'else' regions if it exists and has a block.
2160  auto &elseRegion = getElseRegion();
2161  if (!elseRegion.empty()) {
2162  p << " else ";
2163  p.printRegion(elseRegion,
2164  /*printEntryBlockArgs=*/false,
2165  /*printBlockTerminators=*/printBlockTerminators);
2166  }
2167 
2168  p.printOptionalAttrDict((*this)->getAttrs());
2169 }
2170 
2171 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2173  // The `then` and the `else` region branch back to the parent operation.
2174  if (!point.isParent()) {
2175  regions.push_back(RegionSuccessor(getResults()));
2176  return;
2177  }
2178 
2179  regions.push_back(RegionSuccessor(&getThenRegion()));
2180 
2181  // Don't consider the else region if it is empty.
2182  Region *elseRegion = &this->getElseRegion();
2183  if (elseRegion->empty())
2184  regions.push_back(RegionSuccessor());
2185  else
2186  regions.push_back(RegionSuccessor(elseRegion));
2187 }
2188 
2189 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2191  FoldAdaptor adaptor(operands, *this);
2192  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2193  if (!boolAttr || boolAttr.getValue())
2194  regions.emplace_back(&getThenRegion());
2195 
2196  // If the else region is empty, execution continues after the parent op.
2197  if (!boolAttr || !boolAttr.getValue()) {
2198  if (!getElseRegion().empty())
2199  regions.emplace_back(&getElseRegion());
2200  else
2201  regions.emplace_back(getResults());
2202  }
2203 }
2204 
2205 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2206  SmallVectorImpl<OpFoldResult> &results) {
2207  // if (!c) then A() else B() -> if c then B() else A()
2208  if (getElseRegion().empty())
2209  return failure();
2210 
2211  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2212  if (!xorStmt)
2213  return failure();
2214 
2215  if (!matchPattern(xorStmt.getRhs(), m_One()))
2216  return failure();
2217 
2218  getConditionMutable().assign(xorStmt.getLhs());
2219  Block *thenBlock = &getThenRegion().front();
2220  // It would be nicer to use iplist::swap, but that has no implemented
2221  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2222  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2223  getElseRegion().getBlocks());
2224  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2225  getThenRegion().getBlocks(), thenBlock);
2226  return success();
2227 }
2228 
2229 void IfOp::getRegionInvocationBounds(
2230  ArrayRef<Attribute> operands,
2231  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2232  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2233  // If the condition is known, then one region is known to be executed once
2234  // and the other zero times.
2235  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2236  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2237  } else {
2238  // Non-constant condition. Each region may be executed 0 or 1 times.
2239  invocationBounds.assign(2, {0, 1});
2240  }
2241 }
2242 
2243 namespace {
2244 // Pattern to remove unused IfOp results.
2245 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2247 
2248  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2249  PatternRewriter &rewriter) const {
2250  // Move all operations to the destination block.
2251  rewriter.mergeBlocks(source, dest);
2252  // Replace the yield op by one that returns only the used values.
2253  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2254  SmallVector<Value, 4> usedOperands;
2255  llvm::transform(usedResults, std::back_inserter(usedOperands),
2256  [&](OpResult result) {
2257  return yieldOp.getOperand(result.getResultNumber());
2258  });
2259  rewriter.modifyOpInPlace(yieldOp,
2260  [&]() { yieldOp->setOperands(usedOperands); });
2261  }
2262 
2263  LogicalResult matchAndRewrite(IfOp op,
2264  PatternRewriter &rewriter) const override {
2265  // Compute the list of used results.
2266  SmallVector<OpResult, 4> usedResults;
2267  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2268  [](OpResult result) { return !result.use_empty(); });
2269 
2270  // Replace the operation if only a subset of its results have uses.
2271  if (usedResults.size() == op.getNumResults())
2272  return failure();
2273 
2274  // Compute the result types of the replacement operation.
2275  SmallVector<Type, 4> newTypes;
2276  llvm::transform(usedResults, std::back_inserter(newTypes),
2277  [](OpResult result) { return result.getType(); });
2278 
2279  // Create a replacement operation with empty then and else regions.
2280  auto newOp =
2281  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2282  rewriter.createBlock(&newOp.getThenRegion());
2283  rewriter.createBlock(&newOp.getElseRegion());
2284 
2285  // Move the bodies and replace the terminators (note there is a then and
2286  // an else region since the operation returns results).
2287  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2288  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2289 
2290  // Replace the operation by the new one.
2291  SmallVector<Value, 4> repResults(op.getNumResults());
2292  for (const auto &en : llvm::enumerate(usedResults))
2293  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2294  rewriter.replaceOp(op, repResults);
2295  return success();
2296  }
2297 };
2298 
2299 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2301 
2302  LogicalResult matchAndRewrite(IfOp op,
2303  PatternRewriter &rewriter) const override {
2304  BoolAttr condition;
2305  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2306  return failure();
2307 
2308  if (condition.getValue())
2309  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2310  else if (!op.getElseRegion().empty())
2311  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2312  else
2313  rewriter.eraseOp(op);
2314 
2315  return success();
2316  }
2317 };
2318 
2319 /// Hoist any yielded results whose operands are defined outside
2320 /// the if, to a select instruction.
2321 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2323 
2324  LogicalResult matchAndRewrite(IfOp op,
2325  PatternRewriter &rewriter) const override {
2326  if (op->getNumResults() == 0)
2327  return failure();
2328 
2329  auto cond = op.getCondition();
2330  auto thenYieldArgs = op.thenYield().getOperands();
2331  auto elseYieldArgs = op.elseYield().getOperands();
2332 
2333  SmallVector<Type> nonHoistable;
2334  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2335  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2336  &op.getElseRegion() == falseVal.getParentRegion())
2337  nonHoistable.push_back(trueVal.getType());
2338  }
2339  // Early exit if there aren't any yielded values we can
2340  // hoist outside the if.
2341  if (nonHoistable.size() == op->getNumResults())
2342  return failure();
2343 
2344  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2345  /*withElseRegion=*/false);
2346  if (replacement.thenBlock())
2347  rewriter.eraseBlock(replacement.thenBlock());
2348  replacement.getThenRegion().takeBody(op.getThenRegion());
2349  replacement.getElseRegion().takeBody(op.getElseRegion());
2350 
2351  SmallVector<Value> results(op->getNumResults());
2352  assert(thenYieldArgs.size() == results.size());
2353  assert(elseYieldArgs.size() == results.size());
2354 
2355  SmallVector<Value> trueYields;
2356  SmallVector<Value> falseYields;
2357  rewriter.setInsertionPoint(replacement);
2358  for (const auto &it :
2359  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2360  Value trueVal = std::get<0>(it.value());
2361  Value falseVal = std::get<1>(it.value());
2362  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2363  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2364  results[it.index()] = replacement.getResult(trueYields.size());
2365  trueYields.push_back(trueVal);
2366  falseYields.push_back(falseVal);
2367  } else if (trueVal == falseVal)
2368  results[it.index()] = trueVal;
2369  else
2370  results[it.index()] = rewriter.create<arith::SelectOp>(
2371  op.getLoc(), cond, trueVal, falseVal);
2372  }
2373 
2374  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2375  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2376 
2377  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2378  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2379 
2380  rewriter.replaceOp(op, results);
2381  return success();
2382  }
2383 };
2384 
2385 /// Allow the true region of an if to assume the condition is true
2386 /// and vice versa. For example:
2387 ///
2388 /// scf.if %cmp {
2389 /// print(%cmp)
2390 /// }
2391 ///
2392 /// becomes
2393 ///
2394 /// scf.if %cmp {
2395 /// print(true)
2396 /// }
2397 ///
2398 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2400 
2401  LogicalResult matchAndRewrite(IfOp op,
2402  PatternRewriter &rewriter) const override {
2403  // Early exit if the condition is constant since replacing a constant
2404  // in the body with another constant isn't a simplification.
2405  if (matchPattern(op.getCondition(), m_Constant()))
2406  return failure();
2407 
2408  bool changed = false;
2409  mlir::Type i1Ty = rewriter.getI1Type();
2410 
2411  // These variables serve to prevent creating duplicate constants
2412  // and hold constant true or false values.
2413  Value constantTrue = nullptr;
2414  Value constantFalse = nullptr;
2415 
2416  for (OpOperand &use :
2417  llvm::make_early_inc_range(op.getCondition().getUses())) {
2418  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2419  changed = true;
2420 
2421  if (!constantTrue)
2422  constantTrue = rewriter.create<arith::ConstantOp>(
2423  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2424 
2425  rewriter.modifyOpInPlace(use.getOwner(),
2426  [&]() { use.set(constantTrue); });
2427  } else if (op.getElseRegion().isAncestor(
2428  use.getOwner()->getParentRegion())) {
2429  changed = true;
2430 
2431  if (!constantFalse)
2432  constantFalse = rewriter.create<arith::ConstantOp>(
2433  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2434 
2435  rewriter.modifyOpInPlace(use.getOwner(),
2436  [&]() { use.set(constantFalse); });
2437  }
2438  }
2439 
2440  return success(changed);
2441  }
2442 };
2443 
2444 /// Remove any statements from an if that are equivalent to the condition
2445 /// or its negation. For example:
2446 ///
2447 /// %res:2 = scf.if %cmp {
2448 /// yield something(), true
2449 /// } else {
2450 /// yield something2(), false
2451 /// }
2452 /// print(%res#1)
2453 ///
2454 /// becomes
2455 /// %res = scf.if %cmp {
2456 /// yield something()
2457 /// } else {
2458 /// yield something2()
2459 /// }
2460 /// print(%cmp)
2461 ///
2462 /// Additionally if both branches yield the same value, replace all uses
2463 /// of the result with the yielded value.
2464 ///
2465 /// %res:2 = scf.if %cmp {
2466 /// yield something(), %arg1
2467 /// } else {
2468 /// yield something2(), %arg1
2469 /// }
2470 /// print(%res#1)
2471 ///
2472 /// becomes
2473 /// %res = scf.if %cmp {
2474 /// yield something()
2475 /// } else {
2476 /// yield something2()
2477 /// }
2478 /// print(%arg1)
2479 ///
2480 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2482 
2483  LogicalResult matchAndRewrite(IfOp op,
2484  PatternRewriter &rewriter) const override {
2485  // Early exit if there are no results that could be replaced.
2486  if (op.getNumResults() == 0)
2487  return failure();
2488 
2489  auto trueYield =
2490  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2491  auto falseYield =
2492  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2493 
2494  rewriter.setInsertionPoint(op->getBlock(),
2495  op.getOperation()->getIterator());
2496  bool changed = false;
2497  Type i1Ty = rewriter.getI1Type();
2498  for (auto [trueResult, falseResult, opResult] :
2499  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2500  op.getResults())) {
2501  if (trueResult == falseResult) {
2502  if (!opResult.use_empty()) {
2503  opResult.replaceAllUsesWith(trueResult);
2504  changed = true;
2505  }
2506  continue;
2507  }
2508 
2509  BoolAttr trueYield, falseYield;
2510  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2511  !matchPattern(falseResult, m_Constant(&falseYield)))
2512  continue;
2513 
2514  bool trueVal = trueYield.getValue();
2515  bool falseVal = falseYield.getValue();
2516  if (!trueVal && falseVal) {
2517  if (!opResult.use_empty()) {
2518  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2519  Value notCond = rewriter.create<arith::XOrIOp>(
2520  op.getLoc(), op.getCondition(),
2521  constDialect
2522  ->materializeConstant(rewriter,
2523  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2524  op.getLoc())
2525  ->getResult(0));
2526  opResult.replaceAllUsesWith(notCond);
2527  changed = true;
2528  }
2529  }
2530  if (trueVal && !falseVal) {
2531  if (!opResult.use_empty()) {
2532  opResult.replaceAllUsesWith(op.getCondition());
2533  changed = true;
2534  }
2535  }
2536  }
2537  return success(changed);
2538  }
2539 };
2540 
2541 /// Merge any consecutive scf.if's with the same condition.
2542 ///
2543 /// scf.if %cond {
2544 /// firstCodeTrue();...
2545 /// } else {
2546 /// firstCodeFalse();...
2547 /// }
2548 /// %res = scf.if %cond {
2549 /// secondCodeTrue();...
2550 /// } else {
2551 /// secondCodeFalse();...
2552 /// }
2553 ///
2554 /// becomes
2555 /// %res = scf.if %cmp {
2556 /// firstCodeTrue();...
2557 /// secondCodeTrue();...
2558 /// } else {
2559 /// firstCodeFalse();...
2560 /// secondCodeFalse();...
2561 /// }
2562 struct CombineIfs : public OpRewritePattern<IfOp> {
2564 
2565  LogicalResult matchAndRewrite(IfOp nextIf,
2566  PatternRewriter &rewriter) const override {
2567  Block *parent = nextIf->getBlock();
2568  if (nextIf == &parent->front())
2569  return failure();
2570 
2571  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2572  if (!prevIf)
2573  return failure();
2574 
2575  // Determine the logical then/else blocks when prevIf's
2576  // condition is used. Null means the block does not exist
2577  // in that case (e.g. empty else). If neither of these
2578  // are set, the two conditions cannot be compared.
2579  Block *nextThen = nullptr;
2580  Block *nextElse = nullptr;
2581  if (nextIf.getCondition() == prevIf.getCondition()) {
2582  nextThen = nextIf.thenBlock();
2583  if (!nextIf.getElseRegion().empty())
2584  nextElse = nextIf.elseBlock();
2585  }
2586  if (arith::XOrIOp notv =
2587  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2588  if (notv.getLhs() == prevIf.getCondition() &&
2589  matchPattern(notv.getRhs(), m_One())) {
2590  nextElse = nextIf.thenBlock();
2591  if (!nextIf.getElseRegion().empty())
2592  nextThen = nextIf.elseBlock();
2593  }
2594  }
2595  if (arith::XOrIOp notv =
2596  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2597  if (notv.getLhs() == nextIf.getCondition() &&
2598  matchPattern(notv.getRhs(), m_One())) {
2599  nextElse = nextIf.thenBlock();
2600  if (!nextIf.getElseRegion().empty())
2601  nextThen = nextIf.elseBlock();
2602  }
2603  }
2604 
2605  if (!nextThen && !nextElse)
2606  return failure();
2607 
2608  SmallVector<Value> prevElseYielded;
2609  if (!prevIf.getElseRegion().empty())
2610  prevElseYielded = prevIf.elseYield().getOperands();
2611  // Replace all uses of return values of op within nextIf with the
2612  // corresponding yields
2613  for (auto it : llvm::zip(prevIf.getResults(),
2614  prevIf.thenYield().getOperands(), prevElseYielded))
2615  for (OpOperand &use :
2616  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2617  if (nextThen && nextThen->getParent()->isAncestor(
2618  use.getOwner()->getParentRegion())) {
2619  rewriter.startOpModification(use.getOwner());
2620  use.set(std::get<1>(it));
2621  rewriter.finalizeOpModification(use.getOwner());
2622  } else if (nextElse && nextElse->getParent()->isAncestor(
2623  use.getOwner()->getParentRegion())) {
2624  rewriter.startOpModification(use.getOwner());
2625  use.set(std::get<2>(it));
2626  rewriter.finalizeOpModification(use.getOwner());
2627  }
2628  }
2629 
2630  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2631  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2632 
2633  IfOp combinedIf = rewriter.create<IfOp>(
2634  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2635  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2636 
2637  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2638  combinedIf.getThenRegion(),
2639  combinedIf.getThenRegion().begin());
2640 
2641  if (nextThen) {
2642  YieldOp thenYield = combinedIf.thenYield();
2643  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2644  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2645  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2646 
2647  SmallVector<Value> mergedYields(thenYield.getOperands());
2648  llvm::append_range(mergedYields, thenYield2.getOperands());
2649  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2650  rewriter.eraseOp(thenYield);
2651  rewriter.eraseOp(thenYield2);
2652  }
2653 
2654  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2655  combinedIf.getElseRegion(),
2656  combinedIf.getElseRegion().begin());
2657 
2658  if (nextElse) {
2659  if (combinedIf.getElseRegion().empty()) {
2660  rewriter.inlineRegionBefore(*nextElse->getParent(),
2661  combinedIf.getElseRegion(),
2662  combinedIf.getElseRegion().begin());
2663  } else {
2664  YieldOp elseYield = combinedIf.elseYield();
2665  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2666  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2667 
2668  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2669 
2670  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2671  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2672 
2673  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2674  rewriter.eraseOp(elseYield);
2675  rewriter.eraseOp(elseYield2);
2676  }
2677  }
2678 
2679  SmallVector<Value> prevValues;
2680  SmallVector<Value> nextValues;
2681  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2682  if (pair.index() < prevIf.getNumResults())
2683  prevValues.push_back(pair.value());
2684  else
2685  nextValues.push_back(pair.value());
2686  }
2687  rewriter.replaceOp(prevIf, prevValues);
2688  rewriter.replaceOp(nextIf, nextValues);
2689  return success();
2690  }
2691 };
2692 
2693 /// Pattern to remove an empty else branch.
2694 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2696 
2697  LogicalResult matchAndRewrite(IfOp ifOp,
2698  PatternRewriter &rewriter) const override {
2699  // Cannot remove else region when there are operation results.
2700  if (ifOp.getNumResults())
2701  return failure();
2702  Block *elseBlock = ifOp.elseBlock();
2703  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2704  return failure();
2705  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2706  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2707  newIfOp.getThenRegion().begin());
2708  rewriter.eraseOp(ifOp);
2709  return success();
2710  }
2711 };
2712 
2713 /// Convert nested `if`s into `arith.andi` + single `if`.
2714 ///
2715 /// scf.if %arg0 {
2716 /// scf.if %arg1 {
2717 /// ...
2718 /// scf.yield
2719 /// }
2720 /// scf.yield
2721 /// }
2722 /// becomes
2723 ///
2724 /// %0 = arith.andi %arg0, %arg1
2725 /// scf.if %0 {
2726 /// ...
2727 /// scf.yield
2728 /// }
2729 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2731 
2732  LogicalResult matchAndRewrite(IfOp op,
2733  PatternRewriter &rewriter) const override {
2734  auto nestedOps = op.thenBlock()->without_terminator();
2735  // Nested `if` must be the only op in block.
2736  if (!llvm::hasSingleElement(nestedOps))
2737  return failure();
2738 
2739  // If there is an else block, it can only yield
2740  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2741  return failure();
2742 
2743  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2744  if (!nestedIf)
2745  return failure();
2746 
2747  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2748  return failure();
2749 
2750  SmallVector<Value> thenYield(op.thenYield().getOperands());
2751  SmallVector<Value> elseYield;
2752  if (op.elseBlock())
2753  llvm::append_range(elseYield, op.elseYield().getOperands());
2754 
2755  // A list of indices for which we should upgrade the value yielded
2756  // in the else to a select.
2757  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2758 
2759  // If the outer scf.if yields a value produced by the inner scf.if,
2760  // only permit combining if the value yielded when the condition
2761  // is false in the outer scf.if is the same value yielded when the
2762  // inner scf.if condition is false.
2763  // Note that the array access to elseYield will not go out of bounds
2764  // since it must have the same length as thenYield, since they both
2765  // come from the same scf.if.
2766  for (const auto &tup : llvm::enumerate(thenYield)) {
2767  if (tup.value().getDefiningOp() == nestedIf) {
2768  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2769  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2770  elseYield[tup.index()]) {
2771  return failure();
2772  }
2773  // If the correctness test passes, we will yield
2774  // corresponding value from the inner scf.if
2775  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2776  continue;
2777  }
2778 
2779  // Otherwise, we need to ensure the else block of the combined
2780  // condition still returns the same value when the outer condition is
2781  // true and the inner condition is false. This can be accomplished if
2782  // the then value is defined outside the outer scf.if and we replace the
2783  // value with a select that considers just the outer condition. Since
2784  // the else region contains just the yield, its yielded value is
2785  // defined outside the scf.if, by definition.
2786 
2787  // If the then value is defined within the scf.if, bail.
2788  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2789  return failure();
2790  }
2791  elseYieldsToUpgradeToSelect.push_back(tup.index());
2792  }
2793 
2794  Location loc = op.getLoc();
2795  Value newCondition = rewriter.create<arith::AndIOp>(
2796  loc, op.getCondition(), nestedIf.getCondition());
2797  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2798  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2799 
2800  SmallVector<Value> results;
2801  llvm::append_range(results, newIf.getResults());
2802  rewriter.setInsertionPoint(newIf);
2803 
2804  for (auto idx : elseYieldsToUpgradeToSelect)
2805  results[idx] = rewriter.create<arith::SelectOp>(
2806  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2807 
2808  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2809  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2810  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2811  if (!elseYield.empty()) {
2812  rewriter.createBlock(&newIf.getElseRegion());
2813  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2814  rewriter.create<YieldOp>(loc, elseYield);
2815  }
2816  rewriter.replaceOp(op, results);
2817  return success();
2818  }
2819 };
2820 
2821 } // namespace
2822 
2823 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2824  MLIRContext *context) {
2825  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2826  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2827  RemoveStaticCondition, RemoveUnusedResults,
2828  ReplaceIfYieldWithConditionOrValue>(context);
2829 }
2830 
2831 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2832 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2833 Block *IfOp::elseBlock() {
2834  Region &r = getElseRegion();
2835  if (r.empty())
2836  return nullptr;
2837  return &r.back();
2838 }
2839 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2840 
2841 //===----------------------------------------------------------------------===//
2842 // ParallelOp
2843 //===----------------------------------------------------------------------===//
2844 
2845 void ParallelOp::build(
2846  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2847  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2849  bodyBuilderFn) {
2850  result.addOperands(lowerBounds);
2851  result.addOperands(upperBounds);
2852  result.addOperands(steps);
2853  result.addOperands(initVals);
2854  result.addAttribute(
2855  ParallelOp::getOperandSegmentSizeAttr(),
2856  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2857  static_cast<int32_t>(upperBounds.size()),
2858  static_cast<int32_t>(steps.size()),
2859  static_cast<int32_t>(initVals.size())}));
2860  result.addTypes(initVals.getTypes());
2861 
2862  OpBuilder::InsertionGuard guard(builder);
2863  unsigned numIVs = steps.size();
2864  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2865  SmallVector<Location, 8> argLocs(numIVs, result.location);
2866  Region *bodyRegion = result.addRegion();
2867  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2868 
2869  if (bodyBuilderFn) {
2870  builder.setInsertionPointToStart(bodyBlock);
2871  bodyBuilderFn(builder, result.location,
2872  bodyBlock->getArguments().take_front(numIVs),
2873  bodyBlock->getArguments().drop_front(numIVs));
2874  }
2875  // Add terminator only if there are no reductions.
2876  if (initVals.empty())
2877  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2878 }
2879 
2880 void ParallelOp::build(
2881  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2882  ValueRange upperBounds, ValueRange steps,
2883  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2884  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2885  // we don't capture a reference to a temporary by constructing the lambda at
2886  // function level.
2887  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2888  Location nestedLoc, ValueRange ivs,
2889  ValueRange) {
2890  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2891  };
2892  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2893  if (bodyBuilderFn)
2894  wrapper = wrappedBuilderFn;
2895 
2896  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2897  wrapper);
2898 }
2899 
2900 LogicalResult ParallelOp::verify() {
2901  // Check that there is at least one value in lowerBound, upperBound and step.
2902  // It is sufficient to test only step, because it is ensured already that the
2903  // number of elements in lowerBound, upperBound and step are the same.
2904  Operation::operand_range stepValues = getStep();
2905  if (stepValues.empty())
2906  return emitOpError(
2907  "needs at least one tuple element for lowerBound, upperBound and step");
2908 
2909  // Check whether all constant step values are positive.
2910  for (Value stepValue : stepValues)
2911  if (auto cst = getConstantIntValue(stepValue))
2912  if (*cst <= 0)
2913  return emitOpError("constant step operand must be positive");
2914 
2915  // Check that the body defines the same number of block arguments as the
2916  // number of tuple elements in step.
2917  Block *body = getBody();
2918  if (body->getNumArguments() != stepValues.size())
2919  return emitOpError() << "expects the same number of induction variables: "
2920  << body->getNumArguments()
2921  << " as bound and step values: " << stepValues.size();
2922  for (auto arg : body->getArguments())
2923  if (!arg.getType().isIndex())
2924  return emitOpError(
2925  "expects arguments for the induction variable to be of index type");
2926 
2927  // Check that the terminator is an scf.reduce op.
2928  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2929  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2930  if (!reduceOp)
2931  return failure();
2932 
2933  // Check that the number of results is the same as the number of reductions.
2934  auto resultsSize = getResults().size();
2935  auto reductionsSize = reduceOp.getReductions().size();
2936  auto initValsSize = getInitVals().size();
2937  if (resultsSize != reductionsSize)
2938  return emitOpError() << "expects number of results: " << resultsSize
2939  << " to be the same as number of reductions: "
2940  << reductionsSize;
2941  if (resultsSize != initValsSize)
2942  return emitOpError() << "expects number of results: " << resultsSize
2943  << " to be the same as number of initial values: "
2944  << initValsSize;
2945 
2946  // Check that the types of the results and reductions are the same.
2947  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2948  auto resultType = getOperation()->getResult(i).getType();
2949  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2950  if (resultType != reductionOperandType)
2951  return reduceOp.emitOpError()
2952  << "expects type of " << i
2953  << "-th reduction operand: " << reductionOperandType
2954  << " to be the same as the " << i
2955  << "-th result type: " << resultType;
2956  }
2957  return success();
2958 }
2959 
2960 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2961  auto &builder = parser.getBuilder();
2962  // Parse an opening `(` followed by induction variables followed by `)`
2965  return failure();
2966 
2967  // Parse loop bounds.
2969  if (parser.parseEqual() ||
2970  parser.parseOperandList(lower, ivs.size(),
2972  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2973  return failure();
2974 
2976  if (parser.parseKeyword("to") ||
2977  parser.parseOperandList(upper, ivs.size(),
2979  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2980  return failure();
2981 
2982  // Parse step values.
2984  if (parser.parseKeyword("step") ||
2985  parser.parseOperandList(steps, ivs.size(),
2987  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2988  return failure();
2989 
2990  // Parse init values.
2992  if (succeeded(parser.parseOptionalKeyword("init"))) {
2993  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2994  return failure();
2995  }
2996 
2997  // Parse optional results in case there is a reduce.
2998  if (parser.parseOptionalArrowTypeList(result.types))
2999  return failure();
3000 
3001  // Now parse the body.
3002  Region *body = result.addRegion();
3003  for (auto &iv : ivs)
3004  iv.type = builder.getIndexType();
3005  if (parser.parseRegion(*body, ivs))
3006  return failure();
3007 
3008  // Set `operandSegmentSizes` attribute.
3009  result.addAttribute(
3010  ParallelOp::getOperandSegmentSizeAttr(),
3011  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3012  static_cast<int32_t>(upper.size()),
3013  static_cast<int32_t>(steps.size()),
3014  static_cast<int32_t>(initVals.size())}));
3015 
3016  // Parse attributes.
3017  if (parser.parseOptionalAttrDict(result.attributes) ||
3018  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3019  result.operands))
3020  return failure();
3021 
3022  // Add a terminator if none was parsed.
3023  ParallelOp::ensureTerminator(*body, builder, result.location);
3024  return success();
3025 }
3026 
3028  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3029  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3030  if (!getInitVals().empty())
3031  p << " init (" << getInitVals() << ")";
3032  p.printOptionalArrowTypeList(getResultTypes());
3033  p << ' ';
3034  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3036  (*this)->getAttrs(),
3037  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3038 }
3039 
3040 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3041 
3042 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3043  return SmallVector<Value>{getBody()->getArguments()};
3044 }
3045 
3046 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3047  return getLowerBound();
3048 }
3049 
3050 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3051  return getUpperBound();
3052 }
3053 
3054 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3055  return getStep();
3056 }
3057 
3059  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3060  if (!ivArg)
3061  return ParallelOp();
3062  assert(ivArg.getOwner() && "unlinked block argument");
3063  auto *containingOp = ivArg.getOwner()->getParentOp();
3064  return dyn_cast<ParallelOp>(containingOp);
3065 }
3066 
3067 namespace {
3068 // Collapse loop dimensions that perform a single iteration.
3069 struct ParallelOpSingleOrZeroIterationDimsFolder
3070  : public OpRewritePattern<ParallelOp> {
3072 
3073  LogicalResult matchAndRewrite(ParallelOp op,
3074  PatternRewriter &rewriter) const override {
3075  Location loc = op.getLoc();
3076 
3077  // Compute new loop bounds that omit all single-iteration loop dimensions.
3078  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3079  IRMapping mapping;
3080  for (auto [lb, ub, step, iv] :
3081  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3082  op.getInductionVars())) {
3083  auto numIterations = constantTripCount(lb, ub, step);
3084  if (numIterations.has_value()) {
3085  // Remove the loop if it performs zero iterations.
3086  if (*numIterations == 0) {
3087  rewriter.replaceOp(op, op.getInitVals());
3088  return success();
3089  }
3090  // Replace the loop induction variable by the lower bound if the loop
3091  // performs a single iteration. Otherwise, copy the loop bounds.
3092  if (*numIterations == 1) {
3093  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3094  continue;
3095  }
3096  }
3097  newLowerBounds.push_back(lb);
3098  newUpperBounds.push_back(ub);
3099  newSteps.push_back(step);
3100  }
3101  // Exit if none of the loop dimensions perform a single iteration.
3102  if (newLowerBounds.size() == op.getLowerBound().size())
3103  return failure();
3104 
3105  if (newLowerBounds.empty()) {
3106  // All of the loop dimensions perform a single iteration. Inline
3107  // loop body and nested ReduceOp's
3108  SmallVector<Value> results;
3109  results.reserve(op.getInitVals().size());
3110  for (auto &bodyOp : op.getBody()->without_terminator())
3111  rewriter.clone(bodyOp, mapping);
3112  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3113  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3114  Block &reduceBlock = reduceOp.getReductions()[i].front();
3115  auto initValIndex = results.size();
3116  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3117  mapping.map(reduceBlock.getArgument(1),
3118  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3119  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3120  rewriter.clone(reduceBodyOp, mapping);
3121 
3122  auto result = mapping.lookupOrDefault(
3123  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3124  results.push_back(result);
3125  }
3126 
3127  rewriter.replaceOp(op, results);
3128  return success();
3129  }
3130  // Replace the parallel loop by lower-dimensional parallel loop.
3131  auto newOp =
3132  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3133  newSteps, op.getInitVals(), nullptr);
3134  // Erase the empty block that was inserted by the builder.
3135  rewriter.eraseBlock(newOp.getBody());
3136  // Clone the loop body and remap the block arguments of the collapsed loops
3137  // (inlining does not support a cancellable block argument mapping).
3138  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3139  newOp.getRegion().begin(), mapping);
3140  rewriter.replaceOp(op, newOp.getResults());
3141  return success();
3142  }
3143 };
3144 
3145 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3147 
3148  LogicalResult matchAndRewrite(ParallelOp op,
3149  PatternRewriter &rewriter) const override {
3150  Block &outerBody = *op.getBody();
3151  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3152  return failure();
3153 
3154  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3155  if (!innerOp)
3156  return failure();
3157 
3158  for (auto val : outerBody.getArguments())
3159  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3160  llvm::is_contained(innerOp.getUpperBound(), val) ||
3161  llvm::is_contained(innerOp.getStep(), val))
3162  return failure();
3163 
3164  // Reductions are not supported yet.
3165  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3166  return failure();
3167 
3168  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3169  ValueRange iterVals, ValueRange) {
3170  Block &innerBody = *innerOp.getBody();
3171  assert(iterVals.size() ==
3172  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3173  IRMapping mapping;
3174  mapping.map(outerBody.getArguments(),
3175  iterVals.take_front(outerBody.getNumArguments()));
3176  mapping.map(innerBody.getArguments(),
3177  iterVals.take_back(innerBody.getNumArguments()));
3178  for (Operation &op : innerBody.without_terminator())
3179  builder.clone(op, mapping);
3180  };
3181 
3182  auto concatValues = [](const auto &first, const auto &second) {
3183  SmallVector<Value> ret;
3184  ret.reserve(first.size() + second.size());
3185  ret.assign(first.begin(), first.end());
3186  ret.append(second.begin(), second.end());
3187  return ret;
3188  };
3189 
3190  auto newLowerBounds =
3191  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3192  auto newUpperBounds =
3193  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3194  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3195 
3196  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3197  newSteps, std::nullopt,
3198  bodyBuilder);
3199  return success();
3200  }
3201 };
3202 
3203 } // namespace
3204 
3205 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3206  MLIRContext *context) {
3207  results
3208  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3209  context);
3210 }
3211 
3212 /// Given the region at `index`, or the parent operation if `index` is None,
3213 /// return the successor regions. These are the regions that may be selected
3214 /// during the flow of control. `operands` is a set of optional attributes that
3215 /// correspond to a constant value for each operand, or null if that operand is
3216 /// not a constant.
3217 void ParallelOp::getSuccessorRegions(
3219  // Both the operation itself and the region may be branching into the body or
3220  // back into the operation itself. It is possible for loop not to enter the
3221  // body.
3222  regions.push_back(RegionSuccessor(&getRegion()));
3223  regions.push_back(RegionSuccessor());
3224 }
3225 
3226 //===----------------------------------------------------------------------===//
3227 // ReduceOp
3228 //===----------------------------------------------------------------------===//
3229 
3230 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3231 
3232 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3233  ValueRange operands) {
3234  result.addOperands(operands);
3235  for (Value v : operands) {
3236  OpBuilder::InsertionGuard guard(builder);
3237  Region *bodyRegion = result.addRegion();
3238  builder.createBlock(bodyRegion, {},
3239  ArrayRef<Type>{v.getType(), v.getType()},
3240  {result.location, result.location});
3241  }
3242 }
3243 
3244 LogicalResult ReduceOp::verifyRegions() {
3245  // The region of a ReduceOp has two arguments of the same type as its
3246  // corresponding operand.
3247  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3248  auto type = getOperands()[i].getType();
3249  Block &block = getReductions()[i].front();
3250  if (block.empty())
3251  return emitOpError() << i << "-th reduction has an empty body";
3252  if (block.getNumArguments() != 2 ||
3253  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3254  return arg.getType() != type;
3255  }))
3256  return emitOpError() << "expected two block arguments with type " << type
3257  << " in the " << i << "-th reduction region";
3258 
3259  // Check that the block is terminated by a ReduceReturnOp.
3260  if (!isa<ReduceReturnOp>(block.getTerminator()))
3261  return emitOpError("reduction bodies must be terminated with an "
3262  "'scf.reduce.return' op");
3263  }
3264 
3265  return success();
3266 }
3267 
3270  // No operands are forwarded to the next iteration.
3271  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3272 }
3273 
3274 //===----------------------------------------------------------------------===//
3275 // ReduceReturnOp
3276 //===----------------------------------------------------------------------===//
3277 
3278 LogicalResult ReduceReturnOp::verify() {
3279  // The type of the return value should be the same type as the types of the
3280  // block arguments of the reduction body.
3281  Block *reductionBody = getOperation()->getBlock();
3282  // Should already be verified by an op trait.
3283  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3284  Type expectedResultType = reductionBody->getArgument(0).getType();
3285  if (expectedResultType != getResult().getType())
3286  return emitOpError() << "must have type " << expectedResultType
3287  << " (the type of the reduction inputs)";
3288  return success();
3289 }
3290 
3291 //===----------------------------------------------------------------------===//
3292 // WhileOp
3293 //===----------------------------------------------------------------------===//
3294 
3295 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3296  ::mlir::OperationState &odsState, TypeRange resultTypes,
3297  ValueRange inits, BodyBuilderFn beforeBuilder,
3298  BodyBuilderFn afterBuilder) {
3299  odsState.addOperands(inits);
3300  odsState.addTypes(resultTypes);
3301 
3302  OpBuilder::InsertionGuard guard(odsBuilder);
3303 
3304  // Build before region.
3305  SmallVector<Location, 4> beforeArgLocs;
3306  beforeArgLocs.reserve(inits.size());
3307  for (Value operand : inits) {
3308  beforeArgLocs.push_back(operand.getLoc());
3309  }
3310 
3311  Region *beforeRegion = odsState.addRegion();
3312  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3313  inits.getTypes(), beforeArgLocs);
3314  if (beforeBuilder)
3315  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3316 
3317  // Build after region.
3318  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3319 
3320  Region *afterRegion = odsState.addRegion();
3321  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3322  resultTypes, afterArgLocs);
3323 
3324  if (afterBuilder)
3325  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3326 }
3327 
3328 ConditionOp WhileOp::getConditionOp() {
3329  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3330 }
3331 
3332 YieldOp WhileOp::getYieldOp() {
3333  return cast<YieldOp>(getAfterBody()->getTerminator());
3334 }
3335 
3336 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3337  return getYieldOp().getResultsMutable();
3338 }
3339 
3340 Block::BlockArgListType WhileOp::getBeforeArguments() {
3341  return getBeforeBody()->getArguments();
3342 }
3343 
3344 Block::BlockArgListType WhileOp::getAfterArguments() {
3345  return getAfterBody()->getArguments();
3346 }
3347 
3348 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3349  return getBeforeArguments();
3350 }
3351 
3352 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3353  assert(point == getBefore() &&
3354  "WhileOp is expected to branch only to the first region");
3355  return getInits();
3356 }
3357 
3358 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3360  // The parent op always branches to the condition region.
3361  if (point.isParent()) {
3362  regions.emplace_back(&getBefore(), getBefore().getArguments());
3363  return;
3364  }
3365 
3366  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3367  "there are only two regions in a WhileOp");
3368  // The body region always branches back to the condition region.
3369  if (point == getAfter()) {
3370  regions.emplace_back(&getBefore(), getBefore().getArguments());
3371  return;
3372  }
3373 
3374  regions.emplace_back(getResults());
3375  regions.emplace_back(&getAfter(), getAfter().getArguments());
3376 }
3377 
3378 SmallVector<Region *> WhileOp::getLoopRegions() {
3379  return {&getBefore(), &getAfter()};
3380 }
3381 
3382 /// Parses a `while` op.
3383 ///
3384 /// op ::= `scf.while` assignments `:` function-type region `do` region
3385 /// `attributes` attribute-dict
3386 /// initializer ::= /* empty */ | `(` assignment-list `)`
3387 /// assignment-list ::= assignment | assignment `,` assignment-list
3388 /// assignment ::= ssa-value `=` ssa-value
3389 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3392  Region *before = result.addRegion();
3393  Region *after = result.addRegion();
3394 
3395  OptionalParseResult listResult =
3396  parser.parseOptionalAssignmentList(regionArgs, operands);
3397  if (listResult.has_value() && failed(listResult.value()))
3398  return failure();
3399 
3400  FunctionType functionType;
3401  SMLoc typeLoc = parser.getCurrentLocation();
3402  if (failed(parser.parseColonType(functionType)))
3403  return failure();
3404 
3405  result.addTypes(functionType.getResults());
3406 
3407  if (functionType.getNumInputs() != operands.size()) {
3408  return parser.emitError(typeLoc)
3409  << "expected as many input types as operands "
3410  << "(expected " << operands.size() << " got "
3411  << functionType.getNumInputs() << ")";
3412  }
3413 
3414  // Resolve input operands.
3415  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3416  parser.getCurrentLocation(),
3417  result.operands)))
3418  return failure();
3419 
3420  // Propagate the types into the region arguments.
3421  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3422  regionArgs[i].type = functionType.getInput(i);
3423 
3424  return failure(parser.parseRegion(*before, regionArgs) ||
3425  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3427 }
3428 
3429 /// Prints a `while` op.
3431  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3432  p << " : ";
3433  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3434  p << ' ';
3435  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3436  p << " do ";
3437  p.printRegion(getAfter());
3438  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3439 }
3440 
3441 /// Verifies that two ranges of types match, i.e. have the same number of
3442 /// entries and that types are pairwise equals. Reports errors on the given
3443 /// operation in case of mismatch.
3444 template <typename OpTy>
3445 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3446  TypeRange right, StringRef message) {
3447  if (left.size() != right.size())
3448  return op.emitOpError("expects the same number of ") << message;
3449 
3450  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3451  if (left[i] != right[i]) {
3452  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3453  << message;
3454  diag.attachNote() << "for argument " << i << ", found " << left[i]
3455  << " and " << right[i];
3456  return diag;
3457  }
3458  }
3459 
3460  return success();
3461 }
3462 
3463 LogicalResult scf::WhileOp::verify() {
3464  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3465  *this, getBefore(),
3466  "expects the 'before' region to terminate with 'scf.condition'");
3467  if (!beforeTerminator)
3468  return failure();
3469 
3470  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3471  *this, getAfter(),
3472  "expects the 'after' region to terminate with 'scf.yield'");
3473  return success(afterTerminator != nullptr);
3474 }
3475 
3476 namespace {
3477 /// Replace uses of the condition within the do block with true, since otherwise
3478 /// the block would not be evaluated.
3479 ///
3480 /// scf.while (..) : (i1, ...) -> ... {
3481 /// %condition = call @evaluate_condition() : () -> i1
3482 /// scf.condition(%condition) %condition : i1, ...
3483 /// } do {
3484 /// ^bb0(%arg0: i1, ...):
3485 /// use(%arg0)
3486 /// ...
3487 ///
3488 /// becomes
3489 /// scf.while (..) : (i1, ...) -> ... {
3490 /// %condition = call @evaluate_condition() : () -> i1
3491 /// scf.condition(%condition) %condition : i1, ...
3492 /// } do {
3493 /// ^bb0(%arg0: i1, ...):
3494 /// use(%true)
3495 /// ...
3496 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3498 
3499  LogicalResult matchAndRewrite(WhileOp op,
3500  PatternRewriter &rewriter) const override {
3501  auto term = op.getConditionOp();
3502 
3503  // These variables serve to prevent creating duplicate constants
3504  // and hold constant true or false values.
3505  Value constantTrue = nullptr;
3506 
3507  bool replaced = false;
3508  for (auto yieldedAndBlockArgs :
3509  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3510  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3511  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3512  if (!constantTrue)
3513  constantTrue = rewriter.create<arith::ConstantOp>(
3514  op.getLoc(), term.getCondition().getType(),
3515  rewriter.getBoolAttr(true));
3516 
3517  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3518  constantTrue);
3519  replaced = true;
3520  }
3521  }
3522  }
3523  return success(replaced);
3524  }
3525 };
3526 
3527 /// Remove loop invariant arguments from `before` block of scf.while.
3528 /// A before block argument is considered loop invariant if :-
3529 /// 1. i-th yield operand is equal to the i-th while operand.
3530 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3531 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3532 /// iter argument/while operand.
3533 /// For the arguments which are removed, their uses inside scf.while
3534 /// are replaced with their corresponding initial value.
3535 ///
3536 /// Eg:
3537 /// INPUT :-
3538 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3539 /// ..., %argN_before = %N)
3540 /// {
3541 /// ...
3542 /// scf.condition(%cond) %arg1_before, %arg0_before,
3543 /// %arg2_before, %arg0_before, ...
3544 /// } do {
3545 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3546 /// ..., %argK_after):
3547 /// ...
3548 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3549 /// }
3550 ///
3551 /// OUTPUT :-
3552 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3553 /// %N)
3554 /// {
3555 /// ...
3556 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3557 /// } do {
3558 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3559 /// ..., %argK_after):
3560 /// ...
3561 /// scf.yield %arg1_after, ..., %argN
3562 /// }
3563 ///
3564 /// EXPLANATION:
3565 /// We iterate over each yield operand.
3566 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3567 /// %arg0_before, which in turn is the 0-th iter argument. So we
3568 /// remove 0-th before block argument and yield operand, and replace
3569 /// all uses of the 0-th before block argument with its initial value
3570 /// %a.
3571 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3572 /// value. So we remove this operand and the corresponding before
3573 /// block argument and replace all uses of 1-th before block argument
3574 /// with %b.
3575 struct RemoveLoopInvariantArgsFromBeforeBlock
3576  : public OpRewritePattern<WhileOp> {
3578 
3579  LogicalResult matchAndRewrite(WhileOp op,
3580  PatternRewriter &rewriter) const override {
3581  Block &afterBlock = *op.getAfterBody();
3582  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3583  ConditionOp condOp = op.getConditionOp();
3584  OperandRange condOpArgs = condOp.getArgs();
3585  Operation *yieldOp = afterBlock.getTerminator();
3586  ValueRange yieldOpArgs = yieldOp->getOperands();
3587 
3588  bool canSimplify = false;
3589  for (const auto &it :
3590  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3591  auto index = static_cast<unsigned>(it.index());
3592  auto [initVal, yieldOpArg] = it.value();
3593  // If i-th yield operand is equal to the i-th operand of the scf.while,
3594  // the i-th before block argument is a loop invariant.
3595  if (yieldOpArg == initVal) {
3596  canSimplify = true;
3597  break;
3598  }
3599  // If the i-th yield operand is k-th after block argument, then we check
3600  // if the (k+1)-th condition op operand is equal to either the i-th before
3601  // block argument or the initial value of i-th before block argument. If
3602  // the comparison results `true`, i-th before block argument is a loop
3603  // invariant.
3604  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3605  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3606  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3607  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3608  canSimplify = true;
3609  break;
3610  }
3611  }
3612  }
3613 
3614  if (!canSimplify)
3615  return failure();
3616 
3617  SmallVector<Value> newInitArgs, newYieldOpArgs;
3618  DenseMap<unsigned, Value> beforeBlockInitValMap;
3619  SmallVector<Location> newBeforeBlockArgLocs;
3620  for (const auto &it :
3621  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3622  auto index = static_cast<unsigned>(it.index());
3623  auto [initVal, yieldOpArg] = it.value();
3624 
3625  // If i-th yield operand is equal to the i-th operand of the scf.while,
3626  // the i-th before block argument is a loop invariant.
3627  if (yieldOpArg == initVal) {
3628  beforeBlockInitValMap.insert({index, initVal});
3629  continue;
3630  } else {
3631  // If the i-th yield operand is k-th after block argument, then we check
3632  // if the (k+1)-th condition op operand is equal to either the i-th
3633  // before block argument or the initial value of i-th before block
3634  // argument. If the comparison results `true`, i-th before block
3635  // argument is a loop invariant.
3636  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3637  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3638  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3639  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3640  beforeBlockInitValMap.insert({index, initVal});
3641  continue;
3642  }
3643  }
3644  }
3645  newInitArgs.emplace_back(initVal);
3646  newYieldOpArgs.emplace_back(yieldOpArg);
3647  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3648  }
3649 
3650  {
3651  OpBuilder::InsertionGuard g(rewriter);
3652  rewriter.setInsertionPoint(yieldOp);
3653  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3654  }
3655 
3656  auto newWhile =
3657  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3658 
3659  Block &newBeforeBlock = *rewriter.createBlock(
3660  &newWhile.getBefore(), /*insertPt*/ {},
3661  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3662 
3663  Block &beforeBlock = *op.getBeforeBody();
3664  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3665  // For each i-th before block argument we find it's replacement value as :-
3666  // 1. If i-th before block argument is a loop invariant, we fetch it's
3667  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3668  // 2. Else we fetch j-th new before block argument as the replacement
3669  // value of i-th before block argument.
3670  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3671  // If the index 'i' argument was a loop invariant we fetch it's initial
3672  // value from `beforeBlockInitValMap`.
3673  if (beforeBlockInitValMap.count(i) != 0)
3674  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3675  else
3676  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3677  }
3678 
3679  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3680  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3681  newWhile.getAfter().begin());
3682 
3683  rewriter.replaceOp(op, newWhile.getResults());
3684  return success();
3685  }
3686 };
3687 
3688 /// Remove loop invariant value from result (condition op) of scf.while.
3689 /// A value is considered loop invariant if the final value yielded by
3690 /// scf.condition is defined outside of the `before` block. We remove the
3691 /// corresponding argument in `after` block and replace the use with the value.
3692 /// We also replace the use of the corresponding result of scf.while with the
3693 /// value.
3694 ///
3695 /// Eg:
3696 /// INPUT :-
3697 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3698 /// %argN_before = %N) {
3699 /// ...
3700 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3701 /// } do {
3702 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3703 /// ...
3704 /// some_func(%arg1_after)
3705 /// ...
3706 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3707 /// }
3708 ///
3709 /// OUTPUT :-
3710 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3711 /// ...
3712 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3713 /// } do {
3714 /// ^bb0(%arg0, %arg3, ..., %argM):
3715 /// ...
3716 /// some_func(%a)
3717 /// ...
3718 /// scf.yield %arg0, %b, ..., %argN
3719 /// }
3720 ///
3721 /// EXPLANATION:
3722 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3723 /// before block of scf.while, so they get removed.
3724 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3725 /// replaced by %b.
3726 /// 3. The corresponding after block argument %arg1_after's uses are
3727 /// replaced by %a and %arg2_after's uses are replaced by %b.
3728 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3730 
3731  LogicalResult matchAndRewrite(WhileOp op,
3732  PatternRewriter &rewriter) const override {
3733  Block &beforeBlock = *op.getBeforeBody();
3734  ConditionOp condOp = op.getConditionOp();
3735  OperandRange condOpArgs = condOp.getArgs();
3736 
3737  bool canSimplify = false;
3738  for (Value condOpArg : condOpArgs) {
3739  // Those values not defined within `before` block will be considered as
3740  // loop invariant values. We map the corresponding `index` with their
3741  // value.
3742  if (condOpArg.getParentBlock() != &beforeBlock) {
3743  canSimplify = true;
3744  break;
3745  }
3746  }
3747 
3748  if (!canSimplify)
3749  return failure();
3750 
3751  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3752 
3753  SmallVector<Value> newCondOpArgs;
3754  SmallVector<Type> newAfterBlockType;
3755  DenseMap<unsigned, Value> condOpInitValMap;
3756  SmallVector<Location> newAfterBlockArgLocs;
3757  for (const auto &it : llvm::enumerate(condOpArgs)) {
3758  auto index = static_cast<unsigned>(it.index());
3759  Value condOpArg = it.value();
3760  // Those values not defined within `before` block will be considered as
3761  // loop invariant values. We map the corresponding `index` with their
3762  // value.
3763  if (condOpArg.getParentBlock() != &beforeBlock) {
3764  condOpInitValMap.insert({index, condOpArg});
3765  } else {
3766  newCondOpArgs.emplace_back(condOpArg);
3767  newAfterBlockType.emplace_back(condOpArg.getType());
3768  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3769  }
3770  }
3771 
3772  {
3773  OpBuilder::InsertionGuard g(rewriter);
3774  rewriter.setInsertionPoint(condOp);
3775  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3776  newCondOpArgs);
3777  }
3778 
3779  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3780  op.getOperands());
3781 
3782  Block &newAfterBlock =
3783  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3784  newAfterBlockType, newAfterBlockArgLocs);
3785 
3786  Block &afterBlock = *op.getAfterBody();
3787  // Since a new scf.condition op was created, we need to fetch the new
3788  // `after` block arguments which will be used while replacing operations of
3789  // previous scf.while's `after` blocks. We'd also be fetching new result
3790  // values too.
3791  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3792  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3793  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3794  Value afterBlockArg, result;
3795  // If index 'i' argument was loop invariant we fetch it's value from the
3796  // `condOpInitMap` map.
3797  if (condOpInitValMap.count(i) != 0) {
3798  afterBlockArg = condOpInitValMap[i];
3799  result = afterBlockArg;
3800  } else {
3801  afterBlockArg = newAfterBlock.getArgument(j);
3802  result = newWhile.getResult(j);
3803  j++;
3804  }
3805  newAfterBlockArgs[i] = afterBlockArg;
3806  newWhileResults[i] = result;
3807  }
3808 
3809  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3810  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3811  newWhile.getBefore().begin());
3812 
3813  rewriter.replaceOp(op, newWhileResults);
3814  return success();
3815  }
3816 };
3817 
3818 /// Remove WhileOp results that are also unused in 'after' block.
3819 ///
3820 /// %0:2 = scf.while () : () -> (i32, i64) {
3821 /// %condition = "test.condition"() : () -> i1
3822 /// %v1 = "test.get_some_value"() : () -> i32
3823 /// %v2 = "test.get_some_value"() : () -> i64
3824 /// scf.condition(%condition) %v1, %v2 : i32, i64
3825 /// } do {
3826 /// ^bb0(%arg0: i32, %arg1: i64):
3827 /// "test.use"(%arg0) : (i32) -> ()
3828 /// scf.yield
3829 /// }
3830 /// return %0#0 : i32
3831 ///
3832 /// becomes
3833 /// %0 = scf.while () : () -> (i32) {
3834 /// %condition = "test.condition"() : () -> i1
3835 /// %v1 = "test.get_some_value"() : () -> i32
3836 /// %v2 = "test.get_some_value"() : () -> i64
3837 /// scf.condition(%condition) %v1 : i32
3838 /// } do {
3839 /// ^bb0(%arg0: i32):
3840 /// "test.use"(%arg0) : (i32) -> ()
3841 /// scf.yield
3842 /// }
3843 /// return %0 : i32
3844 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3846 
3847  LogicalResult matchAndRewrite(WhileOp op,
3848  PatternRewriter &rewriter) const override {
3849  auto term = op.getConditionOp();
3850  auto afterArgs = op.getAfterArguments();
3851  auto termArgs = term.getArgs();
3852 
3853  // Collect results mapping, new terminator args and new result types.
3854  SmallVector<unsigned> newResultsIndices;
3855  SmallVector<Type> newResultTypes;
3856  SmallVector<Value> newTermArgs;
3857  SmallVector<Location> newArgLocs;
3858  bool needUpdate = false;
3859  for (const auto &it :
3860  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3861  auto i = static_cast<unsigned>(it.index());
3862  Value result = std::get<0>(it.value());
3863  Value afterArg = std::get<1>(it.value());
3864  Value termArg = std::get<2>(it.value());
3865  if (result.use_empty() && afterArg.use_empty()) {
3866  needUpdate = true;
3867  } else {
3868  newResultsIndices.emplace_back(i);
3869  newTermArgs.emplace_back(termArg);
3870  newResultTypes.emplace_back(result.getType());
3871  newArgLocs.emplace_back(result.getLoc());
3872  }
3873  }
3874 
3875  if (!needUpdate)
3876  return failure();
3877 
3878  {
3879  OpBuilder::InsertionGuard g(rewriter);
3880  rewriter.setInsertionPoint(term);
3881  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3882  newTermArgs);
3883  }
3884 
3885  auto newWhile =
3886  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3887 
3888  Block &newAfterBlock = *rewriter.createBlock(
3889  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3890 
3891  // Build new results list and new after block args (unused entries will be
3892  // null).
3893  SmallVector<Value> newResults(op.getNumResults());
3894  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3895  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3896  newResults[it.value()] = newWhile.getResult(it.index());
3897  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3898  }
3899 
3900  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3901  newWhile.getBefore().begin());
3902 
3903  Block &afterBlock = *op.getAfterBody();
3904  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3905 
3906  rewriter.replaceOp(op, newResults);
3907  return success();
3908  }
3909 };
3910 
3911 /// Replace operations equivalent to the condition in the do block with true,
3912 /// since otherwise the block would not be evaluated.
3913 ///
3914 /// scf.while (..) : (i32, ...) -> ... {
3915 /// %z = ... : i32
3916 /// %condition = cmpi pred %z, %a
3917 /// scf.condition(%condition) %z : i32, ...
3918 /// } do {
3919 /// ^bb0(%arg0: i32, ...):
3920 /// %condition2 = cmpi pred %arg0, %a
3921 /// use(%condition2)
3922 /// ...
3923 ///
3924 /// becomes
3925 /// scf.while (..) : (i32, ...) -> ... {
3926 /// %z = ... : i32
3927 /// %condition = cmpi pred %z, %a
3928 /// scf.condition(%condition) %z : i32, ...
3929 /// } do {
3930 /// ^bb0(%arg0: i32, ...):
3931 /// use(%true)
3932 /// ...
3933 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3935 
3936  LogicalResult matchAndRewrite(scf::WhileOp op,
3937  PatternRewriter &rewriter) const override {
3938  using namespace scf;
3939  auto cond = op.getConditionOp();
3940  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3941  if (!cmp)
3942  return failure();
3943  bool changed = false;
3944  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3945  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3946  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3947  continue;
3948  for (OpOperand &u :
3949  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3950  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3951  if (!cmp2)
3952  continue;
3953  // For a binary operator 1-opIdx gets the other side.
3954  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3955  continue;
3956  bool samePredicate;
3957  if (cmp2.getPredicate() == cmp.getPredicate())
3958  samePredicate = true;
3959  else if (cmp2.getPredicate() ==
3960  arith::invertPredicate(cmp.getPredicate()))
3961  samePredicate = false;
3962  else
3963  continue;
3964 
3965  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3966  1);
3967  changed = true;
3968  }
3969  }
3970  }
3971  return success(changed);
3972  }
3973 };
3974 
3975 /// Remove unused init/yield args.
3976 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3978 
3979  LogicalResult matchAndRewrite(WhileOp op,
3980  PatternRewriter &rewriter) const override {
3981 
3982  if (!llvm::any_of(op.getBeforeArguments(),
3983  [](Value arg) { return arg.use_empty(); }))
3984  return rewriter.notifyMatchFailure(op, "No args to remove");
3985 
3986  YieldOp yield = op.getYieldOp();
3987 
3988  // Collect results mapping, new terminator args and new result types.
3989  SmallVector<Value> newYields;
3990  SmallVector<Value> newInits;
3991  llvm::BitVector argsToErase;
3992 
3993  size_t argsCount = op.getBeforeArguments().size();
3994  newYields.reserve(argsCount);
3995  newInits.reserve(argsCount);
3996  argsToErase.reserve(argsCount);
3997  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3998  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3999  if (beforeArg.use_empty()) {
4000  argsToErase.push_back(true);
4001  } else {
4002  argsToErase.push_back(false);
4003  newYields.emplace_back(yieldValue);
4004  newInits.emplace_back(initValue);
4005  }
4006  }
4007 
4008  Block &beforeBlock = *op.getBeforeBody();
4009  Block &afterBlock = *op.getAfterBody();
4010 
4011  beforeBlock.eraseArguments(argsToErase);
4012 
4013  Location loc = op.getLoc();
4014  auto newWhileOp =
4015  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4016  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4017  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4018  Block &newAfterBlock = *newWhileOp.getAfterBody();
4019 
4020  OpBuilder::InsertionGuard g(rewriter);
4021  rewriter.setInsertionPoint(yield);
4022  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4023 
4024  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4025  newBeforeBlock.getArguments());
4026  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4027  newAfterBlock.getArguments());
4028 
4029  rewriter.replaceOp(op, newWhileOp.getResults());
4030  return success();
4031  }
4032 };
4033 
4034 /// Remove duplicated ConditionOp args.
4035 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4037 
4038  LogicalResult matchAndRewrite(WhileOp op,
4039  PatternRewriter &rewriter) const override {
4040  ConditionOp condOp = op.getConditionOp();
4041  ValueRange condOpArgs = condOp.getArgs();
4042 
4043  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4044 
4045  if (argsSet.size() == condOpArgs.size())
4046  return rewriter.notifyMatchFailure(op, "No results to remove");
4047 
4048  llvm::SmallDenseMap<Value, unsigned> argsMap;
4049  SmallVector<Value> newArgs;
4050  argsMap.reserve(condOpArgs.size());
4051  newArgs.reserve(condOpArgs.size());
4052  for (Value arg : condOpArgs) {
4053  if (!argsMap.count(arg)) {
4054  auto pos = static_cast<unsigned>(argsMap.size());
4055  argsMap.insert({arg, pos});
4056  newArgs.emplace_back(arg);
4057  }
4058  }
4059 
4060  ValueRange argsRange(newArgs);
4061 
4062  Location loc = op.getLoc();
4063  auto newWhileOp = rewriter.create<scf::WhileOp>(
4064  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4065  /*afterBody*/ nullptr);
4066  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4067  Block &newAfterBlock = *newWhileOp.getAfterBody();
4068 
4069  SmallVector<Value> afterArgsMapping;
4070  SmallVector<Value> resultsMapping;
4071  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4072  auto it = argsMap.find(arg);
4073  assert(it != argsMap.end());
4074  auto pos = it->second;
4075  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4076  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4077  }
4078 
4079  OpBuilder::InsertionGuard g(rewriter);
4080  rewriter.setInsertionPoint(condOp);
4081  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4082  argsRange);
4083 
4084  Block &beforeBlock = *op.getBeforeBody();
4085  Block &afterBlock = *op.getAfterBody();
4086 
4087  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4088  newBeforeBlock.getArguments());
4089  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4090  rewriter.replaceOp(op, resultsMapping);
4091  return success();
4092  }
4093 };
4094 
4095 /// If both ranges contain same values return mappping indices from args2 to
4096 /// args1. Otherwise return std::nullopt.
4097 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4098  ValueRange args2) {
4099  if (args1.size() != args2.size())
4100  return std::nullopt;
4101 
4102  SmallVector<unsigned> ret(args1.size());
4103  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4104  auto it = llvm::find(args2, arg1);
4105  if (it == args2.end())
4106  return std::nullopt;
4107 
4108  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4109  }
4110 
4111  return ret;
4112 }
4113 
4114 static bool hasDuplicates(ValueRange args) {
4115  llvm::SmallDenseSet<Value> set;
4116  for (Value arg : args) {
4117  if (!set.insert(arg).second)
4118  return true;
4119  }
4120  return false;
4121 }
4122 
4123 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4124 /// `scf.condition` args into same order as block args. Update `after` block
4125 /// args and op result values accordingly.
4126 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4127 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4129 
4130  LogicalResult matchAndRewrite(WhileOp loop,
4131  PatternRewriter &rewriter) const override {
4132  auto oldBefore = loop.getBeforeBody();
4133  ConditionOp oldTerm = loop.getConditionOp();
4134  ValueRange beforeArgs = oldBefore->getArguments();
4135  ValueRange termArgs = oldTerm.getArgs();
4136  if (beforeArgs == termArgs)
4137  return failure();
4138 
4139  if (hasDuplicates(termArgs))
4140  return failure();
4141 
4142  auto mapping = getArgsMapping(beforeArgs, termArgs);
4143  if (!mapping)
4144  return failure();
4145 
4146  {
4147  OpBuilder::InsertionGuard g(rewriter);
4148  rewriter.setInsertionPoint(oldTerm);
4149  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4150  beforeArgs);
4151  }
4152 
4153  auto oldAfter = loop.getAfterBody();
4154 
4155  SmallVector<Type> newResultTypes(beforeArgs.size());
4156  for (auto &&[i, j] : llvm::enumerate(*mapping))
4157  newResultTypes[j] = loop.getResult(i).getType();
4158 
4159  auto newLoop = rewriter.create<WhileOp>(
4160  loop.getLoc(), newResultTypes, loop.getInits(),
4161  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4162  auto newBefore = newLoop.getBeforeBody();
4163  auto newAfter = newLoop.getAfterBody();
4164 
4165  SmallVector<Value> newResults(beforeArgs.size());
4166  SmallVector<Value> newAfterArgs(beforeArgs.size());
4167  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4168  newResults[i] = newLoop.getResult(j);
4169  newAfterArgs[i] = newAfter->getArgument(j);
4170  }
4171 
4172  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4173  newBefore->getArguments());
4174  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4175  newAfterArgs);
4176 
4177  rewriter.replaceOp(loop, newResults);
4178  return success();
4179  }
4180 };
4181 } // namespace
4182 
4183 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4184  MLIRContext *context) {
4185  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4186  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4187  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4188  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4189 }
4190 
4191 //===----------------------------------------------------------------------===//
4192 // IndexSwitchOp
4193 //===----------------------------------------------------------------------===//
4194 
4195 /// Parse the case regions and values.
4196 static ParseResult
4198  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4199  SmallVector<int64_t> caseValues;
4200  while (succeeded(p.parseOptionalKeyword("case"))) {
4201  int64_t value;
4202  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4203  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4204  return failure();
4205  caseValues.push_back(value);
4206  }
4207  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4208  return success();
4209 }
4210 
4211 /// Print the case regions and values.
4213  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4214  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4215  p.printNewline();
4216  p << "case " << value << ' ';
4217  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4218  }
4219 }
4220 
4221 LogicalResult scf::IndexSwitchOp::verify() {
4222  if (getCases().size() != getCaseRegions().size()) {
4223  return emitOpError("has ")
4224  << getCaseRegions().size() << " case regions but "
4225  << getCases().size() << " case values";
4226  }
4227 
4228  DenseSet<int64_t> valueSet;
4229  for (int64_t value : getCases())
4230  if (!valueSet.insert(value).second)
4231  return emitOpError("has duplicate case value: ") << value;
4232  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4233  auto yield = dyn_cast<YieldOp>(region.front().back());
4234  if (!yield)
4235  return emitOpError("expected region to end with scf.yield, but got ")
4236  << region.front().back().getName();
4237 
4238  if (yield.getNumOperands() != getNumResults()) {
4239  return (emitOpError("expected each region to return ")
4240  << getNumResults() << " values, but " << name << " returns "
4241  << yield.getNumOperands())
4242  .attachNote(yield.getLoc())
4243  << "see yield operation here";
4244  }
4245  for (auto [idx, result, operand] :
4246  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4247  yield.getOperandTypes())) {
4248  if (result == operand)
4249  continue;
4250  return (emitOpError("expected result #")
4251  << idx << " of each region to be " << result)
4252  .attachNote(yield.getLoc())
4253  << name << " returns " << operand << " here";
4254  }
4255  return success();
4256  };
4257 
4258  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4259  return failure();
4260  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4261  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4262  return failure();
4263 
4264  return success();
4265 }
4266 
4267 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4268 
4269 Block &scf::IndexSwitchOp::getDefaultBlock() {
4270  return getDefaultRegion().front();
4271 }
4272 
4273 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4274  assert(idx < getNumCases() && "case index out-of-bounds");
4275  return getCaseRegions()[idx].front();
4276 }
4277 
4278 void IndexSwitchOp::getSuccessorRegions(
4280  // All regions branch back to the parent op.
4281  if (!point.isParent()) {
4282  successors.emplace_back(getResults());
4283  return;
4284  }
4285 
4286  llvm::append_range(successors, getRegions());
4287 }
4288 
4289 void IndexSwitchOp::getEntrySuccessorRegions(
4290  ArrayRef<Attribute> operands,
4291  SmallVectorImpl<RegionSuccessor> &successors) {
4292  FoldAdaptor adaptor(operands, *this);
4293 
4294  // If a constant was not provided, all regions are possible successors.
4295  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4296  if (!arg) {
4297  llvm::append_range(successors, getRegions());
4298  return;
4299  }
4300 
4301  // Otherwise, try to find a case with a matching value. If not, the
4302  // default region is the only successor.
4303  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4304  if (caseValue == arg.getInt()) {
4305  successors.emplace_back(&caseRegion);
4306  return;
4307  }
4308  }
4309  successors.emplace_back(&getDefaultRegion());
4310 }
4311 
4312 void IndexSwitchOp::getRegionInvocationBounds(
4314  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4315  if (!operandValue) {
4316  // All regions are invoked at most once.
4317  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4318  return;
4319  }
4320 
4321  unsigned liveIndex = getNumRegions() - 1;
4322  const auto *it = llvm::find(getCases(), operandValue.getInt());
4323  if (it != getCases().end())
4324  liveIndex = std::distance(getCases().begin(), it);
4325  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4326  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4327 }
4328 
4329 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4331 
4332  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4333  PatternRewriter &rewriter) const override {
4334  // If `op.getArg()` is a constant, select the region that matches with
4335  // the constant value. Use the default region if no matche is found.
4336  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4337  if (!maybeCst.has_value())
4338  return failure();
4339  int64_t cst = *maybeCst;
4340  int64_t caseIdx, e = op.getNumCases();
4341  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4342  if (cst == op.getCases()[caseIdx])
4343  break;
4344  }
4345 
4346  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4347  : op.getDefaultRegion();
4348  Block &source = r.front();
4349  Operation *terminator = source.getTerminator();
4350  SmallVector<Value> results = terminator->getOperands();
4351 
4352  rewriter.inlineBlockBefore(&source, op);
4353  rewriter.eraseOp(terminator);
4354  // Replace the operation with a potentially empty list of results.
4355  // Fold mechanism doesn't support the case where the result list is empty.
4356  rewriter.replaceOp(op, results);
4357 
4358  return success();
4359  }
4360 };
4361 
4362 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4363  MLIRContext *context) {
4364  results.add<FoldConstantCase>(context);
4365 }
4366 
4367 //===----------------------------------------------------------------------===//
4368 // TableGen'd op method definitions
4369 //===----------------------------------------------------------------------===//
4370 
4371 #define GET_OP_CLASSES
4372 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:746
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:738
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1287
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:114
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4197
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3445
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:433
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:94
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4212
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:295
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:307
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:576
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:586
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:194
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3058
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:87
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:694
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1986
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:651
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:604
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1460
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:783
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:269
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4332
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:235
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:186
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.