MLIR  16.0.0git
SCF.cpp
Go to the documentation of this file.
1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
22 
23 using namespace mlir;
24 using namespace mlir::scf;
25 
26 #include "mlir/Dialect/SCF/IR/SCFOpsDialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // SCFDialect Dialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 struct SCFInlinerInterface : public DialectInlinerInterface {
35  // We don't have any special restrictions on what can be inlined into
36  // destination regions (e.g. while/conditional bodies). Always allow it.
37  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
38  BlockAndValueMapping &valueMapping) const final {
39  return true;
40  }
41  // Operations in scf dialect are always legal to inline since they are
42  // pure.
43  bool isLegalToInline(Operation *, Region *, bool,
44  BlockAndValueMapping &) const final {
45  return true;
46  }
47  // Handle the given inlined terminator by replacing it with a new operation
48  // as necessary. Required when the region has only one block.
49  void handleTerminator(Operation *op,
50  ArrayRef<Value> valuesToRepl) const final {
51  auto retValOp = dyn_cast<scf::YieldOp>(op);
52  if (!retValOp)
53  return;
54 
55  for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
56  std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
57  }
58  }
59 };
60 } // namespace
61 
62 //===----------------------------------------------------------------------===//
63 // SCFDialect
64 //===----------------------------------------------------------------------===//
65 
66 void SCFDialect::initialize() {
67  addOperations<
68 #define GET_OP_LIST
69 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
70  >();
71  addInterfaces<SCFInlinerInterface>();
72 }
73 
74 /// Default callback for IfOp builders. Inserts a yield without arguments.
76  builder.create<scf::YieldOp>(loc);
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // ExecuteRegionOp
81 //===----------------------------------------------------------------------===//
82 
83 /// Replaces the given op with the contents of the given single-block region,
84 /// using the operands of the block terminator to replace operation results.
85 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
86  Region &region, ValueRange blockArgs = {}) {
87  assert(llvm::hasSingleElement(region) && "expected single-region block");
88  Block *block = &region.front();
89  Operation *terminator = block->getTerminator();
90  ValueRange results = terminator->getOperands();
91  rewriter.mergeBlockBefore(block, op, blockArgs);
92  rewriter.replaceOp(op, results);
93  rewriter.eraseOp(terminator);
94 }
95 
96 ///
97 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
98 /// block+
99 /// `}`
100 ///
101 /// Example:
102 /// scf.execute_region -> i32 {
103 /// %idx = load %rI[%i] : memref<128xi32>
104 /// return %idx : i32
105 /// }
106 ///
107 ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
108  OperationState &result) {
109  if (parser.parseOptionalArrowTypeList(result.types))
110  return failure();
111 
112  // Introduce the body region and parse it.
113  Region *body = result.addRegion();
114  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
115  parser.parseOptionalAttrDict(result.attributes))
116  return failure();
117 
118  return success();
119 }
120 
122  p.printOptionalArrowTypeList(getResultTypes());
123 
124  p << ' ';
125  p.printRegion(getRegion(),
126  /*printEntryBlockArgs=*/false,
127  /*printBlockTerminators=*/true);
128 
129  p.printOptionalAttrDict((*this)->getAttrs());
130 }
131 
133  if (getRegion().empty())
134  return emitOpError("region needs to have at least one block");
135  if (getRegion().front().getNumArguments() > 0)
136  return emitOpError("region cannot have any arguments");
137  return success();
138 }
139 
140 // Inline an ExecuteRegionOp if it only contains one block.
141 // "test.foo"() : () -> ()
142 // %v = scf.execute_region -> i64 {
143 // %x = "test.val"() : () -> i64
144 // scf.yield %x : i64
145 // }
146 // "test.bar"(%v) : (i64) -> ()
147 //
148 // becomes
149 //
150 // "test.foo"() : () -> ()
151 // %x = "test.val"() : () -> i64
152 // "test.bar"(%x) : (i64) -> ()
153 //
154 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
156 
157  LogicalResult matchAndRewrite(ExecuteRegionOp op,
158  PatternRewriter &rewriter) const override {
159  if (!llvm::hasSingleElement(op.getRegion()))
160  return failure();
161  replaceOpWithRegion(rewriter, op, op.getRegion());
162  return success();
163  }
164 };
165 
166 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
167 // TODO generalize the conditions for operations which can be inlined into.
168 // func @func_execute_region_elim() {
169 // "test.foo"() : () -> ()
170 // %v = scf.execute_region -> i64 {
171 // %c = "test.cmp"() : () -> i1
172 // cf.cond_br %c, ^bb2, ^bb3
173 // ^bb2:
174 // %x = "test.val1"() : () -> i64
175 // cf.br ^bb4(%x : i64)
176 // ^bb3:
177 // %y = "test.val2"() : () -> i64
178 // cf.br ^bb4(%y : i64)
179 // ^bb4(%z : i64):
180 // scf.yield %z : i64
181 // }
182 // "test.bar"(%v) : (i64) -> ()
183 // return
184 // }
185 //
186 // becomes
187 //
188 // func @func_execute_region_elim() {
189 // "test.foo"() : () -> ()
190 // %c = "test.cmp"() : () -> i1
191 // cf.cond_br %c, ^bb1, ^bb2
192 // ^bb1: // pred: ^bb0
193 // %x = "test.val1"() : () -> i64
194 // cf.br ^bb3(%x : i64)
195 // ^bb2: // pred: ^bb0
196 // %y = "test.val2"() : () -> i64
197 // cf.br ^bb3(%y : i64)
198 // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2
199 // "test.bar"(%z) : (i64) -> ()
200 // return
201 // }
202 //
203 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
205 
206  LogicalResult matchAndRewrite(ExecuteRegionOp op,
207  PatternRewriter &rewriter) const override {
208  if (!isa<FunctionOpInterface, ExecuteRegionOp>(op->getParentOp()))
209  return failure();
210 
211  Block *prevBlock = op->getBlock();
212  Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
213  rewriter.setInsertionPointToEnd(prevBlock);
214 
215  rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front());
216 
217  for (Block &blk : op.getRegion()) {
218  if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
219  rewriter.setInsertionPoint(yieldOp);
220  rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock,
221  yieldOp.getResults());
222  rewriter.eraseOp(yieldOp);
223  }
224  }
225 
226  rewriter.inlineRegionBefore(op.getRegion(), postBlock);
227  SmallVector<Value> blockArgs;
228 
229  for (auto res : op.getResults())
230  blockArgs.push_back(postBlock->addArgument(res.getType(), res.getLoc()));
231 
232  rewriter.replaceOp(op, blockArgs);
233  return success();
234  }
235 };
236 
237 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
238  MLIRContext *context) {
240 }
241 
242 /// Given the region at `index`, or the parent operation if `index` is None,
243 /// return the successor regions. These are the regions that may be selected
244 /// during the flow of control. `operands` is a set of optional attributes that
245 /// correspond to a constant value for each operand, or null if that operand is
246 /// not a constant.
247 void ExecuteRegionOp::getSuccessorRegions(
248  Optional<unsigned> index, ArrayRef<Attribute> operands,
250  // If the predecessor is the ExecuteRegionOp, branch into the body.
251  if (!index) {
252  regions.push_back(RegionSuccessor(&getRegion()));
253  return;
254  }
255 
256  // Otherwise, the region branches back to the parent operation.
257  regions.push_back(RegionSuccessor(getResults()));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // ConditionOp
262 //===----------------------------------------------------------------------===//
263 
265 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
266  // Pass all operands except the condition to the successor region.
267  return getArgsMutable();
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // ForOp
272 //===----------------------------------------------------------------------===//
273 
274 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
275  Value ub, Value step, ValueRange iterArgs,
276  BodyBuilderFn bodyBuilder) {
277  result.addOperands({lb, ub, step});
278  result.addOperands(iterArgs);
279  for (Value v : iterArgs)
280  result.addTypes(v.getType());
281  Region *bodyRegion = result.addRegion();
282  bodyRegion->push_back(new Block);
283  Block &bodyBlock = bodyRegion->front();
284  bodyBlock.addArgument(builder.getIndexType(), result.location);
285  for (Value v : iterArgs)
286  bodyBlock.addArgument(v.getType(), v.getLoc());
287 
288  // Create the default terminator if the builder is not provided and if the
289  // iteration arguments are not provided. Otherwise, leave this to the caller
290  // because we don't know which values to return from the loop.
291  if (iterArgs.empty() && !bodyBuilder) {
292  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
293  } else if (bodyBuilder) {
294  OpBuilder::InsertionGuard guard(builder);
295  builder.setInsertionPointToStart(&bodyBlock);
296  bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
297  bodyBlock.getArguments().drop_front());
298  }
299 }
300 
302  if (auto cst = getStep().getDefiningOp<arith::ConstantIndexOp>())
303  if (cst.value() <= 0)
304  return emitOpError("constant step operand must be positive");
305 
306  auto opNumResults = getNumResults();
307  if (opNumResults == 0)
308  return success();
309  // If ForOp defines values, check that the number and types of
310  // the defined values match ForOp initial iter operands and backedge
311  // basic block arguments.
312  if (getNumIterOperands() != opNumResults)
313  return emitOpError(
314  "mismatch in number of loop-carried values and defined values");
315  return success();
316 }
317 
318 LogicalResult ForOp::verifyRegions() {
319  // Check that the body defines as single block argument for the induction
320  // variable.
321  auto *body = getBody();
322  if (!body->getArgument(0).getType().isIndex())
323  return emitOpError(
324  "expected body first argument to be an index argument for "
325  "the induction variable");
326 
327  auto opNumResults = getNumResults();
328  if (opNumResults == 0)
329  return success();
330 
331  if (getNumRegionIterArgs() != opNumResults)
332  return emitOpError(
333  "mismatch in number of basic block args and defined values");
334 
335  auto iterOperands = getIterOperands();
336  auto iterArgs = getRegionIterArgs();
337  auto opResults = getResults();
338  unsigned i = 0;
339  for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
340  if (std::get<0>(e).getType() != std::get<2>(e).getType())
341  return emitOpError() << "types mismatch between " << i
342  << "th iter operand and defined value";
343  if (std::get<1>(e).getType() != std::get<2>(e).getType())
344  return emitOpError() << "types mismatch between " << i
345  << "th iter region arg and defined value";
346 
347  i++;
348  }
349  return success();
350 }
351 
352 Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
353 
354 Optional<OpFoldResult> ForOp::getSingleLowerBound() {
355  return OpFoldResult(getLowerBound());
356 }
357 
358 Optional<OpFoldResult> ForOp::getSingleStep() {
359  return OpFoldResult(getStep());
360 }
361 
362 Optional<OpFoldResult> ForOp::getSingleUpperBound() {
363  return OpFoldResult(getUpperBound());
364 }
365 
366 /// Prints the initialization list in the form of
367 /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
368 /// where 'inner' values are assumed to be region arguments and 'outer' values
369 /// are regular SSA values.
371  Block::BlockArgListType blocksArgs,
372  ValueRange initializers,
373  StringRef prefix = "") {
374  assert(blocksArgs.size() == initializers.size() &&
375  "expected same length of arguments and initializers");
376  if (initializers.empty())
377  return;
378 
379  p << prefix << '(';
380  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
381  p << std::get<0>(it) << " = " << std::get<1>(it);
382  });
383  p << ")";
384 }
385 
386 void ForOp::print(OpAsmPrinter &p) {
387  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
388  << getUpperBound() << " step " << getStep();
389 
390  printInitializationList(p, getRegionIterArgs(), getIterOperands(),
391  " iter_args");
392  if (!getIterOperands().empty())
393  p << " -> (" << getIterOperands().getTypes() << ')';
394  p << ' ';
395  p.printRegion(getRegion(),
396  /*printEntryBlockArgs=*/false,
397  /*printBlockTerminators=*/hasIterOperands());
398  p.printOptionalAttrDict((*this)->getAttrs());
399 }
400 
401 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
402  auto &builder = parser.getBuilder();
403  Type indexType = builder.getIndexType();
404 
405  OpAsmParser::Argument inductionVariable;
406  inductionVariable.type = indexType;
407  OpAsmParser::UnresolvedOperand lb, ub, step;
408 
409  // Parse the induction variable followed by '='.
410  if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
411  // Parse loop bounds.
412  parser.parseOperand(lb) ||
413  parser.resolveOperand(lb, indexType, result.operands) ||
414  parser.parseKeyword("to") || parser.parseOperand(ub) ||
415  parser.resolveOperand(ub, indexType, result.operands) ||
416  parser.parseKeyword("step") || parser.parseOperand(step) ||
417  parser.resolveOperand(step, indexType, result.operands))
418  return failure();
419 
420  // Parse the optional initial iteration arguments.
423  regionArgs.push_back(inductionVariable);
424 
425  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
426  // Parse assignment list and results type list.
427  if (parser.parseAssignmentList(regionArgs, operands) ||
428  parser.parseArrowTypeList(result.types))
429  return failure();
430 
431  // Resolve input operands.
432  for (auto argOperandType :
433  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
434  Type type = std::get<2>(argOperandType);
435  std::get<0>(argOperandType).type = type;
436  if (parser.resolveOperand(std::get<1>(argOperandType), type,
437  result.operands))
438  return failure();
439  }
440  }
441 
442  if (regionArgs.size() != result.types.size() + 1)
443  return parser.emitError(
444  parser.getNameLoc(),
445  "mismatch in number of loop-carried values and defined values");
446 
447  // Parse the body region.
448  Region *body = result.addRegion();
449  if (parser.parseRegion(*body, regionArgs))
450  return failure();
451 
452  ForOp::ensureTerminator(*body, builder, result.location);
453 
454  // Parse the optional attribute list.
455  if (parser.parseOptionalAttrDict(result.attributes))
456  return failure();
457 
458  return success();
459 }
460 
461 Region &ForOp::getLoopBody() { return getRegion(); }
462 
464  auto ivArg = val.dyn_cast<BlockArgument>();
465  if (!ivArg)
466  return ForOp();
467  assert(ivArg.getOwner() && "unlinked block argument");
468  auto *containingOp = ivArg.getOwner()->getParentOp();
469  return dyn_cast_or_null<ForOp>(containingOp);
470 }
471 
472 /// Return operands used when entering the region at 'index'. These operands
473 /// correspond to the loop iterator operands, i.e., those excluding the
474 /// induction variable. LoopOp only has one region, so 0 is the only valid value
475 /// for `index`.
476 OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
477  assert(index && *index == 0 && "invalid region index");
478 
479  // The initial operands map to the loop arguments after the induction
480  // variable.
481  return getInitArgs();
482 }
483 
484 /// Given the region at `index`, or the parent operation if `index` is None,
485 /// return the successor regions. These are the regions that may be selected
486 /// during the flow of control. `operands` is a set of optional attributes that
487 /// correspond to a constant value for each operand, or null if that operand is
488 /// not a constant.
489 void ForOp::getSuccessorRegions(Optional<unsigned> index,
490  ArrayRef<Attribute> operands,
492  // If the predecessor is the ForOp, branch into the body using the iterator
493  // arguments.
494  if (!index) {
495  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
496  return;
497  }
498 
499  // Otherwise, the loop may branch back to itself or the parent operation.
500  assert(*index == 0 && "expected loop region");
501  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
502  regions.push_back(RegionSuccessor(getResults()));
503 }
504 
506  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
507  ValueRange steps, ValueRange iterArgs,
509  bodyBuilder) {
510  assert(lbs.size() == ubs.size() &&
511  "expected the same number of lower and upper bounds");
512  assert(lbs.size() == steps.size() &&
513  "expected the same number of lower bounds and steps");
514 
515  // If there are no bounds, call the body-building function and return early.
516  if (lbs.empty()) {
517  ValueVector results =
518  bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
519  : ValueVector();
520  assert(results.size() == iterArgs.size() &&
521  "loop nest body must return as many values as loop has iteration "
522  "arguments");
523  return LoopNest();
524  }
525 
526  // First, create the loop structure iteratively using the body-builder
527  // callback of `ForOp::build`. Do not create `YieldOp`s yet.
528  OpBuilder::InsertionGuard guard(builder);
531  loops.reserve(lbs.size());
532  ivs.reserve(lbs.size());
533  ValueRange currentIterArgs = iterArgs;
534  Location currentLoc = loc;
535  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
536  auto loop = builder.create<scf::ForOp>(
537  currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
538  [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
539  ValueRange args) {
540  ivs.push_back(iv);
541  // It is safe to store ValueRange args because it points to block
542  // arguments of a loop operation that we also own.
543  currentIterArgs = args;
544  currentLoc = nestedLoc;
545  });
546  // Set the builder to point to the body of the newly created loop. We don't
547  // do this in the callback because the builder is reset when the callback
548  // returns.
549  builder.setInsertionPointToStart(loop.getBody());
550  loops.push_back(loop);
551  }
552 
553  // For all loops but the innermost, yield the results of the nested loop.
554  for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
555  builder.setInsertionPointToEnd(loops[i].getBody());
556  builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
557  }
558 
559  // In the body of the innermost loop, call the body building function if any
560  // and yield its results.
561  builder.setInsertionPointToStart(loops.back().getBody());
562  ValueVector results = bodyBuilder
563  ? bodyBuilder(builder, currentLoc, ivs,
564  loops.back().getRegionIterArgs())
565  : ValueVector();
566  assert(results.size() == iterArgs.size() &&
567  "loop nest body must return as many values as loop has iteration "
568  "arguments");
569  builder.setInsertionPointToEnd(loops.back().getBody());
570  builder.create<scf::YieldOp>(loc, results);
571 
572  // Return the loops.
573  LoopNest res;
574  res.loops.assign(loops.begin(), loops.end());
575  return res;
576 }
577 
579  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
580  ValueRange steps,
581  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
582  // Delegate to the main function by wrapping the body builder.
583  return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
584  [&bodyBuilder](OpBuilder &nestedBuilder,
585  Location nestedLoc, ValueRange ivs,
586  ValueRange) -> ValueVector {
587  if (bodyBuilder)
588  bodyBuilder(nestedBuilder, nestedLoc, ivs);
589  return {};
590  });
591 }
592 
593 namespace {
594 // Fold away ForOp iter arguments when:
595 // 1) The op yields the iter arguments.
596 // 2) The iter arguments have no use and the corresponding outer region
597 // iterators (inputs) are yielded.
598 // 3) The iter arguments have no use and the corresponding (operation) results
599 // have no use.
600 //
601 // These arguments must be defined outside of
602 // the ForOp region and can just be forwarded after simplifying the op inits,
603 // yields and returns.
604 //
605 // The implementation uses `mergeBlockBefore` to steal the content of the
606 // original ForOp and avoid cloning.
607 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
609 
610  LogicalResult matchAndRewrite(scf::ForOp forOp,
611  PatternRewriter &rewriter) const final {
612  bool canonicalize = false;
613  Block &block = forOp.getRegion().front();
614  auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
615 
616  // An internal flat vector of block transfer
617  // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
618  // transformed block argument mappings. This plays the role of a
619  // BlockAndValueMapping for the particular use case of calling into
620  // `mergeBlockBefore`.
621  SmallVector<bool, 4> keepMask;
622  keepMask.reserve(yieldOp.getNumOperands());
623  SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
624  newResultValues;
625  newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
626  newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
627  newIterArgs.reserve(forOp.getNumIterOperands());
628  newYieldValues.reserve(yieldOp.getNumOperands());
629  newResultValues.reserve(forOp.getNumResults());
630  for (auto it : llvm::zip(forOp.getIterOperands(), // iter from outside
631  forOp.getRegionIterArgs(), // iter inside region
632  forOp.getResults(), // op results
633  yieldOp.getOperands() // iter yield
634  )) {
635  // Forwarded is `true` when:
636  // 1) The region `iter` argument is yielded.
637  // 2) The region `iter` argument has no use, and the corresponding iter
638  // operand (input) is yielded.
639  // 3) The region `iter` argument has no use, and the corresponding op
640  // result has no use.
641  bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
642  (std::get<1>(it).use_empty() &&
643  (std::get<0>(it) == std::get<3>(it) ||
644  std::get<2>(it).use_empty())));
645  keepMask.push_back(!forwarded);
646  canonicalize |= forwarded;
647  if (forwarded) {
648  newBlockTransferArgs.push_back(std::get<0>(it));
649  newResultValues.push_back(std::get<0>(it));
650  continue;
651  }
652  newIterArgs.push_back(std::get<0>(it));
653  newYieldValues.push_back(std::get<3>(it));
654  newBlockTransferArgs.push_back(Value()); // placeholder with null value
655  newResultValues.push_back(Value()); // placeholder with null value
656  }
657 
658  if (!canonicalize)
659  return failure();
660 
661  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
662  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
663  forOp.getStep(), newIterArgs);
664  newForOp->setAttrs(forOp->getAttrs());
665  Block &newBlock = newForOp.getRegion().front();
666 
667  // Replace the null placeholders with newly constructed values.
668  newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
669  for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
670  idx != e; ++idx) {
671  Value &blockTransferArg = newBlockTransferArgs[1 + idx];
672  Value &newResultVal = newResultValues[idx];
673  assert((blockTransferArg && newResultVal) ||
674  (!blockTransferArg && !newResultVal));
675  if (!blockTransferArg) {
676  blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
677  newResultVal = newForOp.getResult(collapsedIdx++);
678  }
679  }
680 
681  Block &oldBlock = forOp.getRegion().front();
682  assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
683  "unexpected argument size mismatch");
684 
685  // No results case: the scf::ForOp builder already created a zero
686  // result terminator. Merge before this terminator and just get rid of the
687  // original terminator that has been merged in.
688  if (newIterArgs.empty()) {
689  auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
690  rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
691  rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
692  rewriter.replaceOp(forOp, newResultValues);
693  return success();
694  }
695 
696  // No terminator case: merge and rewrite the merged terminator.
697  auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
698  OpBuilder::InsertionGuard g(rewriter);
699  rewriter.setInsertionPoint(mergedTerminator);
700  SmallVector<Value, 4> filteredOperands;
701  filteredOperands.reserve(newResultValues.size());
702  for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
703  if (keepMask[idx])
704  filteredOperands.push_back(mergedTerminator.getOperand(idx));
705  rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
706  filteredOperands);
707  };
708 
709  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
710  auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
711  cloneFilteredTerminator(mergedYieldOp);
712  rewriter.eraseOp(mergedYieldOp);
713  rewriter.replaceOp(forOp, newResultValues);
714  return success();
715  }
716 };
717 
718 /// Rewriting pattern that erases loops that are known not to iterate, replaces
719 /// single-iteration loops with their bodies, and removes empty loops that
720 /// iterate at least once and only return values defined outside of the loop.
721 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
723 
724  LogicalResult matchAndRewrite(ForOp op,
725  PatternRewriter &rewriter) const override {
726  // If the upper bound is the same as the lower bound, the loop does not
727  // iterate, just remove it.
728  if (op.getLowerBound() == op.getUpperBound()) {
729  rewriter.replaceOp(op, op.getIterOperands());
730  return success();
731  }
732 
733  auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
734  auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
735  if (!lb || !ub)
736  return failure();
737 
738  // If the loop is known to have 0 iterations, remove it.
739  llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
740  llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
741  if (lbValue.sge(ubValue)) {
742  rewriter.replaceOp(op, op.getIterOperands());
743  return success();
744  }
745 
746  auto step = op.getStep().getDefiningOp<arith::ConstantOp>();
747  if (!step)
748  return failure();
749 
750  // If the loop is known to have 1 iteration, inline its body and remove the
751  // loop.
752  llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
753  if ((lbValue + stepValue).sge(ubValue)) {
754  SmallVector<Value, 4> blockArgs;
755  blockArgs.reserve(op.getNumIterOperands() + 1);
756  blockArgs.push_back(op.getLowerBound());
757  llvm::append_range(blockArgs, op.getIterOperands());
758  replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
759  return success();
760  }
761 
762  // Now we are left with loops that have more than 1 iterations.
763  Block &block = op.getRegion().front();
764  if (!llvm::hasSingleElement(block))
765  return failure();
766  // If the loop is empty, iterates at least once, and only returns values
767  // defined outside of the loop, remove it and replace it with yield values.
768  auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
769  auto yieldOperands = yieldOp.getOperands();
770  if (llvm::any_of(yieldOperands,
771  [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
772  return failure();
773  rewriter.replaceOp(op, yieldOperands);
774  return success();
775  }
776 };
777 
778 /// Perform a replacement of one iter OpOperand of an scf.for to the
779 /// `replacement` value which is expected to be the source of a tensor.cast.
780 /// tensor.cast ops are inserted inside the block to account for the type cast.
781 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
782  OpOperand &operand,
783  Value replacement) {
784  Type oldType = operand.get().getType(), newType = replacement.getType();
785  assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
786  "expected ranked tensor types");
787 
788  // 1. Create new iter operands, exactly 1 is replaced.
789  ForOp forOp = cast<ForOp>(operand.getOwner());
790  assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
791  "expected an iter OpOperand");
792  if (operand.get().getType() == replacement.getType())
793  return forOp;
794  SmallVector<Value> newIterOperands;
795  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
796  if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
797  newIterOperands.push_back(replacement);
798  continue;
799  }
800  newIterOperands.push_back(opOperand.get());
801  }
802 
803  // 2. Create the new forOp shell.
804  scf::ForOp newForOp = rewriter.create<scf::ForOp>(
805  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
806  forOp.getStep(), newIterOperands);
807  newForOp->setAttrs(forOp->getAttrs());
808  Block &newBlock = newForOp.getRegion().front();
809  SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
810  newBlock.getArguments().end());
811 
812  // 3. Inject an incoming cast op at the beginning of the block for the bbArg
813  // corresponding to the `replacement` value.
814  OpBuilder::InsertionGuard g(rewriter);
815  rewriter.setInsertionPoint(&newBlock, newBlock.begin());
816  BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
817  newForOp->getOpOperand(operand.getOperandNumber()));
818  Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
819  newRegionIterArg);
820  newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
821 
822  // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
823  Block &oldBlock = forOp.getRegion().front();
824  rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
825 
826  // 5. Inject an outgoing cast op at the end of the block and yield it instead.
827  auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
828  rewriter.setInsertionPoint(clonedYieldOp);
829  unsigned yieldIdx =
830  newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
831  Value castOut = rewriter.create<tensor::CastOp>(
832  newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
833  SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
834  newYieldOperands[yieldIdx] = castOut;
835  rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
836  rewriter.eraseOp(clonedYieldOp);
837 
838  // 6. Inject an outgoing cast op after the forOp.
839  rewriter.setInsertionPointAfter(newForOp);
840  SmallVector<Value> newResults = newForOp.getResults();
841  newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
842  newForOp.getLoc(), oldType, newResults[yieldIdx]);
843 
844  return newForOp;
845 }
846 
847 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
848 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
849 ///
850 /// ```
851 /// %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
852 /// %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
853 /// -> (tensor<?x?xf32>) {
854 /// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
855 /// scf.yield %2 : tensor<?x?xf32>
856 /// }
857 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
858 /// use_of(%2)
859 /// ```
860 ///
861 /// folds into:
862 ///
863 /// ```
864 /// %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
865 /// -> (tensor<32x1024xf32>) {
866 /// %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
867 /// %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
868 /// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
869 /// scf.yield %4 : tensor<32x1024xf32>
870 /// }
871 /// use_of(%0)
872 /// ```
873 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
875 
876  LogicalResult matchAndRewrite(ForOp op,
877  PatternRewriter &rewriter) const override {
878  for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
879  OpOperand &iterOpOperand = std::get<0>(it);
880  auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
881  if (!incomingCast)
882  continue;
883  if (!std::get<1>(it).hasOneUse())
884  continue;
885  auto outgoingCastOp =
886  dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
887  if (!outgoingCastOp)
888  continue;
889 
890  // Must be a tensor.cast op pair with matching types.
891  if (outgoingCastOp.getResult().getType() !=
892  incomingCast.getSource().getType())
893  continue;
894 
895  // Create a new ForOp with that iter operand replaced.
896  auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
897  incomingCast.getSource());
898 
899  // Insert outgoing cast and use it to replace the corresponding result.
900  rewriter.setInsertionPointAfter(newForOp);
901  SmallVector<Value> replacements = newForOp.getResults();
902  unsigned returnIdx =
903  iterOpOperand.getOperandNumber() - op.getNumControlOperands();
904  replacements[returnIdx] = rewriter.create<tensor::CastOp>(
905  op.getLoc(), incomingCast.getDest().getType(),
906  replacements[returnIdx]);
907  rewriter.replaceOp(op, replacements);
908  return success();
909  }
910  return failure();
911  }
912 };
913 
914 /// Canonicalize the iter_args of an scf::ForOp that involve a
915 /// `bufferization.to_tensor` and for which only the last loop iteration is
916 /// actually visible outside of the loop. The canonicalization looks for a
917 /// pattern such as:
918 /// ```
919 /// %t0 = ... : tensor_type
920 /// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
921 /// ...
922 /// // %m is either buffer_cast(%bb00) or defined above the loop
923 /// %m... : memref_type
924 /// ... // uses of %m with potential inplace updates
925 /// %new_tensor = bufferization.to_tensor %m : memref_type
926 /// ...
927 /// scf.yield %new_tensor : tensor_type
928 /// }
929 /// ```
930 ///
931 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
932 /// `%m = buffer_cast %bb0` op that feeds into the yielded
933 /// `bufferization.to_tensor` op.
934 ///
935 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
936 /// occurs between `bufferization.to_tensor and yield then the value %0
937 /// visible outside of the loop is the last `bufferization.to_tensor`
938 /// produced in the loop.
939 ///
940 /// For now, we approximate the absence of aliasing by only supporting the case
941 /// when the bufferization.to_tensor is the operation immediately preceding
942 /// the yield.
943 //
944 /// The canonicalization rewrites the pattern as:
945 /// ```
946 /// // %m is either a buffer_cast or defined above
947 /// %m... : memref_type
948 /// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
949 /// ... // uses of %m with potential inplace updates
950 /// scf.yield %bb0: tensor_type
951 /// }
952 /// %0 = bufferization.to_tensor %m : memref_type
953 /// ```
954 ///
955 /// A later bbArg canonicalization will further rewrite as:
956 /// ```
957 /// // %m is either a buffer_cast or defined above
958 /// %m... : memref_type
959 /// scf.for ... { // no iter_args
960 /// ... // uses of %m with potential inplace updates
961 /// }
962 /// %0 = bufferization.to_tensor %m : memref_type
963 /// ```
964 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
966 
967  LogicalResult matchAndRewrite(ForOp forOp,
968  PatternRewriter &rewriter) const override {
969  assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
970  "unexpected multiple blocks");
971 
972  Location loc = forOp.getLoc();
973  DenseMap<Value, Value> replacements;
974  for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
975  unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
976  auto yieldOp =
977  cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
978  Value yieldVal = yieldOp->getOperand(idx);
979  auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
980  bool isTensor = bbArg.getType().isa<TensorType>();
981 
982  bufferization::ToMemrefOp tensorToMemref;
983  // Either bbArg has no use or it has a single buffer_cast use.
984  if (bbArg.hasOneUse())
985  tensorToMemref =
986  dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
987  if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
988  continue;
989  // If tensorToMemref is present, it must feed into the `ToTensorOp`.
990  if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
991  continue;
992  // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
993  // must be before `ToTensorOp` in the block so that the lastWrite
994  // property is not subject to additional side-effects.
995  // For now, we only support the case when ToTensorOp appears
996  // immediately before the terminator.
997  if (tensorLoadOp->getNextNode() != yieldOp)
998  continue;
999 
1000  // Clone the optional tensorToMemref before forOp.
1001  if (tensorToMemref) {
1002  rewriter.setInsertionPoint(forOp);
1003  rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
1004  tensorToMemref, tensorToMemref.getMemref().getType(),
1005  tensorToMemref.getTensor());
1006  }
1007 
1008  // Clone the tensorLoad after forOp.
1009  rewriter.setInsertionPointAfter(forOp);
1010  Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
1011  loc, tensorLoadOp.getMemref());
1012  Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
1013  replacements.insert(std::make_pair(forOpResult, newTensorLoad));
1014 
1015  // Make the terminator just yield the bbArg, the old tensorLoadOp + the
1016  // old bbArg (that is now directly yielded) will canonicalize away.
1017  rewriter.startRootUpdate(yieldOp);
1018  yieldOp.setOperand(idx, bbArg);
1019  rewriter.finalizeRootUpdate(yieldOp);
1020  }
1021  if (replacements.empty())
1022  return failure();
1023 
1024  // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
1025  // replaces the whole op and erase it unconditionally. This is wrong for
1026  // `forOp` as it generally contains ops with side effects.
1027  // Instead, use `rewriter.replaceOpWithIf`.
1028  SmallVector<Value> newResults;
1029  newResults.reserve(forOp.getNumResults());
1030  for (Value v : forOp.getResults()) {
1031  auto it = replacements.find(v);
1032  newResults.push_back((it != replacements.end()) ? it->second : v);
1033  }
1034  unsigned idx = 0;
1035  rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
1036  return op.get() != newResults[idx++];
1037  });
1038  return success();
1039  }
1040 };
1041 } // namespace
1042 
1043 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1044  MLIRContext *context) {
1045  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1046  LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // ForeachThreadOp
1051 //===----------------------------------------------------------------------===//
1052 
1054  // Call terminator's verify to produce most informative error messages.
1055  if (failed(getTerminator().verify()))
1056  return failure();
1057 
1058  // Check that the body defines as single block argument for the thread index.
1059  auto *body = getBody();
1060  if (body->getNumArguments() != getRank())
1061  return emitOpError("region expects ") << getRank() << " arguments";
1062 
1063  // Verify consistency between the result types and the terminator.
1064  auto terminatorTypes = getTerminator().getYieldedTypes();
1065  auto opResults = getResults();
1066  if (opResults.size() != terminatorTypes.size())
1067  return emitOpError("produces ")
1068  << opResults.size() << " results, but its terminator yields "
1069  << terminatorTypes.size() << " value(s)";
1070  unsigned i = 0;
1071  for (auto e : llvm::zip(terminatorTypes, opResults)) {
1072  if (std::get<0>(e) != std::get<1>(e).getType())
1073  return emitOpError() << "type mismatch between result " << i << " ("
1074  << std::get<1>(e).getType() << ") and terminator ("
1075  << std::get<0>(e) << ")";
1076  i++;
1077  }
1078  return success();
1079 }
1080 
1082  p << " (";
1083  llvm::interleaveComma(getThreadIndices(), p);
1084  p << ") in (";
1085  llvm::interleaveComma(getNumThreads(), p);
1086  p << ") -> (" << getResultTypes() << ") ";
1087  p.printRegion(getRegion(),
1088  /*printEntryBlockArgs=*/false,
1089  /*printBlockTerminators=*/getNumResults() > 0);
1090  p.printOptionalAttrDict(getOperation()->getAttrs());
1091 }
1092 
1093 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
1094  OperationState &result) {
1095  auto &builder = parser.getBuilder();
1096  // Parse an opening `(` followed by thread index variables followed by `)`
1097  // TODO: when we can refer to such "induction variable"-like handles from the
1098  // declarative assembly format, we can implement the parser as a custom hook.
1100  if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
1101  return failure();
1102 
1103  // Parse `in` threadNums.
1105  if (parser.parseKeyword("in") ||
1106  parser.parseOperandList(threadNums, threadIndices.size(),
1108  parser.resolveOperands(threadNums, builder.getIndexType(),
1109  result.operands))
1110  return failure();
1111 
1112  // Parse optional results.
1113  if (parser.parseOptionalArrowTypeList(result.types))
1114  return failure();
1115 
1116  // Parse region.
1117  std::unique_ptr<Region> region = std::make_unique<Region>();
1118  for (auto &idx : threadIndices)
1119  idx.type = builder.getIndexType();
1120  if (parser.parseRegion(*region, threadIndices))
1121  return failure();
1122 
1123  // Ensure terminator and move region.
1124  OpBuilder b(builder.getContext());
1125  ForeachThreadOp::ensureTerminator(*region, b, result.location);
1126  result.addRegion(std::move(region));
1127 
1128  // Parse the optional attribute list.
1129  if (parser.parseOptionalAttrDict(result.attributes))
1130  return failure();
1131 
1132  return success();
1133 }
1134 
1135 // Bodyless builder, result types must be specified.
1136 void ForeachThreadOp::build(mlir::OpBuilder &builder,
1137  mlir::OperationState &result, TypeRange resultTypes,
1138  ValueRange numThreads,
1139  ArrayRef<int64_t> threadDimMapping) {
1140  result.addOperands(numThreads);
1141  result.addAttribute(
1142  // TODO: getThreadDimMappingAttrName() but it is not a static member.
1143  "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1144 
1145  Region *bodyRegion = result.addRegion();
1146  OpBuilder::InsertionGuard g(builder);
1147  // createBlock sets the IP inside the block.
1148  // Generally we would guard against that but the default ensureTerminator impl
1149  // expects it ..
1150  builder.createBlock(bodyRegion);
1151  Block &bodyBlock = bodyRegion->front();
1152  bodyBlock.addArguments(
1153  SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1154  SmallVector<Location>(numThreads.size(), result.location));
1155  ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
1156  result.addTypes(resultTypes);
1157 }
1158 
1159 // Builder that takes a bodyBuilder lambda, result types are inferred from
1160 // the terminator.
1161 void ForeachThreadOp::build(
1162  mlir::OpBuilder &builder, mlir::OperationState &result,
1163  ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
1164  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
1165  result.addOperands(numThreads);
1166  result.addAttribute(
1167  // TODO: getThreadDimMappingAttrName() but it is not a static member.
1168  "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
1169 
1170  OpBuilder::InsertionGuard g(builder);
1171  Region *bodyRegion = result.addRegion();
1172  builder.createBlock(bodyRegion);
1173  Block &bodyBlock = bodyRegion->front();
1174  bodyBlock.addArguments(
1175  SmallVector<Type>(numThreads.size(), builder.getIndexType()),
1176  SmallVector<Location>(numThreads.size(), result.location));
1177 
1178  OpBuilder::InsertionGuard guard(builder);
1179  builder.setInsertionPointToStart(&bodyBlock);
1180  bodyBuilder(builder, result.location, bodyBlock.getArguments());
1181  auto terminator =
1182  llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
1183  assert(terminator &&
1184  "expected bodyBuilder to create PerformConcurrentlyOp terminator");
1185  result.addTypes(terminator.getYieldedTypes());
1186 }
1187 
1188 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
1189 // unaware of the fact that our terminator also needs a region to be
1190 // well-formed. We override it here to ensure that we do the right thing.
1191 void ForeachThreadOp::ensureTerminator(Region &region, OpBuilder &builder,
1192  Location loc) {
1194  ForeachThreadOp>::ensureTerminator(region, builder, loc);
1195  auto terminator =
1196  llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
1197  if (terminator.getRegion().empty())
1198  builder.createBlock(&terminator.getRegion());
1199 }
1200 
1201 PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
1202  return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
1203 }
1204 
1206  auto tidxArg = val.dyn_cast<BlockArgument>();
1207  if (!tidxArg)
1208  return ForeachThreadOp();
1209  assert(tidxArg.getOwner() && "unlinked block argument");
1210  auto *containingOp = tidxArg.getOwner()->getParentOp();
1211  return dyn_cast<ForeachThreadOp>(containingOp);
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // PerformConcurrentlyOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 // Build a PerformConcurrentlyOp with mixed static and dynamic entries.
1219 void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
1221  Region *bodyRegion = result.addRegion();
1222  b.createBlock(bodyRegion);
1223 }
1224 
1226  // TODO: PerformConcurrentlyOpInterface.
1227  for (const Operation &op : getRegion().front().getOperations()) {
1228  if (!isa<tensor::ParallelInsertSliceOp>(op)) {
1229  return this->emitOpError("expected only ")
1230  << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
1231  }
1232  }
1233  return success();
1234 }
1235 
1237  p << " ";
1238  p.printRegion(getRegion(),
1239  /*printEntryBlockArgs=*/false,
1240  /*printBlockTerminators=*/false);
1241  p.printOptionalAttrDict(getOperation()->getAttrs());
1242 }
1243 
1244 ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
1245  OperationState &result) {
1246  auto &builder = parser.getBuilder();
1247 
1249  std::unique_ptr<Region> region = std::make_unique<Region>();
1250  if (parser.parseRegion(*region, regionOperands))
1251  return failure();
1252 
1253  if (region->empty())
1254  OpBuilder(builder.getContext()).createBlock(region.get());
1255  result.addRegion(std::move(region));
1256 
1257  // Parse the optional attribute list.
1258  if (parser.parseOptionalAttrDict(result.attributes))
1259  return failure();
1260  return success();
1261 }
1262 
1263 OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
1264  return getOperation()->getParentOp()->getResult(idx);
1265 }
1266 
1267 SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
1268  return llvm::to_vector<4>(
1269  llvm::map_range(getYieldingOps(), [](Operation &op) {
1270  auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
1271  return insertSliceOp ? insertSliceOp.yieldedType() : Type();
1272  }));
1273 }
1274 
1275 llvm::iterator_range<Block::iterator> PerformConcurrentlyOp::getYieldingOps() {
1276  return getRegion().front().getOperations();
1277 }
1278 
1279 //===----------------------------------------------------------------------===//
1280 // IfOp
1281 //===----------------------------------------------------------------------===//
1282 
1284  assert(a && "expected non-empty operation");
1285  assert(b && "expected non-empty operation");
1286 
1287  IfOp ifOp = a->getParentOfType<IfOp>();
1288  while (ifOp) {
1289  // Check if b is inside ifOp. (We already know that a is.)
1290  if (ifOp->isProperAncestor(b))
1291  // b is contained in ifOp. a and b are in mutually exclusive branches if
1292  // they are in different blocks of ifOp.
1293  return static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*a)) !=
1294  static_cast<bool>(ifOp.thenBlock()->findAncestorOpInBlock(*b));
1295  // Check next enclosing IfOp.
1296  ifOp = ifOp->getParentOfType<IfOp>();
1297  }
1298 
1299  // Could not find a common IfOp among a's and b's ancestors.
1300  return false;
1301 }
1302 
1303 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1304  bool withElseRegion) {
1305  build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1306 }
1307 
1308 void IfOp::build(OpBuilder &builder, OperationState &result,
1309  TypeRange resultTypes, Value cond, bool withElseRegion) {
1310  auto addTerminator = [&](OpBuilder &nested, Location loc) {
1311  if (resultTypes.empty())
1312  IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1313  loc);
1314  };
1315 
1316  build(builder, result, resultTypes, cond, addTerminator,
1317  withElseRegion ? addTerminator
1318  : function_ref<void(OpBuilder &, Location)>());
1319 }
1320 
1321 void IfOp::build(OpBuilder &builder, OperationState &result,
1322  TypeRange resultTypes, Value cond,
1323  function_ref<void(OpBuilder &, Location)> thenBuilder,
1324  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1325  assert(thenBuilder && "the builder callback for 'then' must be present");
1326 
1327  result.addOperands(cond);
1328  result.addTypes(resultTypes);
1329 
1330  OpBuilder::InsertionGuard guard(builder);
1331  Region *thenRegion = result.addRegion();
1332  builder.createBlock(thenRegion);
1333  thenBuilder(builder, result.location);
1334 
1335  Region *elseRegion = result.addRegion();
1336  if (!elseBuilder)
1337  return;
1338 
1339  builder.createBlock(elseRegion);
1340  elseBuilder(builder, result.location);
1341 }
1342 
1343 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1344  function_ref<void(OpBuilder &, Location)> thenBuilder,
1345  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1346  build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1347 }
1348 
1350  if (getNumResults() != 0 && getElseRegion().empty())
1351  return emitOpError("must have an else block if defining values");
1352  return success();
1353 }
1354 
1355 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1356  // Create the regions for 'then'.
1357  result.regions.reserve(2);
1358  Region *thenRegion = result.addRegion();
1359  Region *elseRegion = result.addRegion();
1360 
1361  auto &builder = parser.getBuilder();
1363  Type i1Type = builder.getIntegerType(1);
1364  if (parser.parseOperand(cond) ||
1365  parser.resolveOperand(cond, i1Type, result.operands))
1366  return failure();
1367  // Parse optional results type list.
1368  if (parser.parseOptionalArrowTypeList(result.types))
1369  return failure();
1370  // Parse the 'then' region.
1371  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1372  return failure();
1373  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1374 
1375  // If we find an 'else' keyword then parse the 'else' region.
1376  if (!parser.parseOptionalKeyword("else")) {
1377  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1378  return failure();
1379  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1380  }
1381 
1382  // Parse the optional attribute list.
1383  if (parser.parseOptionalAttrDict(result.attributes))
1384  return failure();
1385  return success();
1386 }
1387 
1388 void IfOp::print(OpAsmPrinter &p) {
1389  bool printBlockTerminators = false;
1390 
1391  p << " " << getCondition();
1392  if (!getResults().empty()) {
1393  p << " -> (" << getResultTypes() << ")";
1394  // Print yield explicitly if the op defines values.
1395  printBlockTerminators = true;
1396  }
1397  p << ' ';
1398  p.printRegion(getThenRegion(),
1399  /*printEntryBlockArgs=*/false,
1400  /*printBlockTerminators=*/printBlockTerminators);
1401 
1402  // Print the 'else' regions if it exists and has a block.
1403  auto &elseRegion = getElseRegion();
1404  if (!elseRegion.empty()) {
1405  p << " else ";
1406  p.printRegion(elseRegion,
1407  /*printEntryBlockArgs=*/false,
1408  /*printBlockTerminators=*/printBlockTerminators);
1409  }
1410 
1411  p.printOptionalAttrDict((*this)->getAttrs());
1412 }
1413 
1414 /// Given the region at `index`, or the parent operation if `index` is None,
1415 /// return the successor regions. These are the regions that may be selected
1416 /// during the flow of control. `operands` is a set of optional attributes that
1417 /// correspond to a constant value for each operand, or null if that operand is
1418 /// not a constant.
1419 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1420  ArrayRef<Attribute> operands,
1422  // The `then` and the `else` region branch back to the parent operation.
1423  if (index) {
1424  regions.push_back(RegionSuccessor(getResults()));
1425  return;
1426  }
1427 
1428  // Don't consider the else region if it is empty.
1429  Region *elseRegion = &this->getElseRegion();
1430  if (elseRegion->empty())
1431  elseRegion = nullptr;
1432 
1433  // Otherwise, the successor is dependent on the condition.
1434  bool condition;
1435  if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1436  condition = condAttr.getValue().isOneValue();
1437  } else {
1438  // If the condition isn't constant, both regions may be executed.
1439  regions.push_back(RegionSuccessor(&getThenRegion()));
1440  // If the else region does not exist, it is not a viable successor.
1441  if (elseRegion)
1442  regions.push_back(RegionSuccessor(elseRegion));
1443  return;
1444  }
1445 
1446  // Add the successor regions using the condition.
1447  regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion));
1448 }
1449 
1450 LogicalResult IfOp::fold(ArrayRef<Attribute> operands,
1451  SmallVectorImpl<OpFoldResult> &results) {
1452  // if (!c) then A() else B() -> if c then B() else A()
1453  if (getElseRegion().empty())
1454  return failure();
1455 
1456  arith::XOrIOp xorStmt = getCondition().getDefiningOp<arith::XOrIOp>();
1457  if (!xorStmt)
1458  return failure();
1459 
1460  if (!matchPattern(xorStmt.getRhs(), m_One()))
1461  return failure();
1462 
1463  getConditionMutable().assign(xorStmt.getLhs());
1464  Block *thenBlock = &getThenRegion().front();
1465  // It would be nicer to use iplist::swap, but that has no implemented
1466  // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224
1467  getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(),
1468  getElseRegion().getBlocks());
1469  getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(),
1470  getThenRegion().getBlocks(), thenBlock);
1471  return success();
1472 }
1473 
1474 void IfOp::getRegionInvocationBounds(
1475  ArrayRef<Attribute> operands,
1476  SmallVectorImpl<InvocationBounds> &invocationBounds) {
1477  if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
1478  // If the condition is known, then one region is known to be executed once
1479  // and the other zero times.
1480  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1481  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1482  } else {
1483  // Non-constant condition. Each region may be executed 0 or 1 times.
1484  invocationBounds.assign(2, {0, 1});
1485  }
1486 }
1487 
1488 namespace {
1489 // Pattern to remove unused IfOp results.
1490 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1492 
1493  void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1494  PatternRewriter &rewriter) const {
1495  // Move all operations to the destination block.
1496  rewriter.mergeBlocks(source, dest);
1497  // Replace the yield op by one that returns only the used values.
1498  auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1499  SmallVector<Value, 4> usedOperands;
1500  llvm::transform(usedResults, std::back_inserter(usedOperands),
1501  [&](OpResult result) {
1502  return yieldOp.getOperand(result.getResultNumber());
1503  });
1504  rewriter.updateRootInPlace(yieldOp,
1505  [&]() { yieldOp->setOperands(usedOperands); });
1506  }
1507 
1508  LogicalResult matchAndRewrite(IfOp op,
1509  PatternRewriter &rewriter) const override {
1510  // Compute the list of used results.
1511  SmallVector<OpResult, 4> usedResults;
1512  llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1513  [](OpResult result) { return !result.use_empty(); });
1514 
1515  // Replace the operation if only a subset of its results have uses.
1516  if (usedResults.size() == op.getNumResults())
1517  return failure();
1518 
1519  // Compute the result types of the replacement operation.
1520  SmallVector<Type, 4> newTypes;
1521  llvm::transform(usedResults, std::back_inserter(newTypes),
1522  [](OpResult result) { return result.getType(); });
1523 
1524  // Create a replacement operation with empty then and else regions.
1525  auto emptyBuilder = [](OpBuilder &, Location) {};
1526  auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition(),
1527  emptyBuilder, emptyBuilder);
1528 
1529  // Move the bodies and replace the terminators (note there is a then and
1530  // an else region since the operation returns results).
1531  transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1532  transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1533 
1534  // Replace the operation by the new one.
1535  SmallVector<Value, 4> repResults(op.getNumResults());
1536  for (const auto &en : llvm::enumerate(usedResults))
1537  repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1538  rewriter.replaceOp(op, repResults);
1539  return success();
1540  }
1541 };
1542 
1543 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1545 
1546  LogicalResult matchAndRewrite(IfOp op,
1547  PatternRewriter &rewriter) const override {
1548  auto constant = op.getCondition().getDefiningOp<arith::ConstantOp>();
1549  if (!constant)
1550  return failure();
1551 
1552  if (constant.getValue().cast<BoolAttr>().getValue())
1553  replaceOpWithRegion(rewriter, op, op.getThenRegion());
1554  else if (!op.getElseRegion().empty())
1555  replaceOpWithRegion(rewriter, op, op.getElseRegion());
1556  else
1557  rewriter.eraseOp(op);
1558 
1559  return success();
1560  }
1561 };
1562 
1563 /// Hoist any yielded results whose operands are defined outside
1564 /// the if, to a select instruction.
1565 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1567 
1568  LogicalResult matchAndRewrite(IfOp op,
1569  PatternRewriter &rewriter) const override {
1570  if (op->getNumResults() == 0)
1571  return failure();
1572 
1573  auto cond = op.getCondition();
1574  auto thenYieldArgs = op.thenYield().getOperands();
1575  auto elseYieldArgs = op.elseYield().getOperands();
1576 
1577  SmallVector<Type> nonHoistable;
1578  for (const auto &it :
1579  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1580  Value trueVal = std::get<0>(it.value());
1581  Value falseVal = std::get<1>(it.value());
1582  if (&op.getThenRegion() == trueVal.getParentRegion() ||
1583  &op.getElseRegion() == falseVal.getParentRegion())
1584  nonHoistable.push_back(trueVal.getType());
1585  }
1586  // Early exit if there aren't any yielded values we can
1587  // hoist outside the if.
1588  if (nonHoistable.size() == op->getNumResults())
1589  return failure();
1590 
1591  IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1592  if (replacement.thenBlock())
1593  rewriter.eraseBlock(replacement.thenBlock());
1594  replacement.getThenRegion().takeBody(op.getThenRegion());
1595  replacement.getElseRegion().takeBody(op.getElseRegion());
1596 
1597  SmallVector<Value> results(op->getNumResults());
1598  assert(thenYieldArgs.size() == results.size());
1599  assert(elseYieldArgs.size() == results.size());
1600 
1601  SmallVector<Value> trueYields;
1602  SmallVector<Value> falseYields;
1603  rewriter.setInsertionPoint(replacement);
1604  for (const auto &it :
1605  llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1606  Value trueVal = std::get<0>(it.value());
1607  Value falseVal = std::get<1>(it.value());
1608  if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1609  &replacement.getElseRegion() == falseVal.getParentRegion()) {
1610  results[it.index()] = replacement.getResult(trueYields.size());
1611  trueYields.push_back(trueVal);
1612  falseYields.push_back(falseVal);
1613  } else if (trueVal == falseVal)
1614  results[it.index()] = trueVal;
1615  else
1616  results[it.index()] = rewriter.create<arith::SelectOp>(
1617  op.getLoc(), cond, trueVal, falseVal);
1618  }
1619 
1620  rewriter.setInsertionPointToEnd(replacement.thenBlock());
1621  rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1622 
1623  rewriter.setInsertionPointToEnd(replacement.elseBlock());
1624  rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1625 
1626  rewriter.replaceOp(op, results);
1627  return success();
1628  }
1629 };
1630 
1631 /// Allow the true region of an if to assume the condition is true
1632 /// and vice versa. For example:
1633 ///
1634 /// scf.if %cmp {
1635 /// print(%cmp)
1636 /// }
1637 ///
1638 /// becomes
1639 ///
1640 /// scf.if %cmp {
1641 /// print(true)
1642 /// }
1643 ///
1644 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1646 
1647  LogicalResult matchAndRewrite(IfOp op,
1648  PatternRewriter &rewriter) const override {
1649  // Early exit if the condition is constant since replacing a constant
1650  // in the body with another constant isn't a simplification.
1651  if (op.getCondition().getDefiningOp<arith::ConstantOp>())
1652  return failure();
1653 
1654  bool changed = false;
1655  mlir::Type i1Ty = rewriter.getI1Type();
1656 
1657  // These variables serve to prevent creating duplicate constants
1658  // and hold constant true or false values.
1659  Value constantTrue = nullptr;
1660  Value constantFalse = nullptr;
1661 
1662  for (OpOperand &use :
1663  llvm::make_early_inc_range(op.getCondition().getUses())) {
1664  if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1665  changed = true;
1666 
1667  if (!constantTrue)
1668  constantTrue = rewriter.create<arith::ConstantOp>(
1669  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1670 
1671  rewriter.updateRootInPlace(use.getOwner(),
1672  [&]() { use.set(constantTrue); });
1673  } else if (op.getElseRegion().isAncestor(
1674  use.getOwner()->getParentRegion())) {
1675  changed = true;
1676 
1677  if (!constantFalse)
1678  constantFalse = rewriter.create<arith::ConstantOp>(
1679  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1680 
1681  rewriter.updateRootInPlace(use.getOwner(),
1682  [&]() { use.set(constantFalse); });
1683  }
1684  }
1685 
1686  return success(changed);
1687  }
1688 };
1689 
1690 /// Remove any statements from an if that are equivalent to the condition
1691 /// or its negation. For example:
1692 ///
1693 /// %res:2 = scf.if %cmp {
1694 /// yield something(), true
1695 /// } else {
1696 /// yield something2(), false
1697 /// }
1698 /// print(%res#1)
1699 ///
1700 /// becomes
1701 /// %res = scf.if %cmp {
1702 /// yield something()
1703 /// } else {
1704 /// yield something2()
1705 /// }
1706 /// print(%cmp)
1707 ///
1708 /// Additionally if both branches yield the same value, replace all uses
1709 /// of the result with the yielded value.
1710 ///
1711 /// %res:2 = scf.if %cmp {
1712 /// yield something(), %arg1
1713 /// } else {
1714 /// yield something2(), %arg1
1715 /// }
1716 /// print(%res#1)
1717 ///
1718 /// becomes
1719 /// %res = scf.if %cmp {
1720 /// yield something()
1721 /// } else {
1722 /// yield something2()
1723 /// }
1724 /// print(%arg1)
1725 ///
1726 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1728 
1729  LogicalResult matchAndRewrite(IfOp op,
1730  PatternRewriter &rewriter) const override {
1731  // Early exit if there are no results that could be replaced.
1732  if (op.getNumResults() == 0)
1733  return failure();
1734 
1735  auto trueYield =
1736  cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
1737  auto falseYield =
1738  cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
1739 
1740  rewriter.setInsertionPoint(op->getBlock(),
1741  op.getOperation()->getIterator());
1742  bool changed = false;
1743  Type i1Ty = rewriter.getI1Type();
1744  for (auto [trueResult, falseResult, opResult] :
1745  llvm::zip(trueYield.getResults(), falseYield.getResults(),
1746  op.getResults())) {
1747  if (trueResult == falseResult) {
1748  if (!opResult.use_empty()) {
1749  opResult.replaceAllUsesWith(trueResult);
1750  changed = true;
1751  }
1752  continue;
1753  }
1754 
1755  auto trueYield = trueResult.getDefiningOp<arith::ConstantOp>();
1756  if (!trueYield)
1757  continue;
1758 
1759  if (!trueYield.getType().isInteger(1))
1760  continue;
1761 
1762  auto falseYield = falseResult.getDefiningOp<arith::ConstantOp>();
1763  if (!falseYield)
1764  continue;
1765 
1766  bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1767  bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1768  if (!trueVal && falseVal) {
1769  if (!opResult.use_empty()) {
1770  Value notCond = rewriter.create<arith::XOrIOp>(
1771  op.getLoc(), op.getCondition(),
1772  rewriter.create<arith::ConstantOp>(
1773  op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1774  opResult.replaceAllUsesWith(notCond);
1775  changed = true;
1776  }
1777  }
1778  if (trueVal && !falseVal) {
1779  if (!opResult.use_empty()) {
1780  opResult.replaceAllUsesWith(op.getCondition());
1781  changed = true;
1782  }
1783  }
1784  }
1785  return success(changed);
1786  }
1787 };
1788 
1789 /// Merge any consecutive scf.if's with the same condition.
1790 ///
1791 /// scf.if %cond {
1792 /// firstCodeTrue();...
1793 /// } else {
1794 /// firstCodeFalse();...
1795 /// }
1796 /// %res = scf.if %cond {
1797 /// secondCodeTrue();...
1798 /// } else {
1799 /// secondCodeFalse();...
1800 /// }
1801 ///
1802 /// becomes
1803 /// %res = scf.if %cmp {
1804 /// firstCodeTrue();...
1805 /// secondCodeTrue();...
1806 /// } else {
1807 /// firstCodeFalse();...
1808 /// secondCodeFalse();...
1809 /// }
1810 struct CombineIfs : public OpRewritePattern<IfOp> {
1812 
1813  LogicalResult matchAndRewrite(IfOp nextIf,
1814  PatternRewriter &rewriter) const override {
1815  Block *parent = nextIf->getBlock();
1816  if (nextIf == &parent->front())
1817  return failure();
1818 
1819  auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1820  if (!prevIf)
1821  return failure();
1822 
1823  // Determine the logical then/else blocks when prevIf's
1824  // condition is used. Null means the block does not exist
1825  // in that case (e.g. empty else). If neither of these
1826  // are set, the two conditions cannot be compared.
1827  Block *nextThen = nullptr;
1828  Block *nextElse = nullptr;
1829  if (nextIf.getCondition() == prevIf.getCondition()) {
1830  nextThen = nextIf.thenBlock();
1831  if (!nextIf.getElseRegion().empty())
1832  nextElse = nextIf.elseBlock();
1833  }
1834  if (arith::XOrIOp notv =
1835  nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1836  if (notv.getLhs() == prevIf.getCondition() &&
1837  matchPattern(notv.getRhs(), m_One())) {
1838  nextElse = nextIf.thenBlock();
1839  if (!nextIf.getElseRegion().empty())
1840  nextThen = nextIf.elseBlock();
1841  }
1842  }
1843  if (arith::XOrIOp notv =
1844  prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
1845  if (notv.getLhs() == nextIf.getCondition() &&
1846  matchPattern(notv.getRhs(), m_One())) {
1847  nextElse = nextIf.thenBlock();
1848  if (!nextIf.getElseRegion().empty())
1849  nextThen = nextIf.elseBlock();
1850  }
1851  }
1852 
1853  if (!nextThen && !nextElse)
1854  return failure();
1855 
1856  SmallVector<Value> prevElseYielded;
1857  if (!prevIf.getElseRegion().empty())
1858  prevElseYielded = prevIf.elseYield().getOperands();
1859  // Replace all uses of return values of op within nextIf with the
1860  // corresponding yields
1861  for (auto it : llvm::zip(prevIf.getResults(),
1862  prevIf.thenYield().getOperands(), prevElseYielded))
1863  for (OpOperand &use :
1864  llvm::make_early_inc_range(std::get<0>(it).getUses())) {
1865  if (nextThen && nextThen->getParent()->isAncestor(
1866  use.getOwner()->getParentRegion())) {
1867  rewriter.startRootUpdate(use.getOwner());
1868  use.set(std::get<1>(it));
1869  rewriter.finalizeRootUpdate(use.getOwner());
1870  } else if (nextElse && nextElse->getParent()->isAncestor(
1871  use.getOwner()->getParentRegion())) {
1872  rewriter.startRootUpdate(use.getOwner());
1873  use.set(std::get<2>(it));
1874  rewriter.finalizeRootUpdate(use.getOwner());
1875  }
1876  }
1877 
1878  SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1879  llvm::append_range(mergedTypes, nextIf.getResultTypes());
1880 
1881  IfOp combinedIf = rewriter.create<IfOp>(
1882  nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false);
1883  rewriter.eraseBlock(&combinedIf.getThenRegion().back());
1884 
1885  rewriter.inlineRegionBefore(prevIf.getThenRegion(),
1886  combinedIf.getThenRegion(),
1887  combinedIf.getThenRegion().begin());
1888 
1889  if (nextThen) {
1890  YieldOp thenYield = combinedIf.thenYield();
1891  YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
1892  rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
1893  rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
1894 
1895  SmallVector<Value> mergedYields(thenYield.getOperands());
1896  llvm::append_range(mergedYields, thenYield2.getOperands());
1897  rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
1898  rewriter.eraseOp(thenYield);
1899  rewriter.eraseOp(thenYield2);
1900  }
1901 
1902  rewriter.inlineRegionBefore(prevIf.getElseRegion(),
1903  combinedIf.getElseRegion(),
1904  combinedIf.getElseRegion().begin());
1905 
1906  if (nextElse) {
1907  if (combinedIf.getElseRegion().empty()) {
1908  rewriter.inlineRegionBefore(*nextElse->getParent(),
1909  combinedIf.getElseRegion(),
1910  combinedIf.getElseRegion().begin());
1911  } else {
1912  YieldOp elseYield = combinedIf.elseYield();
1913  YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
1914  rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
1915 
1916  rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
1917 
1918  SmallVector<Value> mergedElseYields(elseYield.getOperands());
1919  llvm::append_range(mergedElseYields, elseYield2.getOperands());
1920 
1921  rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
1922  rewriter.eraseOp(elseYield);
1923  rewriter.eraseOp(elseYield2);
1924  }
1925  }
1926 
1927  SmallVector<Value> prevValues;
1928  SmallVector<Value> nextValues;
1929  for (const auto &pair : llvm::enumerate(combinedIf.getResults())) {
1930  if (pair.index() < prevIf.getNumResults())
1931  prevValues.push_back(pair.value());
1932  else
1933  nextValues.push_back(pair.value());
1934  }
1935  rewriter.replaceOp(prevIf, prevValues);
1936  rewriter.replaceOp(nextIf, nextValues);
1937  return success();
1938  }
1939 };
1940 
1941 /// Pattern to remove an empty else branch.
1942 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
1944 
1945  LogicalResult matchAndRewrite(IfOp ifOp,
1946  PatternRewriter &rewriter) const override {
1947  // Cannot remove else region when there are operation results.
1948  if (ifOp.getNumResults())
1949  return failure();
1950  Block *elseBlock = ifOp.elseBlock();
1951  if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
1952  return failure();
1953  auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
1954  rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
1955  newIfOp.getThenRegion().begin());
1956  rewriter.eraseOp(ifOp);
1957  return success();
1958  }
1959 };
1960 
1961 /// Convert nested `if`s into `arith.andi` + single `if`.
1962 ///
1963 /// scf.if %arg0 {
1964 /// scf.if %arg1 {
1965 /// ...
1966 /// scf.yield
1967 /// }
1968 /// scf.yield
1969 /// }
1970 /// becomes
1971 ///
1972 /// %0 = arith.andi %arg0, %arg1
1973 /// scf.if %0 {
1974 /// ...
1975 /// scf.yield
1976 /// }
1977 struct CombineNestedIfs : public OpRewritePattern<IfOp> {
1979 
1980  LogicalResult matchAndRewrite(IfOp op,
1981  PatternRewriter &rewriter) const override {
1982  auto nestedOps = op.thenBlock()->without_terminator();
1983  // Nested `if` must be the only op in block.
1984  if (!llvm::hasSingleElement(nestedOps))
1985  return failure();
1986 
1987  // If there is an else block, it can only yield
1988  if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
1989  return failure();
1990 
1991  auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
1992  if (!nestedIf)
1993  return failure();
1994 
1995  if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
1996  return failure();
1997 
1998  SmallVector<Value> thenYield(op.thenYield().getOperands());
1999  SmallVector<Value> elseYield;
2000  if (op.elseBlock())
2001  llvm::append_range(elseYield, op.elseYield().getOperands());
2002 
2003  // A list of indices for which we should upgrade the value yielded
2004  // in the else to a select.
2005  SmallVector<unsigned> elseYieldsToUpgradeToSelect;
2006 
2007  // If the outer scf.if yields a value produced by the inner scf.if,
2008  // only permit combining if the value yielded when the condition
2009  // is false in the outer scf.if is the same value yielded when the
2010  // inner scf.if condition is false.
2011  // Note that the array access to elseYield will not go out of bounds
2012  // since it must have the same length as thenYield, since they both
2013  // come from the same scf.if.
2014  for (const auto &tup : llvm::enumerate(thenYield)) {
2015  if (tup.value().getDefiningOp() == nestedIf) {
2016  auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
2017  if (nestedIf.elseYield().getOperand(nestedIdx) !=
2018  elseYield[tup.index()]) {
2019  return failure();
2020  }
2021  // If the correctness test passes, we will yield
2022  // corresponding value from the inner scf.if
2023  thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
2024  continue;
2025  }
2026 
2027  // Otherwise, we need to ensure the else block of the combined
2028  // condition still returns the same value when the outer condition is
2029  // true and the inner condition is false. This can be accomplished if
2030  // the then value is defined outside the outer scf.if and we replace the
2031  // value with a select that considers just the outer condition. Since
2032  // the else region contains just the yield, its yielded value is
2033  // defined outside the scf.if, by definition.
2034 
2035  // If the then value is defined within the scf.if, bail.
2036  if (tup.value().getParentRegion() == &op.getThenRegion()) {
2037  return failure();
2038  }
2039  elseYieldsToUpgradeToSelect.push_back(tup.index());
2040  }
2041 
2042  Location loc = op.getLoc();
2043  Value newCondition = rewriter.create<arith::AndIOp>(
2044  loc, op.getCondition(), nestedIf.getCondition());
2045  auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
2046 
2047  SmallVector<Value> results;
2048  llvm::append_range(results, newIf.getResults());
2049  rewriter.setInsertionPoint(newIf);
2050 
2051  for (auto idx : elseYieldsToUpgradeToSelect)
2052  results[idx] = rewriter.create<arith::SelectOp>(
2053  op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
2054 
2055  Block *newIfBlock = newIf.thenBlock();
2056  if (newIfBlock)
2057  rewriter.eraseOp(newIfBlock->getTerminator());
2058  else
2059  newIfBlock = rewriter.createBlock(&newIf.getThenRegion());
2060  rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
2061  rewriter.setInsertionPointToEnd(newIf.thenBlock());
2062  rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
2063  if (!elseYield.empty()) {
2064  rewriter.createBlock(&newIf.getElseRegion());
2065  rewriter.setInsertionPointToEnd(newIf.elseBlock());
2066  rewriter.create<YieldOp>(loc, elseYield);
2067  }
2068  rewriter.replaceOp(op, results);
2069  return success();
2070  }
2071 };
2072 
2073 } // namespace
2074 
2075 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2076  MLIRContext *context) {
2077  results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
2078  ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
2079  RemoveStaticCondition, RemoveUnusedResults,
2080  ReplaceIfYieldWithConditionOrValue>(context);
2081 }
2082 
2083 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
2084 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
2085 Block *IfOp::elseBlock() {
2086  Region &r = getElseRegion();
2087  if (r.empty())
2088  return nullptr;
2089  return &r.back();
2090 }
2091 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
2092 
2093 //===----------------------------------------------------------------------===//
2094 // ParallelOp
2095 //===----------------------------------------------------------------------===//
2096 
2097 void ParallelOp::build(
2098  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2099  ValueRange upperBounds, ValueRange steps, ValueRange initVals,
2101  bodyBuilderFn) {
2102  result.addOperands(lowerBounds);
2103  result.addOperands(upperBounds);
2104  result.addOperands(steps);
2105  result.addOperands(initVals);
2106  result.addAttribute(
2107  ParallelOp::getOperandSegmentSizeAttr(),
2108  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
2109  static_cast<int32_t>(upperBounds.size()),
2110  static_cast<int32_t>(steps.size()),
2111  static_cast<int32_t>(initVals.size())}));
2112  result.addTypes(initVals.getTypes());
2113 
2114  OpBuilder::InsertionGuard guard(builder);
2115  unsigned numIVs = steps.size();
2116  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
2117  SmallVector<Location, 8> argLocs(numIVs, result.location);
2118  Region *bodyRegion = result.addRegion();
2119  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
2120 
2121  if (bodyBuilderFn) {
2122  builder.setInsertionPointToStart(bodyBlock);
2123  bodyBuilderFn(builder, result.location,
2124  bodyBlock->getArguments().take_front(numIVs),
2125  bodyBlock->getArguments().drop_front(numIVs));
2126  }
2127  ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
2128 }
2129 
2130 void ParallelOp::build(
2131  OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
2132  ValueRange upperBounds, ValueRange steps,
2133  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2134  // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
2135  // we don't capture a reference to a temporary by constructing the lambda at
2136  // function level.
2137  auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
2138  Location nestedLoc, ValueRange ivs,
2139  ValueRange) {
2140  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2141  };
2143  if (bodyBuilderFn)
2144  wrapper = wrappedBuilderFn;
2145 
2146  build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
2147  wrapper);
2148 }
2149 
2151  // Check that there is at least one value in lowerBound, upperBound and step.
2152  // It is sufficient to test only step, because it is ensured already that the
2153  // number of elements in lowerBound, upperBound and step are the same.
2154  Operation::operand_range stepValues = getStep();
2155  if (stepValues.empty())
2156  return emitOpError(
2157  "needs at least one tuple element for lowerBound, upperBound and step");
2158 
2159  // Check whether all constant step values are positive.
2160  for (Value stepValue : stepValues)
2161  if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
2162  if (cst.value() <= 0)
2163  return emitOpError("constant step operand must be positive");
2164 
2165  // Check that the body defines the same number of block arguments as the
2166  // number of tuple elements in step.
2167  Block *body = getBody();
2168  if (body->getNumArguments() != stepValues.size())
2169  return emitOpError() << "expects the same number of induction variables: "
2170  << body->getNumArguments()
2171  << " as bound and step values: " << stepValues.size();
2172  for (auto arg : body->getArguments())
2173  if (!arg.getType().isIndex())
2174  return emitOpError(
2175  "expects arguments for the induction variable to be of index type");
2176 
2177  // Check that the yield has no results
2178  Operation *yield = body->getTerminator();
2179  if (yield->getNumOperands() != 0)
2180  return yield->emitOpError() << "not allowed to have operands inside '"
2181  << ParallelOp::getOperationName() << "'";
2182 
2183  // Check that the number of results is the same as the number of ReduceOps.
2184  SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
2185  auto resultsSize = getResults().size();
2186  auto reductionsSize = reductions.size();
2187  auto initValsSize = getInitVals().size();
2188  if (resultsSize != reductionsSize)
2189  return emitOpError() << "expects number of results: " << resultsSize
2190  << " to be the same as number of reductions: "
2191  << reductionsSize;
2192  if (resultsSize != initValsSize)
2193  return emitOpError() << "expects number of results: " << resultsSize
2194  << " to be the same as number of initial values: "
2195  << initValsSize;
2196 
2197  // Check that the types of the results and reductions are the same.
2198  for (auto resultAndReduce : llvm::zip(getResults(), reductions)) {
2199  auto resultType = std::get<0>(resultAndReduce).getType();
2200  auto reduceOp = std::get<1>(resultAndReduce);
2201  auto reduceType = reduceOp.getOperand().getType();
2202  if (resultType != reduceType)
2203  return reduceOp.emitOpError()
2204  << "expects type of reduce: " << reduceType
2205  << " to be the same as result type: " << resultType;
2206  }
2207  return success();
2208 }
2209 
2210 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
2211  auto &builder = parser.getBuilder();
2212  // Parse an opening `(` followed by induction variables followed by `)`
2215  return failure();
2216 
2217  // Parse loop bounds.
2219  if (parser.parseEqual() ||
2220  parser.parseOperandList(lower, ivs.size(),
2222  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2223  return failure();
2224 
2226  if (parser.parseKeyword("to") ||
2227  parser.parseOperandList(upper, ivs.size(),
2229  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2230  return failure();
2231 
2232  // Parse step values.
2234  if (parser.parseKeyword("step") ||
2235  parser.parseOperandList(steps, ivs.size(),
2237  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2238  return failure();
2239 
2240  // Parse init values.
2242  if (succeeded(parser.parseOptionalKeyword("init"))) {
2243  if (parser.parseOperandList(initVals, OpAsmParser::Delimiter::Paren))
2244  return failure();
2245  }
2246 
2247  // Parse optional results in case there is a reduce.
2248  if (parser.parseOptionalArrowTypeList(result.types))
2249  return failure();
2250 
2251  // Now parse the body.
2252  Region *body = result.addRegion();
2253  for (auto &iv : ivs)
2254  iv.type = builder.getIndexType();
2255  if (parser.parseRegion(*body, ivs))
2256  return failure();
2257 
2258  // Set `operand_segment_sizes` attribute.
2259  result.addAttribute(
2260  ParallelOp::getOperandSegmentSizeAttr(),
2261  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
2262  static_cast<int32_t>(upper.size()),
2263  static_cast<int32_t>(steps.size()),
2264  static_cast<int32_t>(initVals.size())}));
2265 
2266  // Parse attributes.
2267  if (parser.parseOptionalAttrDict(result.attributes) ||
2268  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2269  result.operands))
2270  return failure();
2271 
2272  // Add a terminator if none was parsed.
2273  ForOp::ensureTerminator(*body, builder, result.location);
2274  return success();
2275 }
2276 
2278  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2279  << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2280  if (!getInitVals().empty())
2281  p << " init (" << getInitVals() << ")";
2282  p.printOptionalArrowTypeList(getResultTypes());
2283  p << ' ';
2284  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2286  (*this)->getAttrs(),
2287  /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
2288 }
2289 
2290 Region &ParallelOp::getLoopBody() { return getRegion(); }
2291 
2293  auto ivArg = val.dyn_cast<BlockArgument>();
2294  if (!ivArg)
2295  return ParallelOp();
2296  assert(ivArg.getOwner() && "unlinked block argument");
2297  auto *containingOp = ivArg.getOwner()->getParentOp();
2298  return dyn_cast<ParallelOp>(containingOp);
2299 }
2300 
2301 namespace {
2302 // Collapse loop dimensions that perform a single iteration.
2303 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
2305 
2306  LogicalResult matchAndRewrite(ParallelOp op,
2307  PatternRewriter &rewriter) const override {
2308  BlockAndValueMapping mapping;
2309  // Compute new loop bounds that omit all single-iteration loop dimensions.
2310  SmallVector<Value, 2> newLowerBounds;
2311  SmallVector<Value, 2> newUpperBounds;
2312  SmallVector<Value, 2> newSteps;
2313  newLowerBounds.reserve(op.getLowerBound().size());
2314  newUpperBounds.reserve(op.getUpperBound().size());
2315  newSteps.reserve(op.getStep().size());
2316  for (auto [lowerBound, upperBound, step, iv] :
2317  llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
2318  op.getInductionVars())) {
2319  // Collect the statically known loop bounds.
2320  auto lowerBoundConstant =
2321  dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
2322  auto upperBoundConstant =
2323  dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
2324  auto stepConstant =
2325  dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
2326  // Replace the loop induction variable by the lower bound if the loop
2327  // performs a single iteration. Otherwise, copy the loop bounds.
2328  if (lowerBoundConstant && upperBoundConstant && stepConstant &&
2329  (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
2330  (upperBoundConstant.value() - lowerBoundConstant.value()) <=
2331  stepConstant.value()) {
2332  mapping.map(iv, lowerBound);
2333  } else {
2334  newLowerBounds.push_back(lowerBound);
2335  newUpperBounds.push_back(upperBound);
2336  newSteps.push_back(step);
2337  }
2338  }
2339  // Exit if none of the loop dimensions perform a single iteration.
2340  if (newLowerBounds.size() == op.getLowerBound().size())
2341  return failure();
2342 
2343  if (newLowerBounds.empty()) {
2344  // All of the loop dimensions perform a single iteration. Inline
2345  // loop body and nested ReduceOp's
2346  SmallVector<Value> results;
2347  results.reserve(op.getInitVals().size());
2348  for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
2349  auto reduce = dyn_cast<ReduceOp>(bodyOp);
2350  if (!reduce) {
2351  rewriter.clone(bodyOp, mapping);
2352  continue;
2353  }
2354  Block &reduceBlock = reduce.getReductionOperator().front();
2355  auto initValIndex = results.size();
2356  mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]);
2357  mapping.map(reduceBlock.getArgument(1),
2358  mapping.lookupOrDefault(reduce.getOperand()));
2359  for (auto &reduceBodyOp : reduceBlock.without_terminator())
2360  rewriter.clone(reduceBodyOp, mapping);
2361 
2362  auto result = mapping.lookupOrDefault(
2363  cast<ReduceReturnOp>(reduceBlock.getTerminator()).getResult());
2364  results.push_back(result);
2365  }
2366  rewriter.replaceOp(op, results);
2367  return success();
2368  }
2369  // Replace the parallel loop by lower-dimensional parallel loop.
2370  auto newOp =
2371  rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
2372  newSteps, op.getInitVals(), nullptr);
2373  // Clone the loop body and remap the block arguments of the collapsed loops
2374  // (inlining does not support a cancellable block argument mapping).
2375  rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
2376  newOp.getRegion().begin(), mapping);
2377  rewriter.replaceOp(op, newOp.getResults());
2378  return success();
2379  }
2380 };
2381 
2382 /// Removes parallel loops in which at least one lower/upper bound pair consists
2383 /// of the same values - such loops have an empty iteration domain.
2384 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
2386 
2387  LogicalResult matchAndRewrite(ParallelOp op,
2388  PatternRewriter &rewriter) const override {
2389  for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
2390  if (std::get<0>(dim) == std::get<1>(dim)) {
2391  rewriter.replaceOp(op, op.getInitVals());
2392  return success();
2393  }
2394  }
2395  return failure();
2396  }
2397 };
2398 
2399 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
2401 
2402  LogicalResult matchAndRewrite(ParallelOp op,
2403  PatternRewriter &rewriter) const override {
2404  Block &outerBody = op.getLoopBody().front();
2405  if (!llvm::hasSingleElement(outerBody.without_terminator()))
2406  return failure();
2407 
2408  auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
2409  if (!innerOp)
2410  return failure();
2411 
2412  for (auto val : outerBody.getArguments())
2413  if (llvm::is_contained(innerOp.getLowerBound(), val) ||
2414  llvm::is_contained(innerOp.getUpperBound(), val) ||
2415  llvm::is_contained(innerOp.getStep(), val))
2416  return failure();
2417 
2418  // Reductions are not supported yet.
2419  if (!op.getInitVals().empty() || !innerOp.getInitVals().empty())
2420  return failure();
2421 
2422  auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
2423  ValueRange iterVals, ValueRange) {
2424  Block &innerBody = innerOp.getLoopBody().front();
2425  assert(iterVals.size() ==
2426  (outerBody.getNumArguments() + innerBody.getNumArguments()));
2427  BlockAndValueMapping mapping;
2428  mapping.map(outerBody.getArguments(),
2429  iterVals.take_front(outerBody.getNumArguments()));
2430  mapping.map(innerBody.getArguments(),
2431  iterVals.take_back(innerBody.getNumArguments()));
2432  for (Operation &op : innerBody.without_terminator())
2433  builder.clone(op, mapping);
2434  };
2435 
2436  auto concatValues = [](const auto &first, const auto &second) {
2437  SmallVector<Value> ret;
2438  ret.reserve(first.size() + second.size());
2439  ret.assign(first.begin(), first.end());
2440  ret.append(second.begin(), second.end());
2441  return ret;
2442  };
2443 
2444  auto newLowerBounds =
2445  concatValues(op.getLowerBound(), innerOp.getLowerBound());
2446  auto newUpperBounds =
2447  concatValues(op.getUpperBound(), innerOp.getUpperBound());
2448  auto newSteps = concatValues(op.getStep(), innerOp.getStep());
2449 
2450  rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
2451  newSteps, llvm::None, bodyBuilder);
2452  return success();
2453  }
2454 };
2455 
2456 } // namespace
2457 
2458 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
2459  MLIRContext *context) {
2460  results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
2461  MergeNestedParallelLoops>(context);
2462 }
2463 
2464 //===----------------------------------------------------------------------===//
2465 // ReduceOp
2466 //===----------------------------------------------------------------------===//
2467 
2468 void ReduceOp::build(
2469  OpBuilder &builder, OperationState &result, Value operand,
2470  function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
2471  auto type = operand.getType();
2472  result.addOperands(operand);
2473 
2474  OpBuilder::InsertionGuard guard(builder);
2475  Region *bodyRegion = result.addRegion();
2476  Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type},
2477  {result.location, result.location});
2478  if (bodyBuilderFn)
2479  bodyBuilderFn(builder, result.location, body->getArgument(0),
2480  body->getArgument(1));
2481 }
2482 
2483 LogicalResult ReduceOp::verifyRegions() {
2484  // The region of a ReduceOp has two arguments of the same type as its operand.
2485  auto type = getOperand().getType();
2486  Block &block = getReductionOperator().front();
2487  if (block.empty())
2488  return emitOpError("the block inside reduce should not be empty");
2489  if (block.getNumArguments() != 2 ||
2490  llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
2491  return arg.getType() != type;
2492  }))
2493  return emitOpError() << "expects two arguments to reduce block of type "
2494  << type;
2495 
2496  // Check that the block is terminated by a ReduceReturnOp.
2497  if (!isa<ReduceReturnOp>(block.getTerminator()))
2498  return emitOpError("the block inside reduce should be terminated with a "
2499  "'scf.reduce.return' op");
2500 
2501  return success();
2502 }
2503 
2504 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2505  // Parse an opening `(` followed by the reduced value followed by `)`
2507  if (parser.parseLParen() || parser.parseOperand(operand) ||
2508  parser.parseRParen())
2509  return failure();
2510 
2511  Type resultType;
2512  // Parse the type of the operand (and also what reduce computes on).
2513  if (parser.parseColonType(resultType) ||
2514  parser.resolveOperand(operand, resultType, result.operands))
2515  return failure();
2516 
2517  // Now parse the body.
2518  Region *body = result.addRegion();
2519  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2520  return failure();
2521 
2522  return success();
2523 }
2524 
2525 void ReduceOp::print(OpAsmPrinter &p) {
2526  p << "(" << getOperand() << ") ";
2527  p << " : " << getOperand().getType() << ' ';
2528  p.printRegion(getReductionOperator());
2529 }
2530 
2531 //===----------------------------------------------------------------------===//
2532 // ReduceReturnOp
2533 //===----------------------------------------------------------------------===//
2534 
2536  // The type of the return value should be the same type as the type of the
2537  // operand of the enclosing ReduceOp.
2538  auto reduceOp = cast<ReduceOp>((*this)->getParentOp());
2539  Type reduceType = reduceOp.getOperand().getType();
2540  if (reduceType != getResult().getType())
2541  return emitOpError() << "needs to have type " << reduceType
2542  << " (the type of the enclosing ReduceOp)";
2543  return success();
2544 }
2545 
2546 //===----------------------------------------------------------------------===//
2547 // WhileOp
2548 //===----------------------------------------------------------------------===//
2549 
2550 OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
2551  assert(index && *index == 0 &&
2552  "WhileOp is expected to branch only to the first region");
2553 
2554  return getInits();
2555 }
2556 
2557 ConditionOp WhileOp::getConditionOp() {
2558  return cast<ConditionOp>(getBefore().front().getTerminator());
2559 }
2560 
2561 YieldOp WhileOp::getYieldOp() {
2562  return cast<YieldOp>(getAfter().front().getTerminator());
2563 }
2564 
2565 Block::BlockArgListType WhileOp::getBeforeArguments() {
2566  return getBefore().front().getArguments();
2567 }
2568 
2569 Block::BlockArgListType WhileOp::getAfterArguments() {
2570  return getAfter().front().getArguments();
2571 }
2572 
2573 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2574  ArrayRef<Attribute> operands,
2576  // The parent op always branches to the condition region.
2577  if (!index) {
2578  regions.emplace_back(&getBefore(), getBefore().getArguments());
2579  return;
2580  }
2581 
2582  assert(*index < 2 && "there are only two regions in a WhileOp");
2583  // The body region always branches back to the condition region.
2584  if (*index == 1) {
2585  regions.emplace_back(&getBefore(), getBefore().getArguments());
2586  return;
2587  }
2588 
2589  // Try to narrow the successor to the condition region.
2590  assert(!operands.empty() && "expected at least one operand");
2591  auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
2592  if (!cond || !cond.getValue())
2593  regions.emplace_back(getResults());
2594  if (!cond || cond.getValue())
2595  regions.emplace_back(&getAfter(), getAfter().getArguments());
2596 }
2597 
2598 /// Parses a `while` op.
2599 ///
2600 /// op ::= `scf.while` assignments `:` function-type region `do` region
2601 /// `attributes` attribute-dict
2602 /// initializer ::= /* empty */ | `(` assignment-list `)`
2603 /// assignment-list ::= assignment | assignment `,` assignment-list
2604 /// assignment ::= ssa-value `=` ssa-value
2605 ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2608  Region *before = result.addRegion();
2609  Region *after = result.addRegion();
2610 
2611  OptionalParseResult listResult =
2612  parser.parseOptionalAssignmentList(regionArgs, operands);
2613  if (listResult.has_value() && failed(listResult.value()))
2614  return failure();
2615 
2616  FunctionType functionType;
2617  SMLoc typeLoc = parser.getCurrentLocation();
2618  if (failed(parser.parseColonType(functionType)))
2619  return failure();
2620 
2621  result.addTypes(functionType.getResults());
2622 
2623  if (functionType.getNumInputs() != operands.size()) {
2624  return parser.emitError(typeLoc)
2625  << "expected as many input types as operands "
2626  << "(expected " << operands.size() << " got "
2627  << functionType.getNumInputs() << ")";
2628  }
2629 
2630  // Resolve input operands.
2631  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2632  parser.getCurrentLocation(),
2633  result.operands)))
2634  return failure();
2635 
2636  // Propagate the types into the region arguments.
2637  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2638  regionArgs[i].type = functionType.getInput(i);
2639 
2640  return failure(parser.parseRegion(*before, regionArgs) ||
2641  parser.parseKeyword("do") || parser.parseRegion(*after) ||
2643 }
2644 
2645 /// Prints a `while` op.
2647  printInitializationList(p, getBefore().front().getArguments(), getInits(),
2648  " ");
2649  p << " : ";
2650  p.printFunctionalType(getInits().getTypes(), getResults().getTypes());
2651  p << ' ';
2652  p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
2653  p << " do ";
2654  p.printRegion(getAfter());
2655  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2656 }
2657 
2658 /// Verifies that two ranges of types match, i.e. have the same number of
2659 /// entries and that types are pairwise equals. Reports errors on the given
2660 /// operation in case of mismatch.
2661 template <typename OpTy>
2662 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2663  TypeRange right, StringRef message) {
2664  if (left.size() != right.size())
2665  return op.emitOpError("expects the same number of ") << message;
2666 
2667  for (unsigned i = 0, e = left.size(); i < e; ++i) {
2668  if (left[i] != right[i]) {
2669  InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2670  << message;
2671  diag.attachNote() << "for argument " << i << ", found " << left[i]
2672  << " and " << right[i];
2673  return diag;
2674  }
2675  }
2676 
2677  return success();
2678 }
2679 
2680 /// Verifies that the first block of the given `region` is terminated by a
2681 /// YieldOp. Reports errors on the given operation if it is not the case.
2682 template <typename TerminatorTy>
2683 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2684  StringRef errorMessage) {
2685  Operation *terminatorOperation = region.front().getTerminator();
2686  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2687  return yield;
2688 
2689  auto diag = op.emitOpError(errorMessage);
2690  if (terminatorOperation)
2691  diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2692  return nullptr;
2693 }
2694 
2696  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2697  *this, getBefore(),
2698  "expects the 'before' region to terminate with 'scf.condition'");
2699  if (!beforeTerminator)
2700  return failure();
2701 
2702  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2703  *this, getAfter(),
2704  "expects the 'after' region to terminate with 'scf.yield'");
2705  return success(afterTerminator != nullptr);
2706 }
2707 
2708 namespace {
2709 /// Replace uses of the condition within the do block with true, since otherwise
2710 /// the block would not be evaluated.
2711 ///
2712 /// scf.while (..) : (i1, ...) -> ... {
2713 /// %condition = call @evaluate_condition() : () -> i1
2714 /// scf.condition(%condition) %condition : i1, ...
2715 /// } do {
2716 /// ^bb0(%arg0: i1, ...):
2717 /// use(%arg0)
2718 /// ...
2719 ///
2720 /// becomes
2721 /// scf.while (..) : (i1, ...) -> ... {
2722 /// %condition = call @evaluate_condition() : () -> i1
2723 /// scf.condition(%condition) %condition : i1, ...
2724 /// } do {
2725 /// ^bb0(%arg0: i1, ...):
2726 /// use(%true)
2727 /// ...
2728 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2730 
2731  LogicalResult matchAndRewrite(WhileOp op,
2732  PatternRewriter &rewriter) const override {
2733  auto term = op.getConditionOp();
2734 
2735  // These variables serve to prevent creating duplicate constants
2736  // and hold constant true or false values.
2737  Value constantTrue = nullptr;
2738 
2739  bool replaced = false;
2740  for (auto yieldedAndBlockArgs :
2741  llvm::zip(term.getArgs(), op.getAfterArguments())) {
2742  if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) {
2743  if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2744  if (!constantTrue)
2745  constantTrue = rewriter.create<arith::ConstantOp>(
2746  op.getLoc(), term.getCondition().getType(),
2747  rewriter.getBoolAttr(true));
2748 
2749  std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2750  replaced = true;
2751  }
2752  }
2753  }
2754  return success(replaced);
2755  }
2756 };
2757 
2758 /// Remove loop invariant arguments from `before` block of scf.while.
2759 /// A before block argument is considered loop invariant if :-
2760 /// 1. i-th yield operand is equal to the i-th while operand.
2761 /// 2. i-th yield operand is k-th after block argument which is (k+1)-th
2762 /// condition operand AND this (k+1)-th condition operand is equal to i-th
2763 /// iter argument/while operand.
2764 /// For the arguments which are removed, their uses inside scf.while
2765 /// are replaced with their corresponding initial value.
2766 ///
2767 /// Eg:
2768 /// INPUT :-
2769 /// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
2770 /// ..., %argN_before = %N)
2771 /// {
2772 /// ...
2773 /// scf.condition(%cond) %arg1_before, %arg0_before,
2774 /// %arg2_before, %arg0_before, ...
2775 /// } do {
2776 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2777 /// ..., %argK_after):
2778 /// ...
2779 /// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
2780 /// }
2781 ///
2782 /// OUTPUT :-
2783 /// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
2784 /// %N)
2785 /// {
2786 /// ...
2787 /// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
2788 /// } do {
2789 /// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
2790 /// ..., %argK_after):
2791 /// ...
2792 /// scf.yield %arg1_after, ..., %argN
2793 /// }
2794 ///
2795 /// EXPLANATION:
2796 /// We iterate over each yield operand.
2797 /// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
2798 /// %arg0_before, which in turn is the 0-th iter argument. So we
2799 /// remove 0-th before block argument and yield operand, and replace
2800 /// all uses of the 0-th before block argument with its initial value
2801 /// %a.
2802 /// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
2803 /// value. So we remove this operand and the corresponding before
2804 /// block argument and replace all uses of 1-th before block argument
2805 /// with %b.
2806 struct RemoveLoopInvariantArgsFromBeforeBlock
2807  : public OpRewritePattern<WhileOp> {
2809 
2810  LogicalResult matchAndRewrite(WhileOp op,
2811  PatternRewriter &rewriter) const override {
2812  Block &afterBlock = op.getAfter().front();
2813  Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
2814  ConditionOp condOp = op.getConditionOp();
2815  OperandRange condOpArgs = condOp.getArgs();
2816  Operation *yieldOp = afterBlock.getTerminator();
2817  ValueRange yieldOpArgs = yieldOp->getOperands();
2818 
2819  bool canSimplify = false;
2820  for (const auto &it :
2821  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2822  auto index = static_cast<unsigned>(it.index());
2823  auto [initVal, yieldOpArg] = it.value();
2824  // If i-th yield operand is equal to the i-th operand of the scf.while,
2825  // the i-th before block argument is a loop invariant.
2826  if (yieldOpArg == initVal) {
2827  canSimplify = true;
2828  break;
2829  }
2830  // If the i-th yield operand is k-th after block argument, then we check
2831  // if the (k+1)-th condition op operand is equal to either the i-th before
2832  // block argument or the initial value of i-th before block argument. If
2833  // the comparison results `true`, i-th before block argument is a loop
2834  // invariant.
2835  auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2836  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2837  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2838  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2839  canSimplify = true;
2840  break;
2841  }
2842  }
2843  }
2844 
2845  if (!canSimplify)
2846  return failure();
2847 
2848  SmallVector<Value> newInitArgs, newYieldOpArgs;
2849  DenseMap<unsigned, Value> beforeBlockInitValMap;
2850  SmallVector<Location> newBeforeBlockArgLocs;
2851  for (const auto &it :
2852  llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
2853  auto index = static_cast<unsigned>(it.index());
2854  auto [initVal, yieldOpArg] = it.value();
2855 
2856  // If i-th yield operand is equal to the i-th operand of the scf.while,
2857  // the i-th before block argument is a loop invariant.
2858  if (yieldOpArg == initVal) {
2859  beforeBlockInitValMap.insert({index, initVal});
2860  continue;
2861  } else {
2862  // If the i-th yield operand is k-th after block argument, then we check
2863  // if the (k+1)-th condition op operand is equal to either the i-th
2864  // before block argument or the initial value of i-th before block
2865  // argument. If the comparison results `true`, i-th before block
2866  // argument is a loop invariant.
2867  auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
2868  if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
2869  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
2870  if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
2871  beforeBlockInitValMap.insert({index, initVal});
2872  continue;
2873  }
2874  }
2875  }
2876  newInitArgs.emplace_back(initVal);
2877  newYieldOpArgs.emplace_back(yieldOpArg);
2878  newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
2879  }
2880 
2881  {
2882  OpBuilder::InsertionGuard g(rewriter);
2883  rewriter.setInsertionPoint(yieldOp);
2884  rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
2885  }
2886 
2887  auto newWhile =
2888  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
2889 
2890  Block &newBeforeBlock = *rewriter.createBlock(
2891  &newWhile.getBefore(), /*insertPt*/ {},
2892  ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
2893 
2894  Block &beforeBlock = op.getBefore().front();
2895  SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
2896  // For each i-th before block argument we find it's replacement value as :-
2897  // 1. If i-th before block argument is a loop invariant, we fetch it's
2898  // initial value from `beforeBlockInitValMap` by querying for key `i`.
2899  // 2. Else we fetch j-th new before block argument as the replacement
2900  // value of i-th before block argument.
2901  for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
2902  // If the index 'i' argument was a loop invariant we fetch it's initial
2903  // value from `beforeBlockInitValMap`.
2904  if (beforeBlockInitValMap.count(i) != 0)
2905  newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
2906  else
2907  newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
2908  }
2909 
2910  rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
2911  rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
2912  newWhile.getAfter().begin());
2913 
2914  rewriter.replaceOp(op, newWhile.getResults());
2915  return success();
2916  }
2917 };
2918 
2919 /// Remove loop invariant value from result (condition op) of scf.while.
2920 /// A value is considered loop invariant if the final value yielded by
2921 /// scf.condition is defined outside of the `before` block. We remove the
2922 /// corresponding argument in `after` block and replace the use with the value.
2923 /// We also replace the use of the corresponding result of scf.while with the
2924 /// value.
2925 ///
2926 /// Eg:
2927 /// INPUT :-
2928 /// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
2929 /// %argN_before = %N) {
2930 /// ...
2931 /// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
2932 /// } do {
2933 /// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
2934 /// ...
2935 /// some_func(%arg1_after)
2936 /// ...
2937 /// scf.yield %arg0_after, %arg2_after, ..., %argN_after
2938 /// }
2939 ///
2940 /// OUTPUT :-
2941 /// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
2942 /// ...
2943 /// scf.condition(%cond) %arg0, %arg1, ..., %argM
2944 /// } do {
2945 /// ^bb0(%arg0, %arg3, ..., %argM):
2946 /// ...
2947 /// some_func(%a)
2948 /// ...
2949 /// scf.yield %arg0, %b, ..., %argN
2950 /// }
2951 ///
2952 /// EXPLANATION:
2953 /// 1. The 1-th and 2-th operand of scf.condition are defined outside the
2954 /// before block of scf.while, so they get removed.
2955 /// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
2956 /// replaced by %b.
2957 /// 3. The corresponding after block argument %arg1_after's uses are
2958 /// replaced by %a and %arg2_after's uses are replaced by %b.
2959 struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
2961 
2962  LogicalResult matchAndRewrite(WhileOp op,
2963  PatternRewriter &rewriter) const override {
2964  Block &beforeBlock = op.getBefore().front();
2965  ConditionOp condOp = op.getConditionOp();
2966  OperandRange condOpArgs = condOp.getArgs();
2967 
2968  bool canSimplify = false;
2969  for (Value condOpArg : condOpArgs) {
2970  // Those values not defined within `before` block will be considered as
2971  // loop invariant values. We map the corresponding `index` with their
2972  // value.
2973  if (condOpArg.getParentBlock() != &beforeBlock) {
2974  canSimplify = true;
2975  break;
2976  }
2977  }
2978 
2979  if (!canSimplify)
2980  return failure();
2981 
2982  Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
2983 
2984  SmallVector<Value> newCondOpArgs;
2985  SmallVector<Type> newAfterBlockType;
2986  DenseMap<unsigned, Value> condOpInitValMap;
2987  SmallVector<Location> newAfterBlockArgLocs;
2988  for (const auto &it : llvm::enumerate(condOpArgs)) {
2989  auto index = static_cast<unsigned>(it.index());
2990  Value condOpArg = it.value();
2991  // Those values not defined within `before` block will be considered as
2992  // loop invariant values. We map the corresponding `index` with their
2993  // value.
2994  if (condOpArg.getParentBlock() != &beforeBlock) {
2995  condOpInitValMap.insert({index, condOpArg});
2996  } else {
2997  newCondOpArgs.emplace_back(condOpArg);
2998  newAfterBlockType.emplace_back(condOpArg.getType());
2999  newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
3000  }
3001  }
3002 
3003  {
3004  OpBuilder::InsertionGuard g(rewriter);
3005  rewriter.setInsertionPoint(condOp);
3006  rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
3007  newCondOpArgs);
3008  }
3009 
3010  auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
3011  op.getOperands());
3012 
3013  Block &newAfterBlock =
3014  *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
3015  newAfterBlockType, newAfterBlockArgLocs);
3016 
3017  Block &afterBlock = op.getAfter().front();
3018  // Since a new scf.condition op was created, we need to fetch the new
3019  // `after` block arguments which will be used while replacing operations of
3020  // previous scf.while's `after` blocks. We'd also be fetching new result
3021  // values too.
3022  SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
3023  SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
3024  for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
3025  Value afterBlockArg, result;
3026  // If index 'i' argument was loop invariant we fetch it's value from the
3027  // `condOpInitMap` map.
3028  if (condOpInitValMap.count(i) != 0) {
3029  afterBlockArg = condOpInitValMap[i];
3030  result = afterBlockArg;
3031  } else {
3032  afterBlockArg = newAfterBlock.getArgument(j);
3033  result = newWhile.getResult(j);
3034  j++;
3035  }
3036  newAfterBlockArgs[i] = afterBlockArg;
3037  newWhileResults[i] = result;
3038  }
3039 
3040  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3041  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3042  newWhile.getBefore().begin());
3043 
3044  rewriter.replaceOp(op, newWhileResults);
3045  return success();
3046  }
3047 };
3048 
3049 /// Remove WhileOp results that are also unused in 'after' block.
3050 ///
3051 /// %0:2 = scf.while () : () -> (i32, i64) {
3052 /// %condition = "test.condition"() : () -> i1
3053 /// %v1 = "test.get_some_value"() : () -> i32
3054 /// %v2 = "test.get_some_value"() : () -> i64
3055 /// scf.condition(%condition) %v1, %v2 : i32, i64
3056 /// } do {
3057 /// ^bb0(%arg0: i32, %arg1: i64):
3058 /// "test.use"(%arg0) : (i32) -> ()
3059 /// scf.yield
3060 /// }
3061 /// return %0#0 : i32
3062 ///
3063 /// becomes
3064 /// %0 = scf.while () : () -> (i32) {
3065 /// %condition = "test.condition"() : () -> i1
3066 /// %v1 = "test.get_some_value"() : () -> i32
3067 /// %v2 = "test.get_some_value"() : () -> i64
3068 /// scf.condition(%condition) %v1 : i32
3069 /// } do {
3070 /// ^bb0(%arg0: i32):
3071 /// "test.use"(%arg0) : (i32) -> ()
3072 /// scf.yield
3073 /// }
3074 /// return %0 : i32
3075 struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
3077 
3078  LogicalResult matchAndRewrite(WhileOp op,
3079  PatternRewriter &rewriter) const override {
3080  auto term = op.getConditionOp();
3081  auto afterArgs = op.getAfterArguments();
3082  auto termArgs = term.getArgs();
3083 
3084  // Collect results mapping, new terminator args and new result types.
3085  SmallVector<unsigned> newResultsIndices;
3086  SmallVector<Type> newResultTypes;
3087  SmallVector<Value> newTermArgs;
3088  SmallVector<Location> newArgLocs;
3089  bool needUpdate = false;
3090  for (const auto &it :
3091  llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
3092  auto i = static_cast<unsigned>(it.index());
3093  Value result = std::get<0>(it.value());
3094  Value afterArg = std::get<1>(it.value());
3095  Value termArg = std::get<2>(it.value());
3096  if (result.use_empty() && afterArg.use_empty()) {
3097  needUpdate = true;
3098  } else {
3099  newResultsIndices.emplace_back(i);
3100  newTermArgs.emplace_back(termArg);
3101  newResultTypes.emplace_back(result.getType());
3102  newArgLocs.emplace_back(result.getLoc());
3103  }
3104  }
3105 
3106  if (!needUpdate)
3107  return failure();
3108 
3109  {
3110  OpBuilder::InsertionGuard g(rewriter);
3111  rewriter.setInsertionPoint(term);
3112  rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
3113  newTermArgs);
3114  }
3115 
3116  auto newWhile =
3117  rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits());
3118 
3119  Block &newAfterBlock = *rewriter.createBlock(
3120  &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
3121 
3122  // Build new results list and new after block args (unused entries will be
3123  // null).
3124  SmallVector<Value> newResults(op.getNumResults());
3125  SmallVector<Value> newAfterBlockArgs(op.getNumResults());
3126  for (const auto &it : llvm::enumerate(newResultsIndices)) {
3127  newResults[it.value()] = newWhile.getResult(it.index());
3128  newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
3129  }
3130 
3131  rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
3132  newWhile.getBefore().begin());
3133 
3134  Block &afterBlock = op.getAfter().front();
3135  rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
3136 
3137  rewriter.replaceOp(op, newResults);
3138  return success();
3139  }
3140 };
3141 
3142 /// Replace operations equivalent to the condition in the do block with true,
3143 /// since otherwise the block would not be evaluated.
3144 ///
3145 /// scf.while (..) : (i32, ...) -> ... {
3146 /// %z = ... : i32
3147 /// %condition = cmpi pred %z, %a
3148 /// scf.condition(%condition) %z : i32, ...
3149 /// } do {
3150 /// ^bb0(%arg0: i32, ...):
3151 /// %condition2 = cmpi pred %arg0, %a
3152 /// use(%condition2)
3153 /// ...
3154 ///
3155 /// becomes
3156 /// scf.while (..) : (i32, ...) -> ... {
3157 /// %z = ... : i32
3158 /// %condition = cmpi pred %z, %a
3159 /// scf.condition(%condition) %z : i32, ...
3160 /// } do {
3161 /// ^bb0(%arg0: i32, ...):
3162 /// use(%true)
3163 /// ...
3164 struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
3166 
3167  LogicalResult matchAndRewrite(scf::WhileOp op,
3168  PatternRewriter &rewriter) const override {
3169  using namespace scf;
3170  auto cond = op.getConditionOp();
3171  auto cmp = cond.getCondition().getDefiningOp<arith::CmpIOp>();
3172  if (!cmp)
3173  return failure();
3174  bool changed = false;
3175  for (auto tup :
3176  llvm::zip(cond.getArgs(), op.getAfter().front().getArguments())) {
3177  for (size_t opIdx = 0; opIdx < 2; opIdx++) {
3178  if (std::get<0>(tup) != cmp.getOperand(opIdx))
3179  continue;
3180  for (OpOperand &u :
3181  llvm::make_early_inc_range(std::get<1>(tup).getUses())) {
3182  auto cmp2 = dyn_cast<arith::CmpIOp>(u.getOwner());
3183  if (!cmp2)
3184  continue;
3185  // For a binary operator 1-opIdx gets the other side.
3186  if (cmp2.getOperand(1 - opIdx) != cmp.getOperand(1 - opIdx))
3187  continue;
3188  bool samePredicate;
3189  if (cmp2.getPredicate() == cmp.getPredicate())
3190  samePredicate = true;
3191  else if (cmp2.getPredicate() ==
3192  arith::invertPredicate(cmp.getPredicate()))
3193  samePredicate = false;
3194  else
3195  continue;
3196 
3197  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(cmp2, samePredicate,
3198  1);
3199  changed = true;
3200  }
3201  }
3202  }
3203  return success(changed);
3204  }
3205 };
3206 
3207 struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
3209 
3210  LogicalResult matchAndRewrite(WhileOp op,
3211  PatternRewriter &rewriter) const override {
3212 
3213  if (!llvm::any_of(op.getBeforeArguments(),
3214  [](Value arg) { return arg.use_empty(); }))
3215  return failure();
3216 
3217  YieldOp yield = op.getYieldOp();
3218 
3219  // Collect results mapping, new terminator args and new result types.
3220  SmallVector<Value> newYields;
3221  SmallVector<Value> newInits;
3222  SmallVector<unsigned> argsToErase;
3223  for (const auto &it : llvm::enumerate(llvm::zip(
3224  op.getBeforeArguments(), yield.getOperands(), op.getInits()))) {
3225  Value beforeArg = std::get<0>(it.value());
3226  Value yieldValue = std::get<1>(it.value());
3227  Value initValue = std::get<2>(it.value());
3228  if (beforeArg.use_empty()) {
3229  argsToErase.push_back(it.index());
3230  } else {
3231  newYields.emplace_back(yieldValue);
3232  newInits.emplace_back(initValue);
3233  }
3234  }
3235 
3236  if (argsToErase.empty())
3237  return failure();
3238 
3239  rewriter.startRootUpdate(op);
3240  op.getBefore().front().eraseArguments(argsToErase);
3241  rewriter.finalizeRootUpdate(op);
3242 
3243  WhileOp replacement =
3244  rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInits);
3245  replacement.getBefore().takeBody(op.getBefore());
3246  replacement.getAfter().takeBody(op.getAfter());
3247  rewriter.replaceOp(op, replacement.getResults());
3248 
3249  rewriter.setInsertionPoint(yield);
3250  rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
3251  return success();
3252  }
3253 };
3254 } // namespace
3255 
3256 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
3257  MLIRContext *context) {
3258  results.add<RemoveLoopInvariantArgsFromBeforeBlock,
3259  RemoveLoopInvariantValueYielded, WhileConditionTruth,
3260  WhileCmpCond, WhileUnusedResult>(context);
3261 }
3262 
3263 //===----------------------------------------------------------------------===//
3264 // TableGen'd op method definitions
3265 //===----------------------------------------------------------------------===//
3266 
3267 #define GET_OP_CLASSES
3268 #include "mlir/Dialect/SCF/IR/SCFOps.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
Include the generated interface declarations.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
static std::string diag(llvm::Value &v)
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition: SCF.cpp:85
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:389
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Operation & back()
Definition: Block.h:143
This is a value defined by a result of an operation.
Definition: Value.h:425
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:43
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:310
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:148
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:345
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:496
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:523
void push_back(Block *block)
Definition: Region.h:61
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
std::vector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:63
unsigned getNumOperands()
Definition: Operation.h:263
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b)
Return true if ops a and b (or their ancestors) are in mutually exclusive regions/blocks of an IfOp...
Definition: SCF.cpp:1283
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:206
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
This is the representation of an operand reference.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:200
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:248
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:484
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:312
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:169
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:359
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:81
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:408
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function< bool(OpOperand &) const > functor)
This method replaces the uses of the results of op with the values in newValues when the provided fun...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool empty()
Definition: Region.h:60
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
void addOperands(ValueRange newOperands)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:198
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:348
unsigned getNumArguments()
Definition: Block.h:119
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
U dyn_cast() const
Definition: Value.h:100
type_range getTypes() const
Definition: ValueRange.cpp:44
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:437
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
virtual ParseResult parseRParen()=0
Parse a ) token.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:40
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Block & back()
Definition: Region.h:64
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
IntegerType getI1Type()
Definition: Builders.cpp:50
ParallelOp getParallelForInductionVarOwner(Value val)
Returns the parallel loop parent of an induction variable.
Definition: SCF.cpp:2292
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:114
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues=llvm::None)
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
ForeachThreadOp getForeachThreadOpThreadIndexOwner(Value val)
Returns the ForeachThreadOp parent of an thread index variable.
Definition: SCF.cpp:1205
Parens surrounding zero or more operands.
BlockArgListType getArguments()
Definition: Block.h:76
This class represents an argument of a Block.
Definition: Value.h:300
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:76
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:489
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:139
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:320
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:203
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition: SCF.cpp:505
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:378
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:295
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
Type getType() const
Return the type of this value.
Definition: Value.h:118
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of &#39;OpT&#39;. ...
Definition: Block.h:184
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for IfOp builders. Inserts a yield without arguments.
Definition: SCF.cpp:75
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, BlockAndValueMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent"...
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2, <...>) where &#39;inner&#39; values are assumed to be region arguments and &#39;outer&#39; values are regular SSA values.
Definition: SCF.cpp:370
This class represents an operand of an operation.
Definition: Value.h:251
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
U cast() const
Definition: Value.h:108
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:383
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
virtual ParseResult parseEqual()=0
Parse a = token.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:381
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:508
bool isa() const
Definition: Types.h:254
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
This class represents success/failure for parsing-like operations that find it important to chain tog...
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
Definition: Operation.h:362
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:193
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override
Definition: SCF.cpp:157
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.