MLIR  19.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace mlir;
31 using namespace mlir::scf;
32 
33 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // SCFDialect Dialect Interfaces
37 //===----------------------------------------------------------------------===//
38 
39 namespace {
40 struct SCFInlinerInterface : public DialectInlinerInterface {
42  // We don't have any special restrictions on what can be inlined into
43  // destination regions (e.g. while/conditional bodies). Always allow it.
44  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
45  IRMapping &valueMapping) const final {
46  return true;
47  }
48  // Operations in scf dialect are always legal to inline since they are
49  // pure.
50  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
51  return true;
52  }
53  // Handle the given inlined terminator by replacing it with a new operation
54  // as necessary. Required when the region has only one block.
55  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
56  auto retValOp = dyn_cast<scf::YieldOp>(op);
57  if (!retValOp)
58  return;
59 
60  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
61  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
62  }
63  }
64 };
65 } // namespace
66 
67 //===----------------------------------------------------------------------===//
68 // SCFDialect
69 //===----------------------------------------------------------------------===//
70 
71 void SCFDialect::initialize() {
72  addOperations<
73 #define GET_OP_LIST
74 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
75  >();
76  addInterfaces<SCFInlinerInterface>();
77  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
78  InParallelOp, ReduceReturnOp>();
79  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
80  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
81  ForallOp, InParallelOp, WhileOp, YieldOp>();
82  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
83 }
84 
85 /// Default callback for IfOp builders. Inserts a yield without arguments.
87  builder.create<scf::YieldOp>(loc);
88 }
89 
90 /// Verifies that the first block of the given `region` is terminated by a
91 /// TerminatorTy. Reports errors on the given operation if it is not the case.
92 template <typename TerminatorTy>
93 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
94  StringRef errorMessage) {
95  Operation *terminatorOperation = nullptr;
96  if (!region.empty() && !region.front().empty()) {
97  terminatorOperation = &region.front().back();
98  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
99  return yield;
100  }
101  auto diag = op->emitOpError(errorMessage);
102  if (terminatorOperation)
103  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
104  return nullptr;
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // ExecuteRegionOp
109 //===----------------------------------------------------------------------===//
110 
111 /// Replaces the given op with the contents of the given single-block region,
112 /// using the operands of the block terminator to replace operation results.
113 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
114  Region &region, ValueRange blockArgs = {}) {
115  assert(llvm::hasSingleElement(region) && "expected single-region block");
116  Block *block = &region.front();
117  Operation *terminator = block->getTerminator();
118  ValueRange results = terminator->getOperands();
119  rewriter.inlineBlockBefore(block, op, blockArgs);
120  rewriter.replaceOp(op, results);
121  rewriter.eraseOp(terminator);
122 }
123 
124 ///
125 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
126 /// block+
127 /// `}`
128 ///
129 /// Example:
130 /// scf.execute_region -> i32 {
131 /// %idx = load %rI[%i] : memref<128xi32>
132 /// return %idx : i32
133 /// }
134 ///
136  OperationState &result) {
137  if (parser.parseOptionalArrowTypeList(result.types))
138  return failure();
139 
140  // Introduce the body region and parse it.
141  Region *body = result.addRegion();
142  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
143  parser.parseOptionalAttrDict(result.attributes))
144  return failure();
145 
146  return success();
147 }
148 
150  p.printOptionalArrowTypeList(getResultTypes());
151 
152  p << ' ';
153  p.printRegion(getRegion(),
154  /*printEntryBlockArgs=*/false,
155  /*printBlockTerminators=*/true);
156 
157  p.printOptionalAttrDict((*this)->getAttrs());
158 }
159 
161  if (getRegion().empty())
162  return emitOpError("region needs to have at least one block");
163  if (getRegion().front().getNumArguments() > 0)
164  return emitOpError("region cannot have any arguments");
165  return success();
166 }
167 
168 // Inline an ExecuteRegionOp if it only contains one block.
169 // "test.foo"() : () -> ()
170 // %v = scf.execute_region -> i64 {
171 // %x = "test.val"() : () -> i64
172 // scf.yield %x : i64
173 // }
174 // "test.bar"(%v) : (i64) -> ()
175 //
176 // becomes
177 //
178 // "test.foo"() : () -> ()
179 // %x = "test.val"() : () -> i64
180 // "test.bar"(%x) : (i64) -> ()
181 //
182 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
184 
185  LogicalResult matchAndRewrite(ExecuteRegionOp op,
186  PatternRewriter &rewriter) const override {
187  if (!llvm::hasSingleElement(op.getRegion()))
188  return failure();
189  replaceOpWithRegion(rewriter, op, op.getRegion());
190  return success();
191  }
192 };
193 
194 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
195 // TODO generalize the conditions for operations which can be inlined into.
196 // func @func_execute_region_elim() {
197 // "test.foo"() : () -> ()
198 // %v = scf.execute_region -> i64 {
199 // %c = "test.cmp"() : () -> i1
200 // cf.cond_br %c, ^bb2, ^bb3
201 // ^bb2:
202 // %x = "test.val1"() : () -> i64
203 // cf.br ^bb4(%x : i64)
204 // ^bb3:
205 // %y = "test.val2"() : () -> i64
206 // cf.br ^bb4(%y : i64)
207 // ^bb4(%z : i64):
208 // scf.yield %z : i64
209 // }
210 // "test.bar"(%v) : (i64) -> ()
211 // return
212 // }
213 //
214 // becomes
215 //
216 // func @func_execute_region_elim() {
217 // "test.foo"() : () -> ()
218 // %c = "test.cmp"() : () -> i1
219 // cf.cond_br %c, ^bb1, ^bb2
220 // ^bb1: // pred: ^bb0
221 // %x = "test.val1"() : () -> i64
222 // cf.br ^bb3(%x : i64)
223 // ^bb2: // pred: ^bb0
224 // %y = "test.val2"() : () -> i64
225 // cf.br ^bb3(%y : i64)
226 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
227 // "test.bar"(%z) : (i64) -> ()
228 // return
229 // }
230 //
231 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
233 
234  LogicalResult matchAndRewrite(ExecuteRegionOp op,
235  PatternRewriter &rewriter) const override {
236  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
237  return failure();
238 
239  Block *prevBlock = op->getBlock();
240  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
241  rewriter.setInsertionPointToEnd(prevBlock);
242 
243  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
244 
245  for (Block &blk : op.getRegion()) {
246  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
247  rewriter.setInsertionPoint(yieldOp);
248  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
249  yieldOp.getResults());
250  rewriter.eraseOp(yieldOp);
251  }
252  }
253 
254  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
255  SmallVector<Value> blockArgs;
256 
257  for (auto res : op.getResults())
258  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
259 
260  rewriter.replaceOp(op, blockArgs);
261  return success();
262  }
263 };
264 
265 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
266  MLIRContext *context) {
268 }
269 
270 void ExecuteRegionOp::getSuccessorRegions(
272  // If the predecessor is the ExecuteRegionOp, branch into the body.
273  if (point.isParent()) {
274  regions.push_back(RegionSuccessor(&getRegion()));
275  return;
276  }
277 
278  // Otherwise, the region branches back to the parent operation.
279  regions.push_back(RegionSuccessor(getResults()));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // ConditionOp
284 //===----------------------------------------------------------------------===//
285 
288  assert((point.isParent() || point == getParentOp().getAfter()) &&
289  "condition op can only exit the loop or branch to the after"
290  "region");
291  // Pass all operands except the condition to the successor region.
292  return getArgsMutable();
293 }
294 
295 void ConditionOp::getSuccessorRegions(
297  FoldAdaptor adaptor(operands, *this);
298 
299  WhileOp whileOp = getParentOp();
300 
301  // Condition can either lead to the after region or back to the parent op
302  // depending on whether the condition is true or not.
303  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
304  if (!boolAttr || boolAttr.getValue())
305  regions.emplace_back(&whileOp.getAfter(),
306  whileOp.getAfter().getArguments());
307  if (!boolAttr || !boolAttr.getValue())
308  regions.emplace_back(whileOp.getResults());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // ForOp
313 //===----------------------------------------------------------------------===//
314 
315 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
316  Value ub, Value step, ValueRange iterArgs,
317  BodyBuilderFn bodyBuilder) {
318  OpBuilder::InsertionGuard guard(builder);
319 
320  result.addOperands({lb, ub, step});
321  result.addOperands(iterArgs);
322  for (Value v : iterArgs)
323  result.addTypes(v.getType());
324  Type t = lb.getType();
325  Region *bodyRegion = result.addRegion();
326  Block *bodyBlock = builder.createBlock(bodyRegion);
327  bodyBlock->addArgument(t, result.location);
328  for (Value v : iterArgs)
329  bodyBlock->addArgument(v.getType(), v.getLoc());
330 
331  // Create the default terminator if the builder is not provided and if the
332  // iteration arguments are not provided. Otherwise, leave this to the caller
333  // because we don't know which values to return from the loop.
334  if (iterArgs.empty() && !bodyBuilder) {
335  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
336  } else if (bodyBuilder) {
337  OpBuilder::InsertionGuard guard(builder);
338  builder.setInsertionPointToStart(bodyBlock);
339  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
340  bodyBlock->getArguments().drop_front());
341  }
342 }
343 
345  // Check that the number of init args and op results is the same.
346  if (getInitArgs().size() != getNumResults())
347  return emitOpError(
348  "mismatch in number of loop-carried values and defined values");
349 
350  return success();
351 }
352 
353 LogicalResult ForOp::verifyRegions() {
354  // Check that the body defines as single block argument for the induction
355  // variable.
356  if (getInductionVar().getType() != getLowerBound().getType())
357  return emitOpError(
358  "expected induction variable to be same type as bounds and step");
359 
360  if (getNumRegionIterArgs() != getNumResults())
361  return emitOpError(
362  "mismatch in number of basic block args and defined values");
363 
364  auto initArgs = getInitArgs();
365  auto iterArgs = getRegionIterArgs();
366  auto opResults = getResults();
367  unsigned i = 0;
368  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
369  if (std::get<0>(e).getType() != std::get<2>(e).getType())
370  return emitOpError() << "types mismatch between " << i
371  << "th iter operand and defined value";
372  if (std::get<1>(e).getType() != std::get<2>(e).getType())
373  return emitOpError() << "types mismatch between " << i
374  << "th iter region arg and defined value";
375 
376  ++i;
377  }
378  return success();
379 }
380 
381 std::optional<Value> ForOp::getSingleInductionVar() {
382  return getInductionVar();
383 }
384 
385 std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
386  return OpFoldResult(getLowerBound());
387 }
388 
389 std::optional<OpFoldResult> ForOp::getSingleStep() {
390  return OpFoldResult(getStep());
391 }
392 
393 std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
394  return OpFoldResult(getUpperBound());
395 }
396 
397 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
398 
399 /// Promotes the loop body of a forOp to its containing block if the forOp
400 /// it can be determined that the loop has a single iteration.
402  std::optional<int64_t> tripCount =
404  if (!tripCount.has_value() || tripCount != 1)
405  return failure();
406 
407  // Replace all results with the yielded values.
408  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
409  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
410 
411  // Replace block arguments with lower bound (replacement for IV) and
412  // iter_args.
413  SmallVector<Value> bbArgReplacements;
414  bbArgReplacements.push_back(getLowerBound());
415  llvm::append_range(bbArgReplacements, getInitArgs());
416 
417  // Move the loop body operations to the loop's containing block.
418  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
419  getOperation()->getIterator(), bbArgReplacements);
420 
421  // Erase the old terminator and the loop.
422  rewriter.eraseOp(yieldOp);
423  rewriter.eraseOp(*this);
424 
425  return success();
426 }
427 
428 /// Prints the initialization list in the form of
429 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
430 /// where 'inner' values are assumed to be region arguments and 'outer' values
431 /// are regular SSA values.
433  Block::BlockArgListType blocksArgs,
434  ValueRange initializers,
435  StringRef prefix = "") {
436  assert(blocksArgs.size() == initializers.size() &&
437  "expected same length of arguments and initializers");
438  if (initializers.empty())
439  return;
440 
441  p << prefix << '(';
442  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
443  p << std::get<0>(it) << " = " << std::get<1>(it);
444  });
445  p << ")";
446 }
447 
448 void ForOp::print(OpAsmPrinter &p) {
449  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
450  << getUpperBound() << " step " << getStep();
451 
452  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
453  if (!getInitArgs().empty())
454  p << " -> (" << getInitArgs().getTypes() << ')';
455  p << ' ';
456  if (Type t = getInductionVar().getType(); !t.isIndex())
457  p << " : " << t << ' ';
458  p.printRegion(getRegion(),
459  /*printEntryBlockArgs=*/false,
460  /*printBlockTerminators=*/!getInitArgs().empty());
461  p.printOptionalAttrDict((*this)->getAttrs());
462 }
463 
465  auto &builder = parser.getBuilder();
466  Type type;
467 
468  OpAsmParser::Argument inductionVariable;
469  OpAsmParser::UnresolvedOperand lb, ub, step;
470 
471  // Parse the induction variable followed by '='.
472  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
473  // Parse loop bounds.
474  parser.parseOperand(lb) || parser.parseKeyword("to") ||
475  parser.parseOperand(ub) || parser.parseKeyword("step") ||
476  parser.parseOperand(step))
477  return failure();
478 
479  // Parse the optional initial iteration arguments.
482  regionArgs.push_back(inductionVariable);
483 
484  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
485  if (hasIterArgs) {
486  // Parse assignment list and results type list.
487  if (parser.parseAssignmentList(regionArgs, operands) ||
488  parser.parseArrowTypeList(result.types))
489  return failure();
490  }
491 
492  if (regionArgs.size() != result.types.size() + 1)
493  return parser.emitError(
494  parser.getNameLoc(),
495  "mismatch in number of loop-carried values and defined values");
496 
497  // Parse optional type, else assume Index.
498  if (parser.parseOptionalColon())
499  type = builder.getIndexType();
500  else if (parser.parseType(type))
501  return failure();
502 
503  // Resolve input operands.
504  regionArgs.front().type = type;
505  if (parser.resolveOperand(lb, type, result.operands) ||
506  parser.resolveOperand(ub, type, result.operands) ||
507  parser.resolveOperand(step, type, result.operands))
508  return failure();
509  if (hasIterArgs) {
510  for (auto argOperandType :
511  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
512  Type type = std::get<2>(argOperandType);
513  std::get<0>(argOperandType).type = type;
514  if (parser.resolveOperand(std::get<1>(argOperandType), type,
515  result.operands))
516  return failure();
517  }
518  }
519 
520  // Parse the body region.
521  Region *body = result.addRegion();
522  if (parser.parseRegion(*body, regionArgs))
523  return failure();
524 
525  ForOp::ensureTerminator(*body, builder, result.location);
526 
527  // Parse the optional attribute list.
528  if (parser.parseOptionalAttrDict(result.attributes))
529  return failure();
530 
531  return success();
532 }
533 
534 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
535 
536 Block::BlockArgListType ForOp::getRegionIterArgs() {
537  return getBody()->getArguments().drop_front(getNumInductionVars());
538 }
539 
540 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
541  return getInitArgsMutable();
542 }
543 
545 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
546  ValueRange newInitOperands,
547  bool replaceInitOperandUsesInLoop,
548  const NewYieldValuesFn &newYieldValuesFn) {
549  // Create a new loop before the existing one, with the extra operands.
550  OpBuilder::InsertionGuard g(rewriter);
551  rewriter.setInsertionPoint(getOperation());
552  auto inits = llvm::to_vector(getInitArgs());
553  inits.append(newInitOperands.begin(), newInitOperands.end());
554  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
555  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
556  [](OpBuilder &, Location, Value, ValueRange) {});
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
775 namespace {
776 // Fold away ForOp iter arguments when:
777 // 1) The op yields the iter arguments.
778 // 2) The iter arguments have no use and the corresponding outer region
779 // iterators (inputs) are yielded.
780 // 3) The iter arguments have no use and the corresponding (operation) results
781 // have no use.
782 //
783 // These arguments must be defined outside of
784 // the ForOp region and can just be forwarded after simplifying the op inits,
785 // yields and returns.
786 //
787 // The implementation uses `inlineBlockBefore` to steal the content of the
788 // original ForOp and avoid cloning.
789 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
791 
792  LogicalResult matchAndRewrite(scf::ForOp forOp,
793  PatternRewriter &rewriter) const final {
794  bool canonicalize = false;
795 
796  // An internal flat vector of block transfer
797  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
798  // transformed block argument mappings. This plays the role of a
799  // IRMapping for the particular use case of calling into
800  // `inlineBlockBefore`.
801  int64_t numResults = forOp.getNumResults();
802  SmallVector<bool, 4> keepMask;
803  keepMask.reserve(numResults);
804  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
805  newResultValues;
806  newBlockTransferArgs.reserve(1 + numResults);
807  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
808  newIterArgs.reserve(forOp.getInitArgs().size());
809  newYieldValues.reserve(numResults);
810  newResultValues.reserve(numResults);
811  for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
812  forOp.getRegionIterArgs(), // iter inside region
813  forOp.getResults(), // op results
814  forOp.getYieldedValues() // iter yield
815  )) {
816  // Forwarded is `true` when:
817  // 1) The region `iter` argument is yielded.
818  // 2) The region `iter` argument has no use, and the corresponding iter
819  // operand (input) is yielded.
820  // 3) The region `iter` argument has no use, and the corresponding op
821  // result has no use.
822  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
823  (std::get<1>(it).use_empty() &&
824  (std::get<0>(it) == std::get<3>(it) ||
825  std::get<2>(it).use_empty())));
826  keepMask.push_back(!forwarded);
827  canonicalize |= forwarded;
828  if (forwarded) {
829  newBlockTransferArgs.push_back(std::get<0>(it));
830  newResultValues.push_back(std::get<0>(it));
831  continue;
832  }
833  newIterArgs.push_back(std::get<0>(it));
834  newYieldValues.push_back(std::get<3>(it));
835  newBlockTransferArgs.push_back(Value()); // placeholder with null value
836  newResultValues.push_back(Value()); // placeholder with null value
837  }
838 
839  if (!canonicalize)
840  return failure();
841 
842  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
843  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
844  forOp.getStep(), newIterArgs);
845  newForOp->setAttrs(forOp->getAttrs());
846  Block &newBlock = newForOp.getRegion().front();
847 
848  // Replace the null placeholders with newly constructed values.
849  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
850  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
851  idx != e; ++idx) {
852  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
853  Value &newResultVal = newResultValues[idx];
854  assert((blockTransferArg && newResultVal) ||
855  (!blockTransferArg && !newResultVal));
856  if (!blockTransferArg) {
857  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
858  newResultVal = newForOp.getResult(collapsedIdx++);
859  }
860  }
861 
862  Block &oldBlock = forOp.getRegion().front();
863  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
864  "unexpected argument size mismatch");
865 
866  // No results case: the scf::ForOp builder already created a zero
867  // result terminator. Merge before this terminator and just get rid of the
868  // original terminator that has been merged in.
869  if (newIterArgs.empty()) {
870  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
871  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
872  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
873  rewriter.replaceOp(forOp, newResultValues);
874  return success();
875  }
876 
877  // No terminator case: merge and rewrite the merged terminator.
878  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
879  OpBuilder::InsertionGuard g(rewriter);
880  rewriter.setInsertionPoint(mergedTerminator);
881  SmallVector<Value, 4> filteredOperands;
882  filteredOperands.reserve(newResultValues.size());
883  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
884  if (keepMask[idx])
885  filteredOperands.push_back(mergedTerminator.getOperand(idx));
886  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
887  filteredOperands);
888  };
889 
890  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
891  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
892  cloneFilteredTerminator(mergedYieldOp);
893  rewriter.eraseOp(mergedYieldOp);
894  rewriter.replaceOp(forOp, newResultValues);
895  return success();
896  }
897 };
898 
899 /// Util function that tries to compute a constant diff between u and l.
900 /// Returns std::nullopt when the difference between two AffineValueMap is
901 /// dynamic.
902 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
903  IntegerAttr clb, cub;
904  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
905  llvm::APInt lbValue = clb.getValue();
906  llvm::APInt ubValue = cub.getValue();
907  return (ubValue - lbValue).getSExtValue();
908  }
909 
910  // Else a simple pattern match for x + c or c + x
911  llvm::APInt diff;
912  if (matchPattern(
913  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
914  matchPattern(
915  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
916  return diff.getSExtValue();
917  return std::nullopt;
918 }
919 
920 /// Rewriting pattern that erases loops that are known not to iterate, replaces
921 /// single-iteration loops with their bodies, and removes empty loops that
922 /// iterate at least once and only return values defined outside of the loop.
923 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
925 
926  LogicalResult matchAndRewrite(ForOp op,
927  PatternRewriter &rewriter) const override {
928  // If the upper bound is the same as the lower bound, the loop does not
929  // iterate, just remove it.
930  if (op.getLowerBound() == op.getUpperBound()) {
931  rewriter.replaceOp(op, op.getInitArgs());
932  return success();
933  }
934 
935  std::optional<int64_t> diff =
936  computeConstDiff(op.getLowerBound(), op.getUpperBound());
937  if (!diff)
938  return failure();
939 
940  // If the loop is known to have 0 iterations, remove it.
941  if (*diff <= 0) {
942  rewriter.replaceOp(op, op.getInitArgs());
943  return success();
944  }
945 
946  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
947  if (!maybeStepValue)
948  return failure();
949 
950  // If the loop is known to have 1 iteration, inline its body and remove the
951  // loop.
952  llvm::APInt stepValue = *maybeStepValue;
953  if (stepValue.sge(*diff)) {
954  SmallVector<Value, 4> blockArgs;
955  blockArgs.reserve(op.getInitArgs().size() + 1);
956  blockArgs.push_back(op.getLowerBound());
957  llvm::append_range(blockArgs, op.getInitArgs());
958  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
959  return success();
960  }
961 
962  // Now we are left with loops that have more than 1 iterations.
963  Block &block = op.getRegion().front();
964  if (!llvm::hasSingleElement(block))
965  return failure();
966  // If the loop is empty, iterates at least once, and only returns values
967  // defined outside of the loop, remove it and replace it with yield values.
968  if (llvm::any_of(op.getYieldedValues(),
969  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
970  return failure();
971  rewriter.replaceOp(op, op.getYieldedValues());
972  return success();
973  }
974 };
975 
976 /// Perform a replacement of one iter OpOperand of an scf.for to the
977 /// `replacement` value which is expected to be the source of a tensor.cast.
978 /// tensor.cast ops are inserted inside the block to account for the type cast.
979 static SmallVector<Value>
980 replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
981  Value replacement) {
982  Type oldType = operand.get().getType(), newType = replacement.getType();
983  assert(llvm::isa<RankedTensorType>(oldType) &&
984  llvm::isa<RankedTensorType>(newType) &&
985  "expected ranked tensor types");
986 
987  // 1. Create new iter operands, exactly 1 is replaced.
988  ForOp forOp = cast<ForOp>(operand.getOwner());
989  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
990  "expected an iter OpOperand");
991  assert(operand.get().getType() != replacement.getType() &&
992  "Expected a different type");
993  SmallVector<Value> newIterOperands;
994  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
995  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
996  newIterOperands.push_back(replacement);
997  continue;
998  }
999  newIterOperands.push_back(opOperand.get());
1000  }
1001 
1002  // 2. Create the new forOp shell.
1003  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
1004  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1005  forOp.getStep(), newIterOperands);
1006  newForOp->setAttrs(forOp->getAttrs());
1007  Block &newBlock = newForOp.getRegion().front();
1008  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
1009  newBlock.getArguments().end());
1010 
1011  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
1012  // corresponding to the `replacement` value.
1013  OpBuilder::InsertionGuard g(rewriter);
1014  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
1015  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1016  &newForOp->getOpOperand(operand.getOperandNumber()));
1017  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
1018  newRegionIterArg);
1019  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
1020 
1021  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1022  Block &oldBlock = forOp.getRegion().front();
1023  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1024 
1025  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
1026  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1027  rewriter.setInsertionPoint(clonedYieldOp);
1028  unsigned yieldIdx =
1029  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
1030  Value castOut = rewriter.create<tensor::CastOp>(
1031  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1032  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
1033  newYieldOperands[yieldIdx] = castOut;
1034  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1035  rewriter.eraseOp(clonedYieldOp);
1036 
1037  // 6. Inject an outgoing cast op after the forOp.
1038  rewriter.setInsertionPointAfter(newForOp);
1039  SmallVector<Value> newResults = newForOp.getResults();
1040  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
1041  newForOp.getLoc(), oldType, newResults[yieldIdx]);
1042 
1043  return newResults;
1044 }
1045 
1046 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1047 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1048 ///
1049 /// ```
1050 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1051 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1052 /// -> (tensor<?x?xf32>) {
1053 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1054 /// scf.yield %2 : tensor<?x?xf32>
1055 /// }
1056 /// use_of(%1)
1057 /// ```
1058 ///
1059 /// folds into:
1060 ///
1061 /// ```
1062 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1063 /// -> (tensor<32x1024xf32>) {
1064 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1065 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1066 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1067 /// scf.yield %4 : tensor<32x1024xf32>
1068 /// }
1069 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1070 /// use_of(%1)
1071 /// ```
1072 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1074 
1075  LogicalResult matchAndRewrite(ForOp op,
1076  PatternRewriter &rewriter) const override {
1077  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1078  OpOperand &iterOpOperand = std::get<0>(it);
1079  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1080  if (!incomingCast ||
1081  incomingCast.getSource().getType() == incomingCast.getType())
1082  continue;
1083  // If the dest type of the cast does not preserve static information in
1084  // the source type.
1086  incomingCast.getDest().getType(),
1087  incomingCast.getSource().getType()))
1088  continue;
1089  if (!std::get<1>(it).hasOneUse())
1090  continue;
1091 
1092  // Create a new ForOp with that iter operand replaced.
1093  rewriter.replaceOp(
1094  op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1095  incomingCast.getSource()));
1096  return success();
1097  }
1098  return failure();
1099  }
1100 };
1101 
1102 } // namespace
1103 
1104 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1105  MLIRContext *context) {
1106  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1107  context);
1108 }
1109 
1110 std::optional<APInt> ForOp::getConstantStep() {
1111  IntegerAttr step;
1112  if (matchPattern(getStep(), m_Constant(&step)))
1113  return step.getValue();
1114  return {};
1115 }
1116 
1117 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1118  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1119 }
1120 
1121 Speculation::Speculatability ForOp::getSpeculatability() {
1122  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1123  // and End.
1124  if (auto constantStep = getConstantStep())
1125  if (*constantStep == 1)
1127 
1128  // For Step != 1, the loop may not terminate. We can add more smarts here if
1129  // needed.
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // ForallOp
1135 //===----------------------------------------------------------------------===//
1136 
1138  unsigned numLoops = getRank();
1139  // Check number of outputs.
1140  if (getNumResults() != getOutputs().size())
1141  return emitOpError("produces ")
1142  << getNumResults() << " results, but has only "
1143  << getOutputs().size() << " outputs";
1144 
1145  // Check that the body defines block arguments for thread indices and outputs.
1146  auto *body = getBody();
1147  if (body->getNumArguments() != numLoops + getOutputs().size())
1148  return emitOpError("region expects ") << numLoops << " arguments";
1149  for (int64_t i = 0; i < numLoops; ++i)
1150  if (!body->getArgument(i).getType().isIndex())
1151  return emitOpError("expects ")
1152  << i << "-th block argument to be an index";
1153  for (unsigned i = 0; i < getOutputs().size(); ++i)
1154  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1155  return emitOpError("type mismatch between ")
1156  << i << "-th output and corresponding block argument";
1157  if (getMapping().has_value() && !getMapping()->empty()) {
1158  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1159  return emitOpError() << "mapping attribute size must match op rank";
1160  for (auto map : getMapping()->getValue()) {
1161  if (!isa<DeviceMappingAttrInterface>(map))
1162  return emitOpError()
1163  << getMappingAttrName() << " is not device mapping attribute";
1164  }
1165  }
1166 
1167  // Verify mixed static/dynamic control variables.
1168  Operation *op = getOperation();
1169  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1170  getStaticLowerBound(),
1171  getDynamicLowerBound())))
1172  return failure();
1173  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1174  getStaticUpperBound(),
1175  getDynamicUpperBound())))
1176  return failure();
1177  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1178  getStaticStep(), getDynamicStep())))
1179  return failure();
1180 
1181  return success();
1182 }
1183 
1184 void ForallOp::print(OpAsmPrinter &p) {
1185  Operation *op = getOperation();
1186  p << " (" << getInductionVars();
1187  if (isNormalized()) {
1188  p << ") in ";
1189  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1190  /*valueTypes=*/{}, /*scalables=*/{},
1192  } else {
1193  p << ") = ";
1194  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1195  /*valueTypes=*/{}, /*scalables=*/{},
1197  p << " to ";
1198  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1199  /*valueTypes=*/{}, /*scalables=*/{},
1201  p << " step ";
1202  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1203  /*valueTypes=*/{}, /*scalables=*/{},
1205  }
1206  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1207  p << " ";
1208  if (!getRegionOutArgs().empty())
1209  p << "-> (" << getResultTypes() << ") ";
1210  p.printRegion(getRegion(),
1211  /*printEntryBlockArgs=*/false,
1212  /*printBlockTerminators=*/getNumResults() > 0);
1213  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1214  getStaticLowerBoundAttrName(),
1215  getStaticUpperBoundAttrName(),
1216  getStaticStepAttrName()});
1217 }
1218 
1220  OpBuilder b(parser.getContext());
1221  auto indexType = b.getIndexType();
1222 
1223  // Parse an opening `(` followed by thread index variables followed by `)`
1224  // TODO: when we can refer to such "induction variable"-like handles from the
1225  // declarative assembly format, we can implement the parser as a custom hook.
1228  return failure();
1229 
1230  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1231  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1232  dynamicSteps;
1233  if (succeeded(parser.parseOptionalKeyword("in"))) {
1234  // Parse upper bounds.
1235  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1236  /*valueTypes=*/nullptr,
1238  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1239  return failure();
1240 
1241  unsigned numLoops = ivs.size();
1242  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1243  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1244  } else {
1245  // Parse lower bounds.
1246  if (parser.parseEqual() ||
1247  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1248  /*valueTypes=*/nullptr,
1250 
1251  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1252  return failure();
1253 
1254  // Parse upper bounds.
1255  if (parser.parseKeyword("to") ||
1256  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1257  /*valueTypes=*/nullptr,
1259  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1260  return failure();
1261 
1262  // Parse step values.
1263  if (parser.parseKeyword("step") ||
1264  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1265  /*valueTypes=*/nullptr,
1267  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1268  return failure();
1269  }
1270 
1271  // Parse out operands and results.
1274  SMLoc outOperandsLoc = parser.getCurrentLocation();
1275  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1276  if (outOperands.size() != result.types.size())
1277  return parser.emitError(outOperandsLoc,
1278  "mismatch between out operands and types");
1279  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1280  parser.parseOptionalArrowTypeList(result.types) ||
1281  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1282  result.operands))
1283  return failure();
1284  }
1285 
1286  // Parse region.
1288  std::unique_ptr<Region> region = std::make_unique<Region>();
1289  for (auto &iv : ivs) {
1290  iv.type = b.getIndexType();
1291  regionArgs.push_back(iv);
1292  }
1293  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1294  auto &out = it.value();
1295  out.type = result.types[it.index()];
1296  regionArgs.push_back(out);
1297  }
1298  if (parser.parseRegion(*region, regionArgs))
1299  return failure();
1300 
1301  // Ensure terminator and move region.
1302  ForallOp::ensureTerminator(*region, b, result.location);
1303  result.addRegion(std::move(region));
1304 
1305  // Parse the optional attribute list.
1306  if (parser.parseOptionalAttrDict(result.attributes))
1307  return failure();
1308 
1309  result.addAttribute("staticLowerBound", staticLbs);
1310  result.addAttribute("staticUpperBound", staticUbs);
1311  result.addAttribute("staticStep", staticSteps);
1312  result.addAttribute("operandSegmentSizes",
1314  {static_cast<int32_t>(dynamicLbs.size()),
1315  static_cast<int32_t>(dynamicUbs.size()),
1316  static_cast<int32_t>(dynamicSteps.size()),
1317  static_cast<int32_t>(outOperands.size())}));
1318  return success();
1319 }
1320 
1321 // Builder that takes loop bounds.
1322 void ForallOp::build(
1325  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1326  std::optional<ArrayAttr> mapping,
1327  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1328  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1329  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1330  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1331  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1332  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1333 
1334  result.addOperands(dynamicLbs);
1335  result.addOperands(dynamicUbs);
1336  result.addOperands(dynamicSteps);
1337  result.addOperands(outputs);
1338  result.addTypes(TypeRange(outputs));
1339 
1340  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1341  b.getDenseI64ArrayAttr(staticLbs));
1342  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1343  b.getDenseI64ArrayAttr(staticUbs));
1344  result.addAttribute(getStaticStepAttrName(result.name),
1345  b.getDenseI64ArrayAttr(staticSteps));
1346  result.addAttribute(
1347  "operandSegmentSizes",
1348  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1349  static_cast<int32_t>(dynamicUbs.size()),
1350  static_cast<int32_t>(dynamicSteps.size()),
1351  static_cast<int32_t>(outputs.size())}));
1352  if (mapping.has_value()) {
1354  mapping.value());
1355  }
1356 
1357  Region *bodyRegion = result.addRegion();
1359  b.createBlock(bodyRegion);
1360  Block &bodyBlock = bodyRegion->front();
1361 
1362  // Add block arguments for indices and outputs.
1363  bodyBlock.addArguments(
1364  SmallVector<Type>(lbs.size(), b.getIndexType()),
1365  SmallVector<Location>(staticLbs.size(), result.location));
1366  bodyBlock.addArguments(
1367  TypeRange(outputs),
1368  SmallVector<Location>(outputs.size(), result.location));
1369 
1370  b.setInsertionPointToStart(&bodyBlock);
1371  if (!bodyBuilderFn) {
1372  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1373  return;
1374  }
1375  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1376 }
1377 
1378 // Builder that takes loop bounds.
1379 void ForallOp::build(
1381  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1382  std::optional<ArrayAttr> mapping,
1383  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1384  unsigned numLoops = ubs.size();
1385  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1386  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1387  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1388 }
1389 
1390 // Checks if the lbs are zeros and steps are ones.
1391 bool ForallOp::isNormalized() {
1392  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1393  return llvm::all_of(results, [&](OpFoldResult ofr) {
1394  auto intValue = getConstantIntValue(ofr);
1395  return intValue.has_value() && intValue == val;
1396  });
1397  };
1398  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1399 }
1400 
1401 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1402 // unaware of the fact that our terminator also needs a region to be
1403 // well-formed. We override it here to ensure that we do the right thing.
1404 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1405  Location loc) {
1407  ForallOp>::ensureTerminator(region, builder, loc);
1408  auto terminator =
1409  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1410  if (terminator.getRegion().empty())
1411  builder.createBlock(&terminator.getRegion());
1412 }
1413 
1414 InParallelOp ForallOp::getTerminator() {
1415  return cast<InParallelOp>(getBody()->getTerminator());
1416 }
1417 
1418 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1419  SmallVector<Operation *> storeOps;
1420  InParallelOp inParallelOp = getTerminator();
1421  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1422  if (auto parallelInsertSliceOp =
1423  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1424  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1425  storeOps.push_back(parallelInsertSliceOp);
1426  }
1427  }
1428  return storeOps;
1429 }
1430 
1431 std::optional<Value> ForallOp::getSingleInductionVar() {
1432  if (getRank() != 1)
1433  return std::nullopt;
1434  return getInductionVar(0);
1435 }
1436 
1437 std::optional<OpFoldResult> ForallOp::getSingleLowerBound() {
1438  if (getRank() != 1)
1439  return std::nullopt;
1440  return getMixedLowerBound()[0];
1441 }
1442 
1443 std::optional<OpFoldResult> ForallOp::getSingleUpperBound() {
1444  if (getRank() != 1)
1445  return std::nullopt;
1446  return getMixedUpperBound()[0];
1447 }
1448 
1449 std::optional<OpFoldResult> ForallOp::getSingleStep() {
1450  if (getRank() != 1)
1451  return std::nullopt;
1452  return getMixedStep()[0];
1453 }
1454 
1456  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1457  if (!tidxArg)
1458  return ForallOp();
1459  assert(tidxArg.getOwner() && "unlinked block argument");
1460  auto *containingOp = tidxArg.getOwner()->getParentOp();
1461  return dyn_cast<ForallOp>(containingOp);
1462 }
1463 
1464 namespace {
1465 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1466 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1468 
1469  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1470  PatternRewriter &rewriter) const final {
1471  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1472  if (!forallOp)
1473  return failure();
1474  Value sharedOut =
1475  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1476  ->get();
1477  rewriter.modifyOpInPlace(
1478  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1479  return success();
1480  }
1481 };
1482 
1483 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1484 public:
1486 
1487  LogicalResult matchAndRewrite(ForallOp op,
1488  PatternRewriter &rewriter) const override {
1489  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1490  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1491  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1492  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1493  failed(foldDynamicIndexList(mixedUpperBound)) &&
1494  failed(foldDynamicIndexList(mixedStep)))
1495  return failure();
1496 
1497  rewriter.modifyOpInPlace(op, [&]() {
1498  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1499  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1500  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1501  staticLowerBound);
1502  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1503  op.setStaticLowerBound(staticLowerBound);
1504 
1505  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1506  staticUpperBound);
1507  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1508  op.setStaticUpperBound(staticUpperBound);
1509 
1510  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1511  op.getDynamicStepMutable().assign(dynamicStep);
1512  op.setStaticStep(staticStep);
1513 
1514  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1515  rewriter.getDenseI32ArrayAttr(
1516  {static_cast<int32_t>(dynamicLowerBound.size()),
1517  static_cast<int32_t>(dynamicUpperBound.size()),
1518  static_cast<int32_t>(dynamicStep.size()),
1519  static_cast<int32_t>(op.getNumResults())}));
1520  });
1521  return success();
1522  }
1523 };
1524 
1525 /// The following canonicalization pattern folds the iter arguments of
1526 /// scf.forall op if :-
1527 /// 1. The corresponding result has zero uses.
1528 /// 2. The iter argument is NOT being modified within the loop body.
1529 /// uses.
1530 ///
1531 /// Example of first case :-
1532 /// INPUT:
1533 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1534 /// {
1535 /// ...
1536 /// <SOME USE OF %arg0>
1537 /// <SOME USE OF %arg1>
1538 /// <SOME USE OF %arg2>
1539 /// ...
1540 /// scf.forall.in_parallel {
1541 /// <STORE OP WITH DESTINATION %arg1>
1542 /// <STORE OP WITH DESTINATION %arg0>
1543 /// <STORE OP WITH DESTINATION %arg2>
1544 /// }
1545 /// }
1546 /// return %res#1
1547 ///
1548 /// OUTPUT:
1549 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1550 /// {
1551 /// ...
1552 /// <SOME USE OF %a>
1553 /// <SOME USE OF %new_arg0>
1554 /// <SOME USE OF %c>
1555 /// ...
1556 /// scf.forall.in_parallel {
1557 /// <STORE OP WITH DESTINATION %new_arg0>
1558 /// }
1559 /// }
1560 /// return %res
1561 ///
1562 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1563 /// scf.forall is replaced by their corresponding operands.
1564 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1565 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1566 /// this canonicalization remains valid. For more details, please refer
1567 /// to :
1568 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1569 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1570 /// handles tensor.parallel_insert_slice ops only.
1571 ///
1572 /// Example of second case :-
1573 /// INPUT:
1574 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1575 /// {
1576 /// ...
1577 /// <SOME USE OF %arg0>
1578 /// <SOME USE OF %arg1>
1579 /// ...
1580 /// scf.forall.in_parallel {
1581 /// <STORE OP WITH DESTINATION %arg1>
1582 /// }
1583 /// }
1584 /// return %res#0, %res#1
1585 ///
1586 /// OUTPUT:
1587 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1588 /// {
1589 /// ...
1590 /// <SOME USE OF %a>
1591 /// <SOME USE OF %new_arg0>
1592 /// ...
1593 /// scf.forall.in_parallel {
1594 /// <STORE OP WITH DESTINATION %new_arg0>
1595 /// }
1596 /// }
1597 /// return %a, %res
1598 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1600 
1601  LogicalResult matchAndRewrite(ForallOp forallOp,
1602  PatternRewriter &rewriter) const final {
1603  // Step 1: For a given i-th result of scf.forall, check the following :-
1604  // a. If it has any use.
1605  // b. If the corresponding iter argument is being modified within
1606  // the loop, i.e. has at least one store op with the iter arg as
1607  // its destination operand. For this we use
1608  // ForallOp::getCombiningOps(iter_arg).
1609  //
1610  // Based on the check we maintain the following :-
1611  // a. `resultToDelete` - i-th result of scf.forall that'll be
1612  // deleted.
1613  // b. `resultToReplace` - i-th result of the old scf.forall
1614  // whose uses will be replaced by the new scf.forall.
1615  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1616  // corresponding to the i-th result with at least one use.
1617  SetVector<OpResult> resultToDelete;
1618  SmallVector<Value> resultToReplace;
1619  SmallVector<Value> newOuts;
1620  for (OpResult result : forallOp.getResults()) {
1621  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1622  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1623  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1624  resultToDelete.insert(result);
1625  } else {
1626  resultToReplace.push_back(result);
1627  newOuts.push_back(opOperand->get());
1628  }
1629  }
1630 
1631  // Return early if all results of scf.forall have at least one use and being
1632  // modified within the loop.
1633  if (resultToDelete.empty())
1634  return failure();
1635 
1636  // Step 2: For the the i-th result, do the following :-
1637  // a. Fetch the corresponding BlockArgument.
1638  // b. Look for store ops (currently tensor.parallel_insert_slice)
1639  // with the BlockArgument as its destination operand.
1640  // c. Remove the operations fetched in b.
1641  for (OpResult result : resultToDelete) {
1642  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1643  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1644  SmallVector<Operation *> combiningOps =
1645  forallOp.getCombiningOps(blockArg);
1646  for (Operation *combiningOp : combiningOps)
1647  rewriter.eraseOp(combiningOp);
1648  }
1649 
1650  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1651  // fetched earlier
1652  auto newForallOp = rewriter.create<scf::ForallOp>(
1653  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1654  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1655  forallOp.getMapping(),
1656  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1657 
1658  // Step 4. Merge the block of the old scf.forall into the newly created
1659  // scf.forall using the new set of arguments.
1660  Block *loopBody = forallOp.getBody();
1661  Block *newLoopBody = newForallOp.getBody();
1662  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1663  // Form initial new bbArg list with just the control operands of the new
1664  // scf.forall op.
1665  SmallVector<Value> newBlockArgs =
1666  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1667  [](BlockArgument b) -> Value { return b; });
1668  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1669  unsigned index = 0;
1670  // Take the new corresponding bbArg if the old bbArg was used as a
1671  // destination in the in_parallel op. For all other bbArgs, use the
1672  // corresponding init_arg from the old scf.forall op.
1673  for (OpResult result : forallOp.getResults()) {
1674  if (resultToDelete.count(result)) {
1675  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1676  } else {
1677  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1678  }
1679  }
1680  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1681 
1682  // Step 5. Replace the uses of result of old scf.forall with that of the new
1683  // scf.forall.
1684  for (auto &&[oldResult, newResult] :
1685  llvm::zip(resultToReplace, newForallOp->getResults()))
1686  rewriter.replaceAllUsesWith(oldResult, newResult);
1687 
1688  // Step 6. Replace the uses of those values that either has no use or are
1689  // not being modified within the loop with the corresponding
1690  // OpOperand.
1691  for (OpResult oldResult : resultToDelete)
1692  rewriter.replaceAllUsesWith(oldResult,
1693  forallOp.getTiedOpOperand(oldResult)->get());
1694  return success();
1695  }
1696 };
1697 
1698 struct ForallOpSingleOrZeroIterationDimsFolder
1699  : public OpRewritePattern<ForallOp> {
1701 
1702  LogicalResult matchAndRewrite(ForallOp op,
1703  PatternRewriter &rewriter) const override {
1704  // Do not fold dimensions if they are mapped to processing units.
1705  if (op.getMapping().has_value())
1706  return failure();
1707  Location loc = op.getLoc();
1708 
1709  // Compute new loop bounds that omit all single-iteration loop dimensions.
1710  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1711  newMixedSteps;
1712  IRMapping mapping;
1713  for (auto [lb, ub, step, iv] :
1714  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1715  op.getMixedStep(), op.getInductionVars())) {
1716  auto numIterations = constantTripCount(lb, ub, step);
1717  if (numIterations.has_value()) {
1718  // Remove the loop if it performs zero iterations.
1719  if (*numIterations == 0) {
1720  rewriter.replaceOp(op, op.getOutputs());
1721  return success();
1722  }
1723  // Replace the loop induction variable by the lower bound if the loop
1724  // performs a single iteration. Otherwise, copy the loop bounds.
1725  if (*numIterations == 1) {
1726  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1727  continue;
1728  }
1729  }
1730  newMixedLowerBounds.push_back(lb);
1731  newMixedUpperBounds.push_back(ub);
1732  newMixedSteps.push_back(step);
1733  }
1734  // Exit if none of the loop dimensions perform a single iteration.
1735  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1736  return rewriter.notifyMatchFailure(
1737  op, "no dimensions have 0 or 1 iterations");
1738  }
1739 
1740  // All of the loop dimensions perform a single iteration. Inline loop body.
1741  if (newMixedLowerBounds.empty()) {
1742  promote(rewriter, op);
1743  return success();
1744  }
1745 
1746  // Replace the loop by a lower-dimensional loop.
1747  ForallOp newOp;
1748  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1749  newMixedUpperBounds, newMixedSteps,
1750  op.getOutputs(), std::nullopt, nullptr);
1751  newOp.getBodyRegion().getBlocks().clear();
1752  // The new loop needs to keep all attributes from the old one, except for
1753  // "operandSegmentSizes" and static loop bound attributes which capture
1754  // the outdated information of the old iteration domain.
1755  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1756  newOp.getStaticLowerBoundAttrName(),
1757  newOp.getStaticUpperBoundAttrName(),
1758  newOp.getStaticStepAttrName()};
1759  for (const auto &namedAttr : op->getAttrs()) {
1760  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1761  continue;
1762  rewriter.modifyOpInPlace(newOp, [&]() {
1763  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1764  });
1765  }
1766  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1767  newOp.getRegion().begin(), mapping);
1768  rewriter.replaceOp(op, newOp.getResults());
1769  return success();
1770  }
1771 };
1772 
1773 struct FoldTensorCastOfOutputIntoForallOp
1774  : public OpRewritePattern<scf::ForallOp> {
1776 
1777  struct TypeCast {
1778  Type srcType;
1779  Type dstType;
1780  };
1781 
1782  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1783  PatternRewriter &rewriter) const final {
1784  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1785  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1786  for (auto en : llvm::enumerate(newOutputTensors)) {
1787  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1788  if (!castOp)
1789  continue;
1790 
1791  // Only casts that that preserve static information, i.e. will make the
1792  // loop result type "more" static than before, will be folded.
1793  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1794  castOp.getSource().getType())) {
1795  continue;
1796  }
1797 
1798  tensorCastProducers[en.index()] =
1799  TypeCast{castOp.getSource().getType(), castOp.getType()};
1800  newOutputTensors[en.index()] = castOp.getSource();
1801  }
1802 
1803  if (tensorCastProducers.empty())
1804  return failure();
1805 
1806  // Create new loop.
1807  Location loc = forallOp.getLoc();
1808  auto newForallOp = rewriter.create<ForallOp>(
1809  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1810  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1811  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1812  auto castBlockArgs =
1813  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1814  for (auto [index, cast] : tensorCastProducers) {
1815  Value &oldTypeBBArg = castBlockArgs[index];
1816  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1817  nestedLoc, cast.dstType, oldTypeBBArg);
1818  }
1819 
1820  // Move old body into new parallel loop.
1821  SmallVector<Value> ivsBlockArgs =
1822  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1823  ivsBlockArgs.append(castBlockArgs);
1824  rewriter.mergeBlocks(forallOp.getBody(),
1825  bbArgs.front().getParentBlock(), ivsBlockArgs);
1826  });
1827 
1828  // After `mergeBlocks` happened, the destinations in the terminator were
1829  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1830  // destination have to be updated to point to the output bbArgs directly.
1831  auto terminator = newForallOp.getTerminator();
1832  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1833  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1834  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1835  insertSliceOp.getDestMutable().assign(outputBlockArg);
1836  }
1837 
1838  // Cast results back to the original types.
1839  rewriter.setInsertionPointAfter(newForallOp);
1840  SmallVector<Value> castResults = newForallOp.getResults();
1841  for (auto &item : tensorCastProducers) {
1842  Value &oldTypeResult = castResults[item.first];
1843  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1844  oldTypeResult);
1845  }
1846  rewriter.replaceOp(forallOp, castResults);
1847  return success();
1848  }
1849 };
1850 
1851 } // namespace
1852 
1853 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1854  MLIRContext *context) {
1855  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1856  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1857  ForallOpSingleOrZeroIterationDimsFolder>(context);
1858 }
1859 
1860 /// Given the region at `index`, or the parent operation if `index` is None,
1861 /// return the successor regions. These are the regions that may be selected
1862 /// during the flow of control. `operands` is a set of optional attributes that
1863 /// correspond to a constant value for each operand, or null if that operand is
1864 /// not a constant.
1865 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1867  // Both the operation itself and the region may be branching into the body or
1868  // back into the operation itself. It is possible for loop not to enter the
1869  // body.
1870  regions.push_back(RegionSuccessor(&getRegion()));
1871  regions.push_back(RegionSuccessor());
1872 }
1873 
1874 //===----------------------------------------------------------------------===//
1875 // InParallelOp
1876 //===----------------------------------------------------------------------===//
1877 
1878 // Build a InParallelOp with mixed static and dynamic entries.
1879 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1881  Region *bodyRegion = result.addRegion();
1882  b.createBlock(bodyRegion);
1883 }
1884 
1886  scf::ForallOp forallOp =
1887  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1888  if (!forallOp)
1889  return this->emitOpError("expected forall op parent");
1890 
1891  // TODO: InParallelOpInterface.
1892  for (Operation &op : getRegion().front().getOperations()) {
1893  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1894  return this->emitOpError("expected only ")
1895  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1896  }
1897 
1898  // Verify that inserts are into out block arguments.
1899  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1900  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1901  if (!llvm::is_contained(regionOutArgs, dest))
1902  return op.emitOpError("may only insert into an output block argument");
1903  }
1904  return success();
1905 }
1906 
1908  p << " ";
1909  p.printRegion(getRegion(),
1910  /*printEntryBlockArgs=*/false,
1911  /*printBlockTerminators=*/false);
1912  p.printOptionalAttrDict(getOperation()->getAttrs());
1913 }
1914 
1916  auto &builder = parser.getBuilder();
1917 
1919  std::unique_ptr<Region> region = std::make_unique<Region>();
1920  if (parser.parseRegion(*region, regionOperands))
1921  return failure();
1922 
1923  if (region->empty())
1924  OpBuilder(builder.getContext()).createBlock(region.get());
1925  result.addRegion(std::move(region));
1926 
1927  // Parse the optional attribute list.
1928  if (parser.parseOptionalAttrDict(result.attributes))
1929  return failure();
1930  return success();
1931 }
1932 
1933 OpResult InParallelOp::getParentResult(int64_t idx) {
1934  return getOperation()->getParentOp()->getResult(idx);
1935 }
1936 
1937 SmallVector<BlockArgument> InParallelOp::getDests() {
1938  return llvm::to_vector<4>(
1939  llvm::map_range(getYieldingOps(), [](Operation &op) {
1940  // Add new ops here as needed.
1941  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1942  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1943  }));
1944 }
1945 
1946 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1947  return getRegion().front().getOperations();
1948 }
1949 
1950 //===----------------------------------------------------------------------===//
1951 // IfOp
1952 //===----------------------------------------------------------------------===//
1953 
1955  assert(a && "expected non-empty operation");
1956  assert(b && "expected non-empty operation");
1957 
1958  IfOp ifOp = a->getParentOfType<IfOp>();
1959  while (ifOp) {
1960  // Check if b is inside ifOp. (We already know that a is.)
1961  if (ifOp->isProperAncestor(b))
1962  // b is contained in ifOp. a and b are in mutually exclusive branches if
1963  // they are in different blocks of ifOp.
1964  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1965  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1966  // Check next enclosing IfOp.
1967  ifOp = ifOp->getParentOfType<IfOp>();
1968  }
1969 
1970  // Could not find a common IfOp among a's and b's ancestors.
1971  return false;
1972 }
1973 
1975 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1976  IfOp::Adaptor adaptor,
1977  SmallVectorImpl<Type> &inferredReturnTypes) {
1978  if (adaptor.getRegions().empty())
1979  return failure();
1980  Region *r = &adaptor.getThenRegion();
1981  if (r->empty())
1982  return failure();
1983  Block &b = r->front();
1984  if (b.empty())
1985  return failure();
1986  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1987  if (!yieldOp)
1988  return failure();
1989  TypeRange types = yieldOp.getOperandTypes();
1990  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1991  types.end());
1992  return success();
1993 }
1994 
1995 void IfOp::build(OpBuilder &builder, OperationState &result,
1996  TypeRange resultTypes, Value cond) {
1997  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1998  /*addElseBlock=*/false);
1999 }
2000 
2001 void IfOp::build(OpBuilder &builder, OperationState &result,
2002  TypeRange resultTypes, Value cond, bool addThenBlock,
2003  bool addElseBlock) {
2004  assert((!addElseBlock || addThenBlock) &&
2005  "must not create else block w/o then block");
2006  result.addTypes(resultTypes);
2007  result.addOperands(cond);
2008 
2009  // Add regions and blocks.
2010  OpBuilder::InsertionGuard guard(builder);
2011  Region *thenRegion = result.addRegion();
2012  if (addThenBlock)
2013  builder.createBlock(thenRegion);
2014  Region *elseRegion = result.addRegion();
2015  if (addElseBlock)
2016  builder.createBlock(elseRegion);
2017 }
2018 
2019 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2020  bool withElseRegion) {
2021  build(builder, result, TypeRange{}, cond, withElseRegion);
2022 }
2023 
2024 void IfOp::build(OpBuilder &builder, OperationState &result,
2025  TypeRange resultTypes, Value cond, bool withElseRegion) {
2026  result.addTypes(resultTypes);
2027  result.addOperands(cond);
2028 
2029  // Build then region.
2030  OpBuilder::InsertionGuard guard(builder);
2031  Region *thenRegion = result.addRegion();
2032  builder.createBlock(thenRegion);
2033  if (resultTypes.empty())
2034  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2035 
2036  // Build else region.
2037  Region *elseRegion = result.addRegion();
2038  if (withElseRegion) {
2039  builder.createBlock(elseRegion);
2040  if (resultTypes.empty())
2041  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2042  }
2043 }
2044 
2045 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2046  function_ref<void(OpBuilder &, Location)> thenBuilder,
2047  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2048  assert(thenBuilder && "the builder callback for 'then' must be present");
2049  result.addOperands(cond);
2050 
2051  // Build then region.
2052  OpBuilder::InsertionGuard guard(builder);
2053  Region *thenRegion = result.addRegion();
2054  builder.createBlock(thenRegion);
2055  thenBuilder(builder, result.location);
2056 
2057  // Build else region.
2058  Region *elseRegion = result.addRegion();
2059  if (elseBuilder) {
2060  builder.createBlock(elseRegion);
2061  elseBuilder(builder, result.location);
2062  }
2063 
2064  // Infer result types.
2065  SmallVector<Type> inferredReturnTypes;
2066  MLIRContext *ctx = builder.getContext();
2067  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2068  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2069  /*properties=*/nullptr, result.regions,
2070  inferredReturnTypes))) {
2071  result.addTypes(inferredReturnTypes);
2072  }
2073 }
2074 
2076  if (getNumResults() != 0 && getElseRegion().empty())
2077  return emitOpError("must have an else block if defining values");
2078  return success();
2079 }
2080 
2082  // Create the regions for 'then'.
2083  result.regions.reserve(2);
2084  Region *thenRegion = result.addRegion();
2085  Region *elseRegion = result.addRegion();
2086 
2087  auto &builder = parser.getBuilder();
2089  Type i1Type = builder.getIntegerType(1);
2090  if (parser.parseOperand(cond) ||
2091  parser.resolveOperand(cond, i1Type, result.operands))
2092  return failure();
2093  // Parse optional results type list.
2094  if (parser.parseOptionalArrowTypeList(result.types))
2095  return failure();
2096  // Parse the 'then' region.
2097  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2098  return failure();
2099  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2100 
2101  // If we find an 'else' keyword then parse the 'else' region.
2102  if (!parser.parseOptionalKeyword("else")) {
2103  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2104  return failure();
2105  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2106  }
2107 
2108  // Parse the optional attribute list.
2109  if (parser.parseOptionalAttrDict(result.attributes))
2110  return failure();
2111  return success();
2112 }
2113 
2114 void IfOp::print(OpAsmPrinter &p) {
2115  bool printBlockTerminators = false;
2116 
2117  p << " " << getCondition();
2118  if (!getResults().empty()) {
2119  p << " -> (" << getResultTypes() << ")";
2120  // Print yield explicitly if the op defines values.
2121  printBlockTerminators = true;
2122  }
2123  p << ' ';
2124  p.printRegion(getThenRegion(),
2125  /*printEntryBlockArgs=*/false,
2126  /*printBlockTerminators=*/printBlockTerminators);
2127 
2128  // Print the 'else' regions if it exists and has a block.
2129  auto &elseRegion = getElseRegion();
2130  if (!elseRegion.empty()) {
2131  p << " else ";
2132  p.printRegion(elseRegion,
2133  /*printEntryBlockArgs=*/false,
2134  /*printBlockTerminators=*/printBlockTerminators);
2135  }
2136 
2137  p.printOptionalAttrDict((*this)->getAttrs());
2138 }
2139 
2140 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2142  // The `then` and the `else` region branch back to the parent operation.
2143  if (!point.isParent()) {
2144  regions.push_back(RegionSuccessor(getResults()));
2145  return;
2146  }
2147 
2148  regions.push_back(RegionSuccessor(&getThenRegion()));
2149 
2150  // Don't consider the else region if it is empty.
2151  Region *elseRegion = &this->getElseRegion();
2152  if (elseRegion->empty())
2153  regions.push_back(RegionSuccessor());
2154  else
2155  regions.push_back(RegionSuccessor(elseRegion));
2156 }
2157 
2158 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2160  FoldAdaptor adaptor(operands, *this);
2161  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2162  if (!boolAttr || boolAttr.getValue())
2163  regions.emplace_back(&getThenRegion());
2164 
2165  // If the else region is empty, execution continues after the parent op.
2166  if (!boolAttr || !boolAttr.getValue()) {
2167  if (!getElseRegion().empty())
2168  regions.emplace_back(&getElseRegion());
2169  else
2170  regions.emplace_back(getResults());
2171  }
2172 }
2173 
2174 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2175  SmallVectorImpl<OpFoldResult> &results) {
2176  // if (!c) then A() else B() -> if c then B() else A()
2177  if (getElseRegion().empty())
2178  return failure();
2179 
2180  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2181  if (!xorStmt)
2182  return failure();
2183 
2184  if (!matchPattern(xorStmt.getRhs(), m_One()))
2185  return failure();
2186 
2187  getConditionMutable().assign(xorStmt.getLhs());
2188  Block *thenBlock = &getThenRegion().front();
2189  // It would be nicer to use iplist::swap, but that has no implemented
2190  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2191  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2192  getElseRegion().getBlocks());
2193  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2194  getThenRegion().getBlocks(), thenBlock);
2195  return success();
2196 }
2197 
2198 void IfOp::getRegionInvocationBounds(
2199  ArrayRef<Attribute> operands,
2200  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2201  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2202  // If the condition is known, then one region is known to be executed once
2203  // and the other zero times.
2204  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2205  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2206  } else {
2207  // Non-constant condition. Each region may be executed 0 or 1 times.
2208  invocationBounds.assign(2, {0, 1});
2209  }
2210 }
2211 
2212 namespace {
2213 // Pattern to remove unused IfOp results.
2214 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2216 
2217  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2218  PatternRewriter &rewriter) const {
2219  // Move all operations to the destination block.
2220  rewriter.mergeBlocks(source, dest);
2221  // Replace the yield op by one that returns only the used values.
2222  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2223  SmallVector<Value, 4> usedOperands;
2224  llvm::transform(usedResults, std::back_inserter(usedOperands),
2225  [&](OpResult result) {
2226  return yieldOp.getOperand(result.getResultNumber());
2227  });
2228  rewriter.modifyOpInPlace(yieldOp,
2229  [&]() { yieldOp->setOperands(usedOperands); });
2230  }
2231 
2232  LogicalResult matchAndRewrite(IfOp op,
2233  PatternRewriter &rewriter) const override {
2234  // Compute the list of used results.
2235  SmallVector<OpResult, 4> usedResults;
2236  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2237  [](OpResult result) { return !result.use_empty(); });
2238 
2239  // Replace the operation if only a subset of its results have uses.
2240  if (usedResults.size() == op.getNumResults())
2241  return failure();
2242 
2243  // Compute the result types of the replacement operation.
2244  SmallVector<Type, 4> newTypes;
2245  llvm::transform(usedResults, std::back_inserter(newTypes),
2246  [](OpResult result) { return result.getType(); });
2247 
2248  // Create a replacement operation with empty then and else regions.
2249  auto newOp =
2250  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2251  rewriter.createBlock(&newOp.getThenRegion());
2252  rewriter.createBlock(&newOp.getElseRegion());
2253 
2254  // Move the bodies and replace the terminators (note there is a then and
2255  // an else region since the operation returns results).
2256  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2257  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2258 
2259  // Replace the operation by the new one.
2260  SmallVector<Value, 4> repResults(op.getNumResults());
2261  for (const auto &en : llvm::enumerate(usedResults))
2262  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2263  rewriter.replaceOp(op, repResults);
2264  return success();
2265  }
2266 };
2267 
2268 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2270 
2271  LogicalResult matchAndRewrite(IfOp op,
2272  PatternRewriter &rewriter) const override {
2273  BoolAttr condition;
2274  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2275  return failure();
2276 
2277  if (condition.getValue())
2278  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2279  else if (!op.getElseRegion().empty())
2280  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2281  else
2282  rewriter.eraseOp(op);
2283 
2284  return success();
2285  }
2286 };
2287 
2288 /// Hoist any yielded results whose operands are defined outside
2289 /// the if, to a select instruction.
2290 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2292 
2293  LogicalResult matchAndRewrite(IfOp op,
2294  PatternRewriter &rewriter) const override {
2295  if (op->getNumResults() == 0)
2296  return failure();
2297 
2298  auto cond = op.getCondition();
2299  auto thenYieldArgs = op.thenYield().getOperands();
2300  auto elseYieldArgs = op.elseYield().getOperands();
2301 
2302  SmallVector<Type> nonHoistable;
2303  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2304  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2305  &op.getElseRegion() == falseVal.getParentRegion())
2306  nonHoistable.push_back(trueVal.getType());
2307  }
2308  // Early exit if there aren't any yielded values we can
2309  // hoist outside the if.
2310  if (nonHoistable.size() == op->getNumResults())
2311  return failure();
2312 
2313  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2314  /*withElseRegion=*/false);
2315  if (replacement.thenBlock())
2316  rewriter.eraseBlock(replacement.thenBlock());
2317  replacement.getThenRegion().takeBody(op.getThenRegion());
2318  replacement.getElseRegion().takeBody(op.getElseRegion());
2319 
2320  SmallVector<Value> results(op->getNumResults());
2321  assert(thenYieldArgs.size() == results.size());
2322  assert(elseYieldArgs.size() == results.size());
2323 
2324  SmallVector<Value> trueYields;
2325  SmallVector<Value> falseYields;
2326  rewriter.setInsertionPoint(replacement);
2327  for (const auto &it :
2328  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2329  Value trueVal = std::get<0>(it.value());
2330  Value falseVal = std::get<1>(it.value());
2331  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2332  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2333  results[it.index()] = replacement.getResult(trueYields.size());
2334  trueYields.push_back(trueVal);
2335  falseYields.push_back(falseVal);
2336  } else if (trueVal == falseVal)
2337  results[it.index()] = trueVal;
2338  else
2339  results[it.index()] = rewriter.create<arith::SelectOp>(
2340  op.getLoc(), cond, trueVal, falseVal);
2341  }
2342 
2343  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2344  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2345 
2346  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2347  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2348 
2349  rewriter.replaceOp(op, results);
2350  return success();
2351  }
2352 };
2353 
2354 /// Allow the true region of an if to assume the condition is true
2355 /// and vice versa. For example:
2356 ///
2357 /// scf.if %cmp {
2358 /// print(%cmp)
2359 /// }
2360 ///
2361 /// becomes
2362 ///
2363 /// scf.if %cmp {
2364 /// print(true)
2365 /// }
2366 ///
2367 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2369 
2370  LogicalResult matchAndRewrite(IfOp op,
2371  PatternRewriter &rewriter) const override {
2372  // Early exit if the condition is constant since replacing a constant
2373  // in the body with another constant isn't a simplification.
2374  if (matchPattern(op.getCondition(), m_Constant()))
2375  return failure();
2376 
2377  bool changed = false;
2378  mlir::Type i1Ty = rewriter.getI1Type();
2379 
2380  // These variables serve to prevent creating duplicate constants
2381  // and hold constant true or false values.
2382  Value constantTrue = nullptr;
2383  Value constantFalse = nullptr;
2384 
2385  for (OpOperand &use :
2386  llvm::make_early_inc_range(op.getCondition().getUses())) {
2387  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2388  changed = true;
2389 
2390  if (!constantTrue)
2391  constantTrue = rewriter.create<arith::ConstantOp>(
2392  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2393 
2394  rewriter.modifyOpInPlace(use.getOwner(),
2395  [&]() { use.set(constantTrue); });
2396  } else if (op.getElseRegion().isAncestor(
2397  use.getOwner()->getParentRegion())) {
2398  changed = true;
2399 
2400  if (!constantFalse)
2401  constantFalse = rewriter.create<arith::ConstantOp>(
2402  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2403 
2404  rewriter.modifyOpInPlace(use.getOwner(),
2405  [&]() { use.set(constantFalse); });
2406  }
2407  }
2408 
2409  return success(changed);
2410  }
2411 };
2412 
2413 /// Remove any statements from an if that are equivalent to the condition
2414 /// or its negation. For example:
2415 ///
2416 /// %res:2 = scf.if %cmp {
2417 /// yield something(), true
2418 /// } else {
2419 /// yield something2(), false
2420 /// }
2421 /// print(%res#1)
2422 ///
2423 /// becomes
2424 /// %res = scf.if %cmp {
2425 /// yield something()
2426 /// } else {
2427 /// yield something2()
2428 /// }
2429 /// print(%cmp)
2430 ///
2431 /// Additionally if both branches yield the same value, replace all uses
2432 /// of the result with the yielded value.
2433 ///
2434 /// %res:2 = scf.if %cmp {
2435 /// yield something(), %arg1
2436 /// } else {
2437 /// yield something2(), %arg1
2438 /// }
2439 /// print(%res#1)
2440 ///
2441 /// becomes
2442 /// %res = scf.if %cmp {
2443 /// yield something()
2444 /// } else {
2445 /// yield something2()
2446 /// }
2447 /// print(%arg1)
2448 ///
2449 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2451 
2452  LogicalResult matchAndRewrite(IfOp op,
2453  PatternRewriter &rewriter) const override {
2454  // Early exit if there are no results that could be replaced.
2455  if (op.getNumResults() == 0)
2456  return failure();
2457 
2458  auto trueYield =
2459  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2460  auto falseYield =
2461  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2462 
2463  rewriter.setInsertionPoint(op->getBlock(),
2464  op.getOperation()->getIterator());
2465  bool changed = false;
2466  Type i1Ty = rewriter.getI1Type();
2467  for (auto [trueResult, falseResult, opResult] :
2468  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2469  op.getResults())) {
2470  if (trueResult == falseResult) {
2471  if (!opResult.use_empty()) {
2472  opResult.replaceAllUsesWith(trueResult);
2473  changed = true;
2474  }
2475  continue;
2476  }
2477 
2478  BoolAttr trueYield, falseYield;
2479  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2480  !matchPattern(falseResult, m_Constant(&falseYield)))
2481  continue;
2482 
2483  bool trueVal = trueYield.getValue();
2484  bool falseVal = falseYield.getValue();
2485  if (!trueVal && falseVal) {
2486  if (!opResult.use_empty()) {
2487  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2488  Value notCond = rewriter.create<arith::XOrIOp>(
2489  op.getLoc(), op.getCondition(),
2490  constDialect
2491  ->materializeConstant(rewriter,
2492  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2493  op.getLoc())
2494  ->getResult(0));
2495  opResult.replaceAllUsesWith(notCond);
2496  changed = true;
2497  }
2498  }
2499  if (trueVal && !falseVal) {
2500  if (!opResult.use_empty()) {
2501  opResult.replaceAllUsesWith(op.getCondition());
2502  changed = true;
2503  }
2504  }
2505  }
2506  return success(changed);
2507  }
2508 };
2509 
2510 /// Merge any consecutive scf.if's with the same condition.
2511 ///
2512 /// scf.if %cond {
2513 /// firstCodeTrue();...
2514 /// } else {
2515 /// firstCodeFalse();...
2516 /// }
2517 /// %res = scf.if %cond {
2518 /// secondCodeTrue();...
2519 /// } else {
2520 /// secondCodeFalse();...
2521 /// }
2522 ///
2523 /// becomes
2524 /// %res = scf.if %cmp {
2525 /// firstCodeTrue();...
2526 /// secondCodeTrue();...
2527 /// } else {
2528 /// firstCodeFalse();...
2529 /// secondCodeFalse();...
2530 /// }
2531 struct CombineIfs : public OpRewritePattern<IfOp> {
2533 
2534  LogicalResult matchAndRewrite(IfOp nextIf,
2535  PatternRewriter &rewriter) const override {
2536  Block *parent = nextIf->getBlock();
2537  if (nextIf == &parent->front())
2538  return failure();
2539 
2540  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2541  if (!prevIf)
2542  return failure();
2543 
2544  // Determine the logical then/else blocks when prevIf's
2545  // condition is used. Null means the block does not exist
2546  // in that case (e.g. empty else). If neither of these
2547  // are set, the two conditions cannot be compared.
2548  Block *nextThen = nullptr;
2549  Block *nextElse = nullptr;
2550  if (nextIf.getCondition() == prevIf.getCondition()) {
2551  nextThen = nextIf.thenBlock();
2552  if (!nextIf.getElseRegion().empty())
2553  nextElse = nextIf.elseBlock();
2554  }
2555  if (arith::XOrIOp notv =
2556  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2557  if (notv.getLhs() == prevIf.getCondition() &&
2558  matchPattern(notv.getRhs(), m_One())) {
2559  nextElse = nextIf.thenBlock();
2560  if (!nextIf.getElseRegion().empty())
2561  nextThen = nextIf.elseBlock();
2562  }
2563  }
2564  if (arith::XOrIOp notv =
2565  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2566  if (notv.getLhs() == nextIf.getCondition() &&
2567  matchPattern(notv.getRhs(), m_One())) {
2568  nextElse = nextIf.thenBlock();
2569  if (!nextIf.getElseRegion().empty())
2570  nextThen = nextIf.elseBlock();
2571  }
2572  }
2573 
2574  if (!nextThen && !nextElse)
2575  return failure();
2576 
2577  SmallVector<Value> prevElseYielded;
2578  if (!prevIf.getElseRegion().empty())
2579  prevElseYielded = prevIf.elseYield().getOperands();
2580  // Replace all uses of return values of op within nextIf with the
2581  // corresponding yields
2582  for (auto it : llvm::zip(prevIf.getResults(),
2583  prevIf.thenYield().getOperands(), prevElseYielded))
2584  for (OpOperand &use :
2585  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2586  if (nextThen && nextThen->getParent()->isAncestor(
2587  use.getOwner()->getParentRegion())) {
2588  rewriter.startOpModification(use.getOwner());
2589  use.set(std::get<1>(it));
2590  rewriter.finalizeOpModification(use.getOwner());
2591  } else if (nextElse && nextElse->getParent()->isAncestor(
2592  use.getOwner()->getParentRegion())) {
2593  rewriter.startOpModification(use.getOwner());
2594  use.set(std::get<2>(it));
2595  rewriter.finalizeOpModification(use.getOwner());
2596  }
2597  }
2598 
2599  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2600  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2601 
2602  IfOp combinedIf = rewriter.create<IfOp>(
2603  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2604  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2605 
2606  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2607  combinedIf.getThenRegion(),
2608  combinedIf.getThenRegion().begin());
2609 
2610  if (nextThen) {
2611  YieldOp thenYield = combinedIf.thenYield();
2612  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2613  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2614  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2615 
2616  SmallVector<Value> mergedYields(thenYield.getOperands());
2617  llvm::append_range(mergedYields, thenYield2.getOperands());
2618  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2619  rewriter.eraseOp(thenYield);
2620  rewriter.eraseOp(thenYield2);
2621  }
2622 
2623  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2624  combinedIf.getElseRegion(),
2625  combinedIf.getElseRegion().begin());
2626 
2627  if (nextElse) {
2628  if (combinedIf.getElseRegion().empty()) {
2629  rewriter.inlineRegionBefore(*nextElse->getParent(),
2630  combinedIf.getElseRegion(),
2631  combinedIf.getElseRegion().begin());
2632  } else {
2633  YieldOp elseYield = combinedIf.elseYield();
2634  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2635  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2636 
2637  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2638 
2639  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2640  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2641 
2642  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2643  rewriter.eraseOp(elseYield);
2644  rewriter.eraseOp(elseYield2);
2645  }
2646  }
2647 
2648  SmallVector<Value> prevValues;
2649  SmallVector<Value> nextValues;
2650  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2651  if (pair.index() < prevIf.getNumResults())
2652  prevValues.push_back(pair.value());
2653  else
2654  nextValues.push_back(pair.value());
2655  }
2656  rewriter.replaceOp(prevIf, prevValues);
2657  rewriter.replaceOp(nextIf, nextValues);
2658  return success();
2659  }
2660 };
2661 
2662 /// Pattern to remove an empty else branch.
2663 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2665 
2666  LogicalResult matchAndRewrite(IfOp ifOp,
2667  PatternRewriter &rewriter) const override {
2668  // Cannot remove else region when there are operation results.
2669  if (ifOp.getNumResults())
2670  return failure();
2671  Block *elseBlock = ifOp.elseBlock();
2672  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2673  return failure();
2674  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2675  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2676  newIfOp.getThenRegion().begin());
2677  rewriter.eraseOp(ifOp);
2678  return success();
2679  }
2680 };
2681 
2682 /// Convert nested `if`s into `arith.andi` + single `if`.
2683 ///
2684 /// scf.if %arg0 {
2685 /// scf.if %arg1 {
2686 /// ...
2687 /// scf.yield
2688 /// }
2689 /// scf.yield
2690 /// }
2691 /// becomes
2692 ///
2693 /// %0 = arith.andi %arg0, %arg1
2694 /// scf.if %0 {
2695 /// ...
2696 /// scf.yield
2697 /// }
2698 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2700 
2701  LogicalResult matchAndRewrite(IfOp op,
2702  PatternRewriter &rewriter) const override {
2703  auto nestedOps = op.thenBlock()->without_terminator();
2704  // Nested `if` must be the only op in block.
2705  if (!llvm::hasSingleElement(nestedOps))
2706  return failure();
2707 
2708  // If there is an else block, it can only yield
2709  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2710  return failure();
2711 
2712  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2713  if (!nestedIf)
2714  return failure();
2715 
2716  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2717  return failure();
2718 
2719  SmallVector<Value> thenYield(op.thenYield().getOperands());
2720  SmallVector<Value> elseYield;
2721  if (op.elseBlock())
2722  llvm::append_range(elseYield, op.elseYield().getOperands());
2723 
2724  // A list of indices for which we should upgrade the value yielded
2725  // in the else to a select.
2726  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2727 
2728  // If the outer scf.if yields a value produced by the inner scf.if,
2729  // only permit combining if the value yielded when the condition
2730  // is false in the outer scf.if is the same value yielded when the
2731  // inner scf.if condition is false.
2732  // Note that the array access to elseYield will not go out of bounds
2733  // since it must have the same length as thenYield, since they both
2734  // come from the same scf.if.
2735  for (const auto &tup : llvm::enumerate(thenYield)) {
2736  if (tup.value().getDefiningOp() == nestedIf) {
2737  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2738  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2739  elseYield[tup.index()]) {
2740  return failure();
2741  }
2742  // If the correctness test passes, we will yield
2743  // corresponding value from the inner scf.if
2744  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2745  continue;
2746  }
2747 
2748  // Otherwise, we need to ensure the else block of the combined
2749  // condition still returns the same value when the outer condition is
2750  // true and the inner condition is false. This can be accomplished if
2751  // the then value is defined outside the outer scf.if and we replace the
2752  // value with a select that considers just the outer condition. Since
2753  // the else region contains just the yield, its yielded value is
2754  // defined outside the scf.if, by definition.
2755 
2756  // If the then value is defined within the scf.if, bail.
2757  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2758  return failure();
2759  }
2760  elseYieldsToUpgradeToSelect.push_back(tup.index());
2761  }
2762 
2763  Location loc = op.getLoc();
2764  Value newCondition = rewriter.create<arith::AndIOp>(
2765  loc, op.getCondition(), nestedIf.getCondition());
2766  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2767  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2768 
2769  SmallVector<Value> results;
2770  llvm::append_range(results, newIf.getResults());
2771  rewriter.setInsertionPoint(newIf);
2772 
2773  for (auto idx : elseYieldsToUpgradeToSelect)
2774  results[idx] = rewriter.create<arith::SelectOp>(
2775  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2776 
2777  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2778  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2779  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2780  if (!elseYield.empty()) {
2781  rewriter.createBlock(&newIf.getElseRegion());
2782  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2783  rewriter.create<YieldOp>(loc, elseYield);
2784  }
2785  rewriter.replaceOp(op, results);
2786  return success();
2787  }
2788 };
2789 
2790 } // namespace
2791 
2792 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2793  MLIRContext *context) {
2794  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2795  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2796  RemoveStaticCondition, RemoveUnusedResults,
2797  ReplaceIfYieldWithConditionOrValue>(context);
2798 }
2799 
2800 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2801 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2802 Block *IfOp::elseBlock() {
2803  Region &r = getElseRegion();
2804  if (r.empty())
2805  return nullptr;
2806  return &r.back();
2807 }
2808 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2809 
2810 //===----------------------------------------------------------------------===//
2811 // ParallelOp
2812 //===----------------------------------------------------------------------===//
2813 
2814 void ParallelOp::build(
2815  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2816  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2818  bodyBuilderFn) {
2819  result.addOperands(lowerBounds);
2820  result.addOperands(upperBounds);
2821  result.addOperands(steps);
2822  result.addOperands(initVals);
2823  result.addAttribute(
2824  ParallelOp::getOperandSegmentSizeAttr(),
2825  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2826  static_cast<int32_t>(upperBounds.size()),
2827  static_cast<int32_t>(steps.size()),
2828  static_cast<int32_t>(initVals.size())}));
2829  result.addTypes(initVals.getTypes());
2830 
2831  OpBuilder::InsertionGuard guard(builder);
2832  unsigned numIVs = steps.size();
2833  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2834  SmallVector<Location, 8> argLocs(numIVs, result.location);
2835  Region *bodyRegion = result.addRegion();
2836  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2837 
2838  if (bodyBuilderFn) {
2839  builder.setInsertionPointToStart(bodyBlock);
2840  bodyBuilderFn(builder, result.location,
2841  bodyBlock->getArguments().take_front(numIVs),
2842  bodyBlock->getArguments().drop_front(numIVs));
2843  }
2844  // Add terminator only if there are no reductions.
2845  if (initVals.empty())
2846  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2847 }
2848 
2849 void ParallelOp::build(
2850  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2851  ValueRange upperBounds, ValueRange steps,
2852  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2853  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2854  // we don't capture a reference to a temporary by constructing the lambda at
2855  // function level.
2856  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2857  Location nestedLoc, ValueRange ivs,
2858  ValueRange) {
2859  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2860  };
2861  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2862  if (bodyBuilderFn)
2863  wrapper = wrappedBuilderFn;
2864 
2865  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2866  wrapper);
2867 }
2868 
2870  // Check that there is at least one value in lowerBound, upperBound and step.
2871  // It is sufficient to test only step, because it is ensured already that the
2872  // number of elements in lowerBound, upperBound and step are the same.
2873  Operation::operand_range stepValues = getStep();
2874  if (stepValues.empty())
2875  return emitOpError(
2876  "needs at least one tuple element for lowerBound, upperBound and step");
2877 
2878  // Check whether all constant step values are positive.
2879  for (Value stepValue : stepValues)
2880  if (auto cst = getConstantIntValue(stepValue))
2881  if (*cst <= 0)
2882  return emitOpError("constant step operand must be positive");
2883 
2884  // Check that the body defines the same number of block arguments as the
2885  // number of tuple elements in step.
2886  Block *body = getBody();
2887  if (body->getNumArguments() != stepValues.size())
2888  return emitOpError() << "expects the same number of induction variables: "
2889  << body->getNumArguments()
2890  << " as bound and step values: " << stepValues.size();
2891  for (auto arg : body->getArguments())
2892  if (!arg.getType().isIndex())
2893  return emitOpError(
2894  "expects arguments for the induction variable to be of index type");
2895 
2896  // Check that the terminator is an scf.reduce op.
2897  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2898  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2899  if (!reduceOp)
2900  return failure();
2901 
2902  // Check that the number of results is the same as the number of reductions.
2903  auto resultsSize = getResults().size();
2904  auto reductionsSize = reduceOp.getReductions().size();
2905  auto initValsSize = getInitVals().size();
2906  if (resultsSize != reductionsSize)
2907  return emitOpError() << "expects number of results: " << resultsSize
2908  << " to be the same as number of reductions: "
2909  << reductionsSize;
2910  if (resultsSize != initValsSize)
2911  return emitOpError() << "expects number of results: " << resultsSize
2912  << " to be the same as number of initial values: "
2913  << initValsSize;
2914 
2915  // Check that the types of the results and reductions are the same.
2916  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2917  auto resultType = getOperation()->getResult(i).getType();
2918  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2919  if (resultType != reductionOperandType)
2920  return reduceOp.emitOpError()
2921  << "expects type of " << i
2922  << "-th reduction operand: " << reductionOperandType
2923  << " to be the same as the " << i
2924  << "-th result type: " << resultType;
2925  }
2926  return success();
2927 }
2928 
2930  auto &builder = parser.getBuilder();
2931  // Parse an opening `(` followed by induction variables followed by `)`
2934  return failure();
2935 
2936  // Parse loop bounds.
2938  if (parser.parseEqual() ||
2939  parser.parseOperandList(lower, ivs.size(),
2941  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2942  return failure();
2943 
2945  if (parser.parseKeyword("to") ||
2946  parser.parseOperandList(upper, ivs.size(),
2948  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2949  return failure();
2950 
2951  // Parse step values.
2953  if (parser.parseKeyword("step") ||
2954  parser.parseOperandList(steps, ivs.size(),
2956  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2957  return failure();
2958 
2959  // Parse init values.
2961  if (succeeded(parser.parseOptionalKeyword("init"))) {
2962  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2963  return failure();
2964  }
2965 
2966  // Parse optional results in case there is a reduce.
2967  if (parser.parseOptionalArrowTypeList(result.types))
2968  return failure();
2969 
2970  // Now parse the body.
2971  Region *body = result.addRegion();
2972  for (auto &iv : ivs)
2973  iv.type = builder.getIndexType();
2974  if (parser.parseRegion(*body, ivs))
2975  return failure();
2976 
2977  // Set `operandSegmentSizes` attribute.
2978  result.addAttribute(
2979  ParallelOp::getOperandSegmentSizeAttr(),
2980  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2981  static_cast<int32_t>(upper.size()),
2982  static_cast<int32_t>(steps.size()),
2983  static_cast<int32_t>(initVals.size())}));
2984 
2985  // Parse attributes.
2986  if (parser.parseOptionalAttrDict(result.attributes) ||
2987  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2988  result.operands))
2989  return failure();
2990 
2991  // Add a terminator if none was parsed.
2992  ParallelOp::ensureTerminator(*body, builder, result.location);
2993  return success();
2994 }
2995 
2997  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2998  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2999  if (!getInitVals().empty())
3000  p << " init (" << getInitVals() << ")";
3001  p.printOptionalArrowTypeList(getResultTypes());
3002  p << ' ';
3003  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3005  (*this)->getAttrs(),
3006  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3007 }
3008 
3009 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3010 
3011 std::optional<Value> ParallelOp::getSingleInductionVar() {
3012  if (getNumLoops() != 1)
3013  return std::nullopt;
3014  return getBody()->getArgument(0);
3015 }
3016 
3017 std::optional<OpFoldResult> ParallelOp::getSingleLowerBound() {
3018  if (getNumLoops() != 1)
3019  return std::nullopt;
3020  return getLowerBound()[0];
3021 }
3022 
3023 std::optional<OpFoldResult> ParallelOp::getSingleUpperBound() {
3024  if (getNumLoops() != 1)
3025  return std::nullopt;
3026  return getUpperBound()[0];
3027 }
3028 
3029 std::optional<OpFoldResult> ParallelOp::getSingleStep() {
3030  if (getNumLoops() != 1)
3031  return std::nullopt;
3032  return getStep()[0];
3033 }
3034 
3036  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3037  if (!ivArg)
3038  return ParallelOp();
3039  assert(ivArg.getOwner() && "unlinked block argument");
3040  auto *containingOp = ivArg.getOwner()->getParentOp();
3041  return dyn_cast<ParallelOp>(containingOp);
3042 }
3043 
3044 namespace {
3045 // Collapse loop dimensions that perform a single iteration.
3046 struct ParallelOpSingleOrZeroIterationDimsFolder
3047  : public OpRewritePattern<ParallelOp> {
3049 
3050  LogicalResult matchAndRewrite(ParallelOp op,
3051  PatternRewriter &rewriter) const override {
3052  Location loc = op.getLoc();
3053 
3054  // Compute new loop bounds that omit all single-iteration loop dimensions.
3055  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3056  IRMapping mapping;
3057  for (auto [lb, ub, step, iv] :
3058  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3059  op.getInductionVars())) {
3060  auto numIterations = constantTripCount(lb, ub, step);
3061  if (numIterations.has_value()) {
3062  // Remove the loop if it performs zero iterations.
3063  if (*numIterations == 0) {
3064  rewriter.replaceOp(op, op.getInitVals());
3065  return success();
3066  }
3067  // Replace the loop induction variable by the lower bound if the loop
3068  // performs a single iteration. Otherwise, copy the loop bounds.
3069  if (*numIterations == 1) {
3070  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3071  continue;
3072  }
3073  }
3074  newLowerBounds.push_back(lb);
3075  newUpperBounds.push_back(ub);
3076  newSteps.push_back(step);
3077  }
3078  // Exit if none of the loop dimensions perform a single iteration.
3079  if (newLowerBounds.size() == op.getLowerBound().size())
3080  return failure();
3081 
3082  if (newLowerBounds.empty()) {
3083  // All of the loop dimensions perform a single iteration. Inline
3084  // loop body and nested ReduceOp's
3085  SmallVector<Value> results;
3086  results.reserve(op.getInitVals().size());
3087  for (auto &bodyOp : op.getBody()->without_terminator())
3088  rewriter.clone(bodyOp, mapping);
3089  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3090  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3091  Block &reduceBlock = reduceOp.getReductions()[i].front();
3092  auto initValIndex = results.size();
3093  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3094  mapping.map(reduceBlock.getArgument(1),
3095  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3096  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3097  rewriter.clone(reduceBodyOp, mapping);
3098 
3099  auto result = mapping.lookupOrDefault(
3100  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3101  results.push_back(result);
3102  }
3103 
3104  rewriter.replaceOp(op, results);
3105  return success();
3106  }
3107  // Replace the parallel loop by lower-dimensional parallel loop.
3108  auto newOp =
3109  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3110  newSteps, op.getInitVals(), nullptr);
3111  // Erase the empty block that was inserted by the builder.
3112  rewriter.eraseBlock(newOp.getBody());
3113  // Clone the loop body and remap the block arguments of the collapsed loops
3114  // (inlining does not support a cancellable block argument mapping).
3115  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3116  newOp.getRegion().begin(), mapping);
3117  rewriter.replaceOp(op, newOp.getResults());
3118  return success();
3119  }
3120 };
3121 
3122 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3124 
3125  LogicalResult matchAndRewrite(ParallelOp op,
3126  PatternRewriter &rewriter) const override {
3127  Block &outerBody = *op.getBody();
3128  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3129  return failure();
3130 
3131  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3132  if (!innerOp)
3133  return failure();
3134 
3135  for (auto val : outerBody.getArguments())
3136  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3137  llvm::is_contained(innerOp.getUpperBound(), val) ||
3138  llvm::is_contained(innerOp.getStep(), val))
3139  return failure();
3140 
3141  // Reductions are not supported yet.
3142  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3143  return failure();
3144 
3145  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3146  ValueRange iterVals, ValueRange) {
3147  Block &innerBody = *innerOp.getBody();
3148  assert(iterVals.size() ==
3149  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3150  IRMapping mapping;
3151  mapping.map(outerBody.getArguments(),
3152  iterVals.take_front(outerBody.getNumArguments()));
3153  mapping.map(innerBody.getArguments(),
3154  iterVals.take_back(innerBody.getNumArguments()));
3155  for (Operation &op : innerBody.without_terminator())
3156  builder.clone(op, mapping);
3157  };
3158 
3159  auto concatValues = [](const auto &first, const auto &second) {
3160  SmallVector<Value> ret;
3161  ret.reserve(first.size() + second.size());
3162  ret.assign(first.begin(), first.end());
3163  ret.append(second.begin(), second.end());
3164  return ret;
3165  };
3166 
3167  auto newLowerBounds =
3168  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3169  auto newUpperBounds =
3170  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3171  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3172 
3173  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3174  newSteps, std::nullopt,
3175  bodyBuilder);
3176  return success();
3177  }
3178 };
3179 
3180 } // namespace
3181 
3182 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3183  MLIRContext *context) {
3184  results
3185  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3186  context);
3187 }
3188 
3189 /// Given the region at `index`, or the parent operation if `index` is None,
3190 /// return the successor regions. These are the regions that may be selected
3191 /// during the flow of control. `operands` is a set of optional attributes that
3192 /// correspond to a constant value for each operand, or null if that operand is
3193 /// not a constant.
3194 void ParallelOp::getSuccessorRegions(
3196  // Both the operation itself and the region may be branching into the body or
3197  // back into the operation itself. It is possible for loop not to enter the
3198  // body.
3199  regions.push_back(RegionSuccessor(&getRegion()));
3200  regions.push_back(RegionSuccessor());
3201 }
3202 
3203 //===----------------------------------------------------------------------===//
3204 // ReduceOp
3205 //===----------------------------------------------------------------------===//
3206 
3207 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3208 
3209 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3210  ValueRange operands) {
3211  result.addOperands(operands);
3212  for (Value v : operands) {
3213  OpBuilder::InsertionGuard guard(builder);
3214  Region *bodyRegion = result.addRegion();
3215  builder.createBlock(bodyRegion, {},
3216  ArrayRef<Type>{v.getType(), v.getType()},
3217  {result.location, result.location});
3218  }
3219 }
3220 
3221 LogicalResult ReduceOp::verifyRegions() {
3222  // The region of a ReduceOp has two arguments of the same type as its
3223  // corresponding operand.
3224  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3225  auto type = getOperands()[i].getType();
3226  Block &block = getReductions()[i].front();
3227  if (block.empty())
3228  return emitOpError() << i << "-th reduction has an empty body";
3229  if (block.getNumArguments() != 2 ||
3230  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3231  return arg.getType() != type;
3232  }))
3233  return emitOpError() << "expected two block arguments with type " << type
3234  << " in the " << i << "-th reduction region";
3235 
3236  // Check that the block is terminated by a ReduceReturnOp.
3237  if (!isa<ReduceReturnOp>(block.getTerminator()))
3238  return emitOpError("reduction bodies must be terminated with an "
3239  "'scf.reduce.return' op");
3240  }
3241 
3242  return success();
3243 }
3244 
3247  // No operands are forwarded to the next iteration.
3248  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3249 }
3250 
3251 //===----------------------------------------------------------------------===//
3252 // ReduceReturnOp
3253 //===----------------------------------------------------------------------===//
3254 
3256  // The type of the return value should be the same type as the types of the
3257  // block arguments of the reduction body.
3258  Block *reductionBody = getOperation()->getBlock();
3259  // Should already be verified by an op trait.
3260  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3261  Type expectedResultType = reductionBody->getArgument(0).getType();
3262  if (expectedResultType != getResult().getType())
3263  return emitOpError() << "must have type " << expectedResultType
3264  << " (the type of the reduction inputs)";
3265  return success();
3266 }
3267 
3268 //===----------------------------------------------------------------------===//
3269 // WhileOp
3270 //===----------------------------------------------------------------------===//
3271 
3272 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3273  ::mlir::OperationState &odsState, TypeRange resultTypes,
3274  ValueRange operands, BodyBuilderFn beforeBuilder,
3275  BodyBuilderFn afterBuilder) {
3276  odsState.addOperands(operands);
3277  odsState.addTypes(resultTypes);
3278 
3279  OpBuilder::InsertionGuard guard(odsBuilder);
3280 
3281  // Build before region.
3282  SmallVector<Location, 4> beforeArgLocs;
3283  beforeArgLocs.reserve(operands.size());
3284  for (Value operand : operands) {
3285  beforeArgLocs.push_back(operand.getLoc());
3286  }
3287 
3288  Region *beforeRegion = odsState.addRegion();
3289  Block *beforeBlock = odsBuilder.createBlock(
3290  beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
3291  if (beforeBuilder)
3292  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3293 
3294  // Build after region.
3295  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3296 
3297  Region *afterRegion = odsState.addRegion();
3298  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3299  resultTypes, afterArgLocs);
3300 
3301  if (afterBuilder)
3302  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3303 }
3304 
3305 ConditionOp WhileOp::getConditionOp() {
3306  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3307 }
3308 
3309 YieldOp WhileOp::getYieldOp() {
3310  return cast<YieldOp>(getAfterBody()->getTerminator());
3311 }
3312 
3313 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3314  return getYieldOp().getResultsMutable();
3315 }
3316 
3317 Block::BlockArgListType WhileOp::getBeforeArguments() {
3318  return getBeforeBody()->getArguments();
3319 }
3320 
3321 Block::BlockArgListType WhileOp::getAfterArguments() {
3322  return getAfterBody()->getArguments();
3323 }
3324 
3325 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3326  return getBeforeArguments();
3327 }
3328 
3329 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3330  assert(point == getBefore() &&
3331  "WhileOp is expected to branch only to the first region");
3332  return getInits();
3333 }
3334 
3335 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3337  // The parent op always branches to the condition region.
3338  if (point.isParent()) {
3339  regions.emplace_back(&getBefore(), getBefore().getArguments());
3340  return;
3341  }
3342 
3343  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3344  "there are only two regions in a WhileOp");
3345  // The body region always branches back to the condition region.
3346  if (point == getAfter()) {
3347  regions.emplace_back(&getBefore(), getBefore().getArguments());
3348  return;
3349  }
3350 
3351  regions.emplace_back(getResults());
3352  regions.emplace_back(&getAfter(), getAfter().getArguments());
3353 }
3354 
3355 SmallVector<Region *> WhileOp::getLoopRegions() {
3356  return {&getBefore(), &getAfter()};
3357 }
3358 
3359 /// Parses a `while` op.
3360 ///
3361 /// op ::= `scf.while` assignments `:` function-type region `do` region
3362 /// `attributes` attribute-dict
3363 /// initializer ::= /* empty */ | `(` assignment-list `)`
3364 /// assignment-list ::= assignment | assignment `,` assignment-list
3365 /// assignment ::= ssa-value `=` ssa-value
3369  Region *before = result.addRegion();
3370  Region *after = result.addRegion();
3371 
3372  OptionalParseResult listResult =
3373  parser.parseOptionalAssignmentList(regionArgs, operands);
3374  if (listResult.has_value() && failed(listResult.value()))
3375  return failure();
3376 
3377  FunctionType functionType;
3378  SMLoc typeLoc = parser.getCurrentLocation();
3379  if (failed(parser.parseColonType(functionType)))
3380  return failure();
3381 
3382  result.addTypes(functionType.getResults());
3383 
3384  if (functionType.getNumInputs() != operands.size()) {
3385  return parser.emitError(typeLoc)
3386  << "expected as many input types as operands "
3387  << "(expected " << operands.size() << " got "
3388  << functionType.getNumInputs() << ")";
3389  }
3390 
3391  // Resolve input operands.
3392  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3393  parser.getCurrentLocation(),
3394  result.operands)))
3395  return failure();
3396 
3397  // Propagate the types into the region arguments.
3398  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3399  regionArgs[i].type = functionType.getInput(i);
3400 
3401  return failure(parser.parseRegion(*before, regionArgs) ||
3402  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3404 }
3405 
3406 /// Prints a `while` op.
3408  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3409  p << " : ";
3410  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3411  p << ' ';
3412  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3413  p << " do ";
3414  p.printRegion(getAfter());
3415  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3416 }
3417 
3418 /// Verifies that two ranges of types match, i.e. have the same number of
3419 /// entries and that types are pairwise equals. Reports errors on the given
3420 /// operation in case of mismatch.
3421 template <typename OpTy>
3423  TypeRange right, StringRef message) {
3424  if (left.size() != right.size())
3425  return op.emitOpError("expects the same number of ") << message;
3426 
3427  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3428  if (left[i] != right[i]) {
3429  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3430  << message;
3431  diag.attachNote() << "for argument " << i << ", found " << left[i]
3432  << " and " << right[i];
3433  return diag;
3434  }
3435  }
3436 
3437  return success();
3438 }
3439 
3441  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3442  *this, getBefore(),
3443  "expects the 'before' region to terminate with 'scf.condition'");
3444  if (!beforeTerminator)
3445  return failure();
3446 
3447  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3448  *this, getAfter(),
3449  "expects the 'after' region to terminate with 'scf.yield'");
3450  return success(afterTerminator != nullptr);
3451 }
3452 
3453 namespace {
3454 /// Replace uses of the condition within the do block with true, since otherwise
3455 /// the block would not be evaluated.
3456 ///
3457 /// scf.while (..) : (i1, ...) -> ... {
3458 /// %condition = call @evaluate_condition() : () -> i1
3459 /// scf.condition(%condition) %condition : i1, ...
3460 /// } do {
3461 /// ^bb0(%arg0: i1, ...):
3462 /// use(%arg0)
3463 /// ...
3464 ///
3465 /// becomes
3466 /// scf.while (..) : (i1, ...) -> ... {
3467 /// %condition = call @evaluate_condition() : () -> i1
3468 /// scf.condition(%condition) %condition : i1, ...
3469 /// } do {
3470 /// ^bb0(%arg0: i1, ...):
3471 /// use(%true)
3472 /// ...
3473 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3475 
3476  LogicalResult matchAndRewrite(WhileOp op,
3477  PatternRewriter &rewriter) const override {
3478  auto term = op.getConditionOp();
3479 
3480  // These variables serve to prevent creating duplicate constants
3481  // and hold constant true or false values.
3482  Value constantTrue = nullptr;
3483 
3484  bool replaced = false;
3485  for (auto yieldedAndBlockArgs :
3486  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3487  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3488  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3489  if (!constantTrue)
3490  constantTrue = rewriter.create<arith::ConstantOp>(
3491  op.getLoc(), term.getCondition().getType(),
3492  rewriter.getBoolAttr(true));
3493 
3494  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3495  constantTrue);
3496  replaced = true;
3497  }
3498  }
3499  }
3500  return success(replaced);
3501  }
3502 };
3503 
3504 /// Remove loop invariant arguments from `before` block of scf.while.
3505 /// A before block argument is considered loop invariant if :-
3506 /// 1. i-th yield operand is equal to the i-th while operand.
3507 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3508 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3509 /// iter argument/while operand.
3510 /// For the arguments which are removed, their uses inside scf.while
3511 /// are replaced with their corresponding initial value.
3512 ///
3513 /// Eg:
3514 /// INPUT :-
3515 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3516 /// ..., %argN_before = %N)
3517 /// {
3518 /// ...
3519 /// scf.condition(%cond) %arg1_before, %arg0_before,
3520 /// %arg2_before, %arg0_before, ...
3521 /// } do {
3522 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3523 /// ..., %argK_after):
3524 /// ...
3525 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3526 /// }
3527 ///
3528 /// OUTPUT :-
3529 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3530 /// %N)
3531 /// {
3532 /// ...
3533 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3534 /// } do {
3535 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3536 /// ..., %argK_after):
3537 /// ...
3538 /// scf.yield %arg1_after, ..., %argN
3539 /// }
3540 ///
3541 /// EXPLANATION:
3542 /// We iterate over each yield operand.
3543 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3544 /// %arg0_before, which in turn is the 0-th iter argument. So we
3545 /// remove 0-th before block argument and yield operand, and replace
3546 /// all uses of the 0-th before block argument with its initial value
3547 /// %a.
3548 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3549 /// value. So we remove this operand and the corresponding before
3550 /// block argument and replace all uses of 1-th before block argument
3551 /// with %b.
3552 struct RemoveLoopInvariantArgsFromBeforeBlock
3553  : public OpRewritePattern<WhileOp> {
3555 
3556  LogicalResult matchAndRewrite(WhileOp op,
3557  PatternRewriter &rewriter) const override {
3558  Block &afterBlock = *op.getAfterBody();
3559  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3560  ConditionOp condOp = op.getConditionOp();
3561  OperandRange condOpArgs = condOp.getArgs();
3562  Operation *yieldOp = afterBlock.getTerminator();
3563  ValueRange yieldOpArgs = yieldOp->getOperands();
3564 
3565  bool canSimplify = false;
3566  for (const auto &it :
3567  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3568  auto index = static_cast<unsigned>(it.index());
3569  auto [initVal, yieldOpArg] = it.value();
3570  // If i-th yield operand is equal to the i-th operand of the scf.while,
3571  // the i-th before block argument is a loop invariant.
3572  if (yieldOpArg == initVal) {
3573  canSimplify = true;
3574  break;
3575  }
3576  // If the i-th yield operand is k-th after block argument, then we check
3577  // if the (k+1)-th condition op operand is equal to either the i-th before
3578  // block argument or the initial value of i-th before block argument. If
3579  // the comparison results `true`, i-th before block argument is a loop
3580  // invariant.
3581  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3582  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3583  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3584  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3585  canSimplify = true;
3586  break;
3587  }
3588  }
3589  }
3590 
3591  if (!canSimplify)
3592  return failure();
3593 
3594  SmallVector<Value> newInitArgs, newYieldOpArgs;
3595  DenseMap<unsigned, Value> beforeBlockInitValMap;
3596  SmallVector<Location> newBeforeBlockArgLocs;
3597  for (const auto &it :
3598  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3599  auto index = static_cast<unsigned>(it.index());
3600  auto [initVal, yieldOpArg] = it.value();
3601 
3602  // If i-th yield operand is equal to the i-th operand of the scf.while,
3603  // the i-th before block argument is a loop invariant.
3604  if (yieldOpArg == initVal) {
3605  beforeBlockInitValMap.insert({index, initVal});
3606  continue;
3607  } else {
3608  // If the i-th yield operand is k-th after block argument, then we check
3609  // if the (k+1)-th condition op operand is equal to either the i-th
3610  // before block argument or the initial value of i-th before block
3611  // argument. If the comparison results `true`, i-th before block
3612  // argument is a loop invariant.
3613  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3614  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3615  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3616  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3617  beforeBlockInitValMap.insert({index, initVal});
3618  continue;
3619  }
3620  }
3621  }
3622  newInitArgs.emplace_back(initVal);
3623  newYieldOpArgs.emplace_back(yieldOpArg);
3624  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3625  }
3626 
3627  {
3628  OpBuilder::InsertionGuard g(rewriter);
3629  rewriter.setInsertionPoint(yieldOp);
3630  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3631  }
3632 
3633  auto newWhile =
3634  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3635 
3636  Block &newBeforeBlock = *rewriter.createBlock(
3637  &newWhile.getBefore(), /*insertPt*/ {},
3638  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3639 
3640  Block &beforeBlock = *op.getBeforeBody();
3641  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3642  // For each i-th before block argument we find it's replacement value as :-
3643  // 1. If i-th before block argument is a loop invariant, we fetch it's
3644  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3645  // 2. Else we fetch j-th new before block argument as the replacement
3646  // value of i-th before block argument.
3647  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3648  // If the index 'i' argument was a loop invariant we fetch it's initial
3649  // value from `beforeBlockInitValMap`.
3650  if (beforeBlockInitValMap.count(i) != 0)
3651  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3652  else
3653  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3654  }
3655 
3656  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3657  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3658  newWhile.getAfter().begin());
3659 
3660  rewriter.replaceOp(op, newWhile.getResults());
3661  return success();
3662  }
3663 };
3664 
3665 /// Remove loop invariant value from result (condition op) of scf.while.
3666 /// A value is considered loop invariant if the final value yielded by
3667 /// scf.condition is defined outside of the `before` block. We remove the
3668 /// corresponding argument in `after` block and replace the use with the value.
3669 /// We also replace the use of the corresponding result of scf.while with the
3670 /// value.
3671 ///
3672 /// Eg:
3673 /// INPUT :-
3674 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3675 /// %argN_before = %N) {
3676 /// ...
3677 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3678 /// } do {
3679 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3680 /// ...
3681 /// some_func(%arg1_after)
3682 /// ...
3683 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3684 /// }
3685 ///
3686 /// OUTPUT :-
3687 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3688 /// ...
3689 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3690 /// } do {
3691 /// ^bb0(%arg0, %arg3, ..., %argM):
3692 /// ...
3693 /// some_func(%a)
3694 /// ...
3695 /// scf.yield %arg0, %b, ..., %argN
3696 /// }
3697 ///
3698 /// EXPLANATION:
3699 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3700 /// before block of scf.while, so they get removed.
3701 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3702 /// replaced by %b.
3703 /// 3. The corresponding after block argument %arg1_after's uses are
3704 /// replaced by %a and %arg2_after's uses are replaced by %b.
3705 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3707 
3708  LogicalResult matchAndRewrite(WhileOp op,
3709  PatternRewriter &rewriter) const override {
3710  Block &beforeBlock = *op.getBeforeBody();
3711  ConditionOp condOp = op.getConditionOp();
3712  OperandRange condOpArgs = condOp.getArgs();
3713 
3714  bool canSimplify = false;
3715  for (Value condOpArg : condOpArgs) {
3716  // Those values not defined within `before` block will be considered as
3717  // loop invariant values. We map the corresponding `index` with their
3718  // value.
3719  if (condOpArg.getParentBlock() != &beforeBlock) {
3720  canSimplify = true;
3721  break;
3722  }
3723  }
3724 
3725  if (!canSimplify)
3726  return failure();
3727 
3728  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3729 
3730  SmallVector<Value> newCondOpArgs;
3731  SmallVector<Type> newAfterBlockType;
3732  DenseMap<unsigned, Value> condOpInitValMap;
3733  SmallVector<Location> newAfterBlockArgLocs;
3734  for (const auto &it : llvm::enumerate(condOpArgs)) {
3735  auto index = static_cast<unsigned>(it.index());
3736  Value condOpArg = it.value();
3737  // Those values not defined within `before` block will be considered as
3738  // loop invariant values. We map the corresponding `index` with their
3739  // value.
3740  if (condOpArg.getParentBlock() != &beforeBlock) {
3741  condOpInitValMap.insert({index, condOpArg});
3742  } else {
3743  newCondOpArgs.emplace_back(condOpArg);
3744  newAfterBlockType.emplace_back(condOpArg.getType());
3745  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3746  }
3747  }
3748 
3749  {
3750  OpBuilder::InsertionGuard g(rewriter);
3751  rewriter.setInsertionPoint(condOp);
3752  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3753  newCondOpArgs);
3754  }
3755 
3756  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3757  op.getOperands());
3758 
3759  Block &newAfterBlock =
3760  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3761  newAfterBlockType, newAfterBlockArgLocs);
3762 
3763  Block &afterBlock = *op.getAfterBody();
3764  // Since a new scf.condition op was created, we need to fetch the new
3765  // `after` block arguments which will be used while replacing operations of
3766  // previous scf.while's `after` blocks. We'd also be fetching new result
3767  // values too.
3768  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3769  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3770  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3771  Value afterBlockArg, result;
3772  // If index 'i' argument was loop invariant we fetch it's value from the
3773  // `condOpInitMap` map.
3774  if (condOpInitValMap.count(i) != 0) {
3775  afterBlockArg = condOpInitValMap[i];
3776  result = afterBlockArg;
3777  } else {
3778  afterBlockArg = newAfterBlock.getArgument(j);
3779  result = newWhile.getResult(j);
3780  j++;
3781  }
3782  newAfterBlockArgs[i] = afterBlockArg;
3783  newWhileResults[i] = result;
3784  }
3785 
3786  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3787  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3788  newWhile.getBefore().begin());
3789 
3790  rewriter.replaceOp(op, newWhileResults);
3791  return success();
3792  }
3793 };
3794 
3795 /// Remove WhileOp results that are also unused in 'after' block.
3796 ///
3797 /// %0:2 = scf.while () : () -> (i32, i64) {
3798 /// %condition = "test.condition"() : () -> i1
3799 /// %v1 = "test.get_some_value"() : () -> i32
3800 /// %v2 = "test.get_some_value"() : () -> i64
3801 /// scf.condition(%condition) %v1, %v2 : i32, i64
3802 /// } do {
3803 /// ^bb0(%arg0: i32, %arg1: i64):
3804 /// "test.use"(%arg0) : (i32) -> ()
3805 /// scf.yield
3806 /// }
3807 /// return %0#0 : i32
3808 ///
3809 /// becomes
3810 /// %0 = scf.while () : () -> (i32) {
3811 /// %condition = "test.condition"() : () -> i1
3812 /// %v1 = "test.get_some_value"() : () -> i32
3813 /// %v2 = "test.get_some_value"() : () -> i64
3814 /// scf.condition(%condition) %v1 : i32
3815 /// } do {
3816 /// ^bb0(%arg0: i32):
3817 /// "test.use"(%arg0) : (i32) -> ()
3818 /// scf.yield
3819 /// }
3820 /// return %0 : i32
3821 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3823 
3824  LogicalResult matchAndRewrite(WhileOp op,
3825  PatternRewriter &rewriter) const override {
3826  auto term = op.getConditionOp();
3827  auto afterArgs = op.getAfterArguments();
3828  auto termArgs = term.getArgs();
3829 
3830  // Collect results mapping, new terminator args and new result types.
3831  SmallVector<unsigned> newResultsIndices;
3832  SmallVector<Type> newResultTypes;
3833  SmallVector<Value> newTermArgs;
3834  SmallVector<Location> newArgLocs;
3835  bool needUpdate = false;
3836  for (const auto &it :
3837  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3838  auto i = static_cast<unsigned>(it.index());
3839  Value result = std::get<0>(it.value());
3840  Value afterArg = std::get<1>(it.value());
3841  Value termArg = std::get<2>(it.value());
3842  if (result.use_empty() && afterArg.use_empty()) {
3843  needUpdate = true;
3844  } else {
3845  newResultsIndices.emplace_back(i);
3846  newTermArgs.emplace_back(termArg);
3847  newResultTypes.emplace_back(result.getType());
3848  newArgLocs.emplace_back(result.getLoc());
3849  }
3850  }
3851 
3852  if (!needUpdate)
3853  return failure();
3854 
3855  {
3856  OpBuilder::InsertionGuard g(rewriter);
3857  rewriter.setInsertionPoint(term);
3858  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3859  newTermArgs);
3860  }
3861 
3862  auto newWhile =
3863  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3864 
3865  Block &newAfterBlock = *rewriter.createBlock(
3866  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3867 
3868  // Build new results list and new after block args (unused entries will be
3869  // null).
3870  SmallVector<Value> newResults(op.getNumResults());
3871  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3872  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3873  newResults[it.value()] = newWhile.getResult(it.index());
3874  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3875  }
3876 
3877  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3878  newWhile.getBefore().begin());
3879 
3880  Block &afterBlock = *op.getAfterBody();
3881  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3882 
3883  rewriter.replaceOp(op, newResults);
3884  return success();
3885  }
3886 };
3887 
3888 /// Replace operations equivalent to the condition in the do block with true,
3889 /// since otherwise the block would not be evaluated.
3890 ///
3891 /// scf.while (..) : (i32, ...) -> ... {
3892 /// %z = ... : i32
3893 /// %condition = cmpi pred %z, %a
3894 /// scf.condition(%condition) %z : i32, ...
3895 /// } do {
3896 /// ^bb0(%arg0: i32, ...):
3897 /// %condition2 = cmpi pred %arg0, %a
3898 /// use(%condition2)
3899 /// ...
3900 ///
3901 /// becomes
3902 /// scf.while (..) : (i32, ...) -> ... {
3903 /// %z = ... : i32
3904 /// %condition = cmpi pred %z, %a
3905 /// scf.condition(%condition) %z : i32, ...
3906 /// } do {
3907 /// ^bb0(%arg0: i32, ...):
3908 /// use(%true)
3909 /// ...
3910 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3912 
3913  LogicalResult matchAndRewrite(scf::WhileOp op,
3914  PatternRewriter &rewriter) const override {
3915  using namespace scf;
3916  auto cond = op.getConditionOp();
3917  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3918  if (!cmp)
3919  return failure();
3920  bool changed = false;
3921  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3922  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3923  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3924  continue;
3925  for (OpOperand &u :
3926  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3927  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3928  if (!cmp2)
3929  continue;
3930  // For a binary operator 1-opIdx gets the other side.
3931  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3932  continue;
3933  bool samePredicate;
3934  if (cmp2.getPredicate() == cmp.getPredicate())
3935  samePredicate = true;
3936  else if (cmp2.getPredicate() ==
3937  arith::invertPredicate(cmp.getPredicate()))
3938  samePredicate = false;
3939  else
3940  continue;
3941 
3942  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3943  1);
3944  changed = true;
3945  }
3946  }
3947  }
3948  return success(changed);
3949  }
3950 };
3951 
3952 /// Remove unused init/yield args.
3953 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3955 
3956  LogicalResult matchAndRewrite(WhileOp op,
3957  PatternRewriter &rewriter) const override {
3958 
3959  if (!llvm::any_of(op.getBeforeArguments(),
3960  [](Value arg) { return arg.use_empty(); }))
3961  return rewriter.notifyMatchFailure(op, "No args to remove");
3962 
3963  YieldOp yield = op.getYieldOp();
3964 
3965  // Collect results mapping, new terminator args and new result types.
3966  SmallVector<Value> newYields;
3967  SmallVector<Value> newInits;
3968  llvm::BitVector argsToErase;
3969 
3970  size_t argsCount = op.getBeforeArguments().size();
3971  newYields.reserve(argsCount);
3972  newInits.reserve(argsCount);
3973  argsToErase.reserve(argsCount);
3974  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3975  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3976  if (beforeArg.use_empty()) {
3977  argsToErase.push_back(true);
3978  } else {
3979  argsToErase.push_back(false);
3980  newYields.emplace_back(yieldValue);
3981  newInits.emplace_back(initValue);
3982  }
3983  }
3984 
3985  Block &beforeBlock = *op.getBeforeBody();
3986  Block &afterBlock = *op.getAfterBody();
3987 
3988  beforeBlock.eraseArguments(argsToErase);
3989 
3990  Location loc = op.getLoc();
3991  auto newWhileOp =
3992  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
3993  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
3994  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3995  Block &newAfterBlock = *newWhileOp.getAfterBody();
3996 
3997  OpBuilder::InsertionGuard g(rewriter);
3998  rewriter.setInsertionPoint(yield);
3999  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4000 
4001  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4002  newBeforeBlock.getArguments());
4003  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4004  newAfterBlock.getArguments());
4005 
4006  rewriter.replaceOp(op, newWhileOp.getResults());
4007  return success();
4008  }
4009 };
4010 
4011 /// Remove duplicated ConditionOp args.
4012 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4014 
4015  LogicalResult matchAndRewrite(WhileOp op,
4016  PatternRewriter &rewriter) const override {
4017  ConditionOp condOp = op.getConditionOp();
4018  ValueRange condOpArgs = condOp.getArgs();
4019 
4021  for (Value arg : condOpArgs)
4022  argsSet.insert(arg);
4023 
4024  if (argsSet.size() == condOpArgs.size())
4025  return rewriter.notifyMatchFailure(op, "No results to remove");
4026 
4027  llvm::SmallDenseMap<Value, unsigned> argsMap;
4028  SmallVector<Value> newArgs;
4029  argsMap.reserve(condOpArgs.size());
4030  newArgs.reserve(condOpArgs.size());
4031  for (Value arg : condOpArgs) {
4032  if (!argsMap.count(arg)) {
4033  auto pos = static_cast<unsigned>(argsMap.size());
4034  argsMap.insert({arg, pos});
4035  newArgs.emplace_back(arg);
4036  }
4037  }
4038 
4039  ValueRange argsRange(newArgs);
4040 
4041  Location loc = op.getLoc();
4042  auto newWhileOp = rewriter.create<scf::WhileOp>(
4043  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4044  /*afterBody*/ nullptr);
4045  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4046  Block &newAfterBlock = *newWhileOp.getAfterBody();
4047 
4048  SmallVector<Value> afterArgsMapping;
4049  SmallVector<Value> resultsMapping;
4050  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4051  auto it = argsMap.find(arg);
4052  assert(it != argsMap.end());
4053  auto pos = it->second;
4054  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4055  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4056  }
4057 
4058  OpBuilder::InsertionGuard g(rewriter);
4059  rewriter.setInsertionPoint(condOp);
4060  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4061  argsRange);
4062 
4063  Block &beforeBlock = *op.getBeforeBody();
4064  Block &afterBlock = *op.getAfterBody();
4065 
4066  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4067  newBeforeBlock.getArguments());
4068  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4069  rewriter.replaceOp(op, resultsMapping);
4070  return success();
4071  }
4072 };
4073 
4074 /// If both ranges contain same values return mappping indices from args2 to
4075 /// args1. Otherwise return std::nullopt.
4076 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4077  ValueRange args2) {
4078  if (args1.size() != args2.size())
4079  return std::nullopt;
4080 
4081  SmallVector<unsigned> ret(args1.size());
4082  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4083  auto it = llvm::find(args2, arg1);
4084  if (it == args2.end())
4085  return std::nullopt;
4086 
4087  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4088  }
4089 
4090  return ret;
4091 }
4092 
4093 static bool hasDuplicates(ValueRange args) {
4094  llvm::SmallDenseSet<Value> set;
4095  for (Value arg : args) {
4096  if (set.contains(arg))
4097  return true;
4098 
4099  set.insert(arg);
4100  }
4101  return false;
4102 }
4103 
4104 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4105 /// `scf.condition` args into same order as block args. Update `after` block
4106 /// args and op result values accordingly.
4107 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4108 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4110 
4111  LogicalResult matchAndRewrite(WhileOp loop,
4112  PatternRewriter &rewriter) const override {
4113  auto oldBefore = loop.getBeforeBody();
4114  ConditionOp oldTerm = loop.getConditionOp();
4115  ValueRange beforeArgs = oldBefore->getArguments();
4116  ValueRange termArgs = oldTerm.getArgs();
4117  if (beforeArgs == termArgs)
4118  return failure();
4119 
4120  if (hasDuplicates(termArgs))
4121  return failure();
4122 
4123  auto mapping = getArgsMapping(beforeArgs, termArgs);
4124  if (!mapping)
4125  return failure();
4126 
4127  {
4128  OpBuilder::InsertionGuard g(rewriter);
4129  rewriter.setInsertionPoint(oldTerm);
4130  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4131  beforeArgs);
4132  }
4133 
4134  auto oldAfter = loop.getAfterBody();
4135 
4136  SmallVector<Type> newResultTypes(beforeArgs.size());
4137  for (auto &&[i, j] : llvm::enumerate(*mapping))
4138  newResultTypes[j] = loop.getResult(i).getType();
4139 
4140  auto newLoop = rewriter.create<WhileOp>(
4141  loop.getLoc(), newResultTypes, loop.getInits(),
4142  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4143  auto newBefore = newLoop.getBeforeBody();
4144  auto newAfter = newLoop.getAfterBody();
4145 
4146  SmallVector<Value> newResults(beforeArgs.size());
4147  SmallVector<Value> newAfterArgs(beforeArgs.size());
4148  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4149  newResults[i] = newLoop.getResult(j);
4150  newAfterArgs[i] = newAfter->getArgument(j);
4151  }
4152 
4153  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4154  newBefore->getArguments());
4155  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4156  newAfterArgs);
4157 
4158  rewriter.replaceOp(loop, newResults);
4159  return success();
4160  }
4161 };
4162 } // namespace
4163 
4164 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4165  MLIRContext *context) {
4166  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4167  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4168  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4169  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4170 }
4171 
4172 //===----------------------------------------------------------------------===//
4173 // IndexSwitchOp
4174 //===----------------------------------------------------------------------===//
4175 
4176 /// Parse the case regions and values.
4177 static ParseResult
4179  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4180  SmallVector<int64_t> caseValues;
4181  while (succeeded(p.parseOptionalKeyword("case"))) {
4182  int64_t value;
4183  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4184  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4185  return failure();
4186  caseValues.push_back(value);
4187  }
4188  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4189  return success();
4190 }
4191 
4192 /// Print the case regions and values.
4194  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4195  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4196  p.printNewline();
4197  p << "case " << value << ' ';
4198  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4199  }
4200 }
4201 
4203  if (getCases().size() != getCaseRegions().size()) {
4204  return emitOpError("has ")
4205  << getCaseRegions().size() << " case regions but "
4206  << getCases().size() << " case values";
4207  }
4208 
4209  DenseSet<int64_t> valueSet;
4210  for (int64_t value : getCases())
4211  if (!valueSet.insert(value).second)
4212  return emitOpError("has duplicate case value: ") << value;
4213  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4214  auto yield = dyn_cast<YieldOp>(region.front().back());
4215  if (!yield)
4216  return emitOpError("expected region to end with scf.yield, but got ")
4217  << region.front().back().getName();
4218 
4219  if (yield.getNumOperands() != getNumResults()) {
4220  return (emitOpError("expected each region to return ")
4221  << getNumResults() << " values, but " << name << " returns "
4222  << yield.getNumOperands())
4223  .attachNote(yield.getLoc())
4224  << "see yield operation here";
4225  }
4226  for (auto [idx, result, operand] :
4227  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4228  yield.getOperandTypes())) {
4229  if (result == operand)
4230  continue;
4231  return (emitOpError("expected result #")
4232  << idx << " of each region to be " << result)
4233  .attachNote(yield.getLoc())
4234  << name << " returns " << operand << " here";
4235  }
4236  return success();
4237  };
4238 
4239  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4240  return failure();
4241  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4242  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4243  return failure();
4244 
4245  return success();
4246 }
4247 
4248 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4249 
4250 Block &scf::IndexSwitchOp::getDefaultBlock() {
4251  return getDefaultRegion().front();
4252 }
4253 
4254 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4255  assert(idx < getNumCases() && "case index out-of-bounds");
4256  return getCaseRegions()[idx].front();
4257 }
4258 
4259 void IndexSwitchOp::getSuccessorRegions(
4261  // All regions branch back to the parent op.
4262  if (!point.isParent()) {
4263  successors.emplace_back(getResults());
4264  return;
4265  }
4266 
4267  llvm::copy(getRegions(), std::back_inserter(successors));
4268 }
4269 
4270 void IndexSwitchOp::getEntrySuccessorRegions(
4271  ArrayRef<Attribute> operands,
4272  SmallVectorImpl<RegionSuccessor> &successors) {
4273  FoldAdaptor adaptor(operands, *this);
4274 
4275  // If a constant was not provided, all regions are possible successors.
4276  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4277  if (!arg) {
4278  llvm::copy(getRegions(), std::back_inserter(successors));
4279  return;
4280  }
4281 
4282  // Otherwise, try to find a case with a matching value. If not, the
4283  // default region is the only successor.
4284  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4285  if (caseValue == arg.getInt()) {
4286  successors.emplace_back(&caseRegion);
4287  return;
4288  }
4289  }
4290  successors.emplace_back(&getDefaultRegion());
4291 }
4292 
4293 void IndexSwitchOp::getRegionInvocationBounds(
4295  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4296  if (!operandValue) {
4297  // All regions are invoked at most once.
4298  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4299  return;
4300  }
4301 
4302  unsigned liveIndex = getNumRegions() - 1;
4303  const auto *it = llvm::find(getCases(), operandValue.getInt());
4304  if (it != getCases().end())
4305  liveIndex = std::distance(getCases().begin(), it);
4306  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4307  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4308 }
4309 
4310 LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
4311  SmallVectorImpl<OpFoldResult> &results) {
4312  std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
4313  if (!maybeCst.has_value())
4314  return failure();
4315  int64_t cst = *maybeCst;
4316  int64_t caseIdx, e = getNumCases();
4317  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4318  if (cst == getCases()[caseIdx])
4319  break;
4320  }
4321 
4322  Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
4323  : getDefaultRegion();
4324  Block &source = r.front();
4325  results.assign(source.getTerminator()->getOperands().begin(),
4326  source.getTerminator()->getOperands().end());
4327 
4328  Block *pDestination = (*this)->getBlock();
4329  if (!pDestination)
4330  return failure();
4331  Block::iterator insertionPoint = (*this)->getIterator();
4332  pDestination->getOperations().splice(insertionPoint, source.getOperations(),
4333  source.getOperations().begin(),
4334  std::prev(source.getOperations().end()));
4335 
4336  return success();
4337 }
4338 
4339 //===----------------------------------------------------------------------===//
4340 // TableGen'd op method definitions
4341 //===----------------------------------------------------------------------===//
4342 
4343 #define GET_OP_CLASSES
4344 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:716
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:708
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:113
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4178
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3422
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:432
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:93
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4193
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:93
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:582
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:586
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:131
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3035
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:86
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1954
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1455
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:263
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:234
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:185
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.