31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
63 value = integerValueAttr.getSInt();
70 spirv::MemorySemantics memorySemantics) {
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
83 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
102 if (descriptorSet && binding) {
105 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
113 printer <<
" " << builtInName <<
"(\"" <<
builtin.getValue() <<
"\")";
114 elidedAttrs.push_back(builtInName);
132 auto fnType = dyn_cast<FunctionType>(type);
134 parser.
emitError(loc,
"expected function type");
139 result.addTypes(fnType.getResults());
150 assert(op->
getNumResults() == 1 &&
"op should have one result");
156 [&](
Type type) { return type != resultType; })) {
165 p <<
" : " << resultType;
168template <
typename BlockReadWriteOpTy>
172 if (
auto valVecTy = dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
175 if (valType != cast<spirv::PointerType>(
ptr.getType()).getPointeeType()) {
176 return op.emitOpError(
"mismatch in result type and pointer type");
188 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
193 if (
auto cType = dyn_cast<spirv::CompositeType>(type)) {
194 if (cType.hasCompileTimeKnownNumElements() &&
196 static_cast<uint64_t
>(
index) >= cType.getNumElements())) {
197 emitErrorFn(
"index ") <<
index <<
" out of bounds for " << type;
200 type = cType.getElementType(
index);
202 emitErrorFn(
"cannot extract from non-composite type ")
203 << type <<
" with index " <<
index;
213 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
indices);
214 if (!indicesArrayAttr) {
215 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
218 if (indicesArrayAttr.empty()) {
219 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
224 for (
auto indexAttr : indicesArrayAttr) {
225 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
227 emitErrorFn(
"expected an 32-bit integer for index, but found '")
231 indexVals.push_back(indexIntAttr.getInt());
238 return ::mlir::emitError(loc, err);
251template <
typename ExtendedBinaryOp>
253 auto resultType = cast<spirv::StructType>(op.getType());
254 if (resultType.getNumElements() != 2)
255 return op.emitOpError(
"expected result struct type containing two members");
257 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
258 resultType.getElementType(0),
259 resultType.getElementType(1)}))
260 return op.emitOpError(
261 "expected all operand types and struct member types are the same");
278 auto structType = dyn_cast<spirv::StructType>(resultType);
279 if (!structType || structType.getNumElements() != 2)
280 return parser.
emitError(loc,
"expected spirv.struct type with two members");
286 result.addTypes(resultType);
300 return op->
emitError(
"expected the same type for the first operand and "
301 "result, but provided ")
313 spirv::GlobalVariableOp var) {
314 build(builder, state, var.getType(), SymbolRefAttr::get(var));
317LogicalResult spirv::AddressOfOp::verify() {
318 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
322 return emitOpError(
"expected spirv.GlobalVariable symbol");
324 if (getPointer().
getType() != varOp.getType()) {
326 "result type mismatch with the referenced global variable's type");
335LogicalResult spirv::CompositeConstructOp::verify() {
336 operand_range constituents = this->getConstituents();
351 if (coopElementType) {
352 if (constituents.size() != 1)
353 return emitOpError(
"has incorrect number of operands: expected ")
354 <<
"1, but provided " << constituents.size();
355 if (coopElementType != constituents.front().getType())
356 return emitOpError(
"operand type mismatch: expected operand type ")
357 << coopElementType <<
", but provided "
358 << constituents.front().getType();
363 auto cType = cast<spirv::CompositeType>(
getType());
364 if (constituents.size() == cType.getNumElements()) {
365 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367 return emitOpError(
"operand type mismatch: expected operand type ")
368 << cType.getElementType(
index) <<
", but provided "
369 << constituents[
index].getType();
376 auto resultType = dyn_cast<VectorType>(cType);
379 "expected to return a vector or cooperative matrix when the number of "
380 "constituents is less than what the result needs");
383 for (
Value component : constituents) {
384 if (!isa<VectorType>(component.getType()) &&
385 !component.getType().isIntOrFloat())
386 return emitOpError(
"operand type mismatch: expected operand to have "
387 "a scalar or vector type, but provided ")
388 << component.getType();
390 Type elementType = component.getType();
391 if (
auto vectorType = dyn_cast<VectorType>(component.getType())) {
392 sizes.push_back(vectorType.getNumElements());
393 elementType = vectorType.getElementType();
398 if (elementType != resultType.getElementType())
399 return emitOpError(
"operand element type mismatch: expected to be ")
400 << resultType.getElementType() <<
", but provided " << elementType;
402 unsigned totalCount = llvm::sum_of(sizes);
403 if (totalCount != cType.getNumElements())
404 return emitOpError(
"has incorrect number of operands: expected ")
405 << cType.getNumElements() <<
", but provided " << totalCount;
422 build(builder, state, elementType, composite, indexAttr);
425ParseResult spirv::CompositeExtractOp::parse(
OpAsmParser &parser,
429 StringRef indicesAttrName =
430 spirv::CompositeExtractOp::getIndicesAttrName(
result.name);
447 result.addTypes(resultType);
451void spirv::CompositeExtractOp::print(
OpAsmPrinter &printer) {
452 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
456LogicalResult spirv::CompositeExtractOp::verify() {
457 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
getIndices());
464 return emitOpError(
"invalid result type: expected ")
465 << resultType <<
" but provided " <<
getType();
479 build(builder, state, composite.
getType(),
object, composite, indexAttr);
482ParseResult spirv::CompositeInsertOp::parse(
OpAsmParser &parser,
485 Type objectType, compositeType;
487 StringRef indicesAttrName =
488 spirv::CompositeInsertOp::getIndicesAttrName(
result.name);
501LogicalResult spirv::CompositeInsertOp::verify() {
502 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
getIndices());
508 if (objectType != getObject().
getType()) {
509 return emitOpError(
"object operand type should be ")
510 << objectType <<
", but found " << getObject().getType();
514 return emitOpError(
"result type should be the same as "
515 "the composite type, but found ")
516 << getComposite().getType() <<
" vs " <<
getType();
522void spirv::CompositeInsertOp::print(
OpAsmPrinter &printer) {
523 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
524 <<
" : " << getObject().
getType() <<
" into "
525 << getComposite().getType();
532ParseResult spirv::ConstantOp::parse(
OpAsmParser &parser,
535 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(
result.name);
540 if (
auto typedAttr = dyn_cast<TypedAttr>(value))
541 type = typedAttr.getType();
542 if (isa<NoneType, TensorType>(type)) {
547 if (isa<TensorArmType>(type)) {
557 printer <<
' ' << getValue();
558 if (isa<spirv::ArrayType, spirv::StructType>(
getType()))
564 if (isa<spirv::CooperativeMatrixType>(opType)) {
565 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
566 if (!denseAttr || !denseAttr.isSplat())
567 return op.emitOpError(
"expected a splat dense attribute for cooperative "
568 "matrix constant, but found ")
571 if (isa<IntegerAttr, FloatAttr>(value)) {
572 auto valueType = cast<TypedAttr>(value).getType();
573 if (valueType != opType)
574 return op.emitOpError(
"result type (")
575 << opType <<
") does not match value type (" << valueType <<
")";
578 if (isa<DenseTypedElementsAttr, SparseElementsAttr>(value)) {
579 auto valueType = cast<TypedAttr>(value).getType();
580 if (valueType == opType)
582 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
583 auto shapedType = dyn_cast<ShapedType>(valueType);
585 return op.emitOpError(
"result or element type (")
586 << opType <<
") does not match value type (" << valueType
587 <<
"), must be the same or spirv.array";
589 int numElements = arrayType.getNumElements();
590 auto opElemType = arrayType.getElementType();
591 while (
auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
592 numElements *= t.getNumElements();
593 opElemType = t.getElementType();
595 if (!opElemType.isIntOrFloat())
596 return op.emitOpError(
"only support nested array result type");
598 auto valueElemType = shapedType.getElementType();
599 if (valueElemType != opElemType) {
600 return op.emitOpError(
"result element type (")
601 << opElemType <<
") does not match value element type ("
602 << valueElemType <<
")";
605 if (numElements != shapedType.getNumElements()) {
606 return op.emitOpError(
"result number of elements (")
607 << numElements <<
") does not match value number of elements ("
608 << shapedType.getNumElements() <<
")";
612 if (
auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
613 if (
auto structType = dyn_cast<spirv::StructType>(opType)) {
615 if (structType.isIdentified())
616 return op.emitOpError(
617 "cannot have an identified struct as a constant type");
618 if (arrayAttr.size() != structType.getNumElements())
619 return op.emitOpError(
"number of constituents (")
621 <<
") does not match number of struct members ("
622 << structType.getNumElements() <<
")";
623 for (
auto [idx, element] : llvm::enumerate(arrayAttr.getValue())) {
625 structType.getElementType(idx))))
630 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
632 return op.emitOpError(
633 "must have spirv.array or spirv.struct result type for array value");
634 Type elemType = arrayType.getElementType();
635 for (
Attribute element : arrayAttr.getValue()) {
642 return op.emitOpError(
"cannot have attribute: ") << value;
645LogicalResult spirv::ConstantOp::verify() {
652bool spirv::ConstantOp::isBuildableWith(
Type type) {
654 if (!isa<spirv::SPIRVType>(type))
658 if (
auto structType = dyn_cast<spirv::StructType>(type))
659 return !structType.isIdentified();
660 return isa<spirv::ArrayType>(type);
666spirv::ConstantOp spirv::ConstantOp::getZero(
Type type,
Location loc,
668 if (
auto intType = dyn_cast<IntegerType>(type)) {
669 unsigned width = intType.getWidth();
671 return spirv::ConstantOp::create(builder, loc, type,
673 return spirv::ConstantOp::create(
674 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 0)));
676 if (
auto floatType = dyn_cast<FloatType>(type)) {
677 return spirv::ConstantOp::create(builder, loc, type,
680 if (
auto vectorType = dyn_cast<VectorType>(type)) {
681 Type elemType = vectorType.getElementType();
682 if (isa<IntegerType>(elemType)) {
683 return spirv::ConstantOp::create(
686 IntegerAttr::get(elemType, 0).getValue()));
688 if (isa<FloatType>(elemType)) {
689 return spirv::ConstantOp::create(
692 FloatAttr::get(elemType, 0.0).getValue()));
696 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
699spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
701 if (
auto intType = dyn_cast<IntegerType>(type)) {
702 unsigned width = intType.getWidth();
704 return spirv::ConstantOp::create(builder, loc, type,
706 return spirv::ConstantOp::create(
707 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 1)));
709 if (
auto floatType = dyn_cast<FloatType>(type)) {
710 return spirv::ConstantOp::create(builder, loc, type,
713 if (
auto vectorType = dyn_cast<VectorType>(type)) {
714 Type elemType = vectorType.getElementType();
715 if (isa<IntegerType>(elemType)) {
716 return spirv::ConstantOp::create(
719 IntegerAttr::get(elemType, 1).getValue()));
721 if (isa<FloatType>(elemType)) {
722 return spirv::ConstantOp::create(
725 FloatAttr::get(elemType, 1.0).getValue()));
729 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
732void mlir::spirv::ConstantOp::getAsmResultNames(
737 llvm::raw_svector_ostream specialName(specialNameBuffer);
738 specialName <<
"cst";
740 IntegerType intTy = dyn_cast<IntegerType>(type);
742 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
745 if (intTy.getWidth() == 1) {
746 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
749 if (intTy.isSignless()) {
750 specialName << intCst.getInt();
751 }
else if (intTy.isUnsigned()) {
752 specialName << intCst.getUInt();
754 specialName << intCst.getSInt();
758 if (intTy || isa<FloatType>(type)) {
759 specialName <<
'_' << type;
762 if (
auto vecType = dyn_cast<VectorType>(type)) {
763 specialName <<
"_vec_";
764 specialName << vecType.getDimSize(0);
766 Type elementType = vecType.getElementType();
768 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
769 specialName <<
"x" << elementType;
773 setNameFn(getResult(), specialName.str());
776void mlir::spirv::AddressOfOp::getAsmResultNames(
779 llvm::raw_svector_ostream specialName(specialNameBuffer);
780 specialName << getVariable() <<
"_addr";
781 setNameFn(getResult(), specialName.str());
792 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
793 return typedAttr.getType();
796 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
803LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
806 return emitError(
"unknown value attribute type");
808 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
810 return emitError(
"result type is not a composite type");
812 Type compositeElementType = compositeType.getElementType(0);
815 while (
auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
816 compositeElementType = type.getElementType(0);
817 possibleTypes.push_back(compositeElementType);
820 if (!is_contained(possibleTypes, valueType)) {
821 return emitError(
"expected value attribute type ")
822 << interleaved(possibleTypes,
" or ") <<
", but got: " << valueType;
832LogicalResult spirv::ControlBarrierOp::verify() {
841 spirv::ExecutionModel executionModel,
842 spirv::FuncOp function,
844 build(builder, state,
845 spirv::ExecutionModelAttr::get(builder.
getContext(), executionModel),
846 SymbolRefAttr::get(function), builder.
getArrayAttr(interfaceVars));
849ParseResult spirv::EntryPointOp::parse(
OpAsmParser &parser,
851 spirv::ExecutionModel execModel;
864 FlatSymbolRefAttr var;
866 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
868 interfaceVars.push_back(var);
873 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(
result.name),
881 auto interfaceVars = getInterface().getValue();
882 if (!interfaceVars.empty())
883 printer <<
", " << llvm::interleaved(interfaceVars);
886LogicalResult spirv::EntryPointOp::verify() {
897 spirv::FuncOp function,
898 spirv::ExecutionMode executionMode,
900 build(builder, state, SymbolRefAttr::get(function),
901 spirv::ExecutionModeAttr::get(builder.
getContext(), executionMode),
905ParseResult spirv::ExecutionModeOp::parse(
OpAsmParser &parser,
907 spirv::ExecutionMode execMode;
922 values.push_back(cast<IntegerAttr>(value).getInt());
924 StringRef valuesAttrName =
925 spirv::ExecutionModeOp::getValuesAttrName(
result.name);
926 result.addAttribute(valuesAttrName,
931void spirv::ExecutionModeOp::print(
OpAsmPrinter &printer) {
934 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
937 printer <<
", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
944ParseResult spirv::ExecutionModeIdOp::parse(
OpAsmParser &parser,
946 ExecutionMode execMode;
955 FlatSymbolRefAttr attr;
956 if (parser.parseAttribute(attr))
958 values.push_back(attr);
964 StringRef valuesAttrName = getValuesAttrName(
result.name);
966 result.addAttribute(valuesAttrName, valuesAttr);
970void spirv::ExecutionModeIdOp::print(
OpAsmPrinter &printer) {
973 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\" ";
975 llvm::interleaveComma(
976 getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
980LogicalResult spirv::ExecutionModeIdOp::verify() {
982 switch (getExecutionMode()) {
983 case ExecutionMode::SubgroupsPerWorkgroupId:
984 case ExecutionMode::LocalSizeId:
985 case ExecutionMode::LocalSizeHintId:
988 return emitOpError(
"expected ExecutionMode that takes extra operands that "
989 "are <id> operands, got: ")
990 << stringifyExecutionMode(getExecutionMode());
993 if (getValues().empty())
994 return emitOpError(
"expected at least one value operand");
997 auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
999 return emitOpError(
"expected value operands to be symbol reference");
1001 (*this)->getParentOp(), valueSymbol);
1003 return emitOpError(
"cannot find symbol referenced by value operand: ")
1004 << valueSymbol.getValue();
1021 StringAttr nameAttr;
1027 bool isVariadic =
false;
1029 parser,
false, entryArgs, isVariadic, resultTypes,
1034 for (
auto &arg : entryArgs)
1035 argTypes.push_back(arg.type);
1037 result.addAttribute(getFunctionTypeAttrName(
result.name),
1038 TypeAttr::get(fnType));
1041 spirv::FunctionControl fnControl;
1050 assert(resultAttrs.size() == resultTypes.size());
1052 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1053 getResAttrsAttrName(
result.name));
1056 auto *body =
result.addRegion();
1066 auto fnType = getFunctionType();
1068 printer, *
this, fnType.getInputs(),
1069 false, fnType.getResults());
1070 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
1074 {spirv::attributeName<spirv::FunctionControl>(),
1075 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
1076 getFunctionControlAttrName()});
1079 Region &body = this->getBody();
1080 if (!body.empty()) {
1087LogicalResult spirv::FuncOp::verifyType() {
1088 FunctionType fnType = getFunctionType();
1089 if (fnType.getNumResults() > 1)
1090 return emitOpError(
"cannot have more than one result");
1092 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1093 unsigned argIndex) {
1094 auto func = cast<FunctionOpInterface>(getOperation());
1095 for (
auto argAttr : cast<FunctionOpInterface>(
func).
getArgAttrs(argIndex)) {
1096 if (argAttr.getName() != spirv::DecorationAttr::name)
1098 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1099 return decAttr.getValue() == decoration;
1104 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1105 Type param = fnType.getInputs()[i];
1106 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1110 auto pointeePtrType =
1111 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1112 if (pointeePtrType) {
1118 if (pointeePtrType.getStorageClass() !=
1119 spirv::StorageClass::PhysicalStorageBuffer)
1122 bool hasAliasedPtr =
1123 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1124 bool hasRestrictPtr =
1125 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1126 if (!hasAliasedPtr && !hasRestrictPtr)
1128 <<
"with a pointer points to a physical buffer pointer must "
1129 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1136 if (
auto pointeeArrayType =
1137 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1139 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1141 pointeePtrType = inputPtrType;
1144 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1145 spirv::StorageClass::PhysicalStorageBuffer)
1148 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1149 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1150 if (!hasAliased && !hasRestrict)
1151 return emitOpError() <<
"with physical buffer pointer must be decorated "
1152 "either 'Aliased' or 'Restrict'";
1158LogicalResult spirv::FuncOp::verifyBody() {
1159 FunctionType fnType = getFunctionType();
1160 if (!isExternal()) {
1161 Block &entryBlock = front();
1163 unsigned numArguments = this->getNumArguments();
1166 << numArguments <<
" arguments to match function signature";
1168 for (
auto [
index, fnArgType, blockArgType] :
1170 if (blockArgType != fnArgType) {
1171 return emitOpError(
"type of entry block argument #")
1172 <<
index <<
'(' << blockArgType
1173 <<
") must match the type of the corresponding argument in "
1174 <<
"function signature(" << fnArgType <<
')';
1180 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1181 if (fnType.getNumResults() != 0)
1182 return retOp.emitOpError(
"cannot be used in functions returning value");
1183 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1184 if (fnType.getNumResults() != 1)
1185 return retOp.emitOpError(
1186 "returns 1 value but enclosing function requires ")
1187 << fnType.getNumResults() <<
" results";
1189 auto retOperandType = retOp.getValue().getType();
1190 auto fnResultType = fnType.getResult(0);
1191 if (retOperandType != fnResultType)
1192 return retOp.emitOpError(
" return value's type (")
1193 << retOperandType <<
") mismatch with function's result type ("
1194 << fnResultType <<
")";
1201 return failure(walkResult.wasInterrupted());
1205 StringRef name, FunctionType type,
1206 spirv::FunctionControl control,
1210 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
1211 state.
addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1212 builder.
getAttr<spirv::FunctionControlAttr>(control));
1221ParseResult spirv::GLFClampOp::parse(
OpAsmParser &parser,
1231ParseResult spirv::GLUClampOp::parse(
OpAsmParser &parser,
1241ParseResult spirv::GLSClampOp::parse(
OpAsmParser &parser,
1261 Type type, StringRef name,
1262 unsigned descriptorSet,
unsigned binding) {
1263 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
1265 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1268 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1273 Type type, StringRef name,
1275 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
1277 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1281ParseResult spirv::GlobalVariableOp::parse(
OpAsmParser &parser,
1284 StringAttr nameAttr;
1285 StringRef initializerAttrName =
1286 spirv::GlobalVariableOp::getInitializerAttrName(
result.name);
1307 StringRef typeAttrName =
1308 spirv::GlobalVariableOp::getTypeAttrName(
result.name);
1313 if (!isa<spirv::PointerType>(type)) {
1314 return parser.
emitError(loc,
"expected spirv.ptr type");
1316 result.addAttribute(typeAttrName, TypeAttr::get(type));
1321void spirv::GlobalVariableOp::print(
OpAsmPrinter &printer) {
1323 spirv::attributeName<spirv::StorageClass>()};
1330 StringRef initializerAttrName = this->getInitializerAttrName();
1332 if (
auto initializer = this->getInitializer()) {
1333 printer <<
" " << initializerAttrName <<
'(';
1336 elidedAttrs.push_back(initializerAttrName);
1339 StringRef typeAttrName = this->getTypeAttrName();
1340 elidedAttrs.push_back(typeAttrName);
1342 printer <<
" : " <<
getType();
1345LogicalResult spirv::GlobalVariableOp::verify() {
1346 if (!isa<spirv::PointerType>(
getType()))
1347 return emitOpError(
"result must be of a !spv.ptr type");
1353 auto storageClass = this->storageClass();
1354 if (storageClass == spirv::StorageClass::Generic ||
1355 storageClass == spirv::StorageClass::Function) {
1357 << stringifyStorageClass(storageClass) <<
"'";
1362 if (std::optional<spirv::LinkageAttributesAttr> linkage =
1363 getLinkageAttributes()) {
1364 if (linkage->getLinkageType().getValue() == spirv::LinkageType::Import &&
1367 "with Import linkage type must not have an initializer");
1372 this->getInitializerAttrName())) {
1374 (*this)->getParentOp(), init.getAttr());
1386 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1387 return emitOpError(
"initializer must be result of a "
1388 "spirv.SpecConstant or "
1389 "spirv.SpecConstantCompositeOp op");
1400LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1411ParseResult spirv::INTELSubgroupBlockWriteOp::parse(
OpAsmParser &parser,
1414 spirv::StorageClass storageClass;
1425 if (
auto valVecTy = dyn_cast<VectorType>(elementType))
1435void spirv::INTELSubgroupBlockWriteOp::print(
OpAsmPrinter &printer) {
1436 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1437 << getValue().getType();
1440LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1451LogicalResult spirv::IAddCarryOp::verify() {
1452 return ::verifyArithmeticExtendedBinaryOp(*
this);
1455ParseResult spirv::IAddCarryOp::parse(
OpAsmParser &parser,
1457 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1468LogicalResult spirv::ISubBorrowOp::verify() {
1469 return ::verifyArithmeticExtendedBinaryOp(*
this);
1472ParseResult spirv::ISubBorrowOp::parse(
OpAsmParser &parser,
1474 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1477void spirv::ISubBorrowOp::print(
OpAsmPrinter &printer) {
1485LogicalResult spirv::SMulExtendedOp::verify() {
1486 return ::verifyArithmeticExtendedBinaryOp(*
this);
1489ParseResult spirv::SMulExtendedOp::parse(
OpAsmParser &parser,
1491 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1494void spirv::SMulExtendedOp::print(
OpAsmPrinter &printer) {
1502LogicalResult spirv::UMulExtendedOp::verify() {
1503 return ::verifyArithmeticExtendedBinaryOp(*
this);
1506ParseResult spirv::UMulExtendedOp::parse(
OpAsmParser &parser,
1508 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1511void spirv::UMulExtendedOp::print(
OpAsmPrinter &printer) {
1519LogicalResult spirv::MemoryBarrierOp::verify() {
1527LogicalResult spirv::MemoryNamedBarrierOp::verify() {
1536 std::optional<StringRef> name) {
1546 spirv::AddressingModel addressingModel,
1547 spirv::MemoryModel memoryModel,
1548 std::optional<VerCapExtAttr> vceTriple,
1549 std::optional<StringRef> name) {
1552 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1554 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1558 state.
addAttribute(getVCETripleAttrName(), *vceTriple);
1564ParseResult spirv::ModuleOp::parse(
OpAsmParser &parser,
1569 StringAttr nameAttr;
1574 spirv::AddressingModel addrModel;
1575 spirv::MemoryModel memoryModel;
1585 spirv::ModuleOp::getVCETripleAttrName(),
1602 if (std::optional<StringRef> name = getName()) {
1611 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1612 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1613 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1616 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1617 printer <<
" requires " << *triple;
1618 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1626LogicalResult spirv::ModuleOp::verifyRegions() {
1627 Dialect *dialect = (*this)->getDialect();
1632 for (
auto &op : *getBody()) {
1634 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1639 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1640 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1642 return entryPointOp.emitError(
"function '")
1643 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1645 if (
auto interface = entryPointOp.getInterface()) {
1647 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1649 return entryPointOp.emitError(
1650 "expected symbol reference for interface "
1651 "specification instead of '")
1655 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1657 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1658 "symbol reference instead of'")
1659 << varSymRef <<
"'";
1664 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1665 funcOp, entryPointOp.getExecutionModel());
1666 if (!entryPoints.try_emplace(key, entryPointOp).second)
1667 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1668 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1672 auto linkageAttr = funcOp.getLinkageAttributes();
1673 auto hasImportLinkage =
1674 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1675 spirv::LinkageType::Import);
1676 if (funcOp.isExternal() && !hasImportLinkage)
1678 "'spirv.module' cannot contain external functions "
1679 "without 'Import' linkage_attributes (LinkageAttributes)");
1682 for (
auto &block : funcOp)
1683 for (
auto &op : block) {
1686 "functions in 'spirv.module' can only contain spirv.* ops");
1698LogicalResult spirv::ReferenceOfOp::verify() {
1700 (*this)->getParentOp(), getSpecConstAttr());
1703 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1705 constType = specConstOp.getDefaultValue().getType();
1707 auto specConstCompositeOp =
1708 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1709 if (specConstCompositeOp)
1710 constType = specConstCompositeOp.getType();
1712 if (!specConstOp && !specConstCompositeOp)
1714 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1716 if (getReference().
getType() != constType)
1717 return emitOpError(
"result type mismatch with the referenced "
1718 "specialization constant's type");
1727ParseResult spirv::SpecConstantOp::parse(
OpAsmParser &parser,
1729 StringAttr nameAttr;
1731 StringRef defaultValueAttrName =
1732 spirv::SpecConstantOp::getDefaultValueAttrName(
result.name);
1740 IntegerAttr specIdAttr;
1754void spirv::SpecConstantOp::print(
OpAsmPrinter &printer) {
1757 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1759 printer <<
" = " << getDefaultValue();
1762LogicalResult spirv::SpecConstantOp::verify() {
1763 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1764 if (specID.getValue().isNegative())
1767 auto value = getDefaultValue();
1768 if (isa<IntegerAttr, FloatAttr>(value)) {
1770 if (!isa<spirv::SPIRVType>(value.getType()))
1771 return emitOpError(
"default value bitwidth disallowed");
1775 "default value can only be a bool, integer, or float scalar");
1782LogicalResult spirv::VectorShuffleOp::verify() {
1783 VectorType resultType = cast<VectorType>(
getType());
1785 size_t numResultElements = resultType.getNumElements();
1786 if (numResultElements != getComponents().size())
1788 << numResultElements
1789 <<
") mismatch with the number of component selectors ("
1790 << getComponents().size() <<
")";
1792 size_t totalSrcElements =
1793 cast<VectorType>(getVector1().
getType()).getNumElements() +
1794 cast<VectorType>(getVector2().
getType()).getNumElements();
1796 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1797 uint32_t
index = selector.getZExtValue();
1798 if (
index >= totalSrcElements &&
1799 index != std::numeric_limits<uint32_t>().
max())
1801 <<
index <<
" out of range: expected to be in [0, "
1802 << totalSrcElements <<
") or 0xffffffff";
1811ParseResult spirv::SpecConstantCompositeOp::parse(
OpAsmParser &parser,
1814 StringAttr compositeName;
1826 const char *attrName =
"spec_const";
1833 constituents.push_back(specConstRef);
1839 StringAttr compositeSpecConstituentsName =
1840 spirv::SpecConstantCompositeOp::getConstituentsAttrName(
result.name);
1841 result.addAttribute(compositeSpecConstituentsName,
1848 StringAttr typeAttrName =
1849 spirv::SpecConstantCompositeOp::getTypeAttrName(
result.name);
1850 result.addAttribute(typeAttrName, TypeAttr::get(type));
1855void spirv::SpecConstantCompositeOp::print(
OpAsmPrinter &printer) {
1858 printer <<
" (" << llvm::interleaved(this->getConstituents().getValue())
1862LogicalResult spirv::SpecConstantCompositeOp::verify() {
1863 auto cType = dyn_cast<spirv::CompositeType>(
getType());
1864 auto constituents = this->getConstituents().getValue();
1867 return emitError(
"result type must be a composite type, but provided ")
1870 if (isa<spirv::CooperativeMatrixType>(cType))
1871 return emitError(
"unsupported composite type ") << cType;
1872 if (constituents.size() != cType.getNumElements())
1873 return emitError(
"has incorrect number of operands: expected ")
1874 << cType.getNumElements() <<
", but provided "
1875 << constituents.size();
1877 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1878 auto constituent = cast<FlatSymbolRefAttr>(constituents[
index]);
1881 (*this)->getParentOp(), constituent.getAttr());
1884 return emitError(
"unknown constituent symbol ") << constituent.getAttr();
1886 Type constituentType;
1887 if (
auto specConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp)) {
1888 constituentType = specConstOp.getDefaultValue().getType();
1889 }
else if (
auto specConstCompositeOp =
1890 dyn_cast<spirv::SpecConstantCompositeOp>(constituentOp)) {
1891 constituentType = specConstCompositeOp.getType();
1893 return emitError(
"unsupported constituent ")
1894 << constituent.getAttr()
1895 <<
": must reference a spirv.SpecConstant or "
1896 "spirv.SpecConstantComposite";
1899 if (constituentType != cType.getElementType(
index))
1900 return emitError(
"has incorrect types of operands: expected ")
1901 << cType.getElementType(
index) <<
", but provided "
1913spirv::EXTSpecConstantCompositeReplicateOp::parse(
OpAsmParser &parser,
1915 StringAttr compositeName;
1917 const char *attrName =
"spec_const";
1928 StringAttr compositeSpecConstituentName =
1929 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1931 result.addAttribute(compositeSpecConstituentName, specConstRef);
1933 StringAttr typeAttrName =
1934 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(
result.name);
1935 result.addAttribute(typeAttrName, TypeAttr::get(type));
1940void spirv::EXTSpecConstantCompositeReplicateOp::print(
OpAsmPrinter &printer) {
1943 printer <<
" (" << this->getConstituent() <<
") : " <<
getType();
1946LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1947 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
1949 return emitError(
"result type must be a composite type, but provided ")
1953 (*this)->getParentOp(), this->getConstituent());
1956 "splat spec constant reference defining constituent not found");
1958 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1959 if (!constituentSpecConstOp)
1960 return emitError(
"constituent is not a spec constant");
1962 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1963 Type compositeElementType = compositeType.getElementType(0);
1964 if (constituentType != compositeElementType)
1965 return emitError(
"constituent has incorrect type: expected ")
1966 << compositeElementType <<
", but provided " << constituentType;
1975ParseResult spirv::SpecConstantOperationOp::parse(
OpAsmParser &parser,
1991 spirv::YieldOp::create(builder, wrappedOp->
getLoc(), wrappedOp->
getResult(0));
2002void spirv::SpecConstantOperationOp::print(
OpAsmPrinter &printer) {
2003 printer <<
" wraps ";
2007LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2008 Block &block = getRegion().getBlocks().
front();
2011 return emitOpError(
"expected exactly 2 nested ops");
2019 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2020 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2022 "invalid operand, must be defined by a constant operation");
2031LogicalResult spirv::GLFrexpStructOp::verify() {
2033 dyn_cast<spirv::StructType>(getResult().
getType());
2036 return emitError(
"result type must be a struct type with two memebers");
2040 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2041 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2043 Type operandTy = getOperand().getType();
2044 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2045 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2047 if (significandTy != operandTy)
2048 return emitError(
"member zero of the resulting struct type must be the "
2049 "same type as the operand");
2051 if (exponentVecTy) {
2052 IntegerType componentIntTy =
2053 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2054 if (!componentIntTy || componentIntTy.getWidth() != 32)
2055 return emitError(
"member one of the resulting struct type must"
2056 "be a scalar or vector of 32 bit integer type");
2057 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2058 return emitError(
"member one of the resulting struct type "
2059 "must be a scalar or vector of 32 bit integer type");
2063 if (operandVecTy && exponentVecTy &&
2064 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2067 if (operandFTy && exponentIntTy)
2070 return emitError(
"member one of the resulting struct type must have the same "
2071 "number of components as the operand type");
2080 if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
2081 return op->
emitOpError(
"operands must both be scalars or vectors");
2084 if (
auto vectorType = dyn_cast<VectorType>(type))
2085 return vectorType.getNumElements();
2090 return op->
emitOpError(
"operands must have the same number of elements");
2095LogicalResult spirv::GLLdexpOp::verify() {
2104LogicalResult spirv::CLLdexpOp::verify() {
2113LogicalResult spirv::CLPownOp::verify() {
2122LogicalResult spirv::CLRootnOp::verify() {
2131LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2139LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2147LogicalResult spirv::ShiftRightLogicalOp::verify() {
2155LogicalResult spirv::VectorTimesScalarOp::verify() {
2157 return emitOpError(
"vector operand and result type mismatch");
2158 auto scalarType = cast<VectorType>(
getType()).getElementType();
2159 if (getScalar().
getType() != scalarType)
2160 return emitOpError(
"scalar operand and result element type match");
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static Type getValueType(Attribute attr)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType, Type integerType)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
Type getElementType() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.