18 #include "llvm/Support/InterleavedRange.h"
29 template <
typename EnumAttrClass,
typename EnumClass>
32 StringRef attrName = spirv::attributeName<EnumClass>()) {
36 spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
43 state.addAttribute(attrName,
44 builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
53 assert(index == 0 &&
"invalid successor index");
62 assert(index < 2 &&
"invalid successor index");
63 return SuccessorOperands(index == kTrueIndex
64 ? getTrueTargetOperandsMutable()
65 : getFalseTargetOperandsMutable());
69 OperationState &result) {
70 auto &builder = parser.getBuilder();
71 OpAsmParser::UnresolvedOperand condInfo;
75 Type boolTy = builder.getI1Type();
76 if (parser.parseOperand(condInfo) ||
77 parser.resolveOperand(condInfo, boolTy, result.operands))
81 if (succeeded(parser.parseOptionalLSquare())) {
82 IntegerAttr trueWeight, falseWeight;
83 NamedAttrList weights;
85 auto i32Type = builder.getIntegerType(32);
86 if (parser.parseAttribute(trueWeight, i32Type,
"weight", weights) ||
87 parser.parseComma() ||
88 parser.parseAttribute(falseWeight, i32Type,
"weight", weights) ||
89 parser.parseRSquare())
92 StringAttr branchWeightsAttrName =
93 BranchConditionalOp::getBranchWeightsAttrName(result.name);
94 result.addAttribute(branchWeightsAttrName,
95 builder.getArrayAttr({trueWeight, falseWeight}));
99 SmallVector<Value, 4> trueOperands;
100 if (parser.parseComma() ||
101 parser.parseSuccessorAndUseList(dest, trueOperands))
103 result.addSuccessors(dest);
104 result.addOperands(trueOperands);
107 SmallVector<Value, 4> falseOperands;
108 if (parser.parseComma() ||
109 parser.parseSuccessorAndUseList(dest, falseOperands))
111 result.addSuccessors(dest);
112 result.addOperands(falseOperands);
113 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
114 builder.getDenseI32ArrayAttr(
115 {1, static_cast<int32_t>(trueOperands.size()),
116 static_cast<int32_t>(falseOperands.size())}));
122 printer <<
' ' << getCondition();
124 if (std::optional<ArrayAttr> weights = getBranchWeights()) {
126 << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
130 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
132 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
136 if (
auto weights = getBranchWeights()) {
137 if (weights->getValue().size() != 2) {
138 return emitOpError(
"must have exactly two branch weights");
140 if (llvm::all_of(*weights, [](Attribute attr) {
141 return llvm::cast<IntegerAttr>(attr).getValue().isZero();
143 return emitOpError(
"branch weights cannot both be zero");
154 if (getNumResults() > 1) {
156 "expected callee function to have 0 or 1 result, but provided ")
163 FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
164 auto fnName = getCalleeAttr();
167 symbolTable.lookupNearestSymbolFrom<spirv::FuncOp>(*
this, fnName);
169 return emitOpError(
"callee function '")
170 << fnName.getValue() <<
"' not found in nearest symbol table";
173 auto functionType = funcOp.getFunctionType();
175 if (functionType.getNumInputs() != getNumOperands()) {
176 return emitOpError(
"has incorrect number of operands for callee: expected ")
177 << functionType.getNumInputs() <<
", but provided "
181 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
182 if (getOperand(i).
getType() != functionType.getInput(i)) {
183 return emitOpError(
"operand type mismatch: expected operand type ")
184 << functionType.getInput(i) <<
", but provided "
185 << getOperand(i).getType() <<
" for operand number " << i;
189 if (functionType.getNumResults() != getNumResults()) {
191 "has incorrect number of results has for callee: expected ")
192 << functionType.getNumResults() <<
", but provided "
196 if (getNumResults() &&
197 (getResult(0).
getType() != functionType.getResult(0))) {
198 return emitOpError(
"result type mismatch: expected ")
199 << functionType.getResult(0) <<
", but provided "
200 << getResult(0).getType();
206 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
207 return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
210 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
211 (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
215 return getArguments();
218 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
219 return getArgumentsMutable();
226 void LoopOp::build(OpBuilder &builder, OperationState &state) {
227 state.addAttribute(
"loop_control", builder.getAttr<spirv::LoopControlAttr>(
232 ParseResult
LoopOp::parse(OpAsmParser &parser, OperationState &result) {
233 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
237 if (succeeded(parser.parseOptionalArrow()))
238 if (parser.parseTypeList(result.types))
241 return parser.parseRegion(*result.addRegion(), {});
245 auto control = getLoopControl();
247 printer <<
" control(" << spirv::stringifyLoopControl(control) <<
")";
248 if (getNumResults() > 0) {
250 printer << getResultTypes();
253 printer.printRegion(getRegion(),
false,
261 if (!llvm::hasSingleElement(srcBlock))
264 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
265 return branchOp && branchOp.getSuccessor() == &dstBlock;
270 return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.
front());
276 return isa<spirv::MergeOp>(op) && op.getBlock() != ®ion.back();
280 LogicalResult LoopOp::verifyRegions() {
281 auto *op = getOperation();
308 auto ®ion = op->getRegion(0);
317 return emitOpError(
"last block must be the merge block with only one "
318 "'spirv.mlir.merge' op");
321 "should not have 'spirv.mlir.merge' op outside the merge block");
323 if (region.hasOneBlock())
325 "must have an entry block branching to the loop header block");
329 if (std::next(region.begin(), 2) == region.end())
331 "must have a loop header block branched from the entry block");
333 Block &header = *std::next(region.begin(), 1);
337 "entry block must only have one 'spirv.Branch' op to the second block");
339 if (std::next(region.begin(), 3) == region.end())
341 "requires a loop continue block branching to the loop header block");
343 Block &cont = *std::prev(region.end(), 2);
349 [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
350 return emitOpError(
"second to last block must be the loop continue "
351 "block that branches to the loop header block");
355 for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
356 std::prev(region.end(), 2))) {
357 for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
358 if (block.getSuccessor(i) == &header) {
359 return emitOpError(
"can only have the entry and loop continue "
360 "block branching to the loop header block");
368 Block *LoopOp::getEntryBlock() {
369 assert(!getBody().empty() &&
"op region should not be empty!");
370 return &getBody().front();
373 Block *LoopOp::getHeaderBlock() {
374 assert(!getBody().empty() &&
"op region should not be empty!");
376 return &*std::next(getBody().begin());
379 Block *LoopOp::getContinueBlock() {
380 assert(!getBody().empty() &&
"op region should not be empty!");
382 return &*std::prev(getBody().end(), 2);
385 Block *LoopOp::getMergeBlock() {
386 assert(!getBody().empty() &&
"op region should not be empty!");
388 return &getBody().back();
391 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
392 assert(getBody().empty() &&
"entry and merge block already exist");
393 OpBuilder::InsertionGuard g(builder);
394 builder.createBlock(&getBody());
395 builder.createBlock(&getBody());
398 spirv::MergeOp::create(builder, getLoc());
424 if (
auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().
getType())) {
425 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().
getType());
426 if (!resultVectorTy) {
427 return emitOpError(
"result expected to be of vector type when "
428 "condition is of vector type");
430 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
431 return emitOpError(
"result should have the same number of elements as "
432 "the condition when condition is of vector type");
440 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
443 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
446 std::optional<spirv::Version> SelectOp::getMinVersion() {
449 if (isa<spirv::ScalarType>(getCondition().
getType()) &&
450 isa<spirv::CompositeType>(
getType()))
451 return Version::V_1_4;
453 return Version::V_1_0;
455 std::optional<spirv::Version> SelectOp::getMaxVersion() {
456 return Version::V_1_6;
465 spirv::SelectionControl>(parser, result))
468 if (succeeded(parser.parseOptionalArrow()))
469 if (parser.parseTypeList(result.types))
472 return parser.parseRegion(*result.addRegion(), {});
476 auto control = getSelectionControl();
478 printer <<
" control(" << spirv::stringifySelectionControl(control) <<
")";
479 if (getNumResults() > 0) {
481 printer << getResultTypes();
484 printer.printRegion(getRegion(),
false,
488 LogicalResult SelectionOp::verifyRegions() {
489 auto *op = getOperation();
512 auto ®ion = op->getRegion(0);
520 return emitOpError(
"last block must be the merge block with only one "
521 "'spirv.mlir.merge' op");
524 "should not have 'spirv.mlir.merge' op outside the merge block");
526 if (region.hasOneBlock())
527 return emitOpError(
"must have a selection header block");
532 Block *SelectionOp::getHeaderBlock() {
533 assert(!getBody().empty() &&
"op region should not be empty!");
535 return &getBody().front();
538 Block *SelectionOp::getMergeBlock() {
539 assert(!getBody().empty() &&
"op region should not be empty!");
541 return &getBody().back();
544 void SelectionOp::addMergeBlock(OpBuilder &builder) {
545 assert(getBody().empty() &&
"entry and merge block already exist");
546 OpBuilder::InsertionGuard guard(builder);
547 builder.createBlock(&getBody());
550 spirv::MergeOp::create(builder, getLoc());
554 SelectionOp::createIfThen(Location loc, Value condition,
556 OpBuilder &builder) {
560 selectionOp.addMergeBlock(builder);
561 Block *mergeBlock = selectionOp.getMergeBlock();
562 Block *thenBlock =
nullptr;
566 OpBuilder::InsertionGuard guard(builder);
567 thenBlock = builder.createBlock(mergeBlock);
569 spirv::BranchOp::create(builder, loc, mergeBlock);
574 OpBuilder::InsertionGuard guard(builder);
575 builder.createBlock(thenBlock);
576 spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
590 auto *block = (*this)->getBlock();
593 if (block->isEntryBlock())
594 return emitOpError(
"cannot be used in reachable block");
595 if (block->hasNoPredecessors())
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
This class is a general helper class for creating context-global objects like types,...
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
This class models how operands are forwarded to block arguments in control flow.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kControl[]
static bool hasOtherMerge(Region ®ion)
Returns true if a spirv.mlir.merge op outside the merge block.
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
llvm::function_ref< Fn > function_ref
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.