32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/InterleavedRange.h"
42 #include <type_traits>
52 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
56 auto valueAttr = constOp.getValue();
57 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58 if (!integerValueAttr) {
62 if (integerValueAttr.getType().isSignlessInteger())
63 value = integerValueAttr.getInt();
65 value = integerValueAttr.getSInt();
72 spirv::MemorySemantics memorySemantics) {
79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80 spirv::MemorySemantics::Release |
81 spirv::MemorySemantics::AcquireRelease |
82 spirv::MemorySemantics::SequentiallyConsistent;
85 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
88 "expected at most one of these four memory constraints "
89 "to be set: `Acquire`, `Release`,"
90 "`AcquireRelease` or `SequentiallyConsistent`");
99 stringifyDecoration(spirv::Decoration::DescriptorSet));
100 auto bindingName = llvm::convertToSnakeFromCamelCase(
101 stringifyDecoration(spirv::Decoration::Binding));
104 if (descriptorSet && binding) {
107 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
112 auto builtInName = llvm::convertToSnakeFromCamelCase(
113 stringifyDecoration(spirv::Decoration::BuiltIn));
114 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
115 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
116 elidedAttrs.push_back(builtInName);
134 auto fnType = llvm::dyn_cast<FunctionType>(type);
136 parser.
emitError(loc,
"expected function type");
141 result.
addTypes(fnType.getResults());
152 assert(op->
getNumResults() == 1 &&
"op should have one result");
158 [&](
Type type) { return type != resultType; })) {
167 p <<
" : " << resultType;
170 template <
typename BlockReadWriteOpTy>
174 if (
auto valVecTy = llvm::dyn_cast<VectorType>(valType))
175 valType = valVecTy.getElementType();
178 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
179 return op.emitOpError(
"mismatch in result type and pointer type");
190 if (indices.empty()) {
191 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
195 for (
auto index : indices) {
196 if (
auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
197 if (cType.hasCompileTimeKnownNumElements() &&
199 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
200 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
203 type = cType.getElementType(index);
205 emitErrorFn(
"cannot extract from non-composite type ")
206 << type <<
" with index " << index;
216 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
217 if (!indicesArrayAttr) {
218 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
221 if (indicesArrayAttr.empty()) {
222 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
227 for (
auto indexAttr : indicesArrayAttr) {
228 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
230 emitErrorFn(
"expected an 32-bit integer for index, but found '")
234 indexVals.push_back(indexIntAttr.getInt());
254 template <
typename ExtendedBinaryOp>
256 auto resultType = llvm::cast<spirv::StructType>(op.getType());
257 if (resultType.getNumElements() != 2)
258 return op.emitOpError(
"expected result struct type containing two members");
260 if (!llvm::all_equal({op.getOperand1().
getType(), op.getOperand2().getType(),
261 resultType.getElementType(0),
262 resultType.getElementType(1)}))
263 return op.emitOpError(
264 "expected all operand types and struct member types are the same");
281 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
282 if (!structType || structType.getNumElements() != 2)
283 return parser.
emitError(loc,
"expected spirv.struct type with two members");
303 return op->
emitError(
"expected the same type for the first operand and "
304 "result, but provided ")
316 spirv::GlobalVariableOp var) {
321 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
325 return emitOpError(
"expected spirv.GlobalVariable symbol");
327 if (getPointer().
getType() != varOp.getType()) {
329 "result type mismatch with the referenced global variable's type");
339 operand_range constituents = this->getConstituents();
347 auto coopElementType =
350 [](
auto coopType) {
return coopType.getElementType(); })
351 .Default([](
Type) {
return nullptr; });
354 if (coopElementType) {
355 if (constituents.size() != 1)
356 return emitOpError(
"has incorrect number of operands: expected ")
357 <<
"1, but provided " << constituents.size();
358 if (coopElementType != constituents.front().getType())
359 return emitOpError(
"operand type mismatch: expected operand type ")
360 << coopElementType <<
", but provided "
361 << constituents.front().getType();
366 auto cType = llvm::cast<spirv::CompositeType>(
getType());
367 if (constituents.size() == cType.getNumElements()) {
368 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
369 if (constituents[index].
getType() != cType.getElementType(index)) {
370 return emitOpError(
"operand type mismatch: expected operand type ")
371 << cType.getElementType(index) <<
", but provided "
372 << constituents[index].getType();
379 auto resultType = llvm::dyn_cast<VectorType>(cType);
382 "expected to return a vector or cooperative matrix when the number of "
383 "constituents is less than what the result needs");
386 for (
Value component : constituents) {
387 if (!llvm::isa<VectorType>(component.getType()) &&
388 !component.getType().isIntOrFloat())
389 return emitOpError(
"operand type mismatch: expected operand to have "
390 "a scalar or vector type, but provided ")
391 << component.getType();
393 Type elementType = component.getType();
394 if (
auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
395 sizes.push_back(vectorType.getNumElements());
396 elementType = vectorType.getElementType();
401 if (elementType != resultType.getElementType())
402 return emitOpError(
"operand element type mismatch: expected to be ")
403 << resultType.getElementType() <<
", but provided " << elementType;
405 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
406 if (totalCount != cType.getNumElements())
407 return emitOpError(
"has incorrect number of operands: expected ")
408 << cType.getNumElements() <<
", but provided " << totalCount;
425 build(builder, state, elementType, composite, indexAttr);
432 StringRef indicesAttrName =
433 spirv::CompositeExtractOp::getIndicesAttrName(result.
name);
455 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
460 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
467 return emitOpError(
"invalid result type: expected ")
468 << resultType <<
" but provided " <<
getType();
482 build(builder, state, composite.
getType(),
object, composite, indexAttr);
488 Type objectType, compositeType;
490 StringRef indicesAttrName =
491 spirv::CompositeInsertOp::getIndicesAttrName(result.
name);
505 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
511 if (objectType != getObject().
getType()) {
512 return emitOpError(
"object operand type should be ")
513 << objectType <<
", but found " << getObject().getType();
517 return emitOpError(
"result type should be the same as "
518 "the composite type, but found ")
519 << getComposite().getType() <<
" vs " <<
getType();
526 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
527 <<
" : " << getObject().
getType() <<
" into "
528 << getComposite().getType();
538 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.
name);
543 if (
auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
544 type = typedAttr.getType();
545 if (llvm::isa<NoneType, TensorType>(type)) {
554 printer <<
' ' << getValue();
555 if (llvm::isa<spirv::ArrayType>(
getType()))
561 if (isa<spirv::CooperativeMatrixType>(opType)) {
562 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
563 if (!denseAttr || !denseAttr.isSplat())
564 return op.emitOpError(
"expected a splat dense attribute for cooperative "
565 "matrix constant, but found ")
568 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
569 auto valueType = llvm::cast<TypedAttr>(value).getType();
570 if (valueType != opType)
571 return op.emitOpError(
"result type (")
572 << opType <<
") does not match value type (" << valueType <<
")";
575 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
576 auto valueType = llvm::cast<TypedAttr>(value).getType();
577 if (valueType == opType)
579 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
580 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
582 return op.emitOpError(
"result or element type (")
583 << opType <<
") does not match value type (" << valueType
584 <<
"), must be the same or spirv.array";
586 int numElements = arrayType.getNumElements();
587 auto opElemType = arrayType.getElementType();
588 while (
auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
589 numElements *= t.getNumElements();
590 opElemType = t.getElementType();
592 if (!opElemType.isIntOrFloat())
593 return op.emitOpError(
"only support nested array result type");
595 auto valueElemType = shapedType.getElementType();
596 if (valueElemType != opElemType) {
597 return op.emitOpError(
"result element type (")
598 << opElemType <<
") does not match value element type ("
599 << valueElemType <<
")";
602 if (numElements != shapedType.getNumElements()) {
603 return op.emitOpError(
"result number of elements (")
604 << numElements <<
") does not match value number of elements ("
605 << shapedType.getNumElements() <<
")";
609 if (
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
610 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
612 return op.emitOpError(
613 "must have spirv.array result type for array value");
614 Type elemType = arrayType.getElementType();
615 for (
Attribute element : arrayAttr.getValue()) {
622 return op.emitOpError(
"cannot have attribute: ") << value;
632 bool spirv::ConstantOp::isBuildableWith(
Type type) {
634 if (!llvm::isa<spirv::SPIRVType>(type))
639 return llvm::isa<spirv::ArrayType>(type);
647 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
648 unsigned width = intType.getWidth();
650 return builder.
create<spirv::ConstantOp>(loc, type,
652 return builder.
create<spirv::ConstantOp>(
655 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
656 return builder.
create<spirv::ConstantOp>(
659 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
660 Type elemType = vectorType.getElementType();
661 if (llvm::isa<IntegerType>(elemType)) {
662 return builder.
create<spirv::ConstantOp>(
667 if (llvm::isa<FloatType>(elemType)) {
668 return builder.
create<spirv::ConstantOp>(
675 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
678 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
680 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
681 unsigned width = intType.getWidth();
683 return builder.
create<spirv::ConstantOp>(loc, type,
685 return builder.
create<spirv::ConstantOp>(
688 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
689 return builder.
create<spirv::ConstantOp>(
692 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
693 Type elemType = vectorType.getElementType();
694 if (llvm::isa<IntegerType>(elemType)) {
695 return builder.
create<spirv::ConstantOp>(
700 if (llvm::isa<FloatType>(elemType)) {
701 return builder.
create<spirv::ConstantOp>(
708 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
711 void mlir::spirv::ConstantOp::getAsmResultNames(
716 llvm::raw_svector_ostream specialName(specialNameBuffer);
717 specialName <<
"cst";
719 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
721 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
722 if (intTy && intTy.getWidth() == 1) {
723 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
726 if (intTy.isSignless()) {
727 specialName << intCst.getInt();
728 }
else if (intTy.isUnsigned()) {
729 specialName << intCst.getUInt();
731 specialName << intCst.getSInt();
735 if (intTy || llvm::isa<FloatType>(type)) {
736 specialName <<
'_' << type;
739 if (
auto vecType = llvm::dyn_cast<VectorType>(type)) {
740 specialName <<
"_vec_";
741 specialName << vecType.getDimSize(0);
743 Type elementType = vecType.getElementType();
745 if (llvm::isa<IntegerType>(elementType) ||
746 llvm::isa<FloatType>(elementType)) {
747 specialName <<
"x" << elementType;
751 setNameFn(getResult(), specialName.str());
754 void mlir::spirv::AddressOfOp::getAsmResultNames(
757 llvm::raw_svector_ostream specialName(specialNameBuffer);
758 specialName << getVariable() <<
"_addr";
759 setNameFn(getResult(), specialName.str());
775 spirv::ExecutionModel executionModel,
776 spirv::FuncOp
function,
778 build(builder, state,
785 spirv::ExecutionModel execModel;
789 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
798 FlatSymbolRefAttr var;
800 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
802 interfaceVars.push_back(var);
815 auto interfaceVars = getInterface().getValue();
816 if (!interfaceVars.empty())
817 printer <<
", " << llvm::interleaved(interfaceVars);
831 spirv::FuncOp
function,
832 spirv::ExecutionMode executionMode,
841 spirv::ExecutionMode execMode;
844 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
856 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
858 StringRef valuesAttrName =
859 spirv::ExecutionModeOp::getValuesAttrName(result.
name);
868 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
869 ArrayAttr values = this->getValues();
871 printer <<
", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
891 bool isVariadic =
false;
893 parser,
false, entryArgs, isVariadic, resultTypes,
898 for (
auto &arg : entryArgs)
899 argTypes.push_back(arg.type);
905 spirv::FunctionControl fnControl;
906 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
914 assert(resultAttrs.size() == resultTypes.size());
916 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
917 getResAttrsAttrName(result.
name));
923 return failure(parseResult.
has_value() && failed(*parseResult));
930 auto fnType = getFunctionType();
932 printer, *
this, fnType.getInputs(),
933 false, fnType.getResults());
934 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
938 {spirv::attributeName<spirv::FunctionControl>(),
939 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940 getFunctionControlAttrName()});
943 Region &body = this->getBody();
951 LogicalResult spirv::FuncOp::verifyType() {
952 FunctionType fnType = getFunctionType();
953 if (fnType.getNumResults() > 1)
954 return emitOpError(
"cannot have more than one result");
956 auto hasDecorationAttr = [&](spirv::Decoration decoration,
958 auto func = llvm::cast<FunctionOpInterface>(getOperation());
959 for (
auto argAttr : cast<FunctionOpInterface>(func).
getArgAttrs(argIndex)) {
960 if (argAttr.getName() != spirv::DecorationAttr::name)
962 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
963 return decAttr.getValue() == decoration;
968 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969 Type param = fnType.getInputs()[i];
970 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
974 auto pointeePtrType =
975 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
976 if (pointeePtrType) {
982 if (pointeePtrType.getStorageClass() !=
983 spirv::StorageClass::PhysicalStorageBuffer)
987 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988 bool hasRestrictPtr =
989 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990 if (!hasAliasedPtr && !hasRestrictPtr)
992 <<
"with a pointer points to a physical buffer pointer must "
993 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1000 if (
auto pointeeArrayType =
1001 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1003 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1005 pointeePtrType = inputPtrType;
1008 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009 spirv::StorageClass::PhysicalStorageBuffer)
1012 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014 if (!hasAliased && !hasRestrict)
1015 return emitOpError() <<
"with physical buffer pointer must be decorated "
1016 "either 'Aliased' or 'Restrict'";
1022 LogicalResult spirv::FuncOp::verifyBody() {
1023 FunctionType fnType = getFunctionType();
1024 if (!isExternal()) {
1025 Block &entryBlock = front();
1027 unsigned numArguments = this->getNumArguments();
1029 return emitOpError(
"entry block must have ")
1030 << numArguments <<
" arguments to match function signature";
1032 for (
auto [index, fnArgType, blockArgType] :
1034 if (blockArgType != fnArgType) {
1035 return emitOpError(
"type of entry block argument #")
1036 << index <<
'(' << blockArgType
1037 <<
") must match the type of the corresponding argument in "
1038 <<
"function signature(" << fnArgType <<
')';
1044 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1045 if (fnType.getNumResults() != 0)
1046 return retOp.emitOpError(
"cannot be used in functions returning value");
1047 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1048 if (fnType.getNumResults() != 1)
1049 return retOp.emitOpError(
1050 "returns 1 value but enclosing function requires ")
1051 << fnType.getNumResults() <<
" results";
1053 auto retOperandType = retOp.getValue().getType();
1054 auto fnResultType = fnType.getResult(0);
1055 if (retOperandType != fnResultType)
1056 return retOp.emitOpError(
" return value's type (")
1057 << retOperandType <<
") mismatch with function's result type ("
1058 << fnResultType <<
")";
1065 return failure(walkResult.wasInterrupted());
1069 StringRef name, FunctionType type,
1070 spirv::FunctionControl control,
1074 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
1075 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1076 builder.
getAttr<spirv::FunctionControlAttr>(control));
1077 state.attributes.append(attrs.begin(), attrs.end());
1125 Type type, StringRef name,
1126 unsigned descriptorSet,
unsigned binding) {
1129 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1132 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1137 Type type, StringRef name,
1138 spirv::BuiltIn builtin) {
1141 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1148 StringAttr nameAttr;
1149 StringRef initializerAttrName =
1150 spirv::GlobalVariableOp::getInitializerAttrName(result.
name);
1171 StringRef typeAttrName =
1172 spirv::GlobalVariableOp::getTypeAttrName(result.
name);
1177 if (!llvm::isa<spirv::PointerType>(type)) {
1178 return parser.
emitError(loc,
"expected spirv.ptr type");
1187 spirv::attributeName<spirv::StorageClass>()};
1194 StringRef initializerAttrName = this->getInitializerAttrName();
1196 if (
auto initializer = this->getInitializer()) {
1197 printer <<
" " << initializerAttrName <<
'(';
1200 elidedAttrs.push_back(initializerAttrName);
1203 StringRef typeAttrName = this->getTypeAttrName();
1204 elidedAttrs.push_back(typeAttrName);
1206 printer <<
" : " <<
getType();
1210 if (!llvm::isa<spirv::PointerType>(
getType()))
1211 return emitOpError(
"result must be of a !spv.ptr type");
1217 auto storageClass = this->storageClass();
1218 if (storageClass == spirv::StorageClass::Generic ||
1219 storageClass == spirv::StorageClass::Function) {
1220 return emitOpError(
"storage class cannot be '")
1221 << stringifyStorageClass(storageClass) <<
"'";
1225 this->getInitializerAttrName())) {
1227 (*this)->getParentOp(), init.getAttr());
1231 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1232 spirv::SpecConstantCompositeOp>(initOp)) {
1233 return emitOpError(
"initializer must be result of a "
1234 "spirv.SpecConstant or spirv.GlobalVariable or "
1235 "spirv.SpecConstantCompositeOp op");
1260 spirv::StorageClass storageClass;
1271 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1282 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1283 << getValue().getType();
1374 std::optional<StringRef> name) {
1384 spirv::AddressingModel addressingModel,
1385 spirv::MemoryModel memoryModel,
1386 std::optional<VerCapExtAttr> vceTriple,
1387 std::optional<StringRef> name) {
1390 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1391 state.addAttribute(
"memory_model",
1392 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1396 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1407 StringAttr nameAttr;
1412 spirv::AddressingModel addrModel;
1413 spirv::MemoryModel memoryModel;
1414 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1416 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1423 spirv::ModuleOp::getVCETripleAttrName(),
1440 if (std::optional<StringRef> name = getName()) {
1449 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1450 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1451 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1454 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1455 printer <<
" requires " << *triple;
1456 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1464 LogicalResult spirv::ModuleOp::verifyRegions() {
1465 Dialect *dialect = (*this)->getDialect();
1470 for (
auto &op : *getBody()) {
1472 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1477 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1478 auto funcOp =
table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1480 return entryPointOp.emitError(
"function '")
1481 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1483 if (
auto interface = entryPointOp.getInterface()) {
1485 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1487 return entryPointOp.emitError(
1488 "expected symbol reference for interface "
1489 "specification instead of '")
1493 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1495 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1496 "symbol reference instead of'")
1497 << varSymRef <<
"'";
1502 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1503 funcOp, entryPointOp.getExecutionModel());
1504 if (!entryPoints.try_emplace(key, entryPointOp).second)
1505 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1506 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1510 auto linkageAttr = funcOp.getLinkageAttributes();
1511 auto hasImportLinkage =
1512 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1513 spirv::LinkageType::Import);
1514 if (funcOp.isExternal() && !hasImportLinkage)
1516 "'spirv.module' cannot contain external functions "
1517 "without 'Import' linkage_attributes (LinkageAttributes)");
1520 for (
auto &block : funcOp)
1521 for (
auto &op : block) {
1524 "functions in 'spirv.module' can only contain spirv.* ops");
1538 (*this)->getParentOp(), getSpecConstAttr());
1541 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1543 constType = specConstOp.getDefaultValue().getType();
1545 auto specConstCompositeOp =
1546 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1547 if (specConstCompositeOp)
1548 constType = specConstCompositeOp.getType();
1550 if (!specConstOp && !specConstCompositeOp)
1552 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1554 if (getReference().
getType() != constType)
1555 return emitOpError(
"result type mismatch with the referenced "
1556 "specialization constant's type");
1567 StringAttr nameAttr;
1569 StringRef defaultValueAttrName =
1570 spirv::SpecConstantOp::getDefaultValueAttrName(result.
name);
1578 IntegerAttr specIdAttr;
1595 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1597 printer <<
" = " << getDefaultValue();
1601 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1602 if (specID.getValue().isNegative())
1603 return emitOpError(
"SpecId cannot be negative");
1605 auto value = getDefaultValue();
1606 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1608 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1609 return emitOpError(
"default value bitwidth disallowed");
1613 "default value can only be a bool, integer, or float scalar");
1621 VectorType resultType = llvm::cast<VectorType>(
getType());
1623 size_t numResultElements = resultType.getNumElements();
1624 if (numResultElements != getComponents().size())
1625 return emitOpError(
"result type element count (")
1626 << numResultElements
1627 <<
") mismatch with the number of component selectors ("
1628 << getComponents().size() <<
")";
1630 size_t totalSrcElements =
1631 llvm::cast<VectorType>(getVector1().
getType()).getNumElements() +
1632 llvm::cast<VectorType>(getVector2().
getType()).getNumElements();
1634 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1635 uint32_t index = selector.getZExtValue();
1636 if (index >= totalSrcElements &&
1637 index != std::numeric_limits<uint32_t>().
max())
1638 return emitOpError(
"component selector ")
1639 << index <<
" out of range: expected to be in [0, "
1640 << totalSrcElements <<
") or 0xffffffff";
1653 [](
auto matrixType) {
return matrixType.getElementType(); })
1654 .Default([](
Type) {
return nullptr; });
1656 assert(elementType &&
"Unhandled type");
1659 if (getScalar().
getType() != elementType)
1660 return emitOpError(
"input matrix components' type and scaling value must "
1661 "have the same type");
1671 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1672 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1675 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1676 return emitError(
"input matrix rows count must be equal to "
1677 "output matrix columns count");
1679 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1680 return emitError(
"input matrix columns count must be equal to "
1681 "output matrix rows count");
1684 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1685 return emitError(
"input and output matrices must have the same "
1696 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1697 auto vectorType = llvm::cast<VectorType>(getVector().
getType());
1698 auto resultType = llvm::cast<VectorType>(
getType());
1700 if (matrixType.getNumColumns() != vectorType.getNumElements())
1701 return emitOpError(
"matrix columns (")
1702 << matrixType.getNumColumns() <<
") must match vector operand size ("
1703 << vectorType.getNumElements() <<
")";
1705 if (resultType.getNumElements() != matrixType.getNumRows())
1706 return emitOpError(
"result size (")
1707 << resultType.getNumElements() <<
") must match the matrix rows ("
1708 << matrixType.getNumRows() <<
")";
1710 if (matrixType.getElementType() != resultType.getElementType())
1711 return emitOpError(
"matrix and result element types must match");
1721 auto vectorType = llvm::cast<VectorType>(getVector().
getType());
1722 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1723 auto resultType = llvm::cast<VectorType>(
getType());
1725 if (matrixType.getNumRows() != vectorType.getNumElements())
1726 return emitOpError(
"number of components in vector must equal the number "
1727 "of components in each column in matrix");
1729 if (resultType.getNumElements() != matrixType.getNumColumns())
1730 return emitOpError(
"number of columns in matrix must equal the number of "
1731 "components in result");
1733 if (matrixType.getElementType() != resultType.getElementType())
1734 return emitOpError(
"matrix must be a matrix with the same component type "
1735 "as the component type in result");
1745 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().
getType());
1746 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().
getType());
1747 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1750 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1751 return emitError(
"left matrix columns' count must be equal to "
1752 "the right matrix rows' count");
1755 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1757 "right and result matrices must have equal columns' count");
1760 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1761 return emitError(
"right and result matrices' component type must"
1765 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1766 return emitError(
"left and result matrices' component type"
1767 " must be the same");
1770 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1771 return emitError(
"left and result matrices must have equal rows' count");
1783 StringAttr compositeName;
1795 const char *attrName =
"spec_const";
1802 constituents.push_back(specConstRef);
1808 StringAttr compositeSpecConstituentsName =
1809 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.
name);
1817 StringAttr typeAttrName =
1818 spirv::SpecConstantCompositeOp::getTypeAttrName(result.
name);
1827 printer <<
" (" << llvm::interleaved(this->getConstituents().getValue())
1832 auto cType = llvm::dyn_cast<spirv::CompositeType>(
getType());
1833 auto constituents = this->getConstituents().getValue();
1836 return emitError(
"result type must be a composite type, but provided ")
1839 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1840 return emitError(
"unsupported composite type ") << cType;
1841 if (constituents.size() != cType.getNumElements())
1842 return emitError(
"has incorrect number of operands: expected ")
1843 << cType.getNumElements() <<
", but provided "
1844 << constituents.size();
1846 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1847 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1849 auto constituentSpecConstOp =
1851 (*this)->getParentOp(), constituent.getAttr()));
1853 if (constituentSpecConstOp.getDefaultValue().getType() !=
1854 cType.getElementType(index))
1855 return emitError(
"has incorrect types of operands: expected ")
1856 << cType.getElementType(index) <<
", but provided "
1857 << constituentSpecConstOp.getDefaultValue().getType();
1895 printer <<
" wraps ";
1899 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1900 Block &block = getRegion().getBlocks().
front();
1903 return emitOpError(
"expected exactly 2 nested ops");
1908 return emitOpError(
"invalid enclosed op");
1911 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1912 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1914 "invalid operand, must be defined by a constant operation");
1925 llvm::dyn_cast<spirv::StructType>(getResult().
getType());
1928 return emitError(
"result type must be a struct type with two memebers");
1932 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1933 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1935 Type operandTy = getOperand().getType();
1936 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1937 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1939 if (significandTy != operandTy)
1940 return emitError(
"member zero of the resulting struct type must be the "
1941 "same type as the operand");
1943 if (exponentVecTy) {
1944 IntegerType componentIntTy =
1945 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1946 if (!componentIntTy || componentIntTy.getWidth() != 32)
1947 return emitError(
"member one of the resulting struct type must"
1948 "be a scalar or vector of 32 bit integer type");
1949 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1950 return emitError(
"member one of the resulting struct type "
1951 "must be a scalar or vector of 32 bit integer type");
1955 if (operandVecTy && exponentVecTy &&
1956 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1959 if (operandFTy && exponentIntTy)
1962 return emitError(
"member one of the resulting struct type must have the same "
1963 "number of components as the operand type");
1971 Type significandType = getX().getType();
1972 Type exponentType = getExp().getType();
1974 if (llvm::isa<FloatType>(significandType) !=
1975 llvm::isa<IntegerType>(exponentType))
1976 return emitOpError(
"operands must both be scalars or vectors");
1979 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
1980 return vectorType.getNumElements();
1985 return emitOpError(
"operands must have the same number of elements");
2020 return emitOpError(
"vector operand and result type mismatch");
2021 auto scalarType = llvm::cast<VectorType>(
getType()).getElementType();
2022 if (getScalar().
getType() != scalarType)
2023 return emitOpError(
"scalar operand and result element type match");
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.