18 #include "llvm/Support/InterleavedRange.h" 
   29 template <
typename EnumAttrClass, 
typename EnumClass>
 
   32                       StringRef attrName = spirv::attributeName<EnumClass>()) {
 
   36         spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
 
   43   state.addAttribute(attrName,
 
   44                      builder.
getAttr<EnumAttrClass>(
static_cast<EnumClass
>(0)));
 
   53   assert(index == 0 && 
"invalid successor index");
 
   62   assert(index < 2 && 
"invalid successor index");
 
   63   return SuccessorOperands(index == kTrueIndex
 
   64                                ? getTrueTargetOperandsMutable()
 
   65                                : getFalseTargetOperandsMutable());
 
   69                                        OperationState &result) {
 
   70   auto &builder = parser.getBuilder();
 
   71   OpAsmParser::UnresolvedOperand condInfo;
 
   75   Type boolTy = builder.getI1Type();
 
   76   if (parser.parseOperand(condInfo) ||
 
   77       parser.resolveOperand(condInfo, boolTy, result.operands))
 
   81   if (succeeded(parser.parseOptionalLSquare())) {
 
   82     IntegerAttr trueWeight, falseWeight;
 
   83     NamedAttrList weights;
 
   85     auto i32Type = builder.getIntegerType(32);
 
   86     if (parser.parseAttribute(trueWeight, i32Type, 
"weight", weights) ||
 
   87         parser.parseComma() ||
 
   88         parser.parseAttribute(falseWeight, i32Type, 
"weight", weights) ||
 
   89         parser.parseRSquare())
 
   92     StringAttr branchWeightsAttrName =
 
   93         BranchConditionalOp::getBranchWeightsAttrName(result.name);
 
   94     result.addAttribute(branchWeightsAttrName,
 
   95                         builder.getArrayAttr({trueWeight, falseWeight}));
 
   99   SmallVector<Value, 4> trueOperands;
 
  100   if (parser.parseComma() ||
 
  101       parser.parseSuccessorAndUseList(dest, trueOperands))
 
  103   result.addSuccessors(dest);
 
  104   result.addOperands(trueOperands);
 
  107   SmallVector<Value, 4> falseOperands;
 
  108   if (parser.parseComma() ||
 
  109       parser.parseSuccessorAndUseList(dest, falseOperands))
 
  111   result.addSuccessors(dest);
 
  112   result.addOperands(falseOperands);
 
  113   result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
 
  114                       builder.getDenseI32ArrayAttr(
 
  115                           {1, static_cast<int32_t>(trueOperands.size()),
 
  116                            static_cast<int32_t>(falseOperands.size())}));
 
  122   printer << 
' ' << getCondition();
 
  124   if (std::optional<ArrayAttr> weights = getBranchWeights()) {
 
  126             << llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
 
  130   printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
 
  132   printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
 
  136   if (
auto weights = getBranchWeights()) {
 
  137     if (weights->getValue().size() != 2) {
 
  138       return emitOpError(
"must have exactly two branch weights");
 
  140     if (llvm::all_of(*weights, [](Attribute attr) {
 
  141           return llvm::cast<IntegerAttr>(attr).getValue().isZero();
 
  143       return emitOpError(
"branch weights cannot both be zero");
 
  154   if (getNumResults() > 1) {
 
  156                "expected callee function to have 0 or 1 result, but provided ")
 
  163 FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
  164   auto fnName = getCalleeAttr();
 
  167       symbolTable.lookupNearestSymbolFrom<spirv::FuncOp>(*
this, fnName);
 
  169     return emitOpError(
"callee function '")
 
  170            << fnName.getValue() << 
"' not found in nearest symbol table";
 
  173   auto functionType = funcOp.getFunctionType();
 
  175   if (functionType.getNumInputs() != getNumOperands()) {
 
  176     return emitOpError(
"has incorrect number of operands for callee: expected ")
 
  177            << functionType.getNumInputs() << 
", but provided " 
  181   for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
 
  182     if (getOperand(i).
getType() != functionType.getInput(i)) {
 
  183       return emitOpError(
"operand type mismatch: expected operand type ")
 
  184              << functionType.getInput(i) << 
", but provided " 
  185              << getOperand(i).getType() << 
" for operand number " << i;
 
  189   if (functionType.getNumResults() != getNumResults()) {
 
  191                "has incorrect number of results has for callee: expected ")
 
  192            << functionType.getNumResults() << 
", but provided " 
  196   if (getNumResults() &&
 
  197       (getResult(0).
getType() != functionType.getResult(0))) {
 
  198     return emitOpError(
"result type mismatch: expected ")
 
  199            << functionType.getResult(0) << 
", but provided " 
  200            << getResult(0).getType();
 
  206 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
 
  207   return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
 
  210 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 
  211   (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
 
  215   return getArguments();
 
  218 MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
 
  219   return getArgumentsMutable();
 
  226 void LoopOp::build(OpBuilder &builder, OperationState &state) {
 
  227   state.addAttribute(
"loop_control", builder.getAttr<spirv::LoopControlAttr>(
 
  232 ParseResult 
LoopOp::parse(OpAsmParser &parser, OperationState &result) {
 
  233   if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
 
  237   if (succeeded(parser.parseOptionalArrow()))
 
  238     if (parser.parseTypeList(result.types))
 
  241   return parser.parseRegion(*result.addRegion(), {});
 
  245   auto control = getLoopControl();
 
  247     printer << 
" control(" << spirv::stringifyLoopControl(control) << 
")";
 
  248   if (getNumResults() > 0) {
 
  250     printer << getResultTypes();
 
  253   printer.printRegion(getRegion(), 
false,
 
  261   if (!llvm::hasSingleElement(srcBlock))
 
  264   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.
back());
 
  265   return branchOp && branchOp.getSuccessor() == &dstBlock;
 
  270   return llvm::hasSingleElement(block) && isa<spirv::MergeOp>(block.
front());
 
  276     return isa<spirv::MergeOp>(op) && op.getBlock() != ®ion.back();
 
  280 LogicalResult LoopOp::verifyRegions() {
 
  281   auto *op = getOperation();
 
  308   auto ®ion = op->getRegion(0);
 
  317     return emitOpError(
"last block must be the merge block with only one " 
  318                        "'spirv.mlir.merge' op");
 
  321         "should not have 'spirv.mlir.merge' op outside the merge block");
 
  323   if (region.hasOneBlock())
 
  325         "must have an entry block branching to the loop header block");
 
  329   if (std::next(region.begin(), 2) == region.end())
 
  331         "must have a loop header block branched from the entry block");
 
  333   Block &header = *std::next(region.begin(), 1);
 
  337         "entry block must only have one 'spirv.Branch' op to the second block");
 
  339   if (std::next(region.begin(), 3) == region.end())
 
  341         "requires a loop continue block branching to the loop header block");
 
  343   Block &cont = *std::prev(region.end(), 2);
 
  349           [&](
unsigned index) { return cont.getSuccessor(index) == &header; }))
 
  350     return emitOpError(
"second to last block must be the loop continue " 
  351                        "block that branches to the loop header block");
 
  355   for (
auto &block : llvm::make_range(std::next(region.begin(), 2),
 
  356                                       std::prev(region.end(), 2))) {
 
  357     for (
auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
 
  358       if (block.getSuccessor(i) == &header) {
 
  359         return emitOpError(
"can only have the entry and loop continue " 
  360                            "block branching to the loop header block");
 
  368 Block *LoopOp::getEntryBlock() {
 
  369   assert(!getBody().empty() && 
"op region should not be empty!");
 
  370   return &getBody().front();
 
  373 Block *LoopOp::getHeaderBlock() {
 
  374   assert(!getBody().empty() && 
"op region should not be empty!");
 
  376   return &*std::next(getBody().begin());
 
  379 Block *LoopOp::getContinueBlock() {
 
  380   assert(!getBody().empty() && 
"op region should not be empty!");
 
  382   return &*std::prev(getBody().end(), 2);
 
  385 Block *LoopOp::getMergeBlock() {
 
  386   assert(!getBody().empty() && 
"op region should not be empty!");
 
  388   return &getBody().back();
 
  391 void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
 
  392   assert(getBody().empty() && 
"entry and merge block already exist");
 
  393   OpBuilder::InsertionGuard g(builder);
 
  394   builder.createBlock(&getBody());
 
  395   builder.createBlock(&getBody());
 
  398   spirv::MergeOp::create(builder, getLoc());
 
  424   if (
auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().
getType())) {
 
  425     auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().
getType());
 
  426     if (!resultVectorTy) {
 
  427       return emitOpError(
"result expected to be of vector type when " 
  428                          "condition is of vector type");
 
  430     if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
 
  431       return emitOpError(
"result should have the same number of elements as " 
  432                          "the condition when condition is of vector type");
 
  440 SmallVector<ArrayRef<spirv::Extension>, 1> SelectOp::getExtensions() {
 
  443 SmallVector<ArrayRef<spirv::Capability>, 1> SelectOp::getCapabilities() {
 
  449   if (isa<spirv::ScalarType>(getCondition().
getType()) &&
 
  450       isa<spirv::CompositeType>(
getType()))
 
  451     return Version::V_1_4;
 
  453   return Version::V_1_0;
 
  455 std::optional<spirv::Version> SelectOp::getMaxVersion() {
 
  456   return Version::V_1_6;
 
  465                             spirv::SelectionControl>(parser, result))
 
  468   if (succeeded(parser.parseOptionalArrow()))
 
  469     if (parser.parseTypeList(result.types))
 
  472   return parser.parseRegion(*result.addRegion(), {});
 
  476   auto control = getSelectionControl();
 
  478     printer << 
" control(" << spirv::stringifySelectionControl(control) << 
")";
 
  479   if (getNumResults() > 0) {
 
  481     printer << getResultTypes();
 
  484   printer.printRegion(getRegion(), 
false,
 
  488 LogicalResult SelectionOp::verifyRegions() {
 
  489   auto *op = getOperation();
 
  512   auto ®ion = op->getRegion(0);
 
  520     return emitOpError(
"last block must be the merge block with only one " 
  521                        "'spirv.mlir.merge' op");
 
  524         "should not have 'spirv.mlir.merge' op outside the merge block");
 
  526   if (region.hasOneBlock())
 
  527     return emitOpError(
"must have a selection header block");
 
  532 Block *SelectionOp::getHeaderBlock() {
 
  533   assert(!getBody().empty() && 
"op region should not be empty!");
 
  535   return &getBody().front();
 
  538 Block *SelectionOp::getMergeBlock() {
 
  539   assert(!getBody().empty() && 
"op region should not be empty!");
 
  541   return &getBody().back();
 
  544 void SelectionOp::addMergeBlock(OpBuilder &builder) {
 
  545   assert(getBody().empty() && 
"entry and merge block already exist");
 
  546   OpBuilder::InsertionGuard guard(builder);
 
  547   builder.createBlock(&getBody());
 
  550   spirv::MergeOp::create(builder, getLoc());
 
  554 SelectionOp::createIfThen(Location loc, Value condition,
 
  556                           OpBuilder &builder) {
 
  560   selectionOp.addMergeBlock(builder);
 
  561   Block *mergeBlock = selectionOp.getMergeBlock();
 
  562   Block *thenBlock = 
nullptr;
 
  566     OpBuilder::InsertionGuard guard(builder);
 
  567     thenBlock = builder.createBlock(mergeBlock);
 
  569     spirv::BranchOp::create(builder, loc, mergeBlock);
 
  574     OpBuilder::InsertionGuard guard(builder);
 
  575     builder.createBlock(thenBlock);
 
  576     spirv::BranchConditionalOp::create(builder, loc, condition, thenBlock,
 
  590   auto *block = (*this)->getBlock();
 
  593   if (block->isEntryBlock())
 
  594     return emitOpError(
"cannot be used in reachable block");
 
  595   if (block->hasNoPredecessors())
 
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseLParen()=0
Parse a ( token.
Block represents an ordered list of Operations.
unsigned getNumSuccessors()
This class is a general helper class for creating context-global objects like types,...
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
This class models how operands are forwarded to block arguments in control flow.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kControl[]
static bool hasOtherMerge(Region ®ion)
Returns true if a spirv.mlir.merge op outside the merge block.
static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock)
Returns true if the given srcBlock contains only one spirv.Branch to the given dstBlock.
static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName=spirv::attributeName< EnumClass >())
Parses Function, Selection and Loop control attributes.
static bool isMergeBlock(Block &block)
Returns true if the given block only contains one spirv.mlir.merge op.
TosaSpecificationVersion getMinVersion(const Profile &profile)
llvm::function_ref< Fn > function_ref
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.