27 template <
typename EnumAttrClass,
typename EnumClass>
30 StringRef attrName = spirv::attributeName<EnumClass>()) {
34 spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
41 state.addAttribute(attrName,
42 builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
51 assert(index == 0 &&
"invalid successor index");
60 assert(index < 2 &&
"invalid successor index");
61 return SuccessorOperands(index == kTrueIndex
62 ? getTrueTargetOperandsMutable()
63 : getFalseTargetOperandsMutable());
67 OperationState &result) {
68 auto &builder = parser.getBuilder();
69 OpAsmParser::UnresolvedOperand condInfo;
73 Type boolTy = builder.getI1Type();
74 if (parser.parseOperand(condInfo) ||
75 parser.resolveOperand(condInfo, boolTy, result.operands))
79 if (
succeeded(parser.parseOptionalLSquare())) {
80 IntegerAttr trueWeight, falseWeight;
81 NamedAttrList weights;
83 auto i32Type = builder.getIntegerType(32);
84 if (parser.parseAttribute(trueWeight, i32Type,
"weight", weights) ||
85 parser.parseComma() ||
86 parser.parseAttribute(falseWeight, i32Type,
"weight", weights) ||
87 parser.parseRSquare())
91 builder.getArrayAttr({trueWeight, falseWeight}));
95 SmallVector<Value, 4> trueOperands;
96 if (parser.parseComma() ||
97 parser.parseSuccessorAndUseList(dest, trueOperands))
99 result.addSuccessors(dest);
100 result.addOperands(trueOperands);
103 SmallVector<Value, 4> falseOperands;
104 if (parser.parseComma() ||
105 parser.parseSuccessorAndUseList(dest, falseOperands))
107 result.addSuccessors(dest);
108 result.addOperands(falseOperands);
109 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
110 builder.getDenseI32ArrayAttr(
111 {1, static_cast<int32_t>(trueOperands.size()),
112 static_cast<int32_t>(falseOperands.size())}));
118 printer <<
' ' << getCondition();
120 if (
auto weights = getBranchWeights()) {
122 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
123 printer << llvm::cast<IntegerAttr>(a).getInt();
129 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
131 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
135 if (
auto weights = getBranchWeights()) {
136 if (weights->getValue().size() != 2) {
137 return emitOpError(
"must have exactly two branch weights");
139 if (llvm::all_of(*weights, [](Attribute attr) {
140 return llvm::cast<IntegerAttr>(attr).getValue().isZero();
142 return emitOpError(
"branch weights cannot both be zero");
153 auto fnName = getCalleeAttr();
155 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
158 return emitOpError(
"callee function '")
159 << fnName.getValue() <<
"' not found in nearest symbol table";
162 auto functionType = funcOp.getFunctionType();
164 if (getNumResults() > 1) {
166 "expected callee function to have 0 or 1 result, but provided ")
170 if (functionType.getNumInputs() != getNumOperands()) {
171 return emitOpError(
"has incorrect number of operands for callee: expected ")
172 << functionType.getNumInputs() <<
", but provided "
176 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
177 if (getOperand(i).getType() != functionType.getInput(i)) {
178 return emitOpError(
"operand type mismatch: expected operand type ")
179 << functionType.getInput(i) <<
", but provided "
180 << getOperand(i).getType() <<
" for operand number " << i;
184 if (functionType.getNumResults() != getNumResults()) {
186 "has incorrect number of results has for callee: expected ")
187 << functionType.getNumResults() <<
", but provided "
191 if (getNumResults() &&
192 (getResult(0).getType() != functionType.getResult(0))) {
193 return emitOpError(
"result type mismatch: expected ")
194 << functionType.getResult(0) <<
", but provided "
195 << getResult(0).getType();
201 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
202 return (*this)->getAttrOfType<SymbolRefAttr>(
kCallee);
205 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
206 (*this)->setAttr(
kCallee, callee.get<SymbolRefAttr>());
210 return getArguments();
213 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
214 return getArgumentsMutable();
221 void LoopOp::build(OpBuilder &builder, OperationState &state) {
222 state.addAttribute(
"loop_control", builder.getAttr<spirv::LoopControlAttr>(
227 ParseResult
LoopOp::parse(OpAsmParser &parser, OperationState &result) {
228 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
231 return parser.parseRegion(*result.addRegion(), {});
235 auto control = getLoopControl();
237 printer <<
" control(" << spirv::stringifyLoopControl(control) <<
")";
239 printer.printRegion(getRegion(),
false,
247 if (!llvm::hasSingleElement(srcBlock))
250 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
251 return branchOp && branchOp.getSuccessor() == &dstBlock;
256 return !block.
empty() && std::next(block.
begin()) == block.
end() &&
257 isa<spirv::MergeOp>(block.
front());
261 auto *op = getOperation();
288 auto ®ion = op->getRegion(0);
297 return emitOpError(
"last block must be the merge block with only one "
298 "'spirv.mlir.merge' op");
300 if (std::next(region.begin()) == region.end())
302 "must have an entry block branching to the loop header block");
306 if (std::next(region.begin(), 2) == region.end())
308 "must have a loop header block branched from the entry block");
310 Block &header = *std::next(region.begin(), 1);
314 "entry block must only have one 'spirv.Branch' op to the second block");
316 if (std::next(region.begin(), 3) == region.end())
318 "requires a loop continue block branching to the loop header block");
320 Block &cont = *std::prev(region.end(), 2);
326 [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
327 return emitOpError(
"second to last block must be the loop continue "
328 "block that branches to the loop header block");
332 for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
333 std::prev(region.end(), 2))) {
334 for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
335 if (block.getSuccessor(i) == &header) {
336 return emitOpError(
"can only have the entry and loop continue "
337 "block branching to the loop header block");
345 Block *LoopOp::getEntryBlock() {
346 assert(!getBody().empty() &&
"op region should not be empty!");
347 return &getBody().front();
350 Block *LoopOp::getHeaderBlock() {
351 assert(!getBody().empty() &&
"op region should not be empty!");
353 return &*std::next(getBody().begin());
356 Block *LoopOp::getContinueBlock() {
357 assert(!getBody().empty() &&
"op region should not be empty!");
359 return &*std::prev(getBody().end(), 2);
362 Block *LoopOp::getMergeBlock() {
363 assert(!getBody().empty() &&
"op region should not be empty!");
365 return &getBody().back();
368 void LoopOp::addEntryAndMergeBlock() {
369 assert(getBody().empty() &&
"entry and merge block already exist");
370 getBody().push_back(
new Block());
371 auto *mergeBlock =
new Block();
372 getBody().push_back(mergeBlock);
376 builder.create<spirv::MergeOp>(getLoc());
384 auto *parentOp = (*this)->getParentOp();
385 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
387 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
390 Block &parentLastBlock = (*this)->getParentRegion()->back();
391 if (getOperation() != parentLastBlock.getTerminator())
392 return emitOpError(
"can only be used in the last block of "
393 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
420 if (
auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
421 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
422 if (!resultVectorTy) {
423 return emitOpError(
"result expected to be of vector type when "
424 "condition is of vector type");
426 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
427 return emitOpError(
"result should have the same number of elements as "
428 "the condition when condition is of vector type");
436 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
439 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
442 std::optional<spirv::Version> SelectOp::getMinVersion() {
445 if (isa<spirv::ScalarType>(getCondition().getType()) &&
446 isa<spirv::CompositeType>(getType()))
447 return Version::V_1_4;
449 return Version::V_1_0;
451 std::optional<spirv::Version> SelectOp::getMaxVersion() {
452 return Version::V_1_6;
461 spirv::SelectionControl>(parser, result))
463 return parser.parseRegion(*result.addRegion(), {});
467 auto control = getSelectionControl();
469 printer <<
" control(" << spirv::stringifySelectionControl(control) <<
")";
471 printer.printRegion(getRegion(),
false,
475 LogicalResult SelectionOp::verifyRegions() {
476 auto *op = getOperation();
499 auto ®ion = op->getRegion(0);
507 return emitOpError(
"last block must be the merge block with only one "
508 "'spirv.mlir.merge' op");
510 if (std::next(region.begin()) == region.end())
511 return emitOpError(
"must have a selection header block");
516 Block *SelectionOp::getHeaderBlock() {
517 assert(!getBody().empty() &&
"op region should not be empty!");
519 return &getBody().front();
522 Block *SelectionOp::getMergeBlock() {
523 assert(!getBody().empty() &&
"op region should not be empty!");
525 return &getBody().back();
528 void SelectionOp::addMergeBlock() {
529 assert(getBody().empty() &&
"entry and merge block already exist");
530 auto *mergeBlock =
new Block();
531 getBody().push_back(mergeBlock);
535 builder.create<spirv::MergeOp>(getLoc());
539 SelectionOp::createIfThen(Location loc, Value condition,
541 OpBuilder &builder) {
545 selectionOp.addMergeBlock();
546 Block *mergeBlock = selectionOp.getMergeBlock();
547 Block *thenBlock =
nullptr;
551 OpBuilder::InsertionGuard guard(builder);
552 thenBlock = builder.createBlock(mergeBlock);
554 builder.create<spirv::BranchOp>(loc, mergeBlock);
559 OpBuilder::InsertionGuard guard(builder);
560 builder.createBlock(thenBlock);
561 builder.create<spirv::BranchConditionalOp>(
562 loc, condition, thenBlock,
563 ArrayRef<Value>(), mergeBlock,
575 auto *block = (*this)->getBlock();
578 if (block->isEntryBlock())
579 return emitOpError(
"cannot be used in reachable block");
580 if (block->hasNoPredecessors())
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
This class is a general helper class for creating context-global objects like types,...
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
OperandRange operand_range
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kBranchWeightAttrName[]
constexpr char kControl[]
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
llvm::function_ref< Fn > function_ref
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.