MLIR  21.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the control flow operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 
18 #include "llvm/Support/InterleavedRange.h"
19 
20 #include "SPIRVOpUtils.h"
21 #include "SPIRVParsingUtils.h"
22 
23 using namespace mlir::spirv::AttrNames;
24 
25 namespace mlir::spirv {
26 
27 /// Parses Function, Selection and Loop control attributes. If no control is
28 /// specified, "None" is used as a default.
29 template <typename EnumAttrClass, typename EnumClass>
30 static ParseResult
32  StringRef attrName = spirv::attributeName<EnumClass>()) {
33  if (succeeded(parser.parseOptionalKeyword(kControl))) {
34  EnumClass control;
35  if (parser.parseLParen() ||
36  spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
37  parser.parseRParen())
38  return failure();
39  return success();
40  }
41  // Set control to "None" otherwise.
42  Builder builder = parser.getBuilder();
43  state.addAttribute(attrName,
44  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
45  return success();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // spirv.BranchOp
50 //===----------------------------------------------------------------------===//
51 
53  assert(index == 0 && "invalid successor index");
54  return SuccessorOperands(0, getTargetOperandsMutable());
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // spirv.BranchConditionalOp
59 //===----------------------------------------------------------------------===//
60 
61 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
62  assert(index < 2 && "invalid successor index");
63  return SuccessorOperands(index == kTrueIndex
64  ? getTrueTargetOperandsMutable()
65  : getFalseTargetOperandsMutable());
66 }
67 
68 ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
69  OperationState &result) {
70  auto &builder = parser.getBuilder();
71  OpAsmParser::UnresolvedOperand condInfo;
72  Block *dest;
73 
74  // Parse the condition.
75  Type boolTy = builder.getI1Type();
76  if (parser.parseOperand(condInfo) ||
77  parser.resolveOperand(condInfo, boolTy, result.operands))
78  return failure();
79 
80  // Parse the optional branch weights.
81  if (succeeded(parser.parseOptionalLSquare())) {
82  IntegerAttr trueWeight, falseWeight;
83  NamedAttrList weights;
84 
85  auto i32Type = builder.getIntegerType(32);
86  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
87  parser.parseComma() ||
88  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
89  parser.parseRSquare())
90  return failure();
91 
92  StringAttr branchWeightsAttrName =
93  BranchConditionalOp::getBranchWeightsAttrName(result.name);
94  result.addAttribute(branchWeightsAttrName,
95  builder.getArrayAttr({trueWeight, falseWeight}));
96  }
97 
98  // Parse the true branch.
99  SmallVector<Value, 4> trueOperands;
100  if (parser.parseComma() ||
101  parser.parseSuccessorAndUseList(dest, trueOperands))
102  return failure();
103  result.addSuccessors(dest);
104  result.addOperands(trueOperands);
105 
106  // Parse the false branch.
107  SmallVector<Value, 4> falseOperands;
108  if (parser.parseComma() ||
109  parser.parseSuccessorAndUseList(dest, falseOperands))
110  return failure();
111  result.addSuccessors(dest);
112  result.addOperands(falseOperands);
113  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
114  builder.getDenseI32ArrayAttr(
115  {1, static_cast<int32_t>(trueOperands.size()),
116  static_cast<int32_t>(falseOperands.size())}));
117 
118  return success();
119 }
120 
121 void BranchConditionalOp::print(OpAsmPrinter &printer) {
122  printer << ' ' << getCondition();
123 
124  if (std::optional<ArrayAttr> weights = getBranchWeights()) {
125  printer << ' '
126  << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
127  }
128 
129  printer << ", ";
130  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
131  printer << ", ";
132  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
133 }
134 
135 LogicalResult BranchConditionalOp::verify() {
136  if (auto weights = getBranchWeights()) {
137  if (weights->getValue().size() != 2) {
138  return emitOpError("must have exactly two branch weights");
139  }
140  if (llvm::all_of(*weights, [](Attribute attr) {
141  return llvm::cast<IntegerAttr>(attr).getValue().isZero();
142  }))
143  return emitOpError("branch weights cannot both be zero");
144  }
145 
146  return success();
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // spirv.FunctionCall
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult FunctionCallOp::verify() {
154  auto fnName = getCalleeAttr();
155 
156  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
157  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
158  if (!funcOp) {
159  return emitOpError("callee function '")
160  << fnName.getValue() << "' not found in nearest symbol table";
161  }
162 
163  auto functionType = funcOp.getFunctionType();
164 
165  if (getNumResults() > 1) {
166  return emitOpError(
167  "expected callee function to have 0 or 1 result, but provided ")
168  << getNumResults();
169  }
170 
171  if (functionType.getNumInputs() != getNumOperands()) {
172  return emitOpError("has incorrect number of operands for callee: expected ")
173  << functionType.getNumInputs() << ", but provided "
174  << getNumOperands();
175  }
176 
177  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
178  if (getOperand(i).getType() != functionType.getInput(i)) {
179  return emitOpError("operand type mismatch: expected operand type ")
180  << functionType.getInput(i) << ", but provided "
181  << getOperand(i).getType() << " for operand number " << i;
182  }
183  }
184 
185  if (functionType.getNumResults() != getNumResults()) {
186  return emitOpError(
187  "has incorrect number of results has for callee: expected ")
188  << functionType.getNumResults() << ", but provided "
189  << getNumResults();
190  }
191 
192  if (getNumResults() &&
193  (getResult(0).getType() != functionType.getResult(0))) {
194  return emitOpError("result type mismatch: expected ")
195  << functionType.getResult(0) << ", but provided "
196  << getResult(0).getType();
197  }
198 
199  return success();
200 }
201 
202 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
203  return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
204 }
205 
206 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
207  (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
208 }
209 
210 Operation::operand_range FunctionCallOp::getArgOperands() {
211  return getArguments();
212 }
213 
214 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
215  return getArgumentsMutable();
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // spirv.mlir.loop
220 //===----------------------------------------------------------------------===//
221 
222 void LoopOp::build(OpBuilder &builder, OperationState &state) {
223  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
225  state.addRegion();
226 }
227 
228 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
229  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
230  result))
231  return failure();
232  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
233 }
234 
235 void LoopOp::print(OpAsmPrinter &printer) {
236  auto control = getLoopControl();
237  if (control != spirv::LoopControl::None)
238  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
239  printer << ' ';
240  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
241  /*printBlockTerminators=*/true);
242 }
243 
244 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
245 /// given `dstBlock`.
246 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
247  // Check that there is only one op in the `srcBlock`.
248  if (!llvm::hasSingleElement(srcBlock))
249  return false;
250 
251  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
252  return branchOp && branchOp.getSuccessor() == &dstBlock;
253 }
254 
255 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
256 static bool isMergeBlock(Block &block) {
257  return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.front());
258 }
259 
260 /// Returns true if a `spirv.mlir.merge` op outside the merge block.
261 static bool hasOtherMerge(Region &region) {
262  return !region.empty() && llvm::any_of(region.getOps(), [&](Operation &op) {
263  return isa<spirv::MergeOp>(op) && op.getBlock() != &region.back();
264  });
265 }
266 
267 LogicalResult LoopOp::verifyRegions() {
268  auto *op = getOperation();
269 
270  // We need to verify that the blocks follow the following layout:
271  //
272  // +-------------+
273  // | entry block |
274  // +-------------+
275  // |
276  // v
277  // +-------------+
278  // | loop header | <-----+
279  // +-------------+ |
280  // |
281  // ... |
282  // \ | / |
283  // v |
284  // +---------------+ |
285  // | loop continue | -----+
286  // +---------------+
287  //
288  // ...
289  // \ | /
290  // v
291  // +-------------+
292  // | merge block |
293  // +-------------+
294 
295  auto &region = op->getRegion(0);
296  // Allow empty region as a degenerated case, which can come from
297  // optimizations.
298  if (region.empty())
299  return success();
300 
301  // The last block is the merge block.
302  Block &merge = region.back();
303  if (!isMergeBlock(merge))
304  return emitOpError("last block must be the merge block with only one "
305  "'spirv.mlir.merge' op");
306  if (hasOtherMerge(region))
307  return emitOpError(
308  "should not have 'spirv.mlir.merge' op outside the merge block");
309 
310  if (region.hasOneBlock())
311  return emitOpError(
312  "must have an entry block branching to the loop header block");
313  // The first block is the entry block.
314  Block &entry = region.front();
315 
316  if (std::next(region.begin(), 2) == region.end())
317  return emitOpError(
318  "must have a loop header block branched from the entry block");
319  // The second block is the loop header block.
320  Block &header = *std::next(region.begin(), 1);
321 
322  if (!hasOneBranchOpTo(entry, header))
323  return emitOpError(
324  "entry block must only have one 'spirv.Branch' op to the second block");
325 
326  if (std::next(region.begin(), 3) == region.end())
327  return emitOpError(
328  "requires a loop continue block branching to the loop header block");
329  // The second to last block is the loop continue block.
330  Block &cont = *std::prev(region.end(), 2);
331 
332  // Make sure that we have a branch from the loop continue block to the loop
333  // header block.
334  if (llvm::none_of(
335  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
336  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
337  return emitOpError("second to last block must be the loop continue "
338  "block that branches to the loop header block");
339 
340  // Make sure that no other blocks (except the entry and loop continue block)
341  // branches to the loop header block.
342  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
343  std::prev(region.end(), 2))) {
344  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
345  if (block.getSuccessor(i) == &header) {
346  return emitOpError("can only have the entry and loop continue "
347  "block branching to the loop header block");
348  }
349  }
350  }
351 
352  return success();
353 }
354 
355 Block *LoopOp::getEntryBlock() {
356  assert(!getBody().empty() && "op region should not be empty!");
357  return &getBody().front();
358 }
359 
360 Block *LoopOp::getHeaderBlock() {
361  assert(!getBody().empty() && "op region should not be empty!");
362  // The second block is the loop header block.
363  return &*std::next(getBody().begin());
364 }
365 
366 Block *LoopOp::getContinueBlock() {
367  assert(!getBody().empty() && "op region should not be empty!");
368  // The second to last block is the loop continue block.
369  return &*std::prev(getBody().end(), 2);
370 }
371 
372 Block *LoopOp::getMergeBlock() {
373  assert(!getBody().empty() && "op region should not be empty!");
374  // The last block is the loop merge block.
375  return &getBody().back();
376 }
377 
378 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
379  assert(getBody().empty() && "entry and merge block already exist");
380  OpBuilder::InsertionGuard g(builder);
381  builder.createBlock(&getBody());
382  builder.createBlock(&getBody());
383 
384  // Add a spirv.mlir.merge op into the merge block.
385  builder.create<spirv::MergeOp>(getLoc());
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // spirv.Return
390 //===----------------------------------------------------------------------===//
391 
392 LogicalResult ReturnOp::verify() {
393  // Verification is performed in spirv.func op.
394  return success();
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // spirv.ReturnValue
399 //===----------------------------------------------------------------------===//
400 
401 LogicalResult ReturnValueOp::verify() {
402  // Verification is performed in spirv.func op.
403  return success();
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // spirv.Select
408 //===----------------------------------------------------------------------===//
409 
410 LogicalResult SelectOp::verify() {
411  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
412  auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
413  if (!resultVectorTy) {
414  return emitOpError("result expected to be of vector type when "
415  "condition is of vector type");
416  }
417  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
418  return emitOpError("result should have the same number of elements as "
419  "the condition when condition is of vector type");
420  }
421  }
422  return success();
423 }
424 
425 // Custom availability implementation is needed for spirv.Select given the
426 // syntax changes starting v1.4.
427 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
428  return {};
429 }
430 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
431  return {};
432 }
433 std::optional<spirv::Version> SelectOp::getMinVersion() {
434  // Per the spec, "Before version 1.4, results are only computed per
435  // component."
436  if (isa<spirv::ScalarType>(getCondition().getType()) &&
437  isa<spirv::CompositeType>(getType()))
438  return Version::V_1_4;
439 
440  return Version::V_1_0;
441 }
442 std::optional<spirv::Version> SelectOp::getMaxVersion() {
443  return Version::V_1_6;
444 }
445 
446 //===----------------------------------------------------------------------===//
447 // spirv.mlir.selection
448 //===----------------------------------------------------------------------===//
449 
450 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
451  if (parseControlAttribute<spirv::SelectionControlAttr,
452  spirv::SelectionControl>(parser, result))
453  return failure();
454 
455  if (succeeded(parser.parseOptionalArrow()))
456  if (parser.parseTypeList(result.types))
457  return failure();
458 
459  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
460 }
461 
462 void SelectionOp::print(OpAsmPrinter &printer) {
463  auto control = getSelectionControl();
464  if (control != spirv::SelectionControl::None)
465  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
466  if (getNumResults() > 0) {
467  printer << " -> ";
468  printer << getResultTypes();
469  }
470  printer << ' ';
471  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
472  /*printBlockTerminators=*/true);
473 }
474 
475 LogicalResult SelectionOp::verifyRegions() {
476  auto *op = getOperation();
477 
478  // We need to verify that the blocks follow the following layout:
479  //
480  // +--------------+
481  // | header block |
482  // +--------------+
483  // / | \
484  // ...
485  //
486  //
487  // +---------+ +---------+ +---------+
488  // | case #0 | | case #1 | | case #2 | ...
489  // +---------+ +---------+ +---------+
490  //
491  //
492  // ...
493  // \ | /
494  // v
495  // +-------------+
496  // | merge block |
497  // +-------------+
498 
499  auto &region = op->getRegion(0);
500  // Allow empty region as a degenerated case, which can come from
501  // optimizations.
502  if (region.empty())
503  return success();
504 
505  // The last block is the merge block.
506  if (!isMergeBlock(region.back()))
507  return emitOpError("last block must be the merge block with only one "
508  "'spirv.mlir.merge' op");
509  if (hasOtherMerge(region))
510  return emitOpError(
511  "should not have 'spirv.mlir.merge' op outside the merge block");
512 
513  if (region.hasOneBlock())
514  return emitOpError("must have a selection header block");
515 
516  return success();
517 }
518 
519 Block *SelectionOp::getHeaderBlock() {
520  assert(!getBody().empty() && "op region should not be empty!");
521  // The first block is the loop header block.
522  return &getBody().front();
523 }
524 
525 Block *SelectionOp::getMergeBlock() {
526  assert(!getBody().empty() && "op region should not be empty!");
527  // The last block is the loop merge block.
528  return &getBody().back();
529 }
530 
531 void SelectionOp::addMergeBlock(OpBuilder &builder) {
532  assert(getBody().empty() && "entry and merge block already exist");
533  OpBuilder::InsertionGuard guard(builder);
534  builder.createBlock(&getBody());
535 
536  // Add a spirv.mlir.merge op into the merge block.
537  builder.create<spirv::MergeOp>(getLoc());
538 }
539 
540 SelectionOp
541 SelectionOp::createIfThen(Location loc, Value condition,
542  function_ref<void(OpBuilder &builder)> thenBody,
543  OpBuilder &builder) {
544  auto selectionOp =
545  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
546 
547  selectionOp.addMergeBlock(builder);
548  Block *mergeBlock = selectionOp.getMergeBlock();
549  Block *thenBlock = nullptr;
550 
551  // Build the "then" block.
552  {
553  OpBuilder::InsertionGuard guard(builder);
554  thenBlock = builder.createBlock(mergeBlock);
555  thenBody(builder);
556  builder.create<spirv::BranchOp>(loc, mergeBlock);
557  }
558 
559  // Build the header block.
560  {
561  OpBuilder::InsertionGuard guard(builder);
562  builder.createBlock(thenBlock);
563  builder.create<spirv::BranchConditionalOp>(
564  loc, condition, thenBlock,
565  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
566  /*falseArguments=*/ArrayRef<Value>());
567  }
568 
569  return selectionOp;
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // spirv.Unreachable
574 //===----------------------------------------------------------------------===//
575 
576 LogicalResult spirv::UnreachableOp::verify() {
577  auto *block = (*this)->getBlock();
578  // Fast track: if this is in entry block, its invalid. Otherwise, if no
579  // predecessors, it's valid.
580  if (block->isEntryBlock())
581  return emitOpError("cannot be used in reachable block");
582  if (block->hasNoPredecessors())
583  return success();
584 
585  // TODO: further verification needs to analyze reachability from
586  // the entry block.
587 
588  return success();
589 }
590 
591 } // namespace mlir::spirv
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
@ None
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumSuccessors()
Definition: Block.cpp:257
Operation & back()
Definition: Block.h:152
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperandRange operand_range
Definition: Operation.h:371
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
constexpr char kControl[]
static bool hasOtherMerge(Region &region)
Returns true if a spirv.mlir.merge op outside the merge block.
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This represents an operation in an abstracted form, suitable for use with the builder APIs.