27 template <
typename EnumAttrClass,
typename EnumClass>
30 StringRef attrName = spirv::attributeName<EnumClass>()) {
34 spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
41 state.addAttribute(attrName,
42 builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
51 assert(index == 0 &&
"invalid successor index");
60 assert(index < 2 &&
"invalid successor index");
61 return SuccessorOperands(index == kTrueIndex
62 ? getTrueTargetOperandsMutable()
63 : getFalseTargetOperandsMutable());
67 OperationState &result) {
68 auto &builder = parser.getBuilder();
69 OpAsmParser::UnresolvedOperand condInfo;
73 Type boolTy = builder.getI1Type();
74 if (parser.parseOperand(condInfo) ||
75 parser.resolveOperand(condInfo, boolTy, result.operands))
79 if (succeeded(parser.parseOptionalLSquare())) {
80 IntegerAttr trueWeight, falseWeight;
81 NamedAttrList weights;
83 auto i32Type = builder.getIntegerType(32);
84 if (parser.parseAttribute(trueWeight, i32Type,
"weight", weights) ||
85 parser.parseComma() ||
86 parser.parseAttribute(falseWeight, i32Type,
"weight", weights) ||
87 parser.parseRSquare())
90 StringAttr branchWeightsAttrName =
91 BranchConditionalOp::getBranchWeightsAttrName(result.name);
92 result.addAttribute(branchWeightsAttrName,
93 builder.getArrayAttr({trueWeight, falseWeight}));
97 SmallVector<Value, 4> trueOperands;
98 if (parser.parseComma() ||
99 parser.parseSuccessorAndUseList(dest, trueOperands))
101 result.addSuccessors(dest);
102 result.addOperands(trueOperands);
105 SmallVector<Value, 4> falseOperands;
106 if (parser.parseComma() ||
107 parser.parseSuccessorAndUseList(dest, falseOperands))
109 result.addSuccessors(dest);
110 result.addOperands(falseOperands);
111 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
112 builder.getDenseI32ArrayAttr(
113 {1, static_cast<int32_t>(trueOperands.size()),
114 static_cast<int32_t>(falseOperands.size())}));
120 printer <<
' ' << getCondition();
122 if (
auto weights = getBranchWeights()) {
124 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
125 printer << llvm::cast<IntegerAttr>(a).getInt();
131 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
133 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
137 if (
auto weights = getBranchWeights()) {
138 if (weights->getValue().size() != 2) {
139 return emitOpError(
"must have exactly two branch weights");
141 if (llvm::all_of(*weights, [](Attribute attr) {
142 return llvm::cast<IntegerAttr>(attr).getValue().isZero();
144 return emitOpError(
"branch weights cannot both be zero");
155 auto fnName = getCalleeAttr();
157 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
160 return emitOpError(
"callee function '")
161 << fnName.getValue() <<
"' not found in nearest symbol table";
164 auto functionType = funcOp.getFunctionType();
166 if (getNumResults() > 1) {
168 "expected callee function to have 0 or 1 result, but provided ")
172 if (functionType.getNumInputs() != getNumOperands()) {
173 return emitOpError(
"has incorrect number of operands for callee: expected ")
174 << functionType.getNumInputs() <<
", but provided "
178 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
179 if (getOperand(i).
getType() != functionType.getInput(i)) {
180 return emitOpError(
"operand type mismatch: expected operand type ")
181 << functionType.getInput(i) <<
", but provided "
182 << getOperand(i).getType() <<
" for operand number " << i;
186 if (functionType.getNumResults() != getNumResults()) {
188 "has incorrect number of results has for callee: expected ")
189 << functionType.getNumResults() <<
", but provided "
193 if (getNumResults() &&
194 (getResult(0).
getType() != functionType.getResult(0))) {
195 return emitOpError(
"result type mismatch: expected ")
196 << functionType.getResult(0) <<
", but provided "
197 << getResult(0).getType();
203 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
204 return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
207 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
208 (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
212 return getArguments();
215 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
216 return getArgumentsMutable();
223 void LoopOp::build(OpBuilder &builder, OperationState &state) {
224 state.addAttribute(
"loop_control", builder.getAttr<spirv::LoopControlAttr>(
229 ParseResult
LoopOp::parse(OpAsmParser &parser, OperationState &result) {
230 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
233 return parser.parseRegion(*result.addRegion(), {});
237 auto control = getLoopControl();
239 printer <<
" control(" << spirv::stringifyLoopControl(control) <<
")";
241 printer.printRegion(getRegion(),
false,
249 if (!llvm::hasSingleElement(srcBlock))
252 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
253 return branchOp && branchOp.getSuccessor() == &dstBlock;
258 return !block.
empty() && std::next(block.
begin()) == block.
end() &&
259 isa<spirv::MergeOp>(block.
front());
262 LogicalResult LoopOp::verifyRegions() {
263 auto *op = getOperation();
290 auto ®ion = op->getRegion(0);
299 return emitOpError(
"last block must be the merge block with only one "
300 "'spirv.mlir.merge' op");
302 if (std::next(region.begin()) == region.end())
304 "must have an entry block branching to the loop header block");
308 if (std::next(region.begin(), 2) == region.end())
310 "must have a loop header block branched from the entry block");
312 Block &header = *std::next(region.begin(), 1);
316 "entry block must only have one 'spirv.Branch' op to the second block");
318 if (std::next(region.begin(), 3) == region.end())
320 "requires a loop continue block branching to the loop header block");
322 Block &cont = *std::prev(region.end(), 2);
328 [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
329 return emitOpError(
"second to last block must be the loop continue "
330 "block that branches to the loop header block");
334 for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
335 std::prev(region.end(), 2))) {
336 for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
337 if (block.getSuccessor(i) == &header) {
338 return emitOpError(
"can only have the entry and loop continue "
339 "block branching to the loop header block");
347 Block *LoopOp::getEntryBlock() {
348 assert(!getBody().empty() &&
"op region should not be empty!");
349 return &getBody().front();
352 Block *LoopOp::getHeaderBlock() {
353 assert(!getBody().empty() &&
"op region should not be empty!");
355 return &*std::next(getBody().begin());
358 Block *LoopOp::getContinueBlock() {
359 assert(!getBody().empty() &&
"op region should not be empty!");
361 return &*std::prev(getBody().end(), 2);
364 Block *LoopOp::getMergeBlock() {
365 assert(!getBody().empty() &&
"op region should not be empty!");
367 return &getBody().back();
370 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
371 assert(getBody().empty() &&
"entry and merge block already exist");
372 OpBuilder::InsertionGuard g(builder);
373 builder.createBlock(&getBody());
374 builder.createBlock(&getBody());
377 builder.create<spirv::MergeOp>(getLoc());
385 auto *parentOp = (*this)->getParentOp();
386 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
388 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
391 Block &parentLastBlock = (*this)->getParentRegion()->back();
392 if (getOperation() != parentLastBlock.getTerminator())
393 return emitOpError(
"can only be used in the last block of "
394 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
421 if (
auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().
getType())) {
422 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().
getType());
423 if (!resultVectorTy) {
424 return emitOpError(
"result expected to be of vector type when "
425 "condition is of vector type");
427 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
428 return emitOpError(
"result should have the same number of elements as "
429 "the condition when condition is of vector type");
437 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
440 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
443 std::optional<spirv::Version> SelectOp::getMinVersion() {
446 if (isa<spirv::ScalarType>(getCondition().
getType()) &&
447 isa<spirv::CompositeType>(
getType()))
448 return Version::V_1_4;
450 return Version::V_1_0;
452 std::optional<spirv::Version> SelectOp::getMaxVersion() {
453 return Version::V_1_6;
462 spirv::SelectionControl>(parser, result))
464 return parser.parseRegion(*result.addRegion(), {});
468 auto control = getSelectionControl();
470 printer <<
" control(" << spirv::stringifySelectionControl(control) <<
")";
472 printer.printRegion(getRegion(),
false,
476 LogicalResult SelectionOp::verifyRegions() {
477 auto *op = getOperation();
500 auto ®ion = op->getRegion(0);
508 return emitOpError(
"last block must be the merge block with only one "
509 "'spirv.mlir.merge' op");
511 if (std::next(region.begin()) == region.end())
512 return emitOpError(
"must have a selection header block");
517 Block *SelectionOp::getHeaderBlock() {
518 assert(!getBody().empty() &&
"op region should not be empty!");
520 return &getBody().front();
523 Block *SelectionOp::getMergeBlock() {
524 assert(!getBody().empty() &&
"op region should not be empty!");
526 return &getBody().back();
529 void SelectionOp::addMergeBlock(OpBuilder &builder) {
530 assert(getBody().empty() &&
"entry and merge block already exist");
531 OpBuilder::InsertionGuard guard(builder);
532 builder.createBlock(&getBody());
535 builder.create<spirv::MergeOp>(getLoc());
539 SelectionOp::createIfThen(Location loc, Value condition,
541 OpBuilder &builder) {
545 selectionOp.addMergeBlock(builder);
546 Block *mergeBlock = selectionOp.getMergeBlock();
547 Block *thenBlock =
nullptr;
551 OpBuilder::InsertionGuard guard(builder);
552 thenBlock = builder.createBlock(mergeBlock);
554 builder.create<spirv::BranchOp>(loc, mergeBlock);
559 OpBuilder::InsertionGuard guard(builder);
560 builder.createBlock(thenBlock);
561 builder.create<spirv::BranchConditionalOp>(
562 loc, condition, thenBlock,
563 ArrayRef<Value>(), mergeBlock,
575 auto *block = (*this)->getBlock();
578 if (block->isEntryBlock())
579 return emitOpError(
"cannot be used in reachable block");
580 if (block->hasNoPredecessors())
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
This class is a general helper class for creating context-global objects like types,...
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
OperandRange operand_range
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kControl[]
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
llvm::function_ref< Fn > function_ref
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.