MLIR  22.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterface<ConvertToEmitCPatternInterface, SCFDialect>();
77  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78  InParallelOp, ReduceReturnOp>();
79  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81  ForallOp, InParallelOp, WhileOp, YieldOp>();
82  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
83 }
84 
85 /// Default callback for IfOp builders. Inserts a yield without arguments.
87  scf::YieldOp::create(builder, loc);
88 }
89 
90 /// Verifies that the first block of the given `region` is terminated by a
91 /// TerminatorTy. Reports errors on the given operation if it is not the case.
92 template <typename TerminatorTy>
93 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
94  StringRef errorMessage) {
95  Operation *terminatorOperation = nullptr;
96  if (!region.empty() && !region.front().empty()) {
97  terminatorOperation = &region.front().back();
98  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
99  return yield;
100  }
101  auto diag = op->emitOpError(errorMessage);
102  if (terminatorOperation)
103  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
104  return nullptr;
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // ExecuteRegionOp
109 //===----------------------------------------------------------------------===//
110 
111 /// Replaces the given op with the contents of the given single-block region,
112 /// using the operands of the block terminator to replace operation results.
113 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
114  Region &region, ValueRange blockArgs = {}) {
115  assert(region.hasOneBlock() && "expected single-block region");
116  Block *block = &region.front();
117  Operation *terminator = block->getTerminator();
118  ValueRange results = terminator->getOperands();
119  rewriter.inlineBlockBefore(block, op, blockArgs);
120  rewriter.replaceOp(op, results);
121  rewriter.eraseOp(terminator);
122 }
123 
124 ///
125 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
126 /// block+
127 /// `}`
128 ///
129 /// Example:
130 /// scf.execute_region -> i32 {
131 /// %idx = load %rI[%i] : memref<128xi32>
132 /// return %idx : i32
133 /// }
134 ///
135 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
136  OperationState &result) {
137  if (parser.parseOptionalArrowTypeList(result.types))
138  return failure();
139 
140  // Introduce the body region and parse it.
141  Region *body = result.addRegion();
142  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
143  parser.parseOptionalAttrDict(result.attributes))
144  return failure();
145 
146  return success();
147 }
148 
150  p.printOptionalArrowTypeList(getResultTypes());
151 
152  p << ' ';
153  p.printRegion(getRegion(),
154  /*printEntryBlockArgs=*/false,
155  /*printBlockTerminators=*/true);
156 
157  p.printOptionalAttrDict((*this)->getAttrs());
158 }
159 
160 LogicalResult ExecuteRegionOp::verify() {
161  if (getRegion().empty())
162  return emitOpError("region needs to have at least one block");
163  if (getRegion().front().getNumArguments() > 0)
164  return emitOpError("region cannot have any arguments");
165  return success();
166 }
167 
168 // Inline an ExecuteRegionOp if it only contains one block.
169 // "test.foo"() : () -> ()
170 // %v = scf.execute_region -> i64 {
171 // %x = "test.val"() : () -> i64
172 // scf.yield %x : i64
173 // }
174 // "test.bar"(%v) : (i64) -> ()
175 //
176 // becomes
177 //
178 // "test.foo"() : () -> ()
179 // %x = "test.val"() : () -> i64
180 // "test.bar"(%x) : (i64) -> ()
181 //
182 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
184 
185  LogicalResult matchAndRewrite(ExecuteRegionOp op,
186  PatternRewriter &rewriter) const override {
187  if (!op.getRegion().hasOneBlock())
188  return failure();
189  replaceOpWithRegion(rewriter, op, op.getRegion());
190  return success();
191  }
192 };
193 
194 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
195 // TODO generalize the conditions for operations which can be inlined into.
196 // func @func_execute_region_elim() {
197 // "test.foo"() : () -> ()
198 // %v = scf.execute_region -> i64 {
199 // %c = "test.cmp"() : () -> i1
200 // cf.cond_br %c, ^bb2, ^bb3
201 // ^bb2:
202 // %x = "test.val1"() : () -> i64
203 // cf.br ^bb4(%x : i64)
204 // ^bb3:
205 // %y = "test.val2"() : () -> i64
206 // cf.br ^bb4(%y : i64)
207 // ^bb4(%z : i64):
208 // scf.yield %z : i64
209 // }
210 // "test.bar"(%v) : (i64) -> ()
211 // return
212 // }
213 //
214 // becomes
215 //
216 // func @func_execute_region_elim() {
217 // "test.foo"() : () -> ()
218 // %c = "test.cmp"() : () -> i1
219 // cf.cond_br %c, ^bb1, ^bb2
220 // ^bb1: // pred: ^bb0
221 // %x = "test.val1"() : () -> i64
222 // cf.br ^bb3(%x : i64)
223 // ^bb2: // pred: ^bb0
224 // %y = "test.val2"() : () -> i64
225 // cf.br ^bb3(%y : i64)
226 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
227 // "test.bar"(%z) : (i64) -> ()
228 // return
229 // }
230 //
231 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
233 
234  LogicalResult matchAndRewrite(ExecuteRegionOp op,
235  PatternRewriter &rewriter) const override {
236  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
237  return failure();
238 
239  Block *prevBlock = op->getBlock();
240  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
241  rewriter.setInsertionPointToEnd(prevBlock);
242 
243  cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front());
244 
245  for (Block &blk : op.getRegion()) {
246  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247  rewriter.setInsertionPoint(yieldOp);
248  cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock,
249  yieldOp.getResults());
250  rewriter.eraseOp(yieldOp);
251  }
252  }
253 
254  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
255  SmallVector<Value> blockArgs;
256 
257  for (auto res : op.getResults())
258  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
259 
260  rewriter.replaceOp(op, blockArgs);
261  return success();
262  }
263 };
264 
265 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
266  MLIRContext *context) {
268 }
269 
270 void ExecuteRegionOp::getSuccessorRegions(
272  // If the predecessor is the ExecuteRegionOp, branch into the body.
273  if (point.isParent()) {
274  regions.push_back(RegionSuccessor(&getRegion()));
275  return;
276  }
277 
278  // Otherwise, the region branches back to the parent operation.
279  regions.push_back(RegionSuccessor(getResults()));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // ConditionOp
284 //===----------------------------------------------------------------------===//
285 
288  assert((point.isParent() || point == getParentOp().getAfter()) &&
289  "condition op can only exit the loop or branch to the after"
290  "region");
291  // Pass all operands except the condition to the successor region.
292  return getArgsMutable();
293 }
294 
295 void ConditionOp::getSuccessorRegions(
297  FoldAdaptor adaptor(operands, *this);
298 
299  WhileOp whileOp = getParentOp();
300 
301  // Condition can either lead to the after region or back to the parent op
302  // depending on whether the condition is true or not.
303  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304  if (!boolAttr || boolAttr.getValue())
305  regions.emplace_back(&whileOp.getAfter(),
306  whileOp.getAfter().getArguments());
307  if (!boolAttr || !boolAttr.getValue())
308  regions.emplace_back(whileOp.getResults());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // ForOp
313 //===----------------------------------------------------------------------===//
314 
315 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
316  Value ub, Value step, ValueRange initArgs,
317  BodyBuilderFn bodyBuilder) {
318  OpBuilder::InsertionGuard guard(builder);
319 
320  result.addOperands({lb, ub, step});
321  result.addOperands(initArgs);
322  for (Value v : initArgs)
323  result.addTypes(v.getType());
324  Type t = lb.getType();
325  Region *bodyRegion = result.addRegion();
326  Block *bodyBlock = builder.createBlock(bodyRegion);
327  bodyBlock->addArgument(t, result.location);
328  for (Value v : initArgs)
329  bodyBlock->addArgument(v.getType(), v.getLoc());
330 
331  // Create the default terminator if the builder is not provided and if the
332  // iteration arguments are not provided. Otherwise, leave this to the caller
333  // because we don't know which values to return from the loop.
334  if (initArgs.empty() && !bodyBuilder) {
335  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
336  } else if (bodyBuilder) {
337  OpBuilder::InsertionGuard guard(builder);
338  builder.setInsertionPointToStart(bodyBlock);
339  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
340  bodyBlock->getArguments().drop_front());
341  }
342 }
343 
344 LogicalResult ForOp::verify() {
345  // Check that the number of init args and op results is the same.
346  if (getInitArgs().size() != getNumResults())
347  return emitOpError(
348  "mismatch in number of loop-carried values and defined values");
349 
350  return success();
351 }
352 
353 LogicalResult ForOp::verifyRegions() {
354  // Check that the body defines as single block argument for the induction
355  // variable.
356  if (getInductionVar().getType() != getLowerBound().getType())
357  return emitOpError(
358  "expected induction variable to be same type as bounds and step");
359 
360  if (getNumRegionIterArgs() != getNumResults())
361  return emitOpError(
362  "mismatch in number of basic block args and defined values");
363 
364  auto initArgs = getInitArgs();
365  auto iterArgs = getRegionIterArgs();
366  auto opResults = getResults();
367  unsigned i = 0;
368  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369  if (std::get<0>(e).getType() != std::get<2>(e).getType())
370  return emitOpError() << "types mismatch between " << i
371  << "th iter operand and defined value";
372  if (std::get<1>(e).getType() != std::get<2>(e).getType())
373  return emitOpError() << "types mismatch between " << i
374  << "th iter region arg and defined value";
375 
376  ++i;
377  }
378  return success();
379 }
380 
381 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
382  return SmallVector<Value>{getInductionVar()};
383 }
384 
385 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
387 }
388 
389 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
390  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
391 }
392 
393 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
395 }
396 
397 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
398 
399 /// Promotes the loop body of a forOp to its containing block if the forOp
400 /// it can be determined that the loop has a single iteration.
401 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
402  std::optional<int64_t> tripCount =
404  if (!tripCount.has_value() || tripCount != 1)
405  return failure();
406 
407  // Replace all results with the yielded values.
408  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
409  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
410 
411  // Replace block arguments with lower bound (replacement for IV) and
412  // iter_args.
413  SmallVector<Value> bbArgReplacements;
414  bbArgReplacements.push_back(getLowerBound());
415  llvm::append_range(bbArgReplacements, getInitArgs());
416 
417  // Move the loop body operations to the loop's containing block.
418  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
419  getOperation()->getIterator(), bbArgReplacements);
420 
421  // Erase the old terminator and the loop.
422  rewriter.eraseOp(yieldOp);
423  rewriter.eraseOp(*this);
424 
425  return success();
426 }
427 
428 /// Prints the initialization list in the form of
429 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
430 /// where 'inner' values are assumed to be region arguments and 'outer' values
431 /// are regular SSA values.
433  Block::BlockArgListType blocksArgs,
434  ValueRange initializers,
435  StringRef prefix = "") {
436  assert(blocksArgs.size() == initializers.size() &&
437  "expected same length of arguments and initializers");
438  if (initializers.empty())
439  return;
440 
441  p << prefix << '(';
442  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
443  p << std::get<0>(it) << " = " << std::get<1>(it);
444  });
445  p << ")";
446 }
447 
448 void ForOp::print(OpAsmPrinter &p) {
449  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
450  << getUpperBound() << " step " << getStep();
451 
452  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
453  if (!getInitArgs().empty())
454  p << " -> (" << getInitArgs().getTypes() << ')';
455  p << ' ';
456  if (Type t = getInductionVar().getType(); !t.isIndex())
457  p << " : " << t << ' ';
458  p.printRegion(getRegion(),
459  /*printEntryBlockArgs=*/false,
460  /*printBlockTerminators=*/!getInitArgs().empty());
461  p.printOptionalAttrDict((*this)->getAttrs());
462 }
463 
464 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
465  auto &builder = parser.getBuilder();
466  Type type;
467 
468  OpAsmParser::Argument inductionVariable;
469  OpAsmParser::UnresolvedOperand lb, ub, step;
470 
471  // Parse the induction variable followed by '='.
472  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
473  // Parse loop bounds.
474  parser.parseOperand(lb) || parser.parseKeyword("to") ||
475  parser.parseOperand(ub) || parser.parseKeyword("step") ||
476  parser.parseOperand(step))
477  return failure();
478 
479  // Parse the optional initial iteration arguments.
482  regionArgs.push_back(inductionVariable);
483 
484  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
485  if (hasIterArgs) {
486  // Parse assignment list and results type list.
487  if (parser.parseAssignmentList(regionArgs, operands) ||
488  parser.parseArrowTypeList(result.types))
489  return failure();
490  }
491 
492  if (regionArgs.size() != result.types.size() + 1)
493  return parser.emitError(
494  parser.getNameLoc(),
495  "mismatch in number of loop-carried values and defined values");
496 
497  // Parse optional type, else assume Index.
498  if (parser.parseOptionalColon())
499  type = builder.getIndexType();
500  else if (parser.parseType(type))
501  return failure();
502 
503  // Set block argument types, so that they are known when parsing the region.
504  regionArgs.front().type = type;
505  for (auto [iterArg, type] :
506  llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
507  iterArg.type = type;
508 
509  // Parse the body region.
510  Region *body = result.addRegion();
511  if (parser.parseRegion(*body, regionArgs))
512  return failure();
513  ForOp::ensureTerminator(*body, builder, result.location);
514 
515  // Resolve input operands. This should be done after parsing the region to
516  // catch invalid IR where operands were defined inside of the region.
517  if (parser.resolveOperand(lb, type, result.operands) ||
518  parser.resolveOperand(ub, type, result.operands) ||
519  parser.resolveOperand(step, type, result.operands))
520  return failure();
521  if (hasIterArgs) {
522  for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
523  operands, result.types)) {
524  Type type = std::get<2>(argOperandType);
525  std::get<0>(argOperandType).type = type;
526  if (parser.resolveOperand(std::get<1>(argOperandType), type,
527  result.operands))
528  return failure();
529  }
530  }
531 
532  // Parse the optional attribute list.
533  if (parser.parseOptionalAttrDict(result.attributes))
534  return failure();
535 
536  return success();
537 }
538 
539 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
540 
541 Block::BlockArgListType ForOp::getRegionIterArgs() {
542  return getBody()->getArguments().drop_front(getNumInductionVars());
543 }
544 
545 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
546  return getInitArgsMutable();
547 }
548 
549 FailureOr<LoopLikeOpInterface>
550 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
551  ValueRange newInitOperands,
552  bool replaceInitOperandUsesInLoop,
553  const NewYieldValuesFn &newYieldValuesFn) {
554  // Create a new loop before the existing one, with the extra operands.
555  OpBuilder::InsertionGuard g(rewriter);
556  rewriter.setInsertionPoint(getOperation());
557  auto inits = llvm::to_vector(getInitArgs());
558  inits.append(newInitOperands.begin(), newInitOperands.end());
559  scf::ForOp newLoop = scf::ForOp::create(
560  rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
561  [](OpBuilder &, Location, Value, ValueRange) {});
562  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
563 
564  // Generate the new yield values and append them to the scf.yield operation.
565  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
566  ArrayRef<BlockArgument> newIterArgs =
567  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
568  {
569  OpBuilder::InsertionGuard g(rewriter);
570  rewriter.setInsertionPoint(yieldOp);
571  SmallVector<Value> newYieldedValues =
572  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
573  assert(newInitOperands.size() == newYieldedValues.size() &&
574  "expected as many new yield values as new iter operands");
575  rewriter.modifyOpInPlace(yieldOp, [&]() {
576  yieldOp.getResultsMutable().append(newYieldedValues);
577  });
578  }
579 
580  // Move the loop body to the new op.
581  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
582  newLoop.getBody()->getArguments().take_front(
583  getBody()->getNumArguments()));
584 
585  if (replaceInitOperandUsesInLoop) {
586  // Replace all uses of `newInitOperands` with the corresponding basic block
587  // arguments.
588  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
589  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
590  [&](OpOperand &use) {
591  Operation *user = use.getOwner();
592  return newLoop->isProperAncestor(user);
593  });
594  }
595  }
596 
597  // Replace the old loop.
598  rewriter.replaceOp(getOperation(),
599  newLoop->getResults().take_front(getNumResults()));
600  return cast<LoopLikeOpInterface>(newLoop.getOperation());
601 }
602 
604  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
605  if (!ivArg)
606  return ForOp();
607  assert(ivArg.getOwner() && "unlinked block argument");
608  auto *containingOp = ivArg.getOwner()->getParentOp();
609  return dyn_cast_or_null<ForOp>(containingOp);
610 }
611 
612 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
613  return getInitArgs();
614 }
615 
616 void ForOp::getSuccessorRegions(RegionBranchPoint point,
618  // Both the operation itself and the region may be branching into the body or
619  // back into the operation itself. It is possible for loop not to enter the
620  // body.
621  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
622  regions.push_back(RegionSuccessor(getResults()));
623 }
624 
625 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
626 
627 /// Promotes the loop body of a forallOp to its containing block if it can be
628 /// determined that the loop has a single iteration.
629 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
630  for (auto [lb, ub, step] :
631  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
632  auto tripCount = constantTripCount(lb, ub, step);
633  if (!tripCount.has_value() || *tripCount != 1)
634  return failure();
635  }
636 
637  promote(rewriter, *this);
638  return success();
639 }
640 
641 Block::BlockArgListType ForallOp::getRegionIterArgs() {
642  return getBody()->getArguments().drop_front(getRank());
643 }
644 
645 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
646  return getOutputsMutable();
647 }
648 
649 /// Promotes the loop body of a scf::ForallOp to its containing block.
650 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
651  OpBuilder::InsertionGuard g(rewriter);
652  scf::InParallelOp terminator = forallOp.getTerminator();
653 
654  // Replace block arguments with lower bounds (replacements for IVs) and
655  // outputs.
656  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
657  bbArgReplacements.append(forallOp.getOutputs().begin(),
658  forallOp.getOutputs().end());
659 
660  // Move the loop body operations to the loop's containing block.
661  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
662  forallOp->getIterator(), bbArgReplacements);
663 
664  // Replace the terminator with tensor.insert_slice ops.
665  rewriter.setInsertionPointAfter(forallOp);
666  SmallVector<Value> results;
667  results.reserve(forallOp.getResults().size());
668  for (auto &yieldingOp : terminator.getYieldingOps()) {
669  auto parallelInsertSliceOp =
670  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
671 
672  Value dst = parallelInsertSliceOp.getDest();
673  Value src = parallelInsertSliceOp.getSource();
674  if (llvm::isa<TensorType>(src.getType())) {
675  results.push_back(tensor::InsertSliceOp::create(
676  rewriter, forallOp.getLoc(), dst.getType(), src, dst,
677  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
678  parallelInsertSliceOp.getStrides(),
679  parallelInsertSliceOp.getStaticOffsets(),
680  parallelInsertSliceOp.getStaticSizes(),
681  parallelInsertSliceOp.getStaticStrides()));
682  } else {
683  llvm_unreachable("unsupported terminator");
684  }
685  }
686  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
687 
688  // Erase the old terminator and the loop.
689  rewriter.eraseOp(terminator);
690  rewriter.eraseOp(forallOp);
691 }
692 
694  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
695  ValueRange steps, ValueRange iterArgs,
697  bodyBuilder) {
698  assert(lbs.size() == ubs.size() &&
699  "expected the same number of lower and upper bounds");
700  assert(lbs.size() == steps.size() &&
701  "expected the same number of lower bounds and steps");
702 
703  // If there are no bounds, call the body-building function and return early.
704  if (lbs.empty()) {
705  ValueVector results =
706  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
707  : ValueVector();
708  assert(results.size() == iterArgs.size() &&
709  "loop nest body must return as many values as loop has iteration "
710  "arguments");
711  return LoopNest{{}, std::move(results)};
712  }
713 
714  // First, create the loop structure iteratively using the body-builder
715  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
716  OpBuilder::InsertionGuard guard(builder);
719  loops.reserve(lbs.size());
720  ivs.reserve(lbs.size());
721  ValueRange currentIterArgs = iterArgs;
722  Location currentLoc = loc;
723  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
724  auto loop = scf::ForOp::create(
725  builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
726  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
727  ValueRange args) {
728  ivs.push_back(iv);
729  // It is safe to store ValueRange args because it points to block
730  // arguments of a loop operation that we also own.
731  currentIterArgs = args;
732  currentLoc = nestedLoc;
733  });
734  // Set the builder to point to the body of the newly created loop. We don't
735  // do this in the callback because the builder is reset when the callback
736  // returns.
737  builder.setInsertionPointToStart(loop.getBody());
738  loops.push_back(loop);
739  }
740 
741  // For all loops but the innermost, yield the results of the nested loop.
742  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
743  builder.setInsertionPointToEnd(loops[i].getBody());
744  scf::YieldOp::create(builder, loc, loops[i + 1].getResults());
745  }
746 
747  // In the body of the innermost loop, call the body building function if any
748  // and yield its results.
749  builder.setInsertionPointToStart(loops.back().getBody());
750  ValueVector results = bodyBuilder
751  ? bodyBuilder(builder, currentLoc, ivs,
752  loops.back().getRegionIterArgs())
753  : ValueVector();
754  assert(results.size() == iterArgs.size() &&
755  "loop nest body must return as many values as loop has iteration "
756  "arguments");
757  builder.setInsertionPointToEnd(loops.back().getBody());
758  scf::YieldOp::create(builder, loc, results);
759 
760  // Return the loops.
761  ValueVector nestResults;
762  llvm::append_range(nestResults, loops.front().getResults());
763  return LoopNest{std::move(loops), std::move(nestResults)};
764 }
765 
767  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
768  ValueRange steps,
769  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
770  // Delegate to the main function by wrapping the body builder.
771  return buildLoopNest(builder, loc, lbs, ubs, steps, {},
772  [&bodyBuilder](OpBuilder &nestedBuilder,
773  Location nestedLoc, ValueRange ivs,
774  ValueRange) -> ValueVector {
775  if (bodyBuilder)
776  bodyBuilder(nestedBuilder, nestedLoc, ivs);
777  return {};
778  });
779 }
780 
783  OpOperand &operand, Value replacement,
784  const ValueTypeCastFnTy &castFn) {
785  assert(operand.getOwner() == forOp);
786  Type oldType = operand.get().getType(), newType = replacement.getType();
787 
788  // 1. Create new iter operands, exactly 1 is replaced.
789  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
790  "expected an iter OpOperand");
791  assert(operand.get().getType() != replacement.getType() &&
792  "Expected a different type");
793  SmallVector<Value> newIterOperands;
794  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
795  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
796  newIterOperands.push_back(replacement);
797  continue;
798  }
799  newIterOperands.push_back(opOperand.get());
800  }
801 
802  // 2. Create the new forOp shell.
803  scf::ForOp newForOp = scf::ForOp::create(
804  rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
805  forOp.getStep(), newIterOperands);
806  newForOp->setAttrs(forOp->getAttrs());
807  Block &newBlock = newForOp.getRegion().front();
808  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
809  newBlock.getArguments().end());
810 
811  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
812  // corresponding to the `replacement` value.
813  OpBuilder::InsertionGuard g(rewriter);
814  rewriter.setInsertionPointToStart(&newBlock);
815  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
816  &newForOp->getOpOperand(operand.getOperandNumber()));
817  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
818  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
819 
820  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
821  Block &oldBlock = forOp.getRegion().front();
822  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
823 
824  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
825  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
826  rewriter.setInsertionPoint(clonedYieldOp);
827  unsigned yieldIdx =
828  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
829  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
830  clonedYieldOp.getOperand(yieldIdx));
831  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
832  newYieldOperands[yieldIdx] = castOut;
833  scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands);
834  rewriter.eraseOp(clonedYieldOp);
835 
836  // 6. Inject an outgoing cast op after the forOp.
837  rewriter.setInsertionPointAfter(newForOp);
838  SmallVector<Value> newResults = newForOp.getResults();
839  newResults[yieldIdx] =
840  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
841 
842  return newResults;
843 }
844 
845 namespace {
846 // Fold away ForOp iter arguments when:
847 // 1) The op yields the iter arguments.
848 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
849 // 3) The iter arguments have no use and the corresponding (operation) results
850 // have no use.
851 //
852 // These arguments must be defined outside of the ForOp region and can just be
853 // forwarded after simplifying the op inits, yields and returns.
854 //
855 // The implementation uses `inlineBlockBefore` to steal the content of the
856 // original ForOp and avoid cloning.
857 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
859 
860  LogicalResult matchAndRewrite(scf::ForOp forOp,
861  PatternRewriter &rewriter) const final {
862  bool canonicalize = false;
863 
864  // An internal flat vector of block transfer
865  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
866  // transformed block argument mappings. This plays the role of a
867  // IRMapping for the particular use case of calling into
868  // `inlineBlockBefore`.
869  int64_t numResults = forOp.getNumResults();
870  SmallVector<bool, 4> keepMask;
871  keepMask.reserve(numResults);
872  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
873  newResultValues;
874  newBlockTransferArgs.reserve(1 + numResults);
875  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
876  newIterArgs.reserve(forOp.getInitArgs().size());
877  newYieldValues.reserve(numResults);
878  newResultValues.reserve(numResults);
879  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
880  for (auto [init, arg, result, yielded] :
881  llvm::zip(forOp.getInitArgs(), // iter from outside
882  forOp.getRegionIterArgs(), // iter inside region
883  forOp.getResults(), // op results
884  forOp.getYieldedValues() // iter yield
885  )) {
886  // Forwarded is `true` when:
887  // 1) The region `iter` argument is yielded.
888  // 2) The region `iter` argument the corresponding input is yielded.
889  // 3) The region `iter` argument has no use, and the corresponding op
890  // result has no use.
891  bool forwarded = (arg == yielded) || (init == yielded) ||
892  (arg.use_empty() && result.use_empty());
893  if (forwarded) {
894  canonicalize = true;
895  keepMask.push_back(false);
896  newBlockTransferArgs.push_back(init);
897  newResultValues.push_back(init);
898  continue;
899  }
900 
901  // Check if a previous kept argument always has the same values for init
902  // and yielded values.
903  if (auto it = initYieldToArg.find({init, yielded});
904  it != initYieldToArg.end()) {
905  canonicalize = true;
906  keepMask.push_back(false);
907  auto [sameArg, sameResult] = it->second;
908  rewriter.replaceAllUsesWith(arg, sameArg);
909  rewriter.replaceAllUsesWith(result, sameResult);
910  // The replacement value doesn't matter because there are no uses.
911  newBlockTransferArgs.push_back(init);
912  newResultValues.push_back(init);
913  continue;
914  }
915 
916  // This value is kept.
917  initYieldToArg.insert({{init, yielded}, {arg, result}});
918  keepMask.push_back(true);
919  newIterArgs.push_back(init);
920  newYieldValues.push_back(yielded);
921  newBlockTransferArgs.push_back(Value()); // placeholder with null value
922  newResultValues.push_back(Value()); // placeholder with null value
923  }
924 
925  if (!canonicalize)
926  return failure();
927 
928  scf::ForOp newForOp =
929  scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
930  forOp.getUpperBound(), forOp.getStep(), newIterArgs);
931  newForOp->setAttrs(forOp->getAttrs());
932  Block &newBlock = newForOp.getRegion().front();
933 
934  // Replace the null placeholders with newly constructed values.
935  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
936  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
937  idx != e; ++idx) {
938  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
939  Value &newResultVal = newResultValues[idx];
940  assert((blockTransferArg && newResultVal) ||
941  (!blockTransferArg && !newResultVal));
942  if (!blockTransferArg) {
943  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
944  newResultVal = newForOp.getResult(collapsedIdx++);
945  }
946  }
947 
948  Block &oldBlock = forOp.getRegion().front();
949  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
950  "unexpected argument size mismatch");
951 
952  // No results case: the scf::ForOp builder already created a zero
953  // result terminator. Merge before this terminator and just get rid of the
954  // original terminator that has been merged in.
955  if (newIterArgs.empty()) {
956  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
957  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
958  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
959  rewriter.replaceOp(forOp, newResultValues);
960  return success();
961  }
962 
963  // No terminator case: merge and rewrite the merged terminator.
964  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
965  OpBuilder::InsertionGuard g(rewriter);
966  rewriter.setInsertionPoint(mergedTerminator);
967  SmallVector<Value, 4> filteredOperands;
968  filteredOperands.reserve(newResultValues.size());
969  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
970  if (keepMask[idx])
971  filteredOperands.push_back(mergedTerminator.getOperand(idx));
972  scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
973  filteredOperands);
974  };
975 
976  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
977  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
978  cloneFilteredTerminator(mergedYieldOp);
979  rewriter.eraseOp(mergedYieldOp);
980  rewriter.replaceOp(forOp, newResultValues);
981  return success();
982  }
983 };
984 
985 /// Util function that tries to compute a constant diff between u and l.
986 /// Returns std::nullopt when the difference between two AffineValueMap is
987 /// dynamic.
988 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
989  IntegerAttr clb, cub;
990  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
991  llvm::APInt lbValue = clb.getValue();
992  llvm::APInt ubValue = cub.getValue();
993  return (ubValue - lbValue).getSExtValue();
994  }
995 
996  // Else a simple pattern match for x + c or c + x
997  llvm::APInt diff;
998  if (matchPattern(
999  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
1000  matchPattern(
1001  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1002  return diff.getSExtValue();
1003  return std::nullopt;
1004 }
1005 
1006 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1007 /// single-iteration loops with their bodies, and removes empty loops that
1008 /// iterate at least once and only return values defined outside of the loop.
1009 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1011 
1012  LogicalResult matchAndRewrite(ForOp op,
1013  PatternRewriter &rewriter) const override {
1014  // If the upper bound is the same as the lower bound, the loop does not
1015  // iterate, just remove it.
1016  if (op.getLowerBound() == op.getUpperBound()) {
1017  rewriter.replaceOp(op, op.getInitArgs());
1018  return success();
1019  }
1020 
1021  std::optional<int64_t> diff =
1022  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1023  if (!diff)
1024  return failure();
1025 
1026  // If the loop is known to have 0 iterations, remove it.
1027  if (*diff <= 0) {
1028  rewriter.replaceOp(op, op.getInitArgs());
1029  return success();
1030  }
1031 
1032  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1033  if (!maybeStepValue)
1034  return failure();
1035 
1036  // If the loop is known to have 1 iteration, inline its body and remove the
1037  // loop.
1038  llvm::APInt stepValue = *maybeStepValue;
1039  if (stepValue.sge(*diff)) {
1040  SmallVector<Value, 4> blockArgs;
1041  blockArgs.reserve(op.getInitArgs().size() + 1);
1042  blockArgs.push_back(op.getLowerBound());
1043  llvm::append_range(blockArgs, op.getInitArgs());
1044  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1045  return success();
1046  }
1047 
1048  // Now we are left with loops that have more than 1 iterations.
1049  Block &block = op.getRegion().front();
1050  if (!llvm::hasSingleElement(block))
1051  return failure();
1052  // If the loop is empty, iterates at least once, and only returns values
1053  // defined outside of the loop, remove it and replace it with yield values.
1054  if (llvm::any_of(op.getYieldedValues(),
1055  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1056  return failure();
1057  rewriter.replaceOp(op, op.getYieldedValues());
1058  return success();
1059  }
1060 };
1061 
1062 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1063 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1064 ///
1065 /// ```
1066 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1067 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1068 /// -> (tensor<?x?xf32>) {
1069 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1070 /// scf.yield %2 : tensor<?x?xf32>
1071 /// }
1072 /// use_of(%1)
1073 /// ```
1074 ///
1075 /// folds into:
1076 ///
1077 /// ```
1078 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1079 /// -> (tensor<32x1024xf32>) {
1080 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1081 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1082 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1083 /// scf.yield %4 : tensor<32x1024xf32>
1084 /// }
1085 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1086 /// use_of(%1)
1087 /// ```
1088 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1090 
1091  LogicalResult matchAndRewrite(ForOp op,
1092  PatternRewriter &rewriter) const override {
1093  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1094  OpOperand &iterOpOperand = std::get<0>(it);
1095  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1096  if (!incomingCast ||
1097  incomingCast.getSource().getType() == incomingCast.getType())
1098  continue;
1099  // If the dest type of the cast does not preserve static information in
1100  // the source type.
1102  incomingCast.getDest().getType(),
1103  incomingCast.getSource().getType()))
1104  continue;
1105  if (!std::get<1>(it).hasOneUse())
1106  continue;
1107 
1108  // Create a new ForOp with that iter operand replaced.
1109  rewriter.replaceOp(
1111  rewriter, op, iterOpOperand, incomingCast.getSource(),
1112  [](OpBuilder &b, Location loc, Type type, Value source) {
1113  return tensor::CastOp::create(b, loc, type, source);
1114  }));
1115  return success();
1116  }
1117  return failure();
1118  }
1119 };
1120 
1121 } // namespace
1122 
1123 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1124  MLIRContext *context) {
1125  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1126  context);
1127 }
1128 
1129 std::optional<APInt> ForOp::getConstantStep() {
1130  IntegerAttr step;
1131  if (matchPattern(getStep(), m_Constant(&step)))
1132  return step.getValue();
1133  return {};
1134 }
1135 
1136 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1137  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1138 }
1139 
1140 Speculation::Speculatability ForOp::getSpeculatability() {
1141  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1142  // and End.
1143  if (auto constantStep = getConstantStep())
1144  if (*constantStep == 1)
1146 
1147  // For Step != 1, the loop may not terminate. We can add more smarts here if
1148  // needed.
1150 }
1151 
1152 //===----------------------------------------------------------------------===//
1153 // ForallOp
1154 //===----------------------------------------------------------------------===//
1155 
1156 LogicalResult ForallOp::verify() {
1157  unsigned numLoops = getRank();
1158  // Check number of outputs.
1159  if (getNumResults() != getOutputs().size())
1160  return emitOpError("produces ")
1161  << getNumResults() << " results, but has only "
1162  << getOutputs().size() << " outputs";
1163 
1164  // Check that the body defines block arguments for thread indices and outputs.
1165  auto *body = getBody();
1166  if (body->getNumArguments() != numLoops + getOutputs().size())
1167  return emitOpError("region expects ") << numLoops << " arguments";
1168  for (int64_t i = 0; i < numLoops; ++i)
1169  if (!body->getArgument(i).getType().isIndex())
1170  return emitOpError("expects ")
1171  << i << "-th block argument to be an index";
1172  for (unsigned i = 0; i < getOutputs().size(); ++i)
1173  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1174  return emitOpError("type mismatch between ")
1175  << i << "-th output and corresponding block argument";
1176  if (getMapping().has_value() && !getMapping()->empty()) {
1177  if (getDeviceMappingAttrs().size() != numLoops)
1178  return emitOpError() << "mapping attribute size must match op rank";
1179  if (failed(getDeviceMaskingAttr()))
1180  return emitOpError() << getMappingAttrName()
1181  << " supports at most one device masking attribute";
1182  }
1183 
1184  // Verify mixed static/dynamic control variables.
1185  Operation *op = getOperation();
1186  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1187  getStaticLowerBound(),
1188  getDynamicLowerBound())))
1189  return failure();
1190  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1191  getStaticUpperBound(),
1192  getDynamicUpperBound())))
1193  return failure();
1194  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1195  getStaticStep(), getDynamicStep())))
1196  return failure();
1197 
1198  return success();
1199 }
1200 
1201 void ForallOp::print(OpAsmPrinter &p) {
1202  Operation *op = getOperation();
1203  p << " (" << getInductionVars();
1204  if (isNormalized()) {
1205  p << ") in ";
1206  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1207  /*valueTypes=*/{}, /*scalables=*/{},
1209  } else {
1210  p << ") = ";
1211  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1212  /*valueTypes=*/{}, /*scalables=*/{},
1214  p << " to ";
1215  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1216  /*valueTypes=*/{}, /*scalables=*/{},
1218  p << " step ";
1219  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1220  /*valueTypes=*/{}, /*scalables=*/{},
1222  }
1223  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1224  p << " ";
1225  if (!getRegionOutArgs().empty())
1226  p << "-> (" << getResultTypes() << ") ";
1227  p.printRegion(getRegion(),
1228  /*printEntryBlockArgs=*/false,
1229  /*printBlockTerminators=*/getNumResults() > 0);
1230  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1231  getStaticLowerBoundAttrName(),
1232  getStaticUpperBoundAttrName(),
1233  getStaticStepAttrName()});
1234 }
1235 
1236 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1237  OpBuilder b(parser.getContext());
1238  auto indexType = b.getIndexType();
1239 
1240  // Parse an opening `(` followed by thread index variables followed by `)`
1241  // TODO: when we can refer to such "induction variable"-like handles from the
1242  // declarative assembly format, we can implement the parser as a custom hook.
1245  return failure();
1246 
1247  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1248  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1249  dynamicSteps;
1250  if (succeeded(parser.parseOptionalKeyword("in"))) {
1251  // Parse upper bounds.
1252  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1253  /*valueTypes=*/nullptr,
1255  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1256  return failure();
1257 
1258  unsigned numLoops = ivs.size();
1259  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1260  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1261  } else {
1262  // Parse lower bounds.
1263  if (parser.parseEqual() ||
1264  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1265  /*valueTypes=*/nullptr,
1267 
1268  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1269  return failure();
1270 
1271  // Parse upper bounds.
1272  if (parser.parseKeyword("to") ||
1273  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1274  /*valueTypes=*/nullptr,
1276  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1277  return failure();
1278 
1279  // Parse step values.
1280  if (parser.parseKeyword("step") ||
1281  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1282  /*valueTypes=*/nullptr,
1284  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1285  return failure();
1286  }
1287 
1288  // Parse out operands and results.
1291  SMLoc outOperandsLoc = parser.getCurrentLocation();
1292  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1293  if (outOperands.size() != result.types.size())
1294  return parser.emitError(outOperandsLoc,
1295  "mismatch between out operands and types");
1296  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1297  parser.parseOptionalArrowTypeList(result.types) ||
1298  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1299  result.operands))
1300  return failure();
1301  }
1302 
1303  // Parse region.
1305  std::unique_ptr<Region> region = std::make_unique<Region>();
1306  for (auto &iv : ivs) {
1307  iv.type = b.getIndexType();
1308  regionArgs.push_back(iv);
1309  }
1310  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1311  auto &out = it.value();
1312  out.type = result.types[it.index()];
1313  regionArgs.push_back(out);
1314  }
1315  if (parser.parseRegion(*region, regionArgs))
1316  return failure();
1317 
1318  // Ensure terminator and move region.
1319  ForallOp::ensureTerminator(*region, b, result.location);
1320  result.addRegion(std::move(region));
1321 
1322  // Parse the optional attribute list.
1323  if (parser.parseOptionalAttrDict(result.attributes))
1324  return failure();
1325 
1326  result.addAttribute("staticLowerBound", staticLbs);
1327  result.addAttribute("staticUpperBound", staticUbs);
1328  result.addAttribute("staticStep", staticSteps);
1329  result.addAttribute("operandSegmentSizes",
1331  {static_cast<int32_t>(dynamicLbs.size()),
1332  static_cast<int32_t>(dynamicUbs.size()),
1333  static_cast<int32_t>(dynamicSteps.size()),
1334  static_cast<int32_t>(outOperands.size())}));
1335  return success();
1336 }
1337 
1338 // Builder that takes loop bounds.
1339 void ForallOp::build(
1342  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1343  std::optional<ArrayAttr> mapping,
1344  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1345  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1346  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1347  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1348  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1349  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1350 
1351  result.addOperands(dynamicLbs);
1352  result.addOperands(dynamicUbs);
1353  result.addOperands(dynamicSteps);
1354  result.addOperands(outputs);
1355  result.addTypes(TypeRange(outputs));
1356 
1357  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1358  b.getDenseI64ArrayAttr(staticLbs));
1359  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1360  b.getDenseI64ArrayAttr(staticUbs));
1361  result.addAttribute(getStaticStepAttrName(result.name),
1362  b.getDenseI64ArrayAttr(staticSteps));
1363  result.addAttribute(
1364  "operandSegmentSizes",
1365  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1366  static_cast<int32_t>(dynamicUbs.size()),
1367  static_cast<int32_t>(dynamicSteps.size()),
1368  static_cast<int32_t>(outputs.size())}));
1369  if (mapping.has_value()) {
1371  mapping.value());
1372  }
1373 
1374  Region *bodyRegion = result.addRegion();
1376  b.createBlock(bodyRegion);
1377  Block &bodyBlock = bodyRegion->front();
1378 
1379  // Add block arguments for indices and outputs.
1380  bodyBlock.addArguments(
1381  SmallVector<Type>(lbs.size(), b.getIndexType()),
1382  SmallVector<Location>(staticLbs.size(), result.location));
1383  bodyBlock.addArguments(
1384  TypeRange(outputs),
1385  SmallVector<Location>(outputs.size(), result.location));
1386 
1387  b.setInsertionPointToStart(&bodyBlock);
1388  if (!bodyBuilderFn) {
1389  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1390  return;
1391  }
1392  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1393 }
1394 
1395 // Builder that takes loop bounds.
1396 void ForallOp::build(
1398  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1399  std::optional<ArrayAttr> mapping,
1400  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1401  unsigned numLoops = ubs.size();
1402  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1403  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1404  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1405 }
1406 
1407 // Checks if the lbs are zeros and steps are ones.
1408 bool ForallOp::isNormalized() {
1409  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1410  return llvm::all_of(results, [&](OpFoldResult ofr) {
1411  auto intValue = getConstantIntValue(ofr);
1412  return intValue.has_value() && intValue == val;
1413  });
1414  };
1415  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1416 }
1417 
1418 InParallelOp ForallOp::getTerminator() {
1419  return cast<InParallelOp>(getBody()->getTerminator());
1420 }
1421 
1422 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1423  SmallVector<Operation *> storeOps;
1424  InParallelOp inParallelOp = getTerminator();
1425  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1426  if (auto parallelInsertSliceOp =
1427  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1428  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1429  storeOps.push_back(parallelInsertSliceOp);
1430  }
1431  }
1432  return storeOps;
1433 }
1434 
1435 SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() {
1437  if (!getMapping())
1438  return res;
1439  for (auto attr : getMapping()->getValue()) {
1440  auto m = dyn_cast<DeviceMappingAttrInterface>(attr);
1441  if (m)
1442  res.push_back(m);
1443  }
1444  return res;
1445 }
1446 
1447 FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() {
1448  DeviceMaskingAttrInterface res;
1449  if (!getMapping())
1450  return res;
1451  for (auto attr : getMapping()->getValue()) {
1452  auto m = dyn_cast<DeviceMaskingAttrInterface>(attr);
1453  if (m && res)
1454  return failure();
1455  if (m)
1456  res = m;
1457  }
1458  return res;
1459 }
1460 
1461 bool ForallOp::usesLinearMapping() {
1462  SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs();
1463  if (ifaces.empty())
1464  return false;
1465  return ifaces.front().isLinearMapping();
1466 }
1467 
1468 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1469  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1470 }
1471 
1472 // Get lower bounds as OpFoldResult.
1473 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1474  Builder b(getOperation()->getContext());
1475  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1476 }
1477 
1478 // Get upper bounds as OpFoldResult.
1479 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1480  Builder b(getOperation()->getContext());
1481  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1482 }
1483 
1484 // Get steps as OpFoldResult.
1485 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1486  Builder b(getOperation()->getContext());
1487  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1488 }
1489 
1491  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1492  if (!tidxArg)
1493  return ForallOp();
1494  assert(tidxArg.getOwner() && "unlinked block argument");
1495  auto *containingOp = tidxArg.getOwner()->getParentOp();
1496  return dyn_cast<ForallOp>(containingOp);
1497 }
1498 
1499 namespace {
1500 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1501 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1503 
1504  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1505  PatternRewriter &rewriter) const final {
1506  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1507  if (!forallOp)
1508  return failure();
1509  Value sharedOut =
1510  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1511  ->get();
1512  rewriter.modifyOpInPlace(
1513  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1514  return success();
1515  }
1516 };
1517 
1518 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1519 public:
1521 
1522  LogicalResult matchAndRewrite(ForallOp op,
1523  PatternRewriter &rewriter) const override {
1524  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1525  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1526  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1527  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1528  failed(foldDynamicIndexList(mixedUpperBound)) &&
1529  failed(foldDynamicIndexList(mixedStep)))
1530  return failure();
1531 
1532  rewriter.modifyOpInPlace(op, [&]() {
1533  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1534  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1535  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1536  staticLowerBound);
1537  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1538  op.setStaticLowerBound(staticLowerBound);
1539 
1540  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1541  staticUpperBound);
1542  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1543  op.setStaticUpperBound(staticUpperBound);
1544 
1545  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1546  op.getDynamicStepMutable().assign(dynamicStep);
1547  op.setStaticStep(staticStep);
1548 
1549  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1550  rewriter.getDenseI32ArrayAttr(
1551  {static_cast<int32_t>(dynamicLowerBound.size()),
1552  static_cast<int32_t>(dynamicUpperBound.size()),
1553  static_cast<int32_t>(dynamicStep.size()),
1554  static_cast<int32_t>(op.getNumResults())}));
1555  });
1556  return success();
1557  }
1558 };
1559 
1560 /// The following canonicalization pattern folds the iter arguments of
1561 /// scf.forall op if :-
1562 /// 1. The corresponding result has zero uses.
1563 /// 2. The iter argument is NOT being modified within the loop body.
1564 /// uses.
1565 ///
1566 /// Example of first case :-
1567 /// INPUT:
1568 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1569 /// {
1570 /// ...
1571 /// <SOME USE OF %arg0>
1572 /// <SOME USE OF %arg1>
1573 /// <SOME USE OF %arg2>
1574 /// ...
1575 /// scf.forall.in_parallel {
1576 /// <STORE OP WITH DESTINATION %arg1>
1577 /// <STORE OP WITH DESTINATION %arg0>
1578 /// <STORE OP WITH DESTINATION %arg2>
1579 /// }
1580 /// }
1581 /// return %res#1
1582 ///
1583 /// OUTPUT:
1584 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1585 /// {
1586 /// ...
1587 /// <SOME USE OF %a>
1588 /// <SOME USE OF %new_arg0>
1589 /// <SOME USE OF %c>
1590 /// ...
1591 /// scf.forall.in_parallel {
1592 /// <STORE OP WITH DESTINATION %new_arg0>
1593 /// }
1594 /// }
1595 /// return %res
1596 ///
1597 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1598 /// scf.forall is replaced by their corresponding operands.
1599 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1600 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1601 /// this canonicalization remains valid. For more details, please refer
1602 /// to :
1603 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1604 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1605 /// handles tensor.parallel_insert_slice ops only.
1606 ///
1607 /// Example of second case :-
1608 /// INPUT:
1609 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1610 /// {
1611 /// ...
1612 /// <SOME USE OF %arg0>
1613 /// <SOME USE OF %arg1>
1614 /// ...
1615 /// scf.forall.in_parallel {
1616 /// <STORE OP WITH DESTINATION %arg1>
1617 /// }
1618 /// }
1619 /// return %res#0, %res#1
1620 ///
1621 /// OUTPUT:
1622 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1623 /// {
1624 /// ...
1625 /// <SOME USE OF %a>
1626 /// <SOME USE OF %new_arg0>
1627 /// ...
1628 /// scf.forall.in_parallel {
1629 /// <STORE OP WITH DESTINATION %new_arg0>
1630 /// }
1631 /// }
1632 /// return %a, %res
1633 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1635 
1636  LogicalResult matchAndRewrite(ForallOp forallOp,
1637  PatternRewriter &rewriter) const final {
1638  // Step 1: For a given i-th result of scf.forall, check the following :-
1639  // a. If it has any use.
1640  // b. If the corresponding iter argument is being modified within
1641  // the loop, i.e. has at least one store op with the iter arg as
1642  // its destination operand. For this we use
1643  // ForallOp::getCombiningOps(iter_arg).
1644  //
1645  // Based on the check we maintain the following :-
1646  // a. `resultToDelete` - i-th result of scf.forall that'll be
1647  // deleted.
1648  // b. `resultToReplace` - i-th result of the old scf.forall
1649  // whose uses will be replaced by the new scf.forall.
1650  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1651  // corresponding to the i-th result with at least one use.
1652  SetVector<OpResult> resultToDelete;
1653  SmallVector<Value> resultToReplace;
1654  SmallVector<Value> newOuts;
1655  for (OpResult result : forallOp.getResults()) {
1656  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1657  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1658  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1659  resultToDelete.insert(result);
1660  } else {
1661  resultToReplace.push_back(result);
1662  newOuts.push_back(opOperand->get());
1663  }
1664  }
1665 
1666  // Return early if all results of scf.forall have at least one use and being
1667  // modified within the loop.
1668  if (resultToDelete.empty())
1669  return failure();
1670 
1671  // Step 2: For the the i-th result, do the following :-
1672  // a. Fetch the corresponding BlockArgument.
1673  // b. Look for store ops (currently tensor.parallel_insert_slice)
1674  // with the BlockArgument as its destination operand.
1675  // c. Remove the operations fetched in b.
1676  for (OpResult result : resultToDelete) {
1677  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1678  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1679  SmallVector<Operation *> combiningOps =
1680  forallOp.getCombiningOps(blockArg);
1681  for (Operation *combiningOp : combiningOps)
1682  rewriter.eraseOp(combiningOp);
1683  }
1684 
1685  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1686  // fetched earlier
1687  auto newForallOp = scf::ForallOp::create(
1688  rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1689  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1690  forallOp.getMapping(),
1691  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1692 
1693  // Step 4. Merge the block of the old scf.forall into the newly created
1694  // scf.forall using the new set of arguments.
1695  Block *loopBody = forallOp.getBody();
1696  Block *newLoopBody = newForallOp.getBody();
1697  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1698  // Form initial new bbArg list with just the control operands of the new
1699  // scf.forall op.
1700  SmallVector<Value> newBlockArgs =
1701  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1702  [](BlockArgument b) -> Value { return b; });
1703  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1704  unsigned index = 0;
1705  // Take the new corresponding bbArg if the old bbArg was used as a
1706  // destination in the in_parallel op. For all other bbArgs, use the
1707  // corresponding init_arg from the old scf.forall op.
1708  for (OpResult result : forallOp.getResults()) {
1709  if (resultToDelete.count(result)) {
1710  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1711  } else {
1712  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1713  }
1714  }
1715  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1716 
1717  // Step 5. Replace the uses of result of old scf.forall with that of the new
1718  // scf.forall.
1719  for (auto &&[oldResult, newResult] :
1720  llvm::zip(resultToReplace, newForallOp->getResults()))
1721  rewriter.replaceAllUsesWith(oldResult, newResult);
1722 
1723  // Step 6. Replace the uses of those values that either has no use or are
1724  // not being modified within the loop with the corresponding
1725  // OpOperand.
1726  for (OpResult oldResult : resultToDelete)
1727  rewriter.replaceAllUsesWith(oldResult,
1728  forallOp.getTiedOpOperand(oldResult)->get());
1729  return success();
1730  }
1731 };
1732 
1733 struct ForallOpSingleOrZeroIterationDimsFolder
1734  : public OpRewritePattern<ForallOp> {
1736 
1737  LogicalResult matchAndRewrite(ForallOp op,
1738  PatternRewriter &rewriter) const override {
1739  // Do not fold dimensions if they are mapped to processing units.
1740  if (op.getMapping().has_value() && !op.getMapping()->empty())
1741  return failure();
1742  Location loc = op.getLoc();
1743 
1744  // Compute new loop bounds that omit all single-iteration loop dimensions.
1745  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1746  newMixedSteps;
1747  IRMapping mapping;
1748  for (auto [lb, ub, step, iv] :
1749  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1750  op.getMixedStep(), op.getInductionVars())) {
1751  auto numIterations = constantTripCount(lb, ub, step);
1752  if (numIterations.has_value()) {
1753  // Remove the loop if it performs zero iterations.
1754  if (*numIterations == 0) {
1755  rewriter.replaceOp(op, op.getOutputs());
1756  return success();
1757  }
1758  // Replace the loop induction variable by the lower bound if the loop
1759  // performs a single iteration. Otherwise, copy the loop bounds.
1760  if (*numIterations == 1) {
1761  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1762  continue;
1763  }
1764  }
1765  newMixedLowerBounds.push_back(lb);
1766  newMixedUpperBounds.push_back(ub);
1767  newMixedSteps.push_back(step);
1768  }
1769 
1770  // All of the loop dimensions perform a single iteration. Inline loop body.
1771  if (newMixedLowerBounds.empty()) {
1772  promote(rewriter, op);
1773  return success();
1774  }
1775 
1776  // Exit if none of the loop dimensions perform a single iteration.
1777  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1778  return rewriter.notifyMatchFailure(
1779  op, "no dimensions have 0 or 1 iterations");
1780  }
1781 
1782  // Replace the loop by a lower-dimensional loop.
1783  ForallOp newOp;
1784  newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds,
1785  newMixedUpperBounds, newMixedSteps,
1786  op.getOutputs(), std::nullopt, nullptr);
1787  newOp.getBodyRegion().getBlocks().clear();
1788  // The new loop needs to keep all attributes from the old one, except for
1789  // "operandSegmentSizes" and static loop bound attributes which capture
1790  // the outdated information of the old iteration domain.
1791  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1792  newOp.getStaticLowerBoundAttrName(),
1793  newOp.getStaticUpperBoundAttrName(),
1794  newOp.getStaticStepAttrName()};
1795  for (const auto &namedAttr : op->getAttrs()) {
1796  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1797  continue;
1798  rewriter.modifyOpInPlace(newOp, [&]() {
1799  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1800  });
1801  }
1802  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1803  newOp.getRegion().begin(), mapping);
1804  rewriter.replaceOp(op, newOp.getResults());
1805  return success();
1806  }
1807 };
1808 
1809 /// Replace all induction vars with a single trip count with their lower bound.
1810 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1812 
1813  LogicalResult matchAndRewrite(ForallOp op,
1814  PatternRewriter &rewriter) const override {
1815  Location loc = op.getLoc();
1816  bool changed = false;
1817  for (auto [lb, ub, step, iv] :
1818  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1819  op.getMixedStep(), op.getInductionVars())) {
1820  if (iv.hasNUses(0))
1821  continue;
1822  auto numIterations = constantTripCount(lb, ub, step);
1823  if (!numIterations.has_value() || numIterations.value() != 1) {
1824  continue;
1825  }
1826  rewriter.replaceAllUsesWith(
1827  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1828  changed = true;
1829  }
1830  return success(changed);
1831  }
1832 };
1833 
1834 struct FoldTensorCastOfOutputIntoForallOp
1835  : public OpRewritePattern<scf::ForallOp> {
1837 
1838  struct TypeCast {
1839  Type srcType;
1840  Type dstType;
1841  };
1842 
1843  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1844  PatternRewriter &rewriter) const final {
1845  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1846  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1847  for (auto en : llvm::enumerate(newOutputTensors)) {
1848  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1849  if (!castOp)
1850  continue;
1851 
1852  // Only casts that that preserve static information, i.e. will make the
1853  // loop result type "more" static than before, will be folded.
1854  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1855  castOp.getSource().getType())) {
1856  continue;
1857  }
1858 
1859  tensorCastProducers[en.index()] =
1860  TypeCast{castOp.getSource().getType(), castOp.getType()};
1861  newOutputTensors[en.index()] = castOp.getSource();
1862  }
1863 
1864  if (tensorCastProducers.empty())
1865  return failure();
1866 
1867  // Create new loop.
1868  Location loc = forallOp.getLoc();
1869  auto newForallOp = ForallOp::create(
1870  rewriter, loc, forallOp.getMixedLowerBound(),
1871  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1872  newOutputTensors, forallOp.getMapping(),
1873  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1874  auto castBlockArgs =
1875  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1876  for (auto [index, cast] : tensorCastProducers) {
1877  Value &oldTypeBBArg = castBlockArgs[index];
1878  oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc,
1879  cast.dstType, oldTypeBBArg);
1880  }
1881 
1882  // Move old body into new parallel loop.
1883  SmallVector<Value> ivsBlockArgs =
1884  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1885  ivsBlockArgs.append(castBlockArgs);
1886  rewriter.mergeBlocks(forallOp.getBody(),
1887  bbArgs.front().getParentBlock(), ivsBlockArgs);
1888  });
1889 
1890  // After `mergeBlocks` happened, the destinations in the terminator were
1891  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1892  // destination have to be updated to point to the output bbArgs directly.
1893  auto terminator = newForallOp.getTerminator();
1894  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1895  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1896  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1897  insertSliceOp.getDestMutable().assign(outputBlockArg);
1898  }
1899 
1900  // Cast results back to the original types.
1901  rewriter.setInsertionPointAfter(newForallOp);
1902  SmallVector<Value> castResults = newForallOp.getResults();
1903  for (auto &item : tensorCastProducers) {
1904  Value &oldTypeResult = castResults[item.first];
1905  oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType,
1906  oldTypeResult);
1907  }
1908  rewriter.replaceOp(forallOp, castResults);
1909  return success();
1910  }
1911 };
1912 
1913 } // namespace
1914 
1915 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1916  MLIRContext *context) {
1917  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1918  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1919  ForallOpSingleOrZeroIterationDimsFolder,
1920  ForallOpReplaceConstantInductionVar>(context);
1921 }
1922 
1923 /// Given the region at `index`, or the parent operation if `index` is None,
1924 /// return the successor regions. These are the regions that may be selected
1925 /// during the flow of control. `operands` is a set of optional attributes that
1926 /// correspond to a constant value for each operand, or null if that operand is
1927 /// not a constant.
1928 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1930  // In accordance with the semantics of forall, its body is executed in
1931  // parallel by multiple threads. We should not expect to branch back into
1932  // the forall body after the region's execution is complete.
1933  if (point.isParent())
1934  regions.push_back(RegionSuccessor(&getRegion()));
1935  else
1936  regions.push_back(RegionSuccessor());
1937 }
1938 
1939 //===----------------------------------------------------------------------===//
1940 // InParallelOp
1941 //===----------------------------------------------------------------------===//
1942 
1943 // Build a InParallelOp with mixed static and dynamic entries.
1944 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1946  Region *bodyRegion = result.addRegion();
1947  b.createBlock(bodyRegion);
1948 }
1949 
1950 LogicalResult InParallelOp::verify() {
1951  scf::ForallOp forallOp =
1952  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1953  if (!forallOp)
1954  return this->emitOpError("expected forall op parent");
1955 
1956  // TODO: InParallelOpInterface.
1957  for (Operation &op : getRegion().front().getOperations()) {
1958  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1959  return this->emitOpError("expected only ")
1960  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1961  }
1962 
1963  // Verify that inserts are into out block arguments.
1964  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1965  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1966  if (!llvm::is_contained(regionOutArgs, dest))
1967  return op.emitOpError("may only insert into an output block argument");
1968  }
1969  return success();
1970 }
1971 
1973  p << " ";
1974  p.printRegion(getRegion(),
1975  /*printEntryBlockArgs=*/false,
1976  /*printBlockTerminators=*/false);
1977  p.printOptionalAttrDict(getOperation()->getAttrs());
1978 }
1979 
1980 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1981  auto &builder = parser.getBuilder();
1982 
1984  std::unique_ptr<Region> region = std::make_unique<Region>();
1985  if (parser.parseRegion(*region, regionOperands))
1986  return failure();
1987 
1988  if (region->empty())
1989  OpBuilder(builder.getContext()).createBlock(region.get());
1990  result.addRegion(std::move(region));
1991 
1992  // Parse the optional attribute list.
1993  if (parser.parseOptionalAttrDict(result.attributes))
1994  return failure();
1995  return success();
1996 }
1997 
1998 OpResult InParallelOp::getParentResult(int64_t idx) {
1999  return getOperation()->getParentOp()->getResult(idx);
2000 }
2001 
2002 SmallVector<BlockArgument> InParallelOp::getDests() {
2003  return llvm::to_vector<4>(
2004  llvm::map_range(getYieldingOps(), [](Operation &op) {
2005  // Add new ops here as needed.
2006  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
2007  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
2008  }));
2009 }
2010 
2011 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
2012  return getRegion().front().getOperations();
2013 }
2014 
2015 //===----------------------------------------------------------------------===//
2016 // IfOp
2017 //===----------------------------------------------------------------------===//
2018 
2020  assert(a && "expected non-empty operation");
2021  assert(b && "expected non-empty operation");
2022 
2023  IfOp ifOp = a->getParentOfType<IfOp>();
2024  while (ifOp) {
2025  // Check if b is inside ifOp. (We already know that a is.)
2026  if (ifOp->isProperAncestor(b))
2027  // b is contained in ifOp. a and b are in mutually exclusive branches if
2028  // they are in different blocks of ifOp.
2029  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2030  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2031  // Check next enclosing IfOp.
2032  ifOp = ifOp->getParentOfType<IfOp>();
2033  }
2034 
2035  // Could not find a common IfOp among a's and b's ancestors.
2036  return false;
2037 }
2038 
2039 LogicalResult
2040 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2041  IfOp::Adaptor adaptor,
2042  SmallVectorImpl<Type> &inferredReturnTypes) {
2043  if (adaptor.getRegions().empty())
2044  return failure();
2045  Region *r = &adaptor.getThenRegion();
2046  if (r->empty())
2047  return failure();
2048  Block &b = r->front();
2049  if (b.empty())
2050  return failure();
2051  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2052  if (!yieldOp)
2053  return failure();
2054  TypeRange types = yieldOp.getOperandTypes();
2055  llvm::append_range(inferredReturnTypes, types);
2056  return success();
2057 }
2058 
2059 void IfOp::build(OpBuilder &builder, OperationState &result,
2060  TypeRange resultTypes, Value cond) {
2061  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2062  /*addElseBlock=*/false);
2063 }
2064 
2065 void IfOp::build(OpBuilder &builder, OperationState &result,
2066  TypeRange resultTypes, Value cond, bool addThenBlock,
2067  bool addElseBlock) {
2068  assert((!addElseBlock || addThenBlock) &&
2069  "must not create else block w/o then block");
2070  result.addTypes(resultTypes);
2071  result.addOperands(cond);
2072 
2073  // Add regions and blocks.
2074  OpBuilder::InsertionGuard guard(builder);
2075  Region *thenRegion = result.addRegion();
2076  if (addThenBlock)
2077  builder.createBlock(thenRegion);
2078  Region *elseRegion = result.addRegion();
2079  if (addElseBlock)
2080  builder.createBlock(elseRegion);
2081 }
2082 
2083 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2084  bool withElseRegion) {
2085  build(builder, result, TypeRange{}, cond, withElseRegion);
2086 }
2087 
2088 void IfOp::build(OpBuilder &builder, OperationState &result,
2089  TypeRange resultTypes, Value cond, bool withElseRegion) {
2090  result.addTypes(resultTypes);
2091  result.addOperands(cond);
2092 
2093  // Build then region.
2094  OpBuilder::InsertionGuard guard(builder);
2095  Region *thenRegion = result.addRegion();
2096  builder.createBlock(thenRegion);
2097  if (resultTypes.empty())
2098  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2099 
2100  // Build else region.
2101  Region *elseRegion = result.addRegion();
2102  if (withElseRegion) {
2103  builder.createBlock(elseRegion);
2104  if (resultTypes.empty())
2105  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2106  }
2107 }
2108 
2109 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2110  function_ref<void(OpBuilder &, Location)> thenBuilder,
2111  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2112  assert(thenBuilder && "the builder callback for 'then' must be present");
2113  result.addOperands(cond);
2114 
2115  // Build then region.
2116  OpBuilder::InsertionGuard guard(builder);
2117  Region *thenRegion = result.addRegion();
2118  builder.createBlock(thenRegion);
2119  thenBuilder(builder, result.location);
2120 
2121  // Build else region.
2122  Region *elseRegion = result.addRegion();
2123  if (elseBuilder) {
2124  builder.createBlock(elseRegion);
2125  elseBuilder(builder, result.location);
2126  }
2127 
2128  // Infer result types.
2129  SmallVector<Type> inferredReturnTypes;
2130  MLIRContext *ctx = builder.getContext();
2131  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2132  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2133  /*properties=*/nullptr, result.regions,
2134  inferredReturnTypes))) {
2135  result.addTypes(inferredReturnTypes);
2136  }
2137 }
2138 
2139 LogicalResult IfOp::verify() {
2140  if (getNumResults() != 0 && getElseRegion().empty())
2141  return emitOpError("must have an else block if defining values");
2142  return success();
2143 }
2144 
2145 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2146  // Create the regions for 'then'.
2147  result.regions.reserve(2);
2148  Region *thenRegion = result.addRegion();
2149  Region *elseRegion = result.addRegion();
2150 
2151  auto &builder = parser.getBuilder();
2153  Type i1Type = builder.getIntegerType(1);
2154  if (parser.parseOperand(cond) ||
2155  parser.resolveOperand(cond, i1Type, result.operands))
2156  return failure();
2157  // Parse optional results type list.
2158  if (parser.parseOptionalArrowTypeList(result.types))
2159  return failure();
2160  // Parse the 'then' region.
2161  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2162  return failure();
2163  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2164 
2165  // If we find an 'else' keyword then parse the 'else' region.
2166  if (!parser.parseOptionalKeyword("else")) {
2167  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2168  return failure();
2169  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2170  }
2171 
2172  // Parse the optional attribute list.
2173  if (parser.parseOptionalAttrDict(result.attributes))
2174  return failure();
2175  return success();
2176 }
2177 
2178 void IfOp::print(OpAsmPrinter &p) {
2179  bool printBlockTerminators = false;
2180 
2181  p << " " << getCondition();
2182  if (!getResults().empty()) {
2183  p << " -> (" << getResultTypes() << ")";
2184  // Print yield explicitly if the op defines values.
2185  printBlockTerminators = true;
2186  }
2187  p << ' ';
2188  p.printRegion(getThenRegion(),
2189  /*printEntryBlockArgs=*/false,
2190  /*printBlockTerminators=*/printBlockTerminators);
2191 
2192  // Print the 'else' regions if it exists and has a block.
2193  auto &elseRegion = getElseRegion();
2194  if (!elseRegion.empty()) {
2195  p << " else ";
2196  p.printRegion(elseRegion,
2197  /*printEntryBlockArgs=*/false,
2198  /*printBlockTerminators=*/printBlockTerminators);
2199  }
2200 
2201  p.printOptionalAttrDict((*this)->getAttrs());
2202 }
2203 
2204 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2206  // The `then` and the `else` region branch back to the parent operation.
2207  if (!point.isParent()) {
2208  regions.push_back(RegionSuccessor(getResults()));
2209  return;
2210  }
2211 
2212  regions.push_back(RegionSuccessor(&getThenRegion()));
2213 
2214  // Don't consider the else region if it is empty.
2215  Region *elseRegion = &this->getElseRegion();
2216  if (elseRegion->empty())
2217  regions.push_back(RegionSuccessor());
2218  else
2219  regions.push_back(RegionSuccessor(elseRegion));
2220 }
2221 
2222 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2224  FoldAdaptor adaptor(operands, *this);
2225  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2226  if (!boolAttr || boolAttr.getValue())
2227  regions.emplace_back(&getThenRegion());
2228 
2229  // If the else region is empty, execution continues after the parent op.
2230  if (!boolAttr || !boolAttr.getValue()) {
2231  if (!getElseRegion().empty())
2232  regions.emplace_back(&getElseRegion());
2233  else
2234  regions.emplace_back(getResults());
2235  }
2236 }
2237 
2238 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2239  SmallVectorImpl<OpFoldResult> &results) {
2240  // if (!c) then A() else B() -> if c then B() else A()
2241  if (getElseRegion().empty())
2242  return failure();
2243 
2244  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2245  if (!xorStmt)
2246  return failure();
2247 
2248  if (!matchPattern(xorStmt.getRhs(), m_One()))
2249  return failure();
2250 
2251  getConditionMutable().assign(xorStmt.getLhs());
2252  Block *thenBlock = &getThenRegion().front();
2253  // It would be nicer to use iplist::swap, but that has no implemented
2254  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2255  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2256  getElseRegion().getBlocks());
2257  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2258  getThenRegion().getBlocks(), thenBlock);
2259  return success();
2260 }
2261 
2262 void IfOp::getRegionInvocationBounds(
2263  ArrayRef<Attribute> operands,
2264  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2265  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2266  // If the condition is known, then one region is known to be executed once
2267  // and the other zero times.
2268  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2269  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2270  } else {
2271  // Non-constant condition. Each region may be executed 0 or 1 times.
2272  invocationBounds.assign(2, {0, 1});
2273  }
2274 }
2275 
2276 namespace {
2277 // Pattern to remove unused IfOp results.
2278 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2280 
2281  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2282  PatternRewriter &rewriter) const {
2283  // Move all operations to the destination block.
2284  rewriter.mergeBlocks(source, dest);
2285  // Replace the yield op by one that returns only the used values.
2286  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2287  SmallVector<Value, 4> usedOperands;
2288  llvm::transform(usedResults, std::back_inserter(usedOperands),
2289  [&](OpResult result) {
2290  return yieldOp.getOperand(result.getResultNumber());
2291  });
2292  rewriter.modifyOpInPlace(yieldOp,
2293  [&]() { yieldOp->setOperands(usedOperands); });
2294  }
2295 
2296  LogicalResult matchAndRewrite(IfOp op,
2297  PatternRewriter &rewriter) const override {
2298  // Compute the list of used results.
2299  SmallVector<OpResult, 4> usedResults;
2300  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2301  [](OpResult result) { return !result.use_empty(); });
2302 
2303  // Replace the operation if only a subset of its results have uses.
2304  if (usedResults.size() == op.getNumResults())
2305  return failure();
2306 
2307  // Compute the result types of the replacement operation.
2308  SmallVector<Type, 4> newTypes;
2309  llvm::transform(usedResults, std::back_inserter(newTypes),
2310  [](OpResult result) { return result.getType(); });
2311 
2312  // Create a replacement operation with empty then and else regions.
2313  auto newOp =
2314  IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
2315  rewriter.createBlock(&newOp.getThenRegion());
2316  rewriter.createBlock(&newOp.getElseRegion());
2317 
2318  // Move the bodies and replace the terminators (note there is a then and
2319  // an else region since the operation returns results).
2320  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2321  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2322 
2323  // Replace the operation by the new one.
2324  SmallVector<Value, 4> repResults(op.getNumResults());
2325  for (const auto &en : llvm::enumerate(usedResults))
2326  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2327  rewriter.replaceOp(op, repResults);
2328  return success();
2329  }
2330 };
2331 
2332 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2334 
2335  LogicalResult matchAndRewrite(IfOp op,
2336  PatternRewriter &rewriter) const override {
2337  BoolAttr condition;
2338  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2339  return failure();
2340 
2341  if (condition.getValue())
2342  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2343  else if (!op.getElseRegion().empty())
2344  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2345  else
2346  rewriter.eraseOp(op);
2347 
2348  return success();
2349  }
2350 };
2351 
2352 /// Hoist any yielded results whose operands are defined outside
2353 /// the if, to a select instruction.
2354 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2356 
2357  LogicalResult matchAndRewrite(IfOp op,
2358  PatternRewriter &rewriter) const override {
2359  if (op->getNumResults() == 0)
2360  return failure();
2361 
2362  auto cond = op.getCondition();
2363  auto thenYieldArgs = op.thenYield().getOperands();
2364  auto elseYieldArgs = op.elseYield().getOperands();
2365 
2366  SmallVector<Type> nonHoistable;
2367  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2368  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2369  &op.getElseRegion() == falseVal.getParentRegion())
2370  nonHoistable.push_back(trueVal.getType());
2371  }
2372  // Early exit if there aren't any yielded values we can
2373  // hoist outside the if.
2374  if (nonHoistable.size() == op->getNumResults())
2375  return failure();
2376 
2377  IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond,
2378  /*withElseRegion=*/false);
2379  if (replacement.thenBlock())
2380  rewriter.eraseBlock(replacement.thenBlock());
2381  replacement.getThenRegion().takeBody(op.getThenRegion());
2382  replacement.getElseRegion().takeBody(op.getElseRegion());
2383 
2384  SmallVector<Value> results(op->getNumResults());
2385  assert(thenYieldArgs.size() == results.size());
2386  assert(elseYieldArgs.size() == results.size());
2387 
2388  SmallVector<Value> trueYields;
2389  SmallVector<Value> falseYields;
2390  rewriter.setInsertionPoint(replacement);
2391  for (const auto &it :
2392  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2393  Value trueVal = std::get<0>(it.value());
2394  Value falseVal = std::get<1>(it.value());
2395  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2396  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2397  results[it.index()] = replacement.getResult(trueYields.size());
2398  trueYields.push_back(trueVal);
2399  falseYields.push_back(falseVal);
2400  } else if (trueVal == falseVal)
2401  results[it.index()] = trueVal;
2402  else
2403  results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(),
2404  cond, trueVal, falseVal);
2405  }
2406 
2407  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2408  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2409 
2410  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2411  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2412 
2413  rewriter.replaceOp(op, results);
2414  return success();
2415  }
2416 };
2417 
2418 /// Remove any statements from an if that are equivalent to the condition
2419 /// or its negation. For example:
2420 ///
2421 /// %res:2 = scf.if %cmp {
2422 /// yield something(), true
2423 /// } else {
2424 /// yield something2(), false
2425 /// }
2426 /// print(%res#1)
2427 ///
2428 /// becomes
2429 /// %res = scf.if %cmp {
2430 /// yield something()
2431 /// } else {
2432 /// yield something2()
2433 /// }
2434 /// print(%cmp)
2435 ///
2436 /// Additionally if both branches yield the same value, replace all uses
2437 /// of the result with the yielded value.
2438 ///
2439 /// %res:2 = scf.if %cmp {
2440 /// yield something(), %arg1
2441 /// } else {
2442 /// yield something2(), %arg1
2443 /// }
2444 /// print(%res#1)
2445 ///
2446 /// becomes
2447 /// %res = scf.if %cmp {
2448 /// yield something()
2449 /// } else {
2450 /// yield something2()
2451 /// }
2452 /// print(%arg1)
2453 ///
2454 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2456 
2457  LogicalResult matchAndRewrite(IfOp op,
2458  PatternRewriter &rewriter) const override {
2459  // Early exit if there are no results that could be replaced.
2460  if (op.getNumResults() == 0)
2461  return failure();
2462 
2463  auto trueYield =
2464  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2465  auto falseYield =
2466  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2467 
2468  rewriter.setInsertionPoint(op->getBlock(),
2469  op.getOperation()->getIterator());
2470  bool changed = false;
2471  Type i1Ty = rewriter.getI1Type();
2472  for (auto [trueResult, falseResult, opResult] :
2473  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2474  op.getResults())) {
2475  if (trueResult == falseResult) {
2476  if (!opResult.use_empty()) {
2477  opResult.replaceAllUsesWith(trueResult);
2478  changed = true;
2479  }
2480  continue;
2481  }
2482 
2483  BoolAttr trueYield, falseYield;
2484  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2485  !matchPattern(falseResult, m_Constant(&falseYield)))
2486  continue;
2487 
2488  bool trueVal = trueYield.getValue();
2489  bool falseVal = falseYield.getValue();
2490  if (!trueVal && falseVal) {
2491  if (!opResult.use_empty()) {
2492  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2493  Value notCond = arith::XOrIOp::create(
2494  rewriter, op.getLoc(), op.getCondition(),
2495  constDialect
2496  ->materializeConstant(rewriter,
2497  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2498  op.getLoc())
2499  ->getResult(0));
2500  opResult.replaceAllUsesWith(notCond);
2501  changed = true;
2502  }
2503  }
2504  if (trueVal && !falseVal) {
2505  if (!opResult.use_empty()) {
2506  opResult.replaceAllUsesWith(op.getCondition());
2507  changed = true;
2508  }
2509  }
2510  }
2511  return success(changed);
2512  }
2513 };
2514 
2515 /// Merge any consecutive scf.if's with the same condition.
2516 ///
2517 /// scf.if %cond {
2518 /// firstCodeTrue();...
2519 /// } else {
2520 /// firstCodeFalse();...
2521 /// }
2522 /// %res = scf.if %cond {
2523 /// secondCodeTrue();...
2524 /// } else {
2525 /// secondCodeFalse();...
2526 /// }
2527 ///
2528 /// becomes
2529 /// %res = scf.if %cmp {
2530 /// firstCodeTrue();...
2531 /// secondCodeTrue();...
2532 /// } else {
2533 /// firstCodeFalse();...
2534 /// secondCodeFalse();...
2535 /// }
2536 struct CombineIfs : public OpRewritePattern<IfOp> {
2538 
2539  LogicalResult matchAndRewrite(IfOp nextIf,
2540  PatternRewriter &rewriter) const override {
2541  Block *parent = nextIf->getBlock();
2542  if (nextIf == &parent->front())
2543  return failure();
2544 
2545  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2546  if (!prevIf)
2547  return failure();
2548 
2549  // Determine the logical then/else blocks when prevIf's
2550  // condition is used. Null means the block does not exist
2551  // in that case (e.g. empty else). If neither of these
2552  // are set, the two conditions cannot be compared.
2553  Block *nextThen = nullptr;
2554  Block *nextElse = nullptr;
2555  if (nextIf.getCondition() == prevIf.getCondition()) {
2556  nextThen = nextIf.thenBlock();
2557  if (!nextIf.getElseRegion().empty())
2558  nextElse = nextIf.elseBlock();
2559  }
2560  if (arith::XOrIOp notv =
2561  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2562  if (notv.getLhs() == prevIf.getCondition() &&
2563  matchPattern(notv.getRhs(), m_One())) {
2564  nextElse = nextIf.thenBlock();
2565  if (!nextIf.getElseRegion().empty())
2566  nextThen = nextIf.elseBlock();
2567  }
2568  }
2569  if (arith::XOrIOp notv =
2570  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2571  if (notv.getLhs() == nextIf.getCondition() &&
2572  matchPattern(notv.getRhs(), m_One())) {
2573  nextElse = nextIf.thenBlock();
2574  if (!nextIf.getElseRegion().empty())
2575  nextThen = nextIf.elseBlock();
2576  }
2577  }
2578 
2579  if (!nextThen && !nextElse)
2580  return failure();
2581 
2582  SmallVector<Value> prevElseYielded;
2583  if (!prevIf.getElseRegion().empty())
2584  prevElseYielded = prevIf.elseYield().getOperands();
2585  // Replace all uses of return values of op within nextIf with the
2586  // corresponding yields
2587  for (auto it : llvm::zip(prevIf.getResults(),
2588  prevIf.thenYield().getOperands(), prevElseYielded))
2589  for (OpOperand &use :
2590  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2591  if (nextThen && nextThen->getParent()->isAncestor(
2592  use.getOwner()->getParentRegion())) {
2593  rewriter.startOpModification(use.getOwner());
2594  use.set(std::get<1>(it));
2595  rewriter.finalizeOpModification(use.getOwner());
2596  } else if (nextElse && nextElse->getParent()->isAncestor(
2597  use.getOwner()->getParentRegion())) {
2598  rewriter.startOpModification(use.getOwner());
2599  use.set(std::get<2>(it));
2600  rewriter.finalizeOpModification(use.getOwner());
2601  }
2602  }
2603 
2604  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2605  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2606 
2607  IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes,
2608  prevIf.getCondition(), /*hasElse=*/false);
2609  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2610 
2611  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2612  combinedIf.getThenRegion(),
2613  combinedIf.getThenRegion().begin());
2614 
2615  if (nextThen) {
2616  YieldOp thenYield = combinedIf.thenYield();
2617  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2618  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2619  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2620 
2621  SmallVector<Value> mergedYields(thenYield.getOperands());
2622  llvm::append_range(mergedYields, thenYield2.getOperands());
2623  YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
2624  rewriter.eraseOp(thenYield);
2625  rewriter.eraseOp(thenYield2);
2626  }
2627 
2628  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2629  combinedIf.getElseRegion(),
2630  combinedIf.getElseRegion().begin());
2631 
2632  if (nextElse) {
2633  if (combinedIf.getElseRegion().empty()) {
2634  rewriter.inlineRegionBefore(*nextElse->getParent(),
2635  combinedIf.getElseRegion(),
2636  combinedIf.getElseRegion().begin());
2637  } else {
2638  YieldOp elseYield = combinedIf.elseYield();
2639  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2640  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2641 
2642  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2643 
2644  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2645  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2646 
2647  YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
2648  rewriter.eraseOp(elseYield);
2649  rewriter.eraseOp(elseYield2);
2650  }
2651  }
2652 
2653  SmallVector<Value> prevValues;
2654  SmallVector<Value> nextValues;
2655  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2656  if (pair.index() < prevIf.getNumResults())
2657  prevValues.push_back(pair.value());
2658  else
2659  nextValues.push_back(pair.value());
2660  }
2661  rewriter.replaceOp(prevIf, prevValues);
2662  rewriter.replaceOp(nextIf, nextValues);
2663  return success();
2664  }
2665 };
2666 
2667 /// Pattern to remove an empty else branch.
2668 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2670 
2671  LogicalResult matchAndRewrite(IfOp ifOp,
2672  PatternRewriter &rewriter) const override {
2673  // Cannot remove else region when there are operation results.
2674  if (ifOp.getNumResults())
2675  return failure();
2676  Block *elseBlock = ifOp.elseBlock();
2677  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2678  return failure();
2679  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2680  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2681  newIfOp.getThenRegion().begin());
2682  rewriter.eraseOp(ifOp);
2683  return success();
2684  }
2685 };
2686 
2687 /// Convert nested `if`s into `arith.andi` + single `if`.
2688 ///
2689 /// scf.if %arg0 {
2690 /// scf.if %arg1 {
2691 /// ...
2692 /// scf.yield
2693 /// }
2694 /// scf.yield
2695 /// }
2696 /// becomes
2697 ///
2698 /// %0 = arith.andi %arg0, %arg1
2699 /// scf.if %0 {
2700 /// ...
2701 /// scf.yield
2702 /// }
2703 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2705 
2706  LogicalResult matchAndRewrite(IfOp op,
2707  PatternRewriter &rewriter) const override {
2708  auto nestedOps = op.thenBlock()->without_terminator();
2709  // Nested `if` must be the only op in block.
2710  if (!llvm::hasSingleElement(nestedOps))
2711  return failure();
2712 
2713  // If there is an else block, it can only yield
2714  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2715  return failure();
2716 
2717  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2718  if (!nestedIf)
2719  return failure();
2720 
2721  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2722  return failure();
2723 
2724  SmallVector<Value> thenYield(op.thenYield().getOperands());
2725  SmallVector<Value> elseYield;
2726  if (op.elseBlock())
2727  llvm::append_range(elseYield, op.elseYield().getOperands());
2728 
2729  // A list of indices for which we should upgrade the value yielded
2730  // in the else to a select.
2731  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2732 
2733  // If the outer scf.if yields a value produced by the inner scf.if,
2734  // only permit combining if the value yielded when the condition
2735  // is false in the outer scf.if is the same value yielded when the
2736  // inner scf.if condition is false.
2737  // Note that the array access to elseYield will not go out of bounds
2738  // since it must have the same length as thenYield, since they both
2739  // come from the same scf.if.
2740  for (const auto &tup : llvm::enumerate(thenYield)) {
2741  if (tup.value().getDefiningOp() == nestedIf) {
2742  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2743  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2744  elseYield[tup.index()]) {
2745  return failure();
2746  }
2747  // If the correctness test passes, we will yield
2748  // corresponding value from the inner scf.if
2749  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2750  continue;
2751  }
2752 
2753  // Otherwise, we need to ensure the else block of the combined
2754  // condition still returns the same value when the outer condition is
2755  // true and the inner condition is false. This can be accomplished if
2756  // the then value is defined outside the outer scf.if and we replace the
2757  // value with a select that considers just the outer condition. Since
2758  // the else region contains just the yield, its yielded value is
2759  // defined outside the scf.if, by definition.
2760 
2761  // If the then value is defined within the scf.if, bail.
2762  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2763  return failure();
2764  }
2765  elseYieldsToUpgradeToSelect.push_back(tup.index());
2766  }
2767 
2768  Location loc = op.getLoc();
2769  Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(),
2770  nestedIf.getCondition());
2771  auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition);
2772  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2773 
2774  SmallVector<Value> results;
2775  llvm::append_range(results, newIf.getResults());
2776  rewriter.setInsertionPoint(newIf);
2777 
2778  for (auto idx : elseYieldsToUpgradeToSelect)
2779  results[idx] =
2780  arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
2781  thenYield[idx], elseYield[idx]);
2782 
2783  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2784  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2785  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2786  if (!elseYield.empty()) {
2787  rewriter.createBlock(&newIf.getElseRegion());
2788  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2789  YieldOp::create(rewriter, loc, elseYield);
2790  }
2791  rewriter.replaceOp(op, results);
2792  return success();
2793  }
2794 };
2795 
2796 } // namespace
2797 
2798 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2799  MLIRContext *context) {
2800  results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect,
2801  RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults,
2802  ReplaceIfYieldWithConditionOrValue>(context);
2803 }
2804 
2805 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2806 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2807 Block *IfOp::elseBlock() {
2808  Region &r = getElseRegion();
2809  if (r.empty())
2810  return nullptr;
2811  return &r.back();
2812 }
2813 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2814 
2815 //===----------------------------------------------------------------------===//
2816 // ParallelOp
2817 //===----------------------------------------------------------------------===//
2818 
2819 void ParallelOp::build(
2820  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2821  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2823  bodyBuilderFn) {
2824  result.addOperands(lowerBounds);
2825  result.addOperands(upperBounds);
2826  result.addOperands(steps);
2827  result.addOperands(initVals);
2828  result.addAttribute(
2829  ParallelOp::getOperandSegmentSizeAttr(),
2830  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2831  static_cast<int32_t>(upperBounds.size()),
2832  static_cast<int32_t>(steps.size()),
2833  static_cast<int32_t>(initVals.size())}));
2834  result.addTypes(initVals.getTypes());
2835 
2836  OpBuilder::InsertionGuard guard(builder);
2837  unsigned numIVs = steps.size();
2838  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2839  SmallVector<Location, 8> argLocs(numIVs, result.location);
2840  Region *bodyRegion = result.addRegion();
2841  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2842 
2843  if (bodyBuilderFn) {
2844  builder.setInsertionPointToStart(bodyBlock);
2845  bodyBuilderFn(builder, result.location,
2846  bodyBlock->getArguments().take_front(numIVs),
2847  bodyBlock->getArguments().drop_front(numIVs));
2848  }
2849  // Add terminator only if there are no reductions.
2850  if (initVals.empty())
2851  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2852 }
2853 
2854 void ParallelOp::build(
2855  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2856  ValueRange upperBounds, ValueRange steps,
2857  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2858  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2859  // we don't capture a reference to a temporary by constructing the lambda at
2860  // function level.
2861  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2862  Location nestedLoc, ValueRange ivs,
2863  ValueRange) {
2864  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2865  };
2866  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2867  if (bodyBuilderFn)
2868  wrapper = wrappedBuilderFn;
2869 
2870  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2871  wrapper);
2872 }
2873 
2874 LogicalResult ParallelOp::verify() {
2875  // Check that there is at least one value in lowerBound, upperBound and step.
2876  // It is sufficient to test only step, because it is ensured already that the
2877  // number of elements in lowerBound, upperBound and step are the same.
2878  Operation::operand_range stepValues = getStep();
2879  if (stepValues.empty())
2880  return emitOpError(
2881  "needs at least one tuple element for lowerBound, upperBound and step");
2882 
2883  // Check whether all constant step values are positive.
2884  for (Value stepValue : stepValues)
2885  if (auto cst = getConstantIntValue(stepValue))
2886  if (*cst <= 0)
2887  return emitOpError("constant step operand must be positive");
2888 
2889  // Check that the body defines the same number of block arguments as the
2890  // number of tuple elements in step.
2891  Block *body = getBody();
2892  if (body->getNumArguments() != stepValues.size())
2893  return emitOpError() << "expects the same number of induction variables: "
2894  << body->getNumArguments()
2895  << " as bound and step values: " << stepValues.size();
2896  for (auto arg : body->getArguments())
2897  if (!arg.getType().isIndex())
2898  return emitOpError(
2899  "expects arguments for the induction variable to be of index type");
2900 
2901  // Check that the terminator is an scf.reduce op.
2902  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2903  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2904  if (!reduceOp)
2905  return failure();
2906 
2907  // Check that the number of results is the same as the number of reductions.
2908  auto resultsSize = getResults().size();
2909  auto reductionsSize = reduceOp.getReductions().size();
2910  auto initValsSize = getInitVals().size();
2911  if (resultsSize != reductionsSize)
2912  return emitOpError() << "expects number of results: " << resultsSize
2913  << " to be the same as number of reductions: "
2914  << reductionsSize;
2915  if (resultsSize != initValsSize)
2916  return emitOpError() << "expects number of results: " << resultsSize
2917  << " to be the same as number of initial values: "
2918  << initValsSize;
2919 
2920  // Check that the types of the results and reductions are the same.
2921  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2922  auto resultType = getOperation()->getResult(i).getType();
2923  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2924  if (resultType != reductionOperandType)
2925  return reduceOp.emitOpError()
2926  << "expects type of " << i
2927  << "-th reduction operand: " << reductionOperandType
2928  << " to be the same as the " << i
2929  << "-th result type: " << resultType;
2930  }
2931  return success();
2932 }
2933 
2934 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2935  auto &builder = parser.getBuilder();
2936  // Parse an opening `(` followed by induction variables followed by `)`
2938  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
2939  return failure();
2940 
2941  // Parse loop bounds.
2943  if (parser.parseEqual() ||
2944  parser.parseOperandList(lower, ivs.size(),
2945  OpAsmParser::Delimiter::Paren) ||
2946  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2947  return failure();
2948 
2950  if (parser.parseKeyword("to") ||
2951  parser.parseOperandList(upper, ivs.size(),
2952  OpAsmParser::Delimiter::Paren) ||
2953  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2954  return failure();
2955 
2956  // Parse step values.
2958  if (parser.parseKeyword("step") ||
2959  parser.parseOperandList(steps, ivs.size(),
2960  OpAsmParser::Delimiter::Paren) ||
2961  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2962  return failure();
2963 
2964  // Parse init values.
2966  if (succeeded(parser.parseOptionalKeyword("init"))) {
2967  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2968  return failure();
2969  }
2970 
2971  // Parse optional results in case there is a reduce.
2972  if (parser.parseOptionalArrowTypeList(result.types))
2973  return failure();
2974 
2975  // Now parse the body.
2976  Region *body = result.addRegion();
2977  for (auto &iv : ivs)
2978  iv.type = builder.getIndexType();
2979  if (parser.parseRegion(*body, ivs))
2980  return failure();
2981 
2982  // Set `operandSegmentSizes` attribute.
2983  result.addAttribute(
2984  ParallelOp::getOperandSegmentSizeAttr(),
2985  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2986  static_cast<int32_t>(upper.size()),
2987  static_cast<int32_t>(steps.size()),
2988  static_cast<int32_t>(initVals.size())}));
2989 
2990  // Parse attributes.
2991  if (parser.parseOptionalAttrDict(result.attributes) ||
2992  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2993  result.operands))
2994  return failure();
2995 
2996  // Add a terminator if none was parsed.
2997  ParallelOp::ensureTerminator(*body, builder, result.location);
2998  return success();
2999 }
3000 
3002  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3003  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3004  if (!getInitVals().empty())
3005  p << " init (" << getInitVals() << ")";
3006  p.printOptionalArrowTypeList(getResultTypes());
3007  p << ' ';
3008  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3010  (*this)->getAttrs(),
3011  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3012 }
3013 
3014 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3015 
3016 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3017  return SmallVector<Value>{getBody()->getArguments()};
3018 }
3019 
3020 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3021  return getLowerBound();
3022 }
3023 
3024 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3025  return getUpperBound();
3026 }
3027 
3028 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3029  return getStep();
3030 }
3031 
3033  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3034  if (!ivArg)
3035  return ParallelOp();
3036  assert(ivArg.getOwner() && "unlinked block argument");
3037  auto *containingOp = ivArg.getOwner()->getParentOp();
3038  return dyn_cast<ParallelOp>(containingOp);
3039 }
3040 
3041 namespace {
3042 // Collapse loop dimensions that perform a single iteration.
3043 struct ParallelOpSingleOrZeroIterationDimsFolder
3044  : public OpRewritePattern<ParallelOp> {
3046 
3047  LogicalResult matchAndRewrite(ParallelOp op,
3048  PatternRewriter &rewriter) const override {
3049  Location loc = op.getLoc();
3050 
3051  // Compute new loop bounds that omit all single-iteration loop dimensions.
3052  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3053  IRMapping mapping;
3054  for (auto [lb, ub, step, iv] :
3055  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3056  op.getInductionVars())) {
3057  auto numIterations = constantTripCount(lb, ub, step);
3058  if (numIterations.has_value()) {
3059  // Remove the loop if it performs zero iterations.
3060  if (*numIterations == 0) {
3061  rewriter.replaceOp(op, op.getInitVals());
3062  return success();
3063  }
3064  // Replace the loop induction variable by the lower bound if the loop
3065  // performs a single iteration. Otherwise, copy the loop bounds.
3066  if (*numIterations == 1) {
3067  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3068  continue;
3069  }
3070  }
3071  newLowerBounds.push_back(lb);
3072  newUpperBounds.push_back(ub);
3073  newSteps.push_back(step);
3074  }
3075  // Exit if none of the loop dimensions perform a single iteration.
3076  if (newLowerBounds.size() == op.getLowerBound().size())
3077  return failure();
3078 
3079  if (newLowerBounds.empty()) {
3080  // All of the loop dimensions perform a single iteration. Inline
3081  // loop body and nested ReduceOp's
3082  SmallVector<Value> results;
3083  results.reserve(op.getInitVals().size());
3084  for (auto &bodyOp : op.getBody()->without_terminator())
3085  rewriter.clone(bodyOp, mapping);
3086  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3087  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3088  Block &reduceBlock = reduceOp.getReductions()[i].front();
3089  auto initValIndex = results.size();
3090  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3091  mapping.map(reduceBlock.getArgument(1),
3092  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3093  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3094  rewriter.clone(reduceBodyOp, mapping);
3095 
3096  auto result = mapping.lookupOrDefault(
3097  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3098  results.push_back(result);
3099  }
3100 
3101  rewriter.replaceOp(op, results);
3102  return success();
3103  }
3104  // Replace the parallel loop by lower-dimensional parallel loop.
3105  auto newOp =
3106  ParallelOp::create(rewriter, op.getLoc(), newLowerBounds,
3107  newUpperBounds, newSteps, op.getInitVals(), nullptr);
3108  // Erase the empty block that was inserted by the builder.
3109  rewriter.eraseBlock(newOp.getBody());
3110  // Clone the loop body and remap the block arguments of the collapsed loops
3111  // (inlining does not support a cancellable block argument mapping).
3112  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3113  newOp.getRegion().begin(), mapping);
3114  rewriter.replaceOp(op, newOp.getResults());
3115  return success();
3116  }
3117 };
3118 
3119 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3121 
3122  LogicalResult matchAndRewrite(ParallelOp op,
3123  PatternRewriter &rewriter) const override {
3124  Block &outerBody = *op.getBody();
3125  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3126  return failure();
3127 
3128  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3129  if (!innerOp)
3130  return failure();
3131 
3132  for (auto val : outerBody.getArguments())
3133  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3134  llvm::is_contained(innerOp.getUpperBound(), val) ||
3135  llvm::is_contained(innerOp.getStep(), val))
3136  return failure();
3137 
3138  // Reductions are not supported yet.
3139  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3140  return failure();
3141 
3142  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3143  ValueRange iterVals, ValueRange) {
3144  Block &innerBody = *innerOp.getBody();
3145  assert(iterVals.size() ==
3146  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3147  IRMapping mapping;
3148  mapping.map(outerBody.getArguments(),
3149  iterVals.take_front(outerBody.getNumArguments()));
3150  mapping.map(innerBody.getArguments(),
3151  iterVals.take_back(innerBody.getNumArguments()));
3152  for (Operation &op : innerBody.without_terminator())
3153  builder.clone(op, mapping);
3154  };
3155 
3156  auto concatValues = [](const auto &first, const auto &second) {
3157  SmallVector<Value> ret;
3158  ret.reserve(first.size() + second.size());
3159  ret.assign(first.begin(), first.end());
3160  ret.append(second.begin(), second.end());
3161  return ret;
3162  };
3163 
3164  auto newLowerBounds =
3165  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3166  auto newUpperBounds =
3167  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3168  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3169 
3170  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3171  newSteps, ValueRange(),
3172  bodyBuilder);
3173  return success();
3174  }
3175 };
3176 
3177 } // namespace
3178 
3179 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3180  MLIRContext *context) {
3181  results
3182  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3183  context);
3184 }
3185 
3186 /// Given the region at `index`, or the parent operation if `index` is None,
3187 /// return the successor regions. These are the regions that may be selected
3188 /// during the flow of control. `operands` is a set of optional attributes that
3189 /// correspond to a constant value for each operand, or null if that operand is
3190 /// not a constant.
3191 void ParallelOp::getSuccessorRegions(
3193  // Both the operation itself and the region may be branching into the body or
3194  // back into the operation itself. It is possible for loop not to enter the
3195  // body.
3196  regions.push_back(RegionSuccessor(&getRegion()));
3197  regions.push_back(RegionSuccessor());
3198 }
3199 
3200 //===----------------------------------------------------------------------===//
3201 // ReduceOp
3202 //===----------------------------------------------------------------------===//
3203 
3204 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3205 
3206 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3207  ValueRange operands) {
3208  result.addOperands(operands);
3209  for (Value v : operands) {
3210  OpBuilder::InsertionGuard guard(builder);
3211  Region *bodyRegion = result.addRegion();
3212  builder.createBlock(bodyRegion, {},
3213  ArrayRef<Type>{v.getType(), v.getType()},
3214  {result.location, result.location});
3215  }
3216 }
3217 
3218 LogicalResult ReduceOp::verifyRegions() {
3219  // The region of a ReduceOp has two arguments of the same type as its
3220  // corresponding operand.
3221  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3222  auto type = getOperands()[i].getType();
3223  Block &block = getReductions()[i].front();
3224  if (block.empty())
3225  return emitOpError() << i << "-th reduction has an empty body";
3226  if (block.getNumArguments() != 2 ||
3227  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3228  return arg.getType() != type;
3229  }))
3230  return emitOpError() << "expected two block arguments with type " << type
3231  << " in the " << i << "-th reduction region";
3232 
3233  // Check that the block is terminated by a ReduceReturnOp.
3234  if (!isa<ReduceReturnOp>(block.getTerminator()))
3235  return emitOpError("reduction bodies must be terminated with an "
3236  "'scf.reduce.return' op");
3237  }
3238 
3239  return success();
3240 }
3241 
3244  // No operands are forwarded to the next iteration.
3245  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3246 }
3247 
3248 //===----------------------------------------------------------------------===//
3249 // ReduceReturnOp
3250 //===----------------------------------------------------------------------===//
3251 
3252 LogicalResult ReduceReturnOp::verify() {
3253  // The type of the return value should be the same type as the types of the
3254  // block arguments of the reduction body.
3255  Block *reductionBody = getOperation()->getBlock();
3256  // Should already be verified by an op trait.
3257  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3258  Type expectedResultType = reductionBody->getArgument(0).getType();
3259  if (expectedResultType != getResult().getType())
3260  return emitOpError() << "must have type " << expectedResultType
3261  << " (the type of the reduction inputs)";
3262  return success();
3263 }
3264 
3265 //===----------------------------------------------------------------------===//
3266 // WhileOp
3267 //===----------------------------------------------------------------------===//
3268 
3269 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3270  ::mlir::OperationState &odsState, TypeRange resultTypes,
3271  ValueRange inits, BodyBuilderFn beforeBuilder,
3272  BodyBuilderFn afterBuilder) {
3273  odsState.addOperands(inits);
3274  odsState.addTypes(resultTypes);
3275 
3276  OpBuilder::InsertionGuard guard(odsBuilder);
3277 
3278  // Build before region.
3279  SmallVector<Location, 4> beforeArgLocs;
3280  beforeArgLocs.reserve(inits.size());
3281  for (Value operand : inits) {
3282  beforeArgLocs.push_back(operand.getLoc());
3283  }
3284 
3285  Region *beforeRegion = odsState.addRegion();
3286  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3287  inits.getTypes(), beforeArgLocs);
3288  if (beforeBuilder)
3289  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3290 
3291  // Build after region.
3292  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3293 
3294  Region *afterRegion = odsState.addRegion();
3295  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3296  resultTypes, afterArgLocs);
3297 
3298  if (afterBuilder)
3299  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3300 }
3301 
3302 ConditionOp WhileOp::getConditionOp() {
3303  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3304 }
3305 
3306 YieldOp WhileOp::getYieldOp() {
3307  return cast<YieldOp>(getAfterBody()->getTerminator());
3308 }
3309 
3310 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3311  return getYieldOp().getResultsMutable();
3312 }
3313 
3314 Block::BlockArgListType WhileOp::getBeforeArguments() {
3315  return getBeforeBody()->getArguments();
3316 }
3317 
3318 Block::BlockArgListType WhileOp::getAfterArguments() {
3319  return getAfterBody()->getArguments();
3320 }
3321 
3322 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3323  return getBeforeArguments();
3324 }
3325 
3326 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3327  assert(point == getBefore() &&
3328  "WhileOp is expected to branch only to the first region");
3329  return getInits();
3330 }
3331 
3332 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3334  // The parent op always branches to the condition region.
3335  if (point.isParent()) {
3336  regions.emplace_back(&getBefore(), getBefore().getArguments());
3337  return;
3338  }
3339 
3340  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3341  "there are only two regions in a WhileOp");
3342  // The body region always branches back to the condition region.
3343  if (point == getAfter()) {
3344  regions.emplace_back(&getBefore(), getBefore().getArguments());
3345  return;
3346  }
3347 
3348  regions.emplace_back(getResults());
3349  regions.emplace_back(&getAfter(), getAfter().getArguments());
3350 }
3351 
3352 SmallVector<Region *> WhileOp::getLoopRegions() {
3353  return {&getBefore(), &getAfter()};
3354 }
3355 
3356 /// Parses a `while` op.
3357 ///
3358 /// op ::= `scf.while` assignments `:` function-type region `do` region
3359 /// `attributes` attribute-dict
3360 /// initializer ::= /* empty */ | `(` assignment-list `)`
3361 /// assignment-list ::= assignment | assignment `,` assignment-list
3362 /// assignment ::= ssa-value `=` ssa-value
3363 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3366  Region *before = result.addRegion();
3367  Region *after = result.addRegion();
3368 
3369  OptionalParseResult listResult =
3370  parser.parseOptionalAssignmentList(regionArgs, operands);
3371  if (listResult.has_value() && failed(listResult.value()))
3372  return failure();
3373 
3374  FunctionType functionType;
3375  SMLoc typeLoc = parser.getCurrentLocation();
3376  if (failed(parser.parseColonType(functionType)))
3377  return failure();
3378 
3379  result.addTypes(functionType.getResults());
3380 
3381  if (functionType.getNumInputs() != operands.size()) {
3382  return parser.emitError(typeLoc)
3383  << "expected as many input types as operands "
3384  << "(expected " << operands.size() << " got "
3385  << functionType.getNumInputs() << ")";
3386  }
3387 
3388  // Resolve input operands.
3389  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3390  parser.getCurrentLocation(),
3391  result.operands)))
3392  return failure();
3393 
3394  // Propagate the types into the region arguments.
3395  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3396  regionArgs[i].type = functionType.getInput(i);
3397 
3398  return failure(parser.parseRegion(*before, regionArgs) ||
3399  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3401 }
3402 
3403 /// Prints a `while` op.
3405  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3406  p << " : ";
3407  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3408  p << ' ';
3409  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3410  p << " do ";
3411  p.printRegion(getAfter());
3412  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3413 }
3414 
3415 /// Verifies that two ranges of types match, i.e. have the same number of
3416 /// entries and that types are pairwise equals. Reports errors on the given
3417 /// operation in case of mismatch.
3418 template <typename OpTy>
3419 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3420  TypeRange right, StringRef message) {
3421  if (left.size() != right.size())
3422  return op.emitOpError("expects the same number of ") << message;
3423 
3424  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3425  if (left[i] != right[i]) {
3426  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3427  << message;
3428  diag.attachNote() << "for argument " << i << ", found " << left[i]
3429  << " and " << right[i];
3430  return diag;
3431  }
3432  }
3433 
3434  return success();
3435 }
3436 
3437 LogicalResult scf::WhileOp::verify() {
3438  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3439  *this, getBefore(),
3440  "expects the 'before' region to terminate with 'scf.condition'");
3441  if (!beforeTerminator)
3442  return failure();
3443 
3444  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3445  *this, getAfter(),
3446  "expects the 'after' region to terminate with 'scf.yield'");
3447  return success(afterTerminator != nullptr);
3448 }
3449 
3450 namespace {
3451 /// Replace uses of the condition within the do block with true, since otherwise
3452 /// the block would not be evaluated.
3453 ///
3454 /// scf.while (..) : (i1, ...) -> ... {
3455 /// %condition = call @evaluate_condition() : () -> i1
3456 /// scf.condition(%condition) %condition : i1, ...
3457 /// } do {
3458 /// ^bb0(%arg0: i1, ...):
3459 /// use(%arg0)
3460 /// ...
3461 ///
3462 /// becomes
3463 /// scf.while (..) : (i1, ...) -> ... {
3464 /// %condition = call @evaluate_condition() : () -> i1
3465 /// scf.condition(%condition) %condition : i1, ...
3466 /// } do {
3467 /// ^bb0(%arg0: i1, ...):
3468 /// use(%true)
3469 /// ...
3470 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3472 
3473  LogicalResult matchAndRewrite(WhileOp op,
3474  PatternRewriter &rewriter) const override {
3475  auto term = op.getConditionOp();
3476 
3477  // These variables serve to prevent creating duplicate constants
3478  // and hold constant true or false values.
3479  Value constantTrue = nullptr;
3480 
3481  bool replaced = false;
3482  for (auto yieldedAndBlockArgs :
3483  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3484  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3485  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3486  if (!constantTrue)
3487  constantTrue = arith::ConstantOp::create(
3488  rewriter, op.getLoc(), term.getCondition().getType(),
3489  rewriter.getBoolAttr(true));
3490 
3491  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3492  constantTrue);
3493  replaced = true;
3494  }
3495  }
3496  }
3497  return success(replaced);
3498  }
3499 };
3500 
3501 /// Remove loop invariant arguments from `before` block of scf.while.
3502 /// A before block argument is considered loop invariant if :-
3503 /// 1. i-th yield operand is equal to the i-th while operand.
3504 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3505 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3506 /// iter argument/while operand.
3507 /// For the arguments which are removed, their uses inside scf.while
3508 /// are replaced with their corresponding initial value.
3509 ///
3510 /// Eg:
3511 /// INPUT :-
3512 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3513 /// ..., %argN_before = %N)
3514 /// {
3515 /// ...
3516 /// scf.condition(%cond) %arg1_before, %arg0_before,
3517 /// %arg2_before, %arg0_before, ...
3518 /// } do {
3519 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3520 /// ..., %argK_after):
3521 /// ...
3522 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3523 /// }
3524 ///
3525 /// OUTPUT :-
3526 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3527 /// %N)
3528 /// {
3529 /// ...
3530 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3531 /// } do {
3532 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3533 /// ..., %argK_after):
3534 /// ...
3535 /// scf.yield %arg1_after, ..., %argN
3536 /// }
3537 ///
3538 /// EXPLANATION:
3539 /// We iterate over each yield operand.
3540 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3541 /// %arg0_before, which in turn is the 0-th iter argument. So we
3542 /// remove 0-th before block argument and yield operand, and replace
3543 /// all uses of the 0-th before block argument with its initial value
3544 /// %a.
3545 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3546 /// value. So we remove this operand and the corresponding before
3547 /// block argument and replace all uses of 1-th before block argument
3548 /// with %b.
3549 struct RemoveLoopInvariantArgsFromBeforeBlock
3550  : public OpRewritePattern<WhileOp> {
3552 
3553  LogicalResult matchAndRewrite(WhileOp op,
3554  PatternRewriter &rewriter) const override {
3555  Block &afterBlock = *op.getAfterBody();
3556  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3557  ConditionOp condOp = op.getConditionOp();
3558  OperandRange condOpArgs = condOp.getArgs();
3559  Operation *yieldOp = afterBlock.getTerminator();
3560  ValueRange yieldOpArgs = yieldOp->getOperands();
3561 
3562  bool canSimplify = false;
3563  for (const auto &it :
3564  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3565  auto index = static_cast<unsigned>(it.index());
3566  auto [initVal, yieldOpArg] = it.value();
3567  // If i-th yield operand is equal to the i-th operand of the scf.while,
3568  // the i-th before block argument is a loop invariant.
3569  if (yieldOpArg == initVal) {
3570  canSimplify = true;
3571  break;
3572  }
3573  // If the i-th yield operand is k-th after block argument, then we check
3574  // if the (k+1)-th condition op operand is equal to either the i-th before
3575  // block argument or the initial value of i-th before block argument. If
3576  // the comparison results `true`, i-th before block argument is a loop
3577  // invariant.
3578  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3579  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3580  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3581  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3582  canSimplify = true;
3583  break;
3584  }
3585  }
3586  }
3587 
3588  if (!canSimplify)
3589  return failure();
3590 
3591  SmallVector<Value> newInitArgs, newYieldOpArgs;
3592  DenseMap<unsigned, Value> beforeBlockInitValMap;
3593  SmallVector<Location> newBeforeBlockArgLocs;
3594  for (const auto &it :
3595  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3596  auto index = static_cast<unsigned>(it.index());
3597  auto [initVal, yieldOpArg] = it.value();
3598 
3599  // If i-th yield operand is equal to the i-th operand of the scf.while,
3600  // the i-th before block argument is a loop invariant.
3601  if (yieldOpArg == initVal) {
3602  beforeBlockInitValMap.insert({index, initVal});
3603  continue;
3604  } else {
3605  // If the i-th yield operand is k-th after block argument, then we check
3606  // if the (k+1)-th condition op operand is equal to either the i-th
3607  // before block argument or the initial value of i-th before block
3608  // argument. If the comparison results `true`, i-th before block
3609  // argument is a loop invariant.
3610  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3611  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3612  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3613  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3614  beforeBlockInitValMap.insert({index, initVal});
3615  continue;
3616  }
3617  }
3618  }
3619  newInitArgs.emplace_back(initVal);
3620  newYieldOpArgs.emplace_back(yieldOpArg);
3621  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3622  }
3623 
3624  {
3625  OpBuilder::InsertionGuard g(rewriter);
3626  rewriter.setInsertionPoint(yieldOp);
3627  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3628  }
3629 
3630  auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
3631  newInitArgs);
3632 
3633  Block &newBeforeBlock = *rewriter.createBlock(
3634  &newWhile.getBefore(), /*insertPt*/ {},
3635  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3636 
3637  Block &beforeBlock = *op.getBeforeBody();
3638  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3639  // For each i-th before block argument we find it's replacement value as :-
3640  // 1. If i-th before block argument is a loop invariant, we fetch it's
3641  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3642  // 2. Else we fetch j-th new before block argument as the replacement
3643  // value of i-th before block argument.
3644  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3645  // If the index 'i' argument was a loop invariant we fetch it's initial
3646  // value from `beforeBlockInitValMap`.
3647  if (beforeBlockInitValMap.count(i) != 0)
3648  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3649  else
3650  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3651  }
3652 
3653  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3654  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3655  newWhile.getAfter().begin());
3656 
3657  rewriter.replaceOp(op, newWhile.getResults());
3658  return success();
3659  }
3660 };
3661 
3662 /// Remove loop invariant value from result (condition op) of scf.while.
3663 /// A value is considered loop invariant if the final value yielded by
3664 /// scf.condition is defined outside of the `before` block. We remove the
3665 /// corresponding argument in `after` block and replace the use with the value.
3666 /// We also replace the use of the corresponding result of scf.while with the
3667 /// value.
3668 ///
3669 /// Eg:
3670 /// INPUT :-
3671 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3672 /// %argN_before = %N) {
3673 /// ...
3674 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3675 /// } do {
3676 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3677 /// ...
3678 /// some_func(%arg1_after)
3679 /// ...
3680 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3681 /// }
3682 ///
3683 /// OUTPUT :-
3684 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3685 /// ...
3686 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3687 /// } do {
3688 /// ^bb0(%arg0, %arg3, ..., %argM):
3689 /// ...
3690 /// some_func(%a)
3691 /// ...
3692 /// scf.yield %arg0, %b, ..., %argN
3693 /// }
3694 ///
3695 /// EXPLANATION:
3696 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3697 /// before block of scf.while, so they get removed.
3698 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3699 /// replaced by %b.
3700 /// 3. The corresponding after block argument %arg1_after's uses are
3701 /// replaced by %a and %arg2_after's uses are replaced by %b.
3702 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3704 
3705  LogicalResult matchAndRewrite(WhileOp op,
3706  PatternRewriter &rewriter) const override {
3707  Block &beforeBlock = *op.getBeforeBody();
3708  ConditionOp condOp = op.getConditionOp();
3709  OperandRange condOpArgs = condOp.getArgs();
3710 
3711  bool canSimplify = false;
3712  for (Value condOpArg : condOpArgs) {
3713  // Those values not defined within `before` block will be considered as
3714  // loop invariant values. We map the corresponding `index` with their
3715  // value.
3716  if (condOpArg.getParentBlock() != &beforeBlock) {
3717  canSimplify = true;
3718  break;
3719  }
3720  }
3721 
3722  if (!canSimplify)
3723  return failure();
3724 
3725  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3726 
3727  SmallVector<Value> newCondOpArgs;
3728  SmallVector<Type> newAfterBlockType;
3729  DenseMap<unsigned, Value> condOpInitValMap;
3730  SmallVector<Location> newAfterBlockArgLocs;
3731  for (const auto &it : llvm::enumerate(condOpArgs)) {
3732  auto index = static_cast<unsigned>(it.index());
3733  Value condOpArg = it.value();
3734  // Those values not defined within `before` block will be considered as
3735  // loop invariant values. We map the corresponding `index` with their
3736  // value.
3737  if (condOpArg.getParentBlock() != &beforeBlock) {
3738  condOpInitValMap.insert({index, condOpArg});
3739  } else {
3740  newCondOpArgs.emplace_back(condOpArg);
3741  newAfterBlockType.emplace_back(condOpArg.getType());
3742  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3743  }
3744  }
3745 
3746  {
3747  OpBuilder::InsertionGuard g(rewriter);
3748  rewriter.setInsertionPoint(condOp);
3749  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3750  newCondOpArgs);
3751  }
3752 
3753  auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
3754  op.getOperands());
3755 
3756  Block &newAfterBlock =
3757  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3758  newAfterBlockType, newAfterBlockArgLocs);
3759 
3760  Block &afterBlock = *op.getAfterBody();
3761  // Since a new scf.condition op was created, we need to fetch the new
3762  // `after` block arguments which will be used while replacing operations of
3763  // previous scf.while's `after` blocks. We'd also be fetching new result
3764  // values too.
3765  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3766  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3767  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3768  Value afterBlockArg, result;
3769  // If index 'i' argument was loop invariant we fetch it's value from the
3770  // `condOpInitMap` map.
3771  if (condOpInitValMap.count(i) != 0) {
3772  afterBlockArg = condOpInitValMap[i];
3773  result = afterBlockArg;
3774  } else {
3775  afterBlockArg = newAfterBlock.getArgument(j);
3776  result = newWhile.getResult(j);
3777  j++;
3778  }
3779  newAfterBlockArgs[i] = afterBlockArg;
3780  newWhileResults[i] = result;
3781  }
3782 
3783  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3784  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3785  newWhile.getBefore().begin());
3786 
3787  rewriter.replaceOp(op, newWhileResults);
3788  return success();
3789  }
3790 };
3791 
3792 /// Remove WhileOp results that are also unused in 'after' block.
3793 ///
3794 /// %0:2 = scf.while () : () -> (i32, i64) {
3795 /// %condition = "test.condition"() : () -> i1
3796 /// %v1 = "test.get_some_value"() : () -> i32
3797 /// %v2 = "test.get_some_value"() : () -> i64
3798 /// scf.condition(%condition) %v1, %v2 : i32, i64
3799 /// } do {
3800 /// ^bb0(%arg0: i32, %arg1: i64):
3801 /// "test.use"(%arg0) : (i32) -> ()
3802 /// scf.yield
3803 /// }
3804 /// return %0#0 : i32
3805 ///
3806 /// becomes
3807 /// %0 = scf.while () : () -> (i32) {
3808 /// %condition = "test.condition"() : () -> i1
3809 /// %v1 = "test.get_some_value"() : () -> i32
3810 /// %v2 = "test.get_some_value"() : () -> i64
3811 /// scf.condition(%condition) %v1 : i32
3812 /// } do {
3813 /// ^bb0(%arg0: i32):
3814 /// "test.use"(%arg0) : (i32) -> ()
3815 /// scf.yield
3816 /// }
3817 /// return %0 : i32
3818 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3820 
3821  LogicalResult matchAndRewrite(WhileOp op,
3822  PatternRewriter &rewriter) const override {
3823  auto term = op.getConditionOp();
3824  auto afterArgs = op.getAfterArguments();
3825  auto termArgs = term.getArgs();
3826 
3827  // Collect results mapping, new terminator args and new result types.
3828  SmallVector<unsigned> newResultsIndices;
3829  SmallVector<Type> newResultTypes;
3830  SmallVector<Value> newTermArgs;
3831  SmallVector<Location> newArgLocs;
3832  bool needUpdate = false;
3833  for (const auto &it :
3834  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3835  auto i = static_cast<unsigned>(it.index());
3836  Value result = std::get<0>(it.value());
3837  Value afterArg = std::get<1>(it.value());
3838  Value termArg = std::get<2>(it.value());
3839  if (result.use_empty() && afterArg.use_empty()) {
3840  needUpdate = true;
3841  } else {
3842  newResultsIndices.emplace_back(i);
3843  newTermArgs.emplace_back(termArg);
3844  newResultTypes.emplace_back(result.getType());
3845  newArgLocs.emplace_back(result.getLoc());
3846  }
3847  }
3848 
3849  if (!needUpdate)
3850  return failure();
3851 
3852  {
3853  OpBuilder::InsertionGuard g(rewriter);
3854  rewriter.setInsertionPoint(term);
3855  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3856  newTermArgs);
3857  }
3858 
3859  auto newWhile =
3860  WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
3861 
3862  Block &newAfterBlock = *rewriter.createBlock(
3863  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3864 
3865  // Build new results list and new after block args (unused entries will be
3866  // null).
3867  SmallVector<Value> newResults(op.getNumResults());
3868  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3869  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3870  newResults[it.value()] = newWhile.getResult(it.index());
3871  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3872  }
3873 
3874  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3875  newWhile.getBefore().begin());
3876 
3877  Block &afterBlock = *op.getAfterBody();
3878  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3879 
3880  rewriter.replaceOp(op, newResults);
3881  return success();
3882  }
3883 };
3884 
3885 /// Replace operations equivalent to the condition in the do block with true,
3886 /// since otherwise the block would not be evaluated.
3887 ///
3888 /// scf.while (..) : (i32, ...) -> ... {
3889 /// %z = ... : i32
3890 /// %condition = cmpi pred %z, %a
3891 /// scf.condition(%condition) %z : i32, ...
3892 /// } do {
3893 /// ^bb0(%arg0: i32, ...):
3894 /// %condition2 = cmpi pred %arg0, %a
3895 /// use(%condition2)
3896 /// ...
3897 ///
3898 /// becomes
3899 /// scf.while (..) : (i32, ...) -> ... {
3900 /// %z = ... : i32
3901 /// %condition = cmpi pred %z, %a
3902 /// scf.condition(%condition) %z : i32, ...
3903 /// } do {
3904 /// ^bb0(%arg0: i32, ...):
3905 /// use(%true)
3906 /// ...
3907 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3909 
3910  LogicalResult matchAndRewrite(scf::WhileOp op,
3911  PatternRewriter &rewriter) const override {
3912  using namespace scf;
3913  auto cond = op.getConditionOp();
3914  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3915  if (!cmp)
3916  return failure();
3917  bool changed = false;
3918  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3919  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3920  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3921  continue;
3922  for (OpOperand &u :
3923  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3924  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3925  if (!cmp2)
3926  continue;
3927  // For a binary operator 1-opIdx gets the other side.
3928  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3929  continue;
3930  bool samePredicate;
3931  if (cmp2.getPredicate() == cmp.getPredicate())
3932  samePredicate = true;
3933  else if (cmp2.getPredicate() ==
3934  arith::invertPredicate(cmp.getPredicate()))
3935  samePredicate = false;
3936  else
3937  continue;
3938 
3939  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3940  1);
3941  changed = true;
3942  }
3943  }
3944  }
3945  return success(changed);
3946  }
3947 };
3948 
3949 /// Remove unused init/yield args.
3950 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3952 
3953  LogicalResult matchAndRewrite(WhileOp op,
3954  PatternRewriter &rewriter) const override {
3955 
3956  if (!llvm::any_of(op.getBeforeArguments(),
3957  [](Value arg) { return arg.use_empty(); }))
3958  return rewriter.notifyMatchFailure(op, "No args to remove");
3959 
3960  YieldOp yield = op.getYieldOp();
3961 
3962  // Collect results mapping, new terminator args and new result types.
3963  SmallVector<Value> newYields;
3964  SmallVector<Value> newInits;
3965  llvm::BitVector argsToErase;
3966 
3967  size_t argsCount = op.getBeforeArguments().size();
3968  newYields.reserve(argsCount);
3969  newInits.reserve(argsCount);
3970  argsToErase.reserve(argsCount);
3971  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3972  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3973  if (beforeArg.use_empty()) {
3974  argsToErase.push_back(true);
3975  } else {
3976  argsToErase.push_back(false);
3977  newYields.emplace_back(yieldValue);
3978  newInits.emplace_back(initValue);
3979  }
3980  }
3981 
3982  Block &beforeBlock = *op.getBeforeBody();
3983  Block &afterBlock = *op.getAfterBody();
3984 
3985  beforeBlock.eraseArguments(argsToErase);
3986 
3987  Location loc = op.getLoc();
3988  auto newWhileOp =
3989  WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
3990  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
3991  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3992  Block &newAfterBlock = *newWhileOp.getAfterBody();
3993 
3994  OpBuilder::InsertionGuard g(rewriter);
3995  rewriter.setInsertionPoint(yield);
3996  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3997 
3998  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
3999  newBeforeBlock.getArguments());
4000  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4001  newAfterBlock.getArguments());
4002 
4003  rewriter.replaceOp(op, newWhileOp.getResults());
4004  return success();
4005  }
4006 };
4007 
4008 /// Remove duplicated ConditionOp args.
4009 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4010  using OpRewritePattern::OpRewritePattern;
4011 
4012  LogicalResult matchAndRewrite(WhileOp op,
4013  PatternRewriter &rewriter) const override {
4014  ConditionOp condOp = op.getConditionOp();
4015  ValueRange condOpArgs = condOp.getArgs();
4016 
4017  llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
4018 
4019  if (argsSet.size() == condOpArgs.size())
4020  return rewriter.notifyMatchFailure(op, "No results to remove");
4021 
4022  llvm::SmallDenseMap<Value, unsigned> argsMap;
4023  SmallVector<Value> newArgs;
4024  argsMap.reserve(condOpArgs.size());
4025  newArgs.reserve(condOpArgs.size());
4026  for (Value arg : condOpArgs) {
4027  if (!argsMap.count(arg)) {
4028  auto pos = static_cast<unsigned>(argsMap.size());
4029  argsMap.insert({arg, pos});
4030  newArgs.emplace_back(arg);
4031  }
4032  }
4033 
4034  ValueRange argsRange(newArgs);
4035 
4036  Location loc = op.getLoc();
4037  auto newWhileOp =
4038  scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
4039  /*beforeBody*/ nullptr,
4040  /*afterBody*/ nullptr);
4041  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4042  Block &newAfterBlock = *newWhileOp.getAfterBody();
4043 
4044  SmallVector<Value> afterArgsMapping;
4045  SmallVector<Value> resultsMapping;
4046  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4047  auto it = argsMap.find(arg);
4048  assert(it != argsMap.end());
4049  auto pos = it->second;
4050  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4051  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4052  }
4053 
4054  OpBuilder::InsertionGuard g(rewriter);
4055  rewriter.setInsertionPoint(condOp);
4056  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4057  argsRange);
4058 
4059  Block &beforeBlock = *op.getBeforeBody();
4060  Block &afterBlock = *op.getAfterBody();
4061 
4062  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4063  newBeforeBlock.getArguments());
4064  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4065  rewriter.replaceOp(op, resultsMapping);
4066  return success();
4067  }
4068 };
4069 
4070 /// If both ranges contain same values return mappping indices from args2 to
4071 /// args1. Otherwise return std::nullopt.
4072 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4073  ValueRange args2) {
4074  if (args1.size() != args2.size())
4075  return std::nullopt;
4076 
4077  SmallVector<unsigned> ret(args1.size());
4078  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4079  auto it = llvm::find(args2, arg1);
4080  if (it == args2.end())
4081  return std::nullopt;
4082 
4083  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4084  }
4085 
4086  return ret;
4087 }
4088 
4089 static bool hasDuplicates(ValueRange args) {
4090  llvm::SmallDenseSet<Value> set;
4091  for (Value arg : args) {
4092  if (!set.insert(arg).second)
4093  return true;
4094  }
4095  return false;
4096 }
4097 
4098 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4099 /// `scf.condition` args into same order as block args. Update `after` block
4100 /// args and op result values accordingly.
4101 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4102 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4103  using OpRewritePattern::OpRewritePattern;
4104 
4105  LogicalResult matchAndRewrite(WhileOp loop,
4106  PatternRewriter &rewriter) const override {
4107  auto oldBefore = loop.getBeforeBody();
4108  ConditionOp oldTerm = loop.getConditionOp();
4109  ValueRange beforeArgs = oldBefore->getArguments();
4110  ValueRange termArgs = oldTerm.getArgs();
4111  if (beforeArgs == termArgs)
4112  return failure();
4113 
4114  if (hasDuplicates(termArgs))
4115  return failure();
4116 
4117  auto mapping = getArgsMapping(beforeArgs, termArgs);
4118  if (!mapping)
4119  return failure();
4120 
4121  {
4122  OpBuilder::InsertionGuard g(rewriter);
4123  rewriter.setInsertionPoint(oldTerm);
4124  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4125  beforeArgs);
4126  }
4127 
4128  auto oldAfter = loop.getAfterBody();
4129 
4130  SmallVector<Type> newResultTypes(beforeArgs.size());
4131  for (auto &&[i, j] : llvm::enumerate(*mapping))
4132  newResultTypes[j] = loop.getResult(i).getType();
4133 
4134  auto newLoop = WhileOp::create(
4135  rewriter, loop.getLoc(), newResultTypes, loop.getInits(),
4136  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4137  auto newBefore = newLoop.getBeforeBody();
4138  auto newAfter = newLoop.getAfterBody();
4139 
4140  SmallVector<Value> newResults(beforeArgs.size());
4141  SmallVector<Value> newAfterArgs(beforeArgs.size());
4142  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4143  newResults[i] = newLoop.getResult(j);
4144  newAfterArgs[i] = newAfter->getArgument(j);
4145  }
4146 
4147  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4148  newBefore->getArguments());
4149  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4150  newAfterArgs);
4151 
4152  rewriter.replaceOp(loop, newResults);
4153  return success();
4154  }
4155 };
4156 } // namespace
4157 
4158 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4159  MLIRContext *context) {
4160  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4161  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4162  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4163  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4164 }
4165 
4166 //===----------------------------------------------------------------------===//
4167 // IndexSwitchOp
4168 //===----------------------------------------------------------------------===//
4169 
4170 /// Parse the case regions and values.
4171 static ParseResult
4173  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4174  SmallVector<int64_t> caseValues;
4175  while (succeeded(p.parseOptionalKeyword("case"))) {
4176  int64_t value;
4177  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4178  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4179  return failure();
4180  caseValues.push_back(value);
4181  }
4182  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4183  return success();
4184 }
4185 
4186 /// Print the case regions and values.
4188  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4189  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4190  p.printNewline();
4191  p << "case " << value << ' ';
4192  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4193  }
4194 }
4195 
4196 LogicalResult scf::IndexSwitchOp::verify() {
4197  if (getCases().size() != getCaseRegions().size()) {
4198  return emitOpError("has ")
4199  << getCaseRegions().size() << " case regions but "
4200  << getCases().size() << " case values";
4201  }
4202 
4203  DenseSet<int64_t> valueSet;
4204  for (int64_t value : getCases())
4205  if (!valueSet.insert(value).second)
4206  return emitOpError("has duplicate case value: ") << value;
4207  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4208  auto yield = dyn_cast<YieldOp>(region.front().back());
4209  if (!yield)
4210  return emitOpError("expected region to end with scf.yield, but got ")
4211  << region.front().back().getName();
4212 
4213  if (yield.getNumOperands() != getNumResults()) {
4214  return (emitOpError("expected each region to return ")
4215  << getNumResults() << " values, but " << name << " returns "
4216  << yield.getNumOperands())
4217  .attachNote(yield.getLoc())
4218  << "see yield operation here";
4219  }
4220  for (auto [idx, result, operand] :
4221  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4222  yield.getOperandTypes())) {
4223  if (result == operand)
4224  continue;
4225  return (emitOpError("expected result #")
4226  << idx << " of each region to be " << result)
4227  .attachNote(yield.getLoc())
4228  << name << " returns " << operand << " here";
4229  }
4230  return success();
4231  };
4232 
4233  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4234  return failure();
4235  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4236  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4237  return failure();
4238 
4239  return success();
4240 }
4241 
4242 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4243 
4244 Block &scf::IndexSwitchOp::getDefaultBlock() {
4245  return getDefaultRegion().front();
4246 }
4247 
4248 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4249  assert(idx < getNumCases() && "case index out-of-bounds");
4250  return getCaseRegions()[idx].front();
4251 }
4252 
4253 void IndexSwitchOp::getSuccessorRegions(
4255  // All regions branch back to the parent op.
4256  if (!point.isParent()) {
4257  successors.emplace_back(getResults());
4258  return;
4259  }
4260 
4261  llvm::append_range(successors, getRegions());
4262 }
4263 
4264 void IndexSwitchOp::getEntrySuccessorRegions(
4265  ArrayRef<Attribute> operands,
4266  SmallVectorImpl<RegionSuccessor> &successors) {
4267  FoldAdaptor adaptor(operands, *this);
4268 
4269  // If a constant was not provided, all regions are possible successors.
4270  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4271  if (!arg) {
4272  llvm::append_range(successors, getRegions());
4273  return;
4274  }
4275 
4276  // Otherwise, try to find a case with a matching value. If not, the
4277  // default region is the only successor.
4278  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4279  if (caseValue == arg.getInt()) {
4280  successors.emplace_back(&caseRegion);
4281  return;
4282  }
4283  }
4284  successors.emplace_back(&getDefaultRegion());
4285 }
4286 
4287 void IndexSwitchOp::getRegionInvocationBounds(
4289  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4290  if (!operandValue) {
4291  // All regions are invoked at most once.
4292  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4293  return;
4294  }
4295 
4296  unsigned liveIndex = getNumRegions() - 1;
4297  const auto *it = llvm::find(getCases(), operandValue.getInt());
4298  if (it != getCases().end())
4299  liveIndex = std::distance(getCases().begin(), it);
4300  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4301  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4302 }
4303 
4304 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4306 
4307  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4308  PatternRewriter &rewriter) const override {
4309  // If `op.getArg()` is a constant, select the region that matches with
4310  // the constant value. Use the default region if no matche is found.
4311  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4312  if (!maybeCst.has_value())
4313  return failure();
4314  int64_t cst = *maybeCst;
4315  int64_t caseIdx, e = op.getNumCases();
4316  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4317  if (cst == op.getCases()[caseIdx])
4318  break;
4319  }
4320 
4321  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4322  : op.getDefaultRegion();
4323  Block &source = r.front();
4324  Operation *terminator = source.getTerminator();
4325  SmallVector<Value> results = terminator->getOperands();
4326 
4327  rewriter.inlineBlockBefore(&source, op);
4328  rewriter.eraseOp(terminator);
4329  // Replace the operation with a potentially empty list of results.
4330  // Fold mechanism doesn't support the case where the result list is empty.
4331  rewriter.replaceOp(op, results);
4332 
4333  return success();
4334  }
4335 };
4336 
4337 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4338  MLIRContext *context) {
4339  results.add<FoldConstantCase>(context);
4340 }
4341 
4342 //===----------------------------------------------------------------------===//
4343 // TableGen'd op method definitions
4344 //===----------------------------------------------------------------------===//
4345 
4346 #define GET_OP_CLASSES
4347 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:756
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:748
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:131
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1288
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4172
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3419
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:432
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:93
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4187
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:575
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:585
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:831
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:622
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:598
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:117
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3032
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:86
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:693
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:2019
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:650
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:603
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1490
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:782
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:270
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4307
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:234
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:185
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.