MLIR  14.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::scf;
21 
22 #include "mlir/Dialect/SCF/SCFOpsDialect.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // SCFDialect Dialect Interfaces
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 struct SCFInlinerInterface : public DialectInlinerInterface {
31  // We don't have any special restrictions on what can be inlined into
32  // destination regions (e.g. while/conditional bodies). Always allow it.
33  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
34  BlockAndValueMapping &valueMapping) const final {
35  return true;
36  }
37  // Operations in scf dialect are always legal to inline since they are
38  // pure.
39  bool isLegalToInline(Operation *, Region *, bool,
40  BlockAndValueMapping &) const final {
41  return true;
42  }
43  // Handle the given inlined terminator by replacing it with a new operation
44  // as necessary. Required when the region has only one block.
45  void handleTerminator(Operation *op,
46  ArrayRef<Value> valuesToRepl) const final {
47  auto retValOp = dyn_cast<scf::YieldOp>(op);
48  if (!retValOp)
49  return;
50 
51  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
52  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
53  }
54  }
55 };
56 } // namespace
57 
58 //===----------------------------------------------------------------------===//
59 // SCFDialect
60 //===----------------------------------------------------------------------===//
61 
62 void SCFDialect::initialize() {
63  addOperations<
64 #define GET_OP_LIST
65 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
66  >();
67  addInterfaces<SCFInlinerInterface>();
68 }
69 
70 /// Default callback for IfOp builders. Inserts a yield without arguments.
72  builder.create<scf::YieldOp>(loc);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // ExecuteRegionOp
77 //===----------------------------------------------------------------------===//
78 
79 /// Replaces the given op with the contents of the given single-block region,
80 /// using the operands of the block terminator to replace operation results.
81 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
82  Region &region, ValueRange blockArgs = {}) {
83  assert(llvm::hasSingleElement(region) && "expected single-region block");
84  Block *block = &region.front();
85  Operation *terminator = block->getTerminator();
86  ValueRange results = terminator->getOperands();
87  rewriter.mergeBlockBefore(block, op, blockArgs);
88  rewriter.replaceOp(op, results);
89  rewriter.eraseOp(terminator);
90 }
91 
92 ///
93 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
94 /// block+
95 /// `}`
96 ///
97 /// Example:
98 /// scf.execute_region -> i32 {
99 /// %idx = load %rI[%i] : memref<128xi32>
100 /// return %idx : i32
101 /// }
102 ///
104  OperationState &result) {
105  if (parser.parseOptionalArrowTypeList(result.types))
106  return failure();
107 
108  // Introduce the body region and parse it.
109  Region *body = result.addRegion();
110  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
111  parser.parseOptionalAttrDict(result.attributes))
112  return failure();
113 
114  return success();
115 }
116 
117 static void print(OpAsmPrinter &p, ExecuteRegionOp op) {
118  p.printOptionalArrowTypeList(op.getResultTypes());
119 
120  p << ' ';
121  p.printRegion(op.getRegion(),
122  /*printEntryBlockArgs=*/false,
123  /*printBlockTerminators=*/true);
124 
125  p.printOptionalAttrDict(op->getAttrs());
126 }
127 
128 static LogicalResult verify(ExecuteRegionOp op) {
129  if (op.getRegion().empty())
130  return op.emitOpError("region needs to have at least one block");
131  if (op.getRegion().front().getNumArguments() > 0)
132  return op.emitOpError("region cannot have any arguments");
133  return success();
134 }
135 
136 // Inline an ExecuteRegionOp if it only contains one block.
137 // "test.foo"() : () -> ()
138 // %v = scf.execute_region -> i64 {
139 // %x = "test.val"() : () -> i64
140 // scf.yield %x : i64
141 // }
142 // "test.bar"(%v) : (i64) -> ()
143 //
144 // becomes
145 //
146 // "test.foo"() : () -> ()
147 // %x = "test.val"() : () -> i64
148 // "test.bar"(%x) : (i64) -> ()
149 //
150 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
152 
153  LogicalResult matchAndRewrite(ExecuteRegionOp op,
154  PatternRewriter &rewriter) const override {
155  if (!llvm::hasSingleElement(op.getRegion()))
156  return failure();
157  replaceOpWithRegion(rewriter, op, op.getRegion());
158  return success();
159  }
160 };
161 
162 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
163 // TODO generalize the conditions for operations which can be inlined into.
164 // func @func_execute_region_elim() {
165 // "test.foo"() : () -> ()
166 // %v = scf.execute_region -> i64 {
167 // %c = "test.cmp"() : () -> i1
168 // cond_br %c, ^bb2, ^bb3
169 // ^bb2:
170 // %x = "test.val1"() : () -> i64
171 // br ^bb4(%x : i64)
172 // ^bb3:
173 // %y = "test.val2"() : () -> i64
174 // br ^bb4(%y : i64)
175 // ^bb4(%z : i64):
176 // scf.yield %z : i64
177 // }
178 // "test.bar"(%v) : (i64) -> ()
179 // return
180 // }
181 //
182 // becomes
183 //
184 // func @func_execute_region_elim() {
185 // "test.foo"() : () -> ()
186 // %c = "test.cmp"() : () -> i1
187 // cond_br %c, ^bb1, ^bb2
188 // ^bb1: // pred: ^bb0
189 // %x = "test.val1"() : () -> i64
190 // br ^bb3(%x : i64)
191 // ^bb2: // pred: ^bb0
192 // %y = "test.val2"() : () -> i64
193 // br ^bb3(%y : i64)
194 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
195 // "test.bar"(%z) : (i64) -> ()
196 // return
197 // }
198 //
199 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
201 
202  LogicalResult matchAndRewrite(ExecuteRegionOp op,
203  PatternRewriter &rewriter) const override {
204  if (!isa<FuncOp, ExecuteRegionOp>(op->getParentOp()))
205  return failure();
206 
207  Block *prevBlock = op->getBlock();
208  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
209  rewriter.setInsertionPointToEnd(prevBlock);
210 
211  rewriter.create<BranchOp>(op.getLoc(), &op.getRegion().front());
212 
213  for (Block &blk : op.getRegion()) {
214  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
215  rewriter.setInsertionPoint(yieldOp);
216  rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
217  yieldOp.getResults());
218  rewriter.eraseOp(yieldOp);
219  }
220  }
221 
222  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
223  SmallVector<Value> blockArgs;
224 
225  for (auto res : op.getResults())
226  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
227 
228  rewriter.replaceOp(op, blockArgs);
229  return success();
230  }
231 };
232 
233 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
234  MLIRContext *context) {
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // ConditionOp
240 //===----------------------------------------------------------------------===//
241 
243 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
244  // Pass all operands except the condition to the successor region.
245  return getArgsMutable();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // ForOp
250 //===----------------------------------------------------------------------===//
251 
252 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
253  Value ub, Value step, ValueRange iterArgs,
254  BodyBuilderFn bodyBuilder) {
255  result.addOperands({lb, ub, step});
256  result.addOperands(iterArgs);
257  for (Value v : iterArgs)
258  result.addTypes(v.getType());
259  Region *bodyRegion = result.addRegion();
260  bodyRegion->push_back(new Block);
261  Block &bodyBlock = bodyRegion->front();
262  bodyBlock.addArgument(builder.getIndexType(), result.location);
263  for (Value v : iterArgs)
264  bodyBlock.addArgument(v.getType(), v.getLoc());
265 
266  // Create the default terminator if the builder is not provided and if the
267  // iteration arguments are not provided. Otherwise, leave this to the caller
268  // because we don't know which values to return from the loop.
269  if (iterArgs.empty() && !bodyBuilder) {
270  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
271  } else if (bodyBuilder) {
272  OpBuilder::InsertionGuard guard(builder);
273  builder.setInsertionPointToStart(&bodyBlock);
274  bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
275  bodyBlock.getArguments().drop_front());
276  }
277 }
278 
279 static LogicalResult verify(ForOp op) {
280  if (auto cst = op.getStep().getDefiningOp<arith::ConstantIndexOp>())
281  if (cst.value() <= 0)
282  return op.emitOpError("constant step operand must be positive");
283 
284  // Check that the body defines as single block argument for the induction
285  // variable.
286  auto *body = op.getBody();
287  if (!body->getArgument(0).getType().isIndex())
288  return op.emitOpError(
289  "expected body first argument to be an index argument for "
290  "the induction variable");
291 
292  auto opNumResults = op.getNumResults();
293  if (opNumResults == 0)
294  return success();
295  // If ForOp defines values, check that the number and types of
296  // the defined values match ForOp initial iter operands and backedge
297  // basic block arguments.
298  if (op.getNumIterOperands() != opNumResults)
299  return op.emitOpError(
300  "mismatch in number of loop-carried values and defined values");
301  if (op.getNumRegionIterArgs() != opNumResults)
302  return op.emitOpError(
303  "mismatch in number of basic block args and defined values");
304  auto iterOperands = op.getIterOperands();
305  auto iterArgs = op.getRegionIterArgs();
306  auto opResults = op.getResults();
307  unsigned i = 0;
308  for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
309  if (std::get<0>(e).getType() != std::get<2>(e).getType())
310  return op.emitOpError() << "types mismatch between " << i
311  << "th iter operand and defined value";
312  if (std::get<1>(e).getType() != std::get<2>(e).getType())
313  return op.emitOpError() << "types mismatch between " << i
314  << "th iter region arg and defined value";
315 
316  i++;
317  }
318 
319  return RegionBranchOpInterface::verifyTypes(op);
320 }
321 
322 /// Prints the initialization list in the form of
323 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
324 /// where 'inner' values are assumed to be region arguments and 'outer' values
325 /// are regular SSA values.
327  Block::BlockArgListType blocksArgs,
328  ValueRange initializers,
329  StringRef prefix = "") {
330  assert(blocksArgs.size() == initializers.size() &&
331  "expected same length of arguments and initializers");
332  if (initializers.empty())
333  return;
334 
335  p << prefix << '(';
336  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
337  p << std::get<0>(it) << " = " << std::get<1>(it);
338  });
339  p << ")";
340 }
341 
342 static void print(OpAsmPrinter &p, ForOp op) {
343  p << " " << op.getInductionVar() << " = " << op.getLowerBound() << " to "
344  << op.getUpperBound() << " step " << op.getStep();
345 
346  printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
347  " iter_args");
348  if (!op.getIterOperands().empty())
349  p << " -> (" << op.getIterOperands().getTypes() << ')';
350  p << ' ';
351  p.printRegion(op.getRegion(),
352  /*printEntryBlockArgs=*/false,
353  /*printBlockTerminators=*/op.hasIterOperands());
354  p.printOptionalAttrDict(op->getAttrs());
355 }
356 
358  auto &builder = parser.getBuilder();
359  OpAsmParser::OperandType inductionVariable, lb, ub, step;
360  // Parse the induction variable followed by '='.
361  if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
362  return failure();
363 
364  // Parse loop bounds.
365  Type indexType = builder.getIndexType();
366  if (parser.parseOperand(lb) ||
367  parser.resolveOperand(lb, indexType, result.operands) ||
368  parser.parseKeyword("to") || parser.parseOperand(ub) ||
369  parser.resolveOperand(ub, indexType, result.operands) ||
370  parser.parseKeyword("step") || parser.parseOperand(step) ||
371  parser.resolveOperand(step, indexType, result.operands))
372  return failure();
373 
374  // Parse the optional initial iteration arguments.
375  SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
376  SmallVector<Type, 4> argTypes;
377  regionArgs.push_back(inductionVariable);
378 
379  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
380  // Parse assignment list and results type list.
381  if (parser.parseAssignmentList(regionArgs, operands) ||
382  parser.parseArrowTypeList(result.types))
383  return failure();
384  // Resolve input operands.
385  for (auto operandType : llvm::zip(operands, result.types))
386  if (parser.resolveOperand(std::get<0>(operandType),
387  std::get<1>(operandType), result.operands))
388  return failure();
389  }
390  // Induction variable.
391  argTypes.push_back(indexType);
392  // Loop carried variables
393  argTypes.append(result.types.begin(), result.types.end());
394  // Parse the body region.
395  Region *body = result.addRegion();
396  if (regionArgs.size() != argTypes.size())
397  return parser.emitError(
398  parser.getNameLoc(),
399  "mismatch in number of loop-carried values and defined values");
400 
401  if (parser.parseRegion(*body, regionArgs, argTypes))
402  return failure();
403 
404  ForOp::ensureTerminator(*body, builder, result.location);
405 
406  // Parse the optional attribute list.
407  if (parser.parseOptionalAttrDict(result.attributes))
408  return failure();
409 
410  return success();
411 }
412 
413 Region &ForOp::getLoopBody() { return getRegion(); }
414 
415 bool ForOp::isDefinedOutsideOfLoop(Value value) {
416  return !getRegion().isAncestor(value.getParentRegion());
417 }
418 
419 LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
420  for (auto *op : ops)
421  op->moveBefore(*this);
422  return success();
423 }
424 
426  auto ivArg = val.dyn_cast<BlockArgument>();
427  if (!ivArg)
428  return ForOp();
429  assert(ivArg.getOwner() && "unlinked block argument");
430  auto *containingOp = ivArg.getOwner()->getParentOp();
431  return dyn_cast_or_null<ForOp>(containingOp);
432 }
433 
434 /// Return operands used when entering the region at 'index'. These operands
435 /// correspond to the loop iterator operands, i.e., those excluding the
436 /// induction variable. LoopOp only has one region, so 0 is the only valid value
437 /// for `index`.
438 OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
439  assert(index == 0 && "invalid region index");
440 
441  // The initial operands map to the loop arguments after the induction
442  // variable.
443  return getInitArgs();
444 }
445 
446 /// Given the region at `index`, or the parent operation if `index` is None,
447 /// return the successor regions. These are the regions that may be selected
448 /// during the flow of control. `operands` is a set of optional attributes that
449 /// correspond to a constant value for each operand, or null if that operand is
450 /// not a constant.
451 void ForOp::getSuccessorRegions(Optional<unsigned> index,
452  ArrayRef<Attribute> operands,
454  // If the predecessor is the ForOp, branch into the body using the iterator
455  // arguments.
456  if (!index.hasValue()) {
457  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
458  return;
459  }
460 
461  // Otherwise, the loop may branch back to itself or the parent operation.
462  assert(index.getValue() == 0 && "expected loop region");
463  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
464  regions.push_back(RegionSuccessor(getResults()));
465 }
466 
468  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
469  ValueRange steps, ValueRange iterArgs,
471  bodyBuilder) {
472  assert(lbs.size() == ubs.size() &&
473  "expected the same number of lower and upper bounds");
474  assert(lbs.size() == steps.size() &&
475  "expected the same number of lower bounds and steps");
476 
477  // If there are no bounds, call the body-building function and return early.
478  if (lbs.empty()) {
479  ValueVector results =
480  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
481  : ValueVector();
482  assert(results.size() == iterArgs.size() &&
483  "loop nest body must return as many values as loop has iteration "
484  "arguments");
485  return LoopNest();
486  }
487 
488  // First, create the loop structure iteratively using the body-builder
489  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
490  OpBuilder::InsertionGuard guard(builder);
493  loops.reserve(lbs.size());
494  ivs.reserve(lbs.size());
495  ValueRange currentIterArgs = iterArgs;
496  Location currentLoc = loc;
497  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
498  auto loop = builder.create<scf::ForOp>(
499  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
500  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
501  ValueRange args) {
502  ivs.push_back(iv);
503  // It is safe to store ValueRange args because it points to block
504  // arguments of a loop operation that we also own.
505  currentIterArgs = args;
506  currentLoc = nestedLoc;
507  });
508  // Set the builder to point to the body of the newly created loop. We don't
509  // do this in the callback because the builder is reset when the callback
510  // returns.
511  builder.setInsertionPointToStart(loop.getBody());
512  loops.push_back(loop);
513  }
514 
515  // For all loops but the innermost, yield the results of the nested loop.
516  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
517  builder.setInsertionPointToEnd(loops[i].getBody());
518  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
519  }
520 
521  // In the body of the innermost loop, call the body building function if any
522  // and yield its results.
523  builder.setInsertionPointToStart(loops.back().getBody());
524  ValueVector results = bodyBuilder
525  ? bodyBuilder(builder, currentLoc, ivs,
526  loops.back().getRegionIterArgs())
527  : ValueVector();
528  assert(results.size() == iterArgs.size() &&
529  "loop nest body must return as many values as loop has iteration "
530  "arguments");
531  builder.setInsertionPointToEnd(loops.back().getBody());
532  builder.create<scf::YieldOp>(loc, results);
533 
534  // Return the loops.
535  LoopNest res;
536  res.loops.assign(loops.begin(), loops.end());
537  return res;
538 }
539 
541  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
542  ValueRange steps,
543  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
544  // Delegate to the main function by wrapping the body builder.
545  return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
546  [&bodyBuilder](OpBuilder &nestedBuilder,
547  Location nestedLoc, ValueRange ivs,
548  ValueRange) -> ValueVector {
549  if (bodyBuilder)
550  bodyBuilder(nestedBuilder, nestedLoc, ivs);
551  return {};
552  });
553 }
554 
555 namespace {
556 // Fold away ForOp iter arguments when:
557 // 1) The op yields the iter arguments.
558 // 2) The iter arguments have no use and the corresponding outer region
559 // iterators (inputs) are yielded.
560 // 3) The iter arguments have no use and the corresponding (operation) results
561 // have no use.
562 //
563 // These arguments must be defined outside of
564 // the ForOp region and can just be forwarded after simplifying the op inits,
565 // yields and returns.
566 //
567 // The implementation uses `mergeBlockBefore` to steal the content of the
568 // original ForOp and avoid cloning.
569 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
571 
572  LogicalResult matchAndRewrite(scf::ForOp forOp,
573  PatternRewriter &rewriter) const final {
574  bool canonicalize = false;
575  Block &block = forOp.getRegion().front();
576  auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
577 
578  // An internal flat vector of block transfer
579  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
580  // transformed block argument mappings. This plays the role of a
581  // BlockAndValueMapping for the particular use case of calling into
582  // `mergeBlockBefore`.
583  SmallVector<bool, 4> keepMask;
584  keepMask.reserve(yieldOp.getNumOperands());
585  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
586  newResultValues;
587  newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
588  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
589  newIterArgs.reserve(forOp.getNumIterOperands());
590  newYieldValues.reserve(yieldOp.getNumOperands());
591  newResultValues.reserve(forOp.getNumResults());
592  for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
593  forOp.getRegionIterArgs(), // iter inside region
594  forOp.getResults(), // op results
595  yieldOp.getOperands() // iter yield
596  )) {
597  // Forwarded is `true` when:
598  // 1) The region `iter` argument is yielded.
599  // 2) The region `iter` argument has no use, and the corresponding iter
600  // operand (input) is yielded.
601  // 3) The region `iter` argument has no use, and the corresponding op
602  // result has no use.
603  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
604  (std::get<1>(it).use_empty() &&
605  (std::get<0>(it) == std::get<3>(it) ||
606  std::get<2>(it).use_empty())));
607  keepMask.push_back(!forwarded);
608  canonicalize |= forwarded;
609  if (forwarded) {
610  newBlockTransferArgs.push_back(std::get<0>(it));
611  newResultValues.push_back(std::get<0>(it));
612  continue;
613  }
614  newIterArgs.push_back(std::get<0>(it));
615  newYieldValues.push_back(std::get<3>(it));
616  newBlockTransferArgs.push_back(Value()); // placeholder with null value
617  newResultValues.push_back(Value()); // placeholder with null value
618  }
619 
620  if (!canonicalize)
621  return failure();
622 
623  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
624  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
625  forOp.getStep(), newIterArgs);
626  Block &newBlock = newForOp.getRegion().front();
627 
628  // Replace the null placeholders with newly constructed values.
629  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
630  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
631  idx != e; ++idx) {
632  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
633  Value &newResultVal = newResultValues[idx];
634  assert((blockTransferArg && newResultVal) ||
635  (!blockTransferArg && !newResultVal));
636  if (!blockTransferArg) {
637  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
638  newResultVal = newForOp.getResult(collapsedIdx++);
639  }
640  }
641 
642  Block &oldBlock = forOp.getRegion().front();
643  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
644  "unexpected argument size mismatch");
645 
646  // No results case: the scf::ForOp builder already created a zero
647  // result terminator. Merge before this terminator and just get rid of the
648  // original terminator that has been merged in.
649  if (newIterArgs.empty()) {
650  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
651  rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
652  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
653  rewriter.replaceOp(forOp, newResultValues);
654  return success();
655  }
656 
657  // No terminator case: merge and rewrite the merged terminator.
658  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
659  OpBuilder::InsertionGuard g(rewriter);
660  rewriter.setInsertionPoint(mergedTerminator);
661  SmallVector<Value, 4> filteredOperands;
662  filteredOperands.reserve(newResultValues.size());
663  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
664  if (keepMask[idx])
665  filteredOperands.push_back(mergedTerminator.getOperand(idx));
666  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
667  filteredOperands);
668  };
669 
670  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
671  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
672  cloneFilteredTerminator(mergedYieldOp);
673  rewriter.eraseOp(mergedYieldOp);
674  rewriter.replaceOp(forOp, newResultValues);
675  return success();
676  }
677 };
678 
679 /// Rewriting pattern that erases loops that are known not to iterate and
680 /// replaces single-iteration loops with their bodies.
681 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
683 
684  LogicalResult matchAndRewrite(ForOp op,
685  PatternRewriter &rewriter) const override {
686  // If the upper bound is the same as the lower bound, the loop does not
687  // iterate, just remove it.
688  if (op.getLowerBound() == op.getUpperBound()) {
689  rewriter.replaceOp(op, op.getIterOperands());
690  return success();
691  }
692 
693  auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
694  auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
695  if (!lb || !ub)
696  return failure();
697 
698  // If the loop is known to have 0 iterations, remove it.
699  llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
700  llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
701  if (lbValue.sge(ubValue)) {
702  rewriter.replaceOp(op, op.getIterOperands());
703  return success();
704  }
705 
706  auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
707  if (!step)
708  return failure();
709 
710  // If the loop is known to have 1 iteration, inline its body and remove the
711  // loop.
712  llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
713  if ((lbValue + stepValue).sge(ubValue)) {
714  SmallVector<Value, 4> blockArgs;
715  blockArgs.reserve(op.getNumIterOperands() + 1);
716  blockArgs.push_back(op.getLowerBound());
717  llvm::append_range(blockArgs, op.getIterOperands());
718  replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
719  return success();
720  }
721 
722  return failure();
723  }
724 };
725 
726 /// Perform a replacement of one iter OpOperand of an scf.for to the
727 /// `replacement` value which is expected to be the source of a tensor.cast.
728 /// tensor.cast ops are inserted inside the block to account for the type cast.
729 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
730  OpOperand &operand,
731  Value replacement) {
732  Type oldType = operand.get().getType(), newType = replacement.getType();
733  assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
734  "expected ranked tensor types");
735 
736  // 1. Create new iter operands, exactly 1 is replaced.
737  ForOp forOp = cast<ForOp>(operand.getOwner());
738  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
739  "expected an iter OpOperand");
740  if (operand.get().getType() == replacement.getType())
741  return forOp;
742  SmallVector<Value> newIterOperands;
743  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
744  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
745  newIterOperands.push_back(replacement);
746  continue;
747  }
748  newIterOperands.push_back(opOperand.get());
749  }
750 
751  // 2. Create the new forOp shell.
752  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
753  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
754  forOp.getStep(), newIterOperands);
755  Block &newBlock = newForOp.getRegion().front();
756  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
757  newBlock.getArguments().end());
758 
759  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
760  // corresponding to the `replacement` value.
761  OpBuilder::InsertionGuard g(rewriter);
762  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
763  BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
764  newForOp->getOpOperand(operand.getOperandNumber()));
765  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
766  newRegionIterArg);
767  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
768 
769  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
770  Block &oldBlock = forOp.getRegion().front();
771  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
772 
773  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
774  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
775  rewriter.setInsertionPoint(clonedYieldOp);
776  unsigned yieldIdx =
777  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
778  Value castOut = rewriter.create<tensor::CastOp>(
779  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
780  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
781  newYieldOperands[yieldIdx] = castOut;
782  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
783  rewriter.eraseOp(clonedYieldOp);
784 
785  // 6. Inject an outgoing cast op after the forOp.
786  rewriter.setInsertionPointAfter(newForOp);
787  SmallVector<Value> newResults = newForOp.getResults();
788  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
789  newForOp.getLoc(), oldType, newResults[yieldIdx]);
790 
791  return newForOp;
792 }
793 
794 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
795 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
796 ///
797 /// ```
798 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
799 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
800 /// -> (tensor<?x?xf32>) {
801 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
802 /// scf.yield %2 : tensor<?x?xf32>
803 /// }
804 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
805 /// use_of(%2)
806 /// ```
807 ///
808 /// folds into:
809 ///
810 /// ```
811 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
812 /// -> (tensor<32x1024xf32>) {
813 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
814 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
815 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
816 /// scf.yield %4 : tensor<32x1024xf32>
817 /// }
818 /// use_of(%0)
819 /// ```
820 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
822 
823  LogicalResult matchAndRewrite(ForOp op,
824  PatternRewriter &rewriter) const override {
825  for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
826  OpOperand &iterOpOperand = std::get<0>(it);
827  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
828  if (!incomingCast)
829  continue;
830  if (!std::get<1>(it).hasOneUse())
831  continue;
832  auto outgoingCastOp =
833  dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
834  if (!outgoingCastOp)
835  continue;
836 
837  // Must be a tensor.cast op pair with matching types.
838  if (outgoingCastOp.getResult().getType() !=
839  incomingCast.source().getType())
840  continue;
841 
842  // Create a new ForOp with that iter operand replaced.
843  auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
844  incomingCast.source());
845 
846  // Insert outgoing cast and use it to replace the corresponding result.
847  rewriter.setInsertionPointAfter(newForOp);
848  SmallVector<Value> replacements = newForOp.getResults();
849  unsigned returnIdx =
850  iterOpOperand.getOperandNumber() - op.getNumControlOperands();
851  replacements[returnIdx] = rewriter.create<tensor::CastOp>(
852  op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
853  rewriter.replaceOp(op, replacements);
854  return success();
855  }
856  return failure();
857  }
858 };
859 
860 /// Canonicalize the iter_args of an scf::ForOp that involve a
861 /// `bufferization.to_tensor` and for which only the last loop iteration is
862 /// actually visible outside of the loop. The canonicalization looks for a
863 /// pattern such as:
864 /// ```
865 /// %t0 = ... : tensor_type
866 /// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
867 /// ...
868 /// // %m is either buffer_cast(%bb00) or defined above the loop
869 /// %m... : memref_type
870 /// ... // uses of %m with potential inplace updates
871 /// %new_tensor = bufferization.to_tensor %m : memref_type
872 /// ...
873 /// scf.yield %new_tensor : tensor_type
874 /// }
875 /// ```
876 ///
877 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
878 /// `%m = buffer_cast %bb0` op that feeds into the yielded
879 /// `bufferization.to_tensor` op.
880 ///
881 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
882 /// occurs between `bufferization.to_tensor and yield then the value %0
883 /// visible outside of the loop is the last `bufferization.to_tensor`
884 /// produced in the loop.
885 ///
886 /// For now, we approximate the absence of aliasing by only supporting the case
887 /// when the bufferization.to_tensor is the operation immediately preceding
888 /// the yield.
889 //
890 /// The canonicalization rewrites the pattern as:
891 /// ```
892 /// // %m is either a buffer_cast or defined above
893 /// %m... : memref_type
894 /// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
895 /// ... // uses of %m with potential inplace updates
896 /// scf.yield %bb0: tensor_type
897 /// }
898 /// %0 = bufferization.to_tensor %m : memref_type
899 /// ```
900 ///
901 /// A later bbArg canonicalization will further rewrite as:
902 /// ```
903 /// // %m is either a buffer_cast or defined above
904 /// %m... : memref_type
905 /// scf.for ... { // no iter_args
906 /// ... // uses of %m with potential inplace updates
907 /// }
908 /// %0 = bufferization.to_tensor %m : memref_type
909 /// ```
910 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
912 
913  LogicalResult matchAndRewrite(ForOp forOp,
914  PatternRewriter &rewriter) const override {
915  assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
916  "unexpected multiple blocks");
917 
918  Location loc = forOp.getLoc();
919  DenseMap<Value, Value> replacements;
920  for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
921  unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
922  auto yieldOp =
923  cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
924  Value yieldVal = yieldOp->getOperand(idx);
925  auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
926  bool isTensor = bbArg.getType().isa<TensorType>();
927 
928  bufferization::ToMemrefOp tensorToMemref;
929  // Either bbArg has no use or it has a single buffer_cast use.
930  if (bbArg.hasOneUse())
931  tensorToMemref =
932  dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
933  if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
934  continue;
935  // If tensorToMemref is present, it must feed into the `ToTensorOp`.
936  if (tensorToMemref && tensorLoadOp.memref() != tensorToMemref)
937  continue;
938  // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
939  // must be before `ToTensorOp` in the block so that the lastWrite
940  // property is not subject to additional side-effects.
941  // For now, we only support the case when ToTensorOp appears
942  // immediately before the terminator.
943  if (tensorLoadOp->getNextNode() != yieldOp)
944  continue;
945 
946  // Clone the optional tensorToMemref before forOp.
947  if (tensorToMemref) {
948  rewriter.setInsertionPoint(forOp);
949  rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
950  tensorToMemref, tensorToMemref.memref().getType(),
951  tensorToMemref.tensor());
952  }
953 
954  // Clone the tensorLoad after forOp.
955  rewriter.setInsertionPointAfter(forOp);
956  Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
957  loc, tensorLoadOp.memref());
958  Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
959  replacements.insert(std::make_pair(forOpResult, newTensorLoad));
960 
961  // Make the terminator just yield the bbArg, the old tensorLoadOp + the
962  // old bbArg (that is now directly yielded) will canonicalize away.
963  rewriter.startRootUpdate(yieldOp);
964  yieldOp.setOperand(idx, bbArg);
965  rewriter.finalizeRootUpdate(yieldOp);
966  }
967  if (replacements.empty())
968  return failure();
969 
970  // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
971  // replaces the whole op and erase it unconditionally. This is wrong for
972  // `forOp` as it generally contains ops with side effects.
973  // Instead, use `rewriter.replaceOpWithIf`.
974  SmallVector<Value> newResults;
975  newResults.reserve(forOp.getNumResults());
976  for (Value v : forOp.getResults()) {
977  auto it = replacements.find(v);
978  newResults.push_back((it != replacements.end()) ? it->second : v);
979  }
980  unsigned idx = 0;
981  rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
982  return op.get() != newResults[idx++];
983  });
984  return success();
985  }
986 };
987 } // namespace
988 
989 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
990  MLIRContext *context) {
991  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
992  LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // IfOp
997 //===----------------------------------------------------------------------===//
998 
1000  assert(a && "expected non-empty operation");
1001  assert(b && "expected non-empty operation");
1002 
1003  IfOp ifOp = a->getParentOfType<IfOp>();
1004  while (ifOp) {
1005  // Check if b is inside ifOp. (We already know that a is.)
1006  if (ifOp->isProperAncestor(b))
1007  // b is contained in ifOp. a and b are in mutually exclusive branches if
1008  // they are in different blocks of ifOp.
1009  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1010  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1011  // Check next enclosing IfOp.
1012  ifOp = ifOp->getParentOfType<IfOp>();
1013  }
1014 
1015  // Could not find a common IfOp among a's and b's ancestors.
1016  return false;
1017 }
1018 
1019 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1020  bool withElseRegion) {
1021  build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1022 }
1023 
1024 void IfOp::build(OpBuilder &builder, OperationState &result,
1025  TypeRange resultTypes, Value cond, bool withElseRegion) {
1026  auto addTerminator = [&](OpBuilder &nested, Location loc) {
1027  if (resultTypes.empty())
1028  IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1029  loc);
1030  };
1031 
1032  build(builder, result, resultTypes, cond, addTerminator,
1033  withElseRegion ? addTerminator
1034  : function_ref<void(OpBuilder &, Location)>());
1035 }
1036 
1037 void IfOp::build(OpBuilder &builder, OperationState &result,
1038  TypeRange resultTypes, Value cond,
1039  function_ref<void(OpBuilder &, Location)> thenBuilder,
1040  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1041  assert(thenBuilder && "the builder callback for 'then' must be present");
1042 
1043  result.addOperands(cond);
1044  result.addTypes(resultTypes);
1045 
1046  OpBuilder::InsertionGuard guard(builder);
1047  Region *thenRegion = result.addRegion();
1048  builder.createBlock(thenRegion);
1049  thenBuilder(builder, result.location);
1050 
1051  Region *elseRegion = result.addRegion();
1052  if (!elseBuilder)
1053  return;
1054 
1055  builder.createBlock(elseRegion);
1056  elseBuilder(builder, result.location);
1057 }
1058 
1059 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1060  function_ref<void(OpBuilder &, Location)> thenBuilder,
1061  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1062  build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1063 }
1064 
1065 static LogicalResult verify(IfOp op) {
1066  if (op.getNumResults() != 0 && op.getElseRegion().empty())
1067  return op.emitOpError("must have an else block if defining values");
1068 
1069  return RegionBranchOpInterface::verifyTypes(op);
1070 }
1071 
1073  // Create the regions for 'then'.
1074  result.regions.reserve(2);
1075  Region *thenRegion = result.addRegion();
1076  Region *elseRegion = result.addRegion();
1077 
1078  auto &builder = parser.getBuilder();
1080  Type i1Type = builder.getIntegerType(1);
1081  if (parser.parseOperand(cond) ||
1082  parser.resolveOperand(cond, i1Type, result.operands))
1083  return failure();
1084  // Parse optional results type list.
1085  if (parser.parseOptionalArrowTypeList(result.types))
1086  return failure();
1087  // Parse the 'then' region.
1088  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1089  return failure();
1090  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1091 
1092  // If we find an 'else' keyword then parse the 'else' region.
1093  if (!parser.parseOptionalKeyword("else")) {
1094  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1095  return failure();
1096  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1097  }
1098 
1099  // Parse the optional attribute list.
1100  if (parser.parseOptionalAttrDict(result.attributes))
1101  return failure();
1102  return success();
1103 }
1104 
1105 static void print(OpAsmPrinter &p, IfOp op) {
1106  bool printBlockTerminators = false;
1107 
1108  p << " " << op.getCondition();
1109  if (!op.getResults().empty()) {
1110  p << " -> (" << op.getResultTypes() << ")";
1111  // Print yield explicitly if the op defines values.
1112  printBlockTerminators = true;
1113  }
1114  p << ' ';
1115  p.printRegion(op.getThenRegion(),
1116  /*printEntryBlockArgs=*/false,
1117  /*printBlockTerminators=*/printBlockTerminators);
1118 
1119  // Print the 'else' regions if it exists and has a block.
1120  auto &elseRegion = op.getElseRegion();
1121  if (!elseRegion.empty()) {
1122  p << " else ";
1123  p.printRegion(elseRegion,
1124  /*printEntryBlockArgs=*/false,
1125  /*printBlockTerminators=*/printBlockTerminators);
1126  }
1127 
1128  p.printOptionalAttrDict(op->getAttrs());
1129 }
1130 
1131 /// Given the region at `index`, or the parent operation if `index` is None,
1132 /// return the successor regions. These are the regions that may be selected
1133 /// during the flow of control. `operands` is a set of optional attributes that
1134 /// correspond to a constant value for each operand, or null if that operand is
1135 /// not a constant.
1136 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1137  ArrayRef<Attribute> operands,
1139  // The `then` and the `else` region branch back to the parent operation.
1140  if (index.hasValue()) {
1141  regions.push_back(RegionSuccessor(getResults()));
1142  return;
1143  }
1144 
1145  // Don't consider the else region if it is empty.
1146  Region *elseRegion = &this->getElseRegion();
1147  if (elseRegion->empty())
1148  elseRegion = nullptr;
1149 
1150  // Otherwise, the successor is dependent on the condition.
1151  bool condition;
1152  if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1153  condition = condAttr.getValue().isOneValue();
1154  } else {
1155  // If the condition isn't constant, both regions may be executed.
1156  regions.push_back(RegionSuccessor(&getThenRegion()));
1157  // If the else region does not exist, it is not a viable successor.
1158  if (elseRegion)
1159  regions.push_back(RegionSuccessor(elseRegion));
1160  return;
1161  }
1162 
1163  // Add the successor regions using the condition.
1164  regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1165 }
1166 
1167 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1168  SmallVectorImpl<OpFoldResult> &results) {
1169  // if (!c) then A() else B() -> if c then B() else A()
1170  if (getElseRegion().empty())
1171  return failure();
1172 
1173  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1174  if (!xorStmt)
1175  return failure();
1176 
1177  if (!matchPattern(xorStmt.getRhs(), m_One()))
1178  return failure();
1179 
1180  getConditionMutable().assign(xorStmt.getLhs());
1181  Block *thenBlock = &getThenRegion().front();
1182  // It would be nicer to use iplist::swap, but that has no implemented
1183  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1184  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1185  getElseRegion().getBlocks());
1186  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1187  getThenRegion().getBlocks(), thenBlock);
1188  return success();
1189 }
1190 
1191 namespace {
1192 // Pattern to remove unused IfOp results.
1193 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1195 
1196  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1197  PatternRewriter &rewriter) const {
1198  // Move all operations to the destination block.
1199  rewriter.mergeBlocks(source, dest);
1200  // Replace the yield op by one that returns only the used values.
1201  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1202  SmallVector<Value, 4> usedOperands;
1203  llvm::transform(usedResults, std::back_inserter(usedOperands),
1204  [&](OpResult result) {
1205  return yieldOp.getOperand(result.getResultNumber());
1206  });
1207  rewriter.updateRootInPlace(yieldOp,
1208  [&]() { yieldOp->setOperands(usedOperands); });
1209  }
1210 
1211  LogicalResult matchAndRewrite(IfOp op,
1212  PatternRewriter &rewriter) const override {
1213  // Compute the list of used results.
1214  SmallVector<OpResult, 4> usedResults;
1215  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1216  [](OpResult result) { return !result.use_empty(); });
1217 
1218  // Replace the operation if only a subset of its results have uses.
1219  if (usedResults.size() == op.getNumResults())
1220  return failure();
1221 
1222  // Compute the result types of the replacement operation.
1223  SmallVector<Type, 4> newTypes;
1224  llvm::transform(usedResults, std::back_inserter(newTypes),
1225  [](OpResult result) { return result.getType(); });
1226 
1227  // Create a replacement operation with empty then and else regions.
1228  auto emptyBuilder = [](OpBuilder &, Location) {};
1229  auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1230  emptyBuilder, emptyBuilder);
1231 
1232  // Move the bodies and replace the terminators (note there is a then and
1233  // an else region since the operation returns results).
1234  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1235  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1236 
1237  // Replace the operation by the new one.
1238  SmallVector<Value, 4> repResults(op.getNumResults());
1239  for (const auto &en : llvm::enumerate(usedResults))
1240  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1241  rewriter.replaceOp(op, repResults);
1242  return success();
1243  }
1244 };
1245 
1246 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1248 
1249  LogicalResult matchAndRewrite(IfOp op,
1250  PatternRewriter &rewriter) const override {
1251  auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1252  if (!constant)
1253  return failure();
1254 
1255  if (constant.getValue().cast<BoolAttr>().getValue())
1256  replaceOpWithRegion(rewriter, op, op.getThenRegion());
1257  else if (!op.getElseRegion().empty())
1258  replaceOpWithRegion(rewriter, op, op.getElseRegion());
1259  else
1260  rewriter.eraseOp(op);
1261 
1262  return success();
1263  }
1264 };
1265 
1266 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1268 
1269  LogicalResult matchAndRewrite(IfOp op,
1270  PatternRewriter &rewriter) const override {
1271  if (op->getNumResults() == 0)
1272  return failure();
1273 
1274  if (!llvm::hasSingleElement(op.getThenRegion().front()) ||
1275  !llvm::hasSingleElement(op.getElseRegion().front()))
1276  return failure();
1277 
1278  auto cond = op.getCondition();
1279  auto thenYieldArgs =
1280  cast<scf::YieldOp>(op.getThenRegion().front().getTerminator())
1281  .getOperands();
1282  auto elseYieldArgs =
1283  cast<scf::YieldOp>(op.getElseRegion().front().getTerminator())
1284  .getOperands();
1285  SmallVector<Value> results(op->getNumResults());
1286  assert(thenYieldArgs.size() == results.size());
1287  assert(elseYieldArgs.size() == results.size());
1288  for (const auto &it :
1289  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1290  Value trueVal = std::get<0>(it.value());
1291  Value falseVal = std::get<1>(it.value());
1292  if (trueVal == falseVal)
1293  results[it.index()] = trueVal;
1294  else
1295  results[it.index()] =
1296  rewriter.create<SelectOp>(op.getLoc(), cond, trueVal, falseVal);
1297  }
1298 
1299  rewriter.replaceOp(op, results);
1300  return success();
1301  }
1302 };
1303 
1304 /// Allow the true region of an if to assume the condition is true
1305 /// and vice versa. For example:
1306 ///
1307 /// scf.if %cmp {
1308 /// print(%cmp)
1309 /// }
1310 ///
1311 /// becomes
1312 ///
1313 /// scf.if %cmp {
1314 /// print(true)
1315 /// }
1316 ///
1317 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1319 
1320  LogicalResult matchAndRewrite(IfOp op,
1321  PatternRewriter &rewriter) const override {
1322  // Early exit if the condition is constant since replacing a constant
1323  // in the body with another constant isn't a simplification.
1324  if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1325  return failure();
1326 
1327  bool changed = false;
1328  mlir::Type i1Ty = rewriter.getI1Type();
1329 
1330  // These variables serve to prevent creating duplicate constants
1331  // and hold constant true or false values.
1332  Value constantTrue = nullptr;
1333  Value constantFalse = nullptr;
1334 
1335  for (OpOperand &use :
1336  llvm::make_early_inc_range(op.getCondition().getUses())) {
1337  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1338  changed = true;
1339 
1340  if (!constantTrue)
1341  constantTrue = rewriter.create<arith::ConstantOp>(
1342  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1343 
1344  rewriter.updateRootInPlace(use.getOwner(),
1345  [&]() { use.set(constantTrue); });
1346  } else if (op.getElseRegion().isAncestor(
1347  use.getOwner()->getParentRegion())) {
1348  changed = true;
1349 
1350  if (!constantFalse)
1351  constantFalse = rewriter.create<arith::ConstantOp>(
1352  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1353 
1354  rewriter.updateRootInPlace(use.getOwner(),
1355  [&]() { use.set(constantFalse); });
1356  }
1357  }
1358 
1359  return success(changed);
1360  }
1361 };
1362 
1363 /// Remove any statements from an if that are equivalent to the condition
1364 /// or its negation. For example:
1365 ///
1366 /// %res:2 = scf.if %cmp {
1367 /// yield something(), true
1368 /// } else {
1369 /// yield something2(), false
1370 /// }
1371 /// print(%res#1)
1372 ///
1373 /// becomes
1374 /// %res = scf.if %cmp {
1375 /// yield something()
1376 /// } else {
1377 /// yield something2()
1378 /// }
1379 /// print(%cmp)
1380 ///
1381 /// Additionally if both branches yield the same value, replace all uses
1382 /// of the result with the yielded value.
1383 ///
1384 /// %res:2 = scf.if %cmp {
1385 /// yield something(), %arg1
1386 /// } else {
1387 /// yield something2(), %arg1
1388 /// }
1389 /// print(%res#1)
1390 ///
1391 /// becomes
1392 /// %res = scf.if %cmp {
1393 /// yield something()
1394 /// } else {
1395 /// yield something2()
1396 /// }
1397 /// print(%arg1)
1398 ///
1399 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1401 
1402  LogicalResult matchAndRewrite(IfOp op,
1403  PatternRewriter &rewriter) const override {
1404  // Early exit if there are no results that could be replaced.
1405  if (op.getNumResults() == 0)
1406  return failure();
1407 
1408  auto trueYield =
1409  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1410  auto falseYield =
1411  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1412 
1413  rewriter.setInsertionPoint(op->getBlock(),
1414  op.getOperation()->getIterator());
1415  bool changed = false;
1416  Type i1Ty = rewriter.getI1Type();
1417  for (auto tup : llvm::zip(trueYield.getResults(), falseYield.getResults(),
1418  op.getResults())) {
1419  Value trueResult, falseResult, opResult;
1420  std::tie(trueResult, falseResult, opResult) = tup;
1421 
1422  if (trueResult == falseResult) {
1423  if (!opResult.use_empty()) {
1424  opResult.replaceAllUsesWith(trueResult);
1425  changed = true;
1426  }
1427  continue;
1428  }
1429 
1430  auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1431  if (!trueYield)
1432  continue;
1433 
1434  if (!trueYield.getType().isInteger(1))
1435  continue;
1436 
1437  auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1438  if (!falseYield)
1439  continue;
1440 
1441  bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1442  bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1443  if (!trueVal && falseVal) {
1444  if (!opResult.use_empty()) {
1445  Value notCond = rewriter.create<arith::XOrIOp>(
1446  op.getLoc(), op.getCondition(),
1447  rewriter.create<arith::ConstantOp>(
1448  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1449  opResult.replaceAllUsesWith(notCond);
1450  changed = true;
1451  }
1452  }
1453  if (trueVal && !falseVal) {
1454  if (!opResult.use_empty()) {
1455  opResult.replaceAllUsesWith(op.getCondition());
1456  changed = true;
1457  }
1458  }
1459  }
1460  return success(changed);
1461  }
1462 };
1463 
1464 /// Merge any consecutive scf.if's with the same condition.
1465 ///
1466 /// scf.if %cond {
1467 /// firstCodeTrue();...
1468 /// } else {
1469 /// firstCodeFalse();...
1470 /// }
1471 /// %res = scf.if %cond {
1472 /// secondCodeTrue();...
1473 /// } else {
1474 /// secondCodeFalse();...
1475 /// }
1476 ///
1477 /// becomes
1478 /// %res = scf.if %cmp {
1479 /// firstCodeTrue();...
1480 /// secondCodeTrue();...
1481 /// } else {
1482 /// firstCodeFalse();...
1483 /// secondCodeFalse();...
1484 /// }
1485 struct CombineIfs : public OpRewritePattern<IfOp> {
1487 
1488  LogicalResult matchAndRewrite(IfOp nextIf,
1489  PatternRewriter &rewriter) const override {
1490  Block *parent = nextIf->getBlock();
1491  if (nextIf == &parent->front())
1492  return failure();
1493 
1494  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1495  if (!prevIf)
1496  return failure();
1497 
1498  if (nextIf.getCondition() != prevIf.getCondition())
1499  return failure();
1500 
1501  // Don't permit merging if a result of the first if is used
1502  // within the second.
1503  if (llvm::any_of(prevIf->getUsers(),
1504  [&](Operation *user) { return nextIf->isAncestor(user); }))
1505  return failure();
1506 
1507  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1508  llvm::append_range(mergedTypes, nextIf.getResultTypes());
1509 
1510  IfOp combinedIf = rewriter.create<IfOp>(
1511  nextIf.getLoc(), mergedTypes, nextIf.getCondition(), /*hasElse=*/false);
1512  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
1513 
1514  YieldOp thenYield = prevIf.thenYield();
1515  YieldOp thenYield2 = nextIf.thenYield();
1516 
1517  combinedIf.getThenRegion().getBlocks().splice(
1518  combinedIf.getThenRegion().getBlocks().begin(),
1519  prevIf.getThenRegion().getBlocks());
1520 
1521  rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
1522  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
1523 
1524  SmallVector<Value> mergedYields(thenYield.getOperands());
1525  llvm::append_range(mergedYields, thenYield2.getOperands());
1526  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
1527  rewriter.eraseOp(thenYield);
1528  rewriter.eraseOp(thenYield2);
1529 
1530  combinedIf.getElseRegion().getBlocks().splice(
1531  combinedIf.getElseRegion().getBlocks().begin(),
1532  prevIf.getElseRegion().getBlocks());
1533 
1534  if (!nextIf.getElseRegion().empty()) {
1535  if (combinedIf.getElseRegion().empty()) {
1536  combinedIf.getElseRegion().getBlocks().splice(
1537  combinedIf.getElseRegion().getBlocks().begin(),
1538  nextIf.getElseRegion().getBlocks());
1539  } else {
1540  YieldOp elseYield = combinedIf.elseYield();
1541  YieldOp elseYield2 = nextIf.elseYield();
1542  rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
1543 
1544  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
1545 
1546  SmallVector<Value> mergedElseYields(elseYield.getOperands());
1547  llvm::append_range(mergedElseYields, elseYield2.getOperands());
1548 
1549  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
1550  rewriter.eraseOp(elseYield);
1551  rewriter.eraseOp(elseYield2);
1552  }
1553  }
1554 
1555  SmallVector<Value> prevValues;
1556  SmallVector<Value> nextValues;
1557  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
1558  if (pair.index() < prevIf.getNumResults())
1559  prevValues.push_back(pair.value());
1560  else
1561  nextValues.push_back(pair.value());
1562  }
1563  rewriter.replaceOp(prevIf, prevValues);
1564  rewriter.replaceOp(nextIf, nextValues);
1565  return success();
1566  }
1567 };
1568 
1569 /// Pattern to remove an empty else branch.
1570 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
1572 
1573  LogicalResult matchAndRewrite(IfOp ifOp,
1574  PatternRewriter &rewriter) const override {
1575  // Cannot remove else region when there are operation results.
1576  if (ifOp.getNumResults())
1577  return failure();
1578  Block *elseBlock = ifOp.elseBlock();
1579  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
1580  return failure();
1581  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
1582  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
1583  newIfOp.getThenRegion().begin());
1584  rewriter.eraseOp(ifOp);
1585  return success();
1586  }
1587 };
1588 
1589 /// Convert nested `if`s into `arith.andi` + single `if`.
1590 ///
1591 /// scf.if %arg0 {
1592 /// scf.if %arg1 {
1593 /// ...
1594 /// scf.yield
1595 /// }
1596 /// scf.yield
1597 /// }
1598 /// becomes
1599 ///
1600 /// %0 = arith.andi %arg0, %arg1
1601 /// scf.if %0 {
1602 /// ...
1603 /// scf.yield
1604 /// }
1605 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
1607 
1608  LogicalResult matchAndRewrite(IfOp op,
1609  PatternRewriter &rewriter) const override {
1610  // Both `if` ops must not yield results and have only `then` block.
1611  if (op->getNumResults() != 0 || op.elseBlock())
1612  return failure();
1613 
1614  auto nestedOps = op.thenBlock()->without_terminator();
1615  // Nested `if` must be the only op in block.
1616  if (!llvm::hasSingleElement(nestedOps))
1617  return failure();
1618 
1619  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
1620  if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
1621  return failure();
1622 
1623  Location loc = op.getLoc();
1624  Value newCondition = rewriter.create<arith::AndIOp>(
1625  loc, op.getCondition(), nestedIf.getCondition());
1626  auto newIf = rewriter.create<IfOp>(loc, newCondition);
1627  Block *newIfBlock = newIf.thenBlock();
1628  rewriter.eraseOp(newIfBlock->getTerminator());
1629  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
1630  rewriter.eraseOp(op);
1631  return success();
1632  }
1633 };
1634 
1635 } // namespace
1636 
1637 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
1638  MLIRContext *context) {
1639  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
1640  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
1641  RemoveStaticCondition, RemoveUnusedResults,
1642  ReplaceIfYieldWithConditionOrValue>(context);
1643 }
1644 
1645 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
1646 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
1647 Block *IfOp::elseBlock() {
1648  Region &r = getElseRegion();
1649  if (r.empty())
1650  return nullptr;
1651  return &r.back();
1652 }
1653 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // ParallelOp
1657 //===----------------------------------------------------------------------===//
1658 
1659 void ParallelOp::build(
1660  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
1661  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
1663  bodyBuilderFn) {
1664  result.addOperands(lowerBounds);
1665  result.addOperands(upperBounds);
1666  result.addOperands(steps);
1667  result.addOperands(initVals);
1668  result.addAttribute(
1669  ParallelOp::getOperandSegmentSizeAttr(),
1670  builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
1671  static_cast<int32_t>(upperBounds.size()),
1672  static_cast<int32_t>(steps.size()),
1673  static_cast<int32_t>(initVals.size())}));
1674  result.addTypes(initVals.getTypes());
1675 
1676  OpBuilder::InsertionGuard guard(builder);
1677  unsigned numIVs = steps.size();
1678  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
1679  SmallVector<Location, 8> argLocs(numIVs, result.location);
1680  Region *bodyRegion = result.addRegion();
1681  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1682 
1683  if (bodyBuilderFn) {
1684  builder.setInsertionPointToStart(bodyBlock);
1685  bodyBuilderFn(builder, result.location,
1686  bodyBlock->getArguments().take_front(numIVs),
1687  bodyBlock->getArguments().drop_front(numIVs));
1688  }
1689  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
1690 }
1691 
1692 void ParallelOp::build(
1693  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
1694  ValueRange upperBounds, ValueRange steps,
1695  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1696  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
1697  // we don't capture a reference to a temporary by constructing the lambda at
1698  // function level.
1699  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
1700  Location nestedLoc, ValueRange ivs,
1701  ValueRange) {
1702  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1703  };
1705  if (bodyBuilderFn)
1706  wrapper = wrappedBuilderFn;
1707 
1708  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
1709  wrapper);
1710 }
1711 
1712 static LogicalResult verify(ParallelOp op) {
1713  // Check that there is at least one value in lowerBound, upperBound and step.
1714  // It is sufficient to test only step, because it is ensured already that the
1715  // number of elements in lowerBound, upperBound and step are the same.
1716  Operation::operand_range stepValues = op.getStep();
1717  if (stepValues.empty())
1718  return op.emitOpError(
1719  "needs at least one tuple element for lowerBound, upperBound and step");
1720 
1721  // Check whether all constant step values are positive.
1722  for (Value stepValue : stepValues)
1723  if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
1724  if (cst.value() <= 0)
1725  return op.emitOpError("constant step operand must be positive");
1726 
1727  // Check that the body defines the same number of block arguments as the
1728  // number of tuple elements in step.
1729  Block *body = op.getBody();
1730  if (body->getNumArguments() != stepValues.size())
1731  return op.emitOpError()
1732  << "expects the same number of induction variables: "
1733  << body->getNumArguments()
1734  << " as bound and step values: " << stepValues.size();
1735  for (auto arg : body->getArguments())
1736  if (!arg.getType().isIndex())
1737  return op.emitOpError(
1738  "expects arguments for the induction variable to be of index type");
1739 
1740  // Check that the yield has no results
1741  Operation *yield = body->getTerminator();
1742  if (yield->getNumOperands() != 0)
1743  return yield->emitOpError() << "not allowed to have operands inside '"
1744  << ParallelOp::getOperationName() << "'";
1745 
1746  // Check that the number of results is the same as the number of ReduceOps.
1747  SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
1748  auto resultsSize = op.getResults().size();
1749  auto reductionsSize = reductions.size();
1750  auto initValsSize = op.getInitVals().size();
1751  if (resultsSize != reductionsSize)
1752  return op.emitOpError()
1753  << "expects number of results: " << resultsSize
1754  << " to be the same as number of reductions: " << reductionsSize;
1755  if (resultsSize != initValsSize)
1756  return op.emitOpError()
1757  << "expects number of results: " << resultsSize
1758  << " to be the same as number of initial values: " << initValsSize;
1759 
1760  // Check that the types of the results and reductions are the same.
1761  for (auto resultAndReduce : llvm::zip(op.getResults(), reductions)) {
1762  auto resultType = std::get<0>(resultAndReduce).getType();
1763  auto reduceOp = std::get<1>(resultAndReduce);
1764  auto reduceType = reduceOp.getOperand().getType();
1765  if (resultType != reduceType)
1766  return reduceOp.emitOpError()
1767  << "expects type of reduce: " << reduceType
1768  << " to be the same as result type: " << resultType;
1769  }
1770  return success();
1771 }
1772 
1774  OperationState &result) {
1775  auto &builder = parser.getBuilder();
1776  // Parse an opening `(` followed by induction variables followed by `)`
1778  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1780  return failure();
1781 
1782  // Parse loop bounds.
1784  if (parser.parseEqual() ||
1785  parser.parseOperandList(lower, ivs.size(),
1787  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
1788  return failure();
1789 
1791  if (parser.parseKeyword("to") ||
1792  parser.parseOperandList(upper, ivs.size(),
1794  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
1795  return failure();
1796 
1797  // Parse step values.
1799  if (parser.parseKeyword("step") ||
1800  parser.parseOperandList(steps, ivs.size(),
1802  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
1803  return failure();
1804 
1805  // Parse init values.
1807  if (succeeded(parser.parseOptionalKeyword("init"))) {
1808  if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1,
1810  return failure();
1811  }
1812 
1813  // Parse optional results in case there is a reduce.
1814  if (parser.parseOptionalArrowTypeList(result.types))
1815  return failure();
1816 
1817  // Now parse the body.
1818  Region *body = result.addRegion();
1819  SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
1820  if (parser.parseRegion(*body, ivs, types))
1821  return failure();
1822 
1823  // Set `operand_segment_sizes` attribute.
1824  result.addAttribute(
1825  ParallelOp::getOperandSegmentSizeAttr(),
1826  builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
1827  static_cast<int32_t>(upper.size()),
1828  static_cast<int32_t>(steps.size()),
1829  static_cast<int32_t>(initVals.size())}));
1830 
1831  // Parse attributes.
1832  if (parser.parseOptionalAttrDict(result.attributes))
1833  return failure();
1834 
1835  if (!initVals.empty())
1836  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1837  result.operands);
1838  // Add a terminator if none was parsed.
1839  ForOp::ensureTerminator(*body, builder, result.location);
1840 
1841  return success();
1842 }
1843 
1844 static void print(OpAsmPrinter &p, ParallelOp op) {
1845  p << " (" << op.getBody()->getArguments() << ") = (" << op.getLowerBound()
1846  << ") to (" << op.getUpperBound() << ") step (" << op.getStep() << ")";
1847  if (!op.getInitVals().empty())
1848  p << " init (" << op.getInitVals() << ")";
1849  p.printOptionalArrowTypeList(op.getResultTypes());
1850  p << ' ';
1851  p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/false);
1853  op->getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
1854 }
1855 
1856 Region &ParallelOp::getLoopBody() { return getRegion(); }
1857 
1858 bool ParallelOp::isDefinedOutsideOfLoop(Value value) {
1859  return !getRegion().isAncestor(value.getParentRegion());
1860 }
1861 
1862 LogicalResult ParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1863  for (auto *op : ops)
1864  op->moveBefore(*this);
1865  return success();
1866 }
1867 
1869  auto ivArg = val.dyn_cast<BlockArgument>();
1870  if (!ivArg)
1871  return ParallelOp();
1872  assert(ivArg.getOwner() && "unlinked block argument");
1873  auto *containingOp = ivArg.getOwner()->getParentOp();
1874  return dyn_cast<ParallelOp>(containingOp);
1875 }
1876 
1877 namespace {
1878 // Collapse loop dimensions that perform a single iteration.
1879 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
1881 
1882  LogicalResult matchAndRewrite(ParallelOp op,
1883  PatternRewriter &rewriter) const override {
1884  BlockAndValueMapping mapping;
1885  // Compute new loop bounds that omit all single-iteration loop dimensions.
1886  SmallVector<Value, 2> newLowerBounds;
1887  SmallVector<Value, 2> newUpperBounds;
1888  SmallVector<Value, 2> newSteps;
1889  newLowerBounds.reserve(op.getLowerBound().size());
1890  newUpperBounds.reserve(op.getUpperBound().size());
1891  newSteps.reserve(op.getStep().size());
1892  for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound(),
1893  op.getStep(), op.getInductionVars())) {
1894  Value lowerBound, upperBound, step, iv;
1895  std::tie(lowerBound, upperBound, step, iv) = dim;
1896  // Collect the statically known loop bounds.
1897  auto lowerBoundConstant =
1898  dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
1899  auto upperBoundConstant =
1900  dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
1901  auto stepConstant =
1902  dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
1903  // Replace the loop induction variable by the lower bound if the loop
1904  // performs a single iteration. Otherwise, copy the loop bounds.
1905  if (lowerBoundConstant && upperBoundConstant && stepConstant &&
1906  (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
1907  (upperBoundConstant.value() - lowerBoundConstant.value()) <=
1908  stepConstant.value()) {
1909  mapping.map(iv, lowerBound);
1910  } else {
1911  newLowerBounds.push_back(lowerBound);
1912  newUpperBounds.push_back(upperBound);
1913  newSteps.push_back(step);
1914  }
1915  }
1916  // Exit if none of the loop dimensions perform a single iteration.
1917  if (newLowerBounds.size() == op.getLowerBound().size())
1918  return failure();
1919 
1920  if (newLowerBounds.empty()) {
1921  // All of the loop dimensions perform a single iteration. Inline
1922  // loop body and nested ReduceOp's
1923  SmallVector<Value> results;
1924  results.reserve(op.getInitVals().size());
1925  for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
1926  auto reduce = dyn_cast<ReduceOp>(bodyOp);
1927  if (!reduce) {
1928  rewriter.clone(bodyOp, mapping);
1929  continue;
1930  }
1931  Block &reduceBlock = reduce.getReductionOperator().front();
1932  auto initValIndex = results.size();
1933  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
1934  mapping.map(reduceBlock.getArgument(1),
1935  mapping.lookupOrDefault(reduce.getOperand()));
1936  for (auto &reduceBodyOp : reduceBlock.without_terminator())
1937  rewriter.clone(reduceBodyOp, mapping);
1938 
1939  auto result = mapping.lookupOrDefault(
1940  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
1941  results.push_back(result);
1942  }
1943  rewriter.replaceOp(op, results);
1944  return success();
1945  }
1946  // Replace the parallel loop by lower-dimensional parallel loop.
1947  auto newOp =
1948  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
1949  newSteps, op.getInitVals(), nullptr);
1950  // Clone the loop body and remap the block arguments of the collapsed loops
1951  // (inlining does not support a cancellable block argument mapping).
1952  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
1953  newOp.getRegion().begin(), mapping);
1954  rewriter.replaceOp(op, newOp.getResults());
1955  return success();
1956  }
1957 };
1958 
1959 /// Removes parallel loops in which at least one lower/upper bound pair consists
1960 /// of the same values - such loops have an empty iteration domain.
1961 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
1963 
1964  LogicalResult matchAndRewrite(ParallelOp op,
1965  PatternRewriter &rewriter) const override {
1966  for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
1967  if (std::get<0>(dim) == std::get<1>(dim)) {
1968  rewriter.replaceOp(op, op.getInitVals());
1969  return success();
1970  }
1971  }
1972  return failure();
1973  }
1974 };
1975 
1976 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
1978 
1979  LogicalResult matchAndRewrite(ParallelOp op,
1980  PatternRewriter &rewriter) const override {
1981  Block &outerBody = op.getLoopBody().front();
1982  if (!llvm::hasSingleElement(outerBody.without_terminator()))
1983  return failure();
1984 
1985  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
1986  if (!innerOp)
1987  return failure();
1988 
1989  auto hasVal = [](const auto &range, Value val) {
1990  return llvm::find(range, val) != range.end();
1991  };
1992 
1993  for (auto val : outerBody.getArguments())
1994  if (hasVal(innerOp.getLowerBound(), val) ||
1995  hasVal(innerOp.getUpperBound(), val) ||
1996  hasVal(innerOp.getStep(), val))
1997  return failure();
1998 
1999  // Reductions are not supported yet.
2000  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2001  return failure();
2002 
2003  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2004  ValueRange iterVals, ValueRange) {
2005  Block &innerBody = innerOp.getLoopBody().front();
2006  assert(iterVals.size() ==
2007  (outerBody.getNumArguments() + innerBody.getNumArguments()));
2008  BlockAndValueMapping mapping;
2009  mapping.map(outerBody.getArguments(),
2010  iterVals.take_front(outerBody.getNumArguments()));
2011  mapping.map(innerBody.getArguments(),
2012  iterVals.take_back(innerBody.getNumArguments()));
2013  for (Operation &op : innerBody.without_terminator())
2014  builder.clone(op, mapping);
2015  };
2016 
2017  auto concatValues = [](const auto &first, const auto &second) {
2018  SmallVector<Value> ret;
2019  ret.reserve(first.size() + second.size());
2020  ret.assign(first.begin(), first.end());
2021  ret.append(second.begin(), second.end());
2022  return ret;
2023  };
2024 
2025  auto newLowerBounds =
2026  concatValues(op.getLowerBound(), innerOp.getLowerBound());
2027  auto newUpperBounds =
2028  concatValues(op.getUpperBound(), innerOp.getUpperBound());
2029  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2030 
2031  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2032  newSteps, llvm::None, bodyBuilder);
2033  return success();
2034  }
2035 };
2036 
2037 } // namespace
2038 
2039 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2040  MLIRContext *context) {
2041  results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2042  MergeNestedParallelLoops>(context);
2043 }
2044 
2045 //===----------------------------------------------------------------------===//
2046 // ReduceOp
2047 //===----------------------------------------------------------------------===//
2048 
2049 void ReduceOp::build(
2050  OpBuilder &builder, OperationState &result, Value operand,
2051  function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2052  auto type = operand.getType();
2053  result.addOperands(operand);
2054 
2055  OpBuilder::InsertionGuard guard(builder);
2056  Region *bodyRegion = result.addRegion();
2057  Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2058  {result.location, result.location});
2059  if (bodyBuilderFn)
2060  bodyBuilderFn(builder, result.location, body->getArgument(0),
2061  body->getArgument(1));
2062 }
2063 
2064 static LogicalResult verify(ReduceOp op) {
2065  // The region of a ReduceOp has two arguments of the same type as its operand.
2066  auto type = op.getOperand().getType();
2067  Block &block = op.getReductionOperator().front();
2068  if (block.empty())
2069  return op.emitOpError("the block inside reduce should not be empty");
2070  if (block.getNumArguments() != 2 ||
2071  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2072  return arg.getType() != type;
2073  }))
2074  return op.emitOpError()
2075  << "expects two arguments to reduce block of type " << type;
2076 
2077  // Check that the block is terminated by a ReduceReturnOp.
2078  if (!isa<ReduceReturnOp>(block.getTerminator()))
2079  return op.emitOpError("the block inside reduce should be terminated with a "
2080  "'scf.reduce.return' op");
2081 
2082  return success();
2083 }
2084 
2086  // Parse an opening `(` followed by the reduced value followed by `)`
2087  OpAsmParser::OperandType operand;
2088  if (parser.parseLParen() || parser.parseOperand(operand) ||
2089  parser.parseRParen())
2090  return failure();
2091 
2092  Type resultType;
2093  // Parse the type of the operand (and also what reduce computes on).
2094  if (parser.parseColonType(resultType) ||
2095  parser.resolveOperand(operand, resultType, result.operands))
2096  return failure();
2097 
2098  // Now parse the body.
2099  Region *body = result.addRegion();
2100  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2101  return failure();
2102 
2103  return success();
2104 }
2105 
2106 static void print(OpAsmPrinter &p, ReduceOp op) {
2107  p << "(" << op.getOperand() << ") ";
2108  p << " : " << op.getOperand().getType() << ' ';
2109  p.printRegion(op.getReductionOperator());
2110 }
2111 
2112 //===----------------------------------------------------------------------===//
2113 // ReduceReturnOp
2114 //===----------------------------------------------------------------------===//
2115 
2116 static LogicalResult verify(ReduceReturnOp op) {
2117  // The type of the return value should be the same type as the type of the
2118  // operand of the enclosing ReduceOp.
2119  auto reduceOp = cast<ReduceOp>(op->getParentOp());
2120  Type reduceType = reduceOp.getOperand().getType();
2121  if (reduceType != op.getResult().getType())
2122  return op.emitOpError() << "needs to have type " << reduceType
2123  << " (the type of the enclosing ReduceOp)";
2124  return success();
2125 }
2126 
2127 //===----------------------------------------------------------------------===//
2128 // WhileOp
2129 //===----------------------------------------------------------------------===//
2130 
2131 OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
2132  assert(index == 0 &&
2133  "WhileOp is expected to branch only to the first region");
2134 
2135  return getInits();
2136 }
2137 
2138 ConditionOp WhileOp::getConditionOp() {
2139  return cast<ConditionOp>(getBefore().front().getTerminator());
2140 }
2141 
2142 YieldOp WhileOp::getYieldOp() {
2143  return cast<YieldOp>(getAfter().front().getTerminator());
2144 }
2145 
2146 Block::BlockArgListType WhileOp::getBeforeArguments() {
2147  return getBefore().front().getArguments();
2148 }
2149 
2150 Block::BlockArgListType WhileOp::getAfterArguments() {
2151  return getAfter().front().getArguments();
2152 }
2153 
2154 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2155  ArrayRef<Attribute> operands,
2157  (void)operands;
2158 
2159  if (!index.hasValue()) {
2160  regions.emplace_back(&getBefore(), getBefore().getArguments());
2161  return;
2162  }
2163 
2164  assert(*index < 2 && "there are only two regions in a WhileOp");
2165  if (*index == 0) {
2166  regions.emplace_back(&getAfter(), getAfter().getArguments());
2167  regions.emplace_back(getResults());
2168  return;
2169  }
2170 
2171  regions.emplace_back(&getBefore(), getBefore().getArguments());
2172 }
2173 
2174 /// Parses a `while` op.
2175 ///
2176 /// op ::= `scf.while` assignments `:` function-type region `do` region
2177 /// `attributes` attribute-dict
2178 /// initializer ::= /* empty */ | `(` assignment-list `)`
2179 /// assignment-list ::= assignment | assignment `,` assignment-list
2180 /// assignment ::= ssa-value `=` ssa-value
2181 static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
2182  SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
2183  Region *before = result.addRegion();
2184  Region *after = result.addRegion();
2185 
2186  OptionalParseResult listResult =
2187  parser.parseOptionalAssignmentList(regionArgs, operands);
2188  if (listResult.hasValue() && failed(listResult.getValue()))
2189  return failure();
2190 
2191  FunctionType functionType;
2192  llvm::SMLoc typeLoc = parser.getCurrentLocation();
2193  if (failed(parser.parseColonType(functionType)))
2194  return failure();
2195 
2196  result.addTypes(functionType.getResults());
2197 
2198  if (functionType.getNumInputs() != operands.size()) {
2199  return parser.emitError(typeLoc)
2200  << "expected as many input types as operands "
2201  << "(expected " << operands.size() << " got "
2202  << functionType.getNumInputs() << ")";
2203  }
2204 
2205  // Resolve input operands.
2206  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2207  parser.getCurrentLocation(),
2208  result.operands)))
2209  return failure();
2210 
2211  return failure(
2212  parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
2213  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2215 }
2216 
2217 /// Prints a `while` op.
2218 static void print(OpAsmPrinter &p, scf::WhileOp op) {
2219  printInitializationList(p, op.getBefore().front().getArguments(),
2220  op.getInits(), " ");
2221  p << " : ";
2222  p.printFunctionalType(op.getInits().getTypes(), op.getResults().getTypes());
2223  p << ' ';
2224  p.printRegion(op.getBefore(), /*printEntryBlockArgs=*/false);
2225  p << " do ";
2226  p.printRegion(op.getAfter());
2227  p.printOptionalAttrDictWithKeyword(op->getAttrs());
2228 }
2229 
2230 /// Verifies that two ranges of types match, i.e. have the same number of
2231 /// entries and that types are pairwise equals. Reports errors on the given
2232 /// operation in case of mismatch.
2233 template <typename OpTy>
2234 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2235  TypeRange right, StringRef message) {
2236  if (left.size() != right.size())
2237  return op.emitOpError("expects the same number of ") << message;
2238 
2239  for (unsigned i = 0, e = left.size(); i < e; ++i) {
2240  if (left[i] != right[i]) {
2241  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2242  << message;
2243  diag.attachNote() << "for argument " << i << ", found " << left[i]
2244  << " and " << right[i];
2245  return diag;
2246  }
2247  }
2248 
2249  return success();
2250 }
2251 
2252 /// Verifies that the first block of the given `region` is terminated by a
2253 /// YieldOp. Reports errors on the given operation if it is not the case.
2254 template <typename TerminatorTy>
2255 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2256  StringRef errorMessage) {
2257  Operation *terminatorOperation = region.front().getTerminator();
2258  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2259  return yield;
2260 
2261  auto diag = op.emitOpError(errorMessage);
2262  if (terminatorOperation)
2263  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2264  return nullptr;
2265 }
2266 
2267 static LogicalResult verify(scf::WhileOp op) {
2268  if (failed(RegionBranchOpInterface::verifyTypes(op)))
2269  return failure();
2270 
2271  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2272  op, op.getBefore(),
2273  "expects the 'before' region to terminate with 'scf.condition'");
2274  if (!beforeTerminator)
2275  return failure();
2276 
2277  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2278  op, op.getAfter(),
2279  "expects the 'after' region to terminate with 'scf.yield'");
2280  return success(afterTerminator != nullptr);
2281 }
2282 
2283 namespace {
2284 /// Replace uses of the condition within the do block with true, since otherwise
2285 /// the block would not be evaluated.
2286 ///
2287 /// scf.while (..) : (i1, ...) -> ... {
2288 /// %condition = call @evaluate_condition() : () -> i1
2289 /// scf.condition(%condition) %condition : i1, ...
2290 /// } do {
2291 /// ^bb0(%arg0: i1, ...):
2292 /// use(%arg0)
2293 /// ...
2294 ///
2295 /// becomes
2296 /// scf.while (..) : (i1, ...) -> ... {
2297 /// %condition = call @evaluate_condition() : () -> i1
2298 /// scf.condition(%condition) %condition : i1, ...
2299 /// } do {
2300 /// ^bb0(%arg0: i1, ...):
2301 /// use(%true)
2302 /// ...
2303 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2305 
2306  LogicalResult matchAndRewrite(WhileOp op,
2307  PatternRewriter &rewriter) const override {
2308  auto term = op.getConditionOp();
2309 
2310  // These variables serve to prevent creating duplicate constants
2311  // and hold constant true or false values.
2312  Value constantTrue = nullptr;
2313 
2314  bool replaced = false;
2315  for (auto yieldedAndBlockArgs :
2316  llvm::zip(term.getArgs(), op.getAfterArguments())) {
2317  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2318  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2319  if (!constantTrue)
2320  constantTrue = rewriter.create<arith::ConstantOp>(
2321  op.getLoc(), term.getCondition().getType(),
2322  rewriter.getBoolAttr(true));
2323 
2324  std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2325  replaced = true;
2326  }
2327  }
2328  }
2329  return success(replaced);
2330  }
2331 };
2332 
2333 /// Remove WhileOp results that are also unused in 'after' block.
2334 ///
2335 /// %0:2 = scf.while () : () -> (i32, i64) {
2336 /// %condition = "test.condition"() : () -> i1
2337 /// %v1 = "test.get_some_value"() : () -> i32
2338 /// %v2 = "test.get_some_value"() : () -> i64
2339 /// scf.condition(%condition) %v1, %v2 : i32, i64
2340 /// } do {
2341 /// ^bb0(%arg0: i32, %arg1: i64):
2342 /// "test.use"(%arg0) : (i32) -> ()
2343 /// scf.yield
2344 /// }
2345 /// return %0#0 : i32
2346 ///
2347 /// becomes
2348 /// %0 = scf.while () : () -> (i32) {
2349 /// %condition = "test.condition"() : () -> i1
2350 /// %v1 = "test.get_some_value"() : () -> i32
2351 /// %v2 = "test.get_some_value"() : () -> i64
2352 /// scf.condition(%condition) %v1 : i32
2353 /// } do {
2354 /// ^bb0(%arg0: i32):
2355 /// "test.use"(%arg0) : (i32) -> ()
2356 /// scf.yield
2357 /// }
2358 /// return %0 : i32
2359 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
2361 
2362  LogicalResult matchAndRewrite(WhileOp op,
2363  PatternRewriter &rewriter) const override {
2364  auto term = op.getConditionOp();
2365  auto afterArgs = op.getAfterArguments();
2366  auto termArgs = term.getArgs();
2367 
2368  // Collect results mapping, new terminator args and new result types.
2369  SmallVector<unsigned> newResultsIndices;
2370  SmallVector<Type> newResultTypes;
2371  SmallVector<Value> newTermArgs;
2372  SmallVector<Location> newArgLocs;
2373  bool needUpdate = false;
2374  for (const auto &it :
2375  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
2376  auto i = static_cast<unsigned>(it.index());
2377  Value result = std::get<0>(it.value());
2378  Value afterArg = std::get<1>(it.value());
2379  Value termArg = std::get<2>(it.value());
2380  if (result.use_empty() && afterArg.use_empty()) {
2381  needUpdate = true;
2382  } else {
2383  newResultsIndices.emplace_back(i);
2384  newTermArgs.emplace_back(termArg);
2385  newResultTypes.emplace_back(result.getType());
2386  newArgLocs.emplace_back(result.getLoc());
2387  }
2388  }
2389 
2390  if (!needUpdate)
2391  return failure();
2392 
2393  {
2394  OpBuilder::InsertionGuard g(rewriter);
2395  rewriter.setInsertionPoint(term);
2396  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
2397  newTermArgs);
2398  }
2399 
2400  auto newWhile =
2401  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
2402 
2403  Block &newAfterBlock = *rewriter.createBlock(
2404  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
2405 
2406  // Build new results list and new after block args (unused entries will be
2407  // null).
2408  SmallVector<Value> newResults(op.getNumResults());
2409  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
2410  for (const auto &it : llvm::enumerate(newResultsIndices)) {
2411  newResults[it.value()] = newWhile.getResult(it.index());
2412  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
2413  }
2414 
2415  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
2416  newWhile.getBefore().begin());
2417 
2418  Block &afterBlock = op.getAfter().front();
2419  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
2420 
2421  rewriter.replaceOp(op, newResults);
2422  return success();
2423  }
2424 };
2425 
2426 /// Replace operations equivalent to the condition in the do block with true,
2427 /// since otherwise the block would not be evaluated.
2428 ///
2429 /// scf.while (..) : (i32, ...) -> ... {
2430 /// %z = ... : i32
2431 /// %condition = cmpi pred %z, %a
2432 /// scf.condition(%condition) %z : i32, ...
2433 /// } do {
2434 /// ^bb0(%arg0: i32, ...):
2435 /// %condition2 = cmpi pred %arg0, %a
2436 /// use(%condition2)
2437 /// ...
2438 ///
2439 /// becomes
2440 /// scf.while (..) : (i32, ...) -> ... {
2441 /// %z = ... : i32
2442 /// %condition = cmpi pred %z, %a
2443 /// scf.condition(%condition) %z : i32, ...
2444 /// } do {
2445 /// ^bb0(%arg0: i32, ...):
2446 /// use(%true)
2447 /// ...
2448 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
2450 
2451  LogicalResult matchAndRewrite(scf::WhileOp op,
2452  PatternRewriter &rewriter) const override {
2453  using namespace scf;
2454  auto cond = op.getConditionOp();
2455  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
2456  if (!cmp)
2457  return failure();
2458  bool changed = false;
2459  for (auto tup :
2460  llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
2461  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
2462  if (std::get<0>(tup) != cmp.getOperand(opIdx))
2463  continue;
2464  for (OpOperand &u :
2465  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
2466  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
2467  if (!cmp2)
2468  continue;
2469  // For a binary operator 1-opIdx gets the other side.
2470  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
2471  continue;
2472  bool samePredicate;
2473  if (cmp2.getPredicate() == cmp.getPredicate())
2474  samePredicate = true;
2475  else if (cmp2.getPredicate() ==
2476  arith::invertPredicate(cmp.getPredicate()))
2477  samePredicate = false;
2478  else
2479  continue;
2480 
2481  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
2482  1);
2483  changed = true;
2484  }
2485  }
2486  }
2487  return success(changed);
2488  }
2489 };
2490 
2491 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
2493 
2494  LogicalResult matchAndRewrite(WhileOp op,
2495  PatternRewriter &rewriter) const override {
2496 
2497  if (!llvm::any_of(op.getBeforeArguments(),
2498  [](Value arg) { return arg.use_empty(); }))
2499  return failure();
2500 
2501  YieldOp yield = op.getYieldOp();
2502 
2503  // Collect results mapping, new terminator args and new result types.
2504  SmallVector<Value> newYields;
2505  SmallVector<Value> newInits;
2506  SmallVector<unsigned> argsToErase;
2507  for (const auto &it : llvm::enumerate(llvm::zip(
2508  op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
2509  Value beforeArg = std::get<0>(it.value());
2510  Value yieldValue = std::get<1>(it.value());
2511  Value initValue = std::get<2>(it.value());
2512  if (beforeArg.use_empty()) {
2513  argsToErase.push_back(it.index());
2514  } else {
2515  newYields.emplace_back(yieldValue);
2516  newInits.emplace_back(initValue);
2517  }
2518  }
2519 
2520  if (argsToErase.empty())
2521  return failure();
2522 
2523  rewriter.startRootUpdate(op);
2524  op.getBefore().front().eraseArguments(argsToErase);
2525  rewriter.finalizeRootUpdate(op);
2526 
2527  WhileOp replacement =
2528  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
2529  replacement.getBefore().takeBody(op.getBefore());
2530  replacement.getAfter().takeBody(op.getAfter());
2531  rewriter.replaceOp(op, replacement.getResults());
2532 
2533  rewriter.setInsertionPoint(yield);
2534  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
2535  return success();
2536  }
2537 };
2538 } // namespace
2539 
2540 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2541  MLIRContext *context) {
2542  results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
2543  WhileUnusedArg>(context);
2544 }
2545 
2546 //===----------------------------------------------------------------------===//
2547 // TableGen'd op method definitions
2548 //===----------------------------------------------------------------------===//
2549 
2550 #define GET_OP_CLASSES
2551 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
This is the representation of an operand reference.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:440
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
static ParseResult parseForOp(OpAsmParser &parser, OperationState &result)
Definition: SCF.cpp:357
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static std::string diag(llvm::Value &v)
virtual ParseResult parseLParen()=0
Parse a ( token.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:373
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation & back()
Definition: Block.h:143
This is a value defined by a result of an operation.
Definition: Value.h:423
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:41
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:301
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:457
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
virtual ParseResult parseRegionArgument(OperandType &argument)=0
Parse a region argument, this argument is resolved when calling &#39;parseRegion&#39;.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:500
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2, <...>) where &#39;inner&#39; values are assumed to be region arguments and &#39;outer&#39; values are regular SSA values.
Definition: SCF.cpp:326
std::vector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:59
unsigned getNumOperands()
Definition: Operation.h:215
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp...
Definition: SCF.cpp:999
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:202
ParseResult parseAssignmentList(SmallVectorImpl< OperandType > &lhs, SmallVectorImpl< OperandType > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
detail::constant_int_value_matcher< 1 > m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:243
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:200
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:81
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:774
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:310
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:120
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:161
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
virtual llvm::SMLoc getNameLoc() const =0
Return the location of the original name token.
static constexpr const bool value
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
static void print(OpAsmPrinter &p, ExecuteRegionOp op)
Definition: SCF.cpp:117
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:75
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual ParseResult resolveOperand(const OperandType &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool empty()
Definition: Region.h:60
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
void addOperands(ValueRange newOperands)
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:335
unsigned getNumArguments()
Definition: Block.h:119
static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result)
Definition: SCF.cpp:1072
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
U dyn_cast() const
Definition: Value.h:99
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:44
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:435
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
virtual ParseResult parseRParen()=0
Parse a ) token.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:41
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Block & back()
Definition: Region.h:64
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:1868
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
This class provides a mutable adaptor for a range of operands.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
Parens surrounding zero or more operands.
BlockArgListType getArguments()
Definition: Block.h:76
This class represents an argument of a Block.
Definition: Value.h:298
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
static ParseResult parseExecuteRegionOp(OpAsmParser &parser, OperationState &result)
(ssa-id =)? execute_region -> function-result-type { block+ }
Definition: SCF.cpp:103
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:779
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result)
Definition: SCF.cpp:2085
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:202
ParseResult getValue() const
Access the internal ParseResult value.
Definition: OpDefinition.h:65
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:467
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< OperandType > &lhs, SmallVectorImpl< OperandType > &rhs)=0
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
type_range getTypes() const
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:71
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:37
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
This class represents an operand of an operation.
Definition: Value.h:249
This class implements the operand iterators for the Operation class.
U cast() const
Definition: Value.h:107
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
virtual ParseResult parseEqual()=0
Parse a = token.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
bool isa() const
Definition: Types.h:234
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
virtual ParseResult parseRegionArgumentList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more region arguments with a specified surrounding delimiter, and an optional required ...
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:153
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
static ParseResult parseParallelOp(OpAsmParser &parser, OperationState &result)
Definition: SCF.cpp:1773
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.