18#include "llvm/Support/InterleavedRange.h"
29template <
typename EnumAttrClass,
typename EnumClass>
44 builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
53 assert(
index == 0 &&
"invalid successor index");
62 assert(
index < 2 &&
"invalid successor index");
64 ? getTrueTargetOperandsMutable()
65 : getFalseTargetOperandsMutable());
68ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
71 OpAsmParser::UnresolvedOperand condInfo;
82 IntegerAttr trueWeight, falseWeight;
83 NamedAttrList weights;
85 auto i32Type = builder.getIntegerType(32);
86 if (parser.
parseAttribute(trueWeight, i32Type,
"weight", weights) ||
92 StringAttr branchWeightsAttrName =
93 BranchConditionalOp::getBranchWeightsAttrName(
result.name);
94 result.addAttribute(branchWeightsAttrName,
95 builder.getArrayAttr({trueWeight, falseWeight}));
99 SmallVector<Value, 4> trueOperands;
103 result.addSuccessors(dest);
104 result.addOperands(trueOperands);
107 SmallVector<Value, 4> falseOperands;
111 result.addSuccessors(dest);
112 result.addOperands(falseOperands);
113 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
114 builder.getDenseI32ArrayAttr(
115 {1, static_cast<int32_t>(trueOperands.size()),
116 static_cast<int32_t>(falseOperands.size())}));
121void BranchConditionalOp::print(OpAsmPrinter &printer) {
122 printer <<
' ' << getCondition();
124 if (std::optional<ArrayAttr> weights = getBranchWeights()) {
126 << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
135LogicalResult BranchConditionalOp::verify() {
136 if (
auto weights = getBranchWeights()) {
137 if (weights->getValue().size() != 2) {
138 return emitOpError(
"must have exactly two branch weights");
140 if (llvm::all_of(*weights, [](Attribute attr) {
141 return llvm::cast<IntegerAttr>(attr).getValue().isZero();
143 return emitOpError(
"branch weights cannot both be zero");
153LogicalResult FunctionCallOp::verify() {
154 if (getNumResults() > 1) {
156 "expected callee function to have 0 or 1 result, but provided ")
163FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
164 auto fnName = getCalleeAttr();
170 << fnName.getValue() <<
"' not found in nearest symbol table";
173 auto functionType = funcOp.getFunctionType();
175 if (functionType.getNumInputs() != getNumOperands()) {
176 return emitOpError(
"has incorrect number of operands for callee: expected ")
177 << functionType.getNumInputs() <<
", but provided "
181 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
182 if (getOperand(i).
getType() != functionType.getInput(i)) {
183 return emitOpError(
"operand type mismatch: expected operand type ")
184 << functionType.getInput(i) <<
", but provided "
185 << getOperand(i).getType() <<
" for operand number " << i;
189 if (functionType.getNumResults() != getNumResults()) {
191 "has incorrect number of results has for callee: expected ")
192 << functionType.getNumResults() <<
", but provided "
196 if (getNumResults() &&
197 (getResult(0).
getType() != functionType.getResult(0))) {
198 return emitOpError(
"result type mismatch: expected ")
199 << functionType.getResult(0) <<
", but provided "
200 << getResult(0).getType();
206CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
207 return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
210void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
211 (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
215 return getArguments();
218MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
219 return getArgumentsMutable();
226void SwitchOp::build(OpBuilder &builder, OperationState &
result, Value selector,
228 DenseIntElementsAttr literals, BlockRange targets,
229 ArrayRef<ValueRange> targetOperands) {
230 build(builder,
result, selector, defaultOperands, targetOperands, literals,
231 defaultTarget, targets);
234void SwitchOp::build(OpBuilder &builder, OperationState &
result, Value selector,
236 ArrayRef<APInt> literals, BlockRange targets,
237 ArrayRef<ValueRange> targetOperands) {
238 DenseIntElementsAttr literalsAttr;
239 if (!literals.empty()) {
240 ShapedType literalType = VectorType::get(
241 static_cast<int64_t
>(literals.size()), selector.
getType());
244 build(builder,
result, selector, defaultTarget, defaultOperands, literalsAttr,
245 targets, targetOperands);
248void SwitchOp::build(OpBuilder &builder, OperationState &
result, Value selector,
250 ArrayRef<int32_t> literals, BlockRange targets,
251 ArrayRef<ValueRange> targetOperands) {
252 DenseIntElementsAttr literalsAttr;
253 if (!literals.empty()) {
254 ShapedType literalType = VectorType::get(
255 static_cast<int64_t
>(literals.size()), selector.
getType());
258 build(builder,
result, selector, defaultTarget, defaultOperands, literalsAttr,
259 targets, targetOperands);
262LogicalResult SwitchOp::verify() {
263 std::optional<DenseIntElementsAttr> literals = getLiterals();
264 BlockRange targets = getTargets();
266 if (!literals && targets.empty())
269 Type selectorType = getSelector().getType();
270 Type literalType = literals->getType().getElementType();
271 if (literalType != selectorType)
272 return emitOpError() <<
"'selector' type (" << selectorType
273 <<
") should match literals type (" << literalType
276 if (literals && literals->size() !=
static_cast<int64_t
>(targets.size()))
277 return emitOpError() <<
"number of literals (" << literals->size()
278 <<
") should match number of targets ("
279 << targets.size() <<
")";
283SuccessorOperands SwitchOp::getSuccessorOperands(
unsigned index) {
284 assert(index < getNumSuccessors() &&
"invalid successor index");
285 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
286 : getTargetOperandsMutable(index - 1));
289Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
290 std::optional<DenseIntElementsAttr> literals = getLiterals();
293 return getDefaultTarget();
295 SuccessorRange targets = getTargets();
296 if (
auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
297 for (
auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
298 if (literal == value.getValue())
299 return targets[index];
300 return getDefaultTarget();
309void LoopOp::build(OpBuilder &builder, OperationState &state) {
311 spirv::LoopControl::None));
315ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &
result) {
327void LoopOp::print(OpAsmPrinter &printer) {
328 auto control = getLoopControl();
329 if (control != spirv::LoopControl::None)
330 printer <<
" control(" << spirv::stringifyLoopControl(control) <<
")";
331 if (getNumResults() > 0) {
333 printer << getResultTypes();
344 if (!llvm::hasSingleElement(srcBlock))
347 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
348 return branchOp && branchOp.getSuccessor() == &dstBlock;
353 return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.
front());
359 return isa<spirv::MergeOp>(op) && op.getBlock() != ®ion.back();
363LogicalResult LoopOp::verifyRegions() {
364 auto *op = getOperation();
391 auto ®ion = op->getRegion(0);
400 return emitOpError(
"last block must be the merge block with only one "
401 "'spirv.mlir.merge' op");
404 "should not have 'spirv.mlir.merge' op outside the merge block");
406 if (region.hasOneBlock())
408 "must have an entry block branching to the loop header block");
412 if (std::next(region.begin(), 2) == region.end())
414 "must have a loop header block branched from the entry block");
416 Block &header = *std::next(region.begin(), 1);
420 "entry block must only have one 'spirv.Branch' op to the second block");
422 if (std::next(region.begin(), 3) == region.end())
424 "requires a loop continue block branching to the loop header block");
426 Block &cont = *std::prev(region.end(), 2);
432 [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
433 return emitOpError(
"second to last block must be the loop continue "
434 "block that branches to the loop header block");
438 for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
439 std::prev(region.end(), 2))) {
440 for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
441 if (block.getSuccessor(i) == &header) {
442 return emitOpError(
"can only have the entry and loop continue "
443 "block branching to the loop header block");
451Block *LoopOp::getEntryBlock() {
452 assert(!getBody().empty() &&
"op region should not be empty!");
453 return &getBody().front();
456Block *LoopOp::getHeaderBlock() {
457 assert(!getBody().empty() &&
"op region should not be empty!");
459 return &*std::next(getBody().begin());
462Block *LoopOp::getContinueBlock() {
463 assert(!getBody().empty() &&
"op region should not be empty!");
465 return &*std::prev(getBody().end(), 2);
468Block *LoopOp::getMergeBlock() {
469 assert(!getBody().empty() &&
"op region should not be empty!");
471 return &getBody().back();
474void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
475 assert(getBody().empty() &&
"entry and merge block already exist");
476 OpBuilder::InsertionGuard g(builder);
481 spirv::MergeOp::create(builder, getLoc());
488LogicalResult ReturnOp::verify() {
497LogicalResult ReturnValueOp::verify() {
506LogicalResult SelectOp::verify() {
507 if (
auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().
getType())) {
508 auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().
getType());
509 if (!resultVectorTy) {
510 return emitOpError(
"result expected to be of vector type when "
511 "condition is of vector type");
513 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
514 return emitOpError(
"result should have the same number of elements as "
515 "the condition when condition is of vector type");
523SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
526SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
529std::optional<spirv::Version> SelectOp::getMinVersion() {
532 if (isa<spirv::ScalarType>(getCondition().
getType()) &&
533 isa<spirv::CompositeType>(
getType()))
534 return Version::V_1_4;
536 return Version::V_1_0;
538std::optional<spirv::Version> SelectOp::getMaxVersion() {
539 return Version::V_1_6;
546ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &
result) {
548 spirv::SelectionControl>(parser,
result))
558void SelectionOp::print(OpAsmPrinter &printer) {
559 auto control = getSelectionControl();
560 if (control != spirv::SelectionControl::None)
561 printer <<
" control(" << spirv::stringifySelectionControl(control) <<
")";
562 if (getNumResults() > 0) {
564 printer << getResultTypes();
571LogicalResult SelectionOp::verifyRegions() {
572 auto *op = getOperation();
595 auto ®ion = op->getRegion(0);
603 return emitOpError(
"last block must be the merge block with only one "
604 "'spirv.mlir.merge' op");
607 "should not have 'spirv.mlir.merge' op outside the merge block");
609 if (region.hasOneBlock())
610 return emitOpError(
"must have a selection header block");
615Block *SelectionOp::getHeaderBlock() {
616 assert(!getBody().empty() &&
"op region should not be empty!");
618 return &getBody().front();
621Block *SelectionOp::getMergeBlock() {
622 assert(!getBody().empty() &&
"op region should not be empty!");
624 return &getBody().back();
627void SelectionOp::addMergeBlock(OpBuilder &builder) {
628 assert(getBody().empty() &&
"entry and merge block already exist");
629 OpBuilder::InsertionGuard guard(builder);
633 spirv::MergeOp::create(builder, getLoc());
637SelectionOp::createIfThen(Location loc, Value condition,
639 OpBuilder &builder) {
641 spirv::SelectionOp::create(builder, loc, spirv::SelectionControl::None);
643 selectionOp.addMergeBlock(builder);
644 Block *mergeBlock = selectionOp.getMergeBlock();
645 Block *thenBlock =
nullptr;
649 OpBuilder::InsertionGuard guard(builder);
652 spirv::BranchOp::create(builder, loc, mergeBlock);
657 OpBuilder::InsertionGuard guard(builder);
659 spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
672LogicalResult spirv::UnreachableOp::verify() {
673 auto *block = (*this)->getBlock();
676 if (block->isEntryBlock())
677 return emitOpError(
"cannot be used in reachable block");
678 if (block->hasNoPredecessors())
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
This class is a general helper class for creating context-global objects like types,...
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
This class models how operands are forwarded to block arguments in control flow.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Type getType() const
Return the type of this value.
constexpr char kControl[]
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
static bool hasOtherMerge(Region ®ion)
Returns true if a spirv.mlir.merge op outside the merge block.
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
constexpr StringRef attributeName()
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::function_ref< Fn > function_ref
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.