MLIR  16.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // SCFDialect Dialect Interfaces
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 struct SCFInlinerInterface : public DialectInlinerInterface {
38  // We don't have any special restrictions on what can be inlined into
39  // destination regions (e.g. while/conditional bodies). Always allow it.
40  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
41  BlockAndValueMapping &valueMapping) const final {
42  return true;
43  }
44  // Operations in scf dialect are always legal to inline since they are
45  // pure.
46  bool isLegalToInline(Operation *, Region *, bool,
47  BlockAndValueMapping &) const final {
48  return true;
49  }
50  // Handle the given inlined terminator by replacing it with a new operation
51  // as necessary. Required when the region has only one block.
52  void handleTerminator(Operation *op,
53  ArrayRef<Value> valuesToRepl) const final {
54  auto retValOp = dyn_cast<scf::YieldOp>(op);
55  if (!retValOp)
56  return;
57 
58  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
59  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
60  }
61  }
62 };
63 } // namespace
64 
65 //===----------------------------------------------------------------------===//
66 // SCFDialect
67 //===----------------------------------------------------------------------===//
68 
69 void SCFDialect::initialize() {
70  addOperations<
71 #define GET_OP_LIST
72 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
73  >();
74  addInterfaces<SCFInlinerInterface>();
75 }
76 
77 /// Default callback for IfOp builders. Inserts a yield without arguments.
79  builder.create<scf::YieldOp>(loc);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // ExecuteRegionOp
84 //===----------------------------------------------------------------------===//
85 
86 /// Replaces the given op with the contents of the given single-block region,
87 /// using the operands of the block terminator to replace operation results.
88 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
89  Region &region, ValueRange blockArgs = {}) {
90  assert(llvm::hasSingleElement(region) && "expected single-region block");
91  Block *block = &region.front();
92  Operation *terminator = block->getTerminator();
93  ValueRange results = terminator->getOperands();
94  rewriter.mergeBlockBefore(block, op, blockArgs);
95  rewriter.replaceOp(op, results);
96  rewriter.eraseOp(terminator);
97 }
98 
99 ///
100 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
101 /// block+
102 /// `}`
103 ///
104 /// Example:
105 /// scf.execute_region -> i32 {
106 /// %idx = load %rI[%i] : memref<128xi32>
107 /// return %idx : i32
108 /// }
109 ///
110 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
111  OperationState &result) {
112  if (parser.parseOptionalArrowTypeList(result.types))
113  return failure();
114 
115  // Introduce the body region and parse it.
116  Region *body = result.addRegion();
117  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
118  parser.parseOptionalAttrDict(result.attributes))
119  return failure();
120 
121  return success();
122 }
123 
125  p.printOptionalArrowTypeList(getResultTypes());
126 
127  p << ' ';
128  p.printRegion(getRegion(),
129  /*printEntryBlockArgs=*/false,
130  /*printBlockTerminators=*/true);
131 
132  p.printOptionalAttrDict((*this)->getAttrs());
133 }
134 
136  if (getRegion().empty())
137  return emitOpError("region needs to have at least one block");
138  if (getRegion().front().getNumArguments() > 0)
139  return emitOpError("region cannot have any arguments");
140  return success();
141 }
142 
143 // Inline an ExecuteRegionOp if it only contains one block.
144 // "test.foo"() : () -> ()
145 // %v = scf.execute_region -> i64 {
146 // %x = "test.val"() : () -> i64
147 // scf.yield %x : i64
148 // }
149 // "test.bar"(%v) : (i64) -> ()
150 //
151 // becomes
152 //
153 // "test.foo"() : () -> ()
154 // %x = "test.val"() : () -> i64
155 // "test.bar"(%x) : (i64) -> ()
156 //
157 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
159 
160  LogicalResult matchAndRewrite(ExecuteRegionOp op,
161  PatternRewriter &rewriter) const override {
162  if (!llvm::hasSingleElement(op.getRegion()))
163  return failure();
164  replaceOpWithRegion(rewriter, op, op.getRegion());
165  return success();
166  }
167 };
168 
169 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
170 // TODO generalize the conditions for operations which can be inlined into.
171 // func @func_execute_region_elim() {
172 // "test.foo"() : () -> ()
173 // %v = scf.execute_region -> i64 {
174 // %c = "test.cmp"() : () -> i1
175 // cf.cond_br %c, ^bb2, ^bb3
176 // ^bb2:
177 // %x = "test.val1"() : () -> i64
178 // cf.br ^bb4(%x : i64)
179 // ^bb3:
180 // %y = "test.val2"() : () -> i64
181 // cf.br ^bb4(%y : i64)
182 // ^bb4(%z : i64):
183 // scf.yield %z : i64
184 // }
185 // "test.bar"(%v) : (i64) -> ()
186 // return
187 // }
188 //
189 // becomes
190 //
191 // func @func_execute_region_elim() {
192 // "test.foo"() : () -> ()
193 // %c = "test.cmp"() : () -> i1
194 // cf.cond_br %c, ^bb1, ^bb2
195 // ^bb1: // pred: ^bb0
196 // %x = "test.val1"() : () -> i64
197 // cf.br ^bb3(%x : i64)
198 // ^bb2: // pred: ^bb0
199 // %y = "test.val2"() : () -> i64
200 // cf.br ^bb3(%y : i64)
201 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
202 // "test.bar"(%z) : (i64) -> ()
203 // return
204 // }
205 //
206 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
208 
209  LogicalResult matchAndRewrite(ExecuteRegionOp op,
210  PatternRewriter &rewriter) const override {
211  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
212  return failure();
213 
214  Block *prevBlock = op->getBlock();
215  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
216  rewriter.setInsertionPointToEnd(prevBlock);
217 
218  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
219 
220  for (Block &blk : op.getRegion()) {
221  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
222  rewriter.setInsertionPoint(yieldOp);
223  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
224  yieldOp.getResults());
225  rewriter.eraseOp(yieldOp);
226  }
227  }
228 
229  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
230  SmallVector<Value> blockArgs;
231 
232  for (auto res : op.getResults())
233  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
234 
235  rewriter.replaceOp(op, blockArgs);
236  return success();
237  }
238 };
239 
240 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
241  MLIRContext *context) {
243 }
244 
245 /// Given the region at `index`, or the parent operation if `index` is None,
246 /// return the successor regions. These are the regions that may be selected
247 /// during the flow of control. `operands` is a set of optional attributes that
248 /// correspond to a constant value for each operand, or null if that operand is
249 /// not a constant.
250 void ExecuteRegionOp::getSuccessorRegions(
251  Optional<unsigned> index, ArrayRef<Attribute> operands,
253  // If the predecessor is the ExecuteRegionOp, branch into the body.
254  if (!index) {
255  regions.push_back(RegionSuccessor(&getRegion()));
256  return;
257  }
258 
259  // Otherwise, the region branches back to the parent operation.
260  regions.push_back(RegionSuccessor(getResults()));
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // ConditionOp
265 //===----------------------------------------------------------------------===//
266 
268 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
269  // Pass all operands except the condition to the successor region.
270  return getArgsMutable();
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // ForOp
275 //===----------------------------------------------------------------------===//
276 
277 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
278  Value ub, Value step, ValueRange iterArgs,
279  BodyBuilderFn bodyBuilder) {
280  result.addOperands({lb, ub, step});
281  result.addOperands(iterArgs);
282  for (Value v : iterArgs)
283  result.addTypes(v.getType());
284  Region *bodyRegion = result.addRegion();
285  bodyRegion->push_back(new Block);
286  Block &bodyBlock = bodyRegion->front();
287  bodyBlock.addArgument(builder.getIndexType(), result.location);
288  for (Value v : iterArgs)
289  bodyBlock.addArgument(v.getType(), v.getLoc());
290 
291  // Create the default terminator if the builder is not provided and if the
292  // iteration arguments are not provided. Otherwise, leave this to the caller
293  // because we don't know which values to return from the loop.
294  if (iterArgs.empty() && !bodyBuilder) {
295  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
296  } else if (bodyBuilder) {
297  OpBuilder::InsertionGuard guard(builder);
298  builder.setInsertionPointToStart(&bodyBlock);
299  bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
300  bodyBlock.getArguments().drop_front());
301  }
302 }
303 
305  IntegerAttr step;
306  if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
307  return emitOpError("constant step operand must be positive");
308 
309  auto opNumResults = getNumResults();
310  if (opNumResults == 0)
311  return success();
312  // If ForOp defines values, check that the number and types of
313  // the defined values match ForOp initial iter operands and backedge
314  // basic block arguments.
315  if (getNumIterOperands() != opNumResults)
316  return emitOpError(
317  "mismatch in number of loop-carried values and defined values");
318  return success();
319 }
320 
321 LogicalResult ForOp::verifyRegions() {
322  // Check that the body defines as single block argument for the induction
323  // variable.
324  auto *body = getBody();
325  if (!body->getArgument(0).getType().isIndex())
326  return emitOpError(
327  "expected body first argument to be an index argument for "
328  "the induction variable");
329 
330  auto opNumResults = getNumResults();
331  if (opNumResults == 0)
332  return success();
333 
334  if (getNumRegionIterArgs() != opNumResults)
335  return emitOpError(
336  "mismatch in number of basic block args and defined values");
337 
338  auto iterOperands = getIterOperands();
339  auto iterArgs = getRegionIterArgs();
340  auto opResults = getResults();
341  unsigned i = 0;
342  for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
343  if (std::get<0>(e).getType() != std::get<2>(e).getType())
344  return emitOpError() << "types mismatch between " << i
345  << "th iter operand and defined value";
346  if (std::get<1>(e).getType() != std::get<2>(e).getType())
347  return emitOpError() << "types mismatch between " << i
348  << "th iter region arg and defined value";
349 
350  i++;
351  }
352  return success();
353 }
354 
355 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
356 
357 Optional<OpFoldResult> ForOp::getSingleLowerBound() {
358  return OpFoldResult(getLowerBound());
359 }
360 
361 Optional<OpFoldResult> ForOp::getSingleStep() {
362  return OpFoldResult(getStep());
363 }
364 
365 Optional<OpFoldResult> ForOp::getSingleUpperBound() {
366  return OpFoldResult(getUpperBound());
367 }
368 
369 /// Prints the initialization list in the form of
370 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
371 /// where 'inner' values are assumed to be region arguments and 'outer' values
372 /// are regular SSA values.
374  Block::BlockArgListType blocksArgs,
375  ValueRange initializers,
376  StringRef prefix = "") {
377  assert(blocksArgs.size() == initializers.size() &&
378  "expected same length of arguments and initializers");
379  if (initializers.empty())
380  return;
381 
382  p << prefix << '(';
383  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
384  p << std::get<0>(it) << " = " << std::get<1>(it);
385  });
386  p << ")";
387 }
388 
389 void ForOp::print(OpAsmPrinter &p) {
390  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
391  << getUpperBound() << " step " << getStep();
392 
393  printInitializationList(p, getRegionIterArgs(), getIterOperands(),
394  " iter_args");
395  if (!getIterOperands().empty())
396  p << " -> (" << getIterOperands().getTypes() << ')';
397  p << ' ';
398  p.printRegion(getRegion(),
399  /*printEntryBlockArgs=*/false,
400  /*printBlockTerminators=*/hasIterOperands());
401  p.printOptionalAttrDict((*this)->getAttrs());
402 }
403 
404 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
405  auto &builder = parser.getBuilder();
406  Type indexType = builder.getIndexType();
407 
408  OpAsmParser::Argument inductionVariable;
409  inductionVariable.type = indexType;
410  OpAsmParser::UnresolvedOperand lb, ub, step;
411 
412  // Parse the induction variable followed by '='.
413  if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
414  // Parse loop bounds.
415  parser.parseOperand(lb) ||
416  parser.resolveOperand(lb, indexType, result.operands) ||
417  parser.parseKeyword("to") || parser.parseOperand(ub) ||
418  parser.resolveOperand(ub, indexType, result.operands) ||
419  parser.parseKeyword("step") || parser.parseOperand(step) ||
420  parser.resolveOperand(step, indexType, result.operands))
421  return failure();
422 
423  // Parse the optional initial iteration arguments.
426  regionArgs.push_back(inductionVariable);
427 
428  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
429  // Parse assignment list and results type list.
430  if (parser.parseAssignmentList(regionArgs, operands) ||
431  parser.parseArrowTypeList(result.types))
432  return failure();
433 
434  // Resolve input operands.
435  for (auto argOperandType :
436  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
437  Type type = std::get<2>(argOperandType);
438  std::get<0>(argOperandType).type = type;
439  if (parser.resolveOperand(std::get<1>(argOperandType), type,
440  result.operands))
441  return failure();
442  }
443  }
444 
445  if (regionArgs.size() != result.types.size() + 1)
446  return parser.emitError(
447  parser.getNameLoc(),
448  "mismatch in number of loop-carried values and defined values");
449 
450  // Parse the body region.
451  Region *body = result.addRegion();
452  if (parser.parseRegion(*body, regionArgs))
453  return failure();
454 
455  ForOp::ensureTerminator(*body, builder, result.location);
456 
457  // Parse the optional attribute list.
458  if (parser.parseOptionalAttrDict(result.attributes))
459  return failure();
460 
461  return success();
462 }
463 
464 Region &ForOp::getLoopBody() { return getRegion(); }
465 
467  auto ivArg = val.dyn_cast<BlockArgument>();
468  if (!ivArg)
469  return ForOp();
470  assert(ivArg.getOwner() && "unlinked block argument");
471  auto *containingOp = ivArg.getOwner()->getParentOp();
472  return dyn_cast_or_null<ForOp>(containingOp);
473 }
474 
475 /// Return operands used when entering the region at 'index'. These operands
476 /// correspond to the loop iterator operands, i.e., those excluding the
477 /// induction variable. LoopOp only has one region, so 0 is the only valid value
478 /// for `index`.
479 OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
480  assert(index && *index == 0 && "invalid region index");
481 
482  // The initial operands map to the loop arguments after the induction
483  // variable.
484  return getInitArgs();
485 }
486 
487 /// Given the region at `index`, or the parent operation if `index` is None,
488 /// return the successor regions. These are the regions that may be selected
489 /// during the flow of control. `operands` is a set of optional attributes that
490 /// correspond to a constant value for each operand, or null if that operand is
491 /// not a constant.
492 void ForOp::getSuccessorRegions(Optional<unsigned> index,
493  ArrayRef<Attribute> operands,
495  // If the predecessor is the ForOp, branch into the body using the iterator
496  // arguments.
497  if (!index) {
498  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
499  return;
500  }
501 
502  // Otherwise, the loop may branch back to itself or the parent operation.
503  assert(*index == 0 && "expected loop region");
504  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
505  regions.push_back(RegionSuccessor(getResults()));
506 }
507 
509  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
510  ValueRange steps, ValueRange iterArgs,
512  bodyBuilder) {
513  assert(lbs.size() == ubs.size() &&
514  "expected the same number of lower and upper bounds");
515  assert(lbs.size() == steps.size() &&
516  "expected the same number of lower bounds and steps");
517 
518  // If there are no bounds, call the body-building function and return early.
519  if (lbs.empty()) {
520  ValueVector results =
521  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
522  : ValueVector();
523  assert(results.size() == iterArgs.size() &&
524  "loop nest body must return as many values as loop has iteration "
525  "arguments");
526  return LoopNest();
527  }
528 
529  // First, create the loop structure iteratively using the body-builder
530  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
531  OpBuilder::InsertionGuard guard(builder);
534  loops.reserve(lbs.size());
535  ivs.reserve(lbs.size());
536  ValueRange currentIterArgs = iterArgs;
537  Location currentLoc = loc;
538  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
539  auto loop = builder.create<scf::ForOp>(
540  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
541  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
542  ValueRange args) {
543  ivs.push_back(iv);
544  // It is safe to store ValueRange args because it points to block
545  // arguments of a loop operation that we also own.
546  currentIterArgs = args;
547  currentLoc = nestedLoc;
548  });
549  // Set the builder to point to the body of the newly created loop. We don't
550  // do this in the callback because the builder is reset when the callback
551  // returns.
552  builder.setInsertionPointToStart(loop.getBody());
553  loops.push_back(loop);
554  }
555 
556  // For all loops but the innermost, yield the results of the nested loop.
557  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
558  builder.setInsertionPointToEnd(loops[i].getBody());
559  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
560  }
561 
562  // In the body of the innermost loop, call the body building function if any
563  // and yield its results.
564  builder.setInsertionPointToStart(loops.back().getBody());
565  ValueVector results = bodyBuilder
566  ? bodyBuilder(builder, currentLoc, ivs,
567  loops.back().getRegionIterArgs())
568  : ValueVector();
569  assert(results.size() == iterArgs.size() &&
570  "loop nest body must return as many values as loop has iteration "
571  "arguments");
572  builder.setInsertionPointToEnd(loops.back().getBody());
573  builder.create<scf::YieldOp>(loc, results);
574 
575  // Return the loops.
576  LoopNest res;
577  res.loops.assign(loops.begin(), loops.end());
578  return res;
579 }
580 
582  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
583  ValueRange steps,
584  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
585  // Delegate to the main function by wrapping the body builder.
586  return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
587  [&bodyBuilder](OpBuilder &nestedBuilder,
588  Location nestedLoc, ValueRange ivs,
589  ValueRange) -> ValueVector {
590  if (bodyBuilder)
591  bodyBuilder(nestedBuilder, nestedLoc, ivs);
592  return {};
593  });
594 }
595 
596 namespace {
597 // Fold away ForOp iter arguments when:
598 // 1) The op yields the iter arguments.
599 // 2) The iter arguments have no use and the corresponding outer region
600 // iterators (inputs) are yielded.
601 // 3) The iter arguments have no use and the corresponding (operation) results
602 // have no use.
603 //
604 // These arguments must be defined outside of
605 // the ForOp region and can just be forwarded after simplifying the op inits,
606 // yields and returns.
607 //
608 // The implementation uses `mergeBlockBefore` to steal the content of the
609 // original ForOp and avoid cloning.
610 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
612 
613  LogicalResult matchAndRewrite(scf::ForOp forOp,
614  PatternRewriter &rewriter) const final {
615  bool canonicalize = false;
616  Block &block = forOp.getRegion().front();
617  auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
618 
619  // An internal flat vector of block transfer
620  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
621  // transformed block argument mappings. This plays the role of a
622  // BlockAndValueMapping for the particular use case of calling into
623  // `mergeBlockBefore`.
624  SmallVector<bool, 4> keepMask;
625  keepMask.reserve(yieldOp.getNumOperands());
626  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
627  newResultValues;
628  newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
629  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
630  newIterArgs.reserve(forOp.getNumIterOperands());
631  newYieldValues.reserve(yieldOp.getNumOperands());
632  newResultValues.reserve(forOp.getNumResults());
633  for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
634  forOp.getRegionIterArgs(), // iter inside region
635  forOp.getResults(), // op results
636  yieldOp.getOperands() // iter yield
637  )) {
638  // Forwarded is `true` when:
639  // 1) The region `iter` argument is yielded.
640  // 2) The region `iter` argument has no use, and the corresponding iter
641  // operand (input) is yielded.
642  // 3) The region `iter` argument has no use, and the corresponding op
643  // result has no use.
644  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
645  (std::get<1>(it).use_empty() &&
646  (std::get<0>(it) == std::get<3>(it) ||
647  std::get<2>(it).use_empty())));
648  keepMask.push_back(!forwarded);
649  canonicalize |= forwarded;
650  if (forwarded) {
651  newBlockTransferArgs.push_back(std::get<0>(it));
652  newResultValues.push_back(std::get<0>(it));
653  continue;
654  }
655  newIterArgs.push_back(std::get<0>(it));
656  newYieldValues.push_back(std::get<3>(it));
657  newBlockTransferArgs.push_back(Value()); // placeholder with null value
658  newResultValues.push_back(Value()); // placeholder with null value
659  }
660 
661  if (!canonicalize)
662  return failure();
663 
664  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
665  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
666  forOp.getStep(), newIterArgs);
667  newForOp->setAttrs(forOp->getAttrs());
668  Block &newBlock = newForOp.getRegion().front();
669 
670  // Replace the null placeholders with newly constructed values.
671  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
672  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
673  idx != e; ++idx) {
674  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
675  Value &newResultVal = newResultValues[idx];
676  assert((blockTransferArg && newResultVal) ||
677  (!blockTransferArg && !newResultVal));
678  if (!blockTransferArg) {
679  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
680  newResultVal = newForOp.getResult(collapsedIdx++);
681  }
682  }
683 
684  Block &oldBlock = forOp.getRegion().front();
685  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
686  "unexpected argument size mismatch");
687 
688  // No results case: the scf::ForOp builder already created a zero
689  // result terminator. Merge before this terminator and just get rid of the
690  // original terminator that has been merged in.
691  if (newIterArgs.empty()) {
692  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
693  rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
694  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
695  rewriter.replaceOp(forOp, newResultValues);
696  return success();
697  }
698 
699  // No terminator case: merge and rewrite the merged terminator.
700  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
701  OpBuilder::InsertionGuard g(rewriter);
702  rewriter.setInsertionPoint(mergedTerminator);
703  SmallVector<Value, 4> filteredOperands;
704  filteredOperands.reserve(newResultValues.size());
705  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
706  if (keepMask[idx])
707  filteredOperands.push_back(mergedTerminator.getOperand(idx));
708  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
709  filteredOperands);
710  };
711 
712  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
713  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
714  cloneFilteredTerminator(mergedYieldOp);
715  rewriter.eraseOp(mergedYieldOp);
716  rewriter.replaceOp(forOp, newResultValues);
717  return success();
718  }
719 };
720 
721 /// Util function that tries to compute a constant diff between u and l.
722 /// Returns llvm::None when the difference between two AffineValueMap is
723 /// dynamic.
724 static Optional<int64_t> computeConstDiff(Value l, Value u) {
725  IntegerAttr clb, cub;
726  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
727  llvm::APInt lbValue = clb.getValue();
728  llvm::APInt ubValue = cub.getValue();
729  return (ubValue - lbValue).getSExtValue();
730  }
731 
732  // Else a simple pattern match for x + c or c + x
733  llvm::APInt diff;
734  if (matchPattern(
735  u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
736  matchPattern(
737  u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
738  return diff.getSExtValue();
739  return llvm::None;
740 }
741 
742 /// Rewriting pattern that erases loops that are known not to iterate, replaces
743 /// single-iteration loops with their bodies, and removes empty loops that
744 /// iterate at least once and only return values defined outside of the loop.
745 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
747 
748  LogicalResult matchAndRewrite(ForOp op,
749  PatternRewriter &rewriter) const override {
750  // If the upper bound is the same as the lower bound, the loop does not
751  // iterate, just remove it.
752  if (op.getLowerBound() == op.getUpperBound()) {
753  rewriter.replaceOp(op, op.getIterOperands());
754  return success();
755  }
756 
757  Optional<int64_t> diff =
758  computeConstDiff(op.getLowerBound(), op.getUpperBound());
759  if (!diff)
760  return failure();
761 
762  // If the loop is known to have 0 iterations, remove it.
763  if (*diff <= 0) {
764  rewriter.replaceOp(op, op.getIterOperands());
765  return success();
766  }
767 
768  llvm::Optional<llvm::APInt> maybeStepValue = op.getConstantStep();
769  if (!maybeStepValue)
770  return failure();
771 
772  // If the loop is known to have 1 iteration, inline its body and remove the
773  // loop.
774  llvm::APInt stepValue = *maybeStepValue;
775  if (stepValue.sge(*diff)) {
776  SmallVector<Value, 4> blockArgs;
777  blockArgs.reserve(op.getNumIterOperands() + 1);
778  blockArgs.push_back(op.getLowerBound());
779  llvm::append_range(blockArgs, op.getIterOperands());
780  replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
781  return success();
782  }
783 
784  // Now we are left with loops that have more than 1 iterations.
785  Block &block = op.getRegion().front();
786  if (!llvm::hasSingleElement(block))
787  return failure();
788  // If the loop is empty, iterates at least once, and only returns values
789  // defined outside of the loop, remove it and replace it with yield values.
790  auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
791  auto yieldOperands = yieldOp.getOperands();
792  if (llvm::any_of(yieldOperands,
793  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
794  return failure();
795  rewriter.replaceOp(op, yieldOperands);
796  return success();
797  }
798 };
799 
800 /// Perform a replacement of one iter OpOperand of an scf.for to the
801 /// `replacement` value which is expected to be the source of a tensor.cast.
802 /// tensor.cast ops are inserted inside the block to account for the type cast.
803 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
804  OpOperand &operand,
805  Value replacement) {
806  Type oldType = operand.get().getType(), newType = replacement.getType();
807  assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
808  "expected ranked tensor types");
809 
810  // 1. Create new iter operands, exactly 1 is replaced.
811  ForOp forOp = cast<ForOp>(operand.getOwner());
812  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
813  "expected an iter OpOperand");
814  if (operand.get().getType() == replacement.getType())
815  return forOp;
816  SmallVector<Value> newIterOperands;
817  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
818  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
819  newIterOperands.push_back(replacement);
820  continue;
821  }
822  newIterOperands.push_back(opOperand.get());
823  }
824 
825  // 2. Create the new forOp shell.
826  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
827  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
828  forOp.getStep(), newIterOperands);
829  newForOp->setAttrs(forOp->getAttrs());
830  Block &newBlock = newForOp.getRegion().front();
831  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
832  newBlock.getArguments().end());
833 
834  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
835  // corresponding to the `replacement` value.
836  OpBuilder::InsertionGuard g(rewriter);
837  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
838  BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
839  newForOp->getOpOperand(operand.getOperandNumber()));
840  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
841  newRegionIterArg);
842  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
843 
844  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
845  Block &oldBlock = forOp.getRegion().front();
846  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
847 
848  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
849  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
850  rewriter.setInsertionPoint(clonedYieldOp);
851  unsigned yieldIdx =
852  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
853  Value castOut = rewriter.create<tensor::CastOp>(
854  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
855  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
856  newYieldOperands[yieldIdx] = castOut;
857  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
858  rewriter.eraseOp(clonedYieldOp);
859 
860  // 6. Inject an outgoing cast op after the forOp.
861  rewriter.setInsertionPointAfter(newForOp);
862  SmallVector<Value> newResults = newForOp.getResults();
863  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
864  newForOp.getLoc(), oldType, newResults[yieldIdx]);
865 
866  return newForOp;
867 }
868 
869 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
870 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
871 ///
872 /// ```
873 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
874 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
875 /// -> (tensor<?x?xf32>) {
876 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
877 /// scf.yield %2 : tensor<?x?xf32>
878 /// }
879 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
880 /// use_of(%2)
881 /// ```
882 ///
883 /// folds into:
884 ///
885 /// ```
886 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
887 /// -> (tensor<32x1024xf32>) {
888 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
889 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
890 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
891 /// scf.yield %4 : tensor<32x1024xf32>
892 /// }
893 /// use_of(%0)
894 /// ```
895 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
897 
898  LogicalResult matchAndRewrite(ForOp op,
899  PatternRewriter &rewriter) const override {
900  for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
901  OpOperand &iterOpOperand = std::get<0>(it);
902  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
903  if (!incomingCast)
904  continue;
905  if (!std::get<1>(it).hasOneUse())
906  continue;
907  auto outgoingCastOp =
908  dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
909  if (!outgoingCastOp)
910  continue;
911 
912  // Must be a tensor.cast op pair with matching types.
913  if (outgoingCastOp.getResult().getType() !=
914  incomingCast.getSource().getType())
915  continue;
916 
917  // Create a new ForOp with that iter operand replaced.
918  auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
919  incomingCast.getSource());
920 
921  // Insert outgoing cast and use it to replace the corresponding result.
922  rewriter.setInsertionPointAfter(newForOp);
923  SmallVector<Value> replacements = newForOp.getResults();
924  unsigned returnIdx =
925  iterOpOperand.getOperandNumber() - op.getNumControlOperands();
926  replacements[returnIdx] = rewriter.create<tensor::CastOp>(
927  op.getLoc(), incomingCast.getDest().getType(),
928  replacements[returnIdx]);
929  rewriter.replaceOp(op, replacements);
930  return success();
931  }
932  return failure();
933  }
934 };
935 
936 /// Canonicalize the iter_args of an scf::ForOp that involve a
937 /// `bufferization.to_tensor` and for which only the last loop iteration is
938 /// actually visible outside of the loop. The canonicalization looks for a
939 /// pattern such as:
940 /// ```
941 /// %t0 = ... : tensor_type
942 /// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
943 /// ...
944 /// // %m is either buffer_cast(%bb00) or defined above the loop
945 /// %m... : memref_type
946 /// ... // uses of %m with potential inplace updates
947 /// %new_tensor = bufferization.to_tensor %m : memref_type
948 /// ...
949 /// scf.yield %new_tensor : tensor_type
950 /// }
951 /// ```
952 ///
953 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
954 /// `%m = buffer_cast %bb0` op that feeds into the yielded
955 /// `bufferization.to_tensor` op.
956 ///
957 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
958 /// occurs between `bufferization.to_tensor and yield then the value %0
959 /// visible outside of the loop is the last `bufferization.to_tensor`
960 /// produced in the loop.
961 ///
962 /// For now, we approximate the absence of aliasing by only supporting the case
963 /// when the bufferization.to_tensor is the operation immediately preceding
964 /// the yield.
965 //
966 /// The canonicalization rewrites the pattern as:
967 /// ```
968 /// // %m is either a buffer_cast or defined above
969 /// %m... : memref_type
970 /// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
971 /// ... // uses of %m with potential inplace updates
972 /// scf.yield %bb0: tensor_type
973 /// }
974 /// %0 = bufferization.to_tensor %m : memref_type
975 /// ```
976 ///
977 /// A later bbArg canonicalization will further rewrite as:
978 /// ```
979 /// // %m is either a buffer_cast or defined above
980 /// %m... : memref_type
981 /// scf.for ... { // no iter_args
982 /// ... // uses of %m with potential inplace updates
983 /// }
984 /// %0 = bufferization.to_tensor %m : memref_type
985 /// ```
986 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
988 
989  LogicalResult matchAndRewrite(ForOp forOp,
990  PatternRewriter &rewriter) const override {
991  assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
992  "unexpected multiple blocks");
993 
994  Location loc = forOp.getLoc();
995  DenseMap<Value, Value> replacements;
996  for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
997  unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
998  auto yieldOp =
999  cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
1000  Value yieldVal = yieldOp->getOperand(idx);
1001  auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
1002  bool isTensor = bbArg.getType().isa<TensorType>();
1003 
1004  bufferization::ToMemrefOp tensorToMemref;
1005  // Either bbArg has no use or it has a single buffer_cast use.
1006  if (bbArg.hasOneUse())
1007  tensorToMemref =
1008  dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
1009  if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
1010  continue;
1011  // If tensorToMemref is present, it must feed into the `ToTensorOp`.
1012  if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
1013  continue;
1014  // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
1015  // must be before `ToTensorOp` in the block so that the lastWrite
1016  // property is not subject to additional side-effects.
1017  // For now, we only support the case when ToTensorOp appears
1018  // immediately before the terminator.
1019  if (tensorLoadOp->getNextNode() != yieldOp)
1020  continue;
1021 
1022  // Clone the optional tensorToMemref before forOp.
1023  if (tensorToMemref) {
1024  rewriter.setInsertionPoint(forOp);
1025  rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1026  tensorToMemref, tensorToMemref.getMemref().getType(),
1027  tensorToMemref.getTensor());
1028  }
1029 
1030  // Clone the tensorLoad after forOp.
1031  rewriter.setInsertionPointAfter(forOp);
1032  Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1033  loc, tensorLoadOp.getMemref());
1034  Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1035  replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1036 
1037  // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1038  // old bbArg (that is now directly yielded) will canonicalize away.
1039  rewriter.startRootUpdate(yieldOp);
1040  yieldOp.setOperand(idx, bbArg);
1041  rewriter.finalizeRootUpdate(yieldOp);
1042  }
1043  if (replacements.empty())
1044  return failure();
1045 
1046  // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1047  // replaces the whole op and erase it unconditionally. This is wrong for
1048  // `forOp` as it generally contains ops with side effects.
1049  // Instead, use `rewriter.replaceOpWithIf`.
1050  SmallVector<Value> newResults;
1051  newResults.reserve(forOp.getNumResults());
1052  for (Value v : forOp.getResults()) {
1053  auto it = replacements.find(v);
1054  newResults.push_back((it != replacements.end()) ? it->second : v);
1055  }
1056  unsigned idx = 0;
1057  rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1058  return op.get() != newResults[idx++];
1059  });
1060  return success();
1061  }
1062 };
1063 } // namespace
1064 
1065 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1066  MLIRContext *context) {
1067  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1068  LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1069 }
1070 
1071 Optional<APInt> ForOp::getConstantStep() {
1072  IntegerAttr step;
1073  if (matchPattern(getStep(), m_Constant(&step)))
1074  return step.getValue();
1075  return {};
1076 }
1077 
1078 Speculation::Speculatability ForOp::getSpeculatability() {
1079  // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
1080  // and End.
1081  if (auto constantStep = getConstantStep())
1082  if (*constantStep == 1)
1084 
1085  // For Step != 1, the loop may not terminate. We can add more smarts here if
1086  // needed.
1088 }
1089 
1090 //===----------------------------------------------------------------------===//
1091 // ForeachThreadOp
1092 //===----------------------------------------------------------------------===//
1093 
1095  // Call terminator's verify to produce most informative error messages.
1096  if (failed(getTerminator().verify()))
1097  return failure();
1098 
1099  // Check number of outputs.
1100  if (getNumResults() != getOutputs().size())
1101  return emitOpError("produces ")
1102  << getNumResults() << " results, but has only "
1103  << getOutputs().size() << " outputs";
1104 
1105  // Check that the body defines block arguments for thread indices and outputs.
1106  auto *body = getBody();
1107  if (body->getNumArguments() != getRank() + getOutputs().size())
1108  return emitOpError("region expects ") << getRank() << " arguments";
1109  for (int64_t i = 0; i < getRank(); ++i)
1110  if (!body->getArgument(i).getType().isIndex())
1111  return emitOpError("expects ")
1112  << i << "-th block argument to be an index";
1113  for (unsigned i = 0; i < getOutputs().size(); ++i)
1114  if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
1115  return emitOpError("type mismatch between ")
1116  << i << "-th output and corresponding block argument";
1117  if (getMapping().has_value() && !getMapping()->empty()) {
1118  if (static_cast<int64_t>(getMapping()->size()) != getRank())
1119  return emitOpError() << "mapping attribute size must match op rank";
1120  for (auto map : getMapping()->getValue()) {
1121  if (!isa<DeviceMappingAttrInterface>(map))
1122  return emitOpError()
1123  << getMappingAttrName() << " is not device mapping attribute";
1124  }
1125  }
1126 
1127  return success();
1128 }
1129 
1131  p << " (";
1132  llvm::interleaveComma(getThreadIndices(), p);
1133  p << ") in (";
1134  llvm::interleaveComma(getNumThreads(), p);
1135  p << ")";
1136  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
1137  p << " ";
1138  if (!getRegionOutArgs().empty())
1139  p << "-> (" << getResultTypes() << ") ";
1140  p.printRegion(getRegion(),
1141  /*printEntryBlockArgs=*/false,
1142  /*printBlockTerminators=*/getNumResults() > 0);
1143  p.printOptionalAttrDict(getOperation()->getAttrs(),
1144  {"operand_segment_sizes"});
1145 }
1146 
1147 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
1148  OperationState &result) {
1149  auto &builder = parser.getBuilder();
1150  // Parse an opening `(` followed by thread index variables followed by `)`
1151  // TODO: when we can refer to such "induction variable"-like handles from the
1152  // declarative assembly format, we can implement the parser as a custom hook.
1154  if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
1155  return failure();
1156 
1157  // Parse `in` threadNums.
1159  if (parser.parseKeyword("in") ||
1160  parser.parseOperandList(threadNums, threadIndices.size(),
1162  parser.resolveOperands(threadNums, builder.getIndexType(),
1163  result.operands))
1164  return failure();
1165 
1166  // Parse out operands and results.
1169  SMLoc outOperandsLoc = parser.getCurrentLocation();
1170  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
1171  if (outOperands.size() != result.types.size())
1172  return parser.emitError(outOperandsLoc,
1173  "mismatch between out operands and types");
1174  if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
1175  parser.parseOptionalArrowTypeList(result.types) ||
1176  parser.resolveOperands(outOperands, result.types, outOperandsLoc,
1177  result.operands))
1178  return failure();
1179  }
1180 
1181  // Parse region.
1183  std::unique_ptr<Region> region = std::make_unique<Region>();
1184  for (auto &idx : threadIndices) {
1185  idx.type = builder.getIndexType();
1186  regionArgs.push_back(idx);
1187  }
1188  for (const auto &it : llvm::enumerate(regionOutArgs)) {
1189  auto &out = it.value();
1190  out.type = result.types[it.index()];
1191  regionArgs.push_back(out);
1192  }
1193  if (parser.parseRegion(*region, regionArgs))
1194  return failure();
1195 
1196  // Ensure terminator and move region.
1197  OpBuilder b(builder.getContext());
1198  ForeachThreadOp::ensureTerminator(*region, b, result.location);
1199  result.addRegion(std::move(region));
1200 
1201  // Parse the optional attribute list.
1202  if (parser.parseOptionalAttrDict(result.attributes))
1203  return failure();
1204  result.addAttribute("operand_segment_sizes",
1206  {static_cast<int32_t>(threadNums.size()),
1207  static_cast<int32_t>(outOperands.size())}));
1208  return success();
1209 }
1210 
1211 // Bodyless builder, outputs must be specified.
1212 void ForeachThreadOp::build(mlir::OpBuilder &builder,
1213  mlir::OperationState &result, ValueRange outputs,
1214  ValueRange numThreads,
1215  Optional<ArrayAttr> mapping) {
1216  result.addOperands(numThreads);
1217  result.addOperands(outputs);
1218  if (mapping.has_value()) {
1220  mapping.value());
1221  }
1222 
1223  result.addAttribute(
1224  "operand_segment_sizes",
1225  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
1226  static_cast<int32_t>(outputs.size())}));
1227  result.addTypes(TypeRange(outputs));
1228 
1229  Region *bodyRegion = result.addRegion();
1230  OpBuilder::InsertionGuard g(builder);
1231  // createBlock sets the IP inside the block.
1232  // Generally we would guard against that but the default ensureTerminator impl
1233  // expects it ..
1234  builder.createBlock(bodyRegion);
1235  Block &bodyBlock = bodyRegion->front();
1236  // Add block arguments for indices and outputs.
1237  bodyBlock.addArguments(
1238  SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1239  SmallVector<Location>(numThreads.size(), result.location));
1240  bodyBlock.addArguments(
1241  TypeRange(outputs),
1242  SmallVector<Location>(outputs.size(), result.location));
1243  ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
1244 }
1245 
1246 // Builder that takes a bodyBuilder lambda.
1247 void ForeachThreadOp::build(
1248  mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs,
1249  ValueRange numThreads, ArrayRef<Attribute> mapping,
1250  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1251  result.addOperands(numThreads);
1252  result.addOperands(outputs);
1254  builder.getArrayAttr(mapping));
1255  result.addAttribute(
1256  "operand_segment_sizes",
1257  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
1258  static_cast<int32_t>(outputs.size())}));
1259  result.addTypes(TypeRange(outputs));
1260 
1261  Region *bodyRegion = result.addRegion();
1262  OpBuilder::InsertionGuard g(builder);
1263  builder.createBlock(bodyRegion);
1264  Block &bodyBlock = bodyRegion->front();
1265  // Add block arguments for indices and outputs.
1266  bodyBlock.addArguments(
1267  SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1268  SmallVector<Location>(numThreads.size(), result.location));
1269  bodyBlock.addArguments(
1270  TypeRange(outputs),
1271  SmallVector<Location>(outputs.size(), result.location));
1272 
1273  builder.setInsertionPointToStart(&bodyBlock);
1274  bodyBuilder(builder, result.location, bodyBlock.getArguments());
1275 #ifndef NDEBUG
1276  auto terminator =
1277  llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1278  assert(terminator &&
1279  "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1280 #endif // NDEBUG
1281 }
1282 
1283 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1284 // unaware of the fact that our terminator also needs a region to be
1285 // well-formed. We override it here to ensure that we do the right thing.
1286 void ForeachThreadOp::ensureTerminator(Region &region, OpBuilder &builder,
1287  Location loc) {
1289  ForeachThreadOp>::ensureTerminator(region, builder, loc);
1290  auto terminator =
1291  llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
1292  if (terminator.getRegion().empty())
1293  builder.createBlock(&terminator.getRegion());
1294 }
1295 
1296 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1297  return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1298 }
1299 
1300 /// Helper to sort `values` according to matching `keys`.
1301 SmallVector<Value> ForeachThreadOp::getValuesSortedByKey(
1302  ArrayRef<Attribute> keys, ValueRange values,
1304  if (keys.empty())
1305  return values;
1306  assert(keys.size() == values.size() && "unexpected mismatching sizes");
1307  auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
1308  std::sort(indices.begin(), indices.end(),
1309  [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
1310  SmallVector<Value> res;
1311  res.reserve(values.size());
1312  for (int64_t i = 0, e = indices.size(); i < e; ++i)
1313  res.push_back(values[indices[i]]);
1314  return res;
1315 }
1316 
1318  auto tidxArg = val.dyn_cast<BlockArgument>();
1319  if (!tidxArg)
1320  return ForeachThreadOp();
1321  assert(tidxArg.getOwner() && "unlinked block argument");
1322  auto *containingOp = tidxArg.getOwner()->getParentOp();
1323  return dyn_cast<ForeachThreadOp>(containingOp);
1324 }
1325 
1326 namespace {
1327 /// Fold tensor.dim(foreach_thread shared_outs(... = %t)) to tensor.dim(%t).
1328 struct DimOfForeachThreadOp : public OpRewritePattern<tensor::DimOp> {
1330 
1331  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1332  PatternRewriter &rewriter) const final {
1333  auto foreachThreadOp = dimOp.getSource().getDefiningOp<ForeachThreadOp>();
1334  if (!foreachThreadOp)
1335  return failure();
1336  Value sharedOut =
1337  foreachThreadOp.getTiedOpOperand(dimOp.getSource().cast<OpResult>())
1338  ->get();
1339  rewriter.updateRootInPlace(
1340  dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
1341  return success();
1342  }
1343 };
1344 } // namespace
1345 
1346 void ForeachThreadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1347  MLIRContext *context) {
1348  results.add<DimOfForeachThreadOp>(context);
1349 }
1350 
1351 //===----------------------------------------------------------------------===//
1352 // PerformConcurrentlyOp
1353 //===----------------------------------------------------------------------===//
1354 
1355 // Build a PerformConcurrentlyOp with mixed static and dynamic entries.
1356 void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1358  Region *bodyRegion = result.addRegion();
1359  b.createBlock(bodyRegion);
1360 }
1361 
1363  scf::ForeachThreadOp foreachThreadOp =
1364  dyn_cast<scf::ForeachThreadOp>(getOperation()->getParentOp());
1365  if (!foreachThreadOp)
1366  return this->emitOpError("expected foreach_thread op parent");
1367 
1368  // TODO: PerformConcurrentlyOpInterface.
1369  for (Operation &op : getRegion().front().getOperations()) {
1370  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1371  return this->emitOpError("expected only ")
1372  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1373  }
1374 
1375  // Verify that inserts are into out block arguments.
1376  Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
1377  ArrayRef<BlockArgument> regionOutArgs = foreachThreadOp.getRegionOutArgs();
1378  if (!llvm::is_contained(regionOutArgs, dest))
1379  return op.emitOpError("may only insert into an output block argument");
1380  }
1381  return success();
1382 }
1383 
1385  p << " ";
1386  p.printRegion(getRegion(),
1387  /*printEntryBlockArgs=*/false,
1388  /*printBlockTerminators=*/false);
1389  p.printOptionalAttrDict(getOperation()->getAttrs());
1390 }
1391 
1392 ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
1393  OperationState &result) {
1394  auto &builder = parser.getBuilder();
1395 
1397  std::unique_ptr<Region> region = std::make_unique<Region>();
1398  if (parser.parseRegion(*region, regionOperands))
1399  return failure();
1400 
1401  if (region->empty())
1402  OpBuilder(builder.getContext()).createBlock(region.get());
1403  result.addRegion(std::move(region));
1404 
1405  // Parse the optional attribute list.
1406  if (parser.parseOptionalAttrDict(result.attributes))
1407  return failure();
1408  return success();
1409 }
1410 
1411 OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
1412  return getOperation()->getParentOp()->getResult(idx);
1413 }
1414 
1415 SmallVector<BlockArgument> PerformConcurrentlyOp::getDests() {
1416  return llvm::to_vector<4>(
1417  llvm::map_range(getYieldingOps(), [](Operation &op) {
1418  // Add new ops here as needed.
1419  auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
1420  return insertSliceOp.getDest().cast<BlockArgument>();
1421  }));
1422 }
1423 
1424 llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
1425  return getRegion().front().getOperations();
1426 }
1427 
1428 //===----------------------------------------------------------------------===//
1429 // IfOp
1430 //===----------------------------------------------------------------------===//
1431 
1433  assert(a && "expected non-empty operation");
1434  assert(b && "expected non-empty operation");
1435 
1436  IfOp ifOp = a->getParentOfType<IfOp>();
1437  while (ifOp) {
1438  // Check if b is inside ifOp. (We already know that a is.)
1439  if (ifOp->isProperAncestor(b))
1440  // b is contained in ifOp. a and b are in mutually exclusive branches if
1441  // they are in different blocks of ifOp.
1442  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1443  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1444  // Check next enclosing IfOp.
1445  ifOp = ifOp->getParentOfType<IfOp>();
1446  }
1447 
1448  // Could not find a common IfOp among a's and b's ancestors.
1449  return false;
1450 }
1451 
1452 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1453  bool withElseRegion) {
1454  build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1455 }
1456 
1457 void IfOp::build(OpBuilder &builder, OperationState &result,
1458  TypeRange resultTypes, Value cond, bool withElseRegion) {
1459  auto addTerminator = [&](OpBuilder &nested, Location loc) {
1460  if (resultTypes.empty())
1461  IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1462  loc);
1463  };
1464 
1465  build(builder, result, resultTypes, cond, addTerminator,
1466  withElseRegion ? addTerminator
1467  : function_ref<void(OpBuilder &, Location)>());
1468 }
1469 
1470 void IfOp::build(OpBuilder &builder, OperationState &result,
1471  TypeRange resultTypes, Value cond,
1472  function_ref<void(OpBuilder &, Location)> thenBuilder,
1473  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1474  assert(thenBuilder && "the builder callback for 'then' must be present");
1475 
1476  result.addOperands(cond);
1477  result.addTypes(resultTypes);
1478 
1479  OpBuilder::InsertionGuard guard(builder);
1480  Region *thenRegion = result.addRegion();
1481  builder.createBlock(thenRegion);
1482  thenBuilder(builder, result.location);
1483 
1484  Region *elseRegion = result.addRegion();
1485  if (!elseBuilder)
1486  return;
1487 
1488  builder.createBlock(elseRegion);
1489  elseBuilder(builder, result.location);
1490 }
1491 
1492 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1493  function_ref<void(OpBuilder &, Location)> thenBuilder,
1494  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1495  build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1496 }
1497 
1499  if (getNumResults() != 0 && getElseRegion().empty())
1500  return emitOpError("must have an else block if defining values");
1501  return success();
1502 }
1503 
1504 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1505  // Create the regions for 'then'.
1506  result.regions.reserve(2);
1507  Region *thenRegion = result.addRegion();
1508  Region *elseRegion = result.addRegion();
1509 
1510  auto &builder = parser.getBuilder();
1512  Type i1Type = builder.getIntegerType(1);
1513  if (parser.parseOperand(cond) ||
1514  parser.resolveOperand(cond, i1Type, result.operands))
1515  return failure();
1516  // Parse optional results type list.
1517  if (parser.parseOptionalArrowTypeList(result.types))
1518  return failure();
1519  // Parse the 'then' region.
1520  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1521  return failure();
1522  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1523 
1524  // If we find an 'else' keyword then parse the 'else' region.
1525  if (!parser.parseOptionalKeyword("else")) {
1526  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1527  return failure();
1528  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1529  }
1530 
1531  // Parse the optional attribute list.
1532  if (parser.parseOptionalAttrDict(result.attributes))
1533  return failure();
1534  return success();
1535 }
1536 
1537 void IfOp::print(OpAsmPrinter &p) {
1538  bool printBlockTerminators = false;
1539 
1540  p << " " << getCondition();
1541  if (!getResults().empty()) {
1542  p << " -> (" << getResultTypes() << ")";
1543  // Print yield explicitly if the op defines values.
1544  printBlockTerminators = true;
1545  }
1546  p << ' ';
1547  p.printRegion(getThenRegion(),
1548  /*printEntryBlockArgs=*/false,
1549  /*printBlockTerminators=*/printBlockTerminators);
1550 
1551  // Print the 'else' regions if it exists and has a block.
1552  auto &elseRegion = getElseRegion();
1553  if (!elseRegion.empty()) {
1554  p << " else ";
1555  p.printRegion(elseRegion,
1556  /*printEntryBlockArgs=*/false,
1557  /*printBlockTerminators=*/printBlockTerminators);
1558  }
1559 
1560  p.printOptionalAttrDict((*this)->getAttrs());
1561 }
1562 
1563 /// Given the region at `index`, or the parent operation if `index` is None,
1564 /// return the successor regions. These are the regions that may be selected
1565 /// during the flow of control. `operands` is a set of optional attributes that
1566 /// correspond to a constant value for each operand, or null if that operand is
1567 /// not a constant.
1568 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1569  ArrayRef<Attribute> operands,
1571  // The `then` and the `else` region branch back to the parent operation.
1572  if (index) {
1573  regions.push_back(RegionSuccessor(getResults()));
1574  return;
1575  }
1576 
1577  // Don't consider the else region if it is empty.
1578  Region *elseRegion = &this->getElseRegion();
1579  if (elseRegion->empty())
1580  elseRegion = nullptr;
1581 
1582  // Otherwise, the successor is dependent on the condition.
1583  bool condition;
1584  if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1585  condition = condAttr.getValue().isOneValue();
1586  } else {
1587  // If the condition isn't constant, both regions may be executed.
1588  regions.push_back(RegionSuccessor(&getThenRegion()));
1589  // If the else region does not exist, it is not a viable successor.
1590  if (elseRegion)
1591  regions.push_back(RegionSuccessor(elseRegion));
1592  return;
1593  }
1594 
1595  // Add the successor regions using the condition.
1596  regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1597 }
1598 
1599 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1600  SmallVectorImpl<OpFoldResult> &results) {
1601  // if (!c) then A() else B() -> if c then B() else A()
1602  if (getElseRegion().empty())
1603  return failure();
1604 
1605  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1606  if (!xorStmt)
1607  return failure();
1608 
1609  if (!matchPattern(xorStmt.getRhs(), m_One()))
1610  return failure();
1611 
1612  getConditionMutable().assign(xorStmt.getLhs());
1613  Block *thenBlock = &getThenRegion().front();
1614  // It would be nicer to use iplist::swap, but that has no implemented
1615  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1616  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1617  getElseRegion().getBlocks());
1618  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1619  getThenRegion().getBlocks(), thenBlock);
1620  return success();
1621 }
1622 
1623 void IfOp::getRegionInvocationBounds(
1624  ArrayRef<Attribute> operands,
1625  SmallVectorImpl<InvocationBounds> &invocationBounds) {
1626  if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1627  // If the condition is known, then one region is known to be executed once
1628  // and the other zero times.
1629  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1630  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1631  } else {
1632  // Non-constant condition. Each region may be executed 0 or 1 times.
1633  invocationBounds.assign(2, {0, 1});
1634  }
1635 }
1636 
1637 namespace {
1638 // Pattern to remove unused IfOp results.
1639 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1641 
1642  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1643  PatternRewriter &rewriter) const {
1644  // Move all operations to the destination block.
1645  rewriter.mergeBlocks(source, dest);
1646  // Replace the yield op by one that returns only the used values.
1647  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1648  SmallVector<Value, 4> usedOperands;
1649  llvm::transform(usedResults, std::back_inserter(usedOperands),
1650  [&](OpResult result) {
1651  return yieldOp.getOperand(result.getResultNumber());
1652  });
1653  rewriter.updateRootInPlace(yieldOp,
1654  [&]() { yieldOp->setOperands(usedOperands); });
1655  }
1656 
1657  LogicalResult matchAndRewrite(IfOp op,
1658  PatternRewriter &rewriter) const override {
1659  // Compute the list of used results.
1660  SmallVector<OpResult, 4> usedResults;
1661  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1662  [](OpResult result) { return !result.use_empty(); });
1663 
1664  // Replace the operation if only a subset of its results have uses.
1665  if (usedResults.size() == op.getNumResults())
1666  return failure();
1667 
1668  // Compute the result types of the replacement operation.
1669  SmallVector<Type, 4> newTypes;
1670  llvm::transform(usedResults, std::back_inserter(newTypes),
1671  [](OpResult result) { return result.getType(); });
1672 
1673  // Create a replacement operation with empty then and else regions.
1674  auto emptyBuilder = [](OpBuilder &, Location) {};
1675  auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1676  emptyBuilder, emptyBuilder);
1677 
1678  // Move the bodies and replace the terminators (note there is a then and
1679  // an else region since the operation returns results).
1680  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1681  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1682 
1683  // Replace the operation by the new one.
1684  SmallVector<Value, 4> repResults(op.getNumResults());
1685  for (const auto &en : llvm::enumerate(usedResults))
1686  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1687  rewriter.replaceOp(op, repResults);
1688  return success();
1689  }
1690 };
1691 
1692 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1694 
1695  LogicalResult matchAndRewrite(IfOp op,
1696  PatternRewriter &rewriter) const override {
1697  BoolAttr condition;
1698  if (!matchPattern(op.getCondition(), m_Constant(&condition)))
1699  return failure();
1700 
1701  if (condition.getValue())
1702  replaceOpWithRegion(rewriter, op, op.getThenRegion());
1703  else if (!op.getElseRegion().empty())
1704  replaceOpWithRegion(rewriter, op, op.getElseRegion());
1705  else
1706  rewriter.eraseOp(op);
1707 
1708  return success();
1709  }
1710 };
1711 
1712 /// Hoist any yielded results whose operands are defined outside
1713 /// the if, to a select instruction.
1714 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1716 
1717  LogicalResult matchAndRewrite(IfOp op,
1718  PatternRewriter &rewriter) const override {
1719  if (op->getNumResults() == 0)
1720  return failure();
1721 
1722  auto cond = op.getCondition();
1723  auto thenYieldArgs = op.thenYield().getOperands();
1724  auto elseYieldArgs = op.elseYield().getOperands();
1725 
1726  SmallVector<Type> nonHoistable;
1727  for (const auto &it :
1728  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1729  Value trueVal = std::get<0>(it.value());
1730  Value falseVal = std::get<1>(it.value());
1731  if (&op.getThenRegion() == trueVal.getParentRegion() ||
1732  &op.getElseRegion() == falseVal.getParentRegion())
1733  nonHoistable.push_back(trueVal.getType());
1734  }
1735  // Early exit if there aren't any yielded values we can
1736  // hoist outside the if.
1737  if (nonHoistable.size() == op->getNumResults())
1738  return failure();
1739 
1740  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1741  if (replacement.thenBlock())
1742  rewriter.eraseBlock(replacement.thenBlock());
1743  replacement.getThenRegion().takeBody(op.getThenRegion());
1744  replacement.getElseRegion().takeBody(op.getElseRegion());
1745 
1746  SmallVector<Value> results(op->getNumResults());
1747  assert(thenYieldArgs.size() == results.size());
1748  assert(elseYieldArgs.size() == results.size());
1749 
1750  SmallVector<Value> trueYields;
1751  SmallVector<Value> falseYields;
1752  rewriter.setInsertionPoint(replacement);
1753  for (const auto &it :
1754  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1755  Value trueVal = std::get<0>(it.value());
1756  Value falseVal = std::get<1>(it.value());
1757  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1758  &replacement.getElseRegion() == falseVal.getParentRegion()) {
1759  results[it.index()] = replacement.getResult(trueYields.size());
1760  trueYields.push_back(trueVal);
1761  falseYields.push_back(falseVal);
1762  } else if (trueVal == falseVal)
1763  results[it.index()] = trueVal;
1764  else
1765  results[it.index()] = rewriter.create<arith::SelectOp>(
1766  op.getLoc(), cond, trueVal, falseVal);
1767  }
1768 
1769  rewriter.setInsertionPointToEnd(replacement.thenBlock());
1770  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1771 
1772  rewriter.setInsertionPointToEnd(replacement.elseBlock());
1773  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1774 
1775  rewriter.replaceOp(op, results);
1776  return success();
1777  }
1778 };
1779 
1780 /// Allow the true region of an if to assume the condition is true
1781 /// and vice versa. For example:
1782 ///
1783 /// scf.if %cmp {
1784 /// print(%cmp)
1785 /// }
1786 ///
1787 /// becomes
1788 ///
1789 /// scf.if %cmp {
1790 /// print(true)
1791 /// }
1792 ///
1793 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1795 
1796  LogicalResult matchAndRewrite(IfOp op,
1797  PatternRewriter &rewriter) const override {
1798  // Early exit if the condition is constant since replacing a constant
1799  // in the body with another constant isn't a simplification.
1800  if (matchPattern(op.getCondition(), m_Constant()))
1801  return failure();
1802 
1803  bool changed = false;
1804  mlir::Type i1Ty = rewriter.getI1Type();
1805 
1806  // These variables serve to prevent creating duplicate constants
1807  // and hold constant true or false values.
1808  Value constantTrue = nullptr;
1809  Value constantFalse = nullptr;
1810 
1811  for (OpOperand &use :
1812  llvm::make_early_inc_range(op.getCondition().getUses())) {
1813  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1814  changed = true;
1815 
1816  if (!constantTrue)
1817  constantTrue = rewriter.create<arith::ConstantOp>(
1818  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1819 
1820  rewriter.updateRootInPlace(use.getOwner(),
1821  [&]() { use.set(constantTrue); });
1822  } else if (op.getElseRegion().isAncestor(
1823  use.getOwner()->getParentRegion())) {
1824  changed = true;
1825 
1826  if (!constantFalse)
1827  constantFalse = rewriter.create<arith::ConstantOp>(
1828  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1829 
1830  rewriter.updateRootInPlace(use.getOwner(),
1831  [&]() { use.set(constantFalse); });
1832  }
1833  }
1834 
1835  return success(changed);
1836  }
1837 };
1838 
1839 /// Remove any statements from an if that are equivalent to the condition
1840 /// or its negation. For example:
1841 ///
1842 /// %res:2 = scf.if %cmp {
1843 /// yield something(), true
1844 /// } else {
1845 /// yield something2(), false
1846 /// }
1847 /// print(%res#1)
1848 ///
1849 /// becomes
1850 /// %res = scf.if %cmp {
1851 /// yield something()
1852 /// } else {
1853 /// yield something2()
1854 /// }
1855 /// print(%cmp)
1856 ///
1857 /// Additionally if both branches yield the same value, replace all uses
1858 /// of the result with the yielded value.
1859 ///
1860 /// %res:2 = scf.if %cmp {
1861 /// yield something(), %arg1
1862 /// } else {
1863 /// yield something2(), %arg1
1864 /// }
1865 /// print(%res#1)
1866 ///
1867 /// becomes
1868 /// %res = scf.if %cmp {
1869 /// yield something()
1870 /// } else {
1871 /// yield something2()
1872 /// }
1873 /// print(%arg1)
1874 ///
1875 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1877 
1878  LogicalResult matchAndRewrite(IfOp op,
1879  PatternRewriter &rewriter) const override {
1880  // Early exit if there are no results that could be replaced.
1881  if (op.getNumResults() == 0)
1882  return failure();
1883 
1884  auto trueYield =
1885  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1886  auto falseYield =
1887  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1888 
1889  rewriter.setInsertionPoint(op->getBlock(),
1890  op.getOperation()->getIterator());
1891  bool changed = false;
1892  Type i1Ty = rewriter.getI1Type();
1893  for (auto [trueResult, falseResult, opResult] :
1894  llvm::zip(trueYield.getResults(), falseYield.getResults(),
1895  op.getResults())) {
1896  if (trueResult == falseResult) {
1897  if (!opResult.use_empty()) {
1898  opResult.replaceAllUsesWith(trueResult);
1899  changed = true;
1900  }
1901  continue;
1902  }
1903 
1904  BoolAttr trueYield, falseYield;
1905  if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
1906  !matchPattern(falseResult, m_Constant(&falseYield)))
1907  continue;
1908 
1909  bool trueVal = trueYield.getValue();
1910  bool falseVal = falseYield.getValue();
1911  if (!trueVal && falseVal) {
1912  if (!opResult.use_empty()) {
1913  Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
1914  Value notCond = rewriter.create<arith::XOrIOp>(
1915  op.getLoc(), op.getCondition(),
1916  constDialect
1917  ->materializeConstant(rewriter,
1918  rewriter.getIntegerAttr(i1Ty, 1), i1Ty,
1919  op.getLoc())
1920  ->getResult(0));
1921  opResult.replaceAllUsesWith(notCond);
1922  changed = true;
1923  }
1924  }
1925  if (trueVal && !falseVal) {
1926  if (!opResult.use_empty()) {
1927  opResult.replaceAllUsesWith(op.getCondition());
1928  changed = true;
1929  }
1930  }
1931  }
1932  return success(changed);
1933  }
1934 };
1935 
1936 /// Merge any consecutive scf.if's with the same condition.
1937 ///
1938 /// scf.if %cond {
1939 /// firstCodeTrue();...
1940 /// } else {
1941 /// firstCodeFalse();...
1942 /// }
1943 /// %res = scf.if %cond {
1944 /// secondCodeTrue();...
1945 /// } else {
1946 /// secondCodeFalse();...
1947 /// }
1948 ///
1949 /// becomes
1950 /// %res = scf.if %cmp {
1951 /// firstCodeTrue();...
1952 /// secondCodeTrue();...
1953 /// } else {
1954 /// firstCodeFalse();...
1955 /// secondCodeFalse();...
1956 /// }
1957 struct CombineIfs : public OpRewritePattern<IfOp> {
1959 
1960  LogicalResult matchAndRewrite(IfOp nextIf,
1961  PatternRewriter &rewriter) const override {
1962  Block *parent = nextIf->getBlock();
1963  if (nextIf == &parent->front())
1964  return failure();
1965 
1966  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1967  if (!prevIf)
1968  return failure();
1969 
1970  // Determine the logical then/else blocks when prevIf's
1971  // condition is used. Null means the block does not exist
1972  // in that case (e.g. empty else). If neither of these
1973  // are set, the two conditions cannot be compared.
1974  Block *nextThen = nullptr;
1975  Block *nextElse = nullptr;
1976  if (nextIf.getCondition() == prevIf.getCondition()) {
1977  nextThen = nextIf.thenBlock();
1978  if (!nextIf.getElseRegion().empty())
1979  nextElse = nextIf.elseBlock();
1980  }
1981  if (arith::XOrIOp notv =
1982  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1983  if (notv.getLhs() == prevIf.getCondition() &&
1984  matchPattern(notv.getRhs(), m_One())) {
1985  nextElse = nextIf.thenBlock();
1986  if (!nextIf.getElseRegion().empty())
1987  nextThen = nextIf.elseBlock();
1988  }
1989  }
1990  if (arith::XOrIOp notv =
1991  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1992  if (notv.getLhs() == nextIf.getCondition() &&
1993  matchPattern(notv.getRhs(), m_One())) {
1994  nextElse = nextIf.thenBlock();
1995  if (!nextIf.getElseRegion().empty())
1996  nextThen = nextIf.elseBlock();
1997  }
1998  }
1999 
2000  if (!nextThen && !nextElse)
2001  return failure();
2002 
2003  SmallVector<Value> prevElseYielded;
2004  if (!prevIf.getElseRegion().empty())
2005  prevElseYielded = prevIf.elseYield().getOperands();
2006  // Replace all uses of return values of op within nextIf with the
2007  // corresponding yields
2008  for (auto it : llvm::zip(prevIf.getResults(),
2009  prevIf.thenYield().getOperands(), prevElseYielded))
2010  for (OpOperand &use :
2011  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
2012  if (nextThen && nextThen->getParent()->isAncestor(
2013  use.getOwner()->getParentRegion())) {
2014  rewriter.startRootUpdate(use.getOwner());
2015  use.set(std::get<1>(it));
2016  rewriter.finalizeRootUpdate(use.getOwner());
2017  } else if (nextElse && nextElse->getParent()->isAncestor(
2018  use.getOwner()->getParentRegion())) {
2019  rewriter.startRootUpdate(use.getOwner());
2020  use.set(std::get<2>(it));
2021  rewriter.finalizeRootUpdate(use.getOwner());
2022  }
2023  }
2024 
2025  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
2026  llvm::append_range(mergedTypes, nextIf.getResultTypes());
2027 
2028  IfOp combinedIf = rewriter.create<IfOp>(
2029  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
2030  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
2031 
2032  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
2033  combinedIf.getThenRegion(),
2034  combinedIf.getThenRegion().begin());
2035 
2036  if (nextThen) {
2037  YieldOp thenYield = combinedIf.thenYield();
2038  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
2039  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
2040  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
2041 
2042  SmallVector<Value> mergedYields(thenYield.getOperands());
2043  llvm::append_range(mergedYields, thenYield2.getOperands());
2044  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
2045  rewriter.eraseOp(thenYield);
2046  rewriter.eraseOp(thenYield2);
2047  }
2048 
2049  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
2050  combinedIf.getElseRegion(),
2051  combinedIf.getElseRegion().begin());
2052 
2053  if (nextElse) {
2054  if (combinedIf.getElseRegion().empty()) {
2055  rewriter.inlineRegionBefore(*nextElse->getParent(),
2056  combinedIf.getElseRegion(),
2057  combinedIf.getElseRegion().begin());
2058  } else {
2059  YieldOp elseYield = combinedIf.elseYield();
2060  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
2061  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
2062 
2063  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
2064 
2065  SmallVector<Value> mergedElseYields(elseYield.getOperands());
2066  llvm::append_range(mergedElseYields, elseYield2.getOperands());
2067 
2068  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
2069  rewriter.eraseOp(elseYield);
2070  rewriter.eraseOp(elseYield2);
2071  }
2072  }
2073 
2074  SmallVector<Value> prevValues;
2075  SmallVector<Value> nextValues;
2076  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
2077  if (pair.index() < prevIf.getNumResults())
2078  prevValues.push_back(pair.value());
2079  else
2080  nextValues.push_back(pair.value());
2081  }
2082  rewriter.replaceOp(prevIf, prevValues);
2083  rewriter.replaceOp(nextIf, nextValues);
2084  return success();
2085  }
2086 };
2087 
2088 /// Pattern to remove an empty else branch.
2089 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
2091 
2092  LogicalResult matchAndRewrite(IfOp ifOp,
2093  PatternRewriter &rewriter) const override {
2094  // Cannot remove else region when there are operation results.
2095  if (ifOp.getNumResults())
2096  return failure();
2097  Block *elseBlock = ifOp.elseBlock();
2098  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
2099  return failure();
2100  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
2101  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
2102  newIfOp.getThenRegion().begin());
2103  rewriter.eraseOp(ifOp);
2104  return success();
2105  }
2106 };
2107 
2108 /// Convert nested `if`s into `arith.andi` + single `if`.
2109 ///
2110 /// scf.if %arg0 {
2111 /// scf.if %arg1 {
2112 /// ...
2113 /// scf.yield
2114 /// }
2115 /// scf.yield
2116 /// }
2117 /// becomes
2118 ///
2119 /// %0 = arith.andi %arg0, %arg1
2120 /// scf.if %0 {
2121 /// ...
2122 /// scf.yield
2123 /// }
2124 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
2126 
2127  LogicalResult matchAndRewrite(IfOp op,
2128  PatternRewriter &rewriter) const override {
2129  auto nestedOps = op.thenBlock()->without_terminator();
2130  // Nested `if` must be the only op in block.
2131  if (!llvm::hasSingleElement(nestedOps))
2132  return failure();
2133 
2134  // If there is an else block, it can only yield
2135  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
2136  return failure();
2137 
2138  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
2139  if (!nestedIf)
2140  return failure();
2141 
2142  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
2143  return failure();
2144 
2145  SmallVector<Value> thenYield(op.thenYield().getOperands());
2146  SmallVector<Value> elseYield;
2147  if (op.elseBlock())
2148  llvm::append_range(elseYield, op.elseYield().getOperands());
2149 
2150  // A list of indices for which we should upgrade the value yielded
2151  // in the else to a select.
2152  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2153 
2154  // If the outer scf.if yields a value produced by the inner scf.if,
2155  // only permit combining if the value yielded when the condition
2156  // is false in the outer scf.if is the same value yielded when the
2157  // inner scf.if condition is false.
2158  // Note that the array access to elseYield will not go out of bounds
2159  // since it must have the same length as thenYield, since they both
2160  // come from the same scf.if.
2161  for (const auto &tup : llvm::enumerate(thenYield)) {
2162  if (tup.value().getDefiningOp() == nestedIf) {
2163  auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
2164  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2165  elseYield[tup.index()]) {
2166  return failure();
2167  }
2168  // If the correctness test passes, we will yield
2169  // corresponding value from the inner scf.if
2170  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2171  continue;
2172  }
2173 
2174  // Otherwise, we need to ensure the else block of the combined
2175  // condition still returns the same value when the outer condition is
2176  // true and the inner condition is false. This can be accomplished if
2177  // the then value is defined outside the outer scf.if and we replace the
2178  // value with a select that considers just the outer condition. Since
2179  // the else region contains just the yield, its yielded value is
2180  // defined outside the scf.if, by definition.
2181 
2182  // If the then value is defined within the scf.if, bail.
2183  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2184  return failure();
2185  }
2186  elseYieldsToUpgradeToSelect.push_back(tup.index());
2187  }
2188 
2189  Location loc = op.getLoc();
2190  Value newCondition = rewriter.create<arith::AndIOp>(
2191  loc, op.getCondition(), nestedIf.getCondition());
2192  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2193 
2194  SmallVector<Value> results;
2195  llvm::append_range(results, newIf.getResults());
2196  rewriter.setInsertionPoint(newIf);
2197 
2198  for (auto idx : elseYieldsToUpgradeToSelect)
2199  results[idx] = rewriter.create<arith::SelectOp>(
2200  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2201 
2202  Block *newIfBlock = newIf.thenBlock();
2203  if (newIfBlock)
2204  rewriter.eraseOp(newIfBlock->getTerminator());
2205  else
2206  newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2207  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2208  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2209  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2210  if (!elseYield.empty()) {
2211  rewriter.createBlock(&newIf.getElseRegion());
2212  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2213  rewriter.create<YieldOp>(loc, elseYield);
2214  }
2215  rewriter.replaceOp(op, results);
2216  return success();
2217  }
2218 };
2219 
2220 } // namespace
2221 
2222 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2223  MLIRContext *context) {
2224  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2225  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2226  RemoveStaticCondition, RemoveUnusedResults,
2227  ReplaceIfYieldWithConditionOrValue>(context);
2228 }
2229 
2230 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2231 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2232 Block *IfOp::elseBlock() {
2233  Region &r = getElseRegion();
2234  if (r.empty())
2235  return nullptr;
2236  return &r.back();
2237 }
2238 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2239 
2240 //===----------------------------------------------------------------------===//
2241 // ParallelOp
2242 //===----------------------------------------------------------------------===//
2243 
2244 void ParallelOp::build(
2245  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2246  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2248  bodyBuilderFn) {
2249  result.addOperands(lowerBounds);
2250  result.addOperands(upperBounds);
2251  result.addOperands(steps);
2252  result.addOperands(initVals);
2253  result.addAttribute(
2254  ParallelOp::getOperandSegmentSizeAttr(),
2255  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2256  static_cast<int32_t>(upperBounds.size()),
2257  static_cast<int32_t>(steps.size()),
2258  static_cast<int32_t>(initVals.size())}));
2259  result.addTypes(initVals.getTypes());
2260 
2261  OpBuilder::InsertionGuard guard(builder);
2262  unsigned numIVs = steps.size();
2263  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2264  SmallVector<Location, 8> argLocs(numIVs, result.location);
2265  Region *bodyRegion = result.addRegion();
2266  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2267 
2268  if (bodyBuilderFn) {
2269  builder.setInsertionPointToStart(bodyBlock);
2270  bodyBuilderFn(builder, result.location,
2271  bodyBlock->getArguments().take_front(numIVs),
2272  bodyBlock->getArguments().drop_front(numIVs));
2273  }
2274  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2275 }
2276 
2277 void ParallelOp::build(
2278  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2279  ValueRange upperBounds, ValueRange steps,
2280  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2281  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2282  // we don't capture a reference to a temporary by constructing the lambda at
2283  // function level.
2284  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2285  Location nestedLoc, ValueRange ivs,
2286  ValueRange) {
2287  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2288  };
2289  function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
2290  if (bodyBuilderFn)
2291  wrapper = wrappedBuilderFn;
2292 
2293  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2294  wrapper);
2295 }
2296 
2298  // Check that there is at least one value in lowerBound, upperBound and step.
2299  // It is sufficient to test only step, because it is ensured already that the
2300  // number of elements in lowerBound, upperBound and step are the same.
2301  Operation::operand_range stepValues = getStep();
2302  if (stepValues.empty())
2303  return emitOpError(
2304  "needs at least one tuple element for lowerBound, upperBound and step");
2305 
2306  // Check whether all constant step values are positive.
2307  for (Value stepValue : stepValues)
2308  if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
2309  if (cst.value() <= 0)
2310  return emitOpError("constant step operand must be positive");
2311 
2312  // Check that the body defines the same number of block arguments as the
2313  // number of tuple elements in step.
2314  Block *body = getBody();
2315  if (body->getNumArguments() != stepValues.size())
2316  return emitOpError() << "expects the same number of induction variables: "
2317  << body->getNumArguments()
2318  << " as bound and step values: " << stepValues.size();
2319  for (auto arg : body->getArguments())
2320  if (!arg.getType().isIndex())
2321  return emitOpError(
2322  "expects arguments for the induction variable to be of index type");
2323 
2324  // Check that the yield has no results
2325  Operation *yield = body->getTerminator();
2326  if (yield->getNumOperands() != 0)
2327  return yield->emitOpError() << "not allowed to have operands inside '"
2328  << ParallelOp::getOperationName() << "'";
2329 
2330  // Check that the number of results is the same as the number of ReduceOps.
2331  SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2332  auto resultsSize = getResults().size();
2333  auto reductionsSize = reductions.size();
2334  auto initValsSize = getInitVals().size();
2335  if (resultsSize != reductionsSize)
2336  return emitOpError() << "expects number of results: " << resultsSize
2337  << " to be the same as number of reductions: "
2338  << reductionsSize;
2339  if (resultsSize != initValsSize)
2340  return emitOpError() << "expects number of results: " << resultsSize
2341  << " to be the same as number of initial values: "
2342  << initValsSize;
2343 
2344  // Check that the types of the results and reductions are the same.
2345  for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2346  auto resultType = std::get<0>(resultAndReduce).getType();
2347  auto reduceOp = std::get<1>(resultAndReduce);
2348  auto reduceType = reduceOp.getOperand().getType();
2349  if (resultType != reduceType)
2350  return reduceOp.emitOpError()
2351  << "expects type of reduce: " << reduceType
2352  << " to be the same as result type: " << resultType;
2353  }
2354  return success();
2355 }
2356 
2357 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2358  auto &builder = parser.getBuilder();
2359  // Parse an opening `(` followed by induction variables followed by `)`
2362  return failure();
2363 
2364  // Parse loop bounds.
2366  if (parser.parseEqual() ||
2367  parser.parseOperandList(lower, ivs.size(),
2369  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2370  return failure();
2371 
2373  if (parser.parseKeyword("to") ||
2374  parser.parseOperandList(upper, ivs.size(),
2376  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2377  return failure();
2378 
2379  // Parse step values.
2381  if (parser.parseKeyword("step") ||
2382  parser.parseOperandList(steps, ivs.size(),
2384  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2385  return failure();
2386 
2387  // Parse init values.
2389  if (succeeded(parser.parseOptionalKeyword("init"))) {
2390  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2391  return failure();
2392  }
2393 
2394  // Parse optional results in case there is a reduce.
2395  if (parser.parseOptionalArrowTypeList(result.types))
2396  return failure();
2397 
2398  // Now parse the body.
2399  Region *body = result.addRegion();
2400  for (auto &iv : ivs)
2401  iv.type = builder.getIndexType();
2402  if (parser.parseRegion(*body, ivs))
2403  return failure();
2404 
2405  // Set `operand_segment_sizes` attribute.
2406  result.addAttribute(
2407  ParallelOp::getOperandSegmentSizeAttr(),
2408  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2409  static_cast<int32_t>(upper.size()),
2410  static_cast<int32_t>(steps.size()),
2411  static_cast<int32_t>(initVals.size())}));
2412 
2413  // Parse attributes.
2414  if (parser.parseOptionalAttrDict(result.attributes) ||
2415  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2416  result.operands))
2417  return failure();
2418 
2419  // Add a terminator if none was parsed.
2420  ForOp::ensureTerminator(*body, builder, result.location);
2421  return success();
2422 }
2423 
2425  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2426  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2427  if (!getInitVals().empty())
2428  p << " init (" << getInitVals() << ")";
2429  p.printOptionalArrowTypeList(getResultTypes());
2430  p << ' ';
2431  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2433  (*this)->getAttrs(),
2434  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2435 }
2436 
2437 Region &ParallelOp::getLoopBody() { return getRegion(); }
2438 
2440  auto ivArg = val.dyn_cast<BlockArgument>();
2441  if (!ivArg)
2442  return ParallelOp();
2443  assert(ivArg.getOwner() && "unlinked block argument");
2444  auto *containingOp = ivArg.getOwner()->getParentOp();
2445  return dyn_cast<ParallelOp>(containingOp);
2446 }
2447 
2448 namespace {
2449 // Collapse loop dimensions that perform a single iteration.
2450 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
2452 
2453  LogicalResult matchAndRewrite(ParallelOp op,
2454  PatternRewriter &rewriter) const override {
2455  BlockAndValueMapping mapping;
2456  // Compute new loop bounds that omit all single-iteration loop dimensions.
2457  SmallVector<Value, 2> newLowerBounds;
2458  SmallVector<Value, 2> newUpperBounds;
2459  SmallVector<Value, 2> newSteps;
2460  newLowerBounds.reserve(op.getLowerBound().size());
2461  newUpperBounds.reserve(op.getUpperBound().size());
2462  newSteps.reserve(op.getStep().size());
2463  for (auto [lowerBound, upperBound, step, iv] :
2464  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2465  op.getInductionVars())) {
2466  // Collect the statically known loop bounds.
2467  auto lowerBoundConstant =
2468  dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2469  auto upperBoundConstant =
2470  dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2471  auto stepConstant =
2472  dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
2473  // Replace the loop induction variable by the lower bound if the loop
2474  // performs a single iteration. Otherwise, copy the loop bounds.
2475  if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2476  (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2477  (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2478  stepConstant.value()) {
2479  mapping.map(iv, lowerBound);
2480  } else {
2481  newLowerBounds.push_back(lowerBound);
2482  newUpperBounds.push_back(upperBound);
2483  newSteps.push_back(step);
2484  }
2485  }
2486  // Exit if none of the loop dimensions perform a single iteration.
2487  if (newLowerBounds.size() == op.getLowerBound().size())
2488  return failure();
2489 
2490  if (newLowerBounds.empty()) {
2491  // All of the loop dimensions perform a single iteration. Inline
2492  // loop body and nested ReduceOp's
2493  SmallVector<Value> results;
2494  results.reserve(op.getInitVals().size());
2495  for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2496  auto reduce = dyn_cast<ReduceOp>(bodyOp);
2497  if (!reduce) {
2498  rewriter.clone(bodyOp, mapping);
2499  continue;
2500  }
2501  Block &reduceBlock = reduce.getReductionOperator().front();
2502  auto initValIndex = results.size();
2503  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2504  mapping.map(reduceBlock.getArgument(1),
2505  mapping.lookupOrDefault(reduce.getOperand()));
2506  for (auto &reduceBodyOp : reduceBlock.without_terminator())
2507  rewriter.clone(reduceBodyOp, mapping);
2508 
2509  auto result = mapping.lookupOrDefault(
2510  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2511  results.push_back(result);
2512  }
2513  rewriter.replaceOp(op, results);
2514  return success();
2515  }
2516  // Replace the parallel loop by lower-dimensional parallel loop.
2517  auto newOp =
2518  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2519  newSteps, op.getInitVals(), nullptr);
2520  // Clone the loop body and remap the block arguments of the collapsed loops
2521  // (inlining does not support a cancellable block argument mapping).
2522  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2523  newOp.getRegion().begin(), mapping);
2524  rewriter.replaceOp(op, newOp.getResults());
2525  return success();
2526  }
2527 };
2528 
2529 /// Removes parallel loops in which at least one lower/upper bound pair consists
2530 /// of the same values - such loops have an empty iteration domain.
2531 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
2533 
2534  LogicalResult matchAndRewrite(ParallelOp op,
2535  PatternRewriter &rewriter) const override {
2536  for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2537  if (std::get<0>(dim) == std::get<1>(dim)) {
2538  rewriter.replaceOp(op, op.getInitVals());
2539  return success();
2540  }
2541  }
2542  return failure();
2543  }
2544 };
2545 
2546 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2548 
2549  LogicalResult matchAndRewrite(ParallelOp op,
2550  PatternRewriter &rewriter) const override {
2551  Block &outerBody = op.getLoopBody().front();
2552  if (!llvm::hasSingleElement(outerBody.without_terminator()))
2553  return failure();
2554 
2555  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2556  if (!innerOp)
2557  return failure();
2558 
2559  for (auto val : outerBody.getArguments())
2560  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2561  llvm::is_contained(innerOp.getUpperBound(), val) ||
2562  llvm::is_contained(innerOp.getStep(), val))
2563  return failure();
2564 
2565  // Reductions are not supported yet.
2566  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2567  return failure();
2568 
2569  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2570  ValueRange iterVals, ValueRange) {
2571  Block &innerBody = innerOp.getLoopBody().front();
2572  assert(iterVals.size() ==
2573  (outerBody.getNumArguments() + innerBody.getNumArguments()));
2574  BlockAndValueMapping mapping;
2575  mapping.map(outerBody.getArguments(),
2576  iterVals.take_front(outerBody.getNumArguments()));
2577  mapping.map(innerBody.getArguments(),
2578  iterVals.take_back(innerBody.getNumArguments()));
2579  for (Operation &op : innerBody.without_terminator())
2580  builder.clone(op, mapping);
2581  };
2582 
2583  auto concatValues = [](const auto &first, const auto &second) {
2584  SmallVector<Value> ret;
2585  ret.reserve(first.size() + second.size());
2586  ret.assign(first.begin(), first.end());
2587  ret.append(second.begin(), second.end());
2588  return ret;
2589  };
2590 
2591  auto newLowerBounds =
2592  concatValues(op.getLowerBound(), innerOp.getLowerBound());
2593  auto newUpperBounds =
2594  concatValues(op.getUpperBound(), innerOp.getUpperBound());
2595  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2596 
2597  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2598  newSteps, llvm::None, bodyBuilder);
2599  return success();
2600  }
2601 };
2602 
2603 } // namespace
2604 
2605 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2606  MLIRContext *context) {
2607  results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2608  MergeNestedParallelLoops>(context);
2609 }
2610 
2611 //===----------------------------------------------------------------------===//
2612 // ReduceOp
2613 //===----------------------------------------------------------------------===//
2614 
2615 void ReduceOp::build(
2616  OpBuilder &builder, OperationState &result, Value operand,
2617  function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2618  auto type = operand.getType();
2619  result.addOperands(operand);
2620 
2621  OpBuilder::InsertionGuard guard(builder);
2622  Region *bodyRegion = result.addRegion();
2623  Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2624  {result.location, result.location});
2625  if (bodyBuilderFn)
2626  bodyBuilderFn(builder, result.location, body->getArgument(0),
2627  body->getArgument(1));
2628 }
2629 
2630 LogicalResult ReduceOp::verifyRegions() {
2631  // The region of a ReduceOp has two arguments of the same type as its operand.
2632  auto type = getOperand().getType();
2633  Block &block = getReductionOperator().front();
2634  if (block.empty())
2635  return emitOpError("the block inside reduce should not be empty");
2636  if (block.getNumArguments() != 2 ||
2637  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2638  return arg.getType() != type;
2639  }))
2640  return emitOpError() << "expects two arguments to reduce block of type "
2641  << type;
2642 
2643  // Check that the block is terminated by a ReduceReturnOp.
2644  if (!isa<ReduceReturnOp>(block.getTerminator()))
2645  return emitOpError("the block inside reduce should be terminated with a "
2646  "'scf.reduce.return' op");
2647 
2648  return success();
2649 }
2650 
2651 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2652  // Parse an opening `(` followed by the reduced value followed by `)`
2654  if (parser.parseLParen() || parser.parseOperand(operand) ||
2655  parser.parseRParen())
2656  return failure();
2657 
2658  Type resultType;
2659  // Parse the type of the operand (and also what reduce computes on).
2660  if (parser.parseColonType(resultType) ||
2661  parser.resolveOperand(operand, resultType, result.operands))
2662  return failure();
2663 
2664  // Now parse the body.
2665  Region *body = result.addRegion();
2666  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2667  return failure();
2668 
2669  return success();
2670 }
2671 
2672 void ReduceOp::print(OpAsmPrinter &p) {
2673  p << "(" << getOperand() << ") ";
2674  p << " : " << getOperand().getType() << ' ';
2675  p.printRegion(getReductionOperator());
2676 }
2677 
2678 //===----------------------------------------------------------------------===//
2679 // ReduceReturnOp
2680 //===----------------------------------------------------------------------===//
2681 
2683  // The type of the return value should be the same type as the type of the
2684  // operand of the enclosing ReduceOp.
2685  auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2686  Type reduceType = reduceOp.getOperand().getType();
2687  if (reduceType != getResult().getType())
2688  return emitOpError() << "needs to have type " << reduceType
2689  << " (the type of the enclosing ReduceOp)";
2690  return success();
2691 }
2692 
2693 //===----------------------------------------------------------------------===//
2694 // WhileOp
2695 //===----------------------------------------------------------------------===//
2696 
2697 void WhileOp::build(::mlir::OpBuilder &odsBuilder,
2698  ::mlir::OperationState &odsState, TypeRange resultTypes,
2699  ValueRange operands, BodyBuilderFn beforeBuilder,
2700  BodyBuilderFn afterBuilder) {
2701  assert(beforeBuilder && "the builder callback for 'before' must be present");
2702  assert(afterBuilder && "the builder callback for 'after' must be present");
2703 
2704  odsState.addOperands(operands);
2705  odsState.addTypes(resultTypes);
2706 
2707  OpBuilder::InsertionGuard guard(odsBuilder);
2708 
2709  SmallVector<Location, 4> blockArgLocs;
2710  for (Value operand : operands) {
2711  blockArgLocs.push_back(operand.getLoc());
2712  }
2713 
2714  Region *beforeRegion = odsState.addRegion();
2715  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
2716  resultTypes, blockArgLocs);
2717  beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
2718 
2719  Region *afterRegion = odsState.addRegion();
2720  Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
2721  resultTypes, blockArgLocs);
2722  afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
2723 }
2724 
2725 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
2726  assert(index && *index == 0 &&
2727  "WhileOp is expected to branch only to the first region");
2728 
2729  return getInits();
2730 }
2731 
2732 ConditionOp WhileOp::getConditionOp() {
2733  return cast<ConditionOp>(getBefore().front().getTerminator());
2734 }
2735 
2736 YieldOp WhileOp::getYieldOp() {
2737  return cast<YieldOp>(getAfter().front().getTerminator());
2738 }
2739 
2740 Block::BlockArgListType WhileOp::getBeforeArguments() {
2741  return getBefore().front().getArguments();
2742 }
2743 
2744 Block::BlockArgListType WhileOp::getAfterArguments() {
2745  return getAfter().front().getArguments();
2746 }
2747 
2748 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2749  ArrayRef<Attribute> operands,
2751  // The parent op always branches to the condition region.
2752  if (!index) {
2753  regions.emplace_back(&getBefore(), getBefore().getArguments());
2754  return;
2755  }
2756 
2757  assert(*index < 2 && "there are only two regions in a WhileOp");
2758  // The body region always branches back to the condition region.
2759  if (*index == 1) {
2760  regions.emplace_back(&getBefore(), getBefore().getArguments());
2761  return;
2762  }
2763 
2764  // Try to narrow the successor to the condition region.
2765  assert(!operands.empty() && "expected at least one operand");
2766  auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
2767  if (!cond || !cond.getValue())
2768  regions.emplace_back(getResults());
2769  if (!cond || cond.getValue())
2770  regions.emplace_back(&getAfter(), getAfter().getArguments());
2771 }
2772 
2773 /// Parses a `while` op.
2774 ///
2775 /// op ::= `scf.while` assignments `:` function-type region `do` region
2776 /// `attributes` attribute-dict
2777 /// initializer ::= /* empty */ | `(` assignment-list `)`
2778 /// assignment-list ::= assignment | assignment `,` assignment-list
2779 /// assignment ::= ssa-value `=` ssa-value
2780 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2783  Region *before = result.addRegion();
2784  Region *after = result.addRegion();
2785 
2786  OptionalParseResult listResult =
2787  parser.parseOptionalAssignmentList(regionArgs, operands);
2788  if (listResult.has_value() && failed(listResult.value()))
2789  return failure();
2790 
2791  FunctionType functionType;
2792  SMLoc typeLoc = parser.getCurrentLocation();
2793  if (failed(parser.parseColonType(functionType)))
2794  return failure();
2795 
2796  result.addTypes(functionType.getResults());
2797 
2798  if (functionType.getNumInputs() != operands.size()) {
2799  return parser.emitError(typeLoc)
2800  << "expected as many input types as operands "
2801  << "(expected " << operands.size() << " got "
2802  << functionType.getNumInputs() << ")";
2803  }
2804 
2805  // Resolve input operands.
2806  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2807  parser.getCurrentLocation(),
2808  result.operands)))
2809  return failure();
2810 
2811  // Propagate the types into the region arguments.
2812  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2813  regionArgs[i].type = functionType.getInput(i);
2814 
2815  return failure(parser.parseRegion(*before, regionArgs) ||
2816  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2818 }
2819 
2820 /// Prints a `while` op.
2822  printInitializationList(p, getBefore().front().getArguments(), getInits(),
2823  " ");
2824  p << " : ";
2825  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
2826  p << ' ';
2827  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
2828  p << " do ";
2829  p.printRegion(getAfter());
2830  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2831 }
2832 
2833 /// Verifies that two ranges of types match, i.e. have the same number of
2834 /// entries and that types are pairwise equals. Reports errors on the given
2835 /// operation in case of mismatch.
2836 template <typename OpTy>
2838  TypeRange right, StringRef message) {
2839  if (left.size() != right.size())
2840  return op.emitOpError("expects the same number of ") << message;
2841 
2842  for (unsigned i = 0, e = left.size(); i < e; ++i) {
2843  if (left[i] != right[i]) {
2844  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2845  << message;
2846  diag.attachNote() << "for argument " << i << ", found " << left[i]
2847  << " and " << right[i];
2848  return diag;
2849  }
2850  }
2851 
2852  return success();
2853 }
2854 
2855 /// Verifies that the first block of the given `region` is terminated by a
2856 /// YieldOp. Reports errors on the given operation if it is not the case.
2857 template <typename TerminatorTy>
2858 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2859  StringRef errorMessage) {
2860  Operation *terminatorOperation = region.front().getTerminator();
2861  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2862  return yield;
2863 
2864  auto diag = op.emitOpError(errorMessage);
2865  if (terminatorOperation)
2866  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2867  return nullptr;
2868 }
2869 
2871  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2872  *this, getBefore(),
2873  "expects the 'before' region to terminate with 'scf.condition'");
2874  if (!beforeTerminator)
2875  return failure();
2876 
2877  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2878  *this, getAfter(),
2879  "expects the 'after' region to terminate with 'scf.yield'");
2880  return success(afterTerminator != nullptr);
2881 }
2882 
2883 namespace {
2884 /// Replace uses of the condition within the do block with true, since otherwise
2885 /// the block would not be evaluated.
2886 ///
2887 /// scf.while (..) : (i1, ...) -> ... {
2888 /// %condition = call @evaluate_condition() : () -> i1
2889 /// scf.condition(%condition) %condition : i1, ...
2890 /// } do {
2891 /// ^bb0(%arg0: i1, ...):
2892 /// use(%arg0)
2893 /// ...
2894 ///
2895 /// becomes
2896 /// scf.while (..) : (i1, ...) -> ... {
2897 /// %condition = call @evaluate_condition() : () -> i1
2898 /// scf.condition(%condition) %condition : i1, ...
2899 /// } do {
2900 /// ^bb0(%arg0: i1, ...):
2901 /// use(%true)
2902 /// ...
2903 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2905 
2906  LogicalResult matchAndRewrite(WhileOp op,
2907  PatternRewriter &rewriter) const override {
2908  auto term = op.getConditionOp();
2909 
2910  // These variables serve to prevent creating duplicate constants
2911  // and hold constant true or false values.
2912  Value constantTrue = nullptr;
2913 
2914  bool replaced = false;
2915  for (auto yieldedAndBlockArgs :
2916  llvm::zip(term.getArgs(), op.getAfterArguments())) {
2917  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2918  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2919  if (!constantTrue)
2920  constantTrue = rewriter.create<arith::ConstantOp>(
2921  op.getLoc(), term.getCondition().getType(),
2922  rewriter.getBoolAttr(true));
2923 
2924  std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2925  replaced = true;
2926  }
2927  }
2928  }
2929  return success(replaced);
2930  }
2931 };
2932 
2933 /// Remove loop invariant arguments from `before` block of scf.while.
2934 /// A before block argument is considered loop invariant if :-
2935 /// 1. i-th yield operand is equal to the i-th while operand.
2936 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
2937 /// condition operand AND this (k+1)-th condition operand is equal to i-th
2938 /// iter argument/while operand.
2939 /// For the arguments which are removed, their uses inside scf.while
2940 /// are replaced with their corresponding initial value.
2941 ///
2942 /// Eg:
2943 /// INPUT :-
2944 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
2945 /// ..., %argN_before = %N)
2946 /// {
2947 /// ...
2948 /// scf.condition(%cond) %arg1_before, %arg0_before,
2949 /// %arg2_before, %arg0_before, ...
2950 /// } do {
2951 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2952 /// ..., %argK_after):
2953 /// ...
2954 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
2955 /// }
2956 ///
2957 /// OUTPUT :-
2958 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
2959 /// %N)
2960 /// {
2961 /// ...
2962 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
2963 /// } do {
2964 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2965 /// ..., %argK_after):
2966 /// ...
2967 /// scf.yield %arg1_after, ..., %argN
2968 /// }
2969 ///
2970 /// EXPLANATION:
2971 /// We iterate over each yield operand.
2972 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
2973 /// %arg0_before, which in turn is the 0-th iter argument. So we
2974 /// remove 0-th before block argument and yield operand, and replace
2975 /// all uses of the 0-th before block argument with its initial value
2976 /// %a.
2977 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
2978 /// value. So we remove this operand and the corresponding before
2979 /// block argument and replace all uses of 1-th before block argument
2980 /// with %b.
2981 struct RemoveLoopInvariantArgsFromBeforeBlock
2982  : public OpRewritePattern<WhileOp> {
2984 
2985  LogicalResult matchAndRewrite(WhileOp op,
2986  PatternRewriter &rewriter) const override {
2987  Block &afterBlock = op.getAfter().front();
2988  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
2989  ConditionOp condOp = op.getConditionOp();
2990  OperandRange condOpArgs = condOp.getArgs();
2991  Operation *yieldOp = afterBlock.getTerminator();
2992  ValueRange yieldOpArgs = yieldOp->getOperands();
2993 
2994  bool canSimplify = false;
2995  for (const auto &it :
2996  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2997  auto index = static_cast<unsigned>(it.index());
2998  auto [initVal, yieldOpArg] = it.value();
2999  // If i-th yield operand is equal to the i-th operand of the scf.while,
3000  // the i-th before block argument is a loop invariant.
3001  if (yieldOpArg == initVal) {
3002  canSimplify = true;
3003  break;
3004  }
3005  // If the i-th yield operand is k-th after block argument, then we check
3006  // if the (k+1)-th condition op operand is equal to either the i-th before
3007  // block argument or the initial value of i-th before block argument. If
3008  // the comparison results `true`, i-th before block argument is a loop
3009  // invariant.
3010  auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
3011  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3012  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3013  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3014  canSimplify = true;
3015  break;
3016  }
3017  }
3018  }
3019 
3020  if (!canSimplify)
3021  return failure();
3022 
3023  SmallVector<Value> newInitArgs, newYieldOpArgs;
3024  DenseMap<unsigned, Value> beforeBlockInitValMap;
3025  SmallVector<Location> newBeforeBlockArgLocs;
3026  for (const auto &it :
3027  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
3028  auto index = static_cast<unsigned>(it.index());
3029  auto [initVal, yieldOpArg] = it.value();
3030 
3031  // If i-th yield operand is equal to the i-th operand of the scf.while,
3032  // the i-th before block argument is a loop invariant.
3033  if (yieldOpArg == initVal) {
3034  beforeBlockInitValMap.insert({index, initVal});
3035  continue;
3036  } else {
3037  // If the i-th yield operand is k-th after block argument, then we check
3038  // if the (k+1)-th condition op operand is equal to either the i-th
3039  // before block argument or the initial value of i-th before block
3040  // argument. If the comparison results `true`, i-th before block
3041  // argument is a loop invariant.
3042  auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
3043  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
3044  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
3045  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
3046  beforeBlockInitValMap.insert({index, initVal});
3047  continue;
3048  }
3049  }
3050  }
3051  newInitArgs.emplace_back(initVal);
3052  newYieldOpArgs.emplace_back(yieldOpArg);
3053  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
3054  }
3055 
3056  {
3057  OpBuilder::InsertionGuard g(rewriter);
3058  rewriter.setInsertionPoint(yieldOp);
3059  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
3060  }
3061 
3062  auto newWhile =
3063  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
3064 
3065  Block &newBeforeBlock = *rewriter.createBlock(
3066  &newWhile.getBefore(), /*insertPt*/ {},
3067  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
3068 
3069  Block &beforeBlock = op.getBefore().front();
3070  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
3071  // For each i-th before block argument we find it's replacement value as :-
3072  // 1. If i-th before block argument is a loop invariant, we fetch it's
3073  // initial value from `beforeBlockInitValMap` by querying for key `i`.
3074  // 2. Else we fetch j-th new before block argument as the replacement
3075  // value of i-th before block argument.
3076  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
3077  // If the index 'i' argument was a loop invariant we fetch it's initial
3078  // value from `beforeBlockInitValMap`.
3079  if (beforeBlockInitValMap.count(i) != 0)
3080  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
3081  else
3082  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
3083  }
3084 
3085  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
3086  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
3087  newWhile.getAfter().begin());
3088 
3089  rewriter.replaceOp(op, newWhile.getResults());
3090  return success();
3091  }
3092 };
3093 
3094 /// Remove loop invariant value from result (condition op) of scf.while.
3095 /// A value is considered loop invariant if the final value yielded by
3096 /// scf.condition is defined outside of the `before` block. We remove the
3097 /// corresponding argument in `after` block and replace the use with the value.
3098 /// We also replace the use of the corresponding result of scf.while with the
3099 /// value.
3100 ///
3101 /// Eg:
3102 /// INPUT :-
3103 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
3104 /// %argN_before = %N) {
3105 /// ...
3106 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
3107 /// } do {
3108 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
3109 /// ...
3110 /// some_func(%arg1_after)
3111 /// ...
3112 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
3113 /// }
3114 ///
3115 /// OUTPUT :-
3116 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
3117 /// ...
3118 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
3119 /// } do {
3120 /// ^bb0(%arg0, %arg3, ..., %argM):
3121 /// ...
3122 /// some_func(%a)
3123 /// ...
3124 /// scf.yield %arg0, %b, ..., %argN
3125 /// }
3126 ///
3127 /// EXPLANATION:
3128 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
3129 /// before block of scf.while, so they get removed.
3130 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
3131 /// replaced by %b.
3132 /// 3. The corresponding after block argument %arg1_after's uses are
3133 /// replaced by %a and %arg2_after's uses are replaced by %b.
3134 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
3136 
3137  LogicalResult matchAndRewrite(WhileOp op,
3138  PatternRewriter &rewriter) const override {
3139  Block &beforeBlock = op.getBefore().front();
3140  ConditionOp condOp = op.getConditionOp();
3141  OperandRange condOpArgs = condOp.getArgs();
3142 
3143  bool canSimplify = false;
3144  for (Value condOpArg : condOpArgs) {
3145  // Those values not defined within `before` block will be considered as
3146  // loop invariant values. We map the corresponding `index` with their
3147  // value.
3148  if (condOpArg.getParentBlock() != &beforeBlock) {
3149  canSimplify = true;
3150  break;
3151  }
3152  }
3153 
3154  if (!canSimplify)
3155  return failure();
3156 
3157  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
3158 
3159  SmallVector<Value> newCondOpArgs;
3160  SmallVector<Type> newAfterBlockType;
3161  DenseMap<unsigned, Value> condOpInitValMap;
3162  SmallVector<Location> newAfterBlockArgLocs;
3163  for (const auto &it : llvm::enumerate(condOpArgs)) {
3164  auto index = static_cast<unsigned>(it.index());
3165  Value condOpArg = it.value();
3166  // Those values not defined within `before` block will be considered as
3167  // loop invariant values. We map the corresponding `index` with their
3168  // value.
3169  if (condOpArg.getParentBlock() != &beforeBlock) {
3170  condOpInitValMap.insert({index, condOpArg});
3171  } else {
3172  newCondOpArgs.emplace_back(condOpArg);
3173  newAfterBlockType.emplace_back(condOpArg.getType());
3174  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3175  }
3176  }
3177 
3178  {
3179  OpBuilder::InsertionGuard g(rewriter);
3180  rewriter.setInsertionPoint(condOp);
3181  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3182  newCondOpArgs);
3183  }
3184 
3185  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3186  op.getOperands());
3187 
3188  Block &newAfterBlock =
3189  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3190  newAfterBlockType, newAfterBlockArgLocs);
3191 
3192  Block &afterBlock = op.getAfter().front();
3193  // Since a new scf.condition op was created, we need to fetch the new
3194  // `after` block arguments which will be used while replacing operations of
3195  // previous scf.while's `after` blocks. We'd also be fetching new result
3196  // values too.
3197  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3198  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3199  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3200  Value afterBlockArg, result;
3201  // If index 'i' argument was loop invariant we fetch it's value from the
3202  // `condOpInitMap` map.
3203  if (condOpInitValMap.count(i) != 0) {
3204  afterBlockArg = condOpInitValMap[i];
3205  result = afterBlockArg;
3206  } else {
3207  afterBlockArg = newAfterBlock.getArgument(j);
3208  result = newWhile.getResult(j);
3209  j++;
3210  }
3211  newAfterBlockArgs[i] = afterBlockArg;
3212  newWhileResults[i] = result;
3213  }
3214 
3215  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3216  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3217  newWhile.getBefore().begin());
3218 
3219  rewriter.replaceOp(op, newWhileResults);
3220  return success();
3221  }
3222 };
3223 
3224 /// Remove WhileOp results that are also unused in 'after' block.
3225 ///
3226 /// %0:2 = scf.while () : () -> (i32, i64) {
3227 /// %condition = "test.condition"() : () -> i1
3228 /// %v1 = "test.get_some_value"() : () -> i32
3229 /// %v2 = "test.get_some_value"() : () -> i64
3230 /// scf.condition(%condition) %v1, %v2 : i32, i64
3231 /// } do {
3232 /// ^bb0(%arg0: i32, %arg1: i64):
3233 /// "test.use"(%arg0) : (i32) -> ()
3234 /// scf.yield
3235 /// }
3236 /// return %0#0 : i32
3237 ///
3238 /// becomes
3239 /// %0 = scf.while () : () -> (i32) {
3240 /// %condition = "test.condition"() : () -> i1
3241 /// %v1 = "test.get_some_value"() : () -> i32
3242 /// %v2 = "test.get_some_value"() : () -> i64
3243 /// scf.condition(%condition) %v1 : i32
3244 /// } do {
3245 /// ^bb0(%arg0: i32):
3246 /// "test.use"(%arg0) : (i32) -> ()
3247 /// scf.yield
3248 /// }
3249 /// return %0 : i32
3250 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3252 
3253  LogicalResult matchAndRewrite(WhileOp op,
3254  PatternRewriter &rewriter) const override {
3255  auto term = op.getConditionOp();
3256  auto afterArgs = op.getAfterArguments();
3257  auto termArgs = term.getArgs();
3258 
3259  // Collect results mapping, new terminator args and new result types.
3260  SmallVector<unsigned> newResultsIndices;
3261  SmallVector<Type> newResultTypes;
3262  SmallVector<Value> newTermArgs;
3263  SmallVector<Location> newArgLocs;
3264  bool needUpdate = false;
3265  for (const auto &it :
3266  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3267  auto i = static_cast<unsigned>(it.index());
3268  Value result = std::get<0>(it.value());
3269  Value afterArg = std::get<1>(it.value());
3270  Value termArg = std::get<2>(it.value());
3271  if (result.use_empty() && afterArg.use_empty()) {
3272  needUpdate = true;
3273  } else {
3274  newResultsIndices.emplace_back(i);
3275  newTermArgs.emplace_back(termArg);
3276  newResultTypes.emplace_back(result.getType());
3277  newArgLocs.emplace_back(result.getLoc());
3278  }
3279  }
3280 
3281  if (!needUpdate)
3282  return failure();
3283 
3284  {
3285  OpBuilder::InsertionGuard g(rewriter);
3286  rewriter.setInsertionPoint(term);
3287  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3288  newTermArgs);
3289  }
3290 
3291  auto newWhile =
3292  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3293 
3294  Block &newAfterBlock = *rewriter.createBlock(
3295  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3296 
3297  // Build new results list and new after block args (unused entries will be
3298  // null).
3299  SmallVector<Value> newResults(op.getNumResults());
3300  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3301  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3302  newResults[it.value()] = newWhile.getResult(it.index());
3303  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3304  }
3305 
3306  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3307  newWhile.getBefore().begin());
3308 
3309  Block &afterBlock = op.getAfter().front();
3310  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3311 
3312  rewriter.replaceOp(op, newResults);
3313  return success();
3314  }
3315 };
3316 
3317 /// Replace operations equivalent to the condition in the do block with true,
3318 /// since otherwise the block would not be evaluated.
3319 ///
3320 /// scf.while (..) : (i32, ...) -> ... {
3321 /// %z = ... : i32
3322 /// %condition = cmpi pred %z, %a
3323 /// scf.condition(%condition) %z : i32, ...
3324 /// } do {
3325 /// ^bb0(%arg0: i32, ...):
3326 /// %condition2 = cmpi pred %arg0, %a
3327 /// use(%condition2)
3328 /// ...
3329 ///
3330 /// becomes
3331 /// scf.while (..) : (i32, ...) -> ... {
3332 /// %z = ... : i32
3333 /// %condition = cmpi pred %z, %a
3334 /// scf.condition(%condition) %z : i32, ...
3335 /// } do {
3336 /// ^bb0(%arg0: i32, ...):
3337 /// use(%true)
3338 /// ...
3339 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3341 
3342  LogicalResult matchAndRewrite(scf::WhileOp op,
3343  PatternRewriter &rewriter) const override {
3344  using namespace scf;
3345  auto cond = op.getConditionOp();
3346  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3347  if (!cmp)
3348  return failure();
3349  bool changed = false;
3350  for (auto tup :
3351  llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3352  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3353  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3354  continue;
3355  for (OpOperand &u :
3356  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3357  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3358  if (!cmp2)
3359  continue;
3360  // For a binary operator 1-opIdx gets the other side.
3361  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3362  continue;
3363  bool samePredicate;
3364  if (cmp2.getPredicate() == cmp.getPredicate())
3365  samePredicate = true;
3366  else if (cmp2.getPredicate() ==
3367  arith::invertPredicate(cmp.getPredicate()))
3368  samePredicate = false;
3369  else
3370  continue;
3371 
3372  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3373  1);
3374  changed = true;
3375  }
3376  }
3377  }
3378  return success(changed);
3379  }
3380 };
3381 
3382 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
3384 
3385  LogicalResult matchAndRewrite(WhileOp op,
3386  PatternRewriter &rewriter) const override {
3387 
3388  if (!llvm::any_of(op.getBeforeArguments(),
3389  [](Value arg) { return arg.use_empty(); }))
3390  return failure();
3391 
3392  YieldOp yield = op.getYieldOp();
3393 
3394  // Collect results mapping, new terminator args and new result types.
3395  SmallVector<Value> newYields;
3396  SmallVector<Value> newInits;
3397  llvm::BitVector argsToErase(op.getBeforeArguments().size());
3398  for (const auto &it : llvm::enumerate(llvm::zip(
3399  op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3400  Value beforeArg = std::get<0>(it.value());
3401  Value yieldValue = std::get<1>(it.value());
3402  Value initValue = std::get<2>(it.value());
3403  if (beforeArg.use_empty()) {
3404  argsToErase.set(it.index());
3405  } else {
3406  newYields.emplace_back(yieldValue);
3407  newInits.emplace_back(initValue);
3408  }
3409  }
3410 
3411  if (argsToErase.none())
3412  return failure();
3413 
3414  rewriter.startRootUpdate(op);
3415  op.getBefore().front().eraseArguments(argsToErase);
3416  rewriter.finalizeRootUpdate(op);
3417 
3418  WhileOp replacement =
3419  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3420  replacement.getBefore().takeBody(op.getBefore());
3421  replacement.getAfter().takeBody(op.getAfter());
3422  rewriter.replaceOp(op, replacement.getResults());
3423 
3424  rewriter.setInsertionPoint(yield);
3425  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3426  return success();
3427  }
3428 };
3429 } // namespace
3430 
3431 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3432  MLIRContext *context) {
3433  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3434  RemoveLoopInvariantValueYielded, WhileConditionTruth,
3435  WhileCmpCond, WhileUnusedResult>(context);
3436 }
3437 
3438 //===----------------------------------------------------------------------===//
3439 // IndexSwitchOp
3440 //===----------------------------------------------------------------------===//
3441 
3442 /// Parse the case regions and values.
3443 static ParseResult
3445  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
3446  SmallVector<int64_t> caseValues;
3447  while (succeeded(p.parseOptionalKeyword("case"))) {
3448  int64_t value;
3449  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
3450  if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
3451  return failure();
3452  caseValues.push_back(value);
3453  }
3454  cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
3455  return success();
3456 }
3457 
3458 /// Print the case regions and values.
3460  DenseI64ArrayAttr cases, RegionRange caseRegions) {
3461  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
3462  p.printNewline();
3463  p << "case " << value << ' ';
3464  p.printRegion(*region, /*printEntryBlockArgs=*/false);
3465  }
3466 }
3467 
3469  if (getCases().size() != getCaseRegions().size()) {
3470  return emitOpError("has ")
3471  << getCaseRegions().size() << " case regions but "
3472  << getCases().size() << " case values";
3473  }
3474 
3475  DenseSet<int64_t> valueSet;
3476  for (int64_t value : getCases())
3477  if (!valueSet.insert(value).second)
3478  return emitOpError("has duplicate case value: ") << value;
3479 
3480  auto verifyRegion = [&](Region &region, const Twine &name) -> LogicalResult {
3481  auto yield = cast<YieldOp>(region.front().getTerminator());
3482  if (yield.getNumOperands() != getNumResults()) {
3483  return (emitOpError("expected each region to return ")
3484  << getNumResults() << " values, but " << name << " returns "
3485  << yield.getNumOperands())
3486  .attachNote(yield.getLoc())
3487  << "see yield operation here";
3488  }
3489  for (auto [idx, result, operand] :
3490  llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
3491  yield.getOperandTypes())) {
3492  if (result == operand)
3493  continue;
3494  return (emitOpError("expected result #")
3495  << idx << " of each region to be " << result)
3496  .attachNote(yield.getLoc())
3497  << name << " returns " << operand << " here";
3498  }
3499  return success();
3500  };
3501 
3502  if (failed(verifyRegion(getDefaultRegion(), "default region")))
3503  return failure();
3504  for (auto &[idx, caseRegion] : llvm::enumerate(getCaseRegions()))
3505  if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
3506  return failure();
3507 
3508  return success();
3509 }
3510 
3511 unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
3512 
3513 Block &scf::IndexSwitchOp::getDefaultBlock() {
3514  return getDefaultRegion().front();
3515 }
3516 
3517 Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
3518  assert(idx < getNumCases() && "case index out-of-bounds");
3519  return getCaseRegions()[idx].front();
3520 }
3521 
3522 void IndexSwitchOp::getSuccessorRegions(
3523  Optional<unsigned> index, ArrayRef<Attribute> operands,
3524  SmallVectorImpl<RegionSuccessor> &successors) {
3525  // All regions branch back to the parent op.
3526  if (index) {
3527  successors.emplace_back(getResults());
3528  return;
3529  }
3530 
3531  // If a constant was not provided, all regions are possible successors.
3532  auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
3533  if (!operandValue) {
3534  for (Region &caseRegion : getCaseRegions())
3535  successors.emplace_back(&caseRegion);
3536  successors.emplace_back(&getDefaultRegion());
3537  return;
3538  }
3539 
3540  // Otherwise, try to find a case with a matching value. If not, the default
3541  // region is the only successor.
3542  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
3543  if (caseValue == operandValue.getInt()) {
3544  successors.emplace_back(&caseRegion);
3545  return;
3546  }
3547  }
3548  successors.emplace_back(&getDefaultRegion());
3549 }
3550 
3551 void IndexSwitchOp::getRegionInvocationBounds(
3553  auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
3554  if (!operandValue) {
3555  // All regions are invoked at most once.
3556  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
3557  return;
3558  }
3559 
3560  unsigned liveIndex = getNumRegions() - 1;
3561  const auto *it = llvm::find(getCases(), operandValue.getInt());
3562  if (it != getCases().end())
3563  liveIndex = std::distance(getCases().begin(), it);
3564  for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
3565  bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
3566 }
3567 
3568 //===----------------------------------------------------------------------===//
3569 // TableGen'd op method definitions
3570 //===----------------------------------------------------------------------===//
3571 
3572 #define GET_OP_CLASSES
3573 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static std::string diag(llvm::Value &value)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:88
static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region, StringRef errorMessage)
Verifies that the first block of the given region is terminated by a YieldOp.
Definition: SCF.cpp:2858
static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: SCF.cpp:3444
static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, TypeRange right, StringRef message)
Verifies that two ranges of types match, i.e.
Definition: SCF.cpp:2837
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition: SCF.cpp:373
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: SCF.cpp:3459
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static constexpr const bool value
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
This class represents an argument of a Block.
Definition: Value.h:296
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:308
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:137
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation & back()
Definition: Block.h:141
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:148
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
iterator begin()
Definition: Block.h:132
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:182
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:198
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:153
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:157
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:101
MLIRContext * getContext() const
Definition: Builders.h:54
IntegerType getI1Type()
Definition: Builders.cpp:58
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:247
IndexType getIndexType()
Definition: Builders.cpp:56
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:41
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:114
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:81
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:526
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:394
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:454
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
Definition: Operation.h:362
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:169
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:203
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:50
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:331
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
Block & back()
Definition: Region.h:64
unsigned getNumArguments()
Definition: Region.h:123
iterator begin()
Definition: Region.h:55
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:78
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:30
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:199
Type getType() const
Return the type of this value.
Definition: Value.h:114
U dyn_cast() const
Definition: Value.h:95
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition: ArithOps.cpp:46
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
StringRef getMappingAttrName()
Name of the mapping attribute produced by loop mappers.
auto m_Val(Value v)
Definition: Matchers.h:364
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:59
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:2439
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:78
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:508
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp.
Definition: SCF.cpp:1432
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
std::vector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:64
ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val)
Returns the ForeachThreadOp parent of an thread index variable.
Definition: SCF.cpp:1317
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:317
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:352
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:255
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:209
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:160
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.