MLIR  21.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "mlir/IR/IRMapping.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::scf;
31 
32 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // SCFDialect Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SCFInlinerInterface : public DialectInlinerInterface {
41  // We don't have any special restrictions on what can be inlined into
42  // destination regions (e.g. while/conditional bodies). Always allow it.
43  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
44  IRMapping &valueMapping) const final {
45  return true;
46  }
47  // Operations in scf dialect are always legal to inline since they are
48  // pure.
49  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
50  return true;
51  }
52  // Handle the given inlined terminator by replacing it with a new operation
53  // as necessary. Required when the region has only one block.
54  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
55  auto retValOp = dyn_cast<scf::YieldOp>(op);
56  if (!retValOp)
57  return;
58 
59  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
60  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
61  }
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // SCFDialect
68 //===----------------------------------------------------------------------===//
69 
70 void SCFDialect::initialize() {
71  addOperations<
72 #define GET_OP_LIST
73 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
74  >();
75  addInterfaces<SCFInlinerInterface>();
76  declarePromisedInterfaces<bufferization::BufferDeallocationOpInterface,
77  InParallelOp, ReduceReturnOp>();
78  declarePromisedInterfaces<bufferization::BufferizableOpInterface, ConditionOp,
79  ExecuteRegionOp, ForOp, IfOp, IndexSwitchOp,
80  ForallOp, InParallelOp, WhileOp, YieldOp>();
81  declarePromisedInterface<ValueBoundsOpInterface, ForOp>();
82 }
83 
84 /// Default callback for IfOp builders. Inserts a yield without arguments.
86  builder.create<scf::YieldOp>(loc);
87 }
88 
89 /// Verifies that the first block of the given `region` is terminated by a
90 /// TerminatorTy. Reports errors on the given operation if it is not the case.
91 template <typename TerminatorTy>
92 static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
93  StringRef errorMessage) {
94  Operation *terminatorOperation = nullptr;
95  if (!region.empty() && !region.front().empty()) {
96  terminatorOperation = &region.front().back();
97  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
98  return yield;
99  }
100  auto diag = op->emitOpError(errorMessage);
101  if (terminatorOperation)
102  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
103  return nullptr;
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // ExecuteRegionOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Replaces the given op with the contents of the given single-block region,
111 /// using the operands of the block terminator to replace operation results.
112 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
113  Region &region, ValueRange blockArgs = {}) {
114  assert(llvm::hasSingleElement(region) && "expected single-region block");
115  Block *block = &region.front();
116  Operation *terminator = block->getTerminator();
117  ValueRange results = terminator->getOperands();
118  rewriter.inlineBlockBefore(block, op, blockArgs);
119  rewriter.replaceOp(op, results);
120  rewriter.eraseOp(terminator);
121 }
122 
123 ///
124 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
125 /// block+
126 /// `}`
127 ///
128 /// Example:
129 /// scf.execute_region -> i32 {
130 /// %idx = load %rI[%i] : memref<128xi32>
131 /// return %idx : i32
132 /// }
133 ///
134 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
135  OperationState &result) {
136  if (parser.parseOptionalArrowTypeList(result.types))
137  return failure();
138 
139  // Introduce the body region and parse it.
140  Region *body = result.addRegion();
141  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
142  parser.parseOptionalAttrDict(result.attributes))
143  return failure();
144 
145  return success();
146 }
147 
149  p.printOptionalArrowTypeList(getResultTypes());
150 
151  p << ' ';
152  p.printRegion(getRegion(),
153  /*printEntryBlockArgs=*/false,
154  /*printBlockTerminators=*/true);
155 
156  p.printOptionalAttrDict((*this)->getAttrs());
157 }
158 
159 LogicalResult ExecuteRegionOp::verify() {
160  if (getRegion().empty())
161  return emitOpError("region needs to have at least one block");
162  if (getRegion().front().getNumArguments() > 0)
163  return emitOpError("region cannot have any arguments");
164  return success();
165 }
166 
167 // Inline an ExecuteRegionOp if it only contains one block.
168 // "test.foo"() : () -> ()
169 // %v = scf.execute_region -> i64 {
170 // %x = "test.val"() : () -> i64
171 // scf.yield %x : i64
172 // }
173 // "test.bar"(%v) : (i64) -> ()
174 //
175 // becomes
176 //
177 // "test.foo"() : () -> ()
178 // %x = "test.val"() : () -> i64
179 // "test.bar"(%x) : (i64) -> ()
180 //
181 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
183 
184  LogicalResult matchAndRewrite(ExecuteRegionOp op,
185  PatternRewriter &rewriter) const override {
186  if (!llvm::hasSingleElement(op.getRegion()))
187  return failure();
188  replaceOpWithRegion(rewriter, op, op.getRegion());
189  return success();
190  }
191 };
192 
193 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
194 // TODO generalize the conditions for operations which can be inlined into.
195 // func @func_execute_region_elim() {
196 // "test.foo"() : () -> ()
197 // %v = scf.execute_region -> i64 {
198 // %c = "test.cmp"() : () -> i1
199 // cf.cond_br %c, ^bb2, ^bb3
200 // ^bb2:
201 // %x = "test.val1"() : () -> i64
202 // cf.br ^bb4(%x : i64)
203 // ^bb3:
204 // %y = "test.val2"() : () -> i64
205 // cf.br ^bb4(%y : i64)
206 // ^bb4(%z : i64):
207 // scf.yield %z : i64
208 // }
209 // "test.bar"(%v) : (i64) -> ()
210 // return
211 // }
212 //
213 // becomes
214 //
215 // func @func_execute_region_elim() {
216 // "test.foo"() : () -> ()
217 // %c = "test.cmp"() : () -> i1
218 // cf.cond_br %c, ^bb1, ^bb2
219 // ^bb1: // pred: ^bb0
220 // %x = "test.val1"() : () -> i64
221 // cf.br ^bb3(%x : i64)
222 // ^bb2: // pred: ^bb0
223 // %y = "test.val2"() : () -> i64
224 // cf.br ^bb3(%y : i64)
225 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
226 // "test.bar"(%z) : (i64) -> ()
227 // return
228 // }
229 //
230 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
232 
233  LogicalResult matchAndRewrite(ExecuteRegionOp op,
234  PatternRewriter &rewriter) const override {
235  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
236  return failure();
237 
238  Block *prevBlock = op->getBlock();
239  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
240  rewriter.setInsertionPointToEnd(prevBlock);
241 
242  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
243 
244  for (Block &blk : op.getRegion()) {
245  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
246  rewriter.setInsertionPoint(yieldOp);
247  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
248  yieldOp.getResults());
249  rewriter.eraseOp(yieldOp);
250  }
251  }
252 
253  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
254  SmallVector<Value> blockArgs;
255 
256  for (auto res : op.getResults())
257  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
258 
259  rewriter.replaceOp(op, blockArgs);
260  return success();
261  }
262 };
263 
264 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
265  MLIRContext *context) {
267 }
268 
269 void ExecuteRegionOp::getSuccessorRegions(
271  // If the predecessor is the ExecuteRegionOp, branch into the body.
272  if (point.isParent()) {
273  regions.push_back(RegionSuccessor(&getRegion()));
274  return;
275  }
276 
277  // Otherwise, the region branches back to the parent operation.
278  regions.push_back(RegionSuccessor(getResults()));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ConditionOp
283 //===----------------------------------------------------------------------===//
284 
287  assert((point.isParent() || point == getParentOp().getAfter()) &&
288  "condition op can only exit the loop or branch to the after"
289  "region");
290  // Pass all operands except the condition to the successor region.
291  return getArgsMutable();
292 }
293 
294 void ConditionOp::getSuccessorRegions(
296  FoldAdaptor adaptor(operands, *this);
297 
298  WhileOp whileOp = getParentOp();
299 
300  // Condition can either lead to the after region or back to the parent op
301  // depending on whether the condition is true or not.
302  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
303  if (!boolAttr || boolAttr.getValue())
304  regions.emplace_back(&whileOp.getAfter(),
305  whileOp.getAfter().getArguments());
306  if (!boolAttr || !boolAttr.getValue())
307  regions.emplace_back(whileOp.getResults());
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // ForOp
312 //===----------------------------------------------------------------------===//
313 
314 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
315  Value ub, Value step, ValueRange initArgs,
316  BodyBuilderFn bodyBuilder) {
317  OpBuilder::InsertionGuard guard(builder);
318 
319  result.addOperands({lb, ub, step});
320  result.addOperands(initArgs);
321  for (Value v : initArgs)
322  result.addTypes(v.getType());
323  Type t = lb.getType();
324  Region *bodyRegion = result.addRegion();
325  Block *bodyBlock = builder.createBlock(bodyRegion);
326  bodyBlock->addArgument(t, result.location);
327  for (Value v : initArgs)
328  bodyBlock->addArgument(v.getType(), v.getLoc());
329 
330  // Create the default terminator if the builder is not provided and if the
331  // iteration arguments are not provided. Otherwise, leave this to the caller
332  // because we don't know which values to return from the loop.
333  if (initArgs.empty() && !bodyBuilder) {
334  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
335  } else if (bodyBuilder) {
336  OpBuilder::InsertionGuard guard(builder);
337  builder.setInsertionPointToStart(bodyBlock);
338  bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
339  bodyBlock->getArguments().drop_front());
340  }
341 }
342 
343 LogicalResult ForOp::verify() {
344  // Check that the number of init args and op results is the same.
345  if (getInitArgs().size() != getNumResults())
346  return emitOpError(
347  "mismatch in number of loop-carried values and defined values");
348 
349  return success();
350 }
351 
352 LogicalResult ForOp::verifyRegions() {
353  // Check that the body defines as single block argument for the induction
354  // variable.
355  if (getInductionVar().getType() != getLowerBound().getType())
356  return emitOpError(
357  "expected induction variable to be same type as bounds and step");
358 
359  if (getNumRegionIterArgs() != getNumResults())
360  return emitOpError(
361  "mismatch in number of basic block args and defined values");
362 
363  auto initArgs = getInitArgs();
364  auto iterArgs = getRegionIterArgs();
365  auto opResults = getResults();
366  unsigned i = 0;
367  for (auto e : llvm::zip(initArgs, iterArgs, opResults)) {
368  if (std::get<0>(e).getType() != std::get<2>(e).getType())
369  return emitOpError() << "types mismatch between " << i
370  << "th iter operand and defined value";
371  if (std::get<1>(e).getType() != std::get<2>(e).getType())
372  return emitOpError() << "types mismatch between " << i
373  << "th iter region arg and defined value";
374 
375  ++i;
376  }
377  return success();
378 }
379 
380 std::optional<SmallVector<Value>> ForOp::getLoopInductionVars() {
381  return SmallVector<Value>{getInductionVar()};
382 }
383 
384 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopLowerBounds() {
386 }
387 
388 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopSteps() {
389  return SmallVector<OpFoldResult>{OpFoldResult(getStep())};
390 }
391 
392 std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
394 }
395 
396 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397 
398 /// Promotes the loop body of a forOp to its containing block if the forOp
399 /// it can be determined that the loop has a single iteration.
400 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
401  std::optional<int64_t> tripCount =
403  if (!tripCount.has_value() || tripCount != 1)
404  return failure();
405 
406  // Replace all results with the yielded values.
407  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
408  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
409 
410  // Replace block arguments with lower bound (replacement for IV) and
411  // iter_args.
412  SmallVector<Value> bbArgReplacements;
413  bbArgReplacements.push_back(getLowerBound());
414  llvm::append_range(bbArgReplacements, getInitArgs());
415 
416  // Move the loop body operations to the loop's containing block.
417  rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(),
418  getOperation()->getIterator(), bbArgReplacements);
419 
420  // Erase the old terminator and the loop.
421  rewriter.eraseOp(yieldOp);
422  rewriter.eraseOp(*this);
423 
424  return success();
425 }
426 
427 /// Prints the initialization list in the form of
428 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
429 /// where 'inner' values are assumed to be region arguments and 'outer' values
430 /// are regular SSA values.
432  Block::BlockArgListType blocksArgs,
433  ValueRange initializers,
434  StringRef prefix = "") {
435  assert(blocksArgs.size() == initializers.size() &&
436  "expected same length of arguments and initializers");
437  if (initializers.empty())
438  return;
439 
440  p << prefix << '(';
441  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
442  p << std::get<0>(it) << " = " << std::get<1>(it);
443  });
444  p << ")";
445 }
446 
447 void ForOp::print(OpAsmPrinter &p) {
448  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
449  << getUpperBound() << " step " << getStep();
450 
451  printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
452  if (!getInitArgs().empty())
453  p << " -> (" << getInitArgs().getTypes() << ')';
454  p << ' ';
455  if (Type t = getInductionVar().getType(); !t.isIndex())
456  p << " : " << t << ' ';
457  p.printRegion(getRegion(),
458  /*printEntryBlockArgs=*/false,
459  /*printBlockTerminators=*/!getInitArgs().empty());
460  p.printOptionalAttrDict((*this)->getAttrs());
461 }
462 
463 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
464  auto &builder = parser.getBuilder();
465  Type type;
466 
467  OpAsmParser::Argument inductionVariable;
468  OpAsmParser::UnresolvedOperand lb, ub, step;
469 
470  // Parse the induction variable followed by '='.
471  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
472  // Parse loop bounds.
473  parser.parseOperand(lb) || parser.parseKeyword("to") ||
474  parser.parseOperand(ub) || parser.parseKeyword("step") ||
475  parser.parseOperand(step))
476  return failure();
477 
478  // Parse the optional initial iteration arguments.
481  regionArgs.push_back(inductionVariable);
482 
483  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
484  if (hasIterArgs) {
485  // Parse assignment list and results type list.
486  if (parser.parseAssignmentList(regionArgs, operands) ||
487  parser.parseArrowTypeList(result.types))
488  return failure();
489  }
490 
491  if (regionArgs.size() != result.types.size() + 1)
492  return parser.emitError(
493  parser.getNameLoc(),
494  "mismatch in number of loop-carried values and defined values");
495 
496  // Parse optional type, else assume Index.
497  if (parser.parseOptionalColon())
498  type = builder.getIndexType();
499  else if (parser.parseType(type))
500  return failure();
501 
502  // Resolve input operands.
503  regionArgs.front().type = type;
504  if (parser.resolveOperand(lb, type, result.operands) ||
505  parser.resolveOperand(ub, type, result.operands) ||
506  parser.resolveOperand(step, type, result.operands))
507  return failure();
508  if (hasIterArgs) {
509  for (auto argOperandType :
510  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
511  Type type = std::get<2>(argOperandType);
512  std::get<0>(argOperandType).type = type;
513  if (parser.resolveOperand(std::get<1>(argOperandType), type,
514  result.operands))
515  return failure();
516  }
517  }
518 
519  // Parse the body region.
520  Region *body = result.addRegion();
521  if (parser.parseRegion(*body, regionArgs))
522  return failure();
523 
524  ForOp::ensureTerminator(*body, builder, result.location);
525 
526  // Parse the optional attribute list.
527  if (parser.parseOptionalAttrDict(result.attributes))
528  return failure();
529 
530  return success();
531 }
532 
533 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
534 
535 Block::BlockArgListType ForOp::getRegionIterArgs() {
536  return getBody()->getArguments().drop_front(getNumInductionVars());
537 }
538 
539 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
540  return getInitArgsMutable();
541 }
542 
543 FailureOr<LoopLikeOpInterface>
544 ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
545  ValueRange newInitOperands,
546  bool replaceInitOperandUsesInLoop,
547  const NewYieldValuesFn &newYieldValuesFn) {
548  // Create a new loop before the existing one, with the extra operands.
549  OpBuilder::InsertionGuard g(rewriter);
550  rewriter.setInsertionPoint(getOperation());
551  auto inits = llvm::to_vector(getInitArgs());
552  inits.append(newInitOperands.begin(), newInitOperands.end());
553  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
554  getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
555  [](OpBuilder &, Location, Value, ValueRange) {});
556  newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
557 
558  // Generate the new yield values and append them to the scf.yield operation.
559  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
560  ArrayRef<BlockArgument> newIterArgs =
561  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
562  {
563  OpBuilder::InsertionGuard g(rewriter);
564  rewriter.setInsertionPoint(yieldOp);
565  SmallVector<Value> newYieldedValues =
566  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
567  assert(newInitOperands.size() == newYieldedValues.size() &&
568  "expected as many new yield values as new iter operands");
569  rewriter.modifyOpInPlace(yieldOp, [&]() {
570  yieldOp.getResultsMutable().append(newYieldedValues);
571  });
572  }
573 
574  // Move the loop body to the new op.
575  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
576  newLoop.getBody()->getArguments().take_front(
577  getBody()->getNumArguments()));
578 
579  if (replaceInitOperandUsesInLoop) {
580  // Replace all uses of `newInitOperands` with the corresponding basic block
581  // arguments.
582  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
583  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
584  [&](OpOperand &use) {
585  Operation *user = use.getOwner();
586  return newLoop->isProperAncestor(user);
587  });
588  }
589  }
590 
591  // Replace the old loop.
592  rewriter.replaceOp(getOperation(),
593  newLoop->getResults().take_front(getNumResults()));
594  return cast<LoopLikeOpInterface>(newLoop.getOperation());
595 }
596 
598  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
599  if (!ivArg)
600  return ForOp();
601  assert(ivArg.getOwner() && "unlinked block argument");
602  auto *containingOp = ivArg.getOwner()->getParentOp();
603  return dyn_cast_or_null<ForOp>(containingOp);
604 }
605 
606 OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
607  return getInitArgs();
608 }
609 
610 void ForOp::getSuccessorRegions(RegionBranchPoint point,
612  // Both the operation itself and the region may be branching into the body or
613  // back into the operation itself. It is possible for loop not to enter the
614  // body.
615  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
616  regions.push_back(RegionSuccessor(getResults()));
617 }
618 
619 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
620 
621 /// Promotes the loop body of a forallOp to its containing block if it can be
622 /// determined that the loop has a single iteration.
623 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
624  for (auto [lb, ub, step] :
625  llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
626  auto tripCount = constantTripCount(lb, ub, step);
627  if (!tripCount.has_value() || *tripCount != 1)
628  return failure();
629  }
630 
631  promote(rewriter, *this);
632  return success();
633 }
634 
635 Block::BlockArgListType ForallOp::getRegionIterArgs() {
636  return getBody()->getArguments().drop_front(getRank());
637 }
638 
639 MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
640  return getOutputsMutable();
641 }
642 
643 /// Promotes the loop body of a scf::ForallOp to its containing block.
644 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
645  OpBuilder::InsertionGuard g(rewriter);
646  scf::InParallelOp terminator = forallOp.getTerminator();
647 
648  // Replace block arguments with lower bounds (replacements for IVs) and
649  // outputs.
650  SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
651  bbArgReplacements.append(forallOp.getOutputs().begin(),
652  forallOp.getOutputs().end());
653 
654  // Move the loop body operations to the loop's containing block.
655  rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(),
656  forallOp->getIterator(), bbArgReplacements);
657 
658  // Replace the terminator with tensor.insert_slice ops.
659  rewriter.setInsertionPointAfter(forallOp);
660  SmallVector<Value> results;
661  results.reserve(forallOp.getResults().size());
662  for (auto &yieldingOp : terminator.getYieldingOps()) {
663  auto parallelInsertSliceOp =
664  cast<tensor::ParallelInsertSliceOp>(yieldingOp);
665 
666  Value dst = parallelInsertSliceOp.getDest();
667  Value src = parallelInsertSliceOp.getSource();
668  if (llvm::isa<TensorType>(src.getType())) {
669  results.push_back(rewriter.create<tensor::InsertSliceOp>(
670  forallOp.getLoc(), dst.getType(), src, dst,
671  parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
672  parallelInsertSliceOp.getStrides(),
673  parallelInsertSliceOp.getStaticOffsets(),
674  parallelInsertSliceOp.getStaticSizes(),
675  parallelInsertSliceOp.getStaticStrides()));
676  } else {
677  llvm_unreachable("unsupported terminator");
678  }
679  }
680  rewriter.replaceAllUsesWith(forallOp.getResults(), results);
681 
682  // Erase the old terminator and the loop.
683  rewriter.eraseOp(terminator);
684  rewriter.eraseOp(forallOp);
685 }
686 
688  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
689  ValueRange steps, ValueRange iterArgs,
691  bodyBuilder) {
692  assert(lbs.size() == ubs.size() &&
693  "expected the same number of lower and upper bounds");
694  assert(lbs.size() == steps.size() &&
695  "expected the same number of lower bounds and steps");
696 
697  // If there are no bounds, call the body-building function and return early.
698  if (lbs.empty()) {
699  ValueVector results =
700  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
701  : ValueVector();
702  assert(results.size() == iterArgs.size() &&
703  "loop nest body must return as many values as loop has iteration "
704  "arguments");
705  return LoopNest{{}, std::move(results)};
706  }
707 
708  // First, create the loop structure iteratively using the body-builder
709  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
710  OpBuilder::InsertionGuard guard(builder);
713  loops.reserve(lbs.size());
714  ivs.reserve(lbs.size());
715  ValueRange currentIterArgs = iterArgs;
716  Location currentLoc = loc;
717  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
718  auto loop = builder.create<scf::ForOp>(
719  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
720  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
721  ValueRange args) {
722  ivs.push_back(iv);
723  // It is safe to store ValueRange args because it points to block
724  // arguments of a loop operation that we also own.
725  currentIterArgs = args;
726  currentLoc = nestedLoc;
727  });
728  // Set the builder to point to the body of the newly created loop. We don't
729  // do this in the callback because the builder is reset when the callback
730  // returns.
731  builder.setInsertionPointToStart(loop.getBody());
732  loops.push_back(loop);
733  }
734 
735  // For all loops but the innermost, yield the results of the nested loop.
736  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
737  builder.setInsertionPointToEnd(loops[i].getBody());
738  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
739  }
740 
741  // In the body of the innermost loop, call the body building function if any
742  // and yield its results.
743  builder.setInsertionPointToStart(loops.back().getBody());
744  ValueVector results = bodyBuilder
745  ? bodyBuilder(builder, currentLoc, ivs,
746  loops.back().getRegionIterArgs())
747  : ValueVector();
748  assert(results.size() == iterArgs.size() &&
749  "loop nest body must return as many values as loop has iteration "
750  "arguments");
751  builder.setInsertionPointToEnd(loops.back().getBody());
752  builder.create<scf::YieldOp>(loc, results);
753 
754  // Return the loops.
755  ValueVector nestResults;
756  llvm::copy(loops.front().getResults(), std::back_inserter(nestResults));
757  return LoopNest{std::move(loops), std::move(nestResults)};
758 }
759 
761  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
762  ValueRange steps,
763  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
764  // Delegate to the main function by wrapping the body builder.
765  return buildLoopNest(builder, loc, lbs, ubs, steps, std::nullopt,
766  [&bodyBuilder](OpBuilder &nestedBuilder,
767  Location nestedLoc, ValueRange ivs,
768  ValueRange) -> ValueVector {
769  if (bodyBuilder)
770  bodyBuilder(nestedBuilder, nestedLoc, ivs);
771  return {};
772  });
773 }
774 
777  OpOperand &operand, Value replacement,
778  const ValueTypeCastFnTy &castFn) {
779  assert(operand.getOwner() == forOp);
780  Type oldType = operand.get().getType(), newType = replacement.getType();
781 
782  // 1. Create new iter operands, exactly 1 is replaced.
783  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
784  "expected an iter OpOperand");
785  assert(operand.get().getType() != replacement.getType() &&
786  "Expected a different type");
787  SmallVector<Value> newIterOperands;
788  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
789  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
790  newIterOperands.push_back(replacement);
791  continue;
792  }
793  newIterOperands.push_back(opOperand.get());
794  }
795 
796  // 2. Create the new forOp shell.
797  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
798  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
799  forOp.getStep(), newIterOperands);
800  newForOp->setAttrs(forOp->getAttrs());
801  Block &newBlock = newForOp.getRegion().front();
802  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
803  newBlock.getArguments().end());
804 
805  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
806  // corresponding to the `replacement` value.
807  OpBuilder::InsertionGuard g(rewriter);
808  rewriter.setInsertionPointToStart(&newBlock);
809  BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
810  &newForOp->getOpOperand(operand.getOperandNumber()));
811  Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
812  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
813 
814  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
815  Block &oldBlock = forOp.getRegion().front();
816  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
817 
818  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
819  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
820  rewriter.setInsertionPoint(clonedYieldOp);
821  unsigned yieldIdx =
822  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
823  Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
824  clonedYieldOp.getOperand(yieldIdx));
825  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
826  newYieldOperands[yieldIdx] = castOut;
827  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
828  rewriter.eraseOp(clonedYieldOp);
829 
830  // 6. Inject an outgoing cast op after the forOp.
831  rewriter.setInsertionPointAfter(newForOp);
832  SmallVector<Value> newResults = newForOp.getResults();
833  newResults[yieldIdx] =
834  castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
835 
836  return newResults;
837 }
838 
839 namespace {
840 // Fold away ForOp iter arguments when:
841 // 1) The op yields the iter arguments.
842 // 2) The argument's corresponding outer region iterators (inputs) are yielded.
843 // 3) The iter arguments have no use and the corresponding (operation) results
844 // have no use.
845 //
846 // These arguments must be defined outside of the ForOp region and can just be
847 // forwarded after simplifying the op inits, yields and returns.
848 //
849 // The implementation uses `inlineBlockBefore` to steal the content of the
850 // original ForOp and avoid cloning.
851 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
853 
854  LogicalResult matchAndRewrite(scf::ForOp forOp,
855  PatternRewriter &rewriter) const final {
856  bool canonicalize = false;
857 
858  // An internal flat vector of block transfer
859  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
860  // transformed block argument mappings. This plays the role of a
861  // IRMapping for the particular use case of calling into
862  // `inlineBlockBefore`.
863  int64_t numResults = forOp.getNumResults();
864  SmallVector<bool, 4> keepMask;
865  keepMask.reserve(numResults);
866  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
867  newResultValues;
868  newBlockTransferArgs.reserve(1 + numResults);
869  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
870  newIterArgs.reserve(forOp.getInitArgs().size());
871  newYieldValues.reserve(numResults);
872  newResultValues.reserve(numResults);
873  DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
874  for (auto [init, arg, result, yielded] :
875  llvm::zip(forOp.getInitArgs(), // iter from outside
876  forOp.getRegionIterArgs(), // iter inside region
877  forOp.getResults(), // op results
878  forOp.getYieldedValues() // iter yield
879  )) {
880  // Forwarded is `true` when:
881  // 1) The region `iter` argument is yielded.
882  // 2) The region `iter` argument the corresponding input is yielded.
883  // 3) The region `iter` argument has no use, and the corresponding op
884  // result has no use.
885  bool forwarded = (arg == yielded) || (init == yielded) ||
886  (arg.use_empty() && result.use_empty());
887  if (forwarded) {
888  canonicalize = true;
889  keepMask.push_back(false);
890  newBlockTransferArgs.push_back(init);
891  newResultValues.push_back(init);
892  continue;
893  }
894 
895  // Check if a previous kept argument always has the same values for init
896  // and yielded values.
897  if (auto it = initYieldToArg.find({init, yielded});
898  it != initYieldToArg.end()) {
899  canonicalize = true;
900  keepMask.push_back(false);
901  auto [sameArg, sameResult] = it->second;
902  rewriter.replaceAllUsesWith(arg, sameArg);
903  rewriter.replaceAllUsesWith(result, sameResult);
904  // The replacement value doesn't matter because there are no uses.
905  newBlockTransferArgs.push_back(init);
906  newResultValues.push_back(init);
907  continue;
908  }
909 
910  // This value is kept.
911  initYieldToArg.insert({{init, yielded}, {arg, result}});
912  keepMask.push_back(true);
913  newIterArgs.push_back(init);
914  newYieldValues.push_back(yielded);
915  newBlockTransferArgs.push_back(Value()); // placeholder with null value
916  newResultValues.push_back(Value()); // placeholder with null value
917  }
918 
919  if (!canonicalize)
920  return failure();
921 
922  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
923  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
924  forOp.getStep(), newIterArgs);
925  newForOp->setAttrs(forOp->getAttrs());
926  Block &newBlock = newForOp.getRegion().front();
927 
928  // Replace the null placeholders with newly constructed values.
929  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
930  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
931  idx != e; ++idx) {
932  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
933  Value &newResultVal = newResultValues[idx];
934  assert((blockTransferArg && newResultVal) ||
935  (!blockTransferArg && !newResultVal));
936  if (!blockTransferArg) {
937  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
938  newResultVal = newForOp.getResult(collapsedIdx++);
939  }
940  }
941 
942  Block &oldBlock = forOp.getRegion().front();
943  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
944  "unexpected argument size mismatch");
945 
946  // No results case: the scf::ForOp builder already created a zero
947  // result terminator. Merge before this terminator and just get rid of the
948  // original terminator that has been merged in.
949  if (newIterArgs.empty()) {
950  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
951  rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
952  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
953  rewriter.replaceOp(forOp, newResultValues);
954  return success();
955  }
956 
957  // No terminator case: merge and rewrite the merged terminator.
958  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
959  OpBuilder::InsertionGuard g(rewriter);
960  rewriter.setInsertionPoint(mergedTerminator);
961  SmallVector<Value, 4> filteredOperands;
962  filteredOperands.reserve(newResultValues.size());
963  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
964  if (keepMask[idx])
965  filteredOperands.push_back(mergedTerminator.getOperand(idx));
966  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
967  filteredOperands);
968  };
969 
970  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
971  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
972  cloneFilteredTerminator(mergedYieldOp);
973  rewriter.eraseOp(mergedYieldOp);
974  rewriter.replaceOp(forOp, newResultValues);
975  return success();
976  }
977 };
978 
979 /// Util function that tries to compute a constant diff between u and l.
980 /// Returns std::nullopt when the difference between two AffineValueMap is
981 /// dynamic.
982 static std::optional<int64_t> computeConstDiff(Value l, Value u) {
983  IntegerAttr clb, cub;
984  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
985  llvm::APInt lbValue = clb.getValue();
986  llvm::APInt ubValue = cub.getValue();
987  return (ubValue - lbValue).getSExtValue();
988  }
989 
990  // Else a simple pattern match for x + c or c + x
991  llvm::APInt diff;
992  if (matchPattern(
993  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
994  matchPattern(
995  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
996  return diff.getSExtValue();
997  return std::nullopt;
998 }
999 
1000 /// Rewriting pattern that erases loops that are known not to iterate, replaces
1001 /// single-iteration loops with their bodies, and removes empty loops that
1002 /// iterate at least once and only return values defined outside of the loop.
1003 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
1005 
1006  LogicalResult matchAndRewrite(ForOp op,
1007  PatternRewriter &rewriter) const override {
1008  // If the upper bound is the same as the lower bound, the loop does not
1009  // iterate, just remove it.
1010  if (op.getLowerBound() == op.getUpperBound()) {
1011  rewriter.replaceOp(op, op.getInitArgs());
1012  return success();
1013  }
1014 
1015  std::optional<int64_t> diff =
1016  computeConstDiff(op.getLowerBound(), op.getUpperBound());
1017  if (!diff)
1018  return failure();
1019 
1020  // If the loop is known to have 0 iterations, remove it.
1021  if (*diff <= 0) {
1022  rewriter.replaceOp(op, op.getInitArgs());
1023  return success();
1024  }
1025 
1026  std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
1027  if (!maybeStepValue)
1028  return failure();
1029 
1030  // If the loop is known to have 1 iteration, inline its body and remove the
1031  // loop.
1032  llvm::APInt stepValue = *maybeStepValue;
1033  if (stepValue.sge(*diff)) {
1034  SmallVector<Value, 4> blockArgs;
1035  blockArgs.reserve(op.getInitArgs().size() + 1);
1036  blockArgs.push_back(op.getLowerBound());
1037  llvm::append_range(blockArgs, op.getInitArgs());
1038  replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
1039  return success();
1040  }
1041 
1042  // Now we are left with loops that have more than 1 iterations.
1043  Block &block = op.getRegion().front();
1044  if (!llvm::hasSingleElement(block))
1045  return failure();
1046  // If the loop is empty, iterates at least once, and only returns values
1047  // defined outside of the loop, remove it and replace it with yield values.
1048  if (llvm::any_of(op.getYieldedValues(),
1049  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
1050  return failure();
1051  rewriter.replaceOp(op, op.getYieldedValues());
1052  return success();
1053  }
1054 };
1055 
1056 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
1057 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
1058 ///
1059 /// ```
1060 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1061 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
1062 /// -> (tensor<?x?xf32>) {
1063 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1064 /// scf.yield %2 : tensor<?x?xf32>
1065 /// }
1066 /// use_of(%1)
1067 /// ```
1068 ///
1069 /// folds into:
1070 ///
1071 /// ```
1072 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
1073 /// -> (tensor<32x1024xf32>) {
1074 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
1075 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1076 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
1077 /// scf.yield %4 : tensor<32x1024xf32>
1078 /// }
1079 /// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
1080 /// use_of(%1)
1081 /// ```
1082 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
1084 
1085  LogicalResult matchAndRewrite(ForOp op,
1086  PatternRewriter &rewriter) const override {
1087  for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
1088  OpOperand &iterOpOperand = std::get<0>(it);
1089  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
1090  if (!incomingCast ||
1091  incomingCast.getSource().getType() == incomingCast.getType())
1092  continue;
1093  // If the dest type of the cast does not preserve static information in
1094  // the source type.
1096  incomingCast.getDest().getType(),
1097  incomingCast.getSource().getType()))
1098  continue;
1099  if (!std::get<1>(it).hasOneUse())
1100  continue;
1101 
1102  // Create a new ForOp with that iter operand replaced.
1103  rewriter.replaceOp(
1105  rewriter, op, iterOpOperand, incomingCast.getSource(),
1106  [](OpBuilder &b, Location loc, Type type, Value source) {
1107  return b.create<tensor::CastOp>(loc, type, source);
1108  }));
1109  return success();
1110  }
1111  return failure();
1112  }
1113 };
1114 
1115 } // namespace
1116 
1117 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1118  MLIRContext *context) {
1119  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
1120  context);
1121 }
1122 
1123 std::optional<APInt> ForOp::getConstantStep() {
1124  IntegerAttr step;
1125  if (matchPattern(getStep(), m_Constant(&step)))
1126  return step.getValue();
1127  return {};
1128 }
1129 
1130 std::optional<MutableArrayRef<OpOperand>> ForOp::getYieldedValuesMutable() {
1131  return cast<scf::YieldOp>(getBody()->getTerminator()).getResultsMutable();
1132 }
1133 
1134 Speculation::Speculatability ForOp::getSpeculatability() {
1135  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1136  // and End.
1137  if (auto constantStep = getConstantStep())
1138  if (*constantStep == 1)
1140 
1141  // For Step != 1, the loop may not terminate. We can add more smarts here if
1142  // needed.
1144 }
1145 
1146 //===----------------------------------------------------------------------===//
1147 // ForallOp
1148 //===----------------------------------------------------------------------===//
1149 
1150 LogicalResult ForallOp::verify() {
1151  unsigned numLoops = getRank();
1152  // Check number of outputs.
1153  if (getNumResults() != getOutputs().size())
1154  return emitOpError("produces ")
1155  << getNumResults() << " results, but has only "
1156  << getOutputs().size() << " outputs";
1157 
1158  // Check that the body defines block arguments for thread indices and outputs.
1159  auto *body = getBody();
1160  if (body->getNumArguments() != numLoops + getOutputs().size())
1161  return emitOpError("region expects ") << numLoops << " arguments";
1162  for (int64_t i = 0; i < numLoops; ++i)
1163  if (!body->getArgument(i).getType().isIndex())
1164  return emitOpError("expects ")
1165  << i << "-th block argument to be an index";
1166  for (unsigned i = 0; i < getOutputs().size(); ++i)
1167  if (body->getArgument(i + numLoops).getType() != getOutputs()[i].getType())
1168  return emitOpError("type mismatch between ")
1169  << i << "-th output and corresponding block argument";
1170  if (getMapping().has_value() && !getMapping()->empty()) {
1171  if (static_cast<int64_t>(getMapping()->size()) != numLoops)
1172  return emitOpError() << "mapping attribute size must match op rank";
1173  for (auto map : getMapping()->getValue()) {
1174  if (!isa<DeviceMappingAttrInterface>(map))
1175  return emitOpError()
1176  << getMappingAttrName() << " is not device mapping attribute";
1177  }
1178  }
1179 
1180  // Verify mixed static/dynamic control variables.
1181  Operation *op = getOperation();
1182  if (failed(verifyListOfOperandsOrIntegers(op, "lower bound", numLoops,
1183  getStaticLowerBound(),
1184  getDynamicLowerBound())))
1185  return failure();
1186  if (failed(verifyListOfOperandsOrIntegers(op, "upper bound", numLoops,
1187  getStaticUpperBound(),
1188  getDynamicUpperBound())))
1189  return failure();
1190  if (failed(verifyListOfOperandsOrIntegers(op, "step", numLoops,
1191  getStaticStep(), getDynamicStep())))
1192  return failure();
1193 
1194  return success();
1195 }
1196 
1197 void ForallOp::print(OpAsmPrinter &p) {
1198  Operation *op = getOperation();
1199  p << " (" << getInductionVars();
1200  if (isNormalized()) {
1201  p << ") in ";
1202  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1203  /*valueTypes=*/{}, /*scalables=*/{},
1205  } else {
1206  p << ") = ";
1207  printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
1208  /*valueTypes=*/{}, /*scalables=*/{},
1210  p << " to ";
1211  printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
1212  /*valueTypes=*/{}, /*scalables=*/{},
1214  p << " step ";
1215  printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
1216  /*valueTypes=*/{}, /*scalables=*/{},
1218  }
1219  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1220  p << " ";
1221  if (!getRegionOutArgs().empty())
1222  p << "-> (" << getResultTypes() << ") ";
1223  p.printRegion(getRegion(),
1224  /*printEntryBlockArgs=*/false,
1225  /*printBlockTerminators=*/getNumResults() > 0);
1226  p.printOptionalAttrDict(op->getAttrs(), {getOperandSegmentSizesAttrName(),
1227  getStaticLowerBoundAttrName(),
1228  getStaticUpperBoundAttrName(),
1229  getStaticStepAttrName()});
1230 }
1231 
1232 ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
1233  OpBuilder b(parser.getContext());
1234  auto indexType = b.getIndexType();
1235 
1236  // Parse an opening `(` followed by thread index variables followed by `)`
1237  // TODO: when we can refer to such "induction variable"-like handles from the
1238  // declarative assembly format, we can implement the parser as a custom hook.
1241  return failure();
1242 
1243  DenseI64ArrayAttr staticLbs, staticUbs, staticSteps;
1244  SmallVector<OpAsmParser::UnresolvedOperand> dynamicLbs, dynamicUbs,
1245  dynamicSteps;
1246  if (succeeded(parser.parseOptionalKeyword("in"))) {
1247  // Parse upper bounds.
1248  if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1249  /*valueTypes=*/nullptr,
1251  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1252  return failure();
1253 
1254  unsigned numLoops = ivs.size();
1255  staticLbs = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 0));
1256  staticSteps = b.getDenseI64ArrayAttr(SmallVector<int64_t>(numLoops, 1));
1257  } else {
1258  // Parse lower bounds.
1259  if (parser.parseEqual() ||
1260  parseDynamicIndexList(parser, dynamicLbs, staticLbs,
1261  /*valueTypes=*/nullptr,
1263 
1264  parser.resolveOperands(dynamicLbs, indexType, result.operands))
1265  return failure();
1266 
1267  // Parse upper bounds.
1268  if (parser.parseKeyword("to") ||
1269  parseDynamicIndexList(parser, dynamicUbs, staticUbs,
1270  /*valueTypes=*/nullptr,
1272  parser.resolveOperands(dynamicUbs, indexType, result.operands))
1273  return failure();
1274 
1275  // Parse step values.
1276  if (parser.parseKeyword("step") ||
1277  parseDynamicIndexList(parser, dynamicSteps, staticSteps,
1278  /*valueTypes=*/nullptr,
1280  parser.resolveOperands(dynamicSteps, indexType, result.operands))
1281  return failure();
1282  }
1283 
1284  // Parse out operands and results.
1287  SMLoc outOperandsLoc = parser.getCurrentLocation();
1288  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1289  if (outOperands.size() != result.types.size())
1290  return parser.emitError(outOperandsLoc,
1291  "mismatch between out operands and types");
1292  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1293  parser.parseOptionalArrowTypeList(result.types) ||
1294  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1295  result.operands))
1296  return failure();
1297  }
1298 
1299  // Parse region.
1301  std::unique_ptr<Region> region = std::make_unique<Region>();
1302  for (auto &iv : ivs) {
1303  iv.type = b.getIndexType();
1304  regionArgs.push_back(iv);
1305  }
1306  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1307  auto &out = it.value();
1308  out.type = result.types[it.index()];
1309  regionArgs.push_back(out);
1310  }
1311  if (parser.parseRegion(*region, regionArgs))
1312  return failure();
1313 
1314  // Ensure terminator and move region.
1315  ForallOp::ensureTerminator(*region, b, result.location);
1316  result.addRegion(std::move(region));
1317 
1318  // Parse the optional attribute list.
1319  if (parser.parseOptionalAttrDict(result.attributes))
1320  return failure();
1321 
1322  result.addAttribute("staticLowerBound", staticLbs);
1323  result.addAttribute("staticUpperBound", staticUbs);
1324  result.addAttribute("staticStep", staticSteps);
1325  result.addAttribute("operandSegmentSizes",
1327  {static_cast<int32_t>(dynamicLbs.size()),
1328  static_cast<int32_t>(dynamicUbs.size()),
1329  static_cast<int32_t>(dynamicSteps.size()),
1330  static_cast<int32_t>(outOperands.size())}));
1331  return success();
1332 }
1333 
1334 // Builder that takes loop bounds.
1335 void ForallOp::build(
1338  ArrayRef<OpFoldResult> steps, ValueRange outputs,
1339  std::optional<ArrayAttr> mapping,
1340  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1341  SmallVector<int64_t> staticLbs, staticUbs, staticSteps;
1342  SmallVector<Value> dynamicLbs, dynamicUbs, dynamicSteps;
1343  dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1344  dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1345  dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1346 
1347  result.addOperands(dynamicLbs);
1348  result.addOperands(dynamicUbs);
1349  result.addOperands(dynamicSteps);
1350  result.addOperands(outputs);
1351  result.addTypes(TypeRange(outputs));
1352 
1353  result.addAttribute(getStaticLowerBoundAttrName(result.name),
1354  b.getDenseI64ArrayAttr(staticLbs));
1355  result.addAttribute(getStaticUpperBoundAttrName(result.name),
1356  b.getDenseI64ArrayAttr(staticUbs));
1357  result.addAttribute(getStaticStepAttrName(result.name),
1358  b.getDenseI64ArrayAttr(staticSteps));
1359  result.addAttribute(
1360  "operandSegmentSizes",
1361  b.getDenseI32ArrayAttr({static_cast<int32_t>(dynamicLbs.size()),
1362  static_cast<int32_t>(dynamicUbs.size()),
1363  static_cast<int32_t>(dynamicSteps.size()),
1364  static_cast<int32_t>(outputs.size())}));
1365  if (mapping.has_value()) {
1367  mapping.value());
1368  }
1369 
1370  Region *bodyRegion = result.addRegion();
1372  b.createBlock(bodyRegion);
1373  Block &bodyBlock = bodyRegion->front();
1374 
1375  // Add block arguments for indices and outputs.
1376  bodyBlock.addArguments(
1377  SmallVector<Type>(lbs.size(), b.getIndexType()),
1378  SmallVector<Location>(staticLbs.size(), result.location));
1379  bodyBlock.addArguments(
1380  TypeRange(outputs),
1381  SmallVector<Location>(outputs.size(), result.location));
1382 
1383  b.setInsertionPointToStart(&bodyBlock);
1384  if (!bodyBuilderFn) {
1385  ForallOp::ensureTerminator(*bodyRegion, b, result.location);
1386  return;
1387  }
1388  bodyBuilderFn(b, result.location, bodyBlock.getArguments());
1389 }
1390 
1391 // Builder that takes loop bounds.
1392 void ForallOp::build(
1394  ArrayRef<OpFoldResult> ubs, ValueRange outputs,
1395  std::optional<ArrayAttr> mapping,
1396  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1397  unsigned numLoops = ubs.size();
1398  SmallVector<OpFoldResult> lbs(numLoops, b.getIndexAttr(0));
1399  SmallVector<OpFoldResult> steps(numLoops, b.getIndexAttr(1));
1400  build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
1401 }
1402 
1403 // Checks if the lbs are zeros and steps are ones.
1404 bool ForallOp::isNormalized() {
1405  auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
1406  return llvm::all_of(results, [&](OpFoldResult ofr) {
1407  auto intValue = getConstantIntValue(ofr);
1408  return intValue.has_value() && intValue == val;
1409  });
1410  };
1411  return allEqual(getMixedLowerBound(), 0) && allEqual(getMixedStep(), 1);
1412 }
1413 
1414 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1415 // unaware of the fact that our terminator also needs a region to be
1416 // well-formed. We override it here to ensure that we do the right thing.
1417 void ForallOp::ensureTerminator(Region &region, OpBuilder &builder,
1418  Location loc) {
1420  ForallOp>::ensureTerminator(region, builder, loc);
1421  auto terminator =
1422  llvm::dyn_cast<InParallelOp>(region.front().getTerminator());
1423  if (terminator.getRegion().empty())
1424  builder.createBlock(&terminator.getRegion());
1425 }
1426 
1427 InParallelOp ForallOp::getTerminator() {
1428  return cast<InParallelOp>(getBody()->getTerminator());
1429 }
1430 
1431 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
1432  SmallVector<Operation *> storeOps;
1433  InParallelOp inParallelOp = getTerminator();
1434  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
1435  if (auto parallelInsertSliceOp =
1436  dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
1437  parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
1438  storeOps.push_back(parallelInsertSliceOp);
1439  }
1440  }
1441  return storeOps;
1442 }
1443 
1444 std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() {
1445  return SmallVector<Value>{getBody()->getArguments().take_front(getRank())};
1446 }
1447 
1448 // Get lower bounds as OpFoldResult.
1449 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopLowerBounds() {
1450  Builder b(getOperation()->getContext());
1451  return getMixedValues(getStaticLowerBound(), getDynamicLowerBound(), b);
1452 }
1453 
1454 // Get upper bounds as OpFoldResult.
1455 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopUpperBounds() {
1456  Builder b(getOperation()->getContext());
1457  return getMixedValues(getStaticUpperBound(), getDynamicUpperBound(), b);
1458 }
1459 
1460 // Get steps as OpFoldResult.
1461 std::optional<SmallVector<OpFoldResult>> ForallOp::getLoopSteps() {
1462  Builder b(getOperation()->getContext());
1463  return getMixedValues(getStaticStep(), getDynamicStep(), b);
1464 }
1465 
1467  auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
1468  if (!tidxArg)
1469  return ForallOp();
1470  assert(tidxArg.getOwner() && "unlinked block argument");
1471  auto *containingOp = tidxArg.getOwner()->getParentOp();
1472  return dyn_cast<ForallOp>(containingOp);
1473 }
1474 
1475 namespace {
1476 /// Fold tensor.dim(forall shared_outs(... = %t)) to tensor.dim(%t).
1477 struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
1479 
1480  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1481  PatternRewriter &rewriter) const final {
1482  auto forallOp = dimOp.getSource().getDefiningOp<ForallOp>();
1483  if (!forallOp)
1484  return failure();
1485  Value sharedOut =
1486  forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
1487  ->get();
1488  rewriter.modifyOpInPlace(
1489  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1490  return success();
1491  }
1492 };
1493 
1494 class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
1495 public:
1497 
1498  LogicalResult matchAndRewrite(ForallOp op,
1499  PatternRewriter &rewriter) const override {
1500  SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
1501  SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
1502  SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
1503  if (failed(foldDynamicIndexList(mixedLowerBound)) &&
1504  failed(foldDynamicIndexList(mixedUpperBound)) &&
1505  failed(foldDynamicIndexList(mixedStep)))
1506  return failure();
1507 
1508  rewriter.modifyOpInPlace(op, [&]() {
1509  SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
1510  SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
1511  dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
1512  staticLowerBound);
1513  op.getDynamicLowerBoundMutable().assign(dynamicLowerBound);
1514  op.setStaticLowerBound(staticLowerBound);
1515 
1516  dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound,
1517  staticUpperBound);
1518  op.getDynamicUpperBoundMutable().assign(dynamicUpperBound);
1519  op.setStaticUpperBound(staticUpperBound);
1520 
1521  dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
1522  op.getDynamicStepMutable().assign(dynamicStep);
1523  op.setStaticStep(staticStep);
1524 
1525  op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
1526  rewriter.getDenseI32ArrayAttr(
1527  {static_cast<int32_t>(dynamicLowerBound.size()),
1528  static_cast<int32_t>(dynamicUpperBound.size()),
1529  static_cast<int32_t>(dynamicStep.size()),
1530  static_cast<int32_t>(op.getNumResults())}));
1531  });
1532  return success();
1533  }
1534 };
1535 
1536 /// The following canonicalization pattern folds the iter arguments of
1537 /// scf.forall op if :-
1538 /// 1. The corresponding result has zero uses.
1539 /// 2. The iter argument is NOT being modified within the loop body.
1540 /// uses.
1541 ///
1542 /// Example of first case :-
1543 /// INPUT:
1544 /// %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
1545 /// {
1546 /// ...
1547 /// <SOME USE OF %arg0>
1548 /// <SOME USE OF %arg1>
1549 /// <SOME USE OF %arg2>
1550 /// ...
1551 /// scf.forall.in_parallel {
1552 /// <STORE OP WITH DESTINATION %arg1>
1553 /// <STORE OP WITH DESTINATION %arg0>
1554 /// <STORE OP WITH DESTINATION %arg2>
1555 /// }
1556 /// }
1557 /// return %res#1
1558 ///
1559 /// OUTPUT:
1560 /// %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
1561 /// {
1562 /// ...
1563 /// <SOME USE OF %a>
1564 /// <SOME USE OF %new_arg0>
1565 /// <SOME USE OF %c>
1566 /// ...
1567 /// scf.forall.in_parallel {
1568 /// <STORE OP WITH DESTINATION %new_arg0>
1569 /// }
1570 /// }
1571 /// return %res
1572 ///
1573 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
1574 /// scf.forall is replaced by their corresponding operands.
1575 /// 2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
1576 /// of the scf.forall besides within scf.forall.in_parallel terminator,
1577 /// this canonicalization remains valid. For more details, please refer
1578 /// to :
1579 /// https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
1580 /// 3. TODO(avarma): Generalize it for other store ops. Currently it
1581 /// handles tensor.parallel_insert_slice ops only.
1582 ///
1583 /// Example of second case :-
1584 /// INPUT:
1585 /// %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
1586 /// {
1587 /// ...
1588 /// <SOME USE OF %arg0>
1589 /// <SOME USE OF %arg1>
1590 /// ...
1591 /// scf.forall.in_parallel {
1592 /// <STORE OP WITH DESTINATION %arg1>
1593 /// }
1594 /// }
1595 /// return %res#0, %res#1
1596 ///
1597 /// OUTPUT:
1598 /// %res = scf.forall ... shared_outs(%new_arg0 = %b)
1599 /// {
1600 /// ...
1601 /// <SOME USE OF %a>
1602 /// <SOME USE OF %new_arg0>
1603 /// ...
1604 /// scf.forall.in_parallel {
1605 /// <STORE OP WITH DESTINATION %new_arg0>
1606 /// }
1607 /// }
1608 /// return %a, %res
1609 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
1611 
1612  LogicalResult matchAndRewrite(ForallOp forallOp,
1613  PatternRewriter &rewriter) const final {
1614  // Step 1: For a given i-th result of scf.forall, check the following :-
1615  // a. If it has any use.
1616  // b. If the corresponding iter argument is being modified within
1617  // the loop, i.e. has at least one store op with the iter arg as
1618  // its destination operand. For this we use
1619  // ForallOp::getCombiningOps(iter_arg).
1620  //
1621  // Based on the check we maintain the following :-
1622  // a. `resultToDelete` - i-th result of scf.forall that'll be
1623  // deleted.
1624  // b. `resultToReplace` - i-th result of the old scf.forall
1625  // whose uses will be replaced by the new scf.forall.
1626  // c. `newOuts` - the shared_outs' operand of the new scf.forall
1627  // corresponding to the i-th result with at least one use.
1628  SetVector<OpResult> resultToDelete;
1629  SmallVector<Value> resultToReplace;
1630  SmallVector<Value> newOuts;
1631  for (OpResult result : forallOp.getResults()) {
1632  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1633  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1634  if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
1635  resultToDelete.insert(result);
1636  } else {
1637  resultToReplace.push_back(result);
1638  newOuts.push_back(opOperand->get());
1639  }
1640  }
1641 
1642  // Return early if all results of scf.forall have at least one use and being
1643  // modified within the loop.
1644  if (resultToDelete.empty())
1645  return failure();
1646 
1647  // Step 2: For the the i-th result, do the following :-
1648  // a. Fetch the corresponding BlockArgument.
1649  // b. Look for store ops (currently tensor.parallel_insert_slice)
1650  // with the BlockArgument as its destination operand.
1651  // c. Remove the operations fetched in b.
1652  for (OpResult result : resultToDelete) {
1653  OpOperand *opOperand = forallOp.getTiedOpOperand(result);
1654  BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
1655  SmallVector<Operation *> combiningOps =
1656  forallOp.getCombiningOps(blockArg);
1657  for (Operation *combiningOp : combiningOps)
1658  rewriter.eraseOp(combiningOp);
1659  }
1660 
1661  // Step 3. Create a new scf.forall op with the new shared_outs' operands
1662  // fetched earlier
1663  auto newForallOp = rewriter.create<scf::ForallOp>(
1664  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1665  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
1666  forallOp.getMapping(),
1667  /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});
1668 
1669  // Step 4. Merge the block of the old scf.forall into the newly created
1670  // scf.forall using the new set of arguments.
1671  Block *loopBody = forallOp.getBody();
1672  Block *newLoopBody = newForallOp.getBody();
1673  ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
1674  // Form initial new bbArg list with just the control operands of the new
1675  // scf.forall op.
1676  SmallVector<Value> newBlockArgs =
1677  llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
1678  [](BlockArgument b) -> Value { return b; });
1679  Block::BlockArgListType newSharedOutsArgs = newForallOp.getRegionOutArgs();
1680  unsigned index = 0;
1681  // Take the new corresponding bbArg if the old bbArg was used as a
1682  // destination in the in_parallel op. For all other bbArgs, use the
1683  // corresponding init_arg from the old scf.forall op.
1684  for (OpResult result : forallOp.getResults()) {
1685  if (resultToDelete.count(result)) {
1686  newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
1687  } else {
1688  newBlockArgs.push_back(newSharedOutsArgs[index++]);
1689  }
1690  }
1691  rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
1692 
1693  // Step 5. Replace the uses of result of old scf.forall with that of the new
1694  // scf.forall.
1695  for (auto &&[oldResult, newResult] :
1696  llvm::zip(resultToReplace, newForallOp->getResults()))
1697  rewriter.replaceAllUsesWith(oldResult, newResult);
1698 
1699  // Step 6. Replace the uses of those values that either has no use or are
1700  // not being modified within the loop with the corresponding
1701  // OpOperand.
1702  for (OpResult oldResult : resultToDelete)
1703  rewriter.replaceAllUsesWith(oldResult,
1704  forallOp.getTiedOpOperand(oldResult)->get());
1705  return success();
1706  }
1707 };
1708 
1709 struct ForallOpSingleOrZeroIterationDimsFolder
1710  : public OpRewritePattern<ForallOp> {
1712 
1713  LogicalResult matchAndRewrite(ForallOp op,
1714  PatternRewriter &rewriter) const override {
1715  // Do not fold dimensions if they are mapped to processing units.
1716  if (op.getMapping().has_value() && !op.getMapping()->empty())
1717  return failure();
1718  Location loc = op.getLoc();
1719 
1720  // Compute new loop bounds that omit all single-iteration loop dimensions.
1721  SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
1722  newMixedSteps;
1723  IRMapping mapping;
1724  for (auto [lb, ub, step, iv] :
1725  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1726  op.getMixedStep(), op.getInductionVars())) {
1727  auto numIterations = constantTripCount(lb, ub, step);
1728  if (numIterations.has_value()) {
1729  // Remove the loop if it performs zero iterations.
1730  if (*numIterations == 0) {
1731  rewriter.replaceOp(op, op.getOutputs());
1732  return success();
1733  }
1734  // Replace the loop induction variable by the lower bound if the loop
1735  // performs a single iteration. Otherwise, copy the loop bounds.
1736  if (*numIterations == 1) {
1737  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1738  continue;
1739  }
1740  }
1741  newMixedLowerBounds.push_back(lb);
1742  newMixedUpperBounds.push_back(ub);
1743  newMixedSteps.push_back(step);
1744  }
1745 
1746  // All of the loop dimensions perform a single iteration. Inline loop body.
1747  if (newMixedLowerBounds.empty()) {
1748  promote(rewriter, op);
1749  return success();
1750  }
1751 
1752  // Exit if none of the loop dimensions perform a single iteration.
1753  if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
1754  return rewriter.notifyMatchFailure(
1755  op, "no dimensions have 0 or 1 iterations");
1756  }
1757 
1758  // Replace the loop by a lower-dimensional loop.
1759  ForallOp newOp;
1760  newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
1761  newMixedUpperBounds, newMixedSteps,
1762  op.getOutputs(), std::nullopt, nullptr);
1763  newOp.getBodyRegion().getBlocks().clear();
1764  // The new loop needs to keep all attributes from the old one, except for
1765  // "operandSegmentSizes" and static loop bound attributes which capture
1766  // the outdated information of the old iteration domain.
1767  SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
1768  newOp.getStaticLowerBoundAttrName(),
1769  newOp.getStaticUpperBoundAttrName(),
1770  newOp.getStaticStepAttrName()};
1771  for (const auto &namedAttr : op->getAttrs()) {
1772  if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
1773  continue;
1774  rewriter.modifyOpInPlace(newOp, [&]() {
1775  newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1776  });
1777  }
1778  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1779  newOp.getRegion().begin(), mapping);
1780  rewriter.replaceOp(op, newOp.getResults());
1781  return success();
1782  }
1783 };
1784 
1785 /// Replace all induction vars with a single trip count with their lower bound.
1786 struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
1788 
1789  LogicalResult matchAndRewrite(ForallOp op,
1790  PatternRewriter &rewriter) const override {
1791  Location loc = op.getLoc();
1792  bool changed = false;
1793  for (auto [lb, ub, step, iv] :
1794  llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
1795  op.getMixedStep(), op.getInductionVars())) {
1796  if (iv.getUses().begin() == iv.getUses().end())
1797  continue;
1798  auto numIterations = constantTripCount(lb, ub, step);
1799  if (!numIterations.has_value() || numIterations.value() != 1) {
1800  continue;
1801  }
1802  rewriter.replaceAllUsesWith(
1803  iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
1804  changed = true;
1805  }
1806  return success(changed);
1807  }
1808 };
1809 
1810 struct FoldTensorCastOfOutputIntoForallOp
1811  : public OpRewritePattern<scf::ForallOp> {
1813 
1814  struct TypeCast {
1815  Type srcType;
1816  Type dstType;
1817  };
1818 
1819  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1820  PatternRewriter &rewriter) const final {
1821  llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1822  llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1823  for (auto en : llvm::enumerate(newOutputTensors)) {
1824  auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1825  if (!castOp)
1826  continue;
1827 
1828  // Only casts that that preserve static information, i.e. will make the
1829  // loop result type "more" static than before, will be folded.
1830  if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1831  castOp.getSource().getType())) {
1832  continue;
1833  }
1834 
1835  tensorCastProducers[en.index()] =
1836  TypeCast{castOp.getSource().getType(), castOp.getType()};
1837  newOutputTensors[en.index()] = castOp.getSource();
1838  }
1839 
1840  if (tensorCastProducers.empty())
1841  return failure();
1842 
1843  // Create new loop.
1844  Location loc = forallOp.getLoc();
1845  auto newForallOp = rewriter.create<ForallOp>(
1846  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1847  forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1848  [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1849  auto castBlockArgs =
1850  llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1851  for (auto [index, cast] : tensorCastProducers) {
1852  Value &oldTypeBBArg = castBlockArgs[index];
1853  oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1854  nestedLoc, cast.dstType, oldTypeBBArg);
1855  }
1856 
1857  // Move old body into new parallel loop.
1858  SmallVector<Value> ivsBlockArgs =
1859  llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1860  ivsBlockArgs.append(castBlockArgs);
1861  rewriter.mergeBlocks(forallOp.getBody(),
1862  bbArgs.front().getParentBlock(), ivsBlockArgs);
1863  });
1864 
1865  // After `mergeBlocks` happened, the destinations in the terminator were
1866  // mapped to the tensor.cast old-typed results of the output bbArgs. The
1867  // destination have to be updated to point to the output bbArgs directly.
1868  auto terminator = newForallOp.getTerminator();
1869  for (auto [yieldingOp, outputBlockArg] : llvm::zip(
1870  terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
1871  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1872  insertSliceOp.getDestMutable().assign(outputBlockArg);
1873  }
1874 
1875  // Cast results back to the original types.
1876  rewriter.setInsertionPointAfter(newForallOp);
1877  SmallVector<Value> castResults = newForallOp.getResults();
1878  for (auto &item : tensorCastProducers) {
1879  Value &oldTypeResult = castResults[item.first];
1880  oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1881  oldTypeResult);
1882  }
1883  rewriter.replaceOp(forallOp, castResults);
1884  return success();
1885  }
1886 };
1887 
1888 } // namespace
1889 
1890 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
1891  MLIRContext *context) {
1892  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1893  ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1894  ForallOpSingleOrZeroIterationDimsFolder,
1895  ForallOpReplaceConstantInductionVar>(context);
1896 }
1897 
1898 /// Given the region at `index`, or the parent operation if `index` is None,
1899 /// return the successor regions. These are the regions that may be selected
1900 /// during the flow of control. `operands` is a set of optional attributes that
1901 /// correspond to a constant value for each operand, or null if that operand is
1902 /// not a constant.
1903 void ForallOp::getSuccessorRegions(RegionBranchPoint point,
1905  // Both the operation itself and the region may be branching into the body or
1906  // back into the operation itself. It is possible for loop not to enter the
1907  // body.
1908  regions.push_back(RegionSuccessor(&getRegion()));
1909  regions.push_back(RegionSuccessor());
1910 }
1911 
1912 //===----------------------------------------------------------------------===//
1913 // InParallelOp
1914 //===----------------------------------------------------------------------===//
1915 
1916 // Build a InParallelOp with mixed static and dynamic entries.
1917 void InParallelOp::build(OpBuilder &b, OperationState &result) {
1919  Region *bodyRegion = result.addRegion();
1920  b.createBlock(bodyRegion);
1921 }
1922 
1923 LogicalResult InParallelOp::verify() {
1924  scf::ForallOp forallOp =
1925  dyn_cast<scf::ForallOp>(getOperation()->getParentOp());
1926  if (!forallOp)
1927  return this->emitOpError("expected forall op parent");
1928 
1929  // TODO: InParallelOpInterface.
1930  for (Operation &op : getRegion().front().getOperations()) {
1931  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1932  return this->emitOpError("expected only ")
1933  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1934  }
1935 
1936  // Verify that inserts are into out block arguments.
1937  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1938  ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
1939  if (!llvm::is_contained(regionOutArgs, dest))
1940  return op.emitOpError("may only insert into an output block argument");
1941  }
1942  return success();
1943 }
1944 
1946  p << " ";
1947  p.printRegion(getRegion(),
1948  /*printEntryBlockArgs=*/false,
1949  /*printBlockTerminators=*/false);
1950  p.printOptionalAttrDict(getOperation()->getAttrs());
1951 }
1952 
1953 ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
1954  auto &builder = parser.getBuilder();
1955 
1957  std::unique_ptr<Region> region = std::make_unique<Region>();
1958  if (parser.parseRegion(*region, regionOperands))
1959  return failure();
1960 
1961  if (region->empty())
1962  OpBuilder(builder.getContext()).createBlock(region.get());
1963  result.addRegion(std::move(region));
1964 
1965  // Parse the optional attribute list.
1966  if (parser.parseOptionalAttrDict(result.attributes))
1967  return failure();
1968  return success();
1969 }
1970 
1971 OpResult InParallelOp::getParentResult(int64_t idx) {
1972  return getOperation()->getParentOp()->getResult(idx);
1973 }
1974 
1975 SmallVector<BlockArgument> InParallelOp::getDests() {
1976  return llvm::to_vector<4>(
1977  llvm::map_range(getYieldingOps(), [](Operation &op) {
1978  // Add new ops here as needed.
1979  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1980  return llvm::cast<BlockArgument>(insertSliceOp.getDest());
1981  }));
1982 }
1983 
1984 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
1985  return getRegion().front().getOperations();
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // IfOp
1990 //===----------------------------------------------------------------------===//
1991 
1993  assert(a && "expected non-empty operation");
1994  assert(b && "expected non-empty operation");
1995 
1996  IfOp ifOp = a->getParentOfType<IfOp>();
1997  while (ifOp) {
1998  // Check if b is inside ifOp. (We already know that a is.)
1999  if (ifOp->isProperAncestor(b))
2000  // b is contained in ifOp. a and b are in mutually exclusive branches if
2001  // they are in different blocks of ifOp.
2002  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
2003  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
2004  // Check next enclosing IfOp.
2005  ifOp = ifOp->getParentOfType<IfOp>();
2006  }
2007 
2008  // Could not find a common IfOp among a's and b's ancestors.
2009  return false;
2010 }
2011 
2012 LogicalResult
2013 IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2014  IfOp::Adaptor adaptor,
2015  SmallVectorImpl<Type> &inferredReturnTypes) {
2016  if (adaptor.getRegions().empty())
2017  return failure();
2018  Region *r = &adaptor.getThenRegion();
2019  if (r->empty())
2020  return failure();
2021  Block &b = r->front();
2022  if (b.empty())
2023  return failure();
2024  auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
2025  if (!yieldOp)
2026  return failure();
2027  TypeRange types = yieldOp.getOperandTypes();
2028  inferredReturnTypes.insert(inferredReturnTypes.end(), types.begin(),
2029  types.end());
2030  return success();
2031 }
2032 
2033 void IfOp::build(OpBuilder &builder, OperationState &result,
2034  TypeRange resultTypes, Value cond) {
2035  return build(builder, result, resultTypes, cond, /*addThenBlock=*/false,
2036  /*addElseBlock=*/false);
2037 }
2038 
2039 void IfOp::build(OpBuilder &builder, OperationState &result,
2040  TypeRange resultTypes, Value cond, bool addThenBlock,
2041  bool addElseBlock) {
2042  assert((!addElseBlock || addThenBlock) &&
2043  "must not create else block w/o then block");
2044  result.addTypes(resultTypes);
2045  result.addOperands(cond);
2046 
2047  // Add regions and blocks.
2048  OpBuilder::InsertionGuard guard(builder);
2049  Region *thenRegion = result.addRegion();
2050  if (addThenBlock)
2051  builder.createBlock(thenRegion);
2052  Region *elseRegion = result.addRegion();
2053  if (addElseBlock)
2054  builder.createBlock(elseRegion);
2055 }
2056 
2057 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2058  bool withElseRegion) {
2059  build(builder, result, TypeRange{}, cond, withElseRegion);
2060 }
2061 
2062 void IfOp::build(OpBuilder &builder, OperationState &result,
2063  TypeRange resultTypes, Value cond, bool withElseRegion) {
2064  result.addTypes(resultTypes);
2065  result.addOperands(cond);
2066 
2067  // Build then region.
2068  OpBuilder::InsertionGuard guard(builder);
2069  Region *thenRegion = result.addRegion();
2070  builder.createBlock(thenRegion);
2071  if (resultTypes.empty())
2072  IfOp::ensureTerminator(*thenRegion, builder, result.location);
2073 
2074  // Build else region.
2075  Region *elseRegion = result.addRegion();
2076  if (withElseRegion) {
2077  builder.createBlock(elseRegion);
2078  if (resultTypes.empty())
2079  IfOp::ensureTerminator(*elseRegion, builder, result.location);
2080  }
2081 }
2082 
2083 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
2084  function_ref<void(OpBuilder &, Location)> thenBuilder,
2085  function_ref<void(OpBuilder &, Location)> elseBuilder) {
2086  assert(thenBuilder && "the builder callback for 'then' must be present");
2087  result.addOperands(cond);
2088 
2089  // Build then region.
2090  OpBuilder::InsertionGuard guard(builder);
2091  Region *thenRegion = result.addRegion();
2092  builder.createBlock(thenRegion);
2093  thenBuilder(builder, result.location);
2094 
2095  // Build else region.
2096  Region *elseRegion = result.addRegion();
2097  if (elseBuilder) {
2098  builder.createBlock(elseRegion);
2099  elseBuilder(builder, result.location);
2100  }
2101 
2102  // Infer result types.
2103  SmallVector<Type> inferredReturnTypes;
2104  MLIRContext *ctx = builder.getContext();
2105  auto attrDict = DictionaryAttr::get(ctx, result.attributes);
2106  if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
2107  /*properties=*/nullptr, result.regions,
2108  inferredReturnTypes))) {
2109  result.addTypes(inferredReturnTypes);
2110  }
2111 }
2112 
2113 LogicalResult IfOp::verify() {
2114  if (getNumResults() != 0 && getElseRegion().empty())
2115  return emitOpError("must have an else block if defining values");
2116  return success();
2117 }
2118 
2119 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2120  // Create the regions for 'then'.
2121  result.regions.reserve(2);
2122  Region *thenRegion = result.addRegion();
2123  Region *elseRegion = result.addRegion();
2124 
2125  auto &builder = parser.getBuilder();
2127  Type i1Type = builder.getIntegerType(1);
2128  if (parser.parseOperand(cond) ||
2129  parser.resolveOperand(cond, i1Type, result.operands))
2130  return failure();
2131  // Parse optional results type list.
2132  if (parser.parseOptionalArrowTypeList(result.types))
2133  return failure();
2134  // Parse the 'then' region.
2135  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2136  return failure();
2137  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
2138 
2139  // If we find an 'else' keyword then parse the 'else' region.
2140  if (!parser.parseOptionalKeyword("else")) {
2141  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2142  return failure();
2143  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
2144  }
2145 
2146  // Parse the optional attribute list.
2147  if (parser.parseOptionalAttrDict(result.attributes))
2148  return failure();
2149  return success();
2150 }
2151 
2152 void IfOp::print(OpAsmPrinter &p) {
2153  bool printBlockTerminators = false;
2154 
2155  p << " " << getCondition();
2156  if (!getResults().empty()) {
2157  p << " -> (" << getResultTypes() << ")";
2158  // Print yield explicitly if the op defines values.
2159  printBlockTerminators = true;
2160  }
2161  p << ' ';
2162  p.printRegion(getThenRegion(),
2163  /*printEntryBlockArgs=*/false,
2164  /*printBlockTerminators=*/printBlockTerminators);
2165 
2166  // Print the 'else' regions if it exists and has a block.
2167  auto &elseRegion = getElseRegion();
2168  if (!elseRegion.empty()) {
2169  p << " else ";
2170  p.printRegion(elseRegion,
2171  /*printEntryBlockArgs=*/false,
2172  /*printBlockTerminators=*/printBlockTerminators);
2173  }
2174 
2175  p.printOptionalAttrDict((*this)->getAttrs());
2176 }
2177 
2178 void IfOp::getSuccessorRegions(RegionBranchPoint point,
2180  // The `then` and the `else` region branch back to the parent operation.
2181  if (!point.isParent()) {
2182  regions.push_back(RegionSuccessor(getResults()));
2183  return;
2184  }
2185 
2186  regions.push_back(RegionSuccessor(&getThenRegion()));
2187 
2188  // Don't consider the else region if it is empty.
2189  Region *elseRegion = &this->getElseRegion();
2190  if (elseRegion->empty())
2191  regions.push_back(RegionSuccessor());
2192  else
2193  regions.push_back(RegionSuccessor(elseRegion));
2194 }
2195 
2196 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
2198  FoldAdaptor adaptor(operands, *this);
2199  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
2200  if (!boolAttr || boolAttr.getValue())
2201  regions.emplace_back(&getThenRegion());
2202 
2203  // If the else region is empty, execution continues after the parent op.
2204  if (!boolAttr || !boolAttr.getValue()) {
2205  if (!getElseRegion().empty())
2206  regions.emplace_back(&getElseRegion());
2207  else
2208  regions.emplace_back(getResults());
2209  }
2210 }
2211 
2212 LogicalResult IfOp::fold(FoldAdaptor adaptor,
2213  SmallVectorImpl<OpFoldResult> &results) {
2214  // if (!c) then A() else B() -> if c then B() else A()
2215  if (getElseRegion().empty())
2216  return failure();
2217 
2218  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
2219  if (!xorStmt)
2220  return failure();
2221 
2222  if (!matchPattern(xorStmt.getRhs(), m_One()))
2223  return failure();
2224 
2225  getConditionMutable().assign(xorStmt.getLhs());
2226  Block *thenBlock = &getThenRegion().front();
2227  // It would be nicer to use iplist::swap, but that has no implemented
2228  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
2229  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
2230  getElseRegion().getBlocks());
2231  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
2232  getThenRegion().getBlocks(), thenBlock);
2233  return success();
2234 }
2235 
2236 void IfOp::getRegionInvocationBounds(
2237  ArrayRef<Attribute> operands,
2238  SmallVectorImpl<InvocationBounds> &invocationBounds) {
2239  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
2240  // If the condition is known, then one region is known to be executed once
2241  // and the other zero times.
2242  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
2243  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
2244  } else {
2245  // Non-constant condition. Each region may be executed 0 or 1 times.
2246  invocationBounds.assign(2, {0, 1});
2247  }
2248 }
2249 
2250 namespace {
2251 // Pattern to remove unused IfOp results.
2252 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
2254 
2255  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
2256  PatternRewriter &rewriter) const {
2257  // Move all operations to the destination block.
2258  rewriter.mergeBlocks(source, dest);
2259  // Replace the yield op by one that returns only the used values.
2260  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
2261  SmallVector<Value, 4> usedOperands;
2262  llvm::transform(usedResults, std::back_inserter(usedOperands),
2263  [&](OpResult result) {
2264  return yieldOp.getOperand(result.getResultNumber());
2265  });
2266  rewriter.modifyOpInPlace(yieldOp,
2267  [&]() { yieldOp->setOperands(usedOperands); });
2268  }
2269 
2270  LogicalResult matchAndRewrite(IfOp op,
2271  PatternRewriter &rewriter) const override {
2272  // Compute the list of used results.
2273  SmallVector<OpResult, 4> usedResults;
2274  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
2275  [](OpResult result) { return !result.use_empty(); });
2276 
2277  // Replace the operation if only a subset of its results have uses.
2278  if (usedResults.size() == op.getNumResults())
2279  return failure();
2280 
2281  // Compute the result types of the replacement operation.
2282  SmallVector<Type, 4> newTypes;
2283  llvm::transform(usedResults, std::back_inserter(newTypes),
2284  [](OpResult result) { return result.getType(); });
2285 
2286  // Create a replacement operation with empty then and else regions.
2287  auto newOp =
2288  rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition());
2289  rewriter.createBlock(&newOp.getThenRegion());
2290  rewriter.createBlock(&newOp.getElseRegion());
2291 
2292  // Move the bodies and replace the terminators (note there is a then and
2293  // an else region since the operation returns results).
2294  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
2295  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
2296 
2297  // Replace the operation by the new one.
2298  SmallVector<Value, 4> repResults(op.getNumResults());
2299  for (const auto &en : llvm::enumerate(usedResults))
2300  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
2301  rewriter.replaceOp(op, repResults);
2302  return success();
2303  }
2304 };
2305 
2306 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
2308 
2309  LogicalResult matchAndRewrite(IfOp op,
2310  PatternRewriter &rewriter) const override {
2311  BoolAttr condition;
2312  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
2313  return failure();
2314 
2315  if (condition.getValue())
2316  replaceOpWithRegion(rewriter, op, op.getThenRegion());
2317  else if (!op.getElseRegion().empty())
2318  replaceOpWithRegion(rewriter, op, op.getElseRegion());
2319  else
2320  rewriter.eraseOp(op);
2321 
2322  return success();
2323  }
2324 };
2325 
2326 /// Hoist any yielded results whose operands are defined outside
2327 /// the if, to a select instruction.
2328 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
2330 
2331  LogicalResult matchAndRewrite(IfOp op,
2332  PatternRewriter &rewriter) const override {
2333  if (op->getNumResults() == 0)
2334  return failure();
2335 
2336  auto cond = op.getCondition();
2337  auto thenYieldArgs = op.thenYield().getOperands();
2338  auto elseYieldArgs = op.elseYield().getOperands();
2339 
2340  SmallVector<Type> nonHoistable;
2341  for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
2342  if (&op.getThenRegion() == trueVal.getParentRegion() ||
2343  &op.getElseRegion() == falseVal.getParentRegion())
2344  nonHoistable.push_back(trueVal.getType());
2345  }
2346  // Early exit if there aren't any yielded values we can
2347  // hoist outside the if.
2348  if (nonHoistable.size() == op->getNumResults())
2349  return failure();
2350 
2351  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond,
2352  /*withElseRegion=*/false);
2353  if (replacement.thenBlock())
2354  rewriter.eraseBlock(replacement.thenBlock());
2355  replacement.getThenRegion().takeBody(op.getThenRegion());
2356  replacement.getElseRegion().takeBody(op.getElseRegion());
2357 
2358  SmallVector<Value> results(op->getNumResults());
2359  assert(thenYieldArgs.size() == results.size());
2360  assert(elseYieldArgs.size() == results.size());
2361 
2362  SmallVector<Value> trueYields;
2363  SmallVector<Value> falseYields;
2364  rewriter.setInsertionPoint(replacement);
2365  for (const auto &it :
2366  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
2367  Value trueVal = std::get<0>(it.value());
2368  Value falseVal = std::get<1>(it.value());
2369  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
2370  &replacement.getElseRegion() == falseVal.getParentRegion()) {
2371  results[it.index()] = replacement.getResult(trueYields.size());
2372  trueYields.push_back(trueVal);
2373  falseYields.push_back(falseVal);
2374  } else if (trueVal == falseVal)
2375  results[it.index()] = trueVal;
2376  else
2377  results[it.index()] = rewriter.create<arith::SelectOp>(
2378  op.getLoc(), cond, trueVal, falseVal);
2379  }
2380 
2381  rewriter.setInsertionPointToEnd(replacement.thenBlock());
2382  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
2383 
2384  rewriter.setInsertionPointToEnd(replacement.elseBlock());
2385  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
2386 
2387  rewriter.replaceOp(op, results);
2388  return success();
2389  }
2390 };
2391 
2392 /// Allow the true region of an if to assume the condition is true
2393 /// and vice versa. For example:
2394 ///
2395 /// scf.if %cmp {
2396 /// print(%cmp)
2397 /// }
2398 ///
2399 /// becomes
2400 ///
2401 /// scf.if %cmp {
2402 /// print(true)
2403 /// }
2404 ///
2405 struct ConditionPropagation : public OpRewritePattern<IfOp> {
2407 
2408  LogicalResult matchAndRewrite(IfOp op,
2409  PatternRewriter &rewriter) const override {
2410  // Early exit if the condition is constant since replacing a constant
2411  // in the body with another constant isn't a simplification.
2412  if (matchPattern(op.getCondition(), m_Constant()))
2413  return failure();
2414 
2415  bool changed = false;
2416  mlir::Type i1Ty = rewriter.getI1Type();
2417 
2418  // These variables serve to prevent creating duplicate constants
2419  // and hold constant true or false values.
2420  Value constantTrue = nullptr;
2421  Value constantFalse = nullptr;
2422 
2423  for (OpOperand &use :
2424  llvm::make_early_inc_range(op.getCondition().getUses())) {
2425  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
2426  changed = true;
2427 
2428  if (!constantTrue)
2429  constantTrue = rewriter.create<arith::ConstantOp>(
2430  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
2431 
2432  rewriter.modifyOpInPlace(use.getOwner(),
2433  [&]() { use.set(constantTrue); });
2434  } else if (op.getElseRegion().isAncestor(
2435  use.getOwner()->getParentRegion())) {
2436  changed = true;
2437 
2438  if (!constantFalse)
2439  constantFalse = rewriter.create<arith::ConstantOp>(
2440  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
2441 
2442  rewriter.modifyOpInPlace(use.getOwner(),
2443  [&]() { use.set(constantFalse); });
2444  }
2445  }
2446 
2447  return success(changed);
2448  }
2449 };
2450 
2451 /// Remove any statements from an if that are equivalent to the condition
2452 /// or its negation. For example:
2453 ///
2454 /// %res:2 = scf.if %cmp {
2455 /// yield something(), true
2456 /// } else {
2457 /// yield something2(), false
2458 /// }
2459 /// print(%res#1)
2460 ///
2461 /// becomes
2462 /// %res = scf.if %cmp {
2463 /// yield something()
2464 /// } else {
2465 /// yield something2()
2466 /// }
2467 /// print(%cmp)
2468 ///
2469 /// Additionally if both branches yield the same value, replace all uses
2470 /// of the result with the yielded value.
2471 ///
2472 /// %res:2 = scf.if %cmp {
2473 /// yield something(), %arg1
2474 /// } else {
2475 /// yield something2(), %arg1
2476 /// }
2477 /// print(%res#1)
2478 ///
2479 /// becomes
2480 /// %res = scf.if %cmp {
2481 /// yield something()
2482 /// } else {
2483 /// yield something2()
2484 /// }
2485 /// print(%arg1)
2486 ///
2487 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
2489 
2490  LogicalResult matchAndRewrite(IfOp op,
2491  PatternRewriter &rewriter) const override {
2492  // Early exit if there are no results that could be replaced.
2493  if (op.getNumResults() == 0)
2494  return failure();
2495 
2496  auto trueYield =
2497  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
2498  auto falseYield =
2499  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
2500 
2501  rewriter.setInsertionPoint(op->getBlock(),
2502  op.getOperation()->getIterator());
2503  bool changed = false;
2504  Type i1Ty = rewriter.getI1Type();
2505  for (auto [trueResult, falseResult, opResult] :
2506  llvm::zip(trueYield.getResults(), falseYield.getResults(),
2507  op.getResults())) {
2508  if (trueResult == falseResult) {
2509  if (!opResult.use_empty()) {
2510  opResult.replaceAllUsesWith(trueResult);
2511  changed = true;
2512  }
2513  continue;
2514  }
2515 
2516  BoolAttr trueYield, falseYield;
2517  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
2518  !matchPattern(falseResult, m_Constant(&falseYield)))
2519  continue;
2520 
2521  bool trueVal = trueYield.getValue();
2522  bool falseVal = falseYield.getValue();
2523  if (!trueVal && falseVal) {
2524  if (!opResult.use_empty()) {
2525  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
2526  Value notCond = rewriter.create<arith::XOrIOp>(
2527  op.getLoc(), op.getCondition(),
2528  constDialect
2529  ->materializeConstant(rewriter,
2530  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
2531  op.getLoc())
2532  ->getResult(0));
2533  opResult.replaceAllUsesWith(notCond);
2534  changed = true;
2535  }
2536  }
2537  if (trueVal && !falseVal) {
2538  if (!opResult.use_empty()) {
2539  opResult.replaceAllUsesWith(op.getCondition());
2540  changed = true;
2541  }
2542  }
2543  }
2544  return success(changed);
2545  }
2546 };
2547 
2548 /// Merge any consecutive scf.if's with the same condition.
2549 ///
2550 /// scf.if %cond {
2551 /// firstCodeTrue();...
2552 /// } else {
2553 /// firstCodeFalse();...
2554 /// }
2555 /// %res = scf.if %cond {
2556 /// secondCodeTrue();...
2557 /// } else {
2558 /// secondCodeFalse();...
2559 /// }
2560 ///
2561 /// becomes
2562 /// %res = scf.if %cmp {
2563 /// firstCodeTrue();...
2564 /// secondCodeTrue();...
2565 /// } else {
2566 /// firstCodeFalse();...
2567 /// secondCodeFalse();...
2568 /// }
2569 struct CombineIfs : public OpRewritePattern<IfOp> {
2571 
2572  LogicalResult matchAndRewrite(IfOp nextIf,
2573  PatternRewriter &rewriter) const override {
2574  Block *parent = nextIf->getBlock();
2575  if (nextIf == &parent->front())
2576  return failure();
2577 
2578  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
2579  if (!prevIf)
2580  return failure();
2581 
2582  // Determine the logical then/else blocks when prevIf's
2583  // condition is used. Null means the block does not exist
2584  // in that case (e.g. empty else). If neither of these
2585  // are set, the two conditions cannot be compared.
2586  Block *nextThen = nullptr;
2587  Block *nextElse = nullptr;
2588  if (nextIf.getCondition() == prevIf.getCondition()) {
2589  nextThen = nextIf.thenBlock();
2590  if (!nextIf.getElseRegion().empty())
2591  nextElse = nextIf.elseBlock();
2592  }
2593  if (arith::XOrIOp notv =
2594  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2595  if (notv.getLhs() == prevIf.getCondition() &&
2596  matchPattern(notv.getRhs(), m_One())) {
2597  nextElse = nextIf.thenBlock();
2598  if (!nextIf.getElseRegion().empty())
2599  nextThen = nextIf.elseBlock();
2600  }
2601  }
2602  if (arith::XOrIOp notv =
2603  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
2604  if (notv.getLhs() == nextIf.getCondition() &&
2605  matchPattern(notv.getRhs(), m_One())) {
2606  nextElse = nextIf.thenBlock();
2607  if (!nextIf.getElseRegion().empty())
2608  nextThen = nextIf.elseBlock();
2609  }
2610  }
2611 
2612  if (!nextThen && !nextElse)
2613  return failure();
2614 
2615  SmallVector<Value> prevElseYielded;
2616  if (!prevIf.getElseRegion().empty())
2617  prevElseYielded = prevIf.elseYield().getOperands();
2618  // Replace all uses of return values of op within nextIf with the
2619  // corresponding yields
2620  for (auto it : llvm::zip(prevIf.getResults(),
2621  prevIf.thenYield().getOperands(), prevElseYielded))
2622  for (OpOperand &use :
2623  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2624  if (nextThen && nextThen->getParent()->isAncestor(
2625  use.getOwner()->getParentRegion())) {
2626  rewriter.startOpModification(use.getOwner());
2627  use.set(std::get<1>(it));
2628  rewriter.finalizeOpModification(use.getOwner());
2629  } else if (nextElse && nextElse->getParent()->isAncestor(
2630  use.getOwner()->getParentRegion())) {
2631  rewriter.startOpModification(use.getOwner());
2632  use.set(std::get<2>(it));
2633  rewriter.finalizeOpModification(use.getOwner());
2634  }
2635  }
2636 
2637  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2638  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2639 
2640  IfOp combinedIf = rewriter.create<IfOp>(
2641  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2642  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2643 
2644  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2645  combinedIf.getThenRegion(),
2646  combinedIf.getThenRegion().begin());
2647 
2648  if (nextThen) {
2649  YieldOp thenYield = combinedIf.thenYield();
2650  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2651  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2652  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2653 
2654  SmallVector<Value> mergedYields(thenYield.getOperands());
2655  llvm::append_range(mergedYields, thenYield2.getOperands());
2656  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2657  rewriter.eraseOp(thenYield);
2658  rewriter.eraseOp(thenYield2);
2659  }
2660 
2661  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2662  combinedIf.getElseRegion(),
2663  combinedIf.getElseRegion().begin());
2664 
2665  if (nextElse) {
2666  if (combinedIf.getElseRegion().empty()) {
2667  rewriter.inlineRegionBefore(*nextElse->getParent(),
2668  combinedIf.getElseRegion(),
2669  combinedIf.getElseRegion().begin());
2670  } else {
2671  YieldOp elseYield = combinedIf.elseYield();
2672  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2673  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2674 
2675  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2676 
2677  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2678  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2679 
2680  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2681  rewriter.eraseOp(elseYield);
2682  rewriter.eraseOp(elseYield2);
2683  }
2684  }
2685 
2686  SmallVector<Value> prevValues;
2687  SmallVector<Value> nextValues;
2688  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2689  if (pair.index() < prevIf.getNumResults())
2690  prevValues.push_back(pair.value());
2691  else
2692  nextValues.push_back(pair.value());
2693  }
2694  rewriter.replaceOp(prevIf, prevValues);
2695  rewriter.replaceOp(nextIf, nextValues);
2696  return success();
2697  }
2698 };
2699 
2700 /// Pattern to remove an empty else branch.
2701 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2703 
2704  LogicalResult matchAndRewrite(IfOp ifOp,
2705  PatternRewriter &rewriter) const override {
2706  // Cannot remove else region when there are operation results.
2707  if (ifOp.getNumResults())
2708  return failure();
2709  Block *elseBlock = ifOp.elseBlock();
2710  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2711  return failure();
2712  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2713  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2714  newIfOp.getThenRegion().begin());
2715  rewriter.eraseOp(ifOp);
2716  return success();
2717  }
2718 };
2719 
2720 /// Convert nested `if`s into `arith.andi` + single `if`.
2721 ///
2722 /// scf.if %arg0 {
2723 /// scf.if %arg1 {
2724 /// ...
2725 /// scf.yield
2726 /// }
2727 /// scf.yield
2728 /// }
2729 /// becomes
2730 ///
2731 /// %0 = arith.andi %arg0, %arg1
2732 /// scf.if %0 {
2733 /// ...
2734 /// scf.yield
2735 /// }
2736 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2738 
2739  LogicalResult matchAndRewrite(IfOp op,
2740  PatternRewriter &rewriter) const override {
2741  auto nestedOps = op.thenBlock()->without_terminator();
2742  // Nested `if` must be the only op in block.
2743  if (!llvm::hasSingleElement(nestedOps))
2744  return failure();
2745 
2746  // If there is an else block, it can only yield
2747  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2748  return failure();
2749 
2750  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2751  if (!nestedIf)
2752  return failure();
2753 
2754  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2755  return failure();
2756 
2757  SmallVector<Value> thenYield(op.thenYield().getOperands());
2758  SmallVector<Value> elseYield;
2759  if (op.elseBlock())
2760  llvm::append_range(elseYield, op.elseYield().getOperands());
2761 
2762  // A list of indices for which we should upgrade the value yielded
2763  // in the else to a select.
2764  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2765 
2766  // If the outer scf.if yields a value produced by the inner scf.if,
2767  // only permit combining if the value yielded when the condition
2768  // is false in the outer scf.if is the same value yielded when the
2769  // inner scf.if condition is false.
2770  // Note that the array access to elseYield will not go out of bounds
2771  // since it must have the same length as thenYield, since they both
2772  // come from the same scf.if.
2773  for (const auto &tup : llvm::enumerate(thenYield)) {
2774  if (tup.value().getDefiningOp() == nestedIf) {
2775  auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
2776  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2777  elseYield[tup.index()]) {
2778  return failure();
2779  }
2780  // If the correctness test passes, we will yield
2781  // corresponding value from the inner scf.if
2782  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2783  continue;
2784  }
2785 
2786  // Otherwise, we need to ensure the else block of the combined
2787  // condition still returns the same value when the outer condition is
2788  // true and the inner condition is false. This can be accomplished if
2789  // the then value is defined outside the outer scf.if and we replace the
2790  // value with a select that considers just the outer condition. Since
2791  // the else region contains just the yield, its yielded value is
2792  // defined outside the scf.if, by definition.
2793 
2794  // If the then value is defined within the scf.if, bail.
2795  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2796  return failure();
2797  }
2798  elseYieldsToUpgradeToSelect.push_back(tup.index());
2799  }
2800 
2801  Location loc = op.getLoc();
2802  Value newCondition = rewriter.create<arith::AndIOp>(
2803  loc, op.getCondition(), nestedIf.getCondition());
2804  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2805  Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2806 
2807  SmallVector<Value> results;
2808  llvm::append_range(results, newIf.getResults());
2809  rewriter.setInsertionPoint(newIf);
2810 
2811  for (auto idx : elseYieldsToUpgradeToSelect)
2812  results[idx] = rewriter.create<arith::SelectOp>(
2813  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2814 
2815  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2816  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2817  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2818  if (!elseYield.empty()) {
2819  rewriter.createBlock(&newIf.getElseRegion());
2820  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2821  rewriter.create<YieldOp>(loc, elseYield);
2822  }
2823  rewriter.replaceOp(op, results);
2824  return success();
2825  }
2826 };
2827 
2828 } // namespace
2829 
2830 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2831  MLIRContext *context) {
2832  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2833  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2834  RemoveStaticCondition, RemoveUnusedResults,
2835  ReplaceIfYieldWithConditionOrValue>(context);
2836 }
2837 
2838 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2839 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2840 Block *IfOp::elseBlock() {
2841  Region &r = getElseRegion();
2842  if (r.empty())
2843  return nullptr;
2844  return &r.back();
2845 }
2846 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2847 
2848 //===----------------------------------------------------------------------===//
2849 // ParallelOp
2850 //===----------------------------------------------------------------------===//
2851 
2852 void ParallelOp::build(
2853  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2854  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2856  bodyBuilderFn) {
2857  result.addOperands(lowerBounds);
2858  result.addOperands(upperBounds);
2859  result.addOperands(steps);
2860  result.addOperands(initVals);
2861  result.addAttribute(
2862  ParallelOp::getOperandSegmentSizeAttr(),
2863  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2864  static_cast<int32_t>(upperBounds.size()),
2865  static_cast<int32_t>(steps.size()),
2866  static_cast<int32_t>(initVals.size())}));
2867  result.addTypes(initVals.getTypes());
2868 
2869  OpBuilder::InsertionGuard guard(builder);
2870  unsigned numIVs = steps.size();
2871  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2872  SmallVector<Location, 8> argLocs(numIVs, result.location);
2873  Region *bodyRegion = result.addRegion();
2874  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2875 
2876  if (bodyBuilderFn) {
2877  builder.setInsertionPointToStart(bodyBlock);
2878  bodyBuilderFn(builder, result.location,
2879  bodyBlock->getArguments().take_front(numIVs),
2880  bodyBlock->getArguments().drop_front(numIVs));
2881  }
2882  // Add terminator only if there are no reductions.
2883  if (initVals.empty())
2884  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2885 }
2886 
2887 void ParallelOp::build(
2888  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2889  ValueRange upperBounds, ValueRange steps,
2890  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2891  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2892  // we don't capture a reference to a temporary by constructing the lambda at
2893  // function level.
2894  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2895  Location nestedLoc, ValueRange ivs,
2896  ValueRange) {
2897  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2898  };
2899  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2900  if (bodyBuilderFn)
2901  wrapper = wrappedBuilderFn;
2902 
2903  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2904  wrapper);
2905 }
2906 
2907 LogicalResult ParallelOp::verify() {
2908  // Check that there is at least one value in lowerBound, upperBound and step.
2909  // It is sufficient to test only step, because it is ensured already that the
2910  // number of elements in lowerBound, upperBound and step are the same.
2911  Operation::operand_range stepValues = getStep();
2912  if (stepValues.empty())
2913  return emitOpError(
2914  "needs at least one tuple element for lowerBound, upperBound and step");
2915 
2916  // Check whether all constant step values are positive.
2917  for (Value stepValue : stepValues)
2918  if (auto cst = getConstantIntValue(stepValue))
2919  if (*cst <= 0)
2920  return emitOpError("constant step operand must be positive");
2921 
2922  // Check that the body defines the same number of block arguments as the
2923  // number of tuple elements in step.
2924  Block *body = getBody();
2925  if (body->getNumArguments() != stepValues.size())
2926  return emitOpError() << "expects the same number of induction variables: "
2927  << body->getNumArguments()
2928  << " as bound and step values: " << stepValues.size();
2929  for (auto arg : body->getArguments())
2930  if (!arg.getType().isIndex())
2931  return emitOpError(
2932  "expects arguments for the induction variable to be of index type");
2933 
2934  // Check that the terminator is an scf.reduce op.
2935  auto reduceOp = verifyAndGetTerminator<scf::ReduceOp>(
2936  *this, getRegion(), "expects body to terminate with 'scf.reduce'");
2937  if (!reduceOp)
2938  return failure();
2939 
2940  // Check that the number of results is the same as the number of reductions.
2941  auto resultsSize = getResults().size();
2942  auto reductionsSize = reduceOp.getReductions().size();
2943  auto initValsSize = getInitVals().size();
2944  if (resultsSize != reductionsSize)
2945  return emitOpError() << "expects number of results: " << resultsSize
2946  << " to be the same as number of reductions: "
2947  << reductionsSize;
2948  if (resultsSize != initValsSize)
2949  return emitOpError() << "expects number of results: " << resultsSize
2950  << " to be the same as number of initial values: "
2951  << initValsSize;
2952 
2953  // Check that the types of the results and reductions are the same.
2954  for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
2955  auto resultType = getOperation()->getResult(i).getType();
2956  auto reductionOperandType = reduceOp.getOperands()[i].getType();
2957  if (resultType != reductionOperandType)
2958  return reduceOp.emitOpError()
2959  << "expects type of " << i
2960  << "-th reduction operand: " << reductionOperandType
2961  << " to be the same as the " << i
2962  << "-th result type: " << resultType;
2963  }
2964  return success();
2965 }
2966 
2967 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2968  auto &builder = parser.getBuilder();
2969  // Parse an opening `(` followed by induction variables followed by `)`
2972  return failure();
2973 
2974  // Parse loop bounds.
2976  if (parser.parseEqual() ||
2977  parser.parseOperandList(lower, ivs.size(),
2979  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2980  return failure();
2981 
2983  if (parser.parseKeyword("to") ||
2984  parser.parseOperandList(upper, ivs.size(),
2986  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2987  return failure();
2988 
2989  // Parse step values.
2991  if (parser.parseKeyword("step") ||
2992  parser.parseOperandList(steps, ivs.size(),
2994  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2995  return failure();
2996 
2997  // Parse init values.
2999  if (succeeded(parser.parseOptionalKeyword("init"))) {
3000  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
3001  return failure();
3002  }
3003 
3004  // Parse optional results in case there is a reduce.
3005  if (parser.parseOptionalArrowTypeList(result.types))
3006  return failure();
3007 
3008  // Now parse the body.
3009  Region *body = result.addRegion();
3010  for (auto &iv : ivs)
3011  iv.type = builder.getIndexType();
3012  if (parser.parseRegion(*body, ivs))
3013  return failure();
3014 
3015  // Set `operandSegmentSizes` attribute.
3016  result.addAttribute(
3017  ParallelOp::getOperandSegmentSizeAttr(),
3018  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
3019  static_cast<int32_t>(upper.size()),
3020  static_cast<int32_t>(steps.size()),
3021  static_cast<int32_t>(initVals.size())}));
3022 
3023  // Parse attributes.
3024  if (parser.parseOptionalAttrDict(result.attributes) ||
3025  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
3026  result.operands))
3027  return failure();
3028 
3029  // Add a terminator if none was parsed.
3030  ParallelOp::ensureTerminator(*body, builder, result.location);
3031  return success();
3032 }
3033 
3035  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
3036  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
3037  if (!getInitVals().empty())
3038  p << " init (" << getInitVals() << ")";
3039  p.printOptionalArrowTypeList(getResultTypes());
3040  p << ' ';
3041  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
3043  (*this)->getAttrs(),
3044  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
3045 }
3046 
3047 SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
3048 
3049 std::optional<SmallVector<Value>> ParallelOp::getLoopInductionVars() {
3050  return SmallVector<Value>{getBody()->getArguments()};
3051 }
3052 
3053 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopLowerBounds() {
3054  return getLowerBound();
3055 }
3056 
3057 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopUpperBounds() {
3058  return getUpperBound();
3059 }
3060 
3061 std::optional<SmallVector<OpFoldResult>> ParallelOp::getLoopSteps() {
3062  return getStep();
3063 }
3064 
3066  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
3067  if (!ivArg)
3068  return ParallelOp();
3069  assert(ivArg.getOwner() && "unlinked block argument");
3070  auto *containingOp = ivArg.getOwner()->getParentOp();
3071  return dyn_cast<ParallelOp>(containingOp);
3072 }
3073 
3074 namespace {
3075 // Collapse loop dimensions that perform a single iteration.
3076 struct ParallelOpSingleOrZeroIterationDimsFolder
3077  : public OpRewritePattern<ParallelOp> {
3079 
3080  LogicalResult matchAndRewrite(ParallelOp op,
3081  PatternRewriter &rewriter) const override {
3082  Location loc = op.getLoc();
3083 
3084  // Compute new loop bounds that omit all single-iteration loop dimensions.
3085  SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
3086  IRMapping mapping;
3087  for (auto [lb, ub, step, iv] :
3088  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
3089  op.getInductionVars())) {
3090  auto numIterations = constantTripCount(lb, ub, step);
3091  if (numIterations.has_value()) {
3092  // Remove the loop if it performs zero iterations.
3093  if (*numIterations == 0) {
3094  rewriter.replaceOp(op, op.getInitVals());
3095  return success();
3096  }
3097  // Replace the loop induction variable by the lower bound if the loop
3098  // performs a single iteration. Otherwise, copy the loop bounds.
3099  if (*numIterations == 1) {
3100  mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
3101  continue;
3102  }
3103  }
3104  newLowerBounds.push_back(lb);
3105  newUpperBounds.push_back(ub);
3106  newSteps.push_back(step);
3107  }
3108  // Exit if none of the loop dimensions perform a single iteration.
3109  if (newLowerBounds.size() == op.getLowerBound().size())
3110  return failure();
3111 
3112  if (newLowerBounds.empty()) {
3113  // All of the loop dimensions perform a single iteration. Inline
3114  // loop body and nested ReduceOp's
3115  SmallVector<Value> results;
3116  results.reserve(op.getInitVals().size());
3117  for (auto &bodyOp : op.getBody()->without_terminator())
3118  rewriter.clone(bodyOp, mapping);
3119  auto reduceOp = cast<ReduceOp>(op.getBody()->getTerminator());
3120  for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) {
3121  Block &reduceBlock = reduceOp.getReductions()[i].front();
3122  auto initValIndex = results.size();
3123  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
3124  mapping.map(reduceBlock.getArgument(1),
3125  mapping.lookupOrDefault(reduceOp.getOperands()[i]));
3126  for (auto &reduceBodyOp : reduceBlock.without_terminator())
3127  rewriter.clone(reduceBodyOp, mapping);
3128 
3129  auto result = mapping.lookupOrDefault(
3130  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
3131  results.push_back(result);
3132  }
3133 
3134  rewriter.replaceOp(op, results);
3135  return success();
3136  }
3137  // Replace the parallel loop by lower-dimensional parallel loop.
3138  auto newOp =
3139  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
3140  newSteps, op.getInitVals(), nullptr);
3141  // Erase the empty block that was inserted by the builder.
3142  rewriter.eraseBlock(newOp.getBody());
3143  // Clone the loop body and remap the block arguments of the collapsed loops
3144  // (inlining does not support a cancellable block argument mapping).
3145  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
3146  newOp.getRegion().begin(), mapping);
3147  rewriter.replaceOp(op, newOp.getResults());
3148  return success();
3149  }
3150 };
3151 
3152 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
3154 
3155  LogicalResult matchAndRewrite(ParallelOp op,
3156  PatternRewriter &rewriter) const override {
3157  Block &outerBody = *op.getBody();
3158  if (!llvm::hasSingleElement(outerBody.without_terminator()))
3159  return failure();
3160 
3161  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
3162  if (!innerOp)
3163  return failure();
3164 
3165  for (auto val : outerBody.getArguments())
3166  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
3167  llvm::is_contained(innerOp.getUpperBound(), val) ||
3168  llvm::is_contained(innerOp.getStep(), val))
3169  return failure();
3170 
3171  // Reductions are not supported yet.
3172  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
3173  return failure();
3174 
3175  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
3176  ValueRange iterVals, ValueRange) {
3177  Block &innerBody = *innerOp.getBody();
3178  assert(iterVals.size() ==
3179  (outerBody.getNumArguments() + innerBody.getNumArguments()));
3180  IRMapping mapping;
3181  mapping.map(outerBody.getArguments(),
3182  iterVals.take_front(outerBody.getNumArguments()));
3183  mapping.map(innerBody.getArguments(),
3184  iterVals.take_back(innerBody.getNumArguments()));
3185  for (Operation &op : innerBody.without_terminator())
3186  builder.clone(op, mapping);
3187  };
3188 
3189  auto concatValues = [](const auto &first, const auto &second) {
3190  SmallVector<Value> ret;
3191  ret.reserve(first.size() + second.size());
3192  ret.assign(first.begin(), first.end());
3193  ret.append(second.begin(), second.end());
3194  return ret;
3195  };
3196 
3197  auto newLowerBounds =
3198  concatValues(op.getLowerBound(), innerOp.getLowerBound());
3199  auto newUpperBounds =
3200  concatValues(op.getUpperBound(), innerOp.getUpperBound());
3201  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
3202 
3203  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
3204  newSteps, std::nullopt,
3205  bodyBuilder);
3206  return success();
3207  }
3208 };
3209 
3210 } // namespace
3211 
3212 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
3213  MLIRContext *context) {
3214  results
3215  .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
3216  context);
3217 }
3218 
3219 /// Given the region at `index`, or the parent operation if `index` is None,
3220 /// return the successor regions. These are the regions that may be selected
3221 /// during the flow of control. `operands` is a set of optional attributes that
3222 /// correspond to a constant value for each operand, or null if that operand is
3223 /// not a constant.
3224 void ParallelOp::getSuccessorRegions(
3226  // Both the operation itself and the region may be branching into the body or
3227  // back into the operation itself. It is possible for loop not to enter the
3228  // body.
3229  regions.push_back(RegionSuccessor(&getRegion()));
3230  regions.push_back(RegionSuccessor());
3231 }
3232 
3233 //===----------------------------------------------------------------------===//
3234 // ReduceOp
3235 //===----------------------------------------------------------------------===//
3236 
3237 void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
3238 
3239 void ReduceOp::build(OpBuilder &builder, OperationState &result,
3240  ValueRange operands) {
3241  result.addOperands(operands);
3242  for (Value v : operands) {
3243  OpBuilder::InsertionGuard guard(builder);
3244  Region *bodyRegion = result.addRegion();
3245  builder.createBlock(bodyRegion, {},
3246  ArrayRef<Type>{v.getType(), v.getType()},
3247  {result.location, result.location});
3248  }
3249 }
3250 
3251 LogicalResult ReduceOp::verifyRegions() {
3252  // The region of a ReduceOp has two arguments of the same type as its
3253  // corresponding operand.
3254  for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {
3255  auto type = getOperands()[i].getType();
3256  Block &block = getReductions()[i].front();
3257  if (block.empty())
3258  return emitOpError() << i << "-th reduction has an empty body";
3259  if (block.getNumArguments() != 2 ||
3260  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
3261  return arg.getType() != type;
3262  }))
3263  return emitOpError() << "expected two block arguments with type " << type
3264  << " in the " << i << "-th reduction region";
3265 
3266  // Check that the block is terminated by a ReduceReturnOp.
3267  if (!isa<ReduceReturnOp>(block.getTerminator()))
3268  return emitOpError("reduction bodies must be terminated with an "
3269  "'scf.reduce.return' op");
3270  }
3271 
3272  return success();
3273 }
3274 
3277  // No operands are forwarded to the next iteration.
3278  return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
3279 }
3280 
3281 //===----------------------------------------------------------------------===//
3282 // ReduceReturnOp
3283 //===----------------------------------------------------------------------===//
3284 
3285 LogicalResult ReduceReturnOp::verify() {
3286  // The type of the return value should be the same type as the types of the
3287  // block arguments of the reduction body.
3288  Block *reductionBody = getOperation()->getBlock();
3289  // Should already be verified by an op trait.
3290  assert(isa<ReduceOp>(reductionBody->getParentOp()) && "expected scf.reduce");
3291  Type expectedResultType = reductionBody->getArgument(0).getType();
3292  if (expectedResultType != getResult().getType())
3293  return emitOpError() << "must have type " << expectedResultType
3294  << " (the type of the reduction inputs)";
3295  return success();
3296 }
3297 
3298 //===----------------------------------------------------------------------===//
3299 // WhileOp
3300 //===----------------------------------------------------------------------===//
3301 
3302 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
3303  ::mlir::OperationState &odsState, TypeRange resultTypes,
3304  ValueRange inits, BodyBuilderFn beforeBuilder,
3305  BodyBuilderFn afterBuilder) {
3306  odsState.addOperands(inits);
3307  odsState.addTypes(resultTypes);
3308 
3309  OpBuilder::InsertionGuard guard(odsBuilder);
3310 
3311  // Build before region.
3312  SmallVector<Location, 4> beforeArgLocs;
3313  beforeArgLocs.reserve(inits.size());
3314  for (Value operand : inits) {
3315  beforeArgLocs.push_back(operand.getLoc());
3316  }
3317 
3318  Region *beforeRegion = odsState.addRegion();
3319  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
3320  inits.getTypes(), beforeArgLocs);
3321  if (beforeBuilder)
3322  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
3323 
3324  // Build after region.
3325  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
3326 
3327  Region *afterRegion = odsState.addRegion();
3328  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
3329  resultTypes, afterArgLocs);
3330 
3331  if (afterBuilder)
3332  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
3333 }
3334 
3335 ConditionOp WhileOp::getConditionOp() {
3336  return cast<ConditionOp>(getBeforeBody()->getTerminator());
3337 }
3338 
3339 YieldOp WhileOp::getYieldOp() {
3340  return cast<YieldOp>(getAfterBody()->getTerminator());
3341 }
3342 
3343 std::optional<MutableArrayRef<OpOperand>> WhileOp::getYieldedValuesMutable() {
3344  return getYieldOp().getResultsMutable();
3345 }
3346 
3347 Block::BlockArgListType WhileOp::getBeforeArguments() {
3348  return getBeforeBody()->getArguments();
3349 }
3350 
3351 Block::BlockArgListType WhileOp::getAfterArguments() {
3352  return getAfterBody()->getArguments();
3353 }
3354 
3355 Block::BlockArgListType WhileOp::getRegionIterArgs() {
3356  return getBeforeArguments();
3357 }
3358 
3359 OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
3360  assert(point == getBefore() &&
3361  "WhileOp is expected to branch only to the first region");
3362  return getInits();
3363 }
3364 
3365 void WhileOp::getSuccessorRegions(RegionBranchPoint point,
3367  // The parent op always branches to the condition region.
3368  if (point.isParent()) {
3369  regions.emplace_back(&getBefore(), getBefore().getArguments());
3370  return;
3371  }
3372 
3373  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
3374  "there are only two regions in a WhileOp");
3375  // The body region always branches back to the condition region.
3376  if (point == getAfter()) {
3377  regions.emplace_back(&getBefore(), getBefore().getArguments());
3378  return;
3379  }
3380 
3381  regions.emplace_back(getResults());
3382  regions.emplace_back(&getAfter(), getAfter().getArguments());
3383 }
3384 
3385 SmallVector<Region *> WhileOp::getLoopRegions() {
3386  return {&getBefore(), &getAfter()};
3387 }
3388 
3389 /// Parses a `while` op.
3390 ///
3391 /// op ::= `scf.while` assignments `:` function-type region `do` region
3392 /// `attributes` attribute-dict
3393 /// initializer ::= /* empty */ | `(` assignment-list `)`
3394 /// assignment-list ::= assignment | assignment `,` assignment-list
3395 /// assignment ::= ssa-value `=` ssa-value
3396 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3399  Region *before = result.addRegion();
3400  Region *after = result.addRegion();
3401 
3402  OptionalParseResult listResult =
3403  parser.parseOptionalAssignmentList(regionArgs, operands);
3404  if (listResult.has_value() && failed(listResult.value()))
3405  return failure();
3406 
3407  FunctionType functionType;
3408  SMLoc typeLoc = parser.getCurrentLocation();
3409  if (failed(parser.parseColonType(functionType)))
3410  return failure();
3411 
3412  result.addTypes(functionType.getResults());
3413 
3414  if (functionType.getNumInputs() != operands.size()) {
3415  return parser.emitError(typeLoc)
3416  << "expected as many input types as operands "
3417  << "(expected " << operands.size() << " got "
3418  << functionType.getNumInputs() << ")";
3419  }
3420 
3421  // Resolve input operands.
3422  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3423  parser.getCurrentLocation(),
3424  result.operands)))
3425  return failure();
3426 
3427  // Propagate the types into the region arguments.
3428  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3429  regionArgs[i].type = functionType.getInput(i);
3430 
3431  return failure(parser.parseRegion(*before, regionArgs) ||
3432  parser.parseKeyword("do") || parser.parseRegion(*after) ||
3434 }
3435 
3436 /// Prints a `while` op.
3438  printInitializationList(p, getBeforeArguments(), getInits(), " ");
3439  p << " : ";
3440  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
3441  p << ' ';
3442  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
3443  p << " do ";
3444  p.printRegion(getAfter());
3445  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3446 }
3447 
3448 /// Verifies that two ranges of types match, i.e. have the same number of
3449 /// entries and that types are pairwise equals. Reports errors on the given
3450 /// operation in case of mismatch.
3451 template <typename OpTy>
3452 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
3453  TypeRange right, StringRef message) {
3454  if (left.size() != right.size())
3455  return op.emitOpError("expects the same number of ") << message;
3456 
3457  for (unsigned i = 0, e = left.size(); i < e; ++i) {
3458  if (left[i] != right[i]) {
3459  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
3460  << message;
3461  diag.attachNote() << "for argument " << i << ", found " << left[i]
3462  << " and " << right[i];
3463  return diag;
3464  }
3465  }
3466 
3467  return success();
3468 }
3469 
3470 LogicalResult scf::WhileOp::verify() {
3471  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
3472  *this, getBefore(),
3473  "expects the 'before' region to terminate with 'scf.condition'");
3474  if (!beforeTerminator)
3475  return failure();
3476 
3477  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
3478  *this, getAfter(),
3479  "expects the 'after' region to terminate with 'scf.yield'");
3480  return success(afterTerminator != nullptr);
3481 }
3482 
3483 namespace {
3484 /// Replace uses of the condition within the do block with true, since otherwise
3485 /// the block would not be evaluated.
3486 ///
3487 /// scf.while (..) : (i1, ...) -> ... {
3488 /// %condition = call @evaluate_condition() : () -> i1
3489 /// scf.condition(%condition) %condition : i1, ...
3490 /// } do {
3491 /// ^bb0(%arg0: i1, ...):
3492 /// use(%arg0)
3493 /// ...
3494 ///
3495 /// becomes
3496 /// scf.while (..) : (i1, ...) -> ... {
3497 /// %condition = call @evaluate_condition() : () -> i1
3498 /// scf.condition(%condition) %condition : i1, ...
3499 /// } do {
3500 /// ^bb0(%arg0: i1, ...):
3501 /// use(%true)
3502 /// ...
3503 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
3505 
3506  LogicalResult matchAndRewrite(WhileOp op,
3507  PatternRewriter &rewriter) const override {
3508  auto term = op.getConditionOp();
3509 
3510  // These variables serve to prevent creating duplicate constants
3511  // and hold constant true or false values.
3512  Value constantTrue = nullptr;
3513 
3514  bool replaced = false;
3515  for (auto yieldedAndBlockArgs :
3516  llvm::zip(term.getArgs(), op.getAfterArguments())) {
3517  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
3518  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
3519  if (!constantTrue)
3520  constantTrue = rewriter.create<arith::ConstantOp>(
3521  op.getLoc(), term.getCondition().getType(),
3522  rewriter.getBoolAttr(true));
3523 
3524  rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs),
3525  constantTrue);
3526  replaced = true;
3527  }
3528  }
3529  }
3530  return success(replaced);
3531  }
3532 };
3533 
3534 /// Remove loop invariant arguments from `before` block of scf.while.
3535 /// A before block argument is considered loop invariant if :-
3536 /// 1. i-th yield operand is equal to the i-th while operand.
3537 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
3538 /// condition operand AND this (k+1)-th condition operand is equal to i-th
3539 /// iter argument/while operand.
3540 /// For the arguments which are removed, their uses inside scf.while
3541 /// are replaced with their corresponding initial value.
3542 ///
3543 /// Eg:
3544 /// INPUT :-
3545 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
3546 /// ..., %argN_before = %N)
3547 /// {
3548 /// ...
3549 /// scf.condition(%cond) %arg1_before, %arg0_before,
3550 /// %arg2_before, %arg0_before, ...
3551 /// } do {
3552 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3553 /// ..., %argK_after):
3554 /// ...
3555 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
3556 /// }
3557 ///
3558 /// OUTPUT :-
3559 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
3560 /// %N)
3561 /// {
3562 /// ...
3563 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
3564 /// } do {
3565 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
3566 /// ..., %argK_after):
3567 /// ...
3568 /// scf.yield %arg1_after, ..., %argN
3569 /// }
3570 ///
3571 /// EXPLANATION:
3572 /// We iterate over each yield operand.
3573 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
3574 /// %arg0_before, which in turn is the 0-th iter argument. So we
3575 /// remove 0-th before block argument and yield operand, and replace
3576 /// all uses of the 0-th before block argument with its initial value
3577 /// %a.
3578 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
3579 /// value. So we remove this operand and the corresponding before
3580 /// block argument and replace all uses of 1-th before block argument
3581 /// with %b.
3582 struct RemoveLoopInvariantArgsFromBeforeBlock
3583  : public OpRewritePattern<WhileOp> {
3585 
3586  LogicalResult matchAndRewrite(WhileOp op,
3587  PatternRewriter &rewriter) const override {
3588  Block &afterBlock = *op.getAfterBody();
3589  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
3590  ConditionOp condOp = op.getConditionOp();
3591  OperandRange condOpArgs = condOp.getArgs();
3592  Operation *yieldOp = afterBlock.getTerminator();
3593  ValueRange yieldOpArgs = yieldOp->getOperands();
3594 
3595  bool canSimplify = false;
3596  for (const auto &it :
3597  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3598  auto index = static_cast<unsigned>(it.index());
3599  auto [initVal, yieldOpArg] = it.value();
3600  // If i-th yield operand is equal to the i-th operand of the scf.while,
3601  // the i-th before block argument is a loop invariant.
3602  if (yieldOpArg == initVal) {
3603  canSimplify = true;
3604  break;
3605  }
3606  // If the i-th yield operand is k-th after block argument, then we check
3607  // if the (k+1)-th condition op operand is equal to either the i-th before
3608  // block argument or the initial value of i-th before block argument. If
3609  // the comparison results `true`, i-th before block argument is a loop
3610  // invariant.
3611  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3612  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3613  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3614  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3615  canSimplify = true;
3616  break;
3617  }
3618  }
3619  }
3620 
3621  if (!canSimplify)
3622  return failure();
3623 
3624  SmallVector<Value> newInitArgs, newYieldOpArgs;
3625  DenseMap<unsigned, Value> beforeBlockInitValMap;
3626  SmallVector<Location> newBeforeBlockArgLocs;
3627  for (const auto &it :
3628  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3629  auto index = static_cast<unsigned>(it.index());
3630  auto [initVal, yieldOpArg] = it.value();
3631 
3632  // If i-th yield operand is equal to the i-th operand of the scf.while,
3633  // the i-th before block argument is a loop invariant.
3634  if (yieldOpArg == initVal) {
3635  beforeBlockInitValMap.insert({index, initVal});
3636  continue;
3637  } else {
3638  // If the i-th yield operand is k-th after block argument, then we check
3639  // if the (k+1)-th condition op operand is equal to either the i-th
3640  // before block argument or the initial value of i-th before block
3641  // argument. If the comparison results `true`, i-th before block
3642  // argument is a loop invariant.
3643  auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
3644  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3645  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3646  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3647  beforeBlockInitValMap.insert({index, initVal});
3648  continue;
3649  }
3650  }
3651  }
3652  newInitArgs.emplace_back(initVal);
3653  newYieldOpArgs.emplace_back(yieldOpArg);
3654  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3655  }
3656 
3657  {
3658  OpBuilder::InsertionGuard g(rewriter);
3659  rewriter.setInsertionPoint(yieldOp);
3660  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3661  }
3662 
3663  auto newWhile =
3664  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3665 
3666  Block &newBeforeBlock = *rewriter.createBlock(
3667  &newWhile.getBefore(), /*insertPt*/ {},
3668  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3669 
3670  Block &beforeBlock = *op.getBeforeBody();
3671  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3672  // For each i-th before block argument we find it's replacement value as :-
3673  // 1. If i-th before block argument is a loop invariant, we fetch it's
3674  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3675  // 2. Else we fetch j-th new before block argument as the replacement
3676  // value of i-th before block argument.
3677  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3678  // If the index 'i' argument was a loop invariant we fetch it's initial
3679  // value from `beforeBlockInitValMap`.
3680  if (beforeBlockInitValMap.count(i) != 0)
3681  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3682  else
3683  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3684  }
3685 
3686  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3687  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3688  newWhile.getAfter().begin());
3689 
3690  rewriter.replaceOp(op, newWhile.getResults());
3691  return success();
3692  }
3693 };
3694 
3695 /// Remove loop invariant value from result (condition op) of scf.while.
3696 /// A value is considered loop invariant if the final value yielded by
3697 /// scf.condition is defined outside of the `before` block. We remove the
3698 /// corresponding argument in `after` block and replace the use with the value.
3699 /// We also replace the use of the corresponding result of scf.while with the
3700 /// value.
3701 ///
3702 /// Eg:
3703 /// INPUT :-
3704 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3705 /// %argN_before = %N) {
3706 /// ...
3707 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3708 /// } do {
3709 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3710 /// ...
3711 /// some_func(%arg1_after)
3712 /// ...
3713 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3714 /// }
3715 ///
3716 /// OUTPUT :-
3717 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3718 /// ...
3719 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3720 /// } do {
3721 /// ^bb0(%arg0, %arg3, ..., %argM):
3722 /// ...
3723 /// some_func(%a)
3724 /// ...
3725 /// scf.yield %arg0, %b, ..., %argN
3726 /// }
3727 ///
3728 /// EXPLANATION:
3729 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3730 /// before block of scf.while, so they get removed.
3731 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3732 /// replaced by %b.
3733 /// 3. The corresponding after block argument %arg1_after's uses are
3734 /// replaced by %a and %arg2_after's uses are replaced by %b.
3735 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3737 
3738  LogicalResult matchAndRewrite(WhileOp op,
3739  PatternRewriter &rewriter) const override {
3740  Block &beforeBlock = *op.getBeforeBody();
3741  ConditionOp condOp = op.getConditionOp();
3742  OperandRange condOpArgs = condOp.getArgs();
3743 
3744  bool canSimplify = false;
3745  for (Value condOpArg : condOpArgs) {
3746  // Those values not defined within `before` block will be considered as
3747  // loop invariant values. We map the corresponding `index` with their
3748  // value.
3749  if (condOpArg.getParentBlock() != &beforeBlock) {
3750  canSimplify = true;
3751  break;
3752  }
3753  }
3754 
3755  if (!canSimplify)
3756  return failure();
3757 
3758  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3759 
3760  SmallVector<Value> newCondOpArgs;
3761  SmallVector<Type> newAfterBlockType;
3762  DenseMap<unsigned, Value> condOpInitValMap;
3763  SmallVector<Location> newAfterBlockArgLocs;
3764  for (const auto &it : llvm::enumerate(condOpArgs)) {
3765  auto index = static_cast<unsigned>(it.index());
3766  Value condOpArg = it.value();
3767  // Those values not defined within `before` block will be considered as
3768  // loop invariant values. We map the corresponding `index` with their
3769  // value.
3770  if (condOpArg.getParentBlock() != &beforeBlock) {
3771  condOpInitValMap.insert({index, condOpArg});
3772  } else {
3773  newCondOpArgs.emplace_back(condOpArg);
3774  newAfterBlockType.emplace_back(condOpArg.getType());
3775  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3776  }
3777  }
3778 
3779  {
3780  OpBuilder::InsertionGuard g(rewriter);
3781  rewriter.setInsertionPoint(condOp);
3782  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3783  newCondOpArgs);
3784  }
3785 
3786  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3787  op.getOperands());
3788 
3789  Block &newAfterBlock =
3790  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3791  newAfterBlockType, newAfterBlockArgLocs);
3792 
3793  Block &afterBlock = *op.getAfterBody();
3794  // Since a new scf.condition op was created, we need to fetch the new
3795  // `after` block arguments which will be used while replacing operations of
3796  // previous scf.while's `after` blocks. We'd also be fetching new result
3797  // values too.
3798  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3799  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3800  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3801  Value afterBlockArg, result;
3802  // If index 'i' argument was loop invariant we fetch it's value from the
3803  // `condOpInitMap` map.
3804  if (condOpInitValMap.count(i) != 0) {
3805  afterBlockArg = condOpInitValMap[i];
3806  result = afterBlockArg;
3807  } else {
3808  afterBlockArg = newAfterBlock.getArgument(j);
3809  result = newWhile.getResult(j);
3810  j++;
3811  }
3812  newAfterBlockArgs[i] = afterBlockArg;
3813  newWhileResults[i] = result;
3814  }
3815 
3816  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3817  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3818  newWhile.getBefore().begin());
3819 
3820  rewriter.replaceOp(op, newWhileResults);
3821  return success();
3822  }
3823 };
3824 
3825 /// Remove WhileOp results that are also unused in 'after' block.
3826 ///
3827 /// %0:2 = scf.while () : () -> (i32, i64) {
3828 /// %condition = "test.condition"() : () -> i1
3829 /// %v1 = "test.get_some_value"() : () -> i32
3830 /// %v2 = "test.get_some_value"() : () -> i64
3831 /// scf.condition(%condition) %v1, %v2 : i32, i64
3832 /// } do {
3833 /// ^bb0(%arg0: i32, %arg1: i64):
3834 /// "test.use"(%arg0) : (i32) -> ()
3835 /// scf.yield
3836 /// }
3837 /// return %0#0 : i32
3838 ///
3839 /// becomes
3840 /// %0 = scf.while () : () -> (i32) {
3841 /// %condition = "test.condition"() : () -> i1
3842 /// %v1 = "test.get_some_value"() : () -> i32
3843 /// %v2 = "test.get_some_value"() : () -> i64
3844 /// scf.condition(%condition) %v1 : i32
3845 /// } do {
3846 /// ^bb0(%arg0: i32):
3847 /// "test.use"(%arg0) : (i32) -> ()
3848 /// scf.yield
3849 /// }
3850 /// return %0 : i32
3851 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3853 
3854  LogicalResult matchAndRewrite(WhileOp op,
3855  PatternRewriter &rewriter) const override {
3856  auto term = op.getConditionOp();
3857  auto afterArgs = op.getAfterArguments();
3858  auto termArgs = term.getArgs();
3859 
3860  // Collect results mapping, new terminator args and new result types.
3861  SmallVector<unsigned> newResultsIndices;
3862  SmallVector<Type> newResultTypes;
3863  SmallVector<Value> newTermArgs;
3864  SmallVector<Location> newArgLocs;
3865  bool needUpdate = false;
3866  for (const auto &it :
3867  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3868  auto i = static_cast<unsigned>(it.index());
3869  Value result = std::get<0>(it.value());
3870  Value afterArg = std::get<1>(it.value());
3871  Value termArg = std::get<2>(it.value());
3872  if (result.use_empty() && afterArg.use_empty()) {
3873  needUpdate = true;
3874  } else {
3875  newResultsIndices.emplace_back(i);
3876  newTermArgs.emplace_back(termArg);
3877  newResultTypes.emplace_back(result.getType());
3878  newArgLocs.emplace_back(result.getLoc());
3879  }
3880  }
3881 
3882  if (!needUpdate)
3883  return failure();
3884 
3885  {
3886  OpBuilder::InsertionGuard g(rewriter);
3887  rewriter.setInsertionPoint(term);
3888  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3889  newTermArgs);
3890  }
3891 
3892  auto newWhile =
3893  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3894 
3895  Block &newAfterBlock = *rewriter.createBlock(
3896  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3897 
3898  // Build new results list and new after block args (unused entries will be
3899  // null).
3900  SmallVector<Value> newResults(op.getNumResults());
3901  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3902  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3903  newResults[it.value()] = newWhile.getResult(it.index());
3904  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3905  }
3906 
3907  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3908  newWhile.getBefore().begin());
3909 
3910  Block &afterBlock = *op.getAfterBody();
3911  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3912 
3913  rewriter.replaceOp(op, newResults);
3914  return success();
3915  }
3916 };
3917 
3918 /// Replace operations equivalent to the condition in the do block with true,
3919 /// since otherwise the block would not be evaluated.
3920 ///
3921 /// scf.while (..) : (i32, ...) -> ... {
3922 /// %z = ... : i32
3923 /// %condition = cmpi pred %z, %a
3924 /// scf.condition(%condition) %z : i32, ...
3925 /// } do {
3926 /// ^bb0(%arg0: i32, ...):
3927 /// %condition2 = cmpi pred %arg0, %a
3928 /// use(%condition2)
3929 /// ...
3930 ///
3931 /// becomes
3932 /// scf.while (..) : (i32, ...) -> ... {
3933 /// %z = ... : i32
3934 /// %condition = cmpi pred %z, %a
3935 /// scf.condition(%condition) %z : i32, ...
3936 /// } do {
3937 /// ^bb0(%arg0: i32, ...):
3938 /// use(%true)
3939 /// ...
3940 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3942 
3943  LogicalResult matchAndRewrite(scf::WhileOp op,
3944  PatternRewriter &rewriter) const override {
3945  using namespace scf;
3946  auto cond = op.getConditionOp();
3947  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3948  if (!cmp)
3949  return failure();
3950  bool changed = false;
3951  for (auto tup : llvm::zip(cond.getArgs(), op.getAfterArguments())) {
3952  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3953  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3954  continue;
3955  for (OpOperand &u :
3956  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3957  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3958  if (!cmp2)
3959  continue;
3960  // For a binary operator 1-opIdx gets the other side.
3961  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3962  continue;
3963  bool samePredicate;
3964  if (cmp2.getPredicate() == cmp.getPredicate())
3965  samePredicate = true;
3966  else if (cmp2.getPredicate() ==
3967  arith::invertPredicate(cmp.getPredicate()))
3968  samePredicate = false;
3969  else
3970  continue;
3971 
3972  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3973  1);
3974  changed = true;
3975  }
3976  }
3977  }
3978  return success(changed);
3979  }
3980 };
3981 
3982 /// Remove unused init/yield args.
3983 struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
3985 
3986  LogicalResult matchAndRewrite(WhileOp op,
3987  PatternRewriter &rewriter) const override {
3988 
3989  if (!llvm::any_of(op.getBeforeArguments(),
3990  [](Value arg) { return arg.use_empty(); }))
3991  return rewriter.notifyMatchFailure(op, "No args to remove");
3992 
3993  YieldOp yield = op.getYieldOp();
3994 
3995  // Collect results mapping, new terminator args and new result types.
3996  SmallVector<Value> newYields;
3997  SmallVector<Value> newInits;
3998  llvm::BitVector argsToErase;
3999 
4000  size_t argsCount = op.getBeforeArguments().size();
4001  newYields.reserve(argsCount);
4002  newInits.reserve(argsCount);
4003  argsToErase.reserve(argsCount);
4004  for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
4005  op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
4006  if (beforeArg.use_empty()) {
4007  argsToErase.push_back(true);
4008  } else {
4009  argsToErase.push_back(false);
4010  newYields.emplace_back(yieldValue);
4011  newInits.emplace_back(initValue);
4012  }
4013  }
4014 
4015  Block &beforeBlock = *op.getBeforeBody();
4016  Block &afterBlock = *op.getAfterBody();
4017 
4018  beforeBlock.eraseArguments(argsToErase);
4019 
4020  Location loc = op.getLoc();
4021  auto newWhileOp =
4022  rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits,
4023  /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
4024  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4025  Block &newAfterBlock = *newWhileOp.getAfterBody();
4026 
4027  OpBuilder::InsertionGuard g(rewriter);
4028  rewriter.setInsertionPoint(yield);
4029  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
4030 
4031  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4032  newBeforeBlock.getArguments());
4033  rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
4034  newAfterBlock.getArguments());
4035 
4036  rewriter.replaceOp(op, newWhileOp.getResults());
4037  return success();
4038  }
4039 };
4040 
4041 /// Remove duplicated ConditionOp args.
4042 struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
4044 
4045  LogicalResult matchAndRewrite(WhileOp op,
4046  PatternRewriter &rewriter) const override {
4047  ConditionOp condOp = op.getConditionOp();
4048  ValueRange condOpArgs = condOp.getArgs();
4049 
4051  for (Value arg : condOpArgs)
4052  argsSet.insert(arg);
4053 
4054  if (argsSet.size() == condOpArgs.size())
4055  return rewriter.notifyMatchFailure(op, "No results to remove");
4056 
4057  llvm::SmallDenseMap<Value, unsigned> argsMap;
4058  SmallVector<Value> newArgs;
4059  argsMap.reserve(condOpArgs.size());
4060  newArgs.reserve(condOpArgs.size());
4061  for (Value arg : condOpArgs) {
4062  if (!argsMap.count(arg)) {
4063  auto pos = static_cast<unsigned>(argsMap.size());
4064  argsMap.insert({arg, pos});
4065  newArgs.emplace_back(arg);
4066  }
4067  }
4068 
4069  ValueRange argsRange(newArgs);
4070 
4071  Location loc = op.getLoc();
4072  auto newWhileOp = rewriter.create<scf::WhileOp>(
4073  loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr,
4074  /*afterBody*/ nullptr);
4075  Block &newBeforeBlock = *newWhileOp.getBeforeBody();
4076  Block &newAfterBlock = *newWhileOp.getAfterBody();
4077 
4078  SmallVector<Value> afterArgsMapping;
4079  SmallVector<Value> resultsMapping;
4080  for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
4081  auto it = argsMap.find(arg);
4082  assert(it != argsMap.end());
4083  auto pos = it->second;
4084  afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
4085  resultsMapping.emplace_back(newWhileOp->getResult(pos));
4086  }
4087 
4088  OpBuilder::InsertionGuard g(rewriter);
4089  rewriter.setInsertionPoint(condOp);
4090  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
4091  argsRange);
4092 
4093  Block &beforeBlock = *op.getBeforeBody();
4094  Block &afterBlock = *op.getAfterBody();
4095 
4096  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
4097  newBeforeBlock.getArguments());
4098  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
4099  rewriter.replaceOp(op, resultsMapping);
4100  return success();
4101  }
4102 };
4103 
4104 /// If both ranges contain same values return mappping indices from args2 to
4105 /// args1. Otherwise return std::nullopt.
4106 static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
4107  ValueRange args2) {
4108  if (args1.size() != args2.size())
4109  return std::nullopt;
4110 
4111  SmallVector<unsigned> ret(args1.size());
4112  for (auto &&[i, arg1] : llvm::enumerate(args1)) {
4113  auto it = llvm::find(args2, arg1);
4114  if (it == args2.end())
4115  return std::nullopt;
4116 
4117  ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
4118  }
4119 
4120  return ret;
4121 }
4122 
4123 static bool hasDuplicates(ValueRange args) {
4124  llvm::SmallDenseSet<Value> set;
4125  for (Value arg : args) {
4126  if (!set.insert(arg).second)
4127  return true;
4128  }
4129  return false;
4130 }
4131 
4132 /// If `before` block args are directly forwarded to `scf.condition`, rearrange
4133 /// `scf.condition` args into same order as block args. Update `after` block
4134 /// args and op result values accordingly.
4135 /// Needed to simplify `scf.while` -> `scf.for` uplifting.
4136 struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
4138 
4139  LogicalResult matchAndRewrite(WhileOp loop,
4140  PatternRewriter &rewriter) const override {
4141  auto oldBefore = loop.getBeforeBody();
4142  ConditionOp oldTerm = loop.getConditionOp();
4143  ValueRange beforeArgs = oldBefore->getArguments();
4144  ValueRange termArgs = oldTerm.getArgs();
4145  if (beforeArgs == termArgs)
4146  return failure();
4147 
4148  if (hasDuplicates(termArgs))
4149  return failure();
4150 
4151  auto mapping = getArgsMapping(beforeArgs, termArgs);
4152  if (!mapping)
4153  return failure();
4154 
4155  {
4156  OpBuilder::InsertionGuard g(rewriter);
4157  rewriter.setInsertionPoint(oldTerm);
4158  rewriter.replaceOpWithNewOp<ConditionOp>(oldTerm, oldTerm.getCondition(),
4159  beforeArgs);
4160  }
4161 
4162  auto oldAfter = loop.getAfterBody();
4163 
4164  SmallVector<Type> newResultTypes(beforeArgs.size());
4165  for (auto &&[i, j] : llvm::enumerate(*mapping))
4166  newResultTypes[j] = loop.getResult(i).getType();
4167 
4168  auto newLoop = rewriter.create<WhileOp>(
4169  loop.getLoc(), newResultTypes, loop.getInits(),
4170  /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr);
4171  auto newBefore = newLoop.getBeforeBody();
4172  auto newAfter = newLoop.getAfterBody();
4173 
4174  SmallVector<Value> newResults(beforeArgs.size());
4175  SmallVector<Value> newAfterArgs(beforeArgs.size());
4176  for (auto &&[i, j] : llvm::enumerate(*mapping)) {
4177  newResults[i] = newLoop.getResult(j);
4178  newAfterArgs[i] = newAfter->getArgument(j);
4179  }
4180 
4181  rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
4182  newBefore->getArguments());
4183  rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
4184  newAfterArgs);
4185 
4186  rewriter.replaceOp(loop, newResults);
4187  return success();
4188  }
4189 };
4190 } // namespace
4191 
4192 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
4193  MLIRContext *context) {
4194  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
4195  RemoveLoopInvariantValueYielded, WhileConditionTruth,
4196  WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4197  WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4198 }
4199 
4200 //===----------------------------------------------------------------------===//
4201 // IndexSwitchOp
4202 //===----------------------------------------------------------------------===//
4203 
4204 /// Parse the case regions and values.
4205 static ParseResult
4207  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
4208  SmallVector<int64_t> caseValues;
4209  while (succeeded(p.parseOptionalKeyword("case"))) {
4210  int64_t value;
4211  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
4212  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
4213  return failure();
4214  caseValues.push_back(value);
4215  }
4216  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
4217  return success();
4218 }
4219 
4220 /// Print the case regions and values.
4222  DenseI64ArrayAttr cases, RegionRange caseRegions) {
4223  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
4224  p.printNewline();
4225  p << "case " << value << ' ';
4226  p.printRegion(*region, /*printEntryBlockArgs=*/false);
4227  }
4228 }
4229 
4230 LogicalResult scf::IndexSwitchOp::verify() {
4231  if (getCases().size() != getCaseRegions().size()) {
4232  return emitOpError("has ")
4233  << getCaseRegions().size() << " case regions but "
4234  << getCases().size() << " case values";
4235  }
4236 
4237  DenseSet<int64_t> valueSet;
4238  for (int64_t value : getCases())
4239  if (!valueSet.insert(value).second)
4240  return emitOpError("has duplicate case value: ") << value;
4241  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
4242  auto yield = dyn_cast<YieldOp>(region.front().back());
4243  if (!yield)
4244  return emitOpError("expected region to end with scf.yield, but got ")
4245  << region.front().back().getName();
4246 
4247  if (yield.getNumOperands() != getNumResults()) {
4248  return (emitOpError("expected each region to return ")
4249  << getNumResults() << " values, but " << name << " returns "
4250  << yield.getNumOperands())
4251  .attachNote(yield.getLoc())
4252  << "see yield operation here";
4253  }
4254  for (auto [idx, result, operand] :
4255  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
4256  yield.getOperandTypes())) {
4257  if (result == operand)
4258  continue;
4259  return (emitOpError("expected result #")
4260  << idx << " of each region to be " << result)
4261  .attachNote(yield.getLoc())
4262  << name << " returns " << operand << " here";
4263  }
4264  return success();
4265  };
4266 
4267  if (failed(verifyRegion(getDefaultRegion(), "default region")))
4268  return failure();
4269  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
4270  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
4271  return failure();
4272 
4273  return success();
4274 }
4275 
4276 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
4277 
4278 Block &scf::IndexSwitchOp::getDefaultBlock() {
4279  return getDefaultRegion().front();
4280 }
4281 
4282 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
4283  assert(idx < getNumCases() && "case index out-of-bounds");
4284  return getCaseRegions()[idx].front();
4285 }
4286 
4287 void IndexSwitchOp::getSuccessorRegions(
4289  // All regions branch back to the parent op.
4290  if (!point.isParent()) {
4291  successors.emplace_back(getResults());
4292  return;
4293  }
4294 
4295  llvm::copy(getRegions(), std::back_inserter(successors));
4296 }
4297 
4298 void IndexSwitchOp::getEntrySuccessorRegions(
4299  ArrayRef<Attribute> operands,
4300  SmallVectorImpl<RegionSuccessor> &successors) {
4301  FoldAdaptor adaptor(operands, *this);
4302 
4303  // If a constant was not provided, all regions are possible successors.
4304  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
4305  if (!arg) {
4306  llvm::copy(getRegions(), std::back_inserter(successors));
4307  return;
4308  }
4309 
4310  // Otherwise, try to find a case with a matching value. If not, the
4311  // default region is the only successor.
4312  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
4313  if (caseValue == arg.getInt()) {
4314  successors.emplace_back(&caseRegion);
4315  return;
4316  }
4317  }
4318  successors.emplace_back(&getDefaultRegion());
4319 }
4320 
4321 void IndexSwitchOp::getRegionInvocationBounds(
4323  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
4324  if (!operandValue) {
4325  // All regions are invoked at most once.
4326  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
4327  return;
4328  }
4329 
4330  unsigned liveIndex = getNumRegions() - 1;
4331  const auto *it = llvm::find(getCases(), operandValue.getInt());
4332  if (it != getCases().end())
4333  liveIndex = std::distance(getCases().begin(), it);
4334  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
4335  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
4336 }
4337 
4338 struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
4340 
4341  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
4342  PatternRewriter &rewriter) const override {
4343  // If `op.getArg()` is a constant, select the region that matches with
4344  // the constant value. Use the default region if no matche is found.
4345  std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
4346  if (!maybeCst.has_value())
4347  return failure();
4348  int64_t cst = *maybeCst;
4349  int64_t caseIdx, e = op.getNumCases();
4350  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
4351  if (cst == op.getCases()[caseIdx])
4352  break;
4353  }
4354 
4355  Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
4356  : op.getDefaultRegion();
4357  Block &source = r.front();
4358  Operation *terminator = source.getTerminator();
4359  SmallVector<Value> results = terminator->getOperands();
4360 
4361  rewriter.inlineBlockBefore(&source, op);
4362  rewriter.eraseOp(terminator);
4363  // Replace the operation with a potentially empty list of results.
4364  // Fold mechanism doesn't support the case where the result list is empty.
4365  rewriter.replaceOp(op, results);
4366 
4367  return success();
4368  }
4369 };
4370 
4371 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4372  MLIRContext *context) {
4373  results.add<FoldConstantCase>(context);
4374 }
4375 
4376 //===----------------------------------------------------------------------===//
4377 // TableGen'd op method definitions
4378 //===----------------------------------------------------------------------===//
4379 
4380 #define GET_OP_CLASSES
4381 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:726
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:718
static MutableOperandRange getMutableSuccessorOperands(Block *block, unsigned successorIndex)
Returns the mutable operand range used to transfer operands from block to its successor with the give...
Definition: CFGToSCF.cpp:133
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:112
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:4206
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:3452
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:431
static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a TerminatorTy.
Definition: SCF.cpp:92
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:4221
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
This class represents an argument of a Block.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:571
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:582
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:45
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
LogicalResult promoteIfSingleIteration(AffineForOp forOp)
Promotes the loop body of a AffineForOp to its containing block if the loop was known to have a singl...
Definition: LoopUtils.cpp:118
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:76
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:539
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:3065
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:85
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:687
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1992
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition: SCF.cpp:644
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
ForallOp getForallOpThreadIndexOwner(Value val)
Returns the ForallOp parent of an thread index variable.
Definition: SCF.cpp:1466
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
SmallVector< Value > replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, OpOperand &operand, Value replacement, const ValueTypeCastFnTy &castFn)
Definition: SCF.cpp:776
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:269
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::optional< int64_t > constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Return the number of iterations for a loop with a lower bound lb, upper bound ub and step step.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayRef< int64_t > attr, ValueRange values)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:4341
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:233
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:184
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.