MLIR  20.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange initArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(initArgs);
321  for (Value v : initArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : initArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (initArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Resolve input operands.
503  regionArgs.front().type = type;
504  if (parser.resolveOperand(lb, type, result.operands) ||
505  parser.resolveOperand(ub, type, result.operands) ||
506  parser.resolveOperand(step, type, result.operands))
507  return failure();
508  if (hasIterArgs) {
509  for (auto argOperandType :
510  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
511  Type type = std::get<2>(argOperandType);
512  std::get<0>(argOperandType).type = type;
513  if (parser.resolveOperand(std::get<1>(argOperandType), type,
514  result.operands))
515  return failure();
516  }
517  }
518 
519  // Parse the body region.
520  Region *body = result.addRegion();
521  if (parser.parseRegion(*body, regionArgs))
522  return failure();
523 
524  ForOp::ensureTerminator(*body, builder, result.location);
525 
526  // Parse the optional attribute list.
527  if (parser.parseOptionalAttrDict(result.attributes))
528  return failure();
529 
530  return success();
531 }
532 
533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534 
535 Block::BlockArgListType ForOp::getRegionIterArgs() {
536  return getBody()->getArguments().drop_front(getNumInductionVars());
537 }
538 
539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
540  return getInitArgsMutable();
541 }
542 
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
545  ValueRange newInitOperands,
546  bool replaceInitOperandUsesInLoop,
547  const NewYieldValuesFn &newYieldValuesFn) {
548  // Create a new loop before the existing one, with the extra operands.
549  OpBuilder::InsertionGuard g(rewriter);
550  rewriter.setInsertionPoint(getOperation());
551  auto inits = llvm::to_vector(getInitArgs());
552  inits.append(newInitOperands.begin(), newInitOperands.end());
553  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
554  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
555  [](OpBuilder &, Location, Value, ValueRange) {});
556  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
777  OpOperand &operand, Value replacement,
778  const ValueTypeCastFnTy &castFn) {
779  assert(operand.getOwner() == forOp);
780  Type oldType = operand.get().getType(), newType = replacement.getType();
781 
782  // 1. Create new iter operands, exactly 1 is replaced.
783  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
784  "expected an iter OpOperand");
785  assert(operand.get().getType() != replacement.getType() &&
786  "Expected a different type");
787  SmallVector<Value> newIterOperands;
788  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
789  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
790  newIterOperands.push_back(replacement);
791  continue;
792  }
793  newIterOperands.push_back(opOperand.get());
794  }
795 
796  // 2. Create the new forOp shell.
797  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
798  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799  forOp.getStep(), newIterOperands);
800  newForOp->setAttrs(forOp->getAttrs());
801  Block &newBlock = newForOp.getRegion().front();
802  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
803  newBlock.getArguments().end());
804 
805  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806  // corresponding to the `replacement` value.
807  OpBuilder::InsertionGuard g(rewriter);
808  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
809  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
810  &newForOp->getOpOperand(operand.getOperandNumber()));
811  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
813 
814  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815  Block &oldBlock = forOp.getRegion().front();
816  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
817 
818  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
820  rewriter.setInsertionPoint(clonedYieldOp);
821  unsigned yieldIdx =
822  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
823  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824  clonedYieldOp.getOperand(yieldIdx));
825  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
826  newYieldOperands[yieldIdx] = castOut;
827  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828  rewriter.eraseOp(clonedYieldOp);
829 
830  // 6. Inject an outgoing cast op after the forOp.
831  rewriter.setInsertionPointAfter(newForOp);
832  SmallVector<Value> newResults = newForOp.getResults();
833  newResults[yieldIdx] =
834  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
835 
836  return newResults;
837 }
838 
839 namespace {
840 // Fold away ForOp iter arguments when:
841 // 1) The op yields the iter arguments.
842 // 2) The iter arguments have no use and the corresponding outer region
843 // iterators (inputs) are yielded.
844 // 3) The iter arguments have no use and the corresponding (operation) results
845 // have no use.
846 //
847 // These arguments must be defined outside of
848 // the ForOp region and can just be forwarded after simplifying the op inits,
849 // yields and returns.
850 //
851 // The implementation uses `inlineBlockBefore` to steal the content of the
852 // original ForOp and avoid cloning.
853 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
855 
856  LogicalResult matchAndRewrite(scf::ForOp forOp,
857  PatternRewriter &rewriter) const final {
858  bool canonicalize = false;
859 
860  // An internal flat vector of block transfer
861  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
862  // transformed block argument mappings. This plays the role of a
863  // IRMapping for the particular use case of calling into
864  // `inlineBlockBefore`.
865  int64_t numResults = forOp.getNumResults();
866  SmallVector<bool, 4> keepMask;
867  keepMask.reserve(numResults);
868  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
869  newResultValues;
870  newBlockTransferArgs.reserve(1 + numResults);
871  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
872  newIterArgs.reserve(forOp.getInitArgs().size());
873  newYieldValues.reserve(numResults);
874  newResultValues.reserve(numResults);
875  for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
876  forOp.getRegionIterArgs(), // iter inside region
877  forOp.getResults(), // op results
878  forOp.getYieldedValues() // iter yield
879  )) {
880  // Forwarded is `true` when:
881  // 1) The region `iter` argument is yielded.
882  // 2) The region `iter` argument has no use, and the corresponding iter
883  // operand (input) is yielded.
884  // 3) The region `iter` argument has no use, and the corresponding op
885  // result has no use.
886  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
887  (std::get<1>(it).use_empty() &&
888  (std::get<0>(it) == std::get<3>(it) ||
889  std::get<2>(it).use_empty())));
890  keepMask.push_back(!forwarded);
891  canonicalize |= forwarded;
892  if (forwarded) {
893  newBlockTransferArgs.push_back(std::get<0>(it));
894  newResultValues.push_back(std::get<0>(it));
895  continue;
896  }
897  newIterArgs.push_back(std::get<0>(it));
898  newYieldValues.push_back(std::get<3>(it));
899  newBlockTransferArgs.push_back(Value()); // placeholder with null value
900  newResultValues.push_back(Value()); // placeholder with null value
901  }
902 
903  if (!canonicalize)
904  return failure();
905 
906  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
907  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
908  forOp.getStep(), newIterArgs);
909  newForOp->setAttrs(forOp->getAttrs());
910  Block &newBlock = newForOp.getRegion().front();
911 
912  // Replace the null placeholders with newly constructed values.
913  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
914  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
915  idx != e; ++idx) {
916  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
917  Value &newResultVal = newResultValues[idx];
918  assert((blockTransferArg && newResultVal) ||
919  (!blockTransferArg && !newResultVal));
920  if (!blockTransferArg) {
921  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
922  newResultVal = newForOp.getResult(collapsedIdx++);
923  }
924  }
925 
926  Block &oldBlock = forOp.getRegion().front();
927  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
928  "unexpected argument size mismatch");
929 
930  // No results case: the scf::ForOp builder already created a zero
931  // result terminator. Merge before this terminator and just get rid of the
932  // original terminator that has been merged in.
933  if (newIterArgs.empty()) {
934  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
935  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
936  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
937  rewriter.replaceOp(forOp, newResultValues);
938  return success();
939  }
940 
941  // No terminator case: merge and rewrite the merged terminator.
942  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
943  OpBuilder::InsertionGuard g(rewriter);
944  rewriter.setInsertionPoint(mergedTerminator);
945  SmallVector<Value, 4> filteredOperands;
946  filteredOperands.reserve(newResultValues.size());
947  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
948  if (keepMask[idx])
949  filteredOperands.push_back(mergedTerminator.getOperand(idx));
950  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
951  filteredOperands);
952  };
953 
954  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
955  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
956  cloneFilteredTerminator(mergedYieldOp);
957  rewriter.eraseOp(mergedYieldOp);
958  rewriter.replaceOp(forOp, newResultValues);
959  return success();
960  }
961 };
962 
963 /// Util function that tries to compute a constant diff between u and l.
964 /// Returns std::nullopt when the difference between two AffineValueMap is
965 /// dynamic.
966 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
967  IntegerAttr clb, cub;
968  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
969  llvm::APInt lbValue = clb.getValue();
970  llvm::APInt ubValue = cub.getValue();
971  return (ubValue - lbValue).getSExtValue();
972  }
973 
974  // Else a simple pattern match for x + c or c + x
975  llvm::APInt diff;
976  if (matchPattern(
977  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
978  matchPattern(
979  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
980  return diff.getSExtValue();
981  return std::nullopt;
982 }
983 
984 /// Rewriting pattern that erases loops that are known not to iterate, replaces
985 /// single-iteration loops with their bodies, and removes empty loops that
986 /// iterate at least once and only return values defined outside of the loop.
987 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
989 
990  LogicalResult matchAndRewrite(ForOp op,
991  PatternRewriter &rewriter) const override {
992  // If the upper bound is the same as the lower bound, the loop does not
993  // iterate, just remove it.
994  if (op.getLowerBound() == op.getUpperBound()) {
995  rewriter.replaceOp(op, op.getInitArgs());
996  return success();
997  }
998 
999  std::optional<int64_t> diff =
1000  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1001  if (!diff)
1002  return failure();
1003 
1004  // If the loop is known to have 0 iterations, remove it.
1005  if (*diff <= 0) {
1006  rewriter.replaceOp(op, op.getInitArgs());
1007  return success();
1008  }
1009 
1010  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1011  if (!maybeStepValue)
1012  return failure();
1013 
1014  // If the loop is known to have 1 iteration, inline its body and remove the
1015  // loop.
1016  llvm::APInt stepValue = *maybeStepValue;
1017  if (stepValue.sge(*diff)) {
1018  SmallVector<Value, 4> blockArgs;
1019  blockArgs.reserve(op.getInitArgs().size() + 1);
1020  blockArgs.push_back(op.getLowerBound());
1021  llvm::append_range(blockArgs, op.getInitArgs());
1022  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1023  return success();
1024  }
1025 
1026  // Now we are left with loops that have more than 1 iterations.
1027  Block &block = op.getRegion().front();
1028  if (!llvm::hasSingleElement(block))
1029  return failure();
1030  // If the loop is empty, iterates at least once, and only returns values
1031  // defined outside of the loop, remove it and replace it with yield values.
1032  if (llvm::any_of(op.getYieldedValues(),
1033  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1034  return failure();
1035  rewriter.replaceOp(op, op.getYieldedValues());
1036  return success();
1037  }
1038 };
1039 
1040 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1041 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1042 ///
1043 /// ```
1044 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1045 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1046 /// -> (tensor<?x?xf32>) {
1047 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1048 /// scf.yield %2 : tensor<?x?xf32>
1049 /// }
1050 /// use_of(%1)
1051 /// ```
1052 ///
1053 /// folds into:
1054 ///
1055 /// ```
1056 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1057 /// -> (tensor<32x1024xf32>) {
1058 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1059 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1060 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1061 /// scf.yield %4 : tensor<32x1024xf32>
1062 /// }
1063 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1064 /// use_of(%1)
1065 /// ```
1066 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1068 
1069  LogicalResult matchAndRewrite(ForOp op,
1070  PatternRewriter &rewriter) const override {
1071  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1072  OpOperand &iterOpOperand = std::get<0>(it);
1073  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1074  if (!incomingCast ||
1075  incomingCast.getSource().getType() == incomingCast.getType())
1076  continue;
1077  // If the dest type of the cast does not preserve static information in
1078  // the source type.
1080  incomingCast.getDest().getType(),
1081  incomingCast.getSource().getType()))
1082  continue;
1083  if (!std::get<1>(it).hasOneUse())
1084  continue;
1085 
1086  // Create a new ForOp with that iter operand replaced.
1087  ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
1088  Value source) {
1089  return b.create<tensor::CastOp>(loc, type, source);
1090  };
1091  rewriter.replaceOp(
1092  op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
1093  incomingCast.getSource(), castFn));
1094  return success();
1095  }
1096  return failure();
1097  }
1098 };
1099 
1100 } // namespace
1101 
1102 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1103  MLIRContext *context) {
1104  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1105  context);
1106 }
1107 
1108 std::optional<APInt> ForOp::getConstantStep() {
1109  IntegerAttr step;
1110  if (matchPattern(getStep(), m_Constant(&step)))
1111  return step.getValue();
1112  return {};
1113 }
1114 
1115 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1116  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1117 }
1118 
1119 Speculation::Speculatability ForOp::getSpeculatability() {
1120  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1121  // and End.
1122  if (auto constantStep = getConstantStep())
1123  if (*constantStep == 1)
1125 
1126  // For Step != 1, the loop may not terminate. We can add more smarts here if
1127  // needed.
1129 }
1130 
1131 //===----------------------------------------------------------------------===//
1132 // ForallOp
1133 //===----------------------------------------------------------------------===//
1134 
1135 LogicalResult ForallOp::verify() {
1136  unsigned numLoops = getRank();
1137  // Check number of outputs.
1138  if (getNumResults() != getOutputs().size())
1139  return emitOpError("produces ")
1140  << getNumResults() << " results, but has only "
1141  << getOutputs().size() << " outputs";
1142 
1143  // Check that the body defines block arguments for thread indices and outputs.
1144  auto *body = getBody();
1145  if (body->getNumArguments() != numLoops + getOutputs().size())
1146  return emitOpError("region expects ") << numLoops << " arguments";
1147  for (int64_t i = 0; i < numLoops; ++i)
1148  if (!body->getArgument(i).getType().isIndex())
1149  return emitOpError("expects ")
1150  << i << "-th block argument to be an index";
1151  for (unsigned i = 0; i < getOutputs().size(); ++i)
1152  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1153  return emitOpError("type mismatch between ")
1154  << i << "-th output and corresponding block argument";
1155  if (getMapping().has_value() && !getMapping()->empty()) {
1156  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1157  return emitOpError() << "mapping attribute size must match op rank";
1158  for (auto map : getMapping()->getValue()) {
1159  if (!isa<DeviceMappingAttrInterface>(map))
1160  return emitOpError()
1161  << getMappingAttrName() << " is not device mapping attribute";
1162  }
1163  }
1164 
1165  // Verify mixed static/dynamic control variables.
1166  Operation *op = getOperation();
1167  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1168  getStaticLowerBound(),
1169  getDynamicLowerBound())))
1170  return failure();
1171  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1172  getStaticUpperBound(),
1173  getDynamicUpperBound())))
1174  return failure();
1175  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1176  getStaticStep(), getDynamicStep())))
1177  return failure();
1178 
1179  return success();
1180 }
1181 
1182 void ForallOp::print(OpAsmPrinter &p) {
1183  Operation *op = getOperation();
1184  p << " (" << getInductionVars();
1185  if (isNormalized()) {
1186  p << ") in ";
1187  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1188  /*valueTypes=*/{}, /*scalables=*/{},
1190  } else {
1191  p << ") = ";
1192  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1193  /*valueTypes=*/{}, /*scalables=*/{},
1195  p << " to ";
1196  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1197  /*valueTypes=*/{}, /*scalables=*/{},
1199  p << " step ";
1200  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1201  /*valueTypes=*/{}, /*scalables=*/{},
1203  }
1204  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1205  p << " ";
1206  if (!getRegionOutArgs().empty())
1207  p << "-> (" << getResultTypes() << ") ";
1208  p.printRegion(getRegion(),
1209  /*printEntryBlockArgs=*/false,
1210  /*printBlockTerminators=*/getNumResults() > 0);
1211  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1212  getStaticLowerBoundAttrName(),
1213  getStaticUpperBoundAttrName(),
1214  getStaticStepAttrName()});
1215 }
1216 
1217 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1218  OpBuilder b(parser.getContext());
1219  auto indexType = b.getIndexType();
1220 
1221  // Parse an opening `(` followed by thread index variables followed by `)`
1222  // TODO: when we can refer to such "induction variable"-like handles from the
1223  // declarative assembly format, we can implement the parser as a custom hook.
1226  return failure();
1227 
1228  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1229  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1230  dynamicSteps;
1231  if (succeeded(parser.parseOptionalKeyword("in"))) {
1232  // Parse upper bounds.
1233  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1234  /*valueTypes=*/nullptr,
1236  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1237  return failure();
1238 
1239  unsigned numLoops = ivs.size();
1240  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1241  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1242  } else {
1243  // Parse lower bounds.
1244  if (parser.parseEqual() ||
1245  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1246  /*valueTypes=*/nullptr,
1248 
1249  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1250  return failure();
1251 
1252  // Parse upper bounds.
1253  if (parser.parseKeyword("to") ||
1254  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1255  /*valueTypes=*/nullptr,
1257  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1258  return failure();
1259 
1260  // Parse step values.
1261  if (parser.parseKeyword("step") ||
1262  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1263  /*valueTypes=*/nullptr,
1265  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1266  return failure();
1267  }
1268 
1269  // Parse out operands and results.
1272  SMLoc outOperandsLoc = parser.getCurrentLocation();
1273  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1274  if (outOperands.size() != result.types.size())
1275  return parser.emitError(outOperandsLoc,
1276  "mismatch between out operands and types");
1277  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1278  parser.parseOptionalArrowTypeList(result.types) ||
1279  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1280  result.operands))
1281  return failure();
1282  }
1283 
1284  // Parse region.
1286  std::unique_ptr<Region> region = std::make_unique<Region>();
1287  for (auto &iv : ivs) {
1288  iv.type = b.getIndexType();
1289  regionArgs.push_back(iv);
1290  }
1291  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1292  auto &out = it.value();
1293  out.type = result.types[it.index()];
1294  regionArgs.push_back(out);
1295  }
1296  if (parser.parseRegion(*region, regionArgs))
1297  return failure();
1298 
1299  // Ensure terminator and move region.
1300  ForallOp::ensureTerminator(*region, b, result.location);
1301  result.addRegion(std::move(region));
1302 
1303  // Parse the optional attribute list.
1304  if (parser.parseOptionalAttrDict(result.attributes))
1305  return failure();
1306 
1307  result.addAttribute("staticLowerBound", staticLbs);
1308  result.addAttribute("staticUpperBound", staticUbs);
1309  result.addAttribute("staticStep", staticSteps);
1310  result.addAttribute("operandSegmentSizes",
1312  {static_cast<int32_t>(dynamicLbs.size()),
1313  static_cast<int32_t>(dynamicUbs.size()),
1314  static_cast<int32_t>(dynamicSteps.size()),
1315  static_cast<int32_t>(outOperands.size())}));
1316  return success();
1317 }
1318 
1319 // Builder that takes loop bounds.
1320 void ForallOp::build(
1323  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1324  std::optional<ArrayAttr> mapping,
1325  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1326  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1327  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1328  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1329  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1330  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1331 
1332  result.addOperands(dynamicLbs);
1333  result.addOperands(dynamicUbs);
1334  result.addOperands(dynamicSteps);
1335  result.addOperands(outputs);
1336  result.addTypes(TypeRange(outputs));
1337 
1338  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1339  b.getDenseI64ArrayAttr(staticLbs));
1340  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1341  b.getDenseI64ArrayAttr(staticUbs));
1342  result.addAttribute(getStaticStepAttrName(result.name),
1343  b.getDenseI64ArrayAttr(staticSteps));
1344  result.addAttribute(
1345  "operandSegmentSizes",
1346  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1347  static_cast<int32_t>(dynamicUbs.size()),
1348  static_cast<int32_t>(dynamicSteps.size()),
1349  static_cast<int32_t>(outputs.size())}));
1350  if (mapping.has_value()) {
1352  mapping.value());
1353  }
1354 
1355  Region *bodyRegion = result.addRegion();
1357  b.createBlock(bodyRegion);
1358  Block &bodyBlock = bodyRegion->front();
1359 
1360  // Add block arguments for indices and outputs.
1361  bodyBlock.addArguments(
1362  SmallVector<Type>(lbs.size(), b.getIndexType()),
1363  SmallVector<Location>(staticLbs.size(), result.location));
1364  bodyBlock.addArguments(
1365  TypeRange(outputs),
1366  SmallVector<Location>(outputs.size(), result.location));
1367 
1368  b.setInsertionPointToStart(&bodyBlock);
1369  if (!bodyBuilderFn) {
1370  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1371  return;
1372  }
1373  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1374 }
1375 
1376 // Builder that takes loop bounds.
1377 void ForallOp::build(
1379  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1380  std::optional<ArrayAttr> mapping,
1381  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1382  unsigned numLoops = ubs.size();
1383  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1384  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1385  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1386 }
1387 
1388 // Checks if the lbs are zeros and steps are ones.
1389 bool ForallOp::isNormalized() {
1390  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1391  return llvm::all_of(results, [&](OpFoldResult ofr) {
1392  auto intValue = getConstantIntValue(ofr);
1393  return intValue.has_value() && intValue == val;
1394  });
1395  };
1396  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1397 }
1398 
1399 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1400 // unaware of the fact that our terminator also needs a region to be
1401 // well-formed. We override it here to ensure that we do the right thing.
1402 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1403  Location loc) {
1405  ForallOp>::ensureTerminator(region, builder, loc);
1406  auto terminator =
1407  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1408  if (terminator.getRegion().empty())
1409  builder.createBlock(&terminator.getRegion());
1410 }
1411 
1412 InParallelOp ForallOp::getTerminator() {
1413  return cast<InParallelOp>(getBody()->getTerminator());
1414 }
1415 
1416 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1417  SmallVector<Operation *> storeOps;
1418  InParallelOp inParallelOp = getTerminator();
1419  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1420  if (auto parallelInsertSliceOp =
1421  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1422  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1423  storeOps.push_back(parallelInsertSliceOp);
1424  }
1425  }
1426  return storeOps;
1427 }
1428 
1429 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1430  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1431 }
1432 
1433 // Get lower bounds as OpFoldResult.
1434 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1435  Builder b(getOperation()->getContext());
1436  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1437 }
1438 
1439 // Get upper bounds as OpFoldResult.
1440 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1441  Builder b(getOperation()->getContext());
1442  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1443 }
1444 
1445 // Get steps as OpFoldResult.
1446 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1447  Builder b(getOperation()->getContext());
1448  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1449 }
1450 
1452  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1453  if (!tidxArg)
1454  return ForallOp();
1455  assert(tidxArg.getOwner() && "unlinked block argument");
1456  auto *containingOp = tidxArg.getOwner()->getParentOp();
1457  return dyn_cast<ForallOp>(containingOp);
1458 }
1459 
1460 namespace {
1461 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1462 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1464 
1465  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1466  PatternRewriter &rewriter) const final {
1467  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1468  if (!forallOp)
1469  return failure();
1470  Value sharedOut =
1471  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1472  ->get();
1473  rewriter.modifyOpInPlace(
1474  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1475  return success();
1476  }
1477 };
1478 
1479 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1480 public:
1482 
1483  LogicalResult matchAndRewrite(ForallOp op,
1484  PatternRewriter &rewriter) const override {
1485  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1486  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1487  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1488  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1489  failed(foldDynamicIndexList(mixedUpperBound)) &&
1490  failed(foldDynamicIndexList(mixedStep)))
1491  return failure();
1492 
1493  rewriter.modifyOpInPlace(op, [&]() {
1494  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1495  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1496  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1497  staticLowerBound);
1498  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1499  op.setStaticLowerBound(staticLowerBound);
1500 
1501  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1502  staticUpperBound);
1503  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1504  op.setStaticUpperBound(staticUpperBound);
1505 
1506  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1507  op.getDynamicStepMutable().assign(dynamicStep);
1508  op.setStaticStep(staticStep);
1509 
1510  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1511  rewriter.getDenseI32ArrayAttr(
1512  {static_cast<int32_t>(dynamicLowerBound.size()),
1513  static_cast<int32_t>(dynamicUpperBound.size()),
1514  static_cast<int32_t>(dynamicStep.size()),
1515  static_cast<int32_t>(op.getNumResults())}));
1516  });
1517  return success();
1518  }
1519 };
1520 
1521 /// The following canonicalization pattern folds the iter arguments of
1522 /// scf.forall op if :-
1523 /// 1. The corresponding result has zero uses.
1524 /// 2. The iter argument is NOT being modified within the loop body.
1525 /// uses.
1526 ///
1527 /// Example of first case :-
1528 /// INPUT:
1529 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1530 /// {
1531 /// ...
1532 /// <SOME USE OF %arg0>
1533 /// <SOME USE OF %arg1>
1534 /// <SOME USE OF %arg2>
1535 /// ...
1536 /// scf.forall.in_parallel {
1537 /// <STORE OP WITH DESTINATION %arg1>
1538 /// <STORE OP WITH DESTINATION %arg0>
1539 /// <STORE OP WITH DESTINATION %arg2>
1540 /// }
1541 /// }
1542 /// return %res#1
1543 ///
1544 /// OUTPUT:
1545 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1546 /// {
1547 /// ...
1548 /// <SOME USE OF %a>
1549 /// <SOME USE OF %new_arg0>
1550 /// <SOME USE OF %c>
1551 /// ...
1552 /// scf.forall.in_parallel {
1553 /// <STORE OP WITH DESTINATION %new_arg0>
1554 /// }
1555 /// }
1556 /// return %res
1557 ///
1558 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1559 /// scf.forall is replaced by their corresponding operands.
1560 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1561 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1562 /// this canonicalization remains valid. For more details, please refer
1563 /// to :
1564 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1565 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1566 /// handles tensor.parallel_insert_slice ops only.
1567 ///
1568 /// Example of second case :-
1569 /// INPUT:
1570 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1571 /// {
1572 /// ...
1573 /// <SOME USE OF %arg0>
1574 /// <SOME USE OF %arg1>
1575 /// ...
1576 /// scf.forall.in_parallel {
1577 /// <STORE OP WITH DESTINATION %arg1>
1578 /// }
1579 /// }
1580 /// return %res#0, %res#1
1581 ///
1582 /// OUTPUT:
1583 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1584 /// {
1585 /// ...
1586 /// <SOME USE OF %a>
1587 /// <SOME USE OF %new_arg0>
1588 /// ...
1589 /// scf.forall.in_parallel {
1590 /// <STORE OP WITH DESTINATION %new_arg0>
1591 /// }
1592 /// }
1593 /// return %a, %res
1594 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1596 
1597  LogicalResult matchAndRewrite(ForallOp forallOp,
1598  PatternRewriter &rewriter) const final {
1599  // Step 1: For a given i-th result of scf.forall, check the following :-
1600  // a. If it has any use.
1601  // b. If the corresponding iter argument is being modified within
1602  // the loop, i.e. has at least one store op with the iter arg as
1603  // its destination operand. For this we use
1604  // ForallOp::getCombiningOps(iter_arg).
1605  //
1606  // Based on the check we maintain the following :-
1607  // a. `resultToDelete` - i-th result of scf.forall that'll be
1608  // deleted.
1609  // b. `resultToReplace` - i-th result of the old scf.forall
1610  // whose uses will be replaced by the new scf.forall.
1611  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1612  // corresponding to the i-th result with at least one use.
1613  SetVector<OpResult> resultToDelete;
1614  SmallVector<Value> resultToReplace;
1615  SmallVector<Value> newOuts;
1616  for (OpResult result : forallOp.getResults()) {
1617  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1618  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1619  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1620  resultToDelete.insert(result);
1621  } else {
1622  resultToReplace.push_back(result);
1623  newOuts.push_back(opOperand->get());
1624  }
1625  }
1626 
1627  // Return early if all results of scf.forall have at least one use and being
1628  // modified within the loop.
1629  if (resultToDelete.empty())
1630  return failure();
1631 
1632  // Step 2: For the the i-th result, do the following :-
1633  // a. Fetch the corresponding BlockArgument.
1634  // b. Look for store ops (currently tensor.parallel_insert_slice)
1635  // with the BlockArgument as its destination operand.
1636  // c. Remove the operations fetched in b.
1637  for (OpResult result : resultToDelete) {
1638  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1639  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1640  SmallVector<Operation *> combiningOps =
1641  forallOp.getCombiningOps(blockArg);
1642  for (Operation *combiningOp : combiningOps)
1643  rewriter.eraseOp(combiningOp);
1644  }
1645 
1646  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1647  // fetched earlier
1648  auto newForallOp = rewriter.create<scf::ForallOp>(
1649  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1650  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1651  forallOp.getMapping(),
1652  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1653 
1654  // Step 4. Merge the block of the old scf.forall into the newly created
1655  // scf.forall using the new set of arguments.
1656  Block *loopBody = forallOp.getBody();
1657  Block *newLoopBody = newForallOp.getBody();
1658  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1659  // Form initial new bbArg list with just the control operands of the new
1660  // scf.forall op.
1661  SmallVector<Value> newBlockArgs =
1662  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1663  [](BlockArgument b) -> Value { return b; });
1664  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1665  unsigned index = 0;
1666  // Take the new corresponding bbArg if the old bbArg was used as a
1667  // destination in the in_parallel op. For all other bbArgs, use the
1668  // corresponding init_arg from the old scf.forall op.
1669  for (OpResult result : forallOp.getResults()) {
1670  if (resultToDelete.count(result)) {
1671  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1672  } else {
1673  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1674  }
1675  }
1676  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1677 
1678  // Step 5. Replace the uses of result of old scf.forall with that of the new
1679  // scf.forall.
1680  for (auto &&[oldResult, newResult] :
1681  llvm::zip(resultToReplace, newForallOp->getResults()))
1682  rewriter.replaceAllUsesWith(oldResult, newResult);
1683 
1684  // Step 6. Replace the uses of those values that either has no use or are
1685  // not being modified within the loop with the corresponding
1686  // OpOperand.
1687  for (OpResult oldResult : resultToDelete)
1688  rewriter.replaceAllUsesWith(oldResult,
1689  forallOp.getTiedOpOperand(oldResult)->get());
1690  return success();
1691  }
1692 };
1693 
1694 struct ForallOpSingleOrZeroIterationDimsFolder
1695  : public OpRewritePattern<ForallOp> {
1697 
1698  LogicalResult matchAndRewrite(ForallOp op,
1699  PatternRewriter &rewriter) const override {
1700  // Do not fold dimensions if they are mapped to processing units.
1701  if (op.getMapping().has_value() && !op.getMapping()->empty())
1702  return failure();
1703  Location loc = op.getLoc();
1704 
1705  // Compute new loop bounds that omit all single-iteration loop dimensions.
1706  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1707  newMixedSteps;
1708  IRMapping mapping;
1709  for (auto [lb, ub, step, iv] :
1710  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1711  op.getMixedStep(), op.getInductionVars())) {
1712  auto numIterations = constantTripCount(lb, ub, step);
1713  if (numIterations.has_value()) {
1714  // Remove the loop if it performs zero iterations.
1715  if (*numIterations == 0) {
1716  rewriter.replaceOp(op, op.getOutputs());
1717  return success();
1718  }
1719  // Replace the loop induction variable by the lower bound if the loop
1720  // performs a single iteration. Otherwise, copy the loop bounds.
1721  if (*numIterations == 1) {
1722  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1723  continue;
1724  }
1725  }
1726  newMixedLowerBounds.push_back(lb);
1727  newMixedUpperBounds.push_back(ub);
1728  newMixedSteps.push_back(step);
1729  }
1730 
1731  // All of the loop dimensions perform a single iteration. Inline loop body.
1732  if (newMixedLowerBounds.empty()) {
1733  promote(rewriter, op);
1734  return success();
1735  }
1736 
1737  // Exit if none of the loop dimensions perform a single iteration.
1738  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1739  return rewriter.notifyMatchFailure(
1740  op, "no dimensions have 0 or 1 iterations");
1741  }
1742 
1743  // Replace the loop by a lower-dimensional loop.
1744  ForallOp newOp;
1745  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1746  newMixedUpperBounds, newMixedSteps,
1747  op.getOutputs(), std::nullopt, nullptr);
1748  newOp.getBodyRegion().getBlocks().clear();
1749  // The new loop needs to keep all attributes from the old one, except for
1750  // "operandSegmentSizes" and static loop bound attributes which capture
1751  // the outdated information of the old iteration domain.
1752  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1753  newOp.getStaticLowerBoundAttrName(),
1754  newOp.getStaticUpperBoundAttrName(),
1755  newOp.getStaticStepAttrName()};
1756  for (const auto &namedAttr : op->getAttrs()) {
1757  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1758  continue;
1759  rewriter.modifyOpInPlace(newOp, [&]() {
1760  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1761  });
1762  }
1763  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1764  newOp.getRegion().begin(), mapping);
1765  rewriter.replaceOp(op, newOp.getResults());
1766  return success();
1767  }
1768 };
1769 
1770 struct FoldTensorCastOfOutputIntoForallOp
1771  : public OpRewritePattern<scf::ForallOp> {
1773 
1774  struct TypeCast {
1775  Type srcType;
1776  Type dstType;
1777  };
1778 
1779  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1780  PatternRewriter &rewriter) const final {
1781  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1782  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1783  for (auto en : llvm::enumerate(newOutputTensors)) {
1784  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1785  if (!castOp)
1786  continue;
1787 
1788  // Only casts that that preserve static information, i.e. will make the
1789  // loop result type "more" static than before, will be folded.
1790  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1791  castOp.getSource().getType())) {
1792  continue;
1793  }
1794 
1795  tensorCastProducers[en.index()] =
1796  TypeCast{castOp.getSource().getType(), castOp.getType()};
1797  newOutputTensors[en.index()] = castOp.getSource();
1798  }
1799 
1800  if (tensorCastProducers.empty())
1801  return failure();
1802 
1803  // Create new loop.
1804  Location loc = forallOp.getLoc();
1805  auto newForallOp = rewriter.create<ForallOp>(
1806  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1807  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1808  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1809  auto castBlockArgs =
1810  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1811  for (auto [index, cast] : tensorCastProducers) {
1812  Value &oldTypeBBArg = castBlockArgs[index];
1813  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1814  nestedLoc, cast.dstType, oldTypeBBArg);
1815  }
1816 
1817  // Move old body into new parallel loop.
1818  SmallVector<Value> ivsBlockArgs =
1819  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1820  ivsBlockArgs.append(castBlockArgs);
1821  rewriter.mergeBlocks(forallOp.getBody(),
1822  bbArgs.front().getParentBlock(), ivsBlockArgs);
1823  });
1824 
1825  // After `mergeBlocks` happened, the destinations in the terminator were
1826  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1827  // destination have to be updated to point to the output bbArgs directly.
1828  auto terminator = newForallOp.getTerminator();
1829  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1830  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1831  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1832  insertSliceOp.getDestMutable().assign(outputBlockArg);
1833  }
1834 
1835  // Cast results back to the original types.
1836  rewriter.setInsertionPointAfter(newForallOp);
1837  SmallVector<Value> castResults = newForallOp.getResults();
1838  for (auto &item : tensorCastProducers) {
1839  Value &oldTypeResult = castResults[item.first];
1840  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1841  oldTypeResult);
1842  }
1843  rewriter.replaceOp(forallOp, castResults);
1844  return success();
1845  }
1846 };
1847 
1848 } // namespace
1849 
1850 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1851  MLIRContext *context) {
1852  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1853  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1854  ForallOpSingleOrZeroIterationDimsFolder>(context);
1855 }
1856 
1857 /// Given the region at `index`, or the parent operation if `index` is None,
1858 /// return the successor regions. These are the regions that may be selected
1859 /// during the flow of control. `operands` is a set of optional attributes that
1860 /// correspond to a constant value for each operand, or null if that operand is
1861 /// not a constant.
1862 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1864  // Both the operation itself and the region may be branching into the body or
1865  // back into the operation itself. It is possible for loop not to enter the
1866  // body.
1867  regions.push_back(RegionSuccessor(&getRegion()));
1868  regions.push_back(RegionSuccessor());
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // InParallelOp
1873 //===----------------------------------------------------------------------===//
1874 
1875 // Build a InParallelOp with mixed static and dynamic entries.
1876 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1878  Region *bodyRegion = result.addRegion();
1879  b.createBlock(bodyRegion);
1880 }
1881 
1882 LogicalResult InParallelOp::verify() {
1883  scf::ForallOp forallOp =
1884  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1885  if (!forallOp)
1886  return this->emitOpError("expected forall op parent");
1887 
1888  // TODO: InParallelOpInterface.
1889  for (Operation &op : getRegion().front().getOperations()) {
1890  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1891  return this->emitOpError("expected only ")
1892  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1893  }
1894 
1895  // Verify that inserts are into out block arguments.
1896  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1897  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1898  if (!llvm::is_contained(regionOutArgs, dest))
1899  return op.emitOpError("may only insert into an output block argument");
1900  }
1901  return success();
1902 }
1903 
1905  p << " ";
1906  p.printRegion(getRegion(),
1907  /*printEntryBlockArgs=*/false,
1908  /*printBlockTerminators=*/false);
1909  p.printOptionalAttrDict(getOperation()->getAttrs());
1910 }
1911 
1912 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1913  auto &builder = parser.getBuilder();
1914 
1916  std::unique_ptr<Region> region = std::make_unique<Region>();
1917  if (parser.parseRegion(*region, regionOperands))
1918  return failure();
1919 
1920  if (region->empty())
1921  OpBuilder(builder.getContext()).createBlock(region.get());
1922  result.addRegion(std::move(region));
1923 
1924  // Parse the optional attribute list.
1925  if (parser.parseOptionalAttrDict(result.attributes))
1926  return failure();
1927  return success();
1928 }
1929 
1930 OpResult InParallelOp::getParentResult(int64_t idx) {
1931  return getOperation()->getParentOp()->getResult(idx);
1932 }
1933 
1934 SmallVector<BlockArgument> InParallelOp::getDests() {
1935  return llvm::to_vector<4>(
1936  llvm::map_range(getYieldingOps(), [](Operation &op) {
1937  // Add new ops here as needed.
1938  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1939  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1940  }));
1941 }
1942 
1943 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1944  return getRegion().front().getOperations();
1945 }
1946 
1947 //===----------------------------------------------------------------------===//
1948 // IfOp
1949 //===----------------------------------------------------------------------===//
1950 
1952  assert(a && "expected non-empty operation");
1953  assert(b && "expected non-empty operation");
1954 
1955  IfOp ifOp = a->getParentOfType<IfOp>();
1956  while (ifOp) {
1957  // Check if b is inside ifOp. (We already know that a is.)
1958  if (ifOp->isProperAncestor(b))
1959  // b is contained in ifOp. a and b are in mutually exclusive branches if
1960  // they are in different blocks of ifOp.
1961  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1962  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1963  // Check next enclosing IfOp.
1964  ifOp = ifOp->getParentOfType<IfOp>();
1965  }
1966 
1967  // Could not find a common IfOp among a's and b's ancestors.
1968  return false;
1969 }
1970 
1971 LogicalResult
1972 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1973  IfOp::Adaptor adaptor,
1974  SmallVectorImpl<Type> &inferredReturnTypes) {
1975  if (adaptor.getRegions().empty())
1976  return failure();
1977  Region *r = &adaptor.getThenRegion();
1978  if (r->empty())
1979  return failure();
1980  Block &b = r->front();
1981  if (b.empty())
1982  return failure();
1983  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1984  if (!yieldOp)
1985  return failure();
1986  TypeRange types = yieldOp.getOperandTypes();
1987  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1988  types.end());
1989  return success();
1990 }
1991 
1992 void IfOp::build(OpBuilder &builder, OperationState &result,
1993  TypeRange resultTypes, Value cond) {
1994  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1995  /*addElseBlock=*/false);
1996 }
1997 
1998 void IfOp::build(OpBuilder &builder, OperationState &result,
1999  TypeRange resultTypes, Value cond, bool addThenBlock,
2000  bool addElseBlock) {
2001  assert((!addElseBlock || addThenBlock) &&
2002  "must not create else block w/o then block");
2003  result.addTypes(resultTypes);
2004  result.addOperands(cond);
2005 
2006  // Add regions and blocks.
2007  OpBuilder::InsertionGuard guard(builder);
2008  Region *thenRegion = result.addRegion();
2009  if (addThenBlock)
2010  builder.createBlock(thenRegion);
2011  Region *elseRegion = result.addRegion();
2012  if (addElseBlock)
2013  builder.createBlock(elseRegion);
2014 }
2015 
2016 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2017  bool withElseRegion) {
2018  build(builder, result, TypeRange{}, cond, withElseRegion);
2019 }
2020 
2021 void IfOp::build(OpBuilder &builder, OperationState &result,
2022  TypeRange resultTypes, Value cond, bool withElseRegion) {
2023  result.addTypes(resultTypes);
2024  result.addOperands(cond);
2025 
2026  // Build then region.
2027  OpBuilder::InsertionGuard guard(builder);
2028  Region *thenRegion = result.addRegion();
2029  builder.createBlock(thenRegion);
2030  if (resultTypes.empty())
2031  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2032 
2033  // Build else region.
2034  Region *elseRegion = result.addRegion();
2035  if (withElseRegion) {
2036  builder.createBlock(elseRegion);
2037  if (resultTypes.empty())
2038  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2039  }
2040 }
2041 
2042 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2043  function_ref<void(OpBuilder &, Location)> thenBuilder,
2044  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2045  assert(thenBuilder && "the builder callback for 'then' must be present");
2046  result.addOperands(cond);
2047 
2048  // Build then region.
2049  OpBuilder::InsertionGuard guard(builder);
2050  Region *thenRegion = result.addRegion();
2051  builder.createBlock(thenRegion);
2052  thenBuilder(builder, result.location);
2053 
2054  // Build else region.
2055  Region *elseRegion = result.addRegion();
2056  if (elseBuilder) {
2057  builder.createBlock(elseRegion);
2058  elseBuilder(builder, result.location);
2059  }
2060 
2061  // Infer result types.
2062  SmallVector<Type> inferredReturnTypes;
2063  MLIRContext *ctx = builder.getContext();
2064  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2065  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2066  /*properties=*/nullptr, result.regions,
2067  inferredReturnTypes))) {
2068  result.addTypes(inferredReturnTypes);
2069  }
2070 }
2071 
2072 LogicalResult IfOp::verify() {
2073  if (getNumResults() != 0 && getElseRegion().empty())
2074  return emitOpError("must have an else block if defining values");
2075  return success();
2076 }
2077 
2078 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2079  // Create the regions for 'then'.
2080  result.regions.reserve(2);
2081  Region *thenRegion = result.addRegion();
2082  Region *elseRegion = result.addRegion();
2083 
2084  auto &builder = parser.getBuilder();
2086  Type i1Type = builder.getIntegerType(1);
2087  if (parser.parseOperand(cond) ||
2088  parser.resolveOperand(cond, i1Type, result.operands))
2089  return failure();
2090  // Parse optional results type list.
2091  if (parser.parseOptionalArrowTypeList(result.types))
2092  return failure();
2093  // Parse the 'then' region.
2094  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2095  return failure();
2096  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2097 
2098  // If we find an 'else' keyword then parse the 'else' region.
2099  if (!parser.parseOptionalKeyword("else")) {
2100  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2101  return failure();
2102  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2103  }
2104 
2105  // Parse the optional attribute list.
2106  if (parser.parseOptionalAttrDict(result.attributes))
2107  return failure();
2108  return success();
2109 }
2110 
2111 void IfOp::print(OpAsmPrinter &p) {
2112  bool printBlockTerminators = false;
2113 
2114  p << " " << getCondition();
2115  if (!getResults().empty()) {
2116  p << " -> (" << getResultTypes() << ")";
2117  // Print yield explicitly if the op defines values.
2118  printBlockTerminators = true;
2119  }
2120  p << ' ';
2121  p.printRegion(getThenRegion(),
2122  /*printEntryBlockArgs=*/false,
2123  /*printBlockTerminators=*/printBlockTerminators);
2124 
2125  // Print the 'else' regions if it exists and has a block.
2126  auto &elseRegion = getElseRegion();
2127  if (!elseRegion.empty()) {
2128  p << " else ";
2129  p.printRegion(elseRegion,
2130  /*printEntryBlockArgs=*/false,
2131  /*printBlockTerminators=*/printBlockTerminators);
2132  }
2133 
2134  p.printOptionalAttrDict((*this)->getAttrs());
2135 }
2136 
2137 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2139  // The `then` and the `else` region branch back to the parent operation.
2140  if (!point.isParent()) {
2141  regions.push_back(RegionSuccessor(getResults()));
2142  return;
2143  }
2144 
2145  regions.push_back(RegionSuccessor(&getThenRegion()));
2146 
2147  // Don't consider the else region if it is empty.
2148  Region *elseRegion = &this->getElseRegion();
2149  if (elseRegion->empty())
2150  regions.push_back(RegionSuccessor());
2151  else
2152  regions.push_back(RegionSuccessor(elseRegion));
2153 }
2154 
2155 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2157  FoldAdaptor adaptor(operands, *this);
2158  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2159  if (!boolAttr || boolAttr.getValue())
2160  regions.emplace_back(&getThenRegion());
2161 
2162  // If the else region is empty, execution continues after the parent op.
2163  if (!boolAttr || !boolAttr.getValue()) {
2164  if (!getElseRegion().empty())
2165  regions.emplace_back(&getElseRegion());
2166  else
2167  regions.emplace_back(getResults());
2168  }
2169 }
2170 
2171 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2172  SmallVectorImpl<OpFoldResult> &results) {
2173  // if (!c) then A() else B() -> if c then B() else A()
2174  if (getElseRegion().empty())
2175  return failure();
2176 
2177  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2178  if (!xorStmt)
2179  return failure();
2180 
2181  if (!matchPattern(xorStmt.getRhs(), m_One()))
2182  return failure();
2183 
2184  getConditionMutable().assign(xorStmt.getLhs());
2185  Block *thenBlock = &getThenRegion().front();
2186  // It would be nicer to use iplist::swap, but that has no implemented
2187  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2188  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2189  getElseRegion().getBlocks());
2190  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2191  getThenRegion().getBlocks(), thenBlock);
2192  return success();
2193 }
2194 
2195 void IfOp::getRegionInvocationBounds(
2196  ArrayRef<Attribute> operands,
2197  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2198  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2199  // If the condition is known, then one region is known to be executed once
2200  // and the other zero times.
2201  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2202  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2203  } else {
2204  // Non-constant condition. Each region may be executed 0 or 1 times.
2205  invocationBounds.assign(2, {0, 1});
2206  }
2207 }
2208 
2209 namespace {
2210 // Pattern to remove unused IfOp results.
2211 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2213 
2214  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2215  PatternRewriter &rewriter) const {
2216  // Move all operations to the destination block.
2217  rewriter.mergeBlocks(source, dest);
2218  // Replace the yield op by one that returns only the used values.
2219  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2220  SmallVector<Value, 4> usedOperands;
2221  llvm::transform(usedResults, std::back_inserter(usedOperands),
2222  [&](OpResult result) {
2223  return yieldOp.getOperand(result.getResultNumber());
2224  });
2225  rewriter.modifyOpInPlace(yieldOp,
2226  [&]() { yieldOp->setOperands(usedOperands); });
2227  }
2228 
2229  LogicalResult matchAndRewrite(IfOp op,
2230  PatternRewriter &rewriter) const override {
2231  // Compute the list of used results.
2232  SmallVector<OpResult, 4> usedResults;
2233  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2234  [](OpResult result) { return !result.use_empty(); });
2235 
2236  // Replace the operation if only a subset of its results have uses.
2237  if (usedResults.size() == op.getNumResults())
2238  return failure();
2239 
2240  // Compute the result types of the replacement operation.
2241  SmallVector<Type, 4> newTypes;
2242  llvm::transform(usedResults, std::back_inserter(newTypes),
2243  [](OpResult result) { return result.getType(); });
2244 
2245  // Create a replacement operation with empty then and else regions.
2246  auto newOp =
2247  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2248  rewriter.createBlock(&newOp.getThenRegion());
2249  rewriter.createBlock(&newOp.getElseRegion());
2250 
2251  // Move the bodies and replace the terminators (note there is a then and
2252  // an else region since the operation returns results).
2253  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2254  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2255 
2256  // Replace the operation by the new one.
2257  SmallVector<Value, 4> repResults(op.getNumResults());
2258  for (const auto &en : llvm::enumerate(usedResults))
2259  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2260  rewriter.replaceOp(op, repResults);
2261  return success();
2262  }
2263 };
2264 
2265 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2267 
2268  LogicalResult matchAndRewrite(IfOp op,
2269  PatternRewriter &rewriter) const override {
2270  BoolAttr condition;
2271  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2272  return failure();
2273 
2274  if (condition.getValue())
2275  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2276  else if (!op.getElseRegion().empty())
2277  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2278  else
2279  rewriter.eraseOp(op);
2280 
2281  return success();
2282  }
2283 };
2284 
2285 /// Hoist any yielded results whose operands are defined outside
2286 /// the if, to a select instruction.
2287 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2289 
2290  LogicalResult matchAndRewrite(IfOp op,
2291  PatternRewriter &rewriter) const override {
2292  if (op->getNumResults() == 0)
2293  return failure();
2294 
2295  auto cond = op.getCondition();
2296  auto thenYieldArgs = op.thenYield().getOperands();
2297  auto elseYieldArgs = op.elseYield().getOperands();
2298 
2299  SmallVector<Type> nonHoistable;
2300  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2301  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2302  &op.getElseRegion() == falseVal.getParentRegion())
2303  nonHoistable.push_back(trueVal.getType());
2304  }
2305  // Early exit if there aren't any yielded values we can
2306  // hoist outside the if.
2307  if (nonHoistable.size() == op->getNumResults())
2308  return failure();
2309 
2310  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2311  /*withElseRegion=*/false);
2312  if (replacement.thenBlock())
2313  rewriter.eraseBlock(replacement.thenBlock());
2314  replacement.getThenRegion().takeBody(op.getThenRegion());
2315  replacement.getElseRegion().takeBody(op.getElseRegion());
2316 
2317  SmallVector<Value> results(op->getNumResults());
2318  assert(thenYieldArgs.size() == results.size());
2319  assert(elseYieldArgs.size() == results.size());
2320 
2321  SmallVector<Value> trueYields;
2322  SmallVector<Value> falseYields;
2323  rewriter.setInsertionPoint(replacement);
2324  for (const auto &it :
2325  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2326  Value trueVal = std::get<0>(it.value());
2327  Value falseVal = std::get<1>(it.value());
2328  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2329  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2330  results[it.index()] = replacement.getResult(trueYields.size());
2331  trueYields.push_back(trueVal);
2332  falseYields.push_back(falseVal);
2333  } else if (trueVal == falseVal)
2334  results[it.index()] = trueVal;
2335  else
2336  results[it.index()] = rewriter.create<arith::SelectOp>(
2337  op.getLoc(), cond, trueVal, falseVal);
2338  }
2339 
2340  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2341  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2342 
2343  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2344  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2345 
2346  rewriter.replaceOp(op, results);
2347  return success();
2348  }
2349 };
2350 
2351 /// Allow the true region of an if to assume the condition is true
2352 /// and vice versa. For example:
2353 ///
2354 /// scf.if %cmp {
2355 /// print(%cmp)
2356 /// }
2357 ///
2358 /// becomes
2359 ///
2360 /// scf.if %cmp {
2361 /// print(true)
2362 /// }
2363 ///
2364 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2366 
2367  LogicalResult matchAndRewrite(IfOp op,
2368  PatternRewriter &rewriter) const override {
2369  // Early exit if the condition is constant since replacing a constant
2370  // in the body with another constant isn't a simplification.
2371  if (matchPattern(op.getCondition(), m_Constant()))
2372  return failure();
2373 
2374  bool changed = false;
2375  mlir::Type i1Ty = rewriter.getI1Type();
2376 
2377  // These variables serve to prevent creating duplicate constants
2378  // and hold constant true or false values.
2379  Value constantTrue = nullptr;
2380  Value constantFalse = nullptr;
2381 
2382  for (OpOperand &use :
2383  llvm::make_early_inc_range(op.getCondition().getUses())) {
2384  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2385  changed = true;
2386 
2387  if (!constantTrue)
2388  constantTrue = rewriter.create<arith::ConstantOp>(
2389  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2390 
2391  rewriter.modifyOpInPlace(use.getOwner(),
2392  [&]() { use.set(constantTrue); });
2393  } else if (op.getElseRegion().isAncestor(
2394  use.getOwner()->getParentRegion())) {
2395  changed = true;
2396 
2397  if (!constantFalse)
2398  constantFalse = rewriter.create<arith::ConstantOp>(
2399  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2400 
2401  rewriter.modifyOpInPlace(use.getOwner(),
2402  [&]() { use.set(constantFalse); });
2403  }
2404  }
2405 
2406  return success(changed);
2407  }
2408 };
2409 
2410 /// Remove any statements from an if that are equivalent to the condition
2411 /// or its negation. For example:
2412 ///
2413 /// %res:2 = scf.if %cmp {
2414 /// yield something(), true
2415 /// } else {
2416 /// yield something2(), false
2417 /// }
2418 /// print(%res#1)
2419 ///
2420 /// becomes
2421 /// %res = scf.if %cmp {
2422 /// yield something()
2423 /// } else {
2424 /// yield something2()
2425 /// }
2426 /// print(%cmp)
2427 ///
2428 /// Additionally if both branches yield the same value, replace all uses
2429 /// of the result with the yielded value.
2430 ///
2431 /// %res:2 = scf.if %cmp {
2432 /// yield something(), %arg1
2433 /// } else {
2434 /// yield something2(), %arg1
2435 /// }
2436 /// print(%res#1)
2437 ///
2438 /// becomes
2439 /// %res = scf.if %cmp {
2440 /// yield something()
2441 /// } else {
2442 /// yield something2()
2443 /// }
2444 /// print(%arg1)
2445 ///
2446 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2448 
2449  LogicalResult matchAndRewrite(IfOp op,
2450  PatternRewriter &rewriter) const override {
2451  // Early exit if there are no results that could be replaced.
2452  if (op.getNumResults() == 0)
2453  return failure();
2454 
2455  auto trueYield =
2456  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2457  auto falseYield =
2458  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2459 
2460  rewriter.setInsertionPoint(op->getBlock(),
2461  op.getOperation()->getIterator());
2462  bool changed = false;
2463  Type i1Ty = rewriter.getI1Type();
2464  for (auto [trueResult, falseResult, opResult] :
2465  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2466  op.getResults())) {
2467  if (trueResult == falseResult) {
2468  if (!opResult.use_empty()) {
2469  opResult.replaceAllUsesWith(trueResult);
2470  changed = true;
2471  }
2472  continue;
2473  }
2474 
2475  BoolAttr trueYield, falseYield;
2476  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2477  !matchPattern(falseResult, m_Constant(&falseYield)))
2478  continue;
2479 
2480  bool trueVal = trueYield.getValue();
2481  bool falseVal = falseYield.getValue();
2482  if (!trueVal && falseVal) {
2483  if (!opResult.use_empty()) {
2484  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2485  Value notCond = rewriter.create<arith::XOrIOp>(
2486  op.getLoc(), op.getCondition(),
2487  constDialect
2488  ->materializeConstant(rewriter,
2489  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2490  op.getLoc())
2491  ->getResult(0));
2492  opResult.replaceAllUsesWith(notCond);
2493  changed = true;
2494  }
2495  }
2496  if (trueVal && !falseVal) {
2497  if (!opResult.use_empty()) {
2498  opResult.replaceAllUsesWith(op.getCondition());
2499  changed = true;
2500  }
2501  }
2502  }
2503  return success(changed);
2504  }
2505 };
2506 
2507 /// Merge any consecutive scf.if's with the same condition.
2508 ///
2509 /// scf.if %cond {
2510 /// firstCodeTrue();...
2511 /// } else {
2512 /// firstCodeFalse();...
2513 /// }
2514 /// %res = scf.if %cond {
2515 /// secondCodeTrue();...
2516 /// } else {
2517 /// secondCodeFalse();...
2518 /// }
2519 ///
2520 /// becomes
2521 /// %res = scf.if %cmp {
2522 /// firstCodeTrue();...
2523 /// secondCodeTrue();...
2524 /// } else {
2525 /// firstCodeFalse();...
2526 /// secondCodeFalse();...
2527 /// }
2528 struct CombineIfs : public OpRewritePattern<IfOp> {
2530 
2531  LogicalResult matchAndRewrite(IfOp nextIf,
2532  PatternRewriter &rewriter) const override {
2533  Block *parent = nextIf->getBlock();
2534  if (nextIf == &parent->front())
2535  return failure();
2536 
2537  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2538  if (!prevIf)
2539  return failure();
2540 
2541  // Determine the logical then/else blocks when prevIf's
2542  // condition is used. Null means the block does not exist
2543  // in that case (e.g. empty else). If neither of these
2544  // are set, the two conditions cannot be compared.
2545  Block *nextThen = nullptr;
2546  Block *nextElse = nullptr;
2547  if (nextIf.getCondition() == prevIf.getCondition()) {
2548  nextThen = nextIf.thenBlock();
2549  if (!nextIf.getElseRegion().empty())
2550  nextElse = nextIf.elseBlock();
2551  }
2552  if (arith::XOrIOp notv =
2553  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2554  if (notv.getLhs() == prevIf.getCondition() &&
2555  matchPattern(notv.getRhs(), m_One())) {
2556  nextElse = nextIf.thenBlock();
2557  if (!nextIf.getElseRegion().empty())
2558  nextThen = nextIf.elseBlock();
2559  }
2560  }
2561  if (arith::XOrIOp notv =
2562  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2563  if (notv.getLhs() == nextIf.getCondition() &&
2564  matchPattern(notv.getRhs(), m_One())) {
2565  nextElse = nextIf.thenBlock();
2566  if (!nextIf.getElseRegion().empty())
2567  nextThen = nextIf.elseBlock();
2568  }
2569  }
2570 
2571  if (!nextThen && !nextElse)
2572  return failure();
2573 
2574  SmallVector<Value> prevElseYielded;
2575  if (!prevIf.getElseRegion().empty())
2576  prevElseYielded = prevIf.elseYield().getOperands();
2577  // Replace all uses of return values of op within nextIf with the
2578  // corresponding yields
2579  for (auto it : llvm::zip(prevIf.getResults(),
2580  prevIf.thenYield().getOperands(), prevElseYielded))
2581  for (OpOperand &use :
2582  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2583  if (nextThen && nextThen->getParent()->isAncestor(
2584  use.getOwner()->getParentRegion())) {
2585  rewriter.startOpModification(use.getOwner());
2586  use.set(std::get<1>(it));
2587  rewriter.finalizeOpModification(use.getOwner());
2588  } else if (nextElse && nextElse->getParent()->isAncestor(
2589  use.getOwner()->getParentRegion())) {
2590  rewriter.startOpModification(use.getOwner());
2591  use.set(std::get<2>(it));
2592  rewriter.finalizeOpModification(use.getOwner());
2593  }
2594  }
2595 
2596  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2597  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2598 
2599  IfOp combinedIf = rewriter.create<IfOp>(
2600  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2601  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2602 
2603  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2604  combinedIf.getThenRegion(),
2605  combinedIf.getThenRegion().begin());
2606 
2607  if (nextThen) {
2608  YieldOp thenYield = combinedIf.thenYield();
2609  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2610  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2611  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2612 
2613  SmallVector<Value> mergedYields(thenYield.getOperands());
2614  llvm::append_range(mergedYields, thenYield2.getOperands());
2615  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2616  rewriter.eraseOp(thenYield);
2617  rewriter.eraseOp(thenYield2);
2618  }
2619 
2620  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2621  combinedIf.getElseRegion(),
2622  combinedIf.getElseRegion().begin());
2623 
2624  if (nextElse) {
2625  if (combinedIf.getElseRegion().empty()) {
2626  rewriter.inlineRegionBefore(*nextElse->getParent(),
2627  combinedIf.getElseRegion(),
2628  combinedIf.getElseRegion().begin());
2629  } else {
2630  YieldOp elseYield = combinedIf.elseYield();
2631  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2632  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2633 
2634  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2635 
2636  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2637  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2638 
2639  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2640  rewriter.eraseOp(elseYield);
2641  rewriter.eraseOp(elseYield2);
2642  }
2643  }
2644 
2645  SmallVector<Value> prevValues;
2646  SmallVector<Value> nextValues;
2647  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2648  if (pair.index() < prevIf.getNumResults())
2649  prevValues.push_back(pair.value());
2650  else
2651  nextValues.push_back(pair.value());
2652  }
2653  rewriter.replaceOp(prevIf, prevValues);
2654  rewriter.replaceOp(nextIf, nextValues);
2655  return success();
2656  }
2657 };
2658 
2659 /// Pattern to remove an empty else branch.
2660 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2662 
2663  LogicalResult matchAndRewrite(IfOp ifOp,
2664  PatternRewriter &rewriter) const override {
2665  // Cannot remove else region when there are operation results.
2666  if (ifOp.getNumResults())
2667  return failure();
2668  Block *elseBlock = ifOp.elseBlock();
2669  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2670  return failure();
2671  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2672  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2673  newIfOp.getThenRegion().begin());
2674  rewriter.eraseOp(ifOp);
2675  return success();
2676  }
2677 };
2678 
2679 /// Convert nested `if`s into `arith.andi` + single `if`.
2680 ///
2681 /// scf.if %arg0 {
2682 /// scf.if %arg1 {
2683 /// ...
2684 /// scf.yield
2685 /// }
2686 /// scf.yield
2687 /// }
2688 /// becomes
2689 ///
2690 /// %0 = arith.andi %arg0, %arg1
2691 /// scf.if %0 {
2692 /// ...
2693 /// scf.yield
2694 /// }
2695 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2697 
2698  LogicalResult matchAndRewrite(IfOp op,
2699  PatternRewriter &rewriter) const override {
2700  auto nestedOps = op.thenBlock()->without_terminator();
2701  // Nested `if` must be the only op in block.
2702  if (!llvm::hasSingleElement(nestedOps))
2703  return failure();
2704 
2705  // If there is an else block, it can only yield
2706  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2707  return failure();
2708 
2709  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2710  if (!nestedIf)
2711  return failure();
2712 
2713  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2714  return failure();
2715 
2716  SmallVector<Value> thenYield(op.thenYield().getOperands());
2717  SmallVector<Value> elseYield;
2718  if (op.elseBlock())
2719  llvm::append_range(elseYield, op.elseYield().getOperands());
2720 
2721  // A list of indices for which we should upgrade the value yielded
2722  // in the else to a select.
2723  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2724 
2725  // If the outer scf.if yields a value produced by the inner scf.if,
2726  // only permit combining if the value yielded when the condition
2727  // is false in the outer scf.if is the same value yielded when the
2728  // inner scf.if condition is false.
2729  // Note that the array access to elseYield will not go out of bounds
2730  // since it must have the same length as thenYield, since they both
2731  // come from the same scf.if.
2732  for (const auto &tup : llvm::enumerate(thenYield)) {
2733  if (tup.value().getDefiningOp() == nestedIf) {
2734  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2735  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2736  elseYield[tup.index()]) {
2737  return failure();
2738  }
2739  // If the correctness test passes, we will yield
2740  // corresponding value from the inner scf.if
2741  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2742  continue;
2743  }
2744 
2745  // Otherwise, we need to ensure the else block of the combined
2746  // condition still returns the same value when the outer condition is
2747  // true and the inner condition is false. This can be accomplished if
2748  // the then value is defined outside the outer scf.if and we replace the
2749  // value with a select that considers just the outer condition. Since
2750  // the else region contains just the yield, its yielded value is
2751  // defined outside the scf.if, by definition.
2752 
2753  // If the then value is defined within the scf.if, bail.
2754  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2755  return failure();
2756  }
2757  elseYieldsToUpgradeToSelect.push_back(tup.index());
2758  }
2759 
2760  Location loc = op.getLoc();
2761  Value newCondition = rewriter.create<arith::AndIOp>(
2762  loc, op.getCondition(), nestedIf.getCondition());
2763  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2764  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2765 
2766  SmallVector<Value> results;
2767  llvm::append_range(results, newIf.getResults());
2768  rewriter.setInsertionPoint(newIf);
2769 
2770  for (auto idx : elseYieldsToUpgradeToSelect)
2771  results[idx] = rewriter.create<arith::SelectOp>(
2772  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2773 
2774  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2775  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2776  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2777  if (!elseYield.empty()) {
2778  rewriter.createBlock(&newIf.getElseRegion());
2779  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2780  rewriter.create<YieldOp>(loc, elseYield);
2781  }
2782  rewriter.replaceOp(op, results);
2783  return success();
2784  }
2785 };
2786 
2787 } // namespace
2788 
2789 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2790  MLIRContext *context) {
2791  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2792  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2793  RemoveStaticCondition, RemoveUnusedResults,
2794  ReplaceIfYieldWithConditionOrValue>(context);
2795 }
2796 
2797 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2798 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2799 Block *IfOp::elseBlock() {
2800  Region &r = getElseRegion();
2801  if (r.empty())
2802  return nullptr;
2803  return &r.back();
2804 }
2805 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2806 
2807 //===----------------------------------------------------------------------===//
2808 // ParallelOp
2809 //===----------------------------------------------------------------------===//
2810 
2811 void ParallelOp::build(
2812  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2813  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2815  bodyBuilderFn) {
2816  result.addOperands(lowerBounds);
2817  result.addOperands(upperBounds);
2818  result.addOperands(steps);
2819  result.addOperands(initVals);
2820  result.addAttribute(
2821  ParallelOp::getOperandSegmentSizeAttr(),
2822  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2823  static_cast<int32_t>(upperBounds.size()),
2824  static_cast<int32_t>(steps.size()),
2825  static_cast<int32_t>(initVals.size())}));
2826  result.addTypes(initVals.getTypes());
2827 
2828  OpBuilder::InsertionGuard guard(builder);
2829  unsigned numIVs = steps.size();
2830  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2831  SmallVector<Location, 8> argLocs(numIVs, result.location);
2832  Region *bodyRegion = result.addRegion();
2833  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2834 
2835  if (bodyBuilderFn) {
2836  builder.setInsertionPointToStart(bodyBlock);
2837  bodyBuilderFn(builder, result.location,
2838  bodyBlock->getArguments().take_front(numIVs),
2839  bodyBlock->getArguments().drop_front(numIVs));
2840  }
2841  // Add terminator only if there are no reductions.
2842  if (initVals.empty())
2843  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2844 }
2845 
2846 void ParallelOp::build(
2847  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2848  ValueRange upperBounds, ValueRange steps,
2849  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2850  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2851  // we don't capture a reference to a temporary by constructing the lambda at
2852  // function level.
2853  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2854  Location nestedLoc, ValueRange ivs,
2855  ValueRange) {
2856  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2857  };
2858  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2859  if (bodyBuilderFn)
2860  wrapper = wrappedBuilderFn;
2861 
2862  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2863  wrapper);
2864 }
2865 
2866 LogicalResult ParallelOp::verify() {
2867  // Check that there is at least one value in lowerBound, upperBound and step.
2868  // It is sufficient to test only step, because it is ensured already that the
2869  // number of elements in lowerBound, upperBound and step are the same.
2870  Operation::operand_range stepValues = getStep();
2871  if (stepValues.empty())
2872  return emitOpError(
2873  "needs at least one tuple element for lowerBound, upperBound and step");
2874 
2875  // Check whether all constant step values are positive.
2876  for (Value stepValue : stepValues)
2877  if (auto cst = getConstantIntValue(stepValue))
2878  if (*cst <= 0)
2879  return emitOpError("constant step operand must be positive");
2880 
2881  // Check that the body defines the same number of block arguments as the
2882  // number of tuple elements in step.
2883  Block *body = getBody();
2884  if (body->getNumArguments() != stepValues.size())
2885  return emitOpError() << "expects the same number of induction variables: "
2886  << body->getNumArguments()
2887  << " as bound and step values: " << stepValues.size();
2888  for (auto arg : body->getArguments())
2889  if (!arg.getType().isIndex())
2890  return emitOpError(
2891  "expects arguments for the induction variable to be of index type");
2892 
2893  // Check that the terminator is an scf.reduce op.
2894  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2895  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2896  if (!reduceOp)
2897  return failure();
2898 
2899  // Check that the number of results is the same as the number of reductions.
2900  auto resultsSize = getResults().size();
2901  auto reductionsSize = reduceOp.getReductions().size();
2902  auto initValsSize = getInitVals().size();
2903  if (resultsSize != reductionsSize)
2904  return emitOpError() << "expects number of results: " << resultsSize
2905  << " to be the same as number of reductions: "
2906  << reductionsSize;
2907  if (resultsSize != initValsSize)
2908  return emitOpError() << "expects number of results: " << resultsSize
2909  << " to be the same as number of initial values: "
2910  << initValsSize;
2911 
2912  // Check that the types of the results and reductions are the same.
2913  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2914  auto resultType = getOperation()->getResult(i).getType();
2915  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2916  if (resultType != reductionOperandType)
2917  return reduceOp.emitOpError()
2918  << "expects type of " << i
2919  << "-th reduction operand: " << reductionOperandType
2920  << " to be the same as the " << i
2921  << "-th result type: " << resultType;
2922  }
2923  return success();
2924 }
2925 
2926 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2927  auto &builder = parser.getBuilder();
2928  // Parse an opening `(` followed by induction variables followed by `)`
2931  return failure();
2932 
2933  // Parse loop bounds.
2935  if (parser.parseEqual() ||
2936  parser.parseOperandList(lower, ivs.size(),
2938  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2939  return failure();
2940 
2942  if (parser.parseKeyword("to") ||
2943  parser.parseOperandList(upper, ivs.size(),
2945  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2946  return failure();
2947 
2948  // Parse step values.
2950  if (parser.parseKeyword("step") ||
2951  parser.parseOperandList(steps, ivs.size(),
2953  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2954  return failure();
2955 
2956  // Parse init values.
2958  if (succeeded(parser.parseOptionalKeyword("init"))) {
2959  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2960  return failure();
2961  }
2962 
2963  // Parse optional results in case there is a reduce.
2964  if (parser.parseOptionalArrowTypeList(result.types))
2965  return failure();
2966 
2967  // Now parse the body.
2968  Region *body = result.addRegion();
2969  for (auto &iv : ivs)
2970  iv.type = builder.getIndexType();
2971  if (parser.parseRegion(*body, ivs))
2972  return failure();
2973 
2974  // Set `operandSegmentSizes` attribute.
2975  result.addAttribute(
2976  ParallelOp::getOperandSegmentSizeAttr(),
2977  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2978  static_cast<int32_t>(upper.size()),
2979  static_cast<int32_t>(steps.size()),
2980  static_cast<int32_t>(initVals.size())}));
2981 
2982  // Parse attributes.
2983  if (parser.parseOptionalAttrDict(result.attributes) ||
2984  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2985  result.operands))
2986  return failure();
2987 
2988  // Add a terminator if none was parsed.
2989  ParallelOp::ensureTerminator(*body, builder, result.location);
2990  return success();
2991 }
2992 
2994  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2995  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2996  if (!getInitVals().empty())
2997  p << " init (" << getInitVals() << ")";
2998  p.printOptionalArrowTypeList(getResultTypes());
2999  p << ' ';
3000  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3002  (*this)->getAttrs(),
3003  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3004 }
3005 
3006 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3007 
3008 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3009  return SmallVector<Value>{getBody()->getArguments()};
3010 }
3011 
3012 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3013  return getLowerBound();
3014 }
3015 
3016 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3017  return getUpperBound();
3018 }
3019 
3020 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3021  return getStep();
3022 }
3023 
3025  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3026  if (!ivArg)
3027  return ParallelOp();
3028  assert(ivArg.getOwner() && "unlinked block argument");
3029  auto *containingOp = ivArg.getOwner()->getParentOp();
3030  return dyn_cast<ParallelOp>(containingOp);
3031 }
3032 
3033 namespace {
3034 // Collapse loop dimensions that perform a single iteration.
3035 struct ParallelOpSingleOrZeroIterationDimsFolder
3036  : public OpRewritePattern<ParallelOp> {
3038 
3039  LogicalResult matchAndRewrite(ParallelOp op,
3040  PatternRewriter &rewriter) const override {
3041  Location loc = op.getLoc();
3042 
3043  // Compute new loop bounds that omit all single-iteration loop dimensions.
3044  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3045  IRMapping mapping;
3046  for (auto [lb, ub, step, iv] :
3047  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3048  op.getInductionVars())) {
3049  auto numIterations = constantTripCount(lb, ub, step);
3050  if (numIterations.has_value()) {
3051  // Remove the loop if it performs zero iterations.
3052  if (*numIterations == 0) {
3053  rewriter.replaceOp(op, op.getInitVals());
3054  return success();
3055  }
3056  // Replace the loop induction variable by the lower bound if the loop
3057  // performs a single iteration. Otherwise, copy the loop bounds.
3058  if (*numIterations == 1) {
3059  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3060  continue;
3061  }
3062  }
3063  newLowerBounds.push_back(lb);
3064  newUpperBounds.push_back(ub);
3065  newSteps.push_back(step);
3066  }
3067  // Exit if none of the loop dimensions perform a single iteration.
3068  if (newLowerBounds.size() == op.getLowerBound().size())
3069  return failure();
3070 
3071  if (newLowerBounds.empty()) {
3072  // All of the loop dimensions perform a single iteration. Inline
3073  // loop body and nested ReduceOp's
3074  SmallVector<Value> results;
3075  results.reserve(op.getInitVals().size());
3076  for (auto &bodyOp : op.getBody()->without_terminator())
3077  rewriter.clone(bodyOp, mapping);
3078  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3079  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3080  Block &reduceBlock = reduceOp.getReductions()[i].front();
3081  auto initValIndex = results.size();
3082  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3083  mapping.map(reduceBlock.getArgument(1),
3084  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3085  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3086  rewriter.clone(reduceBodyOp, mapping);
3087 
3088  auto result = mapping.lookupOrDefault(
3089  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3090  results.push_back(result);
3091  }
3092 
3093  rewriter.replaceOp(op, results);
3094  return success();
3095  }
3096  // Replace the parallel loop by lower-dimensional parallel loop.
3097  auto newOp =
3098  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3099  newSteps, op.getInitVals(), nullptr);
3100  // Erase the empty block that was inserted by the builder.
3101  rewriter.eraseBlock(newOp.getBody());
3102  // Clone the loop body and remap the block arguments of the collapsed loops
3103  // (inlining does not support a cancellable block argument mapping).
3104  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3105  newOp.getRegion().begin(), mapping);
3106  rewriter.replaceOp(op, newOp.getResults());
3107  return success();
3108  }
3109 };
3110 
3111 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3113 
3114  LogicalResult matchAndRewrite(ParallelOp op,
3115  PatternRewriter &rewriter) const override {
3116  Block &outerBody = *op.getBody();
3117  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3118  return failure();
3119 
3120  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3121  if (!innerOp)
3122  return failure();
3123 
3124  for (auto val : outerBody.getArguments())
3125  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3126  llvm::is_contained(innerOp.getUpperBound(), val) ||
3127  llvm::is_contained(innerOp.getStep(), val))
3128  return failure();
3129 
3130  // Reductions are not supported yet.
3131  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3132  return failure();
3133 
3134  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3135  ValueRange iterVals, ValueRange) {
3136  Block &innerBody = *innerOp.getBody();
3137  assert(iterVals.size() ==
3138  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3139  IRMapping mapping;
3140  mapping.map(outerBody.getArguments(),
3141  iterVals.take_front(outerBody.getNumArguments()));
3142  mapping.map(innerBody.getArguments(),
3143  iterVals.take_back(innerBody.getNumArguments()));
3144  for (Operation &op : innerBody.without_terminator())
3145  builder.clone(op, mapping);
3146  };
3147 
3148  auto concatValues = [](const auto &first, const auto &second) {
3149  SmallVector<Value> ret;
3150  ret.reserve(first.size() + second.size());
3151  ret.assign(first.begin(), first.end());
3152  ret.append(second.begin(), second.end());
3153  return ret;
3154  };
3155 
3156  auto newLowerBounds =
3157  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3158  auto newUpperBounds =
3159  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3160  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3161 
3162  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3163  newSteps, std::nullopt,
3164  bodyBuilder);
3165  return success();
3166  }
3167 };
3168 
3169 } // namespace
3170 
3171 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3172  MLIRContext *context) {
3173  results
3174  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3175  context);
3176 }
3177 
3178 /// Given the region at `index`, or the parent operation if `index` is None,
3179 /// return the successor regions. These are the regions that may be selected
3180 /// during the flow of control. `operands` is a set of optional attributes that
3181 /// correspond to a constant value for each operand, or null if that operand is
3182 /// not a constant.
3183 void ParallelOp::getSuccessorRegions(
3185  // Both the operation itself and the region may be branching into the body or
3186  // back into the operation itself. It is possible for loop not to enter the
3187  // body.
3188  regions.push_back(RegionSuccessor(&getRegion()));
3189  regions.push_back(RegionSuccessor());
3190 }
3191 
3192 //===----------------------------------------------------------------------===//
3193 // ReduceOp
3194 //===----------------------------------------------------------------------===//
3195 
3196 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3197 
3198 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3199  ValueRange operands) {
3200  result.addOperands(operands);
3201  for (Value v : operands) {
3202  OpBuilder::InsertionGuard guard(builder);
3203  Region *bodyRegion = result.addRegion();
3204  builder.createBlock(bodyRegion, {},
3205  ArrayRef<Type>{v.getType(), v.getType()},
3206  {result.location, result.location});
3207  }
3208 }
3209 
3210 LogicalResult ReduceOp::verifyRegions() {
3211  // The region of a ReduceOp has two arguments of the same type as its
3212  // corresponding operand.
3213  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3214  auto type = getOperands()[i].getType();
3215  Block &block = getReductions()[i].front();
3216  if (block.empty())
3217  return emitOpError() << i << "-th reduction has an empty body";
3218  if (block.getNumArguments() != 2 ||
3219  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3220  return arg.getType() != type;
3221  }))
3222  return emitOpError() << "expected two block arguments with type " << type
3223  << " in the " << i << "-th reduction region";
3224 
3225  // Check that the block is terminated by a ReduceReturnOp.
3226  if (!isa<ReduceReturnOp>(block.getTerminator()))
3227  return emitOpError("reduction bodies must be terminated with an "
3228  "'scf.reduce.return' op");
3229  }
3230 
3231  return success();
3232 }
3233 
3236  // No operands are forwarded to the next iteration.
3237  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3238 }
3239 
3240 //===----------------------------------------------------------------------===//
3241 // ReduceReturnOp
3242 //===----------------------------------------------------------------------===//
3243 
3244 LogicalResult ReduceReturnOp::verify() {
3245  // The type of the return value should be the same type as the types of the
3246  // block arguments of the reduction body.
3247  Block *reductionBody = getOperation()->getBlock();
3248  // Should already be verified by an op trait.
3249  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3250  Type expectedResultType = reductionBody->getArgument(0).getType();
3251  if (expectedResultType != getResult().getType())
3252  return emitOpError() << "must have type " << expectedResultType
3253  << " (the type of the reduction inputs)";
3254  return success();
3255 }
3256 
3257 //===----------------------------------------------------------------------===//
3258 // WhileOp
3259 //===----------------------------------------------------------------------===//
3260 
3261 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3262  ::mlir::OperationState &odsState, TypeRange resultTypes,
3263  ValueRange inits, BodyBuilderFn beforeBuilder,
3264  BodyBuilderFn afterBuilder) {
3265  odsState.addOperands(inits);
3266  odsState.addTypes(resultTypes);
3267 
3268  OpBuilder::InsertionGuard guard(odsBuilder);
3269 
3270  // Build before region.
3271  SmallVector<Location, 4> beforeArgLocs;
3272  beforeArgLocs.reserve(inits.size());
3273  for (Value operand : inits) {
3274  beforeArgLocs.push_back(operand.getLoc());
3275  }
3276 
3277  Region *beforeRegion = odsState.addRegion();
3278  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3279  inits.getTypes(), beforeArgLocs);
3280  if (beforeBuilder)
3281  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3282 
3283  // Build after region.
3284  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3285 
3286  Region *afterRegion = odsState.addRegion();
3287  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3288  resultTypes, afterArgLocs);
3289 
3290  if (afterBuilder)
3291  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3292 }
3293 
3294 ConditionOp WhileOp::getConditionOp() {
3295  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3296 }
3297 
3298 YieldOp WhileOp::getYieldOp() {
3299  return cast<YieldOp>(getAfterBody()->getTerminator());
3300 }
3301 
3302 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3303  return getYieldOp().getResultsMutable();
3304 }
3305 
3306 Block::BlockArgListType WhileOp::getBeforeArguments() {
3307  return getBeforeBody()->getArguments();
3308 }
3309 
3310 Block::BlockArgListType WhileOp::getAfterArguments() {
3311  return getAfterBody()->getArguments();
3312 }
3313 
3314 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3315  return getBeforeArguments();
3316 }
3317 
3318 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3319  assert(point == getBefore() &&
3320  "WhileOp is expected to branch only to the first region");
3321  return getInits();
3322 }
3323 
3324 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3326  // The parent op always branches to the condition region.
3327  if (point.isParent()) {
3328  regions.emplace_back(&getBefore(), getBefore().getArguments());
3329  return;
3330  }
3331 
3332  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3333  "there are only two regions in a WhileOp");
3334  // The body region always branches back to the condition region.
3335  if (point == getAfter()) {
3336  regions.emplace_back(&getBefore(), getBefore().getArguments());
3337  return;
3338  }
3339 
3340  regions.emplace_back(getResults());
3341  regions.emplace_back(&getAfter(), getAfter().getArguments());
3342 }
3343 
3344 SmallVector<Region *> WhileOp::getLoopRegions() {
3345  return {&getBefore(), &getAfter()};
3346 }
3347 
3348 /// Parses a `while` op.
3349 ///
3350 /// op ::= `scf.while` assignments `:` function-type region `do` region
3351 /// `attributes` attribute-dict
3352 /// initializer ::= /* empty */ | `(` assignment-list `)`
3353 /// assignment-list ::= assignment | assignment `,` assignment-list
3354 /// assignment ::= ssa-value `=` ssa-value
3355 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3358  Region *before = result.addRegion();
3359  Region *after = result.addRegion();
3360 
3361  OptionalParseResult listResult =
3362  parser.parseOptionalAssignmentList(regionArgs, operands);
3363  if (listResult.has_value() && failed(listResult.value()))
3364  return failure();
3365 
3366  FunctionType functionType;
3367  SMLoc typeLoc = parser.getCurrentLocation();
3368  if (failed(parser.parseColonType(functionType)))
3369  return failure();
3370 
3371  result.addTypes(functionType.getResults());
3372 
3373  if (functionType.getNumInputs() != operands.size()) {
3374  return parser.emitError(typeLoc)
3375  << "expected as many input types as operands "
3376  << "(expected " << operands.size() << " got "
3377  << functionType.getNumInputs() << ")";
3378  }
3379 
3380  // Resolve input operands.
3381  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3382  parser.getCurrentLocation(),
3383  result.operands)))
3384  return failure();
3385 
3386  // Propagate the types into the region arguments.
3387  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3388  regionArgs[i].type = functionType.getInput(i);
3389 
3390  return failure(parser.parseRegion(*before, regionArgs) ||
3391  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3393 }
3394 
3395 /// Prints a `while` op.
3397  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3398  p << " : ";
3399  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3400  p << ' ';
3401  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3402  p << " do ";
3403  p.printRegion(getAfter());
3404  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3405 }
3406 
3407 /// Verifies that two ranges of types match, i.e. have the same number of
3408 /// entries and that types are pairwise equals. Reports errors on the given
3409 /// operation in case of mismatch.
3410 template <typename OpTy>
3411 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3412  TypeRange right, StringRef message) {
3413  if (left.size() != right.size())
3414  return op.emitOpError("expects the same number of ") << message;
3415 
3416  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3417  if (left[i] != right[i]) {
3418  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3419  << message;
3420  diag.attachNote() << "for argument " << i << ", found " << left[i]
3421  << " and " << right[i];
3422  return diag;
3423  }
3424  }
3425 
3426  return success();
3427 }
3428 
3429 LogicalResult scf::WhileOp::verify() {
3430  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3431  *this, getBefore(),
3432  "expects the 'before' region to terminate with 'scf.condition'");
3433  if (!beforeTerminator)
3434  return failure();
3435 
3436  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3437  *this, getAfter(),
3438  "expects the 'after' region to terminate with 'scf.yield'");
3439  return success(afterTerminator != nullptr);
3440 }
3441 
3442 namespace {
3443 /// Replace uses of the condition within the do block with true, since otherwise
3444 /// the block would not be evaluated.
3445 ///
3446 /// scf.while (..) : (i1, ...) -> ... {
3447 /// %condition = call @evaluate_condition() : () -> i1
3448 /// scf.condition(%condition) %condition : i1, ...
3449 /// } do {
3450 /// ^bb0(%arg0: i1, ...):
3451 /// use(%arg0)
3452 /// ...
3453 ///
3454 /// becomes
3455 /// scf.while (..) : (i1, ...) -> ... {
3456 /// %condition = call @evaluate_condition() : () -> i1
3457 /// scf.condition(%condition) %condition : i1, ...
3458 /// } do {
3459 /// ^bb0(%arg0: i1, ...):
3460 /// use(%true)
3461 /// ...
3462 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3464 
3465  LogicalResult matchAndRewrite(WhileOp op,
3466  PatternRewriter &rewriter) const override {
3467  auto term = op.getConditionOp();
3468 
3469  // These variables serve to prevent creating duplicate constants
3470  // and hold constant true or false values.
3471  Value constantTrue = nullptr;
3472 
3473  bool replaced = false;
3474  for (auto yieldedAndBlockArgs :
3475  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3476  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3477  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3478  if (!constantTrue)
3479  constantTrue = rewriter.create<arith::ConstantOp>(
3480  op.getLoc(), term.getCondition().getType(),
3481  rewriter.getBoolAttr(true));
3482 
3483  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3484  constantTrue);
3485  replaced = true;
3486  }
3487  }
3488  }
3489  return success(replaced);
3490  }
3491 };
3492 
3493 /// Remove loop invariant arguments from `before` block of scf.while.
3494 /// A before block argument is considered loop invariant if :-
3495 /// 1. i-th yield operand is equal to the i-th while operand.
3496 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3497 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3498 /// iter argument/while operand.
3499 /// For the arguments which are removed, their uses inside scf.while
3500 /// are replaced with their corresponding initial value.
3501 ///
3502 /// Eg:
3503 /// INPUT :-
3504 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3505 /// ..., %argN_before = %N)
3506 /// {
3507 /// ...
3508 /// scf.condition(%cond) %arg1_before, %arg0_before,
3509 /// %arg2_before, %arg0_before, ...
3510 /// } do {
3511 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3512 /// ..., %argK_after):
3513 /// ...
3514 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3515 /// }
3516 ///
3517 /// OUTPUT :-
3518 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3519 /// %N)
3520 /// {
3521 /// ...
3522 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3523 /// } do {
3524 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3525 /// ..., %argK_after):
3526 /// ...
3527 /// scf.yield %arg1_after, ..., %argN
3528 /// }
3529 ///
3530 /// EXPLANATION:
3531 /// We iterate over each yield operand.
3532 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3533 /// %arg0_before, which in turn is the 0-th iter argument. So we
3534 /// remove 0-th before block argument and yield operand, and replace
3535 /// all uses of the 0-th before block argument with its initial value
3536 /// %a.
3537 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3538 /// value. So we remove this operand and the corresponding before
3539 /// block argument and replace all uses of 1-th before block argument
3540 /// with %b.
3541 struct RemoveLoopInvariantArgsFromBeforeBlock
3542  : public OpRewritePattern<WhileOp> {
3544 
3545  LogicalResult matchAndRewrite(WhileOp op,
3546  PatternRewriter &rewriter) const override {
3547  Block &afterBlock = *op.getAfterBody();
3548  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3549  ConditionOp condOp = op.getConditionOp();
3550  OperandRange condOpArgs = condOp.getArgs();
3551  Operation *yieldOp = afterBlock.getTerminator();
3552  ValueRange yieldOpArgs = yieldOp->getOperands();
3553 
3554  bool canSimplify = false;
3555  for (const auto &it :
3556  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3557  auto index = static_cast<unsigned>(it.index());
3558  auto [initVal, yieldOpArg] = it.value();
3559  // If i-th yield operand is equal to the i-th operand of the scf.while,
3560  // the i-th before block argument is a loop invariant.
3561  if (yieldOpArg == initVal) {
3562  canSimplify = true;
3563  break;
3564  }
3565  // If the i-th yield operand is k-th after block argument, then we check
3566  // if the (k+1)-th condition op operand is equal to either the i-th before
3567  // block argument or the initial value of i-th before block argument. If
3568  // the comparison results `true`, i-th before block argument is a loop
3569  // invariant.
3570  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3571  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3572  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3573  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3574  canSimplify = true;
3575  break;
3576  }
3577  }
3578  }
3579 
3580  if (!canSimplify)
3581  return failure();
3582 
3583  SmallVector<Value> newInitArgs, newYieldOpArgs;
3584  DenseMap<unsigned, Value> beforeBlockInitValMap;
3585  SmallVector<Location> newBeforeBlockArgLocs;
3586  for (const auto &it :
3587  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3588  auto index = static_cast<unsigned>(it.index());
3589  auto [initVal, yieldOpArg] = it.value();
3590 
3591  // If i-th yield operand is equal to the i-th operand of the scf.while,
3592  // the i-th before block argument is a loop invariant.
3593  if (yieldOpArg == initVal) {
3594  beforeBlockInitValMap.insert({index, initVal});
3595  continue;
3596  } else {
3597  // If the i-th yield operand is k-th after block argument, then we check
3598  // if the (k+1)-th condition op operand is equal to either the i-th
3599  // before block argument or the initial value of i-th before block
3600  // argument. If the comparison results `true`, i-th before block
3601  // argument is a loop invariant.
3602  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3603  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3604  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3605  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3606  beforeBlockInitValMap.insert({index, initVal});
3607  continue;
3608  }
3609  }
3610  }
3611  newInitArgs.emplace_back(initVal);
3612  newYieldOpArgs.emplace_back(yieldOpArg);
3613  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3614  }
3615 
3616  {
3617  OpBuilder::InsertionGuard g(rewriter);
3618  rewriter.setInsertionPoint(yieldOp);
3619  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3620  }
3621 
3622  auto newWhile =
3623  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3624 
3625  Block &newBeforeBlock = *rewriter.createBlock(
3626  &newWhile.getBefore(), /*insertPt*/ {},
3627  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3628 
3629  Block &beforeBlock = *op.getBeforeBody();
3630  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3631  // For each i-th before block argument we find it's replacement value as :-
3632  // 1. If i-th before block argument is a loop invariant, we fetch it's
3633  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3634  // 2. Else we fetch j-th new before block argument as the replacement
3635  // value of i-th before block argument.
3636  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3637  // If the index 'i' argument was a loop invariant we fetch it's initial
3638  // value from `beforeBlockInitValMap`.
3639  if (beforeBlockInitValMap.count(i) != 0)
3640  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3641  else
3642  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3643  }
3644 
3645  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3646  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3647  newWhile.getAfter().begin());
3648 
3649  rewriter.replaceOp(op, newWhile.getResults());
3650  return success();
3651  }
3652 };
3653 
3654 /// Remove loop invariant value from result (condition op) of scf.while.
3655 /// A value is considered loop invariant if the final value yielded by
3656 /// scf.condition is defined outside of the `before` block. We remove the
3657 /// corresponding argument in `after` block and replace the use with the value.
3658 /// We also replace the use of the corresponding result of scf.while with the
3659 /// value.
3660 ///
3661 /// Eg:
3662 /// INPUT :-
3663 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3664 /// %argN_before = %N) {
3665 /// ...
3666 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3667 /// } do {
3668 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3669 /// ...
3670 /// some_func(%arg1_after)
3671 /// ...
3672 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3673 /// }
3674 ///
3675 /// OUTPUT :-
3676 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3677 /// ...
3678 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3679 /// } do {
3680 /// ^bb0(%arg0, %arg3, ..., %argM):
3681 /// ...
3682 /// some_func(%a)
3683 /// ...
3684 /// scf.yield %arg0, %b, ..., %argN
3685 /// }
3686 ///
3687 /// EXPLANATION:
3688 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3689 /// before block of scf.while, so they get removed.
3690 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3691 /// replaced by %b.
3692 /// 3. The corresponding after block argument %arg1_after's uses are
3693 /// replaced by %a and %arg2_after's uses are replaced by %b.
3694 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3696 
3697  LogicalResult matchAndRewrite(WhileOp op,
3698  PatternRewriter &rewriter) const override {
3699  Block &beforeBlock = *op.getBeforeBody();
3700  ConditionOp condOp = op.getConditionOp();
3701  OperandRange condOpArgs = condOp.getArgs();
3702 
3703  bool canSimplify = false;
3704  for (Value condOpArg : condOpArgs) {
3705  // Those values not defined within `before` block will be considered as
3706  // loop invariant values. We map the corresponding `index` with their
3707  // value.
3708  if (condOpArg.getParentBlock() != &beforeBlock) {
3709  canSimplify = true;
3710  break;
3711  }
3712  }
3713 
3714  if (!canSimplify)
3715  return failure();
3716 
3717  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3718 
3719  SmallVector<Value> newCondOpArgs;
3720  SmallVector<Type> newAfterBlockType;
3721  DenseMap<unsigned, Value> condOpInitValMap;
3722  SmallVector<Location> newAfterBlockArgLocs;
3723  for (const auto &it : llvm::enumerate(condOpArgs)) {
3724  auto index = static_cast<unsigned>(it.index());
3725  Value condOpArg = it.value();
3726  // Those values not defined within `before` block will be considered as
3727  // loop invariant values. We map the corresponding `index` with their
3728  // value.
3729  if (condOpArg.getParentBlock() != &beforeBlock) {
3730  condOpInitValMap.insert({index, condOpArg});
3731  } else {
3732  newCondOpArgs.emplace_back(condOpArg);
3733  newAfterBlockType.emplace_back(condOpArg.getType());
3734  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3735  }
3736  }
3737 
3738  {
3739  OpBuilder::InsertionGuard g(rewriter);
3740  rewriter.setInsertionPoint(condOp);
3741  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3742  newCondOpArgs);
3743  }
3744 
3745  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3746  op.getOperands());
3747 
3748  Block &newAfterBlock =
3749  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3750  newAfterBlockType, newAfterBlockArgLocs);
3751 
3752  Block &afterBlock = *op.getAfterBody();
3753  // Since a new scf.condition op was created, we need to fetch the new
3754  // `after` block arguments which will be used while replacing operations of
3755  // previous scf.while's `after` blocks. We'd also be fetching new result
3756  // values too.
3757  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3758  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3759  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3760  Value afterBlockArg, result;
3761  // If index 'i' argument was loop invariant we fetch it's value from the
3762  // `condOpInitMap` map.
3763  if (condOpInitValMap.count(i) != 0) {
3764  afterBlockArg = condOpInitValMap[i];
3765  result = afterBlockArg;
3766  } else {
3767  afterBlockArg = newAfterBlock.getArgument(j);
3768  result = newWhile.getResult(j);
3769  j++;
3770  }
3771  newAfterBlockArgs[i] = afterBlockArg;
3772  newWhileResults[i] = result;
3773  }
3774 
3775  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3776  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3777  newWhile.getBefore().begin());
3778 
3779  rewriter.replaceOp(op, newWhileResults);
3780  return success();
3781  }
3782 };
3783 
3784 /// Remove WhileOp results that are also unused in 'after' block.
3785 ///
3786 /// %0:2 = scf.while () : () -> (i32, i64) {
3787 /// %condition = "test.condition"() : () -> i1
3788 /// %v1 = "test.get_some_value"() : () -> i32
3789 /// %v2 = "test.get_some_value"() : () -> i64
3790 /// scf.condition(%condition) %v1, %v2 : i32, i64
3791 /// } do {
3792 /// ^bb0(%arg0: i32, %arg1: i64):
3793 /// "test.use"(%arg0) : (i32) -> ()
3794 /// scf.yield
3795 /// }
3796 /// return %0#0 : i32
3797 ///
3798 /// becomes
3799 /// %0 = scf.while () : () -> (i32) {
3800 /// %condition = "test.condition"() : () -> i1
3801 /// %v1 = "test.get_some_value"() : () -> i32
3802 /// %v2 = "test.get_some_value"() : () -> i64
3803 /// scf.condition(%condition) %v1 : i32
3804 /// } do {
3805 /// ^bb0(%arg0: i32):
3806 /// "test.use"(%arg0) : (i32) -> ()
3807 /// scf.yield
3808 /// }
3809 /// return %0 : i32
3810 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3812 
3813  LogicalResult matchAndRewrite(WhileOp op,
3814  PatternRewriter &rewriter) const override {
3815  auto term = op.getConditionOp();
3816  auto afterArgs = op.getAfterArguments();
3817  auto termArgs = term.getArgs();
3818 
3819  // Collect results mapping, new terminator args and new result types.
3820  SmallVector<unsigned> newResultsIndices;
3821  SmallVector<Type> newResultTypes;
3822  SmallVector<Value> newTermArgs;
3823  SmallVector<Location> newArgLocs;
3824  bool needUpdate = false;
3825  for (const auto &it :
3826  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3827  auto i = static_cast<unsigned>(it.index());
3828  Value result = std::get<0>(it.value());
3829  Value afterArg = std::get<1>(it.value());
3830  Value termArg = std::get<2>(it.value());
3831  if (result.use_empty() && afterArg.use_empty()) {
3832  needUpdate = true;
3833  } else {
3834  newResultsIndices.emplace_back(i);
3835  newTermArgs.emplace_back(termArg);
3836  newResultTypes.emplace_back(result.getType());
3837  newArgLocs.emplace_back(result.getLoc());
3838  }
3839  }
3840 
3841  if (!needUpdate)
3842  return failure();
3843 
3844  {
3845  OpBuilder::InsertionGuard g(rewriter);
3846  rewriter.setInsertionPoint(term);
3847  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3848  newTermArgs);
3849  }
3850 
3851  auto newWhile =
3852  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3853 
3854  Block &newAfterBlock = *rewriter.createBlock(
3855  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3856 
3857  // Build new results list and new after block args (unused entries will be
3858  // null).
3859  SmallVector<Value> newResults(op.getNumResults());
3860  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3861  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3862  newResults[it.value()] = newWhile.getResult(it.index());
3863  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3864  }
3865 
3866  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3867  newWhile.getBefore().begin());
3868 
3869  Block &afterBlock = *op.getAfterBody();
3870  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3871 
3872  rewriter.replaceOp(op, newResults);
3873  return success();
3874  }
3875 };
3876 
3877 /// Replace operations equivalent to the condition in the do block with true,
3878 /// since otherwise the block would not be evaluated.
3879 ///
3880 /// scf.while (..) : (i32, ...) -> ... {
3881 /// %z = ... : i32
3882 /// %condition = cmpi pred %z, %a
3883 /// scf.condition(%condition) %z : i32, ...
3884 /// } do {
3885 /// ^bb0(%arg0: i32, ...):
3886 /// %condition2 = cmpi pred %arg0, %a
3887 /// use(%condition2)
3888 /// ...
3889 ///
3890 /// becomes
3891 /// scf.while (..) : (i32, ...) -> ... {
3892 /// %z = ... : i32
3893 /// %condition = cmpi pred %z, %a
3894 /// scf.condition(%condition) %z : i32, ...
3895 /// } do {
3896 /// ^bb0(%arg0: i32, ...):
3897 /// use(%true)
3898 /// ...
3899 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3901 
3902  LogicalResult matchAndRewrite(scf::WhileOp op,
3903  PatternRewriter &rewriter) const override {
3904  using namespace scf;
3905  auto cond = op.getConditionOp();
3906  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3907  if (!cmp)
3908  return failure();
3909  bool changed = false;
3910  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3911  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3912  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3913  continue;
3914  for (OpOperand &u :
3915  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3916  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3917  if (!cmp2)
3918  continue;
3919  // For a binary operator 1-opIdx gets the other side.
3920  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3921  continue;
3922  bool samePredicate;
3923  if (cmp2.getPredicate() == cmp.getPredicate())
3924  samePredicate = true;
3925  else if (cmp2.getPredicate() ==
3926  arith::invertPredicate(cmp.getPredicate()))
3927  samePredicate = false;
3928  else
3929  continue;
3930 
3931  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3932  1);
3933  changed = true;
3934  }
3935  }
3936  }
3937  return success(changed);
3938  }
3939 };
3940 
3941 /// Remove unused init/yield args.
3942 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3944 
3945  LogicalResult matchAndRewrite(WhileOp op,
3946  PatternRewriter &rewriter) const override {
3947 
3948  if (!llvm::any_of(op.getBeforeArguments(),
3949  [](Value arg) { return arg.use_empty(); }))
3950  return rewriter.notifyMatchFailure(op, "No args to remove");
3951 
3952  YieldOp yield = op.getYieldOp();
3953 
3954  // Collect results mapping, new terminator args and new result types.
3955  SmallVector<Value> newYields;
3956  SmallVector<Value> newInits;
3957  llvm::BitVector argsToErase;
3958 
3959  size_t argsCount = op.getBeforeArguments().size();
3960  newYields.reserve(argsCount);
3961  newInits.reserve(argsCount);
3962  argsToErase.reserve(argsCount);
3963  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3964  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3965  if (beforeArg.use_empty()) {
3966  argsToErase.push_back(true);
3967  } else {
3968  argsToErase.push_back(false);
3969  newYields.emplace_back(yieldValue);
3970  newInits.emplace_back(initValue);
3971  }
3972  }
3973 
3974  Block &beforeBlock = *op.getBeforeBody();
3975  Block &afterBlock = *op.getAfterBody();
3976 
3977  beforeBlock.eraseArguments(argsToErase);
3978 
3979  Location loc = op.getLoc();
3980  auto newWhileOp =
3981  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
3982  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
3983  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3984  Block &newAfterBlock = *newWhileOp.getAfterBody();
3985 
3986  OpBuilder::InsertionGuard g(rewriter);
3987  rewriter.setInsertionPoint(yield);
3988  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3989 
3990  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
3991  newBeforeBlock.getArguments());
3992  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
3993  newAfterBlock.getArguments());
3994 
3995  rewriter.replaceOp(op, newWhileOp.getResults());
3996  return success();
3997  }
3998 };
3999 
4000 /// Remove duplicated ConditionOp args.
4001 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4003 
4004  LogicalResult matchAndRewrite(WhileOp op,
4005  PatternRewriter &rewriter) const override {
4006  ConditionOp condOp = op.getConditionOp();
4007  ValueRange condOpArgs = condOp.getArgs();
4008 
4010  for (Value arg : condOpArgs)
4011  argsSet.insert(arg);
4012 
4013  if (argsSet.size() == condOpArgs.size())
4014  return rewriter.notifyMatchFailure(op, "No results to remove");
4015 
4016  llvm::SmallDenseMap<Value, unsigned> argsMap;
4017  SmallVector<Value> newArgs;
4018  argsMap.reserve(condOpArgs.size());
4019  newArgs.reserve(condOpArgs.size());
4020  for (Value arg : condOpArgs) {
4021  if (!argsMap.count(arg)) {
4022  auto pos = static_cast<unsigned>(argsMap.size());
4023  argsMap.insert({arg, pos});
4024  newArgs.emplace_back(arg);
4025  }
4026  }
4027 
4028  ValueRange argsRange(newArgs);
4029 
4030  Location loc = op.getLoc();
4031  auto newWhileOp = rewriter.create<scf::WhileOp>(
4032  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4033  /*afterBody*/ nullptr);
4034  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4035  Block &newAfterBlock = *newWhileOp.getAfterBody();
4036 
4037  SmallVector<Value> afterArgsMapping;
4038  SmallVector<Value> resultsMapping;
4039  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4040  auto it = argsMap.find(arg);
4041  assert(it != argsMap.end());
4042  auto pos = it->second;
4043  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4044  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4045  }
4046 
4047  OpBuilder::InsertionGuard g(rewriter);
4048  rewriter.setInsertionPoint(condOp);
4049  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4050  argsRange);
4051 
4052  Block &beforeBlock = *op.getBeforeBody();
4053  Block &afterBlock = *op.getAfterBody();
4054 
4055  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4056  newBeforeBlock.getArguments());
4057  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4058  rewriter.replaceOp(op, resultsMapping);
4059  return success();
4060  }
4061 };
4062 
4063 /// If both ranges contain same values return mappping indices from args2 to
4064 /// args1. Otherwise return std::nullopt.
4065 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4066  ValueRange args2) {
4067  if (args1.size() != args2.size())
4068  return std::nullopt;
4069 
4070  SmallVector<unsigned> ret(args1.size());
4071  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4072  auto it = llvm::find(args2, arg1);
4073  if (it == args2.end())
4074  return std::nullopt;
4075 
4076  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4077  }
4078 
4079  return ret;
4080 }
4081 
4082 static bool hasDuplicates(ValueRange args) {
4083  llvm::SmallDenseSet<Value> set;
4084  for (Value arg : args) {
4085  if (!set.insert(arg).second)
4086  return true;
4087  }
4088  return false;
4089 }
4090 
4091 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4092 /// `scf.condition` args into same order as block args. Update `after` block
4093 /// args and op result values accordingly.
4094 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4095 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4097 
4098  LogicalResult matchAndRewrite(WhileOp loop,
4099  PatternRewriter &rewriter) const override {
4100  auto oldBefore = loop.getBeforeBody();
4101  ConditionOp oldTerm = loop.getConditionOp();
4102  ValueRange beforeArgs = oldBefore->getArguments();
4103  ValueRange termArgs = oldTerm.getArgs();
4104  if (beforeArgs == termArgs)
4105  return failure();
4106 
4107  if (hasDuplicates(termArgs))
4108  return failure();
4109 
4110  auto mapping = getArgsMapping(beforeArgs, termArgs);
4111  if (!mapping)
4112  return failure();
4113 
4114  {
4115  OpBuilder::InsertionGuard g(rewriter);
4116  rewriter.setInsertionPoint(oldTerm);
4117  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4118  beforeArgs);
4119  }
4120 
4121  auto oldAfter = loop.getAfterBody();
4122 
4123  SmallVector<Type> newResultTypes(beforeArgs.size());
4124  for (auto &&[i, j] : llvm::enumerate(*mapping))
4125  newResultTypes[j] = loop.getResult(i).getType();
4126 
4127  auto newLoop = rewriter.create<WhileOp>(
4128  loop.getLoc(), newResultTypes, loop.getInits(),
4129  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4130  auto newBefore = newLoop.getBeforeBody();
4131  auto newAfter = newLoop.getAfterBody();
4132 
4133  SmallVector<Value> newResults(beforeArgs.size());
4134  SmallVector<Value> newAfterArgs(beforeArgs.size());
4135  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4136  newResults[i] = newLoop.getResult(j);
4137  newAfterArgs[i] = newAfter->getArgument(j);
4138  }
4139 
4140  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4141  newBefore->getArguments());
4142  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4143  newAfterArgs);
4144 
4145  rewriter.replaceOp(loop, newResults);
4146  return success();
4147  }
4148 };
4149 } // namespace
4150 
4151 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4152  MLIRContext *context) {
4153  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4154  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4155  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4156  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4157 }
4158 
4159 //===----------------------------------------------------------------------===//
4160 // IndexSwitchOp
4161 //===----------------------------------------------------------------------===//
4162 
4163 /// Parse the case regions and values.
4164 static ParseResult
4166  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4167  SmallVector<int64_t> caseValues;
4168  while (succeeded(p.parseOptionalKeyword("case"))) {
4169  int64_t value;
4170  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4171  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4172  return failure();
4173  caseValues.push_back(value);
4174  }
4175  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4176  return success();
4177 }
4178 
4179 /// Print the case regions and values.
4181  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4182  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4183  p.printNewline();
4184  p << "case " << value << ' ';
4185  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4186  }
4187 }
4188 
4189 LogicalResult scf::IndexSwitchOp::verify() {
4190  if (getCases().size() != getCaseRegions().size()) {
4191  return emitOpError("has ")
4192  << getCaseRegions().size() << " case regions but "
4193  << getCases().size() << " case values";
4194  }
4195 
4196  DenseSet<int64_t> valueSet;
4197  for (int64_t value : getCases())
4198  if (!valueSet.insert(value).second)
4199  return emitOpError("has duplicate case value: ") << value;
4200  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4201  auto yield = dyn_cast<YieldOp>(region.front().back());
4202  if (!yield)
4203  return emitOpError("expected region to end with scf.yield, but got ")
4204  << region.front().back().getName();
4205 
4206  if (yield.getNumOperands() != getNumResults()) {
4207  return (emitOpError("expected each region to return ")
4208  << getNumResults() << " values, but " << name << " returns "
4209  << yield.getNumOperands())
4210  .attachNote(yield.getLoc())
4211  << "see yield operation here";
4212  }
4213  for (auto [idx, result, operand] :
4214  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4215  yield.getOperandTypes())) {
4216  if (result == operand)
4217  continue;
4218  return (emitOpError("expected result #")
4219  << idx << " of each region to be " << result)
4220  .attachNote(yield.getLoc())
4221  << name << " returns " << operand << " here";
4222  }
4223  return success();
4224  };
4225 
4226  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4227  return failure();
4228  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4229  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4230  return failure();
4231 
4232  return success();
4233 }
4234 
4235 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4236 
4237 Block &scf::IndexSwitchOp::getDefaultBlock() {
4238  return getDefaultRegion().front();
4239 }
4240 
4241 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4242  assert(idx < getNumCases() && "case index out-of-bounds");
4243  return getCaseRegions()[idx].front();
4244 }
4245 
4246 void IndexSwitchOp::getSuccessorRegions(
4248  // All regions branch back to the parent op.
4249  if (!point.isParent()) {
4250  successors.emplace_back(getResults());
4251  return;
4252  }
4253 
4254  llvm::copy(getRegions(), std::back_inserter(successors));
4255 }
4256 
4257 void IndexSwitchOp::getEntrySuccessorRegions(
4258  ArrayRef<Attribute> operands,
4259  SmallVectorImpl<RegionSuccessor> &successors) {
4260  FoldAdaptor adaptor(operands, *this);
4261 
4262  // If a constant was not provided, all regions are possible successors.
4263  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4264  if (!arg) {
4265  llvm::copy(getRegions(), std::back_inserter(successors));
4266  return;
4267  }
4268 
4269  // Otherwise, try to find a case with a matching value. If not, the
4270  // default region is the only successor.
4271  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4272  if (caseValue == arg.getInt()) {
4273  successors.emplace_back(&caseRegion);
4274  return;
4275  }
4276  }
4277  successors.emplace_back(&getDefaultRegion());
4278 }
4279 
4280 void IndexSwitchOp::getRegionInvocationBounds(
4282  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4283  if (!operandValue) {
4284  // All regions are invoked at most once.
4285  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4286  return;
4287  }
4288 
4289  unsigned liveIndex = getNumRegions() - 1;
4290  const auto *it = llvm::find(getCases(), operandValue.getInt());
4291  if (it != getCases().end())
4292  liveIndex = std::distance(getCases().begin(), it);
4293  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4294  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4295 }
4296 
4297 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4299 
4300  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4301  PatternRewriter &rewriter) const override {
4302  // If `op.getArg()` is a constant, select the region that matches with
4303  // the constant value. Use the default region if no matche is found.
4304  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4305  if (!maybeCst.has_value())
4306  return failure();
4307  int64_t cst = *maybeCst;
4308  int64_t caseIdx, e = op.getNumCases();
4309  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4310  if (cst == op.getCases()[caseIdx])
4311  break;
4312  }
4313 
4314  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4315  : op.getDefaultRegion();
4316  Block &source = r.front();
4317  Operation *terminator = source.getTerminator();
4318  SmallVector<Value> results = terminator->getOperands();
4319 
4320  rewriter.inlineBlockBefore(&source, op);
4321  rewriter.eraseOp(terminator);
4322  // Replace the operation with a potentially empty list of results.
4323  // Fold mechanism doesn't support the case where the result list is empty.
4324  rewriter.replaceOp(op, results);
4325 
4326  return success();
4327  }
4328 };
4329 
4330 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4331  MLIRContext *context) {
4332  results.add<FoldConstantCase>(context);
4333 }
4334 
4335 //===----------------------------------------------------------------------===//
4336 // TableGen'd op method definitions
4337 //===----------------------------------------------------------------------===//
4338 
4339 #define GET_OP_CLASSES
4340 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:712
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4165
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3411
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4180
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:262
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:579
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:606
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:461
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:592
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:534
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3024
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1951
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1451
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:776
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:265
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:522
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:473
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4300
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.