MLIR  20.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange iterArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(iterArgs);
321  for (Value v : iterArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : iterArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (iterArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Resolve input operands.
503  regionArgs.front().type = type;
504  if (parser.resolveOperand(lb, type, result.operands) ||
505  parser.resolveOperand(ub, type, result.operands) ||
506  parser.resolveOperand(step, type, result.operands))
507  return failure();
508  if (hasIterArgs) {
509  for (auto argOperandType :
510  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
511  Type type = std::get<2>(argOperandType);
512  std::get<0>(argOperandType).type = type;
513  if (parser.resolveOperand(std::get<1>(argOperandType), type,
514  result.operands))
515  return failure();
516  }
517  }
518 
519  // Parse the body region.
520  Region *body = result.addRegion();
521  if (parser.parseRegion(*body, regionArgs))
522  return failure();
523 
524  ForOp::ensureTerminator(*body, builder, result.location);
525 
526  // Parse the optional attribute list.
527  if (parser.parseOptionalAttrDict(result.attributes))
528  return failure();
529 
530  return success();
531 }
532 
533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534 
535 Block::BlockArgListType ForOp::getRegionIterArgs() {
536  return getBody()->getArguments().drop_front(getNumInductionVars());
537 }
538 
539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
540  return getInitArgsMutable();
541 }
542 
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
545  ValueRange newInitOperands,
546  bool replaceInitOperandUsesInLoop,
547  const NewYieldValuesFn &newYieldValuesFn) {
548  // Create a new loop before the existing one, with the extra operands.
549  OpBuilder::InsertionGuard g(rewriter);
550  rewriter.setInsertionPoint(getOperation());
551  auto inits = llvm::to_vector(getInitArgs());
552  inits.append(newInitOperands.begin(), newInitOperands.end());
553  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
554  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
555  [](OpBuilder &, Location, Value, ValueRange) {});
556  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
775 namespace {
776 // Fold away ForOp iter arguments when:
777 // 1) The op yields the iter arguments.
778 // 2) The iter arguments have no use and the corresponding outer region
779 // iterators (inputs) are yielded.
780 // 3) The iter arguments have no use and the corresponding (operation) results
781 // have no use.
782 //
783 // These arguments must be defined outside of
784 // the ForOp region and can just be forwarded after simplifying the op inits,
785 // yields and returns.
786 //
787 // The implementation uses `inlineBlockBefore` to steal the content of the
788 // original ForOp and avoid cloning.
789 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
791 
792  LogicalResult matchAndRewrite(scf::ForOp forOp,
793  PatternRewriter &rewriter) const final {
794  bool canonicalize = false;
795 
796  // An internal flat vector of block transfer
797  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
798  // transformed block argument mappings. This plays the role of a
799  // IRMapping for the particular use case of calling into
800  // `inlineBlockBefore`.
801  int64_t numResults = forOp.getNumResults();
802  SmallVector<bool, 4> keepMask;
803  keepMask.reserve(numResults);
804  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
805  newResultValues;
806  newBlockTransferArgs.reserve(1 + numResults);
807  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
808  newIterArgs.reserve(forOp.getInitArgs().size());
809  newYieldValues.reserve(numResults);
810  newResultValues.reserve(numResults);
811  for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
812  forOp.getRegionIterArgs(), // iter inside region
813  forOp.getResults(), // op results
814  forOp.getYieldedValues() // iter yield
815  )) {
816  // Forwarded is `true` when:
817  // 1) The region `iter` argument is yielded.
818  // 2) The region `iter` argument has no use, and the corresponding iter
819  // operand (input) is yielded.
820  // 3) The region `iter` argument has no use, and the corresponding op
821  // result has no use.
822  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
823  (std::get<1>(it).use_empty() &&
824  (std::get<0>(it) == std::get<3>(it) ||
825  std::get<2>(it).use_empty())));
826  keepMask.push_back(!forwarded);
827  canonicalize |= forwarded;
828  if (forwarded) {
829  newBlockTransferArgs.push_back(std::get<0>(it));
830  newResultValues.push_back(std::get<0>(it));
831  continue;
832  }
833  newIterArgs.push_back(std::get<0>(it));
834  newYieldValues.push_back(std::get<3>(it));
835  newBlockTransferArgs.push_back(Value()); // placeholder with null value
836  newResultValues.push_back(Value()); // placeholder with null value
837  }
838 
839  if (!canonicalize)
840  return failure();
841 
842  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
843  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
844  forOp.getStep(), newIterArgs);
845  newForOp->setAttrs(forOp->getAttrs());
846  Block &newBlock = newForOp.getRegion().front();
847 
848  // Replace the null placeholders with newly constructed values.
849  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
850  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
851  idx != e; ++idx) {
852  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
853  Value &newResultVal = newResultValues[idx];
854  assert((blockTransferArg && newResultVal) ||
855  (!blockTransferArg && !newResultVal));
856  if (!blockTransferArg) {
857  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
858  newResultVal = newForOp.getResult(collapsedIdx++);
859  }
860  }
861 
862  Block &oldBlock = forOp.getRegion().front();
863  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
864  "unexpected argument size mismatch");
865 
866  // No results case: the scf::ForOp builder already created a zero
867  // result terminator. Merge before this terminator and just get rid of the
868  // original terminator that has been merged in.
869  if (newIterArgs.empty()) {
870  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
871  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
872  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
873  rewriter.replaceOp(forOp, newResultValues);
874  return success();
875  }
876 
877  // No terminator case: merge and rewrite the merged terminator.
878  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
879  OpBuilder::InsertionGuard g(rewriter);
880  rewriter.setInsertionPoint(mergedTerminator);
881  SmallVector<Value, 4> filteredOperands;
882  filteredOperands.reserve(newResultValues.size());
883  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
884  if (keepMask[idx])
885  filteredOperands.push_back(mergedTerminator.getOperand(idx));
886  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
887  filteredOperands);
888  };
889 
890  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
891  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
892  cloneFilteredTerminator(mergedYieldOp);
893  rewriter.eraseOp(mergedYieldOp);
894  rewriter.replaceOp(forOp, newResultValues);
895  return success();
896  }
897 };
898 
899 /// Util function that tries to compute a constant diff between u and l.
900 /// Returns std::nullopt when the difference between two AffineValueMap is
901 /// dynamic.
902 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
903  IntegerAttr clb, cub;
904  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
905  llvm::APInt lbValue = clb.getValue();
906  llvm::APInt ubValue = cub.getValue();
907  return (ubValue - lbValue).getSExtValue();
908  }
909 
910  // Else a simple pattern match for x + c or c + x
911  llvm::APInt diff;
912  if (matchPattern(
913  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
914  matchPattern(
915  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
916  return diff.getSExtValue();
917  return std::nullopt;
918 }
919 
920 /// Rewriting pattern that erases loops that are known not to iterate, replaces
921 /// single-iteration loops with their bodies, and removes empty loops that
922 /// iterate at least once and only return values defined outside of the loop.
923 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
925 
926  LogicalResult matchAndRewrite(ForOp op,
927  PatternRewriter &rewriter) const override {
928  // If the upper bound is the same as the lower bound, the loop does not
929  // iterate, just remove it.
930  if (op.getLowerBound() == op.getUpperBound()) {
931  rewriter.replaceOp(op, op.getInitArgs());
932  return success();
933  }
934 
935  std::optional<int64_t> diff =
936  computeConstDiff(op.getLowerBound(), op.getUpperBound());
937  if (!diff)
938  return failure();
939 
940  // If the loop is known to have 0 iterations, remove it.
941  if (*diff <= 0) {
942  rewriter.replaceOp(op, op.getInitArgs());
943  return success();
944  }
945 
946  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
947  if (!maybeStepValue)
948  return failure();
949 
950  // If the loop is known to have 1 iteration, inline its body and remove the
951  // loop.
952  llvm::APInt stepValue = *maybeStepValue;
953  if (stepValue.sge(*diff)) {
954  SmallVector<Value, 4> blockArgs;
955  blockArgs.reserve(op.getInitArgs().size() + 1);
956  blockArgs.push_back(op.getLowerBound());
957  llvm::append_range(blockArgs, op.getInitArgs());
958  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
959  return success();
960  }
961 
962  // Now we are left with loops that have more than 1 iterations.
963  Block &block = op.getRegion().front();
964  if (!llvm::hasSingleElement(block))
965  return failure();
966  // If the loop is empty, iterates at least once, and only returns values
967  // defined outside of the loop, remove it and replace it with yield values.
968  if (llvm::any_of(op.getYieldedValues(),
969  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
970  return failure();
971  rewriter.replaceOp(op, op.getYieldedValues());
972  return success();
973  }
974 };
975 
976 /// Perform a replacement of one iter OpOperand of an scf.for to the
977 /// `replacement` value which is expected to be the source of a tensor.cast.
978 /// tensor.cast ops are inserted inside the block to account for the type cast.
979 static SmallVector<Value>
980 replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
981  Value replacement) {
982  Type oldType = operand.get().getType(), newType = replacement.getType();
983  assert(llvm::isa<RankedTensorType>(oldType) &&
984  llvm::isa<RankedTensorType>(newType) &&
985  "expected ranked tensor types");
986 
987  // 1. Create new iter operands, exactly 1 is replaced.
988  ForOp forOp = cast<ForOp>(operand.getOwner());
989  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
990  "expected an iter OpOperand");
991  assert(operand.get().getType() != replacement.getType() &&
992  "Expected a different type");
993  SmallVector<Value> newIterOperands;
994  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
995  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
996  newIterOperands.push_back(replacement);
997  continue;
998  }
999  newIterOperands.push_back(opOperand.get());
1000  }
1001 
1002  // 2. Create the new forOp shell.
1003  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
1004  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1005  forOp.getStep(), newIterOperands);
1006  newForOp->setAttrs(forOp->getAttrs());
1007  Block &newBlock = newForOp.getRegion().front();
1008  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
1009  newBlock.getArguments().end());
1010 
1011  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
1012  // corresponding to the `replacement` value.
1013  OpBuilder::InsertionGuard g(rewriter);
1014  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
1015  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
1016  &newForOp->getOpOperand(operand.getOperandNumber()));
1017  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
1018  newRegionIterArg);
1019  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
1020 
1021  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
1022  Block &oldBlock = forOp.getRegion().front();
1023  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
1024 
1025  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
1026  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
1027  rewriter.setInsertionPoint(clonedYieldOp);
1028  unsigned yieldIdx =
1029  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
1030  Value castOut = rewriter.create<tensor::CastOp>(
1031  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
1032  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
1033  newYieldOperands[yieldIdx] = castOut;
1034  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
1035  rewriter.eraseOp(clonedYieldOp);
1036 
1037  // 6. Inject an outgoing cast op after the forOp.
1038  rewriter.setInsertionPointAfter(newForOp);
1039  SmallVector<Value> newResults = newForOp.getResults();
1040  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
1041  newForOp.getLoc(), oldType, newResults[yieldIdx]);
1042 
1043  return newResults;
1044 }
1045 
1046 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1047 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1048 ///
1049 /// ```
1050 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1051 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1052 /// -> (tensor<?x?xf32>) {
1053 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1054 /// scf.yield %2 : tensor<?x?xf32>
1055 /// }
1056 /// use_of(%1)
1057 /// ```
1058 ///
1059 /// folds into:
1060 ///
1061 /// ```
1062 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1063 /// -> (tensor<32x1024xf32>) {
1064 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1065 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1066 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1067 /// scf.yield %4 : tensor<32x1024xf32>
1068 /// }
1069 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1070 /// use_of(%1)
1071 /// ```
1072 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1074 
1075  LogicalResult matchAndRewrite(ForOp op,
1076  PatternRewriter &rewriter) const override {
1077  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1078  OpOperand &iterOpOperand = std::get<0>(it);
1079  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1080  if (!incomingCast ||
1081  incomingCast.getSource().getType() == incomingCast.getType())
1082  continue;
1083  // If the dest type of the cast does not preserve static information in
1084  // the source type.
1086  incomingCast.getDest().getType(),
1087  incomingCast.getSource().getType()))
1088  continue;
1089  if (!std::get<1>(it).hasOneUse())
1090  continue;
1091 
1092  // Create a new ForOp with that iter operand replaced.
1093  rewriter.replaceOp(
1094  op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
1095  incomingCast.getSource()));
1096  return success();
1097  }
1098  return failure();
1099  }
1100 };
1101 
1102 } // namespace
1103 
1104 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1105  MLIRContext *context) {
1106  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1107  context);
1108 }
1109 
1110 std::optional<APInt> ForOp::getConstantStep() {
1111  IntegerAttr step;
1112  if (matchPattern(getStep(), m_Constant(&step)))
1113  return step.getValue();
1114  return {};
1115 }
1116 
1117 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1118  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1119 }
1120 
1121 Speculation::Speculatability ForOp::getSpeculatability() {
1122  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1123  // and End.
1124  if (auto constantStep = getConstantStep())
1125  if (*constantStep == 1)
1127 
1128  // For Step != 1, the loop may not terminate. We can add more smarts here if
1129  // needed.
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // ForallOp
1135 //===----------------------------------------------------------------------===//
1136 
1137 LogicalResult ForallOp::verify() {
1138  unsigned numLoops = getRank();
1139  // Check number of outputs.
1140  if (getNumResults() != getOutputs().size())
1141  return emitOpError("produces ")
1142  << getNumResults() << " results, but has only "
1143  << getOutputs().size() << " outputs";
1144 
1145  // Check that the body defines block arguments for thread indices and outputs.
1146  auto *body = getBody();
1147  if (body->getNumArguments() != numLoops + getOutputs().size())
1148  return emitOpError("region expects ") << numLoops << " arguments";
1149  for (int64_t i = 0; i < numLoops; ++i)
1150  if (!body->getArgument(i).getType().isIndex())
1151  return emitOpError("expects ")
1152  << i << "-th block argument to be an index";
1153  for (unsigned i = 0; i < getOutputs().size(); ++i)
1154  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1155  return emitOpError("type mismatch between ")
1156  << i << "-th output and corresponding block argument";
1157  if (getMapping().has_value() && !getMapping()->empty()) {
1158  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1159  return emitOpError() << "mapping attribute size must match op rank";
1160  for (auto map : getMapping()->getValue()) {
1161  if (!isa<DeviceMappingAttrInterface>(map))
1162  return emitOpError()
1163  << getMappingAttrName() << " is not device mapping attribute";
1164  }
1165  }
1166 
1167  // Verify mixed static/dynamic control variables.
1168  Operation *op = getOperation();
1169  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1170  getStaticLowerBound(),
1171  getDynamicLowerBound())))
1172  return failure();
1173  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1174  getStaticUpperBound(),
1175  getDynamicUpperBound())))
1176  return failure();
1177  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1178  getStaticStep(), getDynamicStep())))
1179  return failure();
1180 
1181  return success();
1182 }
1183 
1184 void ForallOp::print(OpAsmPrinter &p) {
1185  Operation *op = getOperation();
1186  p << " (" << getInductionVars();
1187  if (isNormalized()) {
1188  p << ") in ";
1189  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1190  /*valueTypes=*/{}, /*scalables=*/{},
1192  } else {
1193  p << ") = ";
1194  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1195  /*valueTypes=*/{}, /*scalables=*/{},
1197  p << " to ";
1198  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1199  /*valueTypes=*/{}, /*scalables=*/{},
1201  p << " step ";
1202  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1203  /*valueTypes=*/{}, /*scalables=*/{},
1205  }
1206  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1207  p << " ";
1208  if (!getRegionOutArgs().empty())
1209  p << "-> (" << getResultTypes() << ") ";
1210  p.printRegion(getRegion(),
1211  /*printEntryBlockArgs=*/false,
1212  /*printBlockTerminators=*/getNumResults() > 0);
1213  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1214  getStaticLowerBoundAttrName(),
1215  getStaticUpperBoundAttrName(),
1216  getStaticStepAttrName()});
1217 }
1218 
1219 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1220  OpBuilder b(parser.getContext());
1221  auto indexType = b.getIndexType();
1222 
1223  // Parse an opening `(` followed by thread index variables followed by `)`
1224  // TODO: when we can refer to such "induction variable"-like handles from the
1225  // declarative assembly format, we can implement the parser as a custom hook.
1228  return failure();
1229 
1230  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1231  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1232  dynamicSteps;
1233  if (succeeded(parser.parseOptionalKeyword("in"))) {
1234  // Parse upper bounds.
1235  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1236  /*valueTypes=*/nullptr,
1238  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1239  return failure();
1240 
1241  unsigned numLoops = ivs.size();
1242  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1243  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1244  } else {
1245  // Parse lower bounds.
1246  if (parser.parseEqual() ||
1247  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1248  /*valueTypes=*/nullptr,
1250 
1251  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1252  return failure();
1253 
1254  // Parse upper bounds.
1255  if (parser.parseKeyword("to") ||
1256  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1257  /*valueTypes=*/nullptr,
1259  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1260  return failure();
1261 
1262  // Parse step values.
1263  if (parser.parseKeyword("step") ||
1264  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1265  /*valueTypes=*/nullptr,
1267  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1268  return failure();
1269  }
1270 
1271  // Parse out operands and results.
1274  SMLoc outOperandsLoc = parser.getCurrentLocation();
1275  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1276  if (outOperands.size() != result.types.size())
1277  return parser.emitError(outOperandsLoc,
1278  "mismatch between out operands and types");
1279  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1280  parser.parseOptionalArrowTypeList(result.types) ||
1281  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1282  result.operands))
1283  return failure();
1284  }
1285 
1286  // Parse region.
1288  std::unique_ptr<Region> region = std::make_unique<Region>();
1289  for (auto &iv : ivs) {
1290  iv.type = b.getIndexType();
1291  regionArgs.push_back(iv);
1292  }
1293  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1294  auto &out = it.value();
1295  out.type = result.types[it.index()];
1296  regionArgs.push_back(out);
1297  }
1298  if (parser.parseRegion(*region, regionArgs))
1299  return failure();
1300 
1301  // Ensure terminator and move region.
1302  ForallOp::ensureTerminator(*region, b, result.location);
1303  result.addRegion(std::move(region));
1304 
1305  // Parse the optional attribute list.
1306  if (parser.parseOptionalAttrDict(result.attributes))
1307  return failure();
1308 
1309  result.addAttribute("staticLowerBound", staticLbs);
1310  result.addAttribute("staticUpperBound", staticUbs);
1311  result.addAttribute("staticStep", staticSteps);
1312  result.addAttribute("operandSegmentSizes",
1314  {static_cast<int32_t>(dynamicLbs.size()),
1315  static_cast<int32_t>(dynamicUbs.size()),
1316  static_cast<int32_t>(dynamicSteps.size()),
1317  static_cast<int32_t>(outOperands.size())}));
1318  return success();
1319 }
1320 
1321 // Builder that takes loop bounds.
1322 void ForallOp::build(
1325  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1326  std::optional<ArrayAttr> mapping,
1327  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1328  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1329  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1330  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1331  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1332  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1333 
1334  result.addOperands(dynamicLbs);
1335  result.addOperands(dynamicUbs);
1336  result.addOperands(dynamicSteps);
1337  result.addOperands(outputs);
1338  result.addTypes(TypeRange(outputs));
1339 
1340  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1341  b.getDenseI64ArrayAttr(staticLbs));
1342  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1343  b.getDenseI64ArrayAttr(staticUbs));
1344  result.addAttribute(getStaticStepAttrName(result.name),
1345  b.getDenseI64ArrayAttr(staticSteps));
1346  result.addAttribute(
1347  "operandSegmentSizes",
1348  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1349  static_cast<int32_t>(dynamicUbs.size()),
1350  static_cast<int32_t>(dynamicSteps.size()),
1351  static_cast<int32_t>(outputs.size())}));
1352  if (mapping.has_value()) {
1354  mapping.value());
1355  }
1356 
1357  Region *bodyRegion = result.addRegion();
1359  b.createBlock(bodyRegion);
1360  Block &bodyBlock = bodyRegion->front();
1361 
1362  // Add block arguments for indices and outputs.
1363  bodyBlock.addArguments(
1364  SmallVector<Type>(lbs.size(), b.getIndexType()),
1365  SmallVector<Location>(staticLbs.size(), result.location));
1366  bodyBlock.addArguments(
1367  TypeRange(outputs),
1368  SmallVector<Location>(outputs.size(), result.location));
1369 
1370  b.setInsertionPointToStart(&bodyBlock);
1371  if (!bodyBuilderFn) {
1372  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1373  return;
1374  }
1375  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1376 }
1377 
1378 // Builder that takes loop bounds.
1379 void ForallOp::build(
1381  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1382  std::optional<ArrayAttr> mapping,
1383  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1384  unsigned numLoops = ubs.size();
1385  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1386  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1387  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1388 }
1389 
1390 // Checks if the lbs are zeros and steps are ones.
1391 bool ForallOp::isNormalized() {
1392  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1393  return llvm::all_of(results, [&](OpFoldResult ofr) {
1394  auto intValue = getConstantIntValue(ofr);
1395  return intValue.has_value() && intValue == val;
1396  });
1397  };
1398  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1399 }
1400 
1401 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1402 // unaware of the fact that our terminator also needs a region to be
1403 // well-formed. We override it here to ensure that we do the right thing.
1404 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1405  Location loc) {
1407  ForallOp>::ensureTerminator(region, builder, loc);
1408  auto terminator =
1409  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1410  if (terminator.getRegion().empty())
1411  builder.createBlock(&terminator.getRegion());
1412 }
1413 
1414 InParallelOp ForallOp::getTerminator() {
1415  return cast<InParallelOp>(getBody()->getTerminator());
1416 }
1417 
1418 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1419  SmallVector<Operation *> storeOps;
1420  InParallelOp inParallelOp = getTerminator();
1421  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1422  if (auto parallelInsertSliceOp =
1423  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1424  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1425  storeOps.push_back(parallelInsertSliceOp);
1426  }
1427  }
1428  return storeOps;
1429 }
1430 
1431 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1432  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1433 }
1434 
1435 // Get lower bounds as OpFoldResult.
1436 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1437  Builder b(getOperation()->getContext());
1438  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1439 }
1440 
1441 // Get upper bounds as OpFoldResult.
1442 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1443  Builder b(getOperation()->getContext());
1444  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1445 }
1446 
1447 // Get steps as OpFoldResult.
1448 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1449  Builder b(getOperation()->getContext());
1450  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1451 }
1452 
1454  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1455  if (!tidxArg)
1456  return ForallOp();
1457  assert(tidxArg.getOwner() && "unlinked block argument");
1458  auto *containingOp = tidxArg.getOwner()->getParentOp();
1459  return dyn_cast<ForallOp>(containingOp);
1460 }
1461 
1462 namespace {
1463 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1464 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1466 
1467  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1468  PatternRewriter &rewriter) const final {
1469  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1470  if (!forallOp)
1471  return failure();
1472  Value sharedOut =
1473  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1474  ->get();
1475  rewriter.modifyOpInPlace(
1476  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1477  return success();
1478  }
1479 };
1480 
1481 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1482 public:
1484 
1485  LogicalResult matchAndRewrite(ForallOp op,
1486  PatternRewriter &rewriter) const override {
1487  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1488  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1489  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1490  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1491  failed(foldDynamicIndexList(mixedUpperBound)) &&
1492  failed(foldDynamicIndexList(mixedStep)))
1493  return failure();
1494 
1495  rewriter.modifyOpInPlace(op, [&]() {
1496  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1497  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1498  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1499  staticLowerBound);
1500  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1501  op.setStaticLowerBound(staticLowerBound);
1502 
1503  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1504  staticUpperBound);
1505  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1506  op.setStaticUpperBound(staticUpperBound);
1507 
1508  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1509  op.getDynamicStepMutable().assign(dynamicStep);
1510  op.setStaticStep(staticStep);
1511 
1512  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1513  rewriter.getDenseI32ArrayAttr(
1514  {static_cast<int32_t>(dynamicLowerBound.size()),
1515  static_cast<int32_t>(dynamicUpperBound.size()),
1516  static_cast<int32_t>(dynamicStep.size()),
1517  static_cast<int32_t>(op.getNumResults())}));
1518  });
1519  return success();
1520  }
1521 };
1522 
1523 /// The following canonicalization pattern folds the iter arguments of
1524 /// scf.forall op if :-
1525 /// 1. The corresponding result has zero uses.
1526 /// 2. The iter argument is NOT being modified within the loop body.
1527 /// uses.
1528 ///
1529 /// Example of first case :-
1530 /// INPUT:
1531 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1532 /// {
1533 /// ...
1534 /// <SOME USE OF %arg0>
1535 /// <SOME USE OF %arg1>
1536 /// <SOME USE OF %arg2>
1537 /// ...
1538 /// scf.forall.in_parallel {
1539 /// <STORE OP WITH DESTINATION %arg1>
1540 /// <STORE OP WITH DESTINATION %arg0>
1541 /// <STORE OP WITH DESTINATION %arg2>
1542 /// }
1543 /// }
1544 /// return %res#1
1545 ///
1546 /// OUTPUT:
1547 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1548 /// {
1549 /// ...
1550 /// <SOME USE OF %a>
1551 /// <SOME USE OF %new_arg0>
1552 /// <SOME USE OF %c>
1553 /// ...
1554 /// scf.forall.in_parallel {
1555 /// <STORE OP WITH DESTINATION %new_arg0>
1556 /// }
1557 /// }
1558 /// return %res
1559 ///
1560 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1561 /// scf.forall is replaced by their corresponding operands.
1562 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1563 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1564 /// this canonicalization remains valid. For more details, please refer
1565 /// to :
1566 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1567 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1568 /// handles tensor.parallel_insert_slice ops only.
1569 ///
1570 /// Example of second case :-
1571 /// INPUT:
1572 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1573 /// {
1574 /// ...
1575 /// <SOME USE OF %arg0>
1576 /// <SOME USE OF %arg1>
1577 /// ...
1578 /// scf.forall.in_parallel {
1579 /// <STORE OP WITH DESTINATION %arg1>
1580 /// }
1581 /// }
1582 /// return %res#0, %res#1
1583 ///
1584 /// OUTPUT:
1585 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1586 /// {
1587 /// ...
1588 /// <SOME USE OF %a>
1589 /// <SOME USE OF %new_arg0>
1590 /// ...
1591 /// scf.forall.in_parallel {
1592 /// <STORE OP WITH DESTINATION %new_arg0>
1593 /// }
1594 /// }
1595 /// return %a, %res
1596 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1598 
1599  LogicalResult matchAndRewrite(ForallOp forallOp,
1600  PatternRewriter &rewriter) const final {
1601  // Step 1: For a given i-th result of scf.forall, check the following :-
1602  // a. If it has any use.
1603  // b. If the corresponding iter argument is being modified within
1604  // the loop, i.e. has at least one store op with the iter arg as
1605  // its destination operand. For this we use
1606  // ForallOp::getCombiningOps(iter_arg).
1607  //
1608  // Based on the check we maintain the following :-
1609  // a. `resultToDelete` - i-th result of scf.forall that'll be
1610  // deleted.
1611  // b. `resultToReplace` - i-th result of the old scf.forall
1612  // whose uses will be replaced by the new scf.forall.
1613  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1614  // corresponding to the i-th result with at least one use.
1615  SetVector<OpResult> resultToDelete;
1616  SmallVector<Value> resultToReplace;
1617  SmallVector<Value> newOuts;
1618  for (OpResult result : forallOp.getResults()) {
1619  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1620  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1621  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1622  resultToDelete.insert(result);
1623  } else {
1624  resultToReplace.push_back(result);
1625  newOuts.push_back(opOperand->get());
1626  }
1627  }
1628 
1629  // Return early if all results of scf.forall have at least one use and being
1630  // modified within the loop.
1631  if (resultToDelete.empty())
1632  return failure();
1633 
1634  // Step 2: For the the i-th result, do the following :-
1635  // a. Fetch the corresponding BlockArgument.
1636  // b. Look for store ops (currently tensor.parallel_insert_slice)
1637  // with the BlockArgument as its destination operand.
1638  // c. Remove the operations fetched in b.
1639  for (OpResult result : resultToDelete) {
1640  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1641  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1642  SmallVector<Operation *> combiningOps =
1643  forallOp.getCombiningOps(blockArg);
1644  for (Operation *combiningOp : combiningOps)
1645  rewriter.eraseOp(combiningOp);
1646  }
1647 
1648  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1649  // fetched earlier
1650  auto newForallOp = rewriter.create<scf::ForallOp>(
1651  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1652  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1653  forallOp.getMapping(),
1654  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1655 
1656  // Step 4. Merge the block of the old scf.forall into the newly created
1657  // scf.forall using the new set of arguments.
1658  Block *loopBody = forallOp.getBody();
1659  Block *newLoopBody = newForallOp.getBody();
1660  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1661  // Form initial new bbArg list with just the control operands of the new
1662  // scf.forall op.
1663  SmallVector<Value> newBlockArgs =
1664  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1665  [](BlockArgument b) -> Value { return b; });
1666  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1667  unsigned index = 0;
1668  // Take the new corresponding bbArg if the old bbArg was used as a
1669  // destination in the in_parallel op. For all other bbArgs, use the
1670  // corresponding init_arg from the old scf.forall op.
1671  for (OpResult result : forallOp.getResults()) {
1672  if (resultToDelete.count(result)) {
1673  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1674  } else {
1675  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1676  }
1677  }
1678  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1679 
1680  // Step 5. Replace the uses of result of old scf.forall with that of the new
1681  // scf.forall.
1682  for (auto &&[oldResult, newResult] :
1683  llvm::zip(resultToReplace, newForallOp->getResults()))
1684  rewriter.replaceAllUsesWith(oldResult, newResult);
1685 
1686  // Step 6. Replace the uses of those values that either has no use or are
1687  // not being modified within the loop with the corresponding
1688  // OpOperand.
1689  for (OpResult oldResult : resultToDelete)
1690  rewriter.replaceAllUsesWith(oldResult,
1691  forallOp.getTiedOpOperand(oldResult)->get());
1692  return success();
1693  }
1694 };
1695 
1696 struct ForallOpSingleOrZeroIterationDimsFolder
1697  : public OpRewritePattern<ForallOp> {
1699 
1700  LogicalResult matchAndRewrite(ForallOp op,
1701  PatternRewriter &rewriter) const override {
1702  // Do not fold dimensions if they are mapped to processing units.
1703  if (op.getMapping().has_value())
1704  return failure();
1705  Location loc = op.getLoc();
1706 
1707  // Compute new loop bounds that omit all single-iteration loop dimensions.
1708  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1709  newMixedSteps;
1710  IRMapping mapping;
1711  for (auto [lb, ub, step, iv] :
1712  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1713  op.getMixedStep(), op.getInductionVars())) {
1714  auto numIterations = constantTripCount(lb, ub, step);
1715  if (numIterations.has_value()) {
1716  // Remove the loop if it performs zero iterations.
1717  if (*numIterations == 0) {
1718  rewriter.replaceOp(op, op.getOutputs());
1719  return success();
1720  }
1721  // Replace the loop induction variable by the lower bound if the loop
1722  // performs a single iteration. Otherwise, copy the loop bounds.
1723  if (*numIterations == 1) {
1724  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1725  continue;
1726  }
1727  }
1728  newMixedLowerBounds.push_back(lb);
1729  newMixedUpperBounds.push_back(ub);
1730  newMixedSteps.push_back(step);
1731  }
1732  // Exit if none of the loop dimensions perform a single iteration.
1733  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1734  return rewriter.notifyMatchFailure(
1735  op, "no dimensions have 0 or 1 iterations");
1736  }
1737 
1738  // All of the loop dimensions perform a single iteration. Inline loop body.
1739  if (newMixedLowerBounds.empty()) {
1740  promote(rewriter, op);
1741  return success();
1742  }
1743 
1744  // Replace the loop by a lower-dimensional loop.
1745  ForallOp newOp;
1746  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1747  newMixedUpperBounds, newMixedSteps,
1748  op.getOutputs(), std::nullopt, nullptr);
1749  newOp.getBodyRegion().getBlocks().clear();
1750  // The new loop needs to keep all attributes from the old one, except for
1751  // "operandSegmentSizes" and static loop bound attributes which capture
1752  // the outdated information of the old iteration domain.
1753  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1754  newOp.getStaticLowerBoundAttrName(),
1755  newOp.getStaticUpperBoundAttrName(),
1756  newOp.getStaticStepAttrName()};
1757  for (const auto &namedAttr : op->getAttrs()) {
1758  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1759  continue;
1760  rewriter.modifyOpInPlace(newOp, [&]() {
1761  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1762  });
1763  }
1764  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1765  newOp.getRegion().begin(), mapping);
1766  rewriter.replaceOp(op, newOp.getResults());
1767  return success();
1768  }
1769 };
1770 
1771 struct FoldTensorCastOfOutputIntoForallOp
1772  : public OpRewritePattern<scf::ForallOp> {
1774 
1775  struct TypeCast {
1776  Type srcType;
1777  Type dstType;
1778  };
1779 
1780  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1781  PatternRewriter &rewriter) const final {
1782  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1783  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1784  for (auto en : llvm::enumerate(newOutputTensors)) {
1785  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1786  if (!castOp)
1787  continue;
1788 
1789  // Only casts that that preserve static information, i.e. will make the
1790  // loop result type "more" static than before, will be folded.
1791  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1792  castOp.getSource().getType())) {
1793  continue;
1794  }
1795 
1796  tensorCastProducers[en.index()] =
1797  TypeCast{castOp.getSource().getType(), castOp.getType()};
1798  newOutputTensors[en.index()] = castOp.getSource();
1799  }
1800 
1801  if (tensorCastProducers.empty())
1802  return failure();
1803 
1804  // Create new loop.
1805  Location loc = forallOp.getLoc();
1806  auto newForallOp = rewriter.create<ForallOp>(
1807  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1808  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1809  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1810  auto castBlockArgs =
1811  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1812  for (auto [index, cast] : tensorCastProducers) {
1813  Value &oldTypeBBArg = castBlockArgs[index];
1814  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1815  nestedLoc, cast.dstType, oldTypeBBArg);
1816  }
1817 
1818  // Move old body into new parallel loop.
1819  SmallVector<Value> ivsBlockArgs =
1820  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1821  ivsBlockArgs.append(castBlockArgs);
1822  rewriter.mergeBlocks(forallOp.getBody(),
1823  bbArgs.front().getParentBlock(), ivsBlockArgs);
1824  });
1825 
1826  // After `mergeBlocks` happened, the destinations in the terminator were
1827  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1828  // destination have to be updated to point to the output bbArgs directly.
1829  auto terminator = newForallOp.getTerminator();
1830  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1831  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1832  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1833  insertSliceOp.getDestMutable().assign(outputBlockArg);
1834  }
1835 
1836  // Cast results back to the original types.
1837  rewriter.setInsertionPointAfter(newForallOp);
1838  SmallVector<Value> castResults = newForallOp.getResults();
1839  for (auto &item : tensorCastProducers) {
1840  Value &oldTypeResult = castResults[item.first];
1841  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1842  oldTypeResult);
1843  }
1844  rewriter.replaceOp(forallOp, castResults);
1845  return success();
1846  }
1847 };
1848 
1849 } // namespace
1850 
1851 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1852  MLIRContext *context) {
1853  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1854  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1855  ForallOpSingleOrZeroIterationDimsFolder>(context);
1856 }
1857 
1858 /// Given the region at `index`, or the parent operation if `index` is None,
1859 /// return the successor regions. These are the regions that may be selected
1860 /// during the flow of control. `operands` is a set of optional attributes that
1861 /// correspond to a constant value for each operand, or null if that operand is
1862 /// not a constant.
1863 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1865  // Both the operation itself and the region may be branching into the body or
1866  // back into the operation itself. It is possible for loop not to enter the
1867  // body.
1868  regions.push_back(RegionSuccessor(&getRegion()));
1869  regions.push_back(RegionSuccessor());
1870 }
1871 
1872 //===----------------------------------------------------------------------===//
1873 // InParallelOp
1874 //===----------------------------------------------------------------------===//
1875 
1876 // Build a InParallelOp with mixed static and dynamic entries.
1877 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1879  Region *bodyRegion = result.addRegion();
1880  b.createBlock(bodyRegion);
1881 }
1882 
1883 LogicalResult InParallelOp::verify() {
1884  scf::ForallOp forallOp =
1885  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1886  if (!forallOp)
1887  return this->emitOpError("expected forall op parent");
1888 
1889  // TODO: InParallelOpInterface.
1890  for (Operation &op : getRegion().front().getOperations()) {
1891  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1892  return this->emitOpError("expected only ")
1893  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1894  }
1895 
1896  // Verify that inserts are into out block arguments.
1897  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1898  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1899  if (!llvm::is_contained(regionOutArgs, dest))
1900  return op.emitOpError("may only insert into an output block argument");
1901  }
1902  return success();
1903 }
1904 
1906  p << " ";
1907  p.printRegion(getRegion(),
1908  /*printEntryBlockArgs=*/false,
1909  /*printBlockTerminators=*/false);
1910  p.printOptionalAttrDict(getOperation()->getAttrs());
1911 }
1912 
1913 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1914  auto &builder = parser.getBuilder();
1915 
1917  std::unique_ptr<Region> region = std::make_unique<Region>();
1918  if (parser.parseRegion(*region, regionOperands))
1919  return failure();
1920 
1921  if (region->empty())
1922  OpBuilder(builder.getContext()).createBlock(region.get());
1923  result.addRegion(std::move(region));
1924 
1925  // Parse the optional attribute list.
1926  if (parser.parseOptionalAttrDict(result.attributes))
1927  return failure();
1928  return success();
1929 }
1930 
1931 OpResult InParallelOp::getParentResult(int64_t idx) {
1932  return getOperation()->getParentOp()->getResult(idx);
1933 }
1934 
1935 SmallVector<BlockArgument> InParallelOp::getDests() {
1936  return llvm::to_vector<4>(
1937  llvm::map_range(getYieldingOps(), [](Operation &op) {
1938  // Add new ops here as needed.
1939  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1940  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1941  }));
1942 }
1943 
1944 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1945  return getRegion().front().getOperations();
1946 }
1947 
1948 //===----------------------------------------------------------------------===//
1949 // IfOp
1950 //===----------------------------------------------------------------------===//
1951 
1953  assert(a && "expected non-empty operation");
1954  assert(b && "expected non-empty operation");
1955 
1956  IfOp ifOp = a->getParentOfType<IfOp>();
1957  while (ifOp) {
1958  // Check if b is inside ifOp. (We already know that a is.)
1959  if (ifOp->isProperAncestor(b))
1960  // b is contained in ifOp. a and b are in mutually exclusive branches if
1961  // they are in different blocks of ifOp.
1962  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1963  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1964  // Check next enclosing IfOp.
1965  ifOp = ifOp->getParentOfType<IfOp>();
1966  }
1967 
1968  // Could not find a common IfOp among a's and b's ancestors.
1969  return false;
1970 }
1971 
1972 LogicalResult
1973 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1974  IfOp::Adaptor adaptor,
1975  SmallVectorImpl<Type> &inferredReturnTypes) {
1976  if (adaptor.getRegions().empty())
1977  return failure();
1978  Region *r = &adaptor.getThenRegion();
1979  if (r->empty())
1980  return failure();
1981  Block &b = r->front();
1982  if (b.empty())
1983  return failure();
1984  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
1985  if (!yieldOp)
1986  return failure();
1987  TypeRange types = yieldOp.getOperandTypes();
1988  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
1989  types.end());
1990  return success();
1991 }
1992 
1993 void IfOp::build(OpBuilder &builder, OperationState &result,
1994  TypeRange resultTypes, Value cond) {
1995  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
1996  /*addElseBlock=*/false);
1997 }
1998 
1999 void IfOp::build(OpBuilder &builder, OperationState &result,
2000  TypeRange resultTypes, Value cond, bool addThenBlock,
2001  bool addElseBlock) {
2002  assert((!addElseBlock || addThenBlock) &&
2003  "must not create else block w/o then block");
2004  result.addTypes(resultTypes);
2005  result.addOperands(cond);
2006 
2007  // Add regions and blocks.
2008  OpBuilder::InsertionGuard guard(builder);
2009  Region *thenRegion = result.addRegion();
2010  if (addThenBlock)
2011  builder.createBlock(thenRegion);
2012  Region *elseRegion = result.addRegion();
2013  if (addElseBlock)
2014  builder.createBlock(elseRegion);
2015 }
2016 
2017 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2018  bool withElseRegion) {
2019  build(builder, result, TypeRange{}, cond, withElseRegion);
2020 }
2021 
2022 void IfOp::build(OpBuilder &builder, OperationState &result,
2023  TypeRange resultTypes, Value cond, bool withElseRegion) {
2024  result.addTypes(resultTypes);
2025  result.addOperands(cond);
2026 
2027  // Build then region.
2028  OpBuilder::InsertionGuard guard(builder);
2029  Region *thenRegion = result.addRegion();
2030  builder.createBlock(thenRegion);
2031  if (resultTypes.empty())
2032  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2033 
2034  // Build else region.
2035  Region *elseRegion = result.addRegion();
2036  if (withElseRegion) {
2037  builder.createBlock(elseRegion);
2038  if (resultTypes.empty())
2039  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2040  }
2041 }
2042 
2043 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2044  function_ref<void(OpBuilder &, Location)> thenBuilder,
2045  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2046  assert(thenBuilder && "the builder callback for 'then' must be present");
2047  result.addOperands(cond);
2048 
2049  // Build then region.
2050  OpBuilder::InsertionGuard guard(builder);
2051  Region *thenRegion = result.addRegion();
2052  builder.createBlock(thenRegion);
2053  thenBuilder(builder, result.location);
2054 
2055  // Build else region.
2056  Region *elseRegion = result.addRegion();
2057  if (elseBuilder) {
2058  builder.createBlock(elseRegion);
2059  elseBuilder(builder, result.location);
2060  }
2061 
2062  // Infer result types.
2063  SmallVector<Type> inferredReturnTypes;
2064  MLIRContext *ctx = builder.getContext();
2065  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2066  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2067  /*properties=*/nullptr, result.regions,
2068  inferredReturnTypes))) {
2069  result.addTypes(inferredReturnTypes);
2070  }
2071 }
2072 
2073 LogicalResult IfOp::verify() {
2074  if (getNumResults() != 0 && getElseRegion().empty())
2075  return emitOpError("must have an else block if defining values");
2076  return success();
2077 }
2078 
2079 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2080  // Create the regions for 'then'.
2081  result.regions.reserve(2);
2082  Region *thenRegion = result.addRegion();
2083  Region *elseRegion = result.addRegion();
2084 
2085  auto &builder = parser.getBuilder();
2087  Type i1Type = builder.getIntegerType(1);
2088  if (parser.parseOperand(cond) ||
2089  parser.resolveOperand(cond, i1Type, result.operands))
2090  return failure();
2091  // Parse optional results type list.
2092  if (parser.parseOptionalArrowTypeList(result.types))
2093  return failure();
2094  // Parse the 'then' region.
2095  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2096  return failure();
2097  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2098 
2099  // If we find an 'else' keyword then parse the 'else' region.
2100  if (!parser.parseOptionalKeyword("else")) {
2101  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2102  return failure();
2103  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2104  }
2105 
2106  // Parse the optional attribute list.
2107  if (parser.parseOptionalAttrDict(result.attributes))
2108  return failure();
2109  return success();
2110 }
2111 
2112 void IfOp::print(OpAsmPrinter &p) {
2113  bool printBlockTerminators = false;
2114 
2115  p << " " << getCondition();
2116  if (!getResults().empty()) {
2117  p << " -> (" << getResultTypes() << ")";
2118  // Print yield explicitly if the op defines values.
2119  printBlockTerminators = true;
2120  }
2121  p << ' ';
2122  p.printRegion(getThenRegion(),
2123  /*printEntryBlockArgs=*/false,
2124  /*printBlockTerminators=*/printBlockTerminators);
2125 
2126  // Print the 'else' regions if it exists and has a block.
2127  auto &elseRegion = getElseRegion();
2128  if (!elseRegion.empty()) {
2129  p << " else ";
2130  p.printRegion(elseRegion,
2131  /*printEntryBlockArgs=*/false,
2132  /*printBlockTerminators=*/printBlockTerminators);
2133  }
2134 
2135  p.printOptionalAttrDict((*this)->getAttrs());
2136 }
2137 
2138 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2140  // The `then` and the `else` region branch back to the parent operation.
2141  if (!point.isParent()) {
2142  regions.push_back(RegionSuccessor(getResults()));
2143  return;
2144  }
2145 
2146  regions.push_back(RegionSuccessor(&getThenRegion()));
2147 
2148  // Don't consider the else region if it is empty.
2149  Region *elseRegion = &this->getElseRegion();
2150  if (elseRegion->empty())
2151  regions.push_back(RegionSuccessor());
2152  else
2153  regions.push_back(RegionSuccessor(elseRegion));
2154 }
2155 
2156 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2158  FoldAdaptor adaptor(operands, *this);
2159  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2160  if (!boolAttr || boolAttr.getValue())
2161  regions.emplace_back(&getThenRegion());
2162 
2163  // If the else region is empty, execution continues after the parent op.
2164  if (!boolAttr || !boolAttr.getValue()) {
2165  if (!getElseRegion().empty())
2166  regions.emplace_back(&getElseRegion());
2167  else
2168  regions.emplace_back(getResults());
2169  }
2170 }
2171 
2172 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2173  SmallVectorImpl<OpFoldResult> &results) {
2174  // if (!c) then A() else B() -> if c then B() else A()
2175  if (getElseRegion().empty())
2176  return failure();
2177 
2178  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2179  if (!xorStmt)
2180  return failure();
2181 
2182  if (!matchPattern(xorStmt.getRhs(), m_One()))
2183  return failure();
2184 
2185  getConditionMutable().assign(xorStmt.getLhs());
2186  Block *thenBlock = &getThenRegion().front();
2187  // It would be nicer to use iplist::swap, but that has no implemented
2188  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2189  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2190  getElseRegion().getBlocks());
2191  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2192  getThenRegion().getBlocks(), thenBlock);
2193  return success();
2194 }
2195 
2196 void IfOp::getRegionInvocationBounds(
2197  ArrayRef<Attribute> operands,
2198  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2199  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2200  // If the condition is known, then one region is known to be executed once
2201  // and the other zero times.
2202  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2203  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2204  } else {
2205  // Non-constant condition. Each region may be executed 0 or 1 times.
2206  invocationBounds.assign(2, {0, 1});
2207  }
2208 }
2209 
2210 namespace {
2211 // Pattern to remove unused IfOp results.
2212 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2214 
2215  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2216  PatternRewriter &rewriter) const {
2217  // Move all operations to the destination block.
2218  rewriter.mergeBlocks(source, dest);
2219  // Replace the yield op by one that returns only the used values.
2220  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2221  SmallVector<Value, 4> usedOperands;
2222  llvm::transform(usedResults, std::back_inserter(usedOperands),
2223  [&](OpResult result) {
2224  return yieldOp.getOperand(result.getResultNumber());
2225  });
2226  rewriter.modifyOpInPlace(yieldOp,
2227  [&]() { yieldOp->setOperands(usedOperands); });
2228  }
2229 
2230  LogicalResult matchAndRewrite(IfOp op,
2231  PatternRewriter &rewriter) const override {
2232  // Compute the list of used results.
2233  SmallVector<OpResult, 4> usedResults;
2234  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2235  [](OpResult result) { return !result.use_empty(); });
2236 
2237  // Replace the operation if only a subset of its results have uses.
2238  if (usedResults.size() == op.getNumResults())
2239  return failure();
2240 
2241  // Compute the result types of the replacement operation.
2242  SmallVector<Type, 4> newTypes;
2243  llvm::transform(usedResults, std::back_inserter(newTypes),
2244  [](OpResult result) { return result.getType(); });
2245 
2246  // Create a replacement operation with empty then and else regions.
2247  auto newOp =
2248  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2249  rewriter.createBlock(&newOp.getThenRegion());
2250  rewriter.createBlock(&newOp.getElseRegion());
2251 
2252  // Move the bodies and replace the terminators (note there is a then and
2253  // an else region since the operation returns results).
2254  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2255  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2256 
2257  // Replace the operation by the new one.
2258  SmallVector<Value, 4> repResults(op.getNumResults());
2259  for (const auto &en : llvm::enumerate(usedResults))
2260  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2261  rewriter.replaceOp(op, repResults);
2262  return success();
2263  }
2264 };
2265 
2266 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2268 
2269  LogicalResult matchAndRewrite(IfOp op,
2270  PatternRewriter &rewriter) const override {
2271  BoolAttr condition;
2272  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2273  return failure();
2274 
2275  if (condition.getValue())
2276  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2277  else if (!op.getElseRegion().empty())
2278  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2279  else
2280  rewriter.eraseOp(op);
2281 
2282  return success();
2283  }
2284 };
2285 
2286 /// Hoist any yielded results whose operands are defined outside
2287 /// the if, to a select instruction.
2288 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2290 
2291  LogicalResult matchAndRewrite(IfOp op,
2292  PatternRewriter &rewriter) const override {
2293  if (op->getNumResults() == 0)
2294  return failure();
2295 
2296  auto cond = op.getCondition();
2297  auto thenYieldArgs = op.thenYield().getOperands();
2298  auto elseYieldArgs = op.elseYield().getOperands();
2299 
2300  SmallVector<Type> nonHoistable;
2301  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2302  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2303  &op.getElseRegion() == falseVal.getParentRegion())
2304  nonHoistable.push_back(trueVal.getType());
2305  }
2306  // Early exit if there aren't any yielded values we can
2307  // hoist outside the if.
2308  if (nonHoistable.size() == op->getNumResults())
2309  return failure();
2310 
2311  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2312  /*withElseRegion=*/false);
2313  if (replacement.thenBlock())
2314  rewriter.eraseBlock(replacement.thenBlock());
2315  replacement.getThenRegion().takeBody(op.getThenRegion());
2316  replacement.getElseRegion().takeBody(op.getElseRegion());
2317 
2318  SmallVector<Value> results(op->getNumResults());
2319  assert(thenYieldArgs.size() == results.size());
2320  assert(elseYieldArgs.size() == results.size());
2321 
2322  SmallVector<Value> trueYields;
2323  SmallVector<Value> falseYields;
2324  rewriter.setInsertionPoint(replacement);
2325  for (const auto &it :
2326  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2327  Value trueVal = std::get<0>(it.value());
2328  Value falseVal = std::get<1>(it.value());
2329  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2330  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2331  results[it.index()] = replacement.getResult(trueYields.size());
2332  trueYields.push_back(trueVal);
2333  falseYields.push_back(falseVal);
2334  } else if (trueVal == falseVal)
2335  results[it.index()] = trueVal;
2336  else
2337  results[it.index()] = rewriter.create<arith::SelectOp>(
2338  op.getLoc(), cond, trueVal, falseVal);
2339  }
2340 
2341  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2342  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2343 
2344  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2345  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2346 
2347  rewriter.replaceOp(op, results);
2348  return success();
2349  }
2350 };
2351 
2352 /// Allow the true region of an if to assume the condition is true
2353 /// and vice versa. For example:
2354 ///
2355 /// scf.if %cmp {
2356 /// print(%cmp)
2357 /// }
2358 ///
2359 /// becomes
2360 ///
2361 /// scf.if %cmp {
2362 /// print(true)
2363 /// }
2364 ///
2365 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2367 
2368  LogicalResult matchAndRewrite(IfOp op,
2369  PatternRewriter &rewriter) const override {
2370  // Early exit if the condition is constant since replacing a constant
2371  // in the body with another constant isn't a simplification.
2372  if (matchPattern(op.getCondition(), m_Constant()))
2373  return failure();
2374 
2375  bool changed = false;
2376  mlir::Type i1Ty = rewriter.getI1Type();
2377 
2378  // These variables serve to prevent creating duplicate constants
2379  // and hold constant true or false values.
2380  Value constantTrue = nullptr;
2381  Value constantFalse = nullptr;
2382 
2383  for (OpOperand &use :
2384  llvm::make_early_inc_range(op.getCondition().getUses())) {
2385  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2386  changed = true;
2387 
2388  if (!constantTrue)
2389  constantTrue = rewriter.create<arith::ConstantOp>(
2390  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2391 
2392  rewriter.modifyOpInPlace(use.getOwner(),
2393  [&]() { use.set(constantTrue); });
2394  } else if (op.getElseRegion().isAncestor(
2395  use.getOwner()->getParentRegion())) {
2396  changed = true;
2397 
2398  if (!constantFalse)
2399  constantFalse = rewriter.create<arith::ConstantOp>(
2400  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2401 
2402  rewriter.modifyOpInPlace(use.getOwner(),
2403  [&]() { use.set(constantFalse); });
2404  }
2405  }
2406 
2407  return success(changed);
2408  }
2409 };
2410 
2411 /// Remove any statements from an if that are equivalent to the condition
2412 /// or its negation. For example:
2413 ///
2414 /// %res:2 = scf.if %cmp {
2415 /// yield something(), true
2416 /// } else {
2417 /// yield something2(), false
2418 /// }
2419 /// print(%res#1)
2420 ///
2421 /// becomes
2422 /// %res = scf.if %cmp {
2423 /// yield something()
2424 /// } else {
2425 /// yield something2()
2426 /// }
2427 /// print(%cmp)
2428 ///
2429 /// Additionally if both branches yield the same value, replace all uses
2430 /// of the result with the yielded value.
2431 ///
2432 /// %res:2 = scf.if %cmp {
2433 /// yield something(), %arg1
2434 /// } else {
2435 /// yield something2(), %arg1
2436 /// }
2437 /// print(%res#1)
2438 ///
2439 /// becomes
2440 /// %res = scf.if %cmp {
2441 /// yield something()
2442 /// } else {
2443 /// yield something2()
2444 /// }
2445 /// print(%arg1)
2446 ///
2447 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2449 
2450  LogicalResult matchAndRewrite(IfOp op,
2451  PatternRewriter &rewriter) const override {
2452  // Early exit if there are no results that could be replaced.
2453  if (op.getNumResults() == 0)
2454  return failure();
2455 
2456  auto trueYield =
2457  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2458  auto falseYield =
2459  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2460 
2461  rewriter.setInsertionPoint(op->getBlock(),
2462  op.getOperation()->getIterator());
2463  bool changed = false;
2464  Type i1Ty = rewriter.getI1Type();
2465  for (auto [trueResult, falseResult, opResult] :
2466  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2467  op.getResults())) {
2468  if (trueResult == falseResult) {
2469  if (!opResult.use_empty()) {
2470  opResult.replaceAllUsesWith(trueResult);
2471  changed = true;
2472  }
2473  continue;
2474  }
2475 
2476  BoolAttr trueYield, falseYield;
2477  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2478  !matchPattern(falseResult, m_Constant(&falseYield)))
2479  continue;
2480 
2481  bool trueVal = trueYield.getValue();
2482  bool falseVal = falseYield.getValue();
2483  if (!trueVal && falseVal) {
2484  if (!opResult.use_empty()) {
2485  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2486  Value notCond = rewriter.create<arith::XOrIOp>(
2487  op.getLoc(), op.getCondition(),
2488  constDialect
2489  ->materializeConstant(rewriter,
2490  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2491  op.getLoc())
2492  ->getResult(0));
2493  opResult.replaceAllUsesWith(notCond);
2494  changed = true;
2495  }
2496  }
2497  if (trueVal && !falseVal) {
2498  if (!opResult.use_empty()) {
2499  opResult.replaceAllUsesWith(op.getCondition());
2500  changed = true;
2501  }
2502  }
2503  }
2504  return success(changed);
2505  }
2506 };
2507 
2508 /// Merge any consecutive scf.if's with the same condition.
2509 ///
2510 /// scf.if %cond {
2511 /// firstCodeTrue();...
2512 /// } else {
2513 /// firstCodeFalse();...
2514 /// }
2515 /// %res = scf.if %cond {
2516 /// secondCodeTrue();...
2517 /// } else {
2518 /// secondCodeFalse();...
2519 /// }
2520 ///
2521 /// becomes
2522 /// %res = scf.if %cmp {
2523 /// firstCodeTrue();...
2524 /// secondCodeTrue();...
2525 /// } else {
2526 /// firstCodeFalse();...
2527 /// secondCodeFalse();...
2528 /// }
2529 struct CombineIfs : public OpRewritePattern<IfOp> {
2531 
2532  LogicalResult matchAndRewrite(IfOp nextIf,
2533  PatternRewriter &rewriter) const override {
2534  Block *parent = nextIf->getBlock();
2535  if (nextIf == &parent->front())
2536  return failure();
2537 
2538  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2539  if (!prevIf)
2540  return failure();
2541 
2542  // Determine the logical then/else blocks when prevIf's
2543  // condition is used. Null means the block does not exist
2544  // in that case (e.g. empty else). If neither of these
2545  // are set, the two conditions cannot be compared.
2546  Block *nextThen = nullptr;
2547  Block *nextElse = nullptr;
2548  if (nextIf.getCondition() == prevIf.getCondition()) {
2549  nextThen = nextIf.thenBlock();
2550  if (!nextIf.getElseRegion().empty())
2551  nextElse = nextIf.elseBlock();
2552  }
2553  if (arith::XOrIOp notv =
2554  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2555  if (notv.getLhs() == prevIf.getCondition() &&
2556  matchPattern(notv.getRhs(), m_One())) {
2557  nextElse = nextIf.thenBlock();
2558  if (!nextIf.getElseRegion().empty())
2559  nextThen = nextIf.elseBlock();
2560  }
2561  }
2562  if (arith::XOrIOp notv =
2563  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2564  if (notv.getLhs() == nextIf.getCondition() &&
2565  matchPattern(notv.getRhs(), m_One())) {
2566  nextElse = nextIf.thenBlock();
2567  if (!nextIf.getElseRegion().empty())
2568  nextThen = nextIf.elseBlock();
2569  }
2570  }
2571 
2572  if (!nextThen && !nextElse)
2573  return failure();
2574 
2575  SmallVector<Value> prevElseYielded;
2576  if (!prevIf.getElseRegion().empty())
2577  prevElseYielded = prevIf.elseYield().getOperands();
2578  // Replace all uses of return values of op within nextIf with the
2579  // corresponding yields
2580  for (auto it : llvm::zip(prevIf.getResults(),
2581  prevIf.thenYield().getOperands(), prevElseYielded))
2582  for (OpOperand &use :
2583  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2584  if (nextThen && nextThen->getParent()->isAncestor(
2585  use.getOwner()->getParentRegion())) {
2586  rewriter.startOpModification(use.getOwner());
2587  use.set(std::get<1>(it));
2588  rewriter.finalizeOpModification(use.getOwner());
2589  } else if (nextElse && nextElse->getParent()->isAncestor(
2590  use.getOwner()->getParentRegion())) {
2591  rewriter.startOpModification(use.getOwner());
2592  use.set(std::get<2>(it));
2593  rewriter.finalizeOpModification(use.getOwner());
2594  }
2595  }
2596 
2597  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2598  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2599 
2600  IfOp combinedIf = rewriter.create<IfOp>(
2601  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2602  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2603 
2604  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2605  combinedIf.getThenRegion(),
2606  combinedIf.getThenRegion().begin());
2607 
2608  if (nextThen) {
2609  YieldOp thenYield = combinedIf.thenYield();
2610  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2611  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2612  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2613 
2614  SmallVector<Value> mergedYields(thenYield.getOperands());
2615  llvm::append_range(mergedYields, thenYield2.getOperands());
2616  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2617  rewriter.eraseOp(thenYield);
2618  rewriter.eraseOp(thenYield2);
2619  }
2620 
2621  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2622  combinedIf.getElseRegion(),
2623  combinedIf.getElseRegion().begin());
2624 
2625  if (nextElse) {
2626  if (combinedIf.getElseRegion().empty()) {
2627  rewriter.inlineRegionBefore(*nextElse->getParent(),
2628  combinedIf.getElseRegion(),
2629  combinedIf.getElseRegion().begin());
2630  } else {
2631  YieldOp elseYield = combinedIf.elseYield();
2632  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2633  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2634 
2635  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2636 
2637  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2638  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2639 
2640  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2641  rewriter.eraseOp(elseYield);
2642  rewriter.eraseOp(elseYield2);
2643  }
2644  }
2645 
2646  SmallVector<Value> prevValues;
2647  SmallVector<Value> nextValues;
2648  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2649  if (pair.index() < prevIf.getNumResults())
2650  prevValues.push_back(pair.value());
2651  else
2652  nextValues.push_back(pair.value());
2653  }
2654  rewriter.replaceOp(prevIf, prevValues);
2655  rewriter.replaceOp(nextIf, nextValues);
2656  return success();
2657  }
2658 };
2659 
2660 /// Pattern to remove an empty else branch.
2661 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2663 
2664  LogicalResult matchAndRewrite(IfOp ifOp,
2665  PatternRewriter &rewriter) const override {
2666  // Cannot remove else region when there are operation results.
2667  if (ifOp.getNumResults())
2668  return failure();
2669  Block *elseBlock = ifOp.elseBlock();
2670  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2671  return failure();
2672  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2673  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2674  newIfOp.getThenRegion().begin());
2675  rewriter.eraseOp(ifOp);
2676  return success();
2677  }
2678 };
2679 
2680 /// Convert nested `if`s into `arith.andi` + single `if`.
2681 ///
2682 /// scf.if %arg0 {
2683 /// scf.if %arg1 {
2684 /// ...
2685 /// scf.yield
2686 /// }
2687 /// scf.yield
2688 /// }
2689 /// becomes
2690 ///
2691 /// %0 = arith.andi %arg0, %arg1
2692 /// scf.if %0 {
2693 /// ...
2694 /// scf.yield
2695 /// }
2696 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2698 
2699  LogicalResult matchAndRewrite(IfOp op,
2700  PatternRewriter &rewriter) const override {
2701  auto nestedOps = op.thenBlock()->without_terminator();
2702  // Nested `if` must be the only op in block.
2703  if (!llvm::hasSingleElement(nestedOps))
2704  return failure();
2705 
2706  // If there is an else block, it can only yield
2707  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2708  return failure();
2709 
2710  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2711  if (!nestedIf)
2712  return failure();
2713 
2714  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2715  return failure();
2716 
2717  SmallVector<Value> thenYield(op.thenYield().getOperands());
2718  SmallVector<Value> elseYield;
2719  if (op.elseBlock())
2720  llvm::append_range(elseYield, op.elseYield().getOperands());
2721 
2722  // A list of indices for which we should upgrade the value yielded
2723  // in the else to a select.
2724  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2725 
2726  // If the outer scf.if yields a value produced by the inner scf.if,
2727  // only permit combining if the value yielded when the condition
2728  // is false in the outer scf.if is the same value yielded when the
2729  // inner scf.if condition is false.
2730  // Note that the array access to elseYield will not go out of bounds
2731  // since it must have the same length as thenYield, since they both
2732  // come from the same scf.if.
2733  for (const auto &tup : llvm::enumerate(thenYield)) {
2734  if (tup.value().getDefiningOp() == nestedIf) {
2735  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2736  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2737  elseYield[tup.index()]) {
2738  return failure();
2739  }
2740  // If the correctness test passes, we will yield
2741  // corresponding value from the inner scf.if
2742  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2743  continue;
2744  }
2745 
2746  // Otherwise, we need to ensure the else block of the combined
2747  // condition still returns the same value when the outer condition is
2748  // true and the inner condition is false. This can be accomplished if
2749  // the then value is defined outside the outer scf.if and we replace the
2750  // value with a select that considers just the outer condition. Since
2751  // the else region contains just the yield, its yielded value is
2752  // defined outside the scf.if, by definition.
2753 
2754  // If the then value is defined within the scf.if, bail.
2755  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2756  return failure();
2757  }
2758  elseYieldsToUpgradeToSelect.push_back(tup.index());
2759  }
2760 
2761  Location loc = op.getLoc();
2762  Value newCondition = rewriter.create<arith::AndIOp>(
2763  loc, op.getCondition(), nestedIf.getCondition());
2764  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2765  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2766 
2767  SmallVector<Value> results;
2768  llvm::append_range(results, newIf.getResults());
2769  rewriter.setInsertionPoint(newIf);
2770 
2771  for (auto idx : elseYieldsToUpgradeToSelect)
2772  results[idx] = rewriter.create<arith::SelectOp>(
2773  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2774 
2775  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2776  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2777  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2778  if (!elseYield.empty()) {
2779  rewriter.createBlock(&newIf.getElseRegion());
2780  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2781  rewriter.create<YieldOp>(loc, elseYield);
2782  }
2783  rewriter.replaceOp(op, results);
2784  return success();
2785  }
2786 };
2787 
2788 } // namespace
2789 
2790 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2791  MLIRContext *context) {
2792  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2793  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2794  RemoveStaticCondition, RemoveUnusedResults,
2795  ReplaceIfYieldWithConditionOrValue>(context);
2796 }
2797 
2798 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2799 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2800 Block *IfOp::elseBlock() {
2801  Region &r = getElseRegion();
2802  if (r.empty())
2803  return nullptr;
2804  return &r.back();
2805 }
2806 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2807 
2808 //===----------------------------------------------------------------------===//
2809 // ParallelOp
2810 //===----------------------------------------------------------------------===//
2811 
2812 void ParallelOp::build(
2813  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2814  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2816  bodyBuilderFn) {
2817  result.addOperands(lowerBounds);
2818  result.addOperands(upperBounds);
2819  result.addOperands(steps);
2820  result.addOperands(initVals);
2821  result.addAttribute(
2822  ParallelOp::getOperandSegmentSizeAttr(),
2823  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2824  static_cast<int32_t>(upperBounds.size()),
2825  static_cast<int32_t>(steps.size()),
2826  static_cast<int32_t>(initVals.size())}));
2827  result.addTypes(initVals.getTypes());
2828 
2829  OpBuilder::InsertionGuard guard(builder);
2830  unsigned numIVs = steps.size();
2831  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2832  SmallVector<Location, 8> argLocs(numIVs, result.location);
2833  Region *bodyRegion = result.addRegion();
2834  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2835 
2836  if (bodyBuilderFn) {
2837  builder.setInsertionPointToStart(bodyBlock);
2838  bodyBuilderFn(builder, result.location,
2839  bodyBlock->getArguments().take_front(numIVs),
2840  bodyBlock->getArguments().drop_front(numIVs));
2841  }
2842  // Add terminator only if there are no reductions.
2843  if (initVals.empty())
2844  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2845 }
2846 
2847 void ParallelOp::build(
2848  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2849  ValueRange upperBounds, ValueRange steps,
2850  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2851  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2852  // we don't capture a reference to a temporary by constructing the lambda at
2853  // function level.
2854  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2855  Location nestedLoc, ValueRange ivs,
2856  ValueRange) {
2857  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2858  };
2859  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2860  if (bodyBuilderFn)
2861  wrapper = wrappedBuilderFn;
2862 
2863  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2864  wrapper);
2865 }
2866 
2867 LogicalResult ParallelOp::verify() {
2868  // Check that there is at least one value in lowerBound, upperBound and step.
2869  // It is sufficient to test only step, because it is ensured already that the
2870  // number of elements in lowerBound, upperBound and step are the same.
2871  Operation::operand_range stepValues = getStep();
2872  if (stepValues.empty())
2873  return emitOpError(
2874  "needs at least one tuple element for lowerBound, upperBound and step");
2875 
2876  // Check whether all constant step values are positive.
2877  for (Value stepValue : stepValues)
2878  if (auto cst = getConstantIntValue(stepValue))
2879  if (*cst <= 0)
2880  return emitOpError("constant step operand must be positive");
2881 
2882  // Check that the body defines the same number of block arguments as the
2883  // number of tuple elements in step.
2884  Block *body = getBody();
2885  if (body->getNumArguments() != stepValues.size())
2886  return emitOpError() << "expects the same number of induction variables: "
2887  << body->getNumArguments()
2888  << " as bound and step values: " << stepValues.size();
2889  for (auto arg : body->getArguments())
2890  if (!arg.getType().isIndex())
2891  return emitOpError(
2892  "expects arguments for the induction variable to be of index type");
2893 
2894  // Check that the terminator is an scf.reduce op.
2895  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2896  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2897  if (!reduceOp)
2898  return failure();
2899 
2900  // Check that the number of results is the same as the number of reductions.
2901  auto resultsSize = getResults().size();
2902  auto reductionsSize = reduceOp.getReductions().size();
2903  auto initValsSize = getInitVals().size();
2904  if (resultsSize != reductionsSize)
2905  return emitOpError() << "expects number of results: " << resultsSize
2906  << " to be the same as number of reductions: "
2907  << reductionsSize;
2908  if (resultsSize != initValsSize)
2909  return emitOpError() << "expects number of results: " << resultsSize
2910  << " to be the same as number of initial values: "
2911  << initValsSize;
2912 
2913  // Check that the types of the results and reductions are the same.
2914  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2915  auto resultType = getOperation()->getResult(i).getType();
2916  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2917  if (resultType != reductionOperandType)
2918  return reduceOp.emitOpError()
2919  << "expects type of " << i
2920  << "-th reduction operand: " << reductionOperandType
2921  << " to be the same as the " << i
2922  << "-th result type: " << resultType;
2923  }
2924  return success();
2925 }
2926 
2927 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2928  auto &builder = parser.getBuilder();
2929  // Parse an opening `(` followed by induction variables followed by `)`
2932  return failure();
2933 
2934  // Parse loop bounds.
2936  if (parser.parseEqual() ||
2937  parser.parseOperandList(lower, ivs.size(),
2939  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2940  return failure();
2941 
2943  if (parser.parseKeyword("to") ||
2944  parser.parseOperandList(upper, ivs.size(),
2946  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2947  return failure();
2948 
2949  // Parse step values.
2951  if (parser.parseKeyword("step") ||
2952  parser.parseOperandList(steps, ivs.size(),
2954  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2955  return failure();
2956 
2957  // Parse init values.
2959  if (succeeded(parser.parseOptionalKeyword("init"))) {
2960  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2961  return failure();
2962  }
2963 
2964  // Parse optional results in case there is a reduce.
2965  if (parser.parseOptionalArrowTypeList(result.types))
2966  return failure();
2967 
2968  // Now parse the body.
2969  Region *body = result.addRegion();
2970  for (auto &iv : ivs)
2971  iv.type = builder.getIndexType();
2972  if (parser.parseRegion(*body, ivs))
2973  return failure();
2974 
2975  // Set `operandSegmentSizes` attribute.
2976  result.addAttribute(
2977  ParallelOp::getOperandSegmentSizeAttr(),
2978  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2979  static_cast<int32_t>(upper.size()),
2980  static_cast<int32_t>(steps.size()),
2981  static_cast<int32_t>(initVals.size())}));
2982 
2983  // Parse attributes.
2984  if (parser.parseOptionalAttrDict(result.attributes) ||
2985  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2986  result.operands))
2987  return failure();
2988 
2989  // Add a terminator if none was parsed.
2990  ParallelOp::ensureTerminator(*body, builder, result.location);
2991  return success();
2992 }
2993 
2995  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2996  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2997  if (!getInitVals().empty())
2998  p << " init (" << getInitVals() << ")";
2999  p.printOptionalArrowTypeList(getResultTypes());
3000  p << ' ';
3001  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3003  (*this)->getAttrs(),
3004  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3005 }
3006 
3007 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3008 
3009 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3010  return SmallVector<Value>{getBody()->getArguments()};
3011 }
3012 
3013 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3014  return getLowerBound();
3015 }
3016 
3017 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3018  return getUpperBound();
3019 }
3020 
3021 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3022  return getStep();
3023 }
3024 
3026  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3027  if (!ivArg)
3028  return ParallelOp();
3029  assert(ivArg.getOwner() && "unlinked block argument");
3030  auto *containingOp = ivArg.getOwner()->getParentOp();
3031  return dyn_cast<ParallelOp>(containingOp);
3032 }
3033 
3034 namespace {
3035 // Collapse loop dimensions that perform a single iteration.
3036 struct ParallelOpSingleOrZeroIterationDimsFolder
3037  : public OpRewritePattern<ParallelOp> {
3039 
3040  LogicalResult matchAndRewrite(ParallelOp op,
3041  PatternRewriter &rewriter) const override {
3042  Location loc = op.getLoc();
3043 
3044  // Compute new loop bounds that omit all single-iteration loop dimensions.
3045  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3046  IRMapping mapping;
3047  for (auto [lb, ub, step, iv] :
3048  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3049  op.getInductionVars())) {
3050  auto numIterations = constantTripCount(lb, ub, step);
3051  if (numIterations.has_value()) {
3052  // Remove the loop if it performs zero iterations.
3053  if (*numIterations == 0) {
3054  rewriter.replaceOp(op, op.getInitVals());
3055  return success();
3056  }
3057  // Replace the loop induction variable by the lower bound if the loop
3058  // performs a single iteration. Otherwise, copy the loop bounds.
3059  if (*numIterations == 1) {
3060  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3061  continue;
3062  }
3063  }
3064  newLowerBounds.push_back(lb);
3065  newUpperBounds.push_back(ub);
3066  newSteps.push_back(step);
3067  }
3068  // Exit if none of the loop dimensions perform a single iteration.
3069  if (newLowerBounds.size() == op.getLowerBound().size())
3070  return failure();
3071 
3072  if (newLowerBounds.empty()) {
3073  // All of the loop dimensions perform a single iteration. Inline
3074  // loop body and nested ReduceOp's
3075  SmallVector<Value> results;
3076  results.reserve(op.getInitVals().size());
3077  for (auto &bodyOp : op.getBody()->without_terminator())
3078  rewriter.clone(bodyOp, mapping);
3079  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3080  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3081  Block &reduceBlock = reduceOp.getReductions()[i].front();
3082  auto initValIndex = results.size();
3083  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3084  mapping.map(reduceBlock.getArgument(1),
3085  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3086  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3087  rewriter.clone(reduceBodyOp, mapping);
3088 
3089  auto result = mapping.lookupOrDefault(
3090  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3091  results.push_back(result);
3092  }
3093 
3094  rewriter.replaceOp(op, results);
3095  return success();
3096  }
3097  // Replace the parallel loop by lower-dimensional parallel loop.
3098  auto newOp =
3099  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3100  newSteps, op.getInitVals(), nullptr);
3101  // Erase the empty block that was inserted by the builder.
3102  rewriter.eraseBlock(newOp.getBody());
3103  // Clone the loop body and remap the block arguments of the collapsed loops
3104  // (inlining does not support a cancellable block argument mapping).
3105  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3106  newOp.getRegion().begin(), mapping);
3107  rewriter.replaceOp(op, newOp.getResults());
3108  return success();
3109  }
3110 };
3111 
3112 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3114 
3115  LogicalResult matchAndRewrite(ParallelOp op,
3116  PatternRewriter &rewriter) const override {
3117  Block &outerBody = *op.getBody();
3118  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3119  return failure();
3120 
3121  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3122  if (!innerOp)
3123  return failure();
3124 
3125  for (auto val : outerBody.getArguments())
3126  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3127  llvm::is_contained(innerOp.getUpperBound(), val) ||
3128  llvm::is_contained(innerOp.getStep(), val))
3129  return failure();
3130 
3131  // Reductions are not supported yet.
3132  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3133  return failure();
3134 
3135  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3136  ValueRange iterVals, ValueRange) {
3137  Block &innerBody = *innerOp.getBody();
3138  assert(iterVals.size() ==
3139  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3140  IRMapping mapping;
3141  mapping.map(outerBody.getArguments(),
3142  iterVals.take_front(outerBody.getNumArguments()));
3143  mapping.map(innerBody.getArguments(),
3144  iterVals.take_back(innerBody.getNumArguments()));
3145  for (Operation &op : innerBody.without_terminator())
3146  builder.clone(op, mapping);
3147  };
3148 
3149  auto concatValues = [](const auto &first, const auto &second) {
3150  SmallVector<Value> ret;
3151  ret.reserve(first.size() + second.size());
3152  ret.assign(first.begin(), first.end());
3153  ret.append(second.begin(), second.end());
3154  return ret;
3155  };
3156 
3157  auto newLowerBounds =
3158  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3159  auto newUpperBounds =
3160  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3161  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3162 
3163  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3164  newSteps, std::nullopt,
3165  bodyBuilder);
3166  return success();
3167  }
3168 };
3169 
3170 } // namespace
3171 
3172 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3173  MLIRContext *context) {
3174  results
3175  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3176  context);
3177 }
3178 
3179 /// Given the region at `index`, or the parent operation if `index` is None,
3180 /// return the successor regions. These are the regions that may be selected
3181 /// during the flow of control. `operands` is a set of optional attributes that
3182 /// correspond to a constant value for each operand, or null if that operand is
3183 /// not a constant.
3184 void ParallelOp::getSuccessorRegions(
3186  // Both the operation itself and the region may be branching into the body or
3187  // back into the operation itself. It is possible for loop not to enter the
3188  // body.
3189  regions.push_back(RegionSuccessor(&getRegion()));
3190  regions.push_back(RegionSuccessor());
3191 }
3192 
3193 //===----------------------------------------------------------------------===//
3194 // ReduceOp
3195 //===----------------------------------------------------------------------===//
3196 
3197 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3198 
3199 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3200  ValueRange operands) {
3201  result.addOperands(operands);
3202  for (Value v : operands) {
3203  OpBuilder::InsertionGuard guard(builder);
3204  Region *bodyRegion = result.addRegion();
3205  builder.createBlock(bodyRegion, {},
3206  ArrayRef<Type>{v.getType(), v.getType()},
3207  {result.location, result.location});
3208  }
3209 }
3210 
3211 LogicalResult ReduceOp::verifyRegions() {
3212  // The region of a ReduceOp has two arguments of the same type as its
3213  // corresponding operand.
3214  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3215  auto type = getOperands()[i].getType();
3216  Block &block = getReductions()[i].front();
3217  if (block.empty())
3218  return emitOpError() << i << "-th reduction has an empty body";
3219  if (block.getNumArguments() != 2 ||
3220  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3221  return arg.getType() != type;
3222  }))
3223  return emitOpError() << "expected two block arguments with type " << type
3224  << " in the " << i << "-th reduction region";
3225 
3226  // Check that the block is terminated by a ReduceReturnOp.
3227  if (!isa<ReduceReturnOp>(block.getTerminator()))
3228  return emitOpError("reduction bodies must be terminated with an "
3229  "'scf.reduce.return' op");
3230  }
3231 
3232  return success();
3233 }
3234 
3237  // No operands are forwarded to the next iteration.
3238  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3239 }
3240 
3241 //===----------------------------------------------------------------------===//
3242 // ReduceReturnOp
3243 //===----------------------------------------------------------------------===//
3244 
3245 LogicalResult ReduceReturnOp::verify() {
3246  // The type of the return value should be the same type as the types of the
3247  // block arguments of the reduction body.
3248  Block *reductionBody = getOperation()->getBlock();
3249  // Should already be verified by an op trait.
3250  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3251  Type expectedResultType = reductionBody->getArgument(0).getType();
3252  if (expectedResultType != getResult().getType())
3253  return emitOpError() << "must have type " << expectedResultType
3254  << " (the type of the reduction inputs)";
3255  return success();
3256 }
3257 
3258 //===----------------------------------------------------------------------===//
3259 // WhileOp
3260 //===----------------------------------------------------------------------===//
3261 
3262 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3263  ::mlir::OperationState &odsState, TypeRange resultTypes,
3264  ValueRange operands, BodyBuilderFn beforeBuilder,
3265  BodyBuilderFn afterBuilder) {
3266  odsState.addOperands(operands);
3267  odsState.addTypes(resultTypes);
3268 
3269  OpBuilder::InsertionGuard guard(odsBuilder);
3270 
3271  // Build before region.
3272  SmallVector<Location, 4> beforeArgLocs;
3273  beforeArgLocs.reserve(operands.size());
3274  for (Value operand : operands) {
3275  beforeArgLocs.push_back(operand.getLoc());
3276  }
3277 
3278  Region *beforeRegion = odsState.addRegion();
3279  Block *beforeBlock = odsBuilder.createBlock(
3280  beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
3281  if (beforeBuilder)
3282  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3283 
3284  // Build after region.
3285  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3286 
3287  Region *afterRegion = odsState.addRegion();
3288  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3289  resultTypes, afterArgLocs);
3290 
3291  if (afterBuilder)
3292  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3293 }
3294 
3295 ConditionOp WhileOp::getConditionOp() {
3296  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3297 }
3298 
3299 YieldOp WhileOp::getYieldOp() {
3300  return cast<YieldOp>(getAfterBody()->getTerminator());
3301 }
3302 
3303 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3304  return getYieldOp().getResultsMutable();
3305 }
3306 
3307 Block::BlockArgListType WhileOp::getBeforeArguments() {
3308  return getBeforeBody()->getArguments();
3309 }
3310 
3311 Block::BlockArgListType WhileOp::getAfterArguments() {
3312  return getAfterBody()->getArguments();
3313 }
3314 
3315 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3316  return getBeforeArguments();
3317 }
3318 
3319 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3320  assert(point == getBefore() &&
3321  "WhileOp is expected to branch only to the first region");
3322  return getInits();
3323 }
3324 
3325 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3327  // The parent op always branches to the condition region.
3328  if (point.isParent()) {
3329  regions.emplace_back(&getBefore(), getBefore().getArguments());
3330  return;
3331  }
3332 
3333  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3334  "there are only two regions in a WhileOp");
3335  // The body region always branches back to the condition region.
3336  if (point == getAfter()) {
3337  regions.emplace_back(&getBefore(), getBefore().getArguments());
3338  return;
3339  }
3340 
3341  regions.emplace_back(getResults());
3342  regions.emplace_back(&getAfter(), getAfter().getArguments());
3343 }
3344 
3345 SmallVector<Region *> WhileOp::getLoopRegions() {
3346  return {&getBefore(), &getAfter()};
3347 }
3348 
3349 /// Parses a `while` op.
3350 ///
3351 /// op ::= `scf.while` assignments `:` function-type region `do` region
3352 /// `attributes` attribute-dict
3353 /// initializer ::= /* empty */ | `(` assignment-list `)`
3354 /// assignment-list ::= assignment | assignment `,` assignment-list
3355 /// assignment ::= ssa-value `=` ssa-value
3356 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3359  Region *before = result.addRegion();
3360  Region *after = result.addRegion();
3361 
3362  OptionalParseResult listResult =
3363  parser.parseOptionalAssignmentList(regionArgs, operands);
3364  if (listResult.has_value() && failed(listResult.value()))
3365  return failure();
3366 
3367  FunctionType functionType;
3368  SMLoc typeLoc = parser.getCurrentLocation();
3369  if (failed(parser.parseColonType(functionType)))
3370  return failure();
3371 
3372  result.addTypes(functionType.getResults());
3373 
3374  if (functionType.getNumInputs() != operands.size()) {
3375  return parser.emitError(typeLoc)
3376  << "expected as many input types as operands "
3377  << "(expected " << operands.size() << " got "
3378  << functionType.getNumInputs() << ")";
3379  }
3380 
3381  // Resolve input operands.
3382  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3383  parser.getCurrentLocation(),
3384  result.operands)))
3385  return failure();
3386 
3387  // Propagate the types into the region arguments.
3388  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3389  regionArgs[i].type = functionType.getInput(i);
3390 
3391  return failure(parser.parseRegion(*before, regionArgs) ||
3392  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3394 }
3395 
3396 /// Prints a `while` op.
3398  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3399  p << " : ";
3400  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3401  p << ' ';
3402  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3403  p << " do ";
3404  p.printRegion(getAfter());
3405  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3406 }
3407 
3408 /// Verifies that two ranges of types match, i.e. have the same number of
3409 /// entries and that types are pairwise equals. Reports errors on the given
3410 /// operation in case of mismatch.
3411 template <typename OpTy>
3412 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3413  TypeRange right, StringRef message) {
3414  if (left.size() != right.size())
3415  return op.emitOpError("expects the same number of ") << message;
3416 
3417  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3418  if (left[i] != right[i]) {
3419  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3420  << message;
3421  diag.attachNote() << "for argument " << i << ", found " << left[i]
3422  << " and " << right[i];
3423  return diag;
3424  }
3425  }
3426 
3427  return success();
3428 }
3429 
3430 LogicalResult scf::WhileOp::verify() {
3431  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3432  *this, getBefore(),
3433  "expects the 'before' region to terminate with 'scf.condition'");
3434  if (!beforeTerminator)
3435  return failure();
3436 
3437  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3438  *this, getAfter(),
3439  "expects the 'after' region to terminate with 'scf.yield'");
3440  return success(afterTerminator != nullptr);
3441 }
3442 
3443 namespace {
3444 /// Replace uses of the condition within the do block with true, since otherwise
3445 /// the block would not be evaluated.
3446 ///
3447 /// scf.while (..) : (i1, ...) -> ... {
3448 /// %condition = call @evaluate_condition() : () -> i1
3449 /// scf.condition(%condition) %condition : i1, ...
3450 /// } do {
3451 /// ^bb0(%arg0: i1, ...):
3452 /// use(%arg0)
3453 /// ...
3454 ///
3455 /// becomes
3456 /// scf.while (..) : (i1, ...) -> ... {
3457 /// %condition = call @evaluate_condition() : () -> i1
3458 /// scf.condition(%condition) %condition : i1, ...
3459 /// } do {
3460 /// ^bb0(%arg0: i1, ...):
3461 /// use(%true)
3462 /// ...
3463 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3465 
3466  LogicalResult matchAndRewrite(WhileOp op,
3467  PatternRewriter &rewriter) const override {
3468  auto term = op.getConditionOp();
3469 
3470  // These variables serve to prevent creating duplicate constants
3471  // and hold constant true or false values.
3472  Value constantTrue = nullptr;
3473 
3474  bool replaced = false;
3475  for (auto yieldedAndBlockArgs :
3476  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3477  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3478  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3479  if (!constantTrue)
3480  constantTrue = rewriter.create<arith::ConstantOp>(
3481  op.getLoc(), term.getCondition().getType(),
3482  rewriter.getBoolAttr(true));
3483 
3484  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3485  constantTrue);
3486  replaced = true;
3487  }
3488  }
3489  }
3490  return success(replaced);
3491  }
3492 };
3493 
3494 /// Remove loop invariant arguments from `before` block of scf.while.
3495 /// A before block argument is considered loop invariant if :-
3496 /// 1. i-th yield operand is equal to the i-th while operand.
3497 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3498 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3499 /// iter argument/while operand.
3500 /// For the arguments which are removed, their uses inside scf.while
3501 /// are replaced with their corresponding initial value.
3502 ///
3503 /// Eg:
3504 /// INPUT :-
3505 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3506 /// ..., %argN_before = %N)
3507 /// {
3508 /// ...
3509 /// scf.condition(%cond) %arg1_before, %arg0_before,
3510 /// %arg2_before, %arg0_before, ...
3511 /// } do {
3512 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3513 /// ..., %argK_after):
3514 /// ...
3515 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3516 /// }
3517 ///
3518 /// OUTPUT :-
3519 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3520 /// %N)
3521 /// {
3522 /// ...
3523 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3524 /// } do {
3525 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3526 /// ..., %argK_after):
3527 /// ...
3528 /// scf.yield %arg1_after, ..., %argN
3529 /// }
3530 ///
3531 /// EXPLANATION:
3532 /// We iterate over each yield operand.
3533 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3534 /// %arg0_before, which in turn is the 0-th iter argument. So we
3535 /// remove 0-th before block argument and yield operand, and replace
3536 /// all uses of the 0-th before block argument with its initial value
3537 /// %a.
3538 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3539 /// value. So we remove this operand and the corresponding before
3540 /// block argument and replace all uses of 1-th before block argument
3541 /// with %b.
3542 struct RemoveLoopInvariantArgsFromBeforeBlock
3543  : public OpRewritePattern<WhileOp> {
3545 
3546  LogicalResult matchAndRewrite(WhileOp op,
3547  PatternRewriter &rewriter) const override {
3548  Block &afterBlock = *op.getAfterBody();
3549  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3550  ConditionOp condOp = op.getConditionOp();
3551  OperandRange condOpArgs = condOp.getArgs();
3552  Operation *yieldOp = afterBlock.getTerminator();
3553  ValueRange yieldOpArgs = yieldOp->getOperands();
3554 
3555  bool canSimplify = false;
3556  for (const auto &it :
3557  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3558  auto index = static_cast<unsigned>(it.index());
3559  auto [initVal, yieldOpArg] = it.value();
3560  // If i-th yield operand is equal to the i-th operand of the scf.while,
3561  // the i-th before block argument is a loop invariant.
3562  if (yieldOpArg == initVal) {
3563  canSimplify = true;
3564  break;
3565  }
3566  // If the i-th yield operand is k-th after block argument, then we check
3567  // if the (k+1)-th condition op operand is equal to either the i-th before
3568  // block argument or the initial value of i-th before block argument. If
3569  // the comparison results `true`, i-th before block argument is a loop
3570  // invariant.
3571  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3572  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3573  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3574  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3575  canSimplify = true;
3576  break;
3577  }
3578  }
3579  }
3580 
3581  if (!canSimplify)
3582  return failure();
3583 
3584  SmallVector<Value> newInitArgs, newYieldOpArgs;
3585  DenseMap<unsigned, Value> beforeBlockInitValMap;
3586  SmallVector<Location> newBeforeBlockArgLocs;
3587  for (const auto &it :
3588  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3589  auto index = static_cast<unsigned>(it.index());
3590  auto [initVal, yieldOpArg] = it.value();
3591 
3592  // If i-th yield operand is equal to the i-th operand of the scf.while,
3593  // the i-th before block argument is a loop invariant.
3594  if (yieldOpArg == initVal) {
3595  beforeBlockInitValMap.insert({index, initVal});
3596  continue;
3597  } else {
3598  // If the i-th yield operand is k-th after block argument, then we check
3599  // if the (k+1)-th condition op operand is equal to either the i-th
3600  // before block argument or the initial value of i-th before block
3601  // argument. If the comparison results `true`, i-th before block
3602  // argument is a loop invariant.
3603  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3604  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3605  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3606  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3607  beforeBlockInitValMap.insert({index, initVal});
3608  continue;
3609  }
3610  }
3611  }
3612  newInitArgs.emplace_back(initVal);
3613  newYieldOpArgs.emplace_back(yieldOpArg);
3614  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3615  }
3616 
3617  {
3618  OpBuilder::InsertionGuard g(rewriter);
3619  rewriter.setInsertionPoint(yieldOp);
3620  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3621  }
3622 
3623  auto newWhile =
3624  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3625 
3626  Block &newBeforeBlock = *rewriter.createBlock(
3627  &newWhile.getBefore(), /*insertPt*/ {},
3628  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3629 
3630  Block &beforeBlock = *op.getBeforeBody();
3631  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3632  // For each i-th before block argument we find it's replacement value as :-
3633  // 1. If i-th before block argument is a loop invariant, we fetch it's
3634  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3635  // 2. Else we fetch j-th new before block argument as the replacement
3636  // value of i-th before block argument.
3637  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3638  // If the index 'i' argument was a loop invariant we fetch it's initial
3639  // value from `beforeBlockInitValMap`.
3640  if (beforeBlockInitValMap.count(i) != 0)
3641  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3642  else
3643  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3644  }
3645 
3646  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3647  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3648  newWhile.getAfter().begin());
3649 
3650  rewriter.replaceOp(op, newWhile.getResults());
3651  return success();
3652  }
3653 };
3654 
3655 /// Remove loop invariant value from result (condition op) of scf.while.
3656 /// A value is considered loop invariant if the final value yielded by
3657 /// scf.condition is defined outside of the `before` block. We remove the
3658 /// corresponding argument in `after` block and replace the use with the value.
3659 /// We also replace the use of the corresponding result of scf.while with the
3660 /// value.
3661 ///
3662 /// Eg:
3663 /// INPUT :-
3664 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3665 /// %argN_before = %N) {
3666 /// ...
3667 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3668 /// } do {
3669 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3670 /// ...
3671 /// some_func(%arg1_after)
3672 /// ...
3673 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3674 /// }
3675 ///
3676 /// OUTPUT :-
3677 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3678 /// ...
3679 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3680 /// } do {
3681 /// ^bb0(%arg0, %arg3, ..., %argM):
3682 /// ...
3683 /// some_func(%a)
3684 /// ...
3685 /// scf.yield %arg0, %b, ..., %argN
3686 /// }
3687 ///
3688 /// EXPLANATION:
3689 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3690 /// before block of scf.while, so they get removed.
3691 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3692 /// replaced by %b.
3693 /// 3. The corresponding after block argument %arg1_after's uses are
3694 /// replaced by %a and %arg2_after's uses are replaced by %b.
3695 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3697 
3698  LogicalResult matchAndRewrite(WhileOp op,
3699  PatternRewriter &rewriter) const override {
3700  Block &beforeBlock = *op.getBeforeBody();
3701  ConditionOp condOp = op.getConditionOp();
3702  OperandRange condOpArgs = condOp.getArgs();
3703 
3704  bool canSimplify = false;
3705  for (Value condOpArg : condOpArgs) {
3706  // Those values not defined within `before` block will be considered as
3707  // loop invariant values. We map the corresponding `index` with their
3708  // value.
3709  if (condOpArg.getParentBlock() != &beforeBlock) {
3710  canSimplify = true;
3711  break;
3712  }
3713  }
3714 
3715  if (!canSimplify)
3716  return failure();
3717 
3718  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3719 
3720  SmallVector<Value> newCondOpArgs;
3721  SmallVector<Type> newAfterBlockType;
3722  DenseMap<unsigned, Value> condOpInitValMap;
3723  SmallVector<Location> newAfterBlockArgLocs;
3724  for (const auto &it : llvm::enumerate(condOpArgs)) {
3725  auto index = static_cast<unsigned>(it.index());
3726  Value condOpArg = it.value();
3727  // Those values not defined within `before` block will be considered as
3728  // loop invariant values. We map the corresponding `index` with their
3729  // value.
3730  if (condOpArg.getParentBlock() != &beforeBlock) {
3731  condOpInitValMap.insert({index, condOpArg});
3732  } else {
3733  newCondOpArgs.emplace_back(condOpArg);
3734  newAfterBlockType.emplace_back(condOpArg.getType());
3735  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3736  }
3737  }
3738 
3739  {
3740  OpBuilder::InsertionGuard g(rewriter);
3741  rewriter.setInsertionPoint(condOp);
3742  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3743  newCondOpArgs);
3744  }
3745 
3746  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3747  op.getOperands());
3748 
3749  Block &newAfterBlock =
3750  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3751  newAfterBlockType, newAfterBlockArgLocs);
3752 
3753  Block &afterBlock = *op.getAfterBody();
3754  // Since a new scf.condition op was created, we need to fetch the new
3755  // `after` block arguments which will be used while replacing operations of
3756  // previous scf.while's `after` blocks. We'd also be fetching new result
3757  // values too.
3758  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3759  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3760  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3761  Value afterBlockArg, result;
3762  // If index 'i' argument was loop invariant we fetch it's value from the
3763  // `condOpInitMap` map.
3764  if (condOpInitValMap.count(i) != 0) {
3765  afterBlockArg = condOpInitValMap[i];
3766  result = afterBlockArg;
3767  } else {
3768  afterBlockArg = newAfterBlock.getArgument(j);
3769  result = newWhile.getResult(j);
3770  j++;
3771  }
3772  newAfterBlockArgs[i] = afterBlockArg;
3773  newWhileResults[i] = result;
3774  }
3775 
3776  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3777  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3778  newWhile.getBefore().begin());
3779 
3780  rewriter.replaceOp(op, newWhileResults);
3781  return success();
3782  }
3783 };
3784 
3785 /// Remove WhileOp results that are also unused in 'after' block.
3786 ///
3787 /// %0:2 = scf.while () : () -> (i32, i64) {
3788 /// %condition = "test.condition"() : () -> i1
3789 /// %v1 = "test.get_some_value"() : () -> i32
3790 /// %v2 = "test.get_some_value"() : () -> i64
3791 /// scf.condition(%condition) %v1, %v2 : i32, i64
3792 /// } do {
3793 /// ^bb0(%arg0: i32, %arg1: i64):
3794 /// "test.use"(%arg0) : (i32) -> ()
3795 /// scf.yield
3796 /// }
3797 /// return %0#0 : i32
3798 ///
3799 /// becomes
3800 /// %0 = scf.while () : () -> (i32) {
3801 /// %condition = "test.condition"() : () -> i1
3802 /// %v1 = "test.get_some_value"() : () -> i32
3803 /// %v2 = "test.get_some_value"() : () -> i64
3804 /// scf.condition(%condition) %v1 : i32
3805 /// } do {
3806 /// ^bb0(%arg0: i32):
3807 /// "test.use"(%arg0) : (i32) -> ()
3808 /// scf.yield
3809 /// }
3810 /// return %0 : i32
3811 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3813 
3814  LogicalResult matchAndRewrite(WhileOp op,
3815  PatternRewriter &rewriter) const override {
3816  auto term = op.getConditionOp();
3817  auto afterArgs = op.getAfterArguments();
3818  auto termArgs = term.getArgs();
3819 
3820  // Collect results mapping, new terminator args and new result types.
3821  SmallVector<unsigned> newResultsIndices;
3822  SmallVector<Type> newResultTypes;
3823  SmallVector<Value> newTermArgs;
3824  SmallVector<Location> newArgLocs;
3825  bool needUpdate = false;
3826  for (const auto &it :
3827  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3828  auto i = static_cast<unsigned>(it.index());
3829  Value result = std::get<0>(it.value());
3830  Value afterArg = std::get<1>(it.value());
3831  Value termArg = std::get<2>(it.value());
3832  if (result.use_empty() && afterArg.use_empty()) {
3833  needUpdate = true;
3834  } else {
3835  newResultsIndices.emplace_back(i);
3836  newTermArgs.emplace_back(termArg);
3837  newResultTypes.emplace_back(result.getType());
3838  newArgLocs.emplace_back(result.getLoc());
3839  }
3840  }
3841 
3842  if (!needUpdate)
3843  return failure();
3844 
3845  {
3846  OpBuilder::InsertionGuard g(rewriter);
3847  rewriter.setInsertionPoint(term);
3848  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3849  newTermArgs);
3850  }
3851 
3852  auto newWhile =
3853  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3854 
3855  Block &newAfterBlock = *rewriter.createBlock(
3856  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3857 
3858  // Build new results list and new after block args (unused entries will be
3859  // null).
3860  SmallVector<Value> newResults(op.getNumResults());
3861  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3862  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3863  newResults[it.value()] = newWhile.getResult(it.index());
3864  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3865  }
3866 
3867  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3868  newWhile.getBefore().begin());
3869 
3870  Block &afterBlock = *op.getAfterBody();
3871  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3872 
3873  rewriter.replaceOp(op, newResults);
3874  return success();
3875  }
3876 };
3877 
3878 /// Replace operations equivalent to the condition in the do block with true,
3879 /// since otherwise the block would not be evaluated.
3880 ///
3881 /// scf.while (..) : (i32, ...) -> ... {
3882 /// %z = ... : i32
3883 /// %condition = cmpi pred %z, %a
3884 /// scf.condition(%condition) %z : i32, ...
3885 /// } do {
3886 /// ^bb0(%arg0: i32, ...):
3887 /// %condition2 = cmpi pred %arg0, %a
3888 /// use(%condition2)
3889 /// ...
3890 ///
3891 /// becomes
3892 /// scf.while (..) : (i32, ...) -> ... {
3893 /// %z = ... : i32
3894 /// %condition = cmpi pred %z, %a
3895 /// scf.condition(%condition) %z : i32, ...
3896 /// } do {
3897 /// ^bb0(%arg0: i32, ...):
3898 /// use(%true)
3899 /// ...
3900 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3902 
3903  LogicalResult matchAndRewrite(scf::WhileOp op,
3904  PatternRewriter &rewriter) const override {
3905  using namespace scf;
3906  auto cond = op.getConditionOp();
3907  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3908  if (!cmp)
3909  return failure();
3910  bool changed = false;
3911  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3912  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3913  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3914  continue;
3915  for (OpOperand &u :
3916  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3917  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3918  if (!cmp2)
3919  continue;
3920  // For a binary operator 1-opIdx gets the other side.
3921  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3922  continue;
3923  bool samePredicate;
3924  if (cmp2.getPredicate() == cmp.getPredicate())
3925  samePredicate = true;
3926  else if (cmp2.getPredicate() ==
3927  arith::invertPredicate(cmp.getPredicate()))
3928  samePredicate = false;
3929  else
3930  continue;
3931 
3932  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3933  1);
3934  changed = true;
3935  }
3936  }
3937  }
3938  return success(changed);
3939  }
3940 };
3941 
3942 /// Remove unused init/yield args.
3943 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3945 
3946  LogicalResult matchAndRewrite(WhileOp op,
3947  PatternRewriter &rewriter) const override {
3948 
3949  if (!llvm::any_of(op.getBeforeArguments(),
3950  [](Value arg) { return arg.use_empty(); }))
3951  return rewriter.notifyMatchFailure(op, "No args to remove");
3952 
3953  YieldOp yield = op.getYieldOp();
3954 
3955  // Collect results mapping, new terminator args and new result types.
3956  SmallVector<Value> newYields;
3957  SmallVector<Value> newInits;
3958  llvm::BitVector argsToErase;
3959 
3960  size_t argsCount = op.getBeforeArguments().size();
3961  newYields.reserve(argsCount);
3962  newInits.reserve(argsCount);
3963  argsToErase.reserve(argsCount);
3964  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
3965  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
3966  if (beforeArg.use_empty()) {
3967  argsToErase.push_back(true);
3968  } else {
3969  argsToErase.push_back(false);
3970  newYields.emplace_back(yieldValue);
3971  newInits.emplace_back(initValue);
3972  }
3973  }
3974 
3975  Block &beforeBlock = *op.getBeforeBody();
3976  Block &afterBlock = *op.getAfterBody();
3977 
3978  beforeBlock.eraseArguments(argsToErase);
3979 
3980  Location loc = op.getLoc();
3981  auto newWhileOp =
3982  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
3983  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
3984  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
3985  Block &newAfterBlock = *newWhileOp.getAfterBody();
3986 
3987  OpBuilder::InsertionGuard g(rewriter);
3988  rewriter.setInsertionPoint(yield);
3989  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3990 
3991  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
3992  newBeforeBlock.getArguments());
3993  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
3994  newAfterBlock.getArguments());
3995 
3996  rewriter.replaceOp(op, newWhileOp.getResults());
3997  return success();
3998  }
3999 };
4000 
4001 /// Remove duplicated ConditionOp args.
4002 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4004 
4005  LogicalResult matchAndRewrite(WhileOp op,
4006  PatternRewriter &rewriter) const override {
4007  ConditionOp condOp = op.getConditionOp();
4008  ValueRange condOpArgs = condOp.getArgs();
4009 
4011  for (Value arg : condOpArgs)
4012  argsSet.insert(arg);
4013 
4014  if (argsSet.size() == condOpArgs.size())
4015  return rewriter.notifyMatchFailure(op, "No results to remove");
4016 
4017  llvm::SmallDenseMap<Value, unsigned> argsMap;
4018  SmallVector<Value> newArgs;
4019  argsMap.reserve(condOpArgs.size());
4020  newArgs.reserve(condOpArgs.size());
4021  for (Value arg : condOpArgs) {
4022  if (!argsMap.count(arg)) {
4023  auto pos = static_cast<unsigned>(argsMap.size());
4024  argsMap.insert({arg, pos});
4025  newArgs.emplace_back(arg);
4026  }
4027  }
4028 
4029  ValueRange argsRange(newArgs);
4030 
4031  Location loc = op.getLoc();
4032  auto newWhileOp = rewriter.create<scf::WhileOp>(
4033  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4034  /*afterBody*/ nullptr);
4035  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4036  Block &newAfterBlock = *newWhileOp.getAfterBody();
4037 
4038  SmallVector<Value> afterArgsMapping;
4039  SmallVector<Value> resultsMapping;
4040  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4041  auto it = argsMap.find(arg);
4042  assert(it != argsMap.end());
4043  auto pos = it->second;
4044  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4045  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4046  }
4047 
4048  OpBuilder::InsertionGuard g(rewriter);
4049  rewriter.setInsertionPoint(condOp);
4050  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4051  argsRange);
4052 
4053  Block &beforeBlock = *op.getBeforeBody();
4054  Block &afterBlock = *op.getAfterBody();
4055 
4056  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4057  newBeforeBlock.getArguments());
4058  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4059  rewriter.replaceOp(op, resultsMapping);
4060  return success();
4061  }
4062 };
4063 
4064 /// If both ranges contain same values return mappping indices from args2 to
4065 /// args1. Otherwise return std::nullopt.
4066 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4067  ValueRange args2) {
4068  if (args1.size() != args2.size())
4069  return std::nullopt;
4070 
4071  SmallVector<unsigned> ret(args1.size());
4072  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4073  auto it = llvm::find(args2, arg1);
4074  if (it == args2.end())
4075  return std::nullopt;
4076 
4077  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4078  }
4079 
4080  return ret;
4081 }
4082 
4083 static bool hasDuplicates(ValueRange args) {
4084  llvm::SmallDenseSet<Value> set;
4085  for (Value arg : args) {
4086  if (set.contains(arg))
4087  return true;
4088 
4089  set.insert(arg);
4090  }
4091  return false;
4092 }
4093 
4094 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4095 /// `scf.condition` args into same order as block args. Update `after` block
4096 /// args and op result values accordingly.
4097 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4098 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4100 
4101  LogicalResult matchAndRewrite(WhileOp loop,
4102  PatternRewriter &rewriter) const override {
4103  auto oldBefore = loop.getBeforeBody();
4104  ConditionOp oldTerm = loop.getConditionOp();
4105  ValueRange beforeArgs = oldBefore->getArguments();
4106  ValueRange termArgs = oldTerm.getArgs();
4107  if (beforeArgs == termArgs)
4108  return failure();
4109 
4110  if (hasDuplicates(termArgs))
4111  return failure();
4112 
4113  auto mapping = getArgsMapping(beforeArgs, termArgs);
4114  if (!mapping)
4115  return failure();
4116 
4117  {
4118  OpBuilder::InsertionGuard g(rewriter);
4119  rewriter.setInsertionPoint(oldTerm);
4120  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4121  beforeArgs);
4122  }
4123 
4124  auto oldAfter = loop.getAfterBody();
4125 
4126  SmallVector<Type> newResultTypes(beforeArgs.size());
4127  for (auto &&[i, j] : llvm::enumerate(*mapping))
4128  newResultTypes[j] = loop.getResult(i).getType();
4129 
4130  auto newLoop = rewriter.create<WhileOp>(
4131  loop.getLoc(), newResultTypes, loop.getInits(),
4132  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4133  auto newBefore = newLoop.getBeforeBody();
4134  auto newAfter = newLoop.getAfterBody();
4135 
4136  SmallVector<Value> newResults(beforeArgs.size());
4137  SmallVector<Value> newAfterArgs(beforeArgs.size());
4138  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4139  newResults[i] = newLoop.getResult(j);
4140  newAfterArgs[i] = newAfter->getArgument(j);
4141  }
4142 
4143  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4144  newBefore->getArguments());
4145  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4146  newAfterArgs);
4147 
4148  rewriter.replaceOp(loop, newResults);
4149  return success();
4150  }
4151 };
4152 } // namespace
4153 
4154 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4155  MLIRContext *context) {
4156  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4157  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4158  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4159  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4160 }
4161 
4162 //===----------------------------------------------------------------------===//
4163 // IndexSwitchOp
4164 //===----------------------------------------------------------------------===//
4165 
4166 /// Parse the case regions and values.
4167 static ParseResult
4169  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4170  SmallVector<int64_t> caseValues;
4171  while (succeeded(p.parseOptionalKeyword("case"))) {
4172  int64_t value;
4173  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4174  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4175  return failure();
4176  caseValues.push_back(value);
4177  }
4178  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4179  return success();
4180 }
4181 
4182 /// Print the case regions and values.
4184  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4185  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4186  p.printNewline();
4187  p << "case " << value << ' ';
4188  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4189  }
4190 }
4191 
4192 LogicalResult scf::IndexSwitchOp::verify() {
4193  if (getCases().size() != getCaseRegions().size()) {
4194  return emitOpError("has ")
4195  << getCaseRegions().size() << " case regions but "
4196  << getCases().size() << " case values";
4197  }
4198 
4199  DenseSet<int64_t> valueSet;
4200  for (int64_t value : getCases())
4201  if (!valueSet.insert(value).second)
4202  return emitOpError("has duplicate case value: ") << value;
4203  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4204  auto yield = dyn_cast<YieldOp>(region.front().back());
4205  if (!yield)
4206  return emitOpError("expected region to end with scf.yield, but got ")
4207  << region.front().back().getName();
4208 
4209  if (yield.getNumOperands() != getNumResults()) {
4210  return (emitOpError("expected each region to return ")
4211  << getNumResults() << " values, but " << name << " returns "
4212  << yield.getNumOperands())
4213  .attachNote(yield.getLoc())
4214  << "see yield operation here";
4215  }
4216  for (auto [idx, result, operand] :
4217  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4218  yield.getOperandTypes())) {
4219  if (result == operand)
4220  continue;
4221  return (emitOpError("expected result #")
4222  << idx << " of each region to be " << result)
4223  .attachNote(yield.getLoc())
4224  << name << " returns " << operand << " here";
4225  }
4226  return success();
4227  };
4228 
4229  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4230  return failure();
4231  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4232  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4233  return failure();
4234 
4235  return success();
4236 }
4237 
4238 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4239 
4240 Block &scf::IndexSwitchOp::getDefaultBlock() {
4241  return getDefaultRegion().front();
4242 }
4243 
4244 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4245  assert(idx < getNumCases() && "case index out-of-bounds");
4246  return getCaseRegions()[idx].front();
4247 }
4248 
4249 void IndexSwitchOp::getSuccessorRegions(
4251  // All regions branch back to the parent op.
4252  if (!point.isParent()) {
4253  successors.emplace_back(getResults());
4254  return;
4255  }
4256 
4257  llvm::copy(getRegions(), std::back_inserter(successors));
4258 }
4259 
4260 void IndexSwitchOp::getEntrySuccessorRegions(
4261  ArrayRef<Attribute> operands,
4262  SmallVectorImpl<RegionSuccessor> &successors) {
4263  FoldAdaptor adaptor(operands, *this);
4264 
4265  // If a constant was not provided, all regions are possible successors.
4266  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4267  if (!arg) {
4268  llvm::copy(getRegions(), std::back_inserter(successors));
4269  return;
4270  }
4271 
4272  // Otherwise, try to find a case with a matching value. If not, the
4273  // default region is the only successor.
4274  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4275  if (caseValue == arg.getInt()) {
4276  successors.emplace_back(&caseRegion);
4277  return;
4278  }
4279  }
4280  successors.emplace_back(&getDefaultRegion());
4281 }
4282 
4283 void IndexSwitchOp::getRegionInvocationBounds(
4285  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4286  if (!operandValue) {
4287  // All regions are invoked at most once.
4288  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4289  return;
4290  }
4291 
4292  unsigned liveIndex = getNumRegions() - 1;
4293  const auto *it = llvm::find(getCases(), operandValue.getInt());
4294  if (it != getCases().end())
4295  liveIndex = std::distance(getCases().begin(), it);
4296  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4297  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4298 }
4299 
4300 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4302 
4303  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4304  PatternRewriter &rewriter) const override {
4305  // If `op.getArg()` is a constant, select the region that matches with
4306  // the constant value. Use the default region if no matche is found.
4307  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4308  if (!maybeCst.has_value())
4309  return failure();
4310  int64_t cst = *maybeCst;
4311  int64_t caseIdx, e = op.getNumCases();
4312  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4313  if (cst == op.getCases()[caseIdx])
4314  break;
4315  }
4316 
4317  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4318  : op.getDefaultRegion();
4319  Block &source = r.front();
4320  Operation *terminator = source.getTerminator();
4321  SmallVector<Value> results = terminator->getOperands();
4322 
4323  rewriter.inlineBlockBefore(&source, op);
4324  rewriter.eraseOp(terminator);
4325  // Repalce the operation with a potentially empty list of results.
4326  // Fold mechanism doesn't support the case where the result list is empty.
4327  rewriter.replaceOp(op, results);
4328 
4329  return success();
4330  }
4331 };
4332 
4333 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4334  MLIRContext *context) {
4335  results.add<FoldConstantCase>(context);
4336 }
4337 
4338 //===----------------------------------------------------------------------===//
4339 // TableGen'd op method definitions
4340 //===----------------------------------------------------------------------===//
4341 
4342 #define GET_OP_CLASSES
4343 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:712
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4168
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3412
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4183
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator begin()
Definition: Block.h:141
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:183
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:120
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:77
IndexType getIndexType()
Definition: Builders.cpp:75
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:439
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:586
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:587
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:842
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:57
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:75
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:450
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3025
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1952
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1453
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:267
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalables, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4303
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.