27 template <
typename EnumAttrClass,
typename EnumClass>
30 StringRef attrName = spirv::attributeName<EnumClass>()) {
34 spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
41 state.addAttribute(attrName,
42 builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
51 assert(index == 0 &&
"invalid successor index");
60 assert(index < 2 &&
"invalid successor index");
61 return SuccessorOperands(index == kTrueIndex
62 ? getTrueTargetOperandsMutable()
63 : getFalseTargetOperandsMutable());
67 OperationState &result) {
68 auto &builder = parser.getBuilder();
69 OpAsmParser::UnresolvedOperand condInfo;
73 Type boolTy = builder.getI1Type();
74 if (parser.parseOperand(condInfo) ||
75 parser.resolveOperand(condInfo, boolTy, result.operands))
79 if (
succeeded(parser.parseOptionalLSquare())) {
80 IntegerAttr trueWeight, falseWeight;
81 NamedAttrList weights;
83 auto i32Type = builder.getIntegerType(32);
84 if (parser.parseAttribute(trueWeight, i32Type,
"weight", weights) ||
85 parser.parseComma() ||
86 parser.parseAttribute(falseWeight, i32Type,
"weight", weights) ||
87 parser.parseRSquare())
90 StringAttr branchWeightsAttrName =
91 BranchConditionalOp::getBranchWeightsAttrName(result.name);
92 result.addAttribute(branchWeightsAttrName,
93 builder.getArrayAttr({trueWeight, falseWeight}));
97 SmallVector<Value, 4> trueOperands;
98 if (parser.parseComma() ||
99 parser.parseSuccessorAndUseList(dest, trueOperands))
101 result.addSuccessors(dest);
102 result.addOperands(trueOperands);
105 SmallVector<Value, 4> falseOperands;
106 if (parser.parseComma() ||
107 parser.parseSuccessorAndUseList(dest, falseOperands))
109 result.addSuccessors(dest);
110 result.addOperands(falseOperands);
111 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
112 builder.getDenseI32ArrayAttr(
113 {1, static_cast<int32_t>(trueOperands.size()),
114 static_cast<int32_t>(falseOperands.size())}));
120 printer <<
' ' << getCondition();
122 if (
auto weights = getBranchWeights()) {
124 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
125 printer << llvm::cast<IntegerAttr>(a).getInt();
131 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
133 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
137 if (
auto weights = getBranchWeights()) {
138 if (weights->getValue().size() != 2) {
139 return emitOpError(
"must have exactly two branch weights");
141 if (llvm::all_of(*weights, [](Attribute attr) {
142 return llvm::cast<IntegerAttr>(attr).getValue().isZero();
144 return emitOpError(
"branch weights cannot both be zero");
155 auto fnName = getCalleeAttr();
157 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
160 return emitOpError(
"callee function '")
161 << fnName.getValue() <<
"' not found in nearest symbol table";
164 auto functionType = funcOp.getFunctionType();
166 if (getNumResults() > 1) {
168 "expected callee function to have 0 or 1 result, but provided ")
172 if (functionType.getNumInputs() != getNumOperands()) {
173 return emitOpError(
"has incorrect number of operands for callee: expected ")
174 << functionType.getNumInputs() <<
", but provided "
178 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
179 if (getOperand(i).getType() != functionType.getInput(i)) {
180 return emitOpError(
"operand type mismatch: expected operand type ")
181 << functionType.getInput(i) <<
", but provided "
182 << getOperand(i).getType() <<
" for operand number " << i;
186 if (functionType.getNumResults() != getNumResults()) {
188 "has incorrect number of results has for callee: expected ")
189 << functionType.getNumResults() <<
", but provided "
193 if (getNumResults() &&
194 (getResult(0).getType() != functionType.getResult(0))) {
195 return emitOpError(
"result type mismatch: expected ")
196 << functionType.getResult(0) <<
", but provided "
197 << getResult(0).getType();
203 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
204 return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
207 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
208 (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
212 return getArguments();
215 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
216 return getArgumentsMutable();
223 void LoopOp::build(OpBuilder &builder, OperationState &state) {
224 state.addAttribute(
"loop_control", builder.getAttr<spirv::LoopControlAttr>(
229 ParseResult
LoopOp::parse(OpAsmParser &parser, OperationState &result) {
230 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
233 return parser.parseRegion(*result.addRegion(), {});
237 auto control = getLoopControl();
239 printer <<
" control(" << spirv::stringifyLoopControl(control) <<
")";
241 printer.printRegion(getRegion(),
false,
249 if (!llvm::hasSingleElement(srcBlock))
252 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
253 return branchOp && branchOp.getSuccessor() == &dstBlock;
258 return !block.
empty() && std::next(block.
begin()) == block.
end() &&
259 isa<spirv::MergeOp>(block.
front());
263 auto *op = getOperation();
290 auto ®ion = op->getRegion(0);
299 return emitOpError(
"last block must be the merge block with only one "
300 "'spirv.mlir.merge' op");
302 if (std::next(region.begin()) == region.end())
304 "must have an entry block branching to the loop header block");
308 if (std::next(region.begin(), 2) == region.end())
310 "must have a loop header block branched from the entry block");
312 Block &header = *std::next(region.begin(), 1);
316 "entry block must only have one 'spirv.Branch' op to the second block");
318 if (std::next(region.begin(), 3) == region.end())
320 "requires a loop continue block branching to the loop header block");
322 Block &cont = *std::prev(region.end(), 2);
328 [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
329 return emitOpError(
"second to last block must be the loop continue "
330 "block that branches to the loop header block");
334 for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
335 std::prev(region.end(), 2))) {
336 for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
337 if (block.getSuccessor(i) == &header) {
338 return emitOpError(
"can only have the entry and loop continue "
339 "block branching to the loop header block");
347 Block *LoopOp::getEntryBlock() {
348 assert(!getBody().empty() &&
"op region should not be empty!");
349 return &getBody().front();
352 Block *LoopOp::getHeaderBlock() {
353 assert(!getBody().empty() &&
"op region should not be empty!");
355 return &*std::next(getBody().begin());
358 Block *LoopOp::getContinueBlock() {
359 assert(!getBody().empty() &&
"op region should not be empty!");
361 return &*std::prev(getBody().end(), 2);
364 Block *LoopOp::getMergeBlock() {
365 assert(!getBody().empty() &&
"op region should not be empty!");
367 return &getBody().back();
370 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
371 assert(getBody().empty() &&
"entry and merge block already exist");
372 OpBuilder::InsertionGuard g(builder);
373 builder.createBlock(&getBody());
374 builder.createBlock(&getBody());
377 builder.create<spirv::MergeOp>(getLoc());
385 auto *parentOp = (*this)->getParentOp();
386 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
388 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
391 Block &parentLastBlock = (*this)->getParentRegion()->back();
392 if (getOperation() != parentLastBlock.getTerminator())
393 return emitOpError(
"can only be used in the last block of "
394 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
421 if (
auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
422 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
423 if (!resultVectorTy) {
424 return emitOpError(
"result expected to be of vector type when "
425 "condition is of vector type");
427 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
428 return emitOpError(
"result should have the same number of elements as "
429 "the condition when condition is of vector type");
437 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
440 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
443 std::optional<spirv::Version> SelectOp::getMinVersion() {
446 if (isa<spirv::ScalarType>(getCondition().getType()) &&
447 isa<spirv::CompositeType>(getType()))
448 return Version::V_1_4;
450 return Version::V_1_0;
452 std::optional<spirv::Version> SelectOp::getMaxVersion() {
453 return Version::V_1_6;
462 spirv::SelectionControl>(parser, result))
464 return parser.parseRegion(*result.addRegion(), {});
468 auto control = getSelectionControl();
470 printer <<
" control(" << spirv::stringifySelectionControl(control) <<
")";
472 printer.printRegion(getRegion(),
false,
476 LogicalResult SelectionOp::verifyRegions() {
477 auto *op = getOperation();
500 auto ®ion = op->getRegion(0);
508 return emitOpError(
"last block must be the merge block with only one "
509 "'spirv.mlir.merge' op");
511 if (std::next(region.begin()) == region.end())
512 return emitOpError(
"must have a selection header block");
517 Block *SelectionOp::getHeaderBlock() {
518 assert(!getBody().empty() &&
"op region should not be empty!");
520 return &getBody().front();
523 Block *SelectionOp::getMergeBlock() {
524 assert(!getBody().empty() &&
"op region should not be empty!");
526 return &getBody().back();
529 void SelectionOp::addMergeBlock(OpBuilder &builder) {
530 assert(getBody().empty() &&
"entry and merge block already exist");
531 OpBuilder::InsertionGuard guard(builder);
532 builder.createBlock(&getBody());
535 builder.create<spirv::MergeOp>(getLoc());
539 SelectionOp::createIfThen(Location loc, Value condition,
541 OpBuilder &builder) {
545 selectionOp.addMergeBlock(builder);
546 Block *mergeBlock = selectionOp.getMergeBlock();
547 Block *thenBlock =
nullptr;
551 OpBuilder::InsertionGuard guard(builder);
552 thenBlock = builder.createBlock(mergeBlock);
554 builder.create<spirv::BranchOp>(loc, mergeBlock);
559 OpBuilder::InsertionGuard guard(builder);
560 builder.createBlock(thenBlock);
561 builder.create<spirv::BranchConditionalOp>(
562 loc, condition, thenBlock,
563 ArrayRef<Value>(), mergeBlock,
575 auto *block = (*this)->getBlock();
578 if (block->isEntryBlock())
579 return emitOpError(
"cannot be used in reachable block");
580 if (block->hasNoPredecessors())
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
This class is a general helper class for creating context-global objects like types,...
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
OperandRange operand_range
This class models how operands are forwarded to block arguments in control flow.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kControl[]
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
llvm::function_ref< Fn > function_ref
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.