31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
63 value = integerValueAttr.getSInt();
70 spirv::MemorySemantics memorySemantics) {
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
83 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
97 stringifyDecoration(spirv::Decoration::DescriptorSet));
98 auto bindingName = llvm::convertToSnakeFromCamelCase(
99 stringifyDecoration(spirv::Decoration::Binding));
102 if (descriptorSet && binding) {
105 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
110 auto builtInName = llvm::convertToSnakeFromCamelCase(
111 stringifyDecoration(spirv::Decoration::BuiltIn));
113 printer <<
" " << builtInName <<
"(\"" <<
builtin.getValue() <<
"\")";
114 elidedAttrs.push_back(builtInName);
132 auto fnType = dyn_cast<FunctionType>(type);
134 parser.
emitError(loc,
"expected function type");
139 result.addTypes(fnType.getResults());
150 assert(op->
getNumResults() == 1 &&
"op should have one result");
156 [&](
Type type) { return type != resultType; })) {
165 p <<
" : " << resultType;
168template <
typename BlockReadWriteOpTy>
172 if (
auto valVecTy = dyn_cast<VectorType>(valType))
173 valType = valVecTy.getElementType();
175 if (valType != cast<spirv::PointerType>(
ptr.getType()).getPointeeType()) {
176 return op.emitOpError(
"mismatch in result type and pointer type");
188 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
193 if (
auto cType = dyn_cast<spirv::CompositeType>(type)) {
194 if (cType.hasCompileTimeKnownNumElements() &&
196 static_cast<uint64_t
>(
index) >= cType.getNumElements())) {
197 emitErrorFn(
"index ") <<
index <<
" out of bounds for " << type;
200 type = cType.getElementType(
index);
202 emitErrorFn(
"cannot extract from non-composite type ")
203 << type <<
" with index " <<
index;
213 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
indices);
214 if (!indicesArrayAttr) {
215 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
218 if (indicesArrayAttr.empty()) {
219 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
224 for (
auto indexAttr : indicesArrayAttr) {
225 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
227 emitErrorFn(
"expected an 32-bit integer for index, but found '")
231 indexVals.push_back(indexIntAttr.getInt());
238 return ::mlir::emitError(loc, err);
251template <
typename ExtendedBinaryOp>
253 auto resultType = cast<spirv::StructType>(op.getType());
254 if (resultType.getNumElements() != 2)
255 return op.emitOpError(
"expected result struct type containing two members");
257 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
258 resultType.getElementType(0),
259 resultType.getElementType(1)}))
260 return op.emitOpError(
261 "expected all operand types and struct member types are the same");
278 auto structType = dyn_cast<spirv::StructType>(resultType);
279 if (!structType || structType.getNumElements() != 2)
280 return parser.
emitError(loc,
"expected spirv.struct type with two members");
286 result.addTypes(resultType);
300 return op->
emitError(
"expected the same type for the first operand and "
301 "result, but provided ")
313 spirv::GlobalVariableOp var) {
314 build(builder, state, var.getType(), SymbolRefAttr::get(var));
317LogicalResult spirv::AddressOfOp::verify() {
318 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
322 return emitOpError(
"expected spirv.GlobalVariable symbol");
324 if (getPointer().
getType() != varOp.getType()) {
326 "result type mismatch with the referenced global variable's type");
335LogicalResult spirv::CompositeConstructOp::verify() {
336 operand_range constituents = this->getConstituents();
344 auto coopElementType =
347 [](
auto coopType) {
return coopType.getElementType(); })
351 if (coopElementType) {
352 if (constituents.size() != 1)
353 return emitOpError(
"has incorrect number of operands: expected ")
354 <<
"1, but provided " << constituents.size();
355 if (coopElementType != constituents.front().getType())
356 return emitOpError(
"operand type mismatch: expected operand type ")
357 << coopElementType <<
", but provided "
358 << constituents.front().getType();
363 auto cType = cast<spirv::CompositeType>(
getType());
364 if (constituents.size() == cType.getNumElements()) {
365 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
367 return emitOpError(
"operand type mismatch: expected operand type ")
368 << cType.getElementType(
index) <<
", but provided "
369 << constituents[
index].getType();
376 auto resultType = dyn_cast<VectorType>(cType);
379 "expected to return a vector or cooperative matrix when the number of "
380 "constituents is less than what the result needs");
383 for (
Value component : constituents) {
384 if (!isa<VectorType>(component.getType()) &&
385 !component.getType().isIntOrFloat())
386 return emitOpError(
"operand type mismatch: expected operand to have "
387 "a scalar or vector type, but provided ")
388 << component.getType();
390 Type elementType = component.getType();
391 if (
auto vectorType = dyn_cast<VectorType>(component.getType())) {
392 sizes.push_back(vectorType.getNumElements());
393 elementType = vectorType.getElementType();
398 if (elementType != resultType.getElementType())
399 return emitOpError(
"operand element type mismatch: expected to be ")
400 << resultType.getElementType() <<
", but provided " << elementType;
402 unsigned totalCount = llvm::sum_of(sizes);
403 if (totalCount != cType.getNumElements())
404 return emitOpError(
"has incorrect number of operands: expected ")
405 << cType.getNumElements() <<
", but provided " << totalCount;
422 build(builder, state, elementType, composite, indexAttr);
425ParseResult spirv::CompositeExtractOp::parse(
OpAsmParser &parser,
429 StringRef indicesAttrName =
430 spirv::CompositeExtractOp::getIndicesAttrName(
result.name);
447 result.addTypes(resultType);
451void spirv::CompositeExtractOp::print(
OpAsmPrinter &printer) {
452 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
456LogicalResult spirv::CompositeExtractOp::verify() {
457 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
getIndices());
464 return emitOpError(
"invalid result type: expected ")
465 << resultType <<
" but provided " <<
getType();
479 build(builder, state, composite.
getType(),
object, composite, indexAttr);
482ParseResult spirv::CompositeInsertOp::parse(
OpAsmParser &parser,
485 Type objectType, compositeType;
487 StringRef indicesAttrName =
488 spirv::CompositeInsertOp::getIndicesAttrName(
result.name);
501LogicalResult spirv::CompositeInsertOp::verify() {
502 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
getIndices());
508 if (objectType != getObject().
getType()) {
509 return emitOpError(
"object operand type should be ")
510 << objectType <<
", but found " << getObject().getType();
514 return emitOpError(
"result type should be the same as "
515 "the composite type, but found ")
516 << getComposite().getType() <<
" vs " <<
getType();
522void spirv::CompositeInsertOp::print(
OpAsmPrinter &printer) {
523 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
524 <<
" : " << getObject().
getType() <<
" into "
525 << getComposite().getType();
532ParseResult spirv::ConstantOp::parse(
OpAsmParser &parser,
535 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(
result.name);
540 if (
auto typedAttr = dyn_cast<TypedAttr>(value))
541 type = typedAttr.getType();
542 if (isa<NoneType, TensorType>(type)) {
547 if (isa<TensorArmType>(type)) {
557 printer <<
' ' << getValue();
558 if (isa<spirv::ArrayType>(
getType()))
564 if (isa<spirv::CooperativeMatrixType>(opType)) {
565 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
566 if (!denseAttr || !denseAttr.isSplat())
567 return op.emitOpError(
"expected a splat dense attribute for cooperative "
568 "matrix constant, but found ")
571 if (isa<IntegerAttr, FloatAttr>(value)) {
572 auto valueType = cast<TypedAttr>(value).getType();
573 if (valueType != opType)
574 return op.emitOpError(
"result type (")
575 << opType <<
") does not match value type (" << valueType <<
")";
578 if (isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
579 auto valueType = cast<TypedAttr>(value).getType();
580 if (valueType == opType)
582 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
583 auto shapedType = dyn_cast<ShapedType>(valueType);
585 return op.emitOpError(
"result or element type (")
586 << opType <<
") does not match value type (" << valueType
587 <<
"), must be the same or spirv.array";
589 int numElements = arrayType.getNumElements();
590 auto opElemType = arrayType.getElementType();
591 while (
auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
592 numElements *= t.getNumElements();
593 opElemType = t.getElementType();
595 if (!opElemType.isIntOrFloat())
596 return op.emitOpError(
"only support nested array result type");
598 auto valueElemType = shapedType.getElementType();
599 if (valueElemType != opElemType) {
600 return op.emitOpError(
"result element type (")
601 << opElemType <<
") does not match value element type ("
602 << valueElemType <<
")";
605 if (numElements != shapedType.getNumElements()) {
606 return op.emitOpError(
"result number of elements (")
607 << numElements <<
") does not match value number of elements ("
608 << shapedType.getNumElements() <<
")";
612 if (
auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
613 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
615 return op.emitOpError(
616 "must have spirv.array result type for array value");
617 Type elemType = arrayType.getElementType();
618 for (
Attribute element : arrayAttr.getValue()) {
625 return op.emitOpError(
"cannot have attribute: ") << value;
628LogicalResult spirv::ConstantOp::verify() {
635bool spirv::ConstantOp::isBuildableWith(
Type type) {
637 if (!isa<spirv::SPIRVType>(type))
642 return isa<spirv::ArrayType>(type);
648spirv::ConstantOp spirv::ConstantOp::getZero(
Type type,
Location loc,
650 if (
auto intType = dyn_cast<IntegerType>(type)) {
651 unsigned width = intType.getWidth();
653 return spirv::ConstantOp::create(builder, loc, type,
655 return spirv::ConstantOp::create(
656 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 0)));
658 if (
auto floatType = dyn_cast<FloatType>(type)) {
659 return spirv::ConstantOp::create(builder, loc, type,
662 if (
auto vectorType = dyn_cast<VectorType>(type)) {
663 Type elemType = vectorType.getElementType();
664 if (isa<IntegerType>(elemType)) {
665 return spirv::ConstantOp::create(
668 IntegerAttr::get(elemType, 0).getValue()));
670 if (isa<FloatType>(elemType)) {
671 return spirv::ConstantOp::create(
674 FloatAttr::get(elemType, 0.0).getValue()));
678 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
681spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
683 if (
auto intType = dyn_cast<IntegerType>(type)) {
684 unsigned width = intType.getWidth();
686 return spirv::ConstantOp::create(builder, loc, type,
688 return spirv::ConstantOp::create(
689 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 1)));
691 if (
auto floatType = dyn_cast<FloatType>(type)) {
692 return spirv::ConstantOp::create(builder, loc, type,
695 if (
auto vectorType = dyn_cast<VectorType>(type)) {
696 Type elemType = vectorType.getElementType();
697 if (isa<IntegerType>(elemType)) {
698 return spirv::ConstantOp::create(
701 IntegerAttr::get(elemType, 1).getValue()));
703 if (isa<FloatType>(elemType)) {
704 return spirv::ConstantOp::create(
707 FloatAttr::get(elemType, 1.0).getValue()));
711 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
714void mlir::spirv::ConstantOp::getAsmResultNames(
719 llvm::raw_svector_ostream specialName(specialNameBuffer);
720 specialName <<
"cst";
722 IntegerType intTy = dyn_cast<IntegerType>(type);
724 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
727 if (intTy.getWidth() == 1) {
728 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
731 if (intTy.isSignless()) {
732 specialName << intCst.getInt();
733 }
else if (intTy.isUnsigned()) {
734 specialName << intCst.getUInt();
736 specialName << intCst.getSInt();
740 if (intTy || isa<FloatType>(type)) {
741 specialName <<
'_' << type;
744 if (
auto vecType = dyn_cast<VectorType>(type)) {
745 specialName <<
"_vec_";
746 specialName << vecType.getDimSize(0);
748 Type elementType = vecType.getElementType();
750 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
751 specialName <<
"x" << elementType;
755 setNameFn(getResult(), specialName.str());
758void mlir::spirv::AddressOfOp::getAsmResultNames(
761 llvm::raw_svector_ostream specialName(specialNameBuffer);
762 specialName << getVariable() <<
"_addr";
763 setNameFn(getResult(), specialName.str());
774 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
775 return typedAttr.getType();
778 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
785LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
788 return emitError(
"unknown value attribute type");
790 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
792 return emitError(
"result type is not a composite type");
794 Type compositeElementType = compositeType.getElementType(0);
797 while (
auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
798 compositeElementType = type.getElementType(0);
799 possibleTypes.push_back(compositeElementType);
802 if (!is_contained(possibleTypes, valueType)) {
803 return emitError(
"expected value attribute type ")
804 << interleaved(possibleTypes,
" or ") <<
", but got: " << valueType;
814LogicalResult spirv::ControlBarrierOp::verify() {
823 spirv::ExecutionModel executionModel,
824 spirv::FuncOp function,
826 build(builder, state,
827 spirv::ExecutionModelAttr::get(builder.
getContext(), executionModel),
828 SymbolRefAttr::get(function), builder.
getArrayAttr(interfaceVars));
831ParseResult spirv::EntryPointOp::parse(
OpAsmParser &parser,
833 spirv::ExecutionModel execModel;
846 FlatSymbolRefAttr var;
848 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
850 interfaceVars.push_back(var);
855 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(
result.name),
863 auto interfaceVars = getInterface().getValue();
864 if (!interfaceVars.empty())
865 printer <<
", " << llvm::interleaved(interfaceVars);
868LogicalResult spirv::EntryPointOp::verify() {
879 spirv::FuncOp function,
880 spirv::ExecutionMode executionMode,
882 build(builder, state, SymbolRefAttr::get(function),
883 spirv::ExecutionModeAttr::get(builder.
getContext(), executionMode),
887ParseResult spirv::ExecutionModeOp::parse(
OpAsmParser &parser,
889 spirv::ExecutionMode execMode;
904 values.push_back(cast<IntegerAttr>(value).getInt());
906 StringRef valuesAttrName =
907 spirv::ExecutionModeOp::getValuesAttrName(
result.name);
908 result.addAttribute(valuesAttrName,
913void spirv::ExecutionModeOp::print(
OpAsmPrinter &printer) {
916 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
919 printer <<
", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
939 bool isVariadic =
false;
941 parser,
false, entryArgs, isVariadic, resultTypes,
946 for (
auto &arg : entryArgs)
947 argTypes.push_back(arg.type);
949 result.addAttribute(getFunctionTypeAttrName(
result.name),
950 TypeAttr::get(fnType));
953 spirv::FunctionControl fnControl;
962 assert(resultAttrs.size() == resultTypes.size());
964 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
965 getResAttrsAttrName(
result.name));
968 auto *body =
result.addRegion();
978 auto fnType = getFunctionType();
980 printer, *
this, fnType.getInputs(),
981 false, fnType.getResults());
982 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
987 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
988 getFunctionControlAttrName()});
991 Region &body = this->getBody();
999LogicalResult spirv::FuncOp::verifyType() {
1000 FunctionType fnType = getFunctionType();
1001 if (fnType.getNumResults() > 1)
1002 return emitOpError(
"cannot have more than one result");
1004 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1005 unsigned argIndex) {
1006 auto func = cast<FunctionOpInterface>(getOperation());
1007 for (
auto argAttr : cast<FunctionOpInterface>(
func).
getArgAttrs(argIndex)) {
1008 if (argAttr.getName() != spirv::DecorationAttr::name)
1010 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1011 return decAttr.getValue() == decoration;
1016 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1017 Type param = fnType.getInputs()[i];
1018 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1022 auto pointeePtrType =
1023 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1024 if (pointeePtrType) {
1030 if (pointeePtrType.getStorageClass() !=
1031 spirv::StorageClass::PhysicalStorageBuffer)
1034 bool hasAliasedPtr =
1035 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1036 bool hasRestrictPtr =
1037 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1038 if (!hasAliasedPtr && !hasRestrictPtr)
1040 <<
"with a pointer points to a physical buffer pointer must "
1041 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1048 if (
auto pointeeArrayType =
1049 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1051 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1053 pointeePtrType = inputPtrType;
1056 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1057 spirv::StorageClass::PhysicalStorageBuffer)
1060 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1061 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1062 if (!hasAliased && !hasRestrict)
1063 return emitOpError() <<
"with physical buffer pointer must be decorated "
1064 "either 'Aliased' or 'Restrict'";
1070LogicalResult spirv::FuncOp::verifyBody() {
1071 FunctionType fnType = getFunctionType();
1072 if (!isExternal()) {
1073 Block &entryBlock = front();
1075 unsigned numArguments = this->getNumArguments();
1078 << numArguments <<
" arguments to match function signature";
1080 for (
auto [
index, fnArgType, blockArgType] :
1082 if (blockArgType != fnArgType) {
1083 return emitOpError(
"type of entry block argument #")
1084 <<
index <<
'(' << blockArgType
1085 <<
") must match the type of the corresponding argument in "
1086 <<
"function signature(" << fnArgType <<
')';
1092 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1093 if (fnType.getNumResults() != 0)
1094 return retOp.emitOpError(
"cannot be used in functions returning value");
1095 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1096 if (fnType.getNumResults() != 1)
1097 return retOp.emitOpError(
1098 "returns 1 value but enclosing function requires ")
1099 << fnType.getNumResults() <<
" results";
1101 auto retOperandType = retOp.getValue().getType();
1102 auto fnResultType = fnType.getResult(0);
1103 if (retOperandType != fnResultType)
1104 return retOp.emitOpError(
" return value's type (")
1105 << retOperandType <<
") mismatch with function's result type ("
1106 << fnResultType <<
")";
1113 return failure(walkResult.wasInterrupted());
1117 StringRef name, FunctionType type,
1118 spirv::FunctionControl control,
1122 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
1124 builder.
getAttr<spirv::FunctionControlAttr>(control));
1133ParseResult spirv::GLFClampOp::parse(
OpAsmParser &parser,
1143ParseResult spirv::GLUClampOp::parse(
OpAsmParser &parser,
1153ParseResult spirv::GLSClampOp::parse(
OpAsmParser &parser,
1173 Type type, StringRef name,
1174 unsigned descriptorSet,
unsigned binding) {
1175 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
1177 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1180 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1185 Type type, StringRef name,
1187 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
1189 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1193ParseResult spirv::GlobalVariableOp::parse(
OpAsmParser &parser,
1196 StringAttr nameAttr;
1197 StringRef initializerAttrName =
1198 spirv::GlobalVariableOp::getInitializerAttrName(
result.name);
1219 StringRef typeAttrName =
1220 spirv::GlobalVariableOp::getTypeAttrName(
result.name);
1225 if (!isa<spirv::PointerType>(type)) {
1226 return parser.
emitError(loc,
"expected spirv.ptr type");
1228 result.addAttribute(typeAttrName, TypeAttr::get(type));
1233void spirv::GlobalVariableOp::print(
OpAsmPrinter &printer) {
1242 StringRef initializerAttrName = this->getInitializerAttrName();
1244 if (
auto initializer = this->getInitializer()) {
1245 printer <<
" " << initializerAttrName <<
'(';
1248 elidedAttrs.push_back(initializerAttrName);
1251 StringRef typeAttrName = this->getTypeAttrName();
1252 elidedAttrs.push_back(typeAttrName);
1254 printer <<
" : " <<
getType();
1257LogicalResult spirv::GlobalVariableOp::verify() {
1258 if (!isa<spirv::PointerType>(
getType()))
1259 return emitOpError(
"result must be of a !spv.ptr type");
1265 auto storageClass = this->storageClass();
1266 if (storageClass == spirv::StorageClass::Generic ||
1267 storageClass == spirv::StorageClass::Function) {
1269 << stringifyStorageClass(storageClass) <<
"'";
1273 this->getInitializerAttrName())) {
1275 (*this)->getParentOp(), init.getAttr());
1287 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1288 return emitOpError(
"initializer must be result of a "
1289 "spirv.SpecConstant or "
1290 "spirv.SpecConstantCompositeOp op");
1301LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1312ParseResult spirv::INTELSubgroupBlockWriteOp::parse(
OpAsmParser &parser,
1315 spirv::StorageClass storageClass;
1326 if (
auto valVecTy = dyn_cast<VectorType>(elementType))
1336void spirv::INTELSubgroupBlockWriteOp::print(
OpAsmPrinter &printer) {
1337 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1338 << getValue().getType();
1341LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1352LogicalResult spirv::IAddCarryOp::verify() {
1353 return ::verifyArithmeticExtendedBinaryOp(*
this);
1356ParseResult spirv::IAddCarryOp::parse(
OpAsmParser &parser,
1358 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1369LogicalResult spirv::ISubBorrowOp::verify() {
1370 return ::verifyArithmeticExtendedBinaryOp(*
this);
1373ParseResult spirv::ISubBorrowOp::parse(
OpAsmParser &parser,
1375 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1378void spirv::ISubBorrowOp::print(
OpAsmPrinter &printer) {
1386LogicalResult spirv::SMulExtendedOp::verify() {
1387 return ::verifyArithmeticExtendedBinaryOp(*
this);
1390ParseResult spirv::SMulExtendedOp::parse(
OpAsmParser &parser,
1392 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1395void spirv::SMulExtendedOp::print(
OpAsmPrinter &printer) {
1403LogicalResult spirv::UMulExtendedOp::verify() {
1404 return ::verifyArithmeticExtendedBinaryOp(*
this);
1407ParseResult spirv::UMulExtendedOp::parse(
OpAsmParser &parser,
1409 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1412void spirv::UMulExtendedOp::print(
OpAsmPrinter &printer) {
1420LogicalResult spirv::MemoryBarrierOp::verify() {
1429 std::optional<StringRef> name) {
1439 spirv::AddressingModel addressingModel,
1440 spirv::MemoryModel memoryModel,
1441 std::optional<VerCapExtAttr> vceTriple,
1442 std::optional<StringRef> name) {
1445 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1447 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1451 state.
addAttribute(getVCETripleAttrName(), *vceTriple);
1457ParseResult spirv::ModuleOp::parse(
OpAsmParser &parser,
1462 StringAttr nameAttr;
1467 spirv::AddressingModel addrModel;
1468 spirv::MemoryModel memoryModel;
1478 spirv::ModuleOp::getVCETripleAttrName(),
1495 if (std::optional<StringRef> name = getName()) {
1506 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1509 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1510 printer <<
" requires " << *triple;
1511 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1519LogicalResult spirv::ModuleOp::verifyRegions() {
1520 Dialect *dialect = (*this)->getDialect();
1525 for (
auto &op : *getBody()) {
1527 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1532 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1533 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1535 return entryPointOp.emitError(
"function '")
1536 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1538 if (
auto interface = entryPointOp.getInterface()) {
1540 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1542 return entryPointOp.emitError(
1543 "expected symbol reference for interface "
1544 "specification instead of '")
1548 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1550 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1551 "symbol reference instead of'")
1552 << varSymRef <<
"'";
1557 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1558 funcOp, entryPointOp.getExecutionModel());
1559 if (!entryPoints.try_emplace(key, entryPointOp).second)
1560 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1561 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1565 auto linkageAttr = funcOp.getLinkageAttributes();
1566 auto hasImportLinkage =
1567 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1568 spirv::LinkageType::Import);
1569 if (funcOp.isExternal() && !hasImportLinkage)
1571 "'spirv.module' cannot contain external functions "
1572 "without 'Import' linkage_attributes (LinkageAttributes)");
1575 for (
auto &block : funcOp)
1576 for (
auto &op : block) {
1579 "functions in 'spirv.module' can only contain spirv.* ops");
1591LogicalResult spirv::ReferenceOfOp::verify() {
1593 (*this)->getParentOp(), getSpecConstAttr());
1596 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1598 constType = specConstOp.getDefaultValue().getType();
1600 auto specConstCompositeOp =
1601 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1602 if (specConstCompositeOp)
1603 constType = specConstCompositeOp.getType();
1605 if (!specConstOp && !specConstCompositeOp)
1607 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1609 if (getReference().
getType() != constType)
1610 return emitOpError(
"result type mismatch with the referenced "
1611 "specialization constant's type");
1620ParseResult spirv::SpecConstantOp::parse(
OpAsmParser &parser,
1622 StringAttr nameAttr;
1624 StringRef defaultValueAttrName =
1625 spirv::SpecConstantOp::getDefaultValueAttrName(
result.name);
1633 IntegerAttr specIdAttr;
1647void spirv::SpecConstantOp::print(
OpAsmPrinter &printer) {
1650 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1652 printer <<
" = " << getDefaultValue();
1655LogicalResult spirv::SpecConstantOp::verify() {
1656 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1657 if (specID.getValue().isNegative())
1660 auto value = getDefaultValue();
1661 if (isa<IntegerAttr, FloatAttr>(value)) {
1663 if (!isa<spirv::SPIRVType>(value.getType()))
1664 return emitOpError(
"default value bitwidth disallowed");
1668 "default value can only be a bool, integer, or float scalar");
1675LogicalResult spirv::VectorShuffleOp::verify() {
1676 VectorType resultType = cast<VectorType>(
getType());
1678 size_t numResultElements = resultType.getNumElements();
1679 if (numResultElements != getComponents().size())
1681 << numResultElements
1682 <<
") mismatch with the number of component selectors ("
1683 << getComponents().size() <<
")";
1685 size_t totalSrcElements =
1686 cast<VectorType>(getVector1().
getType()).getNumElements() +
1687 cast<VectorType>(getVector2().
getType()).getNumElements();
1689 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1690 uint32_t
index = selector.getZExtValue();
1691 if (
index >= totalSrcElements &&
1692 index != std::numeric_limits<uint32_t>().
max())
1694 <<
index <<
" out of range: expected to be in [0, "
1695 << totalSrcElements <<
") or 0xffffffff";
1704LogicalResult spirv::MatrixTimesScalarOp::verify() {
1708 [](
auto matrixType) {
return matrixType.getElementType(); })
1711 assert(elementType &&
"Unhandled type");
1714 if (getScalar().
getType() != elementType)
1715 return emitOpError(
"input matrix components' type and scaling value must "
1716 "have the same type");
1725LogicalResult spirv::TransposeOp::verify() {
1726 auto inputMatrix = cast<spirv::MatrixType>(getMatrix().
getType());
1727 auto resultMatrix = cast<spirv::MatrixType>(getResult().
getType());
1730 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1731 return emitError(
"input matrix rows count must be equal to "
1732 "output matrix columns count");
1734 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1735 return emitError(
"input matrix columns count must be equal to "
1736 "output matrix rows count");
1739 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1740 return emitError(
"input and output matrices must have the same "
1750LogicalResult spirv::MatrixTimesVectorOp::verify() {
1751 auto matrixType = cast<spirv::MatrixType>(getMatrix().
getType());
1752 auto vectorType = cast<VectorType>(getVector().
getType());
1753 auto resultType = cast<VectorType>(
getType());
1755 if (matrixType.getNumColumns() != vectorType.getNumElements())
1757 << matrixType.getNumColumns() <<
") must match vector operand size ("
1758 << vectorType.getNumElements() <<
")";
1760 if (resultType.getNumElements() != matrixType.getNumRows())
1762 << resultType.getNumElements() <<
") must match the matrix rows ("
1763 << matrixType.getNumRows() <<
")";
1765 if (matrixType.getElementType() != resultType.getElementType())
1766 return emitOpError(
"matrix and result element types must match");
1775LogicalResult spirv::VectorTimesMatrixOp::verify() {
1776 auto vectorType = cast<VectorType>(getVector().
getType());
1777 auto matrixType = cast<spirv::MatrixType>(getMatrix().
getType());
1778 auto resultType = cast<VectorType>(
getType());
1780 if (matrixType.getNumRows() != vectorType.getNumElements())
1781 return emitOpError(
"number of components in vector must equal the number "
1782 "of components in each column in matrix");
1784 if (resultType.getNumElements() != matrixType.getNumColumns())
1785 return emitOpError(
"number of columns in matrix must equal the number of "
1786 "components in result");
1788 if (matrixType.getElementType() != resultType.getElementType())
1789 return emitOpError(
"matrix must be a matrix with the same component type "
1790 "as the component type in result");
1799LogicalResult spirv::MatrixTimesMatrixOp::verify() {
1800 auto leftMatrix = cast<spirv::MatrixType>(getLeftmatrix().
getType());
1801 auto rightMatrix = cast<spirv::MatrixType>(getRightmatrix().
getType());
1802 auto resultMatrix = cast<spirv::MatrixType>(getResult().
getType());
1805 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1806 return emitError(
"left matrix columns' count must be equal to "
1807 "the right matrix rows' count");
1810 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1812 "right and result matrices must have equal columns' count");
1815 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1816 return emitError(
"right and result matrices' component type must"
1820 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1821 return emitError(
"left and result matrices' component type"
1822 " must be the same");
1825 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1826 return emitError(
"left and result matrices must have equal rows' count");
1835ParseResult spirv::SpecConstantCompositeOp::parse(
OpAsmParser &parser,
1838 StringAttr compositeName;
1850 const char *attrName =
"spec_const";
1857 constituents.push_back(specConstRef);
1863 StringAttr compositeSpecConstituentsName =
1864 spirv::SpecConstantCompositeOp::getConstituentsAttrName(
result.name);
1865 result.addAttribute(compositeSpecConstituentsName,
1872 StringAttr typeAttrName =
1873 spirv::SpecConstantCompositeOp::getTypeAttrName(
result.name);
1874 result.addAttribute(typeAttrName, TypeAttr::get(type));
1879void spirv::SpecConstantCompositeOp::print(
OpAsmPrinter &printer) {
1882 printer <<
" (" << llvm::interleaved(this->getConstituents().getValue())
1886LogicalResult spirv::SpecConstantCompositeOp::verify() {
1887 auto cType = dyn_cast<spirv::CompositeType>(
getType());
1888 auto constituents = this->getConstituents().getValue();
1891 return emitError(
"result type must be a composite type, but provided ")
1894 if (isa<spirv::CooperativeMatrixType>(cType))
1895 return emitError(
"unsupported composite type ") << cType;
1896 if (constituents.size() != cType.getNumElements())
1897 return emitError(
"has incorrect number of operands: expected ")
1898 << cType.getNumElements() <<
", but provided "
1899 << constituents.size();
1901 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1902 auto constituent = cast<FlatSymbolRefAttr>(constituents[
index]);
1904 auto constituentSpecConstOp =
1906 (*this)->getParentOp(), constituent.getAttr()));
1908 if (constituentSpecConstOp.getDefaultValue().getType() !=
1909 cType.getElementType(
index))
1910 return emitError(
"has incorrect types of operands: expected ")
1911 << cType.getElementType(
index) <<
", but provided "
1912 << constituentSpecConstOp.getDefaultValue().getType();
1923spirv::EXTSpecConstantCompositeReplicateOp::parse(
OpAsmParser &parser,
1925 StringAttr compositeName;
1927 const char *attrName =
"spec_const";
1938 StringAttr compositeSpecConstituentName =
1939 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1941 result.addAttribute(compositeSpecConstituentName, specConstRef);
1943 StringAttr typeAttrName =
1944 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(
result.name);
1945 result.addAttribute(typeAttrName, TypeAttr::get(type));
1950void spirv::EXTSpecConstantCompositeReplicateOp::print(
OpAsmPrinter &printer) {
1953 printer <<
" (" << this->getConstituent() <<
") : " <<
getType();
1956LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1957 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
1959 return emitError(
"result type must be a composite type, but provided ")
1963 (*this)->getParentOp(), this->getConstituent());
1966 "splat spec constant reference defining constituent not found");
1968 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
1969 if (!constituentSpecConstOp)
1970 return emitError(
"constituent is not a spec constant");
1972 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
1973 Type compositeElementType = compositeType.getElementType(0);
1974 if (constituentType != compositeElementType)
1975 return emitError(
"constituent has incorrect type: expected ")
1976 << compositeElementType <<
", but provided " << constituentType;
1985ParseResult spirv::SpecConstantOperationOp::parse(
OpAsmParser &parser,
2001 spirv::YieldOp::create(builder, wrappedOp->
getLoc(), wrappedOp->
getResult(0));
2012void spirv::SpecConstantOperationOp::print(
OpAsmPrinter &printer) {
2013 printer <<
" wraps ";
2017LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2018 Block &block = getRegion().getBlocks().
front();
2021 return emitOpError(
"expected exactly 2 nested ops");
2029 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2030 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2032 "invalid operand, must be defined by a constant operation");
2041LogicalResult spirv::GLFrexpStructOp::verify() {
2043 dyn_cast<spirv::StructType>(getResult().
getType());
2046 return emitError(
"result type must be a struct type with two memebers");
2050 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2051 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2053 Type operandTy = getOperand().getType();
2054 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2055 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2057 if (significandTy != operandTy)
2058 return emitError(
"member zero of the resulting struct type must be the "
2059 "same type as the operand");
2061 if (exponentVecTy) {
2062 IntegerType componentIntTy =
2063 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2064 if (!componentIntTy || componentIntTy.getWidth() != 32)
2065 return emitError(
"member one of the resulting struct type must"
2066 "be a scalar or vector of 32 bit integer type");
2067 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2068 return emitError(
"member one of the resulting struct type "
2069 "must be a scalar or vector of 32 bit integer type");
2073 if (operandVecTy && exponentVecTy &&
2074 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2077 if (operandFTy && exponentIntTy)
2080 return emitError(
"member one of the resulting struct type must have the same "
2081 "number of components as the operand type");
2088LogicalResult spirv::GLLdexpOp::verify() {
2089 Type significandType = getX().getType();
2090 Type exponentType = getExp().getType();
2092 if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
2093 return emitOpError(
"operands must both be scalars or vectors");
2096 if (
auto vectorType = dyn_cast<VectorType>(type))
2097 return vectorType.getNumElements();
2102 return emitOpError(
"operands must have the same number of elements");
2111LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2119LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2127LogicalResult spirv::ShiftRightLogicalOp::verify() {
2135LogicalResult spirv::VectorTimesScalarOp::verify() {
2137 return emitOpError(
"vector operand and result type mismatch");
2138 auto scalarType = cast<VectorType>(
getType()).getElementType();
2139 if (getScalar().
getType() != scalarType)
2140 return emitOpError(
"scalar operand and result element type match");
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static Type getValueType(Attribute attr)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
constexpr StringRef attributeName()
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.