MLIR  19.0.0git
ControlFlowOps.cpp
Go to the documentation of this file.
1 //===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the control flow operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 
18 #include "SPIRVOpUtils.h"
19 #include "SPIRVParsingUtils.h"
20 
21 using namespace mlir::spirv::AttrNames;
22 
23 namespace mlir::spirv {
24 
25 /// Parses Function, Selection and Loop control attributes. If no control is
26 /// specified, "None" is used as a default.
27 template <typename EnumAttrClass, typename EnumClass>
28 static ParseResult
30  StringRef attrName = spirv::attributeName<EnumClass>()) {
32  EnumClass control;
33  if (parser.parseLParen() ||
34  spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
35  parser.parseRParen())
36  return failure();
37  return success();
38  }
39  // Set control to "None" otherwise.
40  Builder builder = parser.getBuilder();
41  state.addAttribute(attrName,
42  builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
43  return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // spirv.BranchOp
48 //===----------------------------------------------------------------------===//
49 
51  assert(index == 0 && "invalid successor index");
52  return SuccessorOperands(0, getTargetOperandsMutable());
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // spirv.BranchConditionalOp
57 //===----------------------------------------------------------------------===//
58 
59 SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
60  assert(index < 2 && "invalid successor index");
61  return SuccessorOperands(index == kTrueIndex
62  ? getTrueTargetOperandsMutable()
63  : getFalseTargetOperandsMutable());
64 }
65 
66 ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
67  OperationState &result) {
68  auto &builder = parser.getBuilder();
69  OpAsmParser::UnresolvedOperand condInfo;
70  Block *dest;
71 
72  // Parse the condition.
73  Type boolTy = builder.getI1Type();
74  if (parser.parseOperand(condInfo) ||
75  parser.resolveOperand(condInfo, boolTy, result.operands))
76  return failure();
77 
78  // Parse the optional branch weights.
79  if (succeeded(parser.parseOptionalLSquare())) {
80  IntegerAttr trueWeight, falseWeight;
81  NamedAttrList weights;
82 
83  auto i32Type = builder.getIntegerType(32);
84  if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
85  parser.parseComma() ||
86  parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
87  parser.parseRSquare())
88  return failure();
89 
90  StringAttr branchWeightsAttrName =
91  BranchConditionalOp::getBranchWeightsAttrName(result.name);
92  result.addAttribute(branchWeightsAttrName,
93  builder.getArrayAttr({trueWeight, falseWeight}));
94  }
95 
96  // Parse the true branch.
97  SmallVector<Value, 4> trueOperands;
98  if (parser.parseComma() ||
99  parser.parseSuccessorAndUseList(dest, trueOperands))
100  return failure();
101  result.addSuccessors(dest);
102  result.addOperands(trueOperands);
103 
104  // Parse the false branch.
105  SmallVector<Value, 4> falseOperands;
106  if (parser.parseComma() ||
107  parser.parseSuccessorAndUseList(dest, falseOperands))
108  return failure();
109  result.addSuccessors(dest);
110  result.addOperands(falseOperands);
111  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
112  builder.getDenseI32ArrayAttr(
113  {1, static_cast<int32_t>(trueOperands.size()),
114  static_cast<int32_t>(falseOperands.size())}));
115 
116  return success();
117 }
118 
119 void BranchConditionalOp::print(OpAsmPrinter &printer) {
120  printer << ' ' << getCondition();
121 
122  if (auto weights = getBranchWeights()) {
123  printer << " [";
124  llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
125  printer << llvm::cast<IntegerAttr>(a).getInt();
126  });
127  printer << "]";
128  }
129 
130  printer << ", ";
131  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
132  printer << ", ";
133  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
134 }
135 
136 LogicalResult BranchConditionalOp::verify() {
137  if (auto weights = getBranchWeights()) {
138  if (weights->getValue().size() != 2) {
139  return emitOpError("must have exactly two branch weights");
140  }
141  if (llvm::all_of(*weights, [](Attribute attr) {
142  return llvm::cast<IntegerAttr>(attr).getValue().isZero();
143  }))
144  return emitOpError("branch weights cannot both be zero");
145  }
146 
147  return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // spirv.FunctionCall
152 //===----------------------------------------------------------------------===//
153 
154 LogicalResult FunctionCallOp::verify() {
155  auto fnName = getCalleeAttr();
156 
157  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
158  SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
159  if (!funcOp) {
160  return emitOpError("callee function '")
161  << fnName.getValue() << "' not found in nearest symbol table";
162  }
163 
164  auto functionType = funcOp.getFunctionType();
165 
166  if (getNumResults() > 1) {
167  return emitOpError(
168  "expected callee function to have 0 or 1 result, but provided ")
169  << getNumResults();
170  }
171 
172  if (functionType.getNumInputs() != getNumOperands()) {
173  return emitOpError("has incorrect number of operands for callee: expected ")
174  << functionType.getNumInputs() << ", but provided "
175  << getNumOperands();
176  }
177 
178  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
179  if (getOperand(i).getType() != functionType.getInput(i)) {
180  return emitOpError("operand type mismatch: expected operand type ")
181  << functionType.getInput(i) << ", but provided "
182  << getOperand(i).getType() << " for operand number " << i;
183  }
184  }
185 
186  if (functionType.getNumResults() != getNumResults()) {
187  return emitOpError(
188  "has incorrect number of results has for callee: expected ")
189  << functionType.getNumResults() << ", but provided "
190  << getNumResults();
191  }
192 
193  if (getNumResults() &&
194  (getResult(0).getType() != functionType.getResult(0))) {
195  return emitOpError("result type mismatch: expected ")
196  << functionType.getResult(0) << ", but provided "
197  << getResult(0).getType();
198  }
199 
200  return success();
201 }
202 
203 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
204  return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
205 }
206 
207 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
208  (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
209 }
210 
211 Operation::operand_range FunctionCallOp::getArgOperands() {
212  return getArguments();
213 }
214 
215 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
216  return getArgumentsMutable();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // spirv.mlir.loop
221 //===----------------------------------------------------------------------===//
222 
223 void LoopOp::build(OpBuilder &builder, OperationState &state) {
224  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
226  state.addRegion();
227 }
228 
229 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
230  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
231  result))
232  return failure();
233  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
234 }
235 
236 void LoopOp::print(OpAsmPrinter &printer) {
237  auto control = getLoopControl();
238  if (control != spirv::LoopControl::None)
239  printer << " control(" << spirv::stringifyLoopControl(control) << ")";
240  printer << ' ';
241  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
242  /*printBlockTerminators=*/true);
243 }
244 
245 /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
246 /// given `dstBlock`.
247 static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
248  // Check that there is only one op in the `srcBlock`.
249  if (!llvm::hasSingleElement(srcBlock))
250  return false;
251 
252  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
253  return branchOp && branchOp.getSuccessor() == &dstBlock;
254 }
255 
256 /// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
257 static bool isMergeBlock(Block &block) {
258  return !block.empty() && std::next(block.begin()) == block.end() &&
259  isa<spirv::MergeOp>(block.front());
260 }
261 
262 LogicalResult LoopOp::verifyRegions() {
263  auto *op = getOperation();
264 
265  // We need to verify that the blocks follow the following layout:
266  //
267  // +-------------+
268  // | entry block |
269  // +-------------+
270  // |
271  // v
272  // +-------------+
273  // | loop header | <-----+
274  // +-------------+ |
275  // |
276  // ... |
277  // \ | / |
278  // v |
279  // +---------------+ |
280  // | loop continue | -----+
281  // +---------------+
282  //
283  // ...
284  // \ | /
285  // v
286  // +-------------+
287  // | merge block |
288  // +-------------+
289 
290  auto &region = op->getRegion(0);
291  // Allow empty region as a degenerated case, which can come from
292  // optimizations.
293  if (region.empty())
294  return success();
295 
296  // The last block is the merge block.
297  Block &merge = region.back();
298  if (!isMergeBlock(merge))
299  return emitOpError("last block must be the merge block with only one "
300  "'spirv.mlir.merge' op");
301 
302  if (std::next(region.begin()) == region.end())
303  return emitOpError(
304  "must have an entry block branching to the loop header block");
305  // The first block is the entry block.
306  Block &entry = region.front();
307 
308  if (std::next(region.begin(), 2) == region.end())
309  return emitOpError(
310  "must have a loop header block branched from the entry block");
311  // The second block is the loop header block.
312  Block &header = *std::next(region.begin(), 1);
313 
314  if (!hasOneBranchOpTo(entry, header))
315  return emitOpError(
316  "entry block must only have one 'spirv.Branch' op to the second block");
317 
318  if (std::next(region.begin(), 3) == region.end())
319  return emitOpError(
320  "requires a loop continue block branching to the loop header block");
321  // The second to last block is the loop continue block.
322  Block &cont = *std::prev(region.end(), 2);
323 
324  // Make sure that we have a branch from the loop continue block to the loop
325  // header block.
326  if (llvm::none_of(
327  llvm::seq<unsigned>(0, cont.getNumSuccessors()),
328  [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
329  return emitOpError("second to last block must be the loop continue "
330  "block that branches to the loop header block");
331 
332  // Make sure that no other blocks (except the entry and loop continue block)
333  // branches to the loop header block.
334  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
335  std::prev(region.end(), 2))) {
336  for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
337  if (block.getSuccessor(i) == &header) {
338  return emitOpError("can only have the entry and loop continue "
339  "block branching to the loop header block");
340  }
341  }
342  }
343 
344  return success();
345 }
346 
347 Block *LoopOp::getEntryBlock() {
348  assert(!getBody().empty() && "op region should not be empty!");
349  return &getBody().front();
350 }
351 
352 Block *LoopOp::getHeaderBlock() {
353  assert(!getBody().empty() && "op region should not be empty!");
354  // The second block is the loop header block.
355  return &*std::next(getBody().begin());
356 }
357 
358 Block *LoopOp::getContinueBlock() {
359  assert(!getBody().empty() && "op region should not be empty!");
360  // The second to last block is the loop continue block.
361  return &*std::prev(getBody().end(), 2);
362 }
363 
364 Block *LoopOp::getMergeBlock() {
365  assert(!getBody().empty() && "op region should not be empty!");
366  // The last block is the loop merge block.
367  return &getBody().back();
368 }
369 
370 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
371  assert(getBody().empty() && "entry and merge block already exist");
372  OpBuilder::InsertionGuard g(builder);
373  builder.createBlock(&getBody());
374  builder.createBlock(&getBody());
375 
376  // Add a spirv.mlir.merge op into the merge block.
377  builder.create<spirv::MergeOp>(getLoc());
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // spirv.mlir.merge
382 //===----------------------------------------------------------------------===//
383 
384 LogicalResult MergeOp::verify() {
385  auto *parentOp = (*this)->getParentOp();
386  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
387  return emitOpError(
388  "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
389 
390  // TODO: This check should be done in `verifyRegions` of parent op.
391  Block &parentLastBlock = (*this)->getParentRegion()->back();
392  if (getOperation() != parentLastBlock.getTerminator())
393  return emitOpError("can only be used in the last block of "
394  "'spirv.mlir.selection' or 'spirv.mlir.loop'");
395  return success();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // spirv.Return
400 //===----------------------------------------------------------------------===//
401 
402 LogicalResult ReturnOp::verify() {
403  // Verification is performed in spirv.func op.
404  return success();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // spirv.ReturnValue
409 //===----------------------------------------------------------------------===//
410 
411 LogicalResult ReturnValueOp::verify() {
412  // Verification is performed in spirv.func op.
413  return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // spirv.Select
418 //===----------------------------------------------------------------------===//
419 
420 LogicalResult SelectOp::verify() {
421  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
422  auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
423  if (!resultVectorTy) {
424  return emitOpError("result expected to be of vector type when "
425  "condition is of vector type");
426  }
427  if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
428  return emitOpError("result should have the same number of elements as "
429  "the condition when condition is of vector type");
430  }
431  }
432  return success();
433 }
434 
435 // Custom availability implementation is needed for spirv.Select given the
436 // syntax changes starting v1.4.
437 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
438  return {};
439 }
440 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
441  return {};
442 }
443 std::optional<spirv::Version> SelectOp::getMinVersion() {
444  // Per the spec, "Before version 1.4, results are only computed per
445  // component."
446  if (isa<spirv::ScalarType>(getCondition().getType()) &&
447  isa<spirv::CompositeType>(getType()))
448  return Version::V_1_4;
449 
450  return Version::V_1_0;
451 }
452 std::optional<spirv::Version> SelectOp::getMaxVersion() {
453  return Version::V_1_6;
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // spirv.mlir.selection
458 //===----------------------------------------------------------------------===//
459 
460 ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
461  if (parseControlAttribute<spirv::SelectionControlAttr,
462  spirv::SelectionControl>(parser, result))
463  return failure();
464  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
465 }
466 
467 void SelectionOp::print(OpAsmPrinter &printer) {
468  auto control = getSelectionControl();
469  if (control != spirv::SelectionControl::None)
470  printer << " control(" << spirv::stringifySelectionControl(control) << ")";
471  printer << ' ';
472  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
473  /*printBlockTerminators=*/true);
474 }
475 
476 LogicalResult SelectionOp::verifyRegions() {
477  auto *op = getOperation();
478 
479  // We need to verify that the blocks follow the following layout:
480  //
481  // +--------------+
482  // | header block |
483  // +--------------+
484  // / | \
485  // ...
486  //
487  //
488  // +---------+ +---------+ +---------+
489  // | case #0 | | case #1 | | case #2 | ...
490  // +---------+ +---------+ +---------+
491  //
492  //
493  // ...
494  // \ | /
495  // v
496  // +-------------+
497  // | merge block |
498  // +-------------+
499 
500  auto &region = op->getRegion(0);
501  // Allow empty region as a degenerated case, which can come from
502  // optimizations.
503  if (region.empty())
504  return success();
505 
506  // The last block is the merge block.
507  if (!isMergeBlock(region.back()))
508  return emitOpError("last block must be the merge block with only one "
509  "'spirv.mlir.merge' op");
510 
511  if (std::next(region.begin()) == region.end())
512  return emitOpError("must have a selection header block");
513 
514  return success();
515 }
516 
517 Block *SelectionOp::getHeaderBlock() {
518  assert(!getBody().empty() && "op region should not be empty!");
519  // The first block is the loop header block.
520  return &getBody().front();
521 }
522 
523 Block *SelectionOp::getMergeBlock() {
524  assert(!getBody().empty() && "op region should not be empty!");
525  // The last block is the loop merge block.
526  return &getBody().back();
527 }
528 
529 void SelectionOp::addMergeBlock(OpBuilder &builder) {
530  assert(getBody().empty() && "entry and merge block already exist");
531  OpBuilder::InsertionGuard guard(builder);
532  builder.createBlock(&getBody());
533 
534  // Add a spirv.mlir.merge op into the merge block.
535  builder.create<spirv::MergeOp>(getLoc());
536 }
537 
538 SelectionOp
539 SelectionOp::createIfThen(Location loc, Value condition,
540  function_ref<void(OpBuilder &builder)> thenBody,
541  OpBuilder &builder) {
542  auto selectionOp =
543  builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
544 
545  selectionOp.addMergeBlock(builder);
546  Block *mergeBlock = selectionOp.getMergeBlock();
547  Block *thenBlock = nullptr;
548 
549  // Build the "then" block.
550  {
551  OpBuilder::InsertionGuard guard(builder);
552  thenBlock = builder.createBlock(mergeBlock);
553  thenBody(builder);
554  builder.create<spirv::BranchOp>(loc, mergeBlock);
555  }
556 
557  // Build the header block.
558  {
559  OpBuilder::InsertionGuard guard(builder);
560  builder.createBlock(thenBlock);
561  builder.create<spirv::BranchConditionalOp>(
562  loc, condition, thenBlock,
563  /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
564  /*falseArguments=*/ArrayRef<Value>());
565  }
566 
567  return selectionOp;
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // spirv.Unreachable
572 //===----------------------------------------------------------------------===//
573 
574 LogicalResult spirv::UnreachableOp::verify() {
575  auto *block = (*this)->getBlock();
576  // Fast track: if this is in entry block, its invalid. Otherwise, if no
577  // predecessors, it's valid.
578  if (block->isEntryBlock())
579  return emitOpError("cannot be used in reachable block");
580  if (block->hasNoPredecessors())
581  return success();
582 
583  // TODO: further verification needs to analyze reachability from
584  // the entry block.
585 
586  return success();
587 }
588 
589 } // namespace mlir::spirv
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
@ None
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
Definition: Block.h:31
unsigned getNumSuccessors()
Definition: Block.cpp:254
bool empty()
Definition: Block.h:146
Operation & back()
Definition: Block.h:150
Operation & front()
Definition: Block.h:151
iterator end()
Definition: Block.h:142
iterator begin()
Definition: Block.h:141
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
OperandRange operand_range
Definition: Operation.h:366
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
constexpr char kControl[]
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:147
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.