MLIR  18.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the control flow operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 
18 #include "SPIRVOpUtils.h"
19 #include "SPIRVParsingUtils.h"
20 
21 using namespace mlir::spirv::AttrNames;
22 
23 namespace mlir::spirv {
24 
25 /// Parses Function, Selection and Loop control attributes. If no control is
26 /// specified, "None" is used as a default.
27 template <typename EnumAttrClass, typename EnumClass>
28 static ParseResult
30  StringRef attrName = spirv::attributeName<EnumClass>()) {
32  EnumClass control;
33  if (parser.parseLParen() ||
34  spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
35  parser.parseRParen())
36  return failure();
37  return success();
38  }
39  // Set control to "None" otherwise.
40  Builder builder = parser.getBuilder();
41  state.addAttribute(attrName,
42  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
43  return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // spirv.BranchOp
48 //===----------------------------------------------------------------------===//
49 
51  assert(index == 0 && "invalid successor index");
52  return SuccessorOperands(0, getTargetOperandsMutable());
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // spirv.BranchConditionalOp
57 //===----------------------------------------------------------------------===//
58 
59 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
60  assert(index < 2 && "invalid successor index");
61  return SuccessorOperands(index == kTrueIndex
62  ? getTrueTargetOperandsMutable()
63  : getFalseTargetOperandsMutable());
64 }
65 
66 ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
67  OperationState &result) {
68  auto &builder = parser.getBuilder();
69  OpAsmParser::UnresolvedOperand condInfo;
70  Block *dest;
71 
72  // Parse the condition.
73  Type boolTy = builder.getI1Type();
74  if (parser.parseOperand(condInfo) ||
75  parser.resolveOperand(condInfo, boolTy, result.operands))
76  return failure();
77 
78  // Parse the optional branch weights.
79  if (succeeded(parser.parseOptionalLSquare())) {
80  IntegerAttr trueWeight, falseWeight;
81  NamedAttrList weights;
82 
83  auto i32Type = builder.getIntegerType(32);
84  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
85  parser.parseComma() ||
86  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
87  parser.parseRSquare())
88  return failure();
89 
90  result.addAttribute(kBranchWeightAttrName,
91  builder.getArrayAttr({trueWeight, falseWeight}));
92  }
93 
94  // Parse the true branch.
95  SmallVector<Value, 4> trueOperands;
96  if (parser.parseComma() ||
97  parser.parseSuccessorAndUseList(dest, trueOperands))
98  return failure();
99  result.addSuccessors(dest);
100  result.addOperands(trueOperands);
101 
102  // Parse the false branch.
103  SmallVector<Value, 4> falseOperands;
104  if (parser.parseComma() ||
105  parser.parseSuccessorAndUseList(dest, falseOperands))
106  return failure();
107  result.addSuccessors(dest);
108  result.addOperands(falseOperands);
109  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
110  builder.getDenseI32ArrayAttr(
111  {1, static_cast<int32_t>(trueOperands.size()),
112  static_cast<int32_t>(falseOperands.size())}));
113 
114  return success();
115 }
116 
117 void BranchConditionalOp::print(OpAsmPrinter &printer) {
118  printer << ' ' << getCondition();
119 
120  if (auto weights = getBranchWeights()) {
121  printer << " [";
122  llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
123  printer << llvm::cast<IntegerAttr>(a).getInt();
124  });
125  printer << "]";
126  }
127 
128  printer << ", ";
129  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
130  printer << ", ";
131  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
132 }
133 
134 LogicalResult BranchConditionalOp::verify() {
135  if (auto weights = getBranchWeights()) {
136  if (weights->getValue().size() != 2) {
137  return emitOpError("must have exactly two branch weights");
138  }
139  if (llvm::all_of(*weights, [](Attribute attr) {
140  return llvm::cast<IntegerAttr>(attr).getValue().isZero();
141  }))
142  return emitOpError("branch weights cannot both be zero");
143  }
144 
145  return success();
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // spirv.FunctionCall
150 //===----------------------------------------------------------------------===//
151 
152 LogicalResult FunctionCallOp::verify() {
153  auto fnName = getCalleeAttr();
154 
155  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
156  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
157  if (!funcOp) {
158  return emitOpError("callee function '")
159  << fnName.getValue() << "' not found in nearest symbol table";
160  }
161 
162  auto functionType = funcOp.getFunctionType();
163 
164  if (getNumResults() > 1) {
165  return emitOpError(
166  "expected callee function to have 0 or 1 result, but provided ")
167  << getNumResults();
168  }
169 
170  if (functionType.getNumInputs() != getNumOperands()) {
171  return emitOpError("has incorrect number of operands for callee: expected ")
172  << functionType.getNumInputs() << ", but provided "
173  << getNumOperands();
174  }
175 
176  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
177  if (getOperand(i).getType() != functionType.getInput(i)) {
178  return emitOpError("operand type mismatch: expected operand type ")
179  << functionType.getInput(i) << ", but provided "
180  << getOperand(i).getType() << " for operand number " << i;
181  }
182  }
183 
184  if (functionType.getNumResults() != getNumResults()) {
185  return emitOpError(
186  "has incorrect number of results has for callee: expected ")
187  << functionType.getNumResults() << ", but provided "
188  << getNumResults();
189  }
190 
191  if (getNumResults() &&
192  (getResult(0).getType() != functionType.getResult(0))) {
193  return emitOpError("result type mismatch: expected ")
194  << functionType.getResult(0) << ", but provided "
195  << getResult(0).getType();
196  }
197 
198  return success();
199 }
200 
201 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
202  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
203 }
204 
205 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
206  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
207 }
208 
209 Operation::operand_range FunctionCallOp::getArgOperands() {
210  return getArguments();
211 }
212 
213 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
214  return getArgumentsMutable();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // spirv.mlir.loop
219 //===----------------------------------------------------------------------===//
220 
221 void LoopOp::build(OpBuilder &builder, OperationState &state) {
222  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
224  state.addRegion();
225 }
226 
227 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
228  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
229  result))
230  return failure();
231  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
232 }
233 
234 void LoopOp::print(OpAsmPrinter &printer) {
235  auto control = getLoopControl();
236  if (control != spirv::LoopControl::None)
237  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
238  printer << ' ';
239  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
240  /*printBlockTerminators=*/true);
241 }
242 
243 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
244 /// given `dstBlock`.
245 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
246  // Check that there is only one op in the `srcBlock`.
247  if (!llvm::hasSingleElement(srcBlock))
248  return false;
249 
250  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
251  return branchOp && branchOp.getSuccessor() == &dstBlock;
252 }
253 
254 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
255 static bool isMergeBlock(Block &block) {
256  return !block.empty() && std::next(block.begin()) == block.end() &&
257  isa<spirv::MergeOp>(block.front());
258 }
259 
260 LogicalResult LoopOp::verifyRegions() {
261  auto *op = getOperation();
262 
263  // We need to verify that the blocks follow the following layout:
264  //
265  // +-------------+
266  // | entry block |
267  // +-------------+
268  // |
269  // v
270  // +-------------+
271  // | loop header | <-----+
272  // +-------------+ |
273  // |
274  // ... |
275  // \ | / |
276  // v |
277  // +---------------+ |
278  // | loop continue | -----+
279  // +---------------+
280  //
281  // ...
282  // \ | /
283  // v
284  // +-------------+
285  // | merge block |
286  // +-------------+
287 
288  auto &region = op->getRegion(0);
289  // Allow empty region as a degenerated case, which can come from
290  // optimizations.
291  if (region.empty())
292  return success();
293 
294  // The last block is the merge block.
295  Block &merge = region.back();
296  if (!isMergeBlock(merge))
297  return emitOpError("last block must be the merge block with only one "
298  "'spirv.mlir.merge' op");
299 
300  if (std::next(region.begin()) == region.end())
301  return emitOpError(
302  "must have an entry block branching to the loop header block");
303  // The first block is the entry block.
304  Block &entry = region.front();
305 
306  if (std::next(region.begin(), 2) == region.end())
307  return emitOpError(
308  "must have a loop header block branched from the entry block");
309  // The second block is the loop header block.
310  Block &header = *std::next(region.begin(), 1);
311 
312  if (!hasOneBranchOpTo(entry, header))
313  return emitOpError(
314  "entry block must only have one 'spirv.Branch' op to the second block");
315 
316  if (std::next(region.begin(), 3) == region.end())
317  return emitOpError(
318  "requires a loop continue block branching to the loop header block");
319  // The second to last block is the loop continue block.
320  Block &cont = *std::prev(region.end(), 2);
321 
322  // Make sure that we have a branch from the loop continue block to the loop
323  // header block.
324  if (llvm::none_of(
325  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
326  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
327  return emitOpError("second to last block must be the loop continue "
328  "block that branches to the loop header block");
329 
330  // Make sure that no other blocks (except the entry and loop continue block)
331  // branches to the loop header block.
332  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
333  std::prev(region.end(), 2))) {
334  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
335  if (block.getSuccessor(i) == &header) {
336  return emitOpError("can only have the entry and loop continue "
337  "block branching to the loop header block");
338  }
339  }
340  }
341 
342  return success();
343 }
344 
345 Block *LoopOp::getEntryBlock() {
346  assert(!getBody().empty() && "op region should not be empty!");
347  return &getBody().front();
348 }
349 
350 Block *LoopOp::getHeaderBlock() {
351  assert(!getBody().empty() && "op region should not be empty!");
352  // The second block is the loop header block.
353  return &*std::next(getBody().begin());
354 }
355 
356 Block *LoopOp::getContinueBlock() {
357  assert(!getBody().empty() && "op region should not be empty!");
358  // The second to last block is the loop continue block.
359  return &*std::prev(getBody().end(), 2);
360 }
361 
362 Block *LoopOp::getMergeBlock() {
363  assert(!getBody().empty() && "op region should not be empty!");
364  // The last block is the loop merge block.
365  return &getBody().back();
366 }
367 
368 void LoopOp::addEntryAndMergeBlock() {
369  assert(getBody().empty() && "entry and merge block already exist");
370  getBody().push_back(new Block());
371  auto *mergeBlock = new Block();
372  getBody().push_back(mergeBlock);
373  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
374 
375  // Add a spirv.mlir.merge op into the merge block.
376  builder.create<spirv::MergeOp>(getLoc());
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // spirv.mlir.merge
381 //===----------------------------------------------------------------------===//
382 
383 LogicalResult MergeOp::verify() {
384  auto *parentOp = (*this)->getParentOp();
385  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
386  return emitOpError(
387  "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
388 
389  // TODO: This check should be done in `verifyRegions` of parent op.
390  Block &parentLastBlock = (*this)->getParentRegion()->back();
391  if (getOperation() != parentLastBlock.getTerminator())
392  return emitOpError("can only be used in the last block of "
393  "'spirv.mlir.selection' or 'spirv.mlir.loop'");
394  return success();
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // spirv.Return
399 //===----------------------------------------------------------------------===//
400 
401 LogicalResult ReturnOp::verify() {
402  // Verification is performed in spirv.func op.
403  return success();
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // spirv.ReturnValue
408 //===----------------------------------------------------------------------===//
409 
410 LogicalResult ReturnValueOp::verify() {
411  // Verification is performed in spirv.func op.
412  return success();
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // spirv.Select
417 //===----------------------------------------------------------------------===//
418 
419 LogicalResult SelectOp::verify() {
420  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
421  auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
422  if (!resultVectorTy) {
423  return emitOpError("result expected to be of vector type when "
424  "condition is of vector type");
425  }
426  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
427  return emitOpError("result should have the same number of elements as "
428  "the condition when condition is of vector type");
429  }
430  }
431  return success();
432 }
433 
434 // Custom availability implementation is needed for spirv.Select given the
435 // syntax changes starting v1.4.
436 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
437  return {};
438 }
439 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
440  return {};
441 }
442 std::optional<spirv::Version> SelectOp::getMinVersion() {
443  // Per the spec, "Before version 1.4, results are only computed per
444  // component."
445  if (isa<spirv::ScalarType>(getCondition().getType()) &&
446  isa<spirv::CompositeType>(getType()))
447  return Version::V_1_4;
448 
449  return Version::V_1_0;
450 }
451 std::optional<spirv::Version> SelectOp::getMaxVersion() {
452  return Version::V_1_6;
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // spirv.mlir.selection
457 //===----------------------------------------------------------------------===//
458 
459 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
460  if (parseControlAttribute<spirv::SelectionControlAttr,
461  spirv::SelectionControl>(parser, result))
462  return failure();
463  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
464 }
465 
466 void SelectionOp::print(OpAsmPrinter &printer) {
467  auto control = getSelectionControl();
468  if (control != spirv::SelectionControl::None)
469  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
470  printer << ' ';
471  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
472  /*printBlockTerminators=*/true);
473 }
474 
475 LogicalResult SelectionOp::verifyRegions() {
476  auto *op = getOperation();
477 
478  // We need to verify that the blocks follow the following layout:
479  //
480  // +--------------+
481  // | header block |
482  // +--------------+
483  // / | \
484  // ...
485  //
486  //
487  // +---------+ +---------+ +---------+
488  // | case #0 | | case #1 | | case #2 | ...
489  // +---------+ +---------+ +---------+
490  //
491  //
492  // ...
493  // \ | /
494  // v
495  // +-------------+
496  // | merge block |
497  // +-------------+
498 
499  auto &region = op->getRegion(0);
500  // Allow empty region as a degenerated case, which can come from
501  // optimizations.
502  if (region.empty())
503  return success();
504 
505  // The last block is the merge block.
506  if (!isMergeBlock(region.back()))
507  return emitOpError("last block must be the merge block with only one "
508  "'spirv.mlir.merge' op");
509 
510  if (std::next(region.begin()) == region.end())
511  return emitOpError("must have a selection header block");
512 
513  return success();
514 }
515 
516 Block *SelectionOp::getHeaderBlock() {
517  assert(!getBody().empty() && "op region should not be empty!");
518  // The first block is the loop header block.
519  return &getBody().front();
520 }
521 
522 Block *SelectionOp::getMergeBlock() {
523  assert(!getBody().empty() && "op region should not be empty!");
524  // The last block is the loop merge block.
525  return &getBody().back();
526 }
527 
528 void SelectionOp::addMergeBlock() {
529  assert(getBody().empty() && "entry and merge block already exist");
530  auto *mergeBlock = new Block();
531  getBody().push_back(mergeBlock);
532  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
533 
534  // Add a spirv.mlir.merge op into the merge block.
535  builder.create<spirv::MergeOp>(getLoc());
536 }
537 
538 SelectionOp
539 SelectionOp::createIfThen(Location loc, Value condition,
540  function_ref<void(OpBuilder &builder)> thenBody,
541  OpBuilder &builder) {
542  auto selectionOp =
543  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
544 
545  selectionOp.addMergeBlock();
546  Block *mergeBlock = selectionOp.getMergeBlock();
547  Block *thenBlock = nullptr;
548 
549  // Build the "then" block.
550  {
551  OpBuilder::InsertionGuard guard(builder);
552  thenBlock = builder.createBlock(mergeBlock);
553  thenBody(builder);
554  builder.create<spirv::BranchOp>(loc, mergeBlock);
555  }
556 
557  // Build the header block.
558  {
559  OpBuilder::InsertionGuard guard(builder);
560  builder.createBlock(thenBlock);
561  builder.create<spirv::BranchConditionalOp>(
562  loc, condition, thenBlock,
563  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
564  /*falseArguments=*/ArrayRef<Value>());
565  }
566 
567  return selectionOp;
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // spirv.Unreachable
572 //===----------------------------------------------------------------------===//
573 
574 LogicalResult spirv::UnreachableOp::verify() {
575  auto *block = (*this)->getBlock();
576  // Fast track: if this is in entry block, its invalid. Otherwise, if no
577  // predecessors, it's valid.
578  if (block->isEntryBlock())
579  return emitOpError("cannot be used in reachable block");
580  if (block->hasNoPredecessors())
581  return success();
582 
583  // TODO: further verification needs to analyze reachability from
584  // the entry block.
585 
586  return success();
587 }
588 
589 } // namespace mlir::spirv
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
@ None
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumSuccessors()
Definition: Block.cpp:249
bool empty()
Definition: Block.h:141
Operation & back()
Definition: Block.h:145
Operation & front()
Definition: Block.h:146
iterator end()
Definition: Block.h:137
iterator begin()
Definition: Block.h:136
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:244
OperandRange operand_range
Definition: Operation.h:366
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
constexpr char kCallee[]
constexpr char kBranchWeightAttrName[]
constexpr char kControl[]
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.