MLIR  22.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the control flow operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 
18 #include "llvm/Support/InterleavedRange.h"
19 
20 #include "SPIRVOpUtils.h"
21 #include "SPIRVParsingUtils.h"
22 
23 using namespace mlir::spirv::AttrNames;
24 
25 namespace mlir::spirv {
26 
27 /// Parses Function, Selection and Loop control attributes. If no control is
28 /// specified, "None" is used as a default.
29 template <typename EnumAttrClass, typename EnumClass>
30 static ParseResult
32  StringRef attrName = spirv::attributeName<EnumClass>()) {
33  if (succeeded(parser.parseOptionalKeyword(kControl))) {
34  EnumClass control;
35  if (parser.parseLParen() ||
36  spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
37  parser.parseRParen())
38  return failure();
39  return success();
40  }
41  // Set control to "None" otherwise.
42  Builder builder = parser.getBuilder();
43  state.addAttribute(attrName,
44  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
45  return success();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // spirv.BranchOp
50 //===----------------------------------------------------------------------===//
51 
53  assert(index == 0 && "invalid successor index");
54  return SuccessorOperands(0, getTargetOperandsMutable());
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // spirv.BranchConditionalOp
59 //===----------------------------------------------------------------------===//
60 
61 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
62  assert(index < 2 && "invalid successor index");
63  return SuccessorOperands(index == kTrueIndex
64  ? getTrueTargetOperandsMutable()
65  : getFalseTargetOperandsMutable());
66 }
67 
68 ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
69  OperationState &result) {
70  auto &builder = parser.getBuilder();
71  OpAsmParser::UnresolvedOperand condInfo;
72  Block *dest;
73 
74  // Parse the condition.
75  Type boolTy = builder.getI1Type();
76  if (parser.parseOperand(condInfo) ||
77  parser.resolveOperand(condInfo, boolTy, result.operands))
78  return failure();
79 
80  // Parse the optional branch weights.
81  if (succeeded(parser.parseOptionalLSquare())) {
82  IntegerAttr trueWeight, falseWeight;
83  NamedAttrList weights;
84 
85  auto i32Type = builder.getIntegerType(32);
86  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
87  parser.parseComma() ||
88  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
89  parser.parseRSquare())
90  return failure();
91 
92  StringAttr branchWeightsAttrName =
93  BranchConditionalOp::getBranchWeightsAttrName(result.name);
94  result.addAttribute(branchWeightsAttrName,
95  builder.getArrayAttr({trueWeight, falseWeight}));
96  }
97 
98  // Parse the true branch.
99  SmallVector<Value, 4> trueOperands;
100  if (parser.parseComma() ||
101  parser.parseSuccessorAndUseList(dest, trueOperands))
102  return failure();
103  result.addSuccessors(dest);
104  result.addOperands(trueOperands);
105 
106  // Parse the false branch.
107  SmallVector<Value, 4> falseOperands;
108  if (parser.parseComma() ||
109  parser.parseSuccessorAndUseList(dest, falseOperands))
110  return failure();
111  result.addSuccessors(dest);
112  result.addOperands(falseOperands);
113  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
114  builder.getDenseI32ArrayAttr(
115  {1, static_cast<int32_t>(trueOperands.size()),
116  static_cast<int32_t>(falseOperands.size())}));
117 
118  return success();
119 }
120 
121 void BranchConditionalOp::print(OpAsmPrinter &printer) {
122  printer << ' ' << getCondition();
123 
124  if (std::optional<ArrayAttr> weights = getBranchWeights()) {
125  printer << ' '
126  << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
127  }
128 
129  printer << ", ";
130  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
131  printer << ", ";
132  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
133 }
134 
135 LogicalResult BranchConditionalOp::verify() {
136  if (auto weights = getBranchWeights()) {
137  if (weights->getValue().size() != 2) {
138  return emitOpError("must have exactly two branch weights");
139  }
140  if (llvm::all_of(*weights, [](Attribute attr) {
141  return llvm::cast<IntegerAttr>(attr).getValue().isZero();
142  }))
143  return emitOpError("branch weights cannot both be zero");
144  }
145 
146  return success();
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // spirv.FunctionCall
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult FunctionCallOp::verify() {
154  if (getNumResults() > 1) {
155  return emitOpError(
156  "expected callee function to have 0 or 1 result, but provided ")
157  << getNumResults();
158  }
159  return success();
160 }
161 
162 LogicalResult
163 FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
164  auto fnName = getCalleeAttr();
165 
166  auto funcOp =
167  symbolTable.lookupNearestSymbolFrom<spirv::FuncOp>(*this, fnName);
168  if (!funcOp) {
169  return emitOpError("callee function '")
170  << fnName.getValue() << "' not found in nearest symbol table";
171  }
172 
173  auto functionType = funcOp.getFunctionType();
174 
175  if (functionType.getNumInputs() != getNumOperands()) {
176  return emitOpError("has incorrect number of operands for callee: expected ")
177  << functionType.getNumInputs() << ", but provided "
178  << getNumOperands();
179  }
180 
181  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
182  if (getOperand(i).getType() != functionType.getInput(i)) {
183  return emitOpError("operand type mismatch: expected operand type ")
184  << functionType.getInput(i) << ", but provided "
185  << getOperand(i).getType() << " for operand number " << i;
186  }
187  }
188 
189  if (functionType.getNumResults() != getNumResults()) {
190  return emitOpError(
191  "has incorrect number of results has for callee: expected ")
192  << functionType.getNumResults() << ", but provided "
193  << getNumResults();
194  }
195 
196  if (getNumResults() &&
197  (getResult(0).getType() != functionType.getResult(0))) {
198  return emitOpError("result type mismatch: expected ")
199  << functionType.getResult(0) << ", but provided "
200  << getResult(0).getType();
201  }
202 
203  return success();
204 }
205 
206 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
207  return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
208 }
209 
210 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
211  (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
212 }
213 
214 Operation::operand_range FunctionCallOp::getArgOperands() {
215  return getArguments();
216 }
217 
218 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
219  return getArgumentsMutable();
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // spirv.mlir.loop
224 //===----------------------------------------------------------------------===//
225 
226 void LoopOp::build(OpBuilder &builder, OperationState &state) {
227  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
229  state.addRegion();
230 }
231 
232 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
233  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
234  result))
235  return failure();
236 
237  if (succeeded(parser.parseOptionalArrow()))
238  if (parser.parseTypeList(result.types))
239  return failure();
240 
241  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
242 }
243 
244 void LoopOp::print(OpAsmPrinter &printer) {
245  auto control = getLoopControl();
246  if (control != spirv::LoopControl::None)
247  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
248  if (getNumResults() > 0) {
249  printer << " -> ";
250  printer << getResultTypes();
251  }
252  printer << ' ';
253  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
254  /*printBlockTerminators=*/true);
255 }
256 
257 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
258 /// given `dstBlock`.
259 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
260  // Check that there is only one op in the `srcBlock`.
261  if (!llvm::hasSingleElement(srcBlock))
262  return false;
263 
264  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
265  return branchOp && branchOp.getSuccessor() == &dstBlock;
266 }
267 
268 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
269 static bool isMergeBlock(Block &block) {
270  return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.front());
271 }
272 
273 /// Returns true if a `spirv.mlir.merge` op outside the merge block.
274 static bool hasOtherMerge(Region &region) {
275  return !region.empty() && llvm::any_of(region.getOps(), [&](Operation &op) {
276  return isa<spirv::MergeOp>(op) && op.getBlock() != &region.back();
277  });
278 }
279 
280 LogicalResult LoopOp::verifyRegions() {
281  auto *op = getOperation();
282 
283  // We need to verify that the blocks follow the following layout:
284  //
285  // +-------------+
286  // | entry block |
287  // +-------------+
288  // |
289  // v
290  // +-------------+
291  // | loop header | <-----+
292  // +-------------+ |
293  // |
294  // ... |
295  // \ | / |
296  // v |
297  // +---------------+ |
298  // | loop continue | -----+
299  // +---------------+
300  //
301  // ...
302  // \ | /
303  // v
304  // +-------------+
305  // | merge block |
306  // +-------------+
307 
308  auto &region = op->getRegion(0);
309  // Allow empty region as a degenerated case, which can come from
310  // optimizations.
311  if (region.empty())
312  return success();
313 
314  // The last block is the merge block.
315  Block &merge = region.back();
316  if (!isMergeBlock(merge))
317  return emitOpError("last block must be the merge block with only one "
318  "'spirv.mlir.merge' op");
319  if (hasOtherMerge(region))
320  return emitOpError(
321  "should not have 'spirv.mlir.merge' op outside the merge block");
322 
323  if (region.hasOneBlock())
324  return emitOpError(
325  "must have an entry block branching to the loop header block");
326  // The first block is the entry block.
327  Block &entry = region.front();
328 
329  if (std::next(region.begin(), 2) == region.end())
330  return emitOpError(
331  "must have a loop header block branched from the entry block");
332  // The second block is the loop header block.
333  Block &header = *std::next(region.begin(), 1);
334 
335  if (!hasOneBranchOpTo(entry, header))
336  return emitOpError(
337  "entry block must only have one 'spirv.Branch' op to the second block");
338 
339  if (std::next(region.begin(), 3) == region.end())
340  return emitOpError(
341  "requires a loop continue block branching to the loop header block");
342  // The second to last block is the loop continue block.
343  Block &cont = *std::prev(region.end(), 2);
344 
345  // Make sure that we have a branch from the loop continue block to the loop
346  // header block.
347  if (llvm::none_of(
348  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
349  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
350  return emitOpError("second to last block must be the loop continue "
351  "block that branches to the loop header block");
352 
353  // Make sure that no other blocks (except the entry and loop continue block)
354  // branches to the loop header block.
355  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
356  std::prev(region.end(), 2))) {
357  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
358  if (block.getSuccessor(i) == &header) {
359  return emitOpError("can only have the entry and loop continue "
360  "block branching to the loop header block");
361  }
362  }
363  }
364 
365  return success();
366 }
367 
368 Block *LoopOp::getEntryBlock() {
369  assert(!getBody().empty() && "op region should not be empty!");
370  return &getBody().front();
371 }
372 
373 Block *LoopOp::getHeaderBlock() {
374  assert(!getBody().empty() && "op region should not be empty!");
375  // The second block is the loop header block.
376  return &*std::next(getBody().begin());
377 }
378 
379 Block *LoopOp::getContinueBlock() {
380  assert(!getBody().empty() && "op region should not be empty!");
381  // The second to last block is the loop continue block.
382  return &*std::prev(getBody().end(), 2);
383 }
384 
385 Block *LoopOp::getMergeBlock() {
386  assert(!getBody().empty() && "op region should not be empty!");
387  // The last block is the loop merge block.
388  return &getBody().back();
389 }
390 
391 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
392  assert(getBody().empty() && "entry and merge block already exist");
393  OpBuilder::InsertionGuard g(builder);
394  builder.createBlock(&getBody());
395  builder.createBlock(&getBody());
396 
397  // Add a spirv.mlir.merge op into the merge block.
398  spirv::MergeOp::create(builder, getLoc());
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // spirv.Return
403 //===----------------------------------------------------------------------===//
404 
405 LogicalResult ReturnOp::verify() {
406  // Verification is performed in spirv.func op.
407  return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // spirv.ReturnValue
412 //===----------------------------------------------------------------------===//
413 
414 LogicalResult ReturnValueOp::verify() {
415  // Verification is performed in spirv.func op.
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // spirv.Select
421 //===----------------------------------------------------------------------===//
422 
423 LogicalResult SelectOp::verify() {
424  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
425  auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
426  if (!resultVectorTy) {
427  return emitOpError("result expected to be of vector type when "
428  "condition is of vector type");
429  }
430  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
431  return emitOpError("result should have the same number of elements as "
432  "the condition when condition is of vector type");
433  }
434  }
435  return success();
436 }
437 
438 // Custom availability implementation is needed for spirv.Select given the
439 // syntax changes starting v1.4.
440 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
441  return {};
442 }
443 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
444  return {};
445 }
446 std::optional<spirv::Version> SelectOp::getMinVersion() {
447  // Per the spec, "Before version 1.4, results are only computed per
448  // component."
449  if (isa<spirv::ScalarType>(getCondition().getType()) &&
450  isa<spirv::CompositeType>(getType()))
451  return Version::V_1_4;
452 
453  return Version::V_1_0;
454 }
455 std::optional<spirv::Version> SelectOp::getMaxVersion() {
456  return Version::V_1_6;
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // spirv.mlir.selection
461 //===----------------------------------------------------------------------===//
462 
463 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
464  if (parseControlAttribute<spirv::SelectionControlAttr,
465  spirv::SelectionControl>(parser, result))
466  return failure();
467 
468  if (succeeded(parser.parseOptionalArrow()))
469  if (parser.parseTypeList(result.types))
470  return failure();
471 
472  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
473 }
474 
475 void SelectionOp::print(OpAsmPrinter &printer) {
476  auto control = getSelectionControl();
477  if (control != spirv::SelectionControl::None)
478  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
479  if (getNumResults() > 0) {
480  printer << " -> ";
481  printer << getResultTypes();
482  }
483  printer << ' ';
484  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
485  /*printBlockTerminators=*/true);
486 }
487 
488 LogicalResult SelectionOp::verifyRegions() {
489  auto *op = getOperation();
490 
491  // We need to verify that the blocks follow the following layout:
492  //
493  // +--------------+
494  // | header block |
495  // +--------------+
496  // / | \
497  // ...
498  //
499  //
500  // +---------+ +---------+ +---------+
501  // | case #0 | | case #1 | | case #2 | ...
502  // +---------+ +---------+ +---------+
503  //
504  //
505  // ...
506  // \ | /
507  // v
508  // +-------------+
509  // | merge block |
510  // +-------------+
511 
512  auto &region = op->getRegion(0);
513  // Allow empty region as a degenerated case, which can come from
514  // optimizations.
515  if (region.empty())
516  return success();
517 
518  // The last block is the merge block.
519  if (!isMergeBlock(region.back()))
520  return emitOpError("last block must be the merge block with only one "
521  "'spirv.mlir.merge' op");
522  if (hasOtherMerge(region))
523  return emitOpError(
524  "should not have 'spirv.mlir.merge' op outside the merge block");
525 
526  if (region.hasOneBlock())
527  return emitOpError("must have a selection header block");
528 
529  return success();
530 }
531 
532 Block *SelectionOp::getHeaderBlock() {
533  assert(!getBody().empty() && "op region should not be empty!");
534  // The first block is the loop header block.
535  return &getBody().front();
536 }
537 
538 Block *SelectionOp::getMergeBlock() {
539  assert(!getBody().empty() && "op region should not be empty!");
540  // The last block is the loop merge block.
541  return &getBody().back();
542 }
543 
544 void SelectionOp::addMergeBlock(OpBuilder &builder) {
545  assert(getBody().empty() && "entry and merge block already exist");
546  OpBuilder::InsertionGuard guard(builder);
547  builder.createBlock(&getBody());
548 
549  // Add a spirv.mlir.merge op into the merge block.
550  spirv::MergeOp::create(builder, getLoc());
551 }
552 
553 SelectionOp
554 SelectionOp::createIfThen(Location loc, Value condition,
555  function_ref<void(OpBuilder &builder)> thenBody,
556  OpBuilder &builder) {
557  auto selectionOp =
558  spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None);
559 
560  selectionOp.addMergeBlock(builder);
561  Block *mergeBlock = selectionOp.getMergeBlock();
562  Block *thenBlock = nullptr;
563 
564  // Build the "then" block.
565  {
566  OpBuilder::InsertionGuard guard(builder);
567  thenBlock = builder.createBlock(mergeBlock);
568  thenBody(builder);
569  spirv::BranchOp::create(builder, loc, mergeBlock);
570  }
571 
572  // Build the header block.
573  {
574  OpBuilder::InsertionGuard guard(builder);
575  builder.createBlock(thenBlock);
576  spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
577  /*trueArguments=*/ArrayRef<Value>(),
578  mergeBlock,
579  /*falseArguments=*/ArrayRef<Value>());
580  }
581 
582  return selectionOp;
583 }
584 
585 //===----------------------------------------------------------------------===//
586 // spirv.Unreachable
587 //===----------------------------------------------------------------------===//
588 
589 LogicalResult spirv::UnreachableOp::verify() {
590  auto *block = (*this)->getBlock();
591  // Fast track: if this is in entry block, its invalid. Otherwise, if no
592  // predecessors, it's valid.
593  if (block->isEntryBlock())
594  return emitOpError("cannot be used in reachable block");
595  if (block->hasNoPredecessors())
596  return success();
597 
598  // TODO: further verification needs to analyze reachability from
599  // the entry block.
600 
601  return success();
602 }
603 
604 } // namespace mlir::spirv
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:140
@ None
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumSuccessors()
Definition: Block.cpp:265
Operation & back()
Definition: Block.h:152
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperandRange operand_range
Definition: Operation.h:371
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
This class models how operands are forwarded to block arguments in control flow.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
constexpr char kControl[]
static bool hasOtherMerge(Region &region)
Returns true if a spirv.mlir.merge op outside the merge block.
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.