MLIR  21.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the control flow operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 
18 #include "llvm/Support/InterleavedRange.h"
19 
20 #include "SPIRVOpUtils.h"
21 #include "SPIRVParsingUtils.h"
22 
23 using namespace mlir::spirv::AttrNames;
24 
25 namespace mlir::spirv {
26 
27 /// Parses Function, Selection and Loop control attributes. If no control is
28 /// specified, "None" is used as a default.
29 template <typename EnumAttrClass, typename EnumClass>
30 static ParseResult
32  StringRef attrName = spirv::attributeName<EnumClass>()) {
33  if (succeeded(parser.parseOptionalKeyword(kControl))) {
34  EnumClass control;
35  if (parser.parseLParen() ||
36  spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
37  parser.parseRParen())
38  return failure();
39  return success();
40  }
41  // Set control to "None" otherwise.
42  Builder builder = parser.getBuilder();
43  state.addAttribute(attrName,
44  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
45  return success();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // spirv.BranchOp
50 //===----------------------------------------------------------------------===//
51 
53  assert(index == 0 && "invalid successor index");
54  return SuccessorOperands(0, getTargetOperandsMutable());
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // spirv.BranchConditionalOp
59 //===----------------------------------------------------------------------===//
60 
61 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
62  assert(index < 2 && "invalid successor index");
63  return SuccessorOperands(index == kTrueIndex
64  ? getTrueTargetOperandsMutable()
65  : getFalseTargetOperandsMutable());
66 }
67 
68 ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
69  OperationState &result) {
70  auto &builder = parser.getBuilder();
71  OpAsmParser::UnresolvedOperand condInfo;
72  Block *dest;
73 
74  // Parse the condition.
75  Type boolTy = builder.getI1Type();
76  if (parser.parseOperand(condInfo) ||
77  parser.resolveOperand(condInfo, boolTy, result.operands))
78  return failure();
79 
80  // Parse the optional branch weights.
81  if (succeeded(parser.parseOptionalLSquare())) {
82  IntegerAttr trueWeight, falseWeight;
83  NamedAttrList weights;
84 
85  auto i32Type = builder.getIntegerType(32);
86  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
87  parser.parseComma() ||
88  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
89  parser.parseRSquare())
90  return failure();
91 
92  StringAttr branchWeightsAttrName =
93  BranchConditionalOp::getBranchWeightsAttrName(result.name);
94  result.addAttribute(branchWeightsAttrName,
95  builder.getArrayAttr({trueWeight, falseWeight}));
96  }
97 
98  // Parse the true branch.
99  SmallVector<Value, 4> trueOperands;
100  if (parser.parseComma() ||
101  parser.parseSuccessorAndUseList(dest, trueOperands))
102  return failure();
103  result.addSuccessors(dest);
104  result.addOperands(trueOperands);
105 
106  // Parse the false branch.
107  SmallVector<Value, 4> falseOperands;
108  if (parser.parseComma() ||
109  parser.parseSuccessorAndUseList(dest, falseOperands))
110  return failure();
111  result.addSuccessors(dest);
112  result.addOperands(falseOperands);
113  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
114  builder.getDenseI32ArrayAttr(
115  {1, static_cast<int32_t>(trueOperands.size()),
116  static_cast<int32_t>(falseOperands.size())}));
117 
118  return success();
119 }
120 
121 void BranchConditionalOp::print(OpAsmPrinter &printer) {
122  printer << ' ' << getCondition();
123 
124  if (std::optional<ArrayAttr> weights = getBranchWeights()) {
125  printer << ' '
126  << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
127  }
128 
129  printer << ", ";
130  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
131  printer << ", ";
132  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
133 }
134 
135 LogicalResult BranchConditionalOp::verify() {
136  if (auto weights = getBranchWeights()) {
137  if (weights->getValue().size() != 2) {
138  return emitOpError("must have exactly two branch weights");
139  }
140  if (llvm::all_of(*weights, [](Attribute attr) {
141  return llvm::cast<IntegerAttr>(attr).getValue().isZero();
142  }))
143  return emitOpError("branch weights cannot both be zero");
144  }
145 
146  return success();
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // spirv.FunctionCall
151 //===----------------------------------------------------------------------===//
152 
153 LogicalResult FunctionCallOp::verify() {
154  auto fnName = getCalleeAttr();
155 
156  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
157  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
158  if (!funcOp) {
159  return emitOpError("callee function '")
160  << fnName.getValue() << "' not found in nearest symbol table";
161  }
162 
163  auto functionType = funcOp.getFunctionType();
164 
165  if (getNumResults() > 1) {
166  return emitOpError(
167  "expected callee function to have 0 or 1 result, but provided ")
168  << getNumResults();
169  }
170 
171  if (functionType.getNumInputs() != getNumOperands()) {
172  return emitOpError("has incorrect number of operands for callee: expected ")
173  << functionType.getNumInputs() << ", but provided "
174  << getNumOperands();
175  }
176 
177  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
178  if (getOperand(i).getType() != functionType.getInput(i)) {
179  return emitOpError("operand type mismatch: expected operand type ")
180  << functionType.getInput(i) << ", but provided "
181  << getOperand(i).getType() << " for operand number " << i;
182  }
183  }
184 
185  if (functionType.getNumResults() != getNumResults()) {
186  return emitOpError(
187  "has incorrect number of results has for callee: expected ")
188  << functionType.getNumResults() << ", but provided "
189  << getNumResults();
190  }
191 
192  if (getNumResults() &&
193  (getResult(0).getType() != functionType.getResult(0))) {
194  return emitOpError("result type mismatch: expected ")
195  << functionType.getResult(0) << ", but provided "
196  << getResult(0).getType();
197  }
198 
199  return success();
200 }
201 
202 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
203  return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
204 }
205 
206 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
207  (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
208 }
209 
210 Operation::operand_range FunctionCallOp::getArgOperands() {
211  return getArguments();
212 }
213 
214 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
215  return getArgumentsMutable();
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // spirv.mlir.loop
220 //===----------------------------------------------------------------------===//
221 
222 void LoopOp::build(OpBuilder &builder, OperationState &state) {
223  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
225  state.addRegion();
226 }
227 
228 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
229  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
230  result))
231  return failure();
232 
233  if (succeeded(parser.parseOptionalArrow()))
234  if (parser.parseTypeList(result.types))
235  return failure();
236 
237  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
238 }
239 
240 void LoopOp::print(OpAsmPrinter &printer) {
241  auto control = getLoopControl();
242  if (control != spirv::LoopControl::None)
243  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
244  if (getNumResults() > 0) {
245  printer << " -> ";
246  printer << getResultTypes();
247  }
248  printer << ' ';
249  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
250  /*printBlockTerminators=*/true);
251 }
252 
253 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
254 /// given `dstBlock`.
255 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
256  // Check that there is only one op in the `srcBlock`.
257  if (!llvm::hasSingleElement(srcBlock))
258  return false;
259 
260  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
261  return branchOp && branchOp.getSuccessor() == &dstBlock;
262 }
263 
264 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
265 static bool isMergeBlock(Block &block) {
266  return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.front());
267 }
268 
269 /// Returns true if a `spirv.mlir.merge` op outside the merge block.
270 static bool hasOtherMerge(Region &region) {
271  return !region.empty() && llvm::any_of(region.getOps(), [&](Operation &op) {
272  return isa<spirv::MergeOp>(op) && op.getBlock() != &region.back();
273  });
274 }
275 
276 LogicalResult LoopOp::verifyRegions() {
277  auto *op = getOperation();
278 
279  // We need to verify that the blocks follow the following layout:
280  //
281  // +-------------+
282  // | entry block |
283  // +-------------+
284  // |
285  // v
286  // +-------------+
287  // | loop header | <-----+
288  // +-------------+ |
289  // |
290  // ... |
291  // \ | / |
292  // v |
293  // +---------------+ |
294  // | loop continue | -----+
295  // +---------------+
296  //
297  // ...
298  // \ | /
299  // v
300  // +-------------+
301  // | merge block |
302  // +-------------+
303 
304  auto &region = op->getRegion(0);
305  // Allow empty region as a degenerated case, which can come from
306  // optimizations.
307  if (region.empty())
308  return success();
309 
310  // The last block is the merge block.
311  Block &merge = region.back();
312  if (!isMergeBlock(merge))
313  return emitOpError("last block must be the merge block with only one "
314  "'spirv.mlir.merge' op");
315  if (hasOtherMerge(region))
316  return emitOpError(
317  "should not have 'spirv.mlir.merge' op outside the merge block");
318 
319  if (region.hasOneBlock())
320  return emitOpError(
321  "must have an entry block branching to the loop header block");
322  // The first block is the entry block.
323  Block &entry = region.front();
324 
325  if (std::next(region.begin(), 2) == region.end())
326  return emitOpError(
327  "must have a loop header block branched from the entry block");
328  // The second block is the loop header block.
329  Block &header = *std::next(region.begin(), 1);
330 
331  if (!hasOneBranchOpTo(entry, header))
332  return emitOpError(
333  "entry block must only have one 'spirv.Branch' op to the second block");
334 
335  if (std::next(region.begin(), 3) == region.end())
336  return emitOpError(
337  "requires a loop continue block branching to the loop header block");
338  // The second to last block is the loop continue block.
339  Block &cont = *std::prev(region.end(), 2);
340 
341  // Make sure that we have a branch from the loop continue block to the loop
342  // header block.
343  if (llvm::none_of(
344  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
345  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
346  return emitOpError("second to last block must be the loop continue "
347  "block that branches to the loop header block");
348 
349  // Make sure that no other blocks (except the entry and loop continue block)
350  // branches to the loop header block.
351  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
352  std::prev(region.end(), 2))) {
353  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
354  if (block.getSuccessor(i) == &header) {
355  return emitOpError("can only have the entry and loop continue "
356  "block branching to the loop header block");
357  }
358  }
359  }
360 
361  return success();
362 }
363 
364 Block *LoopOp::getEntryBlock() {
365  assert(!getBody().empty() && "op region should not be empty!");
366  return &getBody().front();
367 }
368 
369 Block *LoopOp::getHeaderBlock() {
370  assert(!getBody().empty() && "op region should not be empty!");
371  // The second block is the loop header block.
372  return &*std::next(getBody().begin());
373 }
374 
375 Block *LoopOp::getContinueBlock() {
376  assert(!getBody().empty() && "op region should not be empty!");
377  // The second to last block is the loop continue block.
378  return &*std::prev(getBody().end(), 2);
379 }
380 
381 Block *LoopOp::getMergeBlock() {
382  assert(!getBody().empty() && "op region should not be empty!");
383  // The last block is the loop merge block.
384  return &getBody().back();
385 }
386 
387 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
388  assert(getBody().empty() && "entry and merge block already exist");
389  OpBuilder::InsertionGuard g(builder);
390  builder.createBlock(&getBody());
391  builder.createBlock(&getBody());
392 
393  // Add a spirv.mlir.merge op into the merge block.
394  builder.create<spirv::MergeOp>(getLoc());
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // spirv.Return
399 //===----------------------------------------------------------------------===//
400 
401 LogicalResult ReturnOp::verify() {
402  // Verification is performed in spirv.func op.
403  return success();
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // spirv.ReturnValue
408 //===----------------------------------------------------------------------===//
409 
410 LogicalResult ReturnValueOp::verify() {
411  // Verification is performed in spirv.func op.
412  return success();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // spirv.Select
417 //===----------------------------------------------------------------------===//
418 
419 LogicalResult SelectOp::verify() {
420  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
421  auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
422  if (!resultVectorTy) {
423  return emitOpError("result expected to be of vector type when "
424  "condition is of vector type");
425  }
426  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
427  return emitOpError("result should have the same number of elements as "
428  "the condition when condition is of vector type");
429  }
430  }
431  return success();
432 }
433 
434 // Custom availability implementation is needed for spirv.Select given the
435 // syntax changes starting v1.4.
436 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
437  return {};
438 }
439 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
440  return {};
441 }
442 std::optional<spirv::Version> SelectOp::getMinVersion() {
443  // Per the spec, "Before version 1.4, results are only computed per
444  // component."
445  if (isa<spirv::ScalarType>(getCondition().getType()) &&
446  isa<spirv::CompositeType>(getType()))
447  return Version::V_1_4;
448 
449  return Version::V_1_0;
450 }
451 std::optional<spirv::Version> SelectOp::getMaxVersion() {
452  return Version::V_1_6;
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // spirv.mlir.selection
457 //===----------------------------------------------------------------------===//
458 
459 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
460  if (parseControlAttribute<spirv::SelectionControlAttr,
461  spirv::SelectionControl>(parser, result))
462  return failure();
463 
464  if (succeeded(parser.parseOptionalArrow()))
465  if (parser.parseTypeList(result.types))
466  return failure();
467 
468  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
469 }
470 
471 void SelectionOp::print(OpAsmPrinter &printer) {
472  auto control = getSelectionControl();
473  if (control != spirv::SelectionControl::None)
474  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
475  if (getNumResults() > 0) {
476  printer << " -> ";
477  printer << getResultTypes();
478  }
479  printer << ' ';
480  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
481  /*printBlockTerminators=*/true);
482 }
483 
484 LogicalResult SelectionOp::verifyRegions() {
485  auto *op = getOperation();
486 
487  // We need to verify that the blocks follow the following layout:
488  //
489  // +--------------+
490  // | header block |
491  // +--------------+
492  // / | \
493  // ...
494  //
495  //
496  // +---------+ +---------+ +---------+
497  // | case #0 | | case #1 | | case #2 | ...
498  // +---------+ +---------+ +---------+
499  //
500  //
501  // ...
502  // \ | /
503  // v
504  // +-------------+
505  // | merge block |
506  // +-------------+
507 
508  auto &region = op->getRegion(0);
509  // Allow empty region as a degenerated case, which can come from
510  // optimizations.
511  if (region.empty())
512  return success();
513 
514  // The last block is the merge block.
515  if (!isMergeBlock(region.back()))
516  return emitOpError("last block must be the merge block with only one "
517  "'spirv.mlir.merge' op");
518  if (hasOtherMerge(region))
519  return emitOpError(
520  "should not have 'spirv.mlir.merge' op outside the merge block");
521 
522  if (region.hasOneBlock())
523  return emitOpError("must have a selection header block");
524 
525  return success();
526 }
527 
528 Block *SelectionOp::getHeaderBlock() {
529  assert(!getBody().empty() && "op region should not be empty!");
530  // The first block is the loop header block.
531  return &getBody().front();
532 }
533 
534 Block *SelectionOp::getMergeBlock() {
535  assert(!getBody().empty() && "op region should not be empty!");
536  // The last block is the loop merge block.
537  return &getBody().back();
538 }
539 
540 void SelectionOp::addMergeBlock(OpBuilder &builder) {
541  assert(getBody().empty() && "entry and merge block already exist");
542  OpBuilder::InsertionGuard guard(builder);
543  builder.createBlock(&getBody());
544 
545  // Add a spirv.mlir.merge op into the merge block.
546  builder.create<spirv::MergeOp>(getLoc());
547 }
548 
549 SelectionOp
550 SelectionOp::createIfThen(Location loc, Value condition,
551  function_ref<void(OpBuilder &builder)> thenBody,
552  OpBuilder &builder) {
553  auto selectionOp =
554  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
555 
556  selectionOp.addMergeBlock(builder);
557  Block *mergeBlock = selectionOp.getMergeBlock();
558  Block *thenBlock = nullptr;
559 
560  // Build the "then" block.
561  {
562  OpBuilder::InsertionGuard guard(builder);
563  thenBlock = builder.createBlock(mergeBlock);
564  thenBody(builder);
565  builder.create<spirv::BranchOp>(loc, mergeBlock);
566  }
567 
568  // Build the header block.
569  {
570  OpBuilder::InsertionGuard guard(builder);
571  builder.createBlock(thenBlock);
572  builder.create<spirv::BranchConditionalOp>(
573  loc, condition, thenBlock,
574  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
575  /*falseArguments=*/ArrayRef<Value>());
576  }
577 
578  return selectionOp;
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // spirv.Unreachable
583 //===----------------------------------------------------------------------===//
584 
585 LogicalResult spirv::UnreachableOp::verify() {
586  auto *block = (*this)->getBlock();
587  // Fast track: if this is in entry block, its invalid. Otherwise, if no
588  // predecessors, it's valid.
589  if (block->isEntryBlock())
590  return emitOpError("cannot be used in reachable block");
591  if (block->hasNoPredecessors())
592  return success();
593 
594  // TODO: further verification needs to analyze reachability from
595  // the entry block.
596 
597  return success();
598 }
599 
600 } // namespace mlir::spirv
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
@ None
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumSuccessors()
Definition: Block.cpp:257
Operation & back()
Definition: Block.h:152
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperandRange operand_range
Definition: Operation.h:371
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
constexpr char kControl[]
static bool hasOtherMerge(Region &region)
Returns true if a spirv.mlir.merge op outside the merge block.
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.