32 #include "llvm/ADT/APFloat.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/StringExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
41 #include <type_traits>
51 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
55 auto valueAttr = constOp.getValue();
56 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
57 if (!integerValueAttr) {
61 if (integerValueAttr.getType().isSignlessInteger())
62 value = integerValueAttr.getInt();
64 value = integerValueAttr.getSInt();
71 spirv::MemorySemantics memorySemantics) {
78 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
79 spirv::MemorySemantics::Release |
80 spirv::MemorySemantics::AcquireRelease |
81 spirv::MemorySemantics::SequentiallyConsistent;
84 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
87 "expected at most one of these four memory constraints "
88 "to be set: `Acquire`, `Release`,"
89 "`AcquireRelease` or `SequentiallyConsistent`");
98 stringifyDecoration(spirv::Decoration::DescriptorSet));
99 auto bindingName = llvm::convertToSnakeFromCamelCase(
100 stringifyDecoration(spirv::Decoration::Binding));
103 if (descriptorSet && binding) {
106 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
111 auto builtInName = llvm::convertToSnakeFromCamelCase(
112 stringifyDecoration(spirv::Decoration::BuiltIn));
113 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
114 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
115 elidedAttrs.push_back(builtInName);
133 auto fnType = llvm::dyn_cast<FunctionType>(type);
135 parser.
emitError(loc,
"expected function type");
140 result.
addTypes(fnType.getResults());
151 assert(op->
getNumResults() == 1 &&
"op should have one result");
157 [&](
Type type) { return type != resultType; })) {
166 p <<
" : " << resultType;
169 template <
typename BlockReadWriteOpTy>
173 if (
auto valVecTy = llvm::dyn_cast<VectorType>(valType))
174 valType = valVecTy.getElementType();
177 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
178 return op.emitOpError(
"mismatch in result type and pointer type");
189 if (indices.empty()) {
190 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
194 for (
auto index : indices) {
195 if (
auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
196 if (cType.hasCompileTimeKnownNumElements() &&
198 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
199 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
202 type = cType.getElementType(index);
204 emitErrorFn(
"cannot extract from non-composite type ")
205 << type <<
" with index " << index;
215 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
216 if (!indicesArrayAttr) {
217 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
220 if (indicesArrayAttr.empty()) {
221 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
226 for (
auto indexAttr : indicesArrayAttr) {
227 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
229 emitErrorFn(
"expected an 32-bit integer for index, but found '")
233 indexVals.push_back(indexIntAttr.getInt());
253 template <
typename ExtendedBinaryOp>
255 auto resultType = llvm::cast<spirv::StructType>(op.getType());
256 if (resultType.getNumElements() != 2)
257 return op.emitOpError(
"expected result struct type containing two members");
259 if (!llvm::all_equal({op.getOperand1().
getType(), op.getOperand2().getType(),
260 resultType.getElementType(0),
261 resultType.getElementType(1)}))
262 return op.emitOpError(
263 "expected all operand types and struct member types are the same");
280 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
281 if (!structType || structType.getNumElements() != 2)
282 return parser.
emitError(loc,
"expected spirv.struct type with two members");
302 return op->
emitError(
"expected the same type for the first operand and "
303 "result, but provided ")
315 spirv::GlobalVariableOp var) {
320 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
324 return emitOpError(
"expected spirv.GlobalVariable symbol");
326 if (getPointer().
getType() != varOp.getType()) {
328 "result type mismatch with the referenced global variable's type");
338 operand_range constituents = this->getConstituents();
346 auto coopElementType =
349 [](
auto coopType) {
return coopType.getElementType(); })
350 .Default([](
Type) {
return nullptr; });
353 if (coopElementType) {
354 if (constituents.size() != 1)
355 return emitOpError(
"has incorrect number of operands: expected ")
356 <<
"1, but provided " << constituents.size();
357 if (coopElementType != constituents.front().getType())
358 return emitOpError(
"operand type mismatch: expected operand type ")
359 << coopElementType <<
", but provided "
360 << constituents.front().getType();
365 auto cType = llvm::cast<spirv::CompositeType>(
getType());
366 if (constituents.size() == cType.getNumElements()) {
367 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
368 if (constituents[index].
getType() != cType.getElementType(index)) {
369 return emitOpError(
"operand type mismatch: expected operand type ")
370 << cType.getElementType(index) <<
", but provided "
371 << constituents[index].getType();
378 auto resultType = llvm::dyn_cast<VectorType>(cType);
381 "expected to return a vector or cooperative matrix when the number of "
382 "constituents is less than what the result needs");
385 for (
Value component : constituents) {
386 if (!llvm::isa<VectorType>(component.getType()) &&
387 !component.getType().isIntOrFloat())
388 return emitOpError(
"operand type mismatch: expected operand to have "
389 "a scalar or vector type, but provided ")
390 << component.getType();
392 Type elementType = component.getType();
393 if (
auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
394 sizes.push_back(vectorType.getNumElements());
395 elementType = vectorType.getElementType();
400 if (elementType != resultType.getElementType())
401 return emitOpError(
"operand element type mismatch: expected to be ")
402 << resultType.getElementType() <<
", but provided " << elementType;
404 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
405 if (totalCount != cType.getNumElements())
406 return emitOpError(
"has incorrect number of operands: expected ")
407 << cType.getNumElements() <<
", but provided " << totalCount;
424 build(builder, state, elementType, composite, indexAttr);
431 StringRef indicesAttrName =
432 spirv::CompositeExtractOp::getIndicesAttrName(result.
name);
454 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
459 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
466 return emitOpError(
"invalid result type: expected ")
467 << resultType <<
" but provided " <<
getType();
481 build(builder, state, composite.
getType(),
object, composite, indexAttr);
487 Type objectType, compositeType;
489 StringRef indicesAttrName =
490 spirv::CompositeInsertOp::getIndicesAttrName(result.
name);
504 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
510 if (objectType != getObject().
getType()) {
511 return emitOpError(
"object operand type should be ")
512 << objectType <<
", but found " << getObject().getType();
516 return emitOpError(
"result type should be the same as "
517 "the composite type, but found ")
518 << getComposite().getType() <<
" vs " <<
getType();
525 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
526 <<
" : " << getObject().
getType() <<
" into "
527 << getComposite().getType();
537 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.
name);
542 if (
auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
543 type = typedAttr.getType();
544 if (llvm::isa<NoneType, TensorType>(type)) {
553 printer <<
' ' << getValue();
554 if (llvm::isa<spirv::ArrayType>(
getType()))
560 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
561 auto valueType = llvm::cast<TypedAttr>(value).getType();
562 if (valueType != opType)
563 return op.emitOpError(
"result type (")
564 << opType <<
") does not match value type (" << valueType <<
")";
567 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
568 auto valueType = llvm::cast<TypedAttr>(value).getType();
569 if (valueType == opType)
571 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
572 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
574 return op.emitOpError(
"result or element type (")
575 << opType <<
") does not match value type (" << valueType
576 <<
"), must be the same or spirv.array";
578 int numElements = arrayType.getNumElements();
579 auto opElemType = arrayType.getElementType();
580 while (
auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
581 numElements *= t.getNumElements();
582 opElemType = t.getElementType();
584 if (!opElemType.isIntOrFloat())
585 return op.emitOpError(
"only support nested array result type");
587 auto valueElemType = shapedType.getElementType();
588 if (valueElemType != opElemType) {
589 return op.emitOpError(
"result element type (")
590 << opElemType <<
") does not match value element type ("
591 << valueElemType <<
")";
594 if (numElements != shapedType.getNumElements()) {
595 return op.emitOpError(
"result number of elements (")
596 << numElements <<
") does not match value number of elements ("
597 << shapedType.getNumElements() <<
")";
601 if (
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
602 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
604 return op.emitOpError(
605 "must have spirv.array result type for array value");
606 Type elemType = arrayType.getElementType();
607 for (
Attribute element : arrayAttr.getValue()) {
614 return op.emitOpError(
"cannot have attribute: ") << value;
624 bool spirv::ConstantOp::isBuildableWith(
Type type) {
626 if (!llvm::isa<spirv::SPIRVType>(type))
631 return llvm::isa<spirv::ArrayType>(type);
639 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
640 unsigned width = intType.getWidth();
642 return builder.
create<spirv::ConstantOp>(loc, type,
644 return builder.
create<spirv::ConstantOp>(
647 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
648 return builder.
create<spirv::ConstantOp>(
651 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
652 Type elemType = vectorType.getElementType();
653 if (llvm::isa<IntegerType>(elemType)) {
654 return builder.
create<spirv::ConstantOp>(
659 if (llvm::isa<FloatType>(elemType)) {
660 return builder.
create<spirv::ConstantOp>(
667 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
670 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
672 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
673 unsigned width = intType.getWidth();
675 return builder.
create<spirv::ConstantOp>(loc, type,
677 return builder.
create<spirv::ConstantOp>(
680 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
681 return builder.
create<spirv::ConstantOp>(
684 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
685 Type elemType = vectorType.getElementType();
686 if (llvm::isa<IntegerType>(elemType)) {
687 return builder.
create<spirv::ConstantOp>(
692 if (llvm::isa<FloatType>(elemType)) {
693 return builder.
create<spirv::ConstantOp>(
700 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
703 void mlir::spirv::ConstantOp::getAsmResultNames(
708 llvm::raw_svector_ostream specialName(specialNameBuffer);
709 specialName <<
"cst";
711 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
713 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
714 if (intTy && intTy.getWidth() == 1) {
715 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
718 if (intTy.isSignless()) {
719 specialName << intCst.getInt();
720 }
else if (intTy.isUnsigned()) {
721 specialName << intCst.getUInt();
723 specialName << intCst.getSInt();
727 if (intTy || llvm::isa<FloatType>(type)) {
728 specialName <<
'_' << type;
731 if (
auto vecType = llvm::dyn_cast<VectorType>(type)) {
732 specialName <<
"_vec_";
733 specialName << vecType.getDimSize(0);
735 Type elementType = vecType.getElementType();
737 if (llvm::isa<IntegerType>(elementType) ||
738 llvm::isa<FloatType>(elementType)) {
739 specialName <<
"x" << elementType;
743 setNameFn(getResult(), specialName.str());
746 void mlir::spirv::AddressOfOp::getAsmResultNames(
749 llvm::raw_svector_ostream specialName(specialNameBuffer);
750 specialName << getVariable() <<
"_addr";
751 setNameFn(getResult(), specialName.str());
767 spirv::ExecutionModel executionModel,
768 spirv::FuncOp
function,
770 build(builder, state,
777 spirv::ExecutionModel execModel;
783 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
792 FlatSymbolRefAttr var;
794 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
796 interfaceVars.push_back(var);
809 auto interfaceVars = getInterface().getValue();
810 if (!interfaceVars.empty()) {
812 llvm::interleaveComma(interfaceVars, printer);
827 spirv::FuncOp
function,
828 spirv::ExecutionMode executionMode,
837 spirv::ExecutionMode execMode;
840 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
852 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
854 StringRef valuesAttrName =
855 spirv::ExecutionModeOp::getValuesAttrName(result.
name);
864 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
865 auto values = this->getValues();
869 llvm::interleaveComma(values, printer, [&](
Attribute a) {
870 printer << llvm::cast<IntegerAttr>(a).getInt();
891 bool isVariadic =
false;
893 parser,
false, entryArgs, isVariadic, resultTypes,
898 for (
auto &arg : entryArgs)
899 argTypes.push_back(arg.type);
905 spirv::FunctionControl fnControl;
906 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
914 assert(resultAttrs.size() == resultTypes.size());
916 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
917 getResAttrsAttrName(result.
name));
923 return failure(parseResult.
has_value() && failed(*parseResult));
930 auto fnType = getFunctionType();
932 printer, *
this, fnType.getInputs(),
933 false, fnType.getResults());
934 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
938 {spirv::attributeName<spirv::FunctionControl>(),
939 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
940 getFunctionControlAttrName()});
943 Region &body = this->getBody();
951 LogicalResult spirv::FuncOp::verifyType() {
952 FunctionType fnType = getFunctionType();
953 if (fnType.getNumResults() > 1)
954 return emitOpError(
"cannot have more than one result");
956 auto hasDecorationAttr = [&](spirv::Decoration decoration,
958 auto func = llvm::cast<FunctionOpInterface>(getOperation());
959 for (
auto argAttr : cast<FunctionOpInterface>(func).
getArgAttrs(argIndex)) {
960 if (argAttr.getName() != spirv::DecorationAttr::name)
962 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
963 return decAttr.getValue() == decoration;
968 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
969 Type param = fnType.getInputs()[i];
970 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
974 auto pointeePtrType =
975 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
976 if (pointeePtrType) {
982 if (pointeePtrType.getStorageClass() !=
983 spirv::StorageClass::PhysicalStorageBuffer)
987 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
988 bool hasRestrictPtr =
989 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
990 if (!hasAliasedPtr && !hasRestrictPtr)
992 <<
"with a pointer points to a physical buffer pointer must "
993 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1000 if (
auto pointeeArrayType =
1001 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1003 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1005 pointeePtrType = inputPtrType;
1008 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1009 spirv::StorageClass::PhysicalStorageBuffer)
1012 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1013 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1014 if (!hasAliased && !hasRestrict)
1015 return emitOpError() <<
"with physical buffer pointer must be decorated "
1016 "either 'Aliased' or 'Restrict'";
1022 LogicalResult spirv::FuncOp::verifyBody() {
1023 FunctionType fnType = getFunctionType();
1024 if (!isExternal()) {
1025 Block &entryBlock = front();
1027 unsigned numArguments = this->getNumArguments();
1029 return emitOpError(
"entry block must have ")
1030 << numArguments <<
" arguments to match function signature";
1032 for (
auto [index, fnArgType, blockArgType] :
1034 if (blockArgType != fnArgType) {
1035 return emitOpError(
"type of entry block argument #")
1036 << index <<
'(' << blockArgType
1037 <<
") must match the type of the corresponding argument in "
1038 <<
"function signature(" << fnArgType <<
')';
1044 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1045 if (fnType.getNumResults() != 0)
1046 return retOp.emitOpError(
"cannot be used in functions returning value");
1047 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1048 if (fnType.getNumResults() != 1)
1049 return retOp.emitOpError(
1050 "returns 1 value but enclosing function requires ")
1051 << fnType.getNumResults() <<
" results";
1053 auto retOperandType = retOp.getValue().getType();
1054 auto fnResultType = fnType.getResult(0);
1055 if (retOperandType != fnResultType)
1056 return retOp.emitOpError(
" return value's type (")
1057 << retOperandType <<
") mismatch with function's result type ("
1058 << fnResultType <<
")";
1065 return failure(walkResult.wasInterrupted());
1069 StringRef name, FunctionType type,
1070 spirv::FunctionControl control,
1074 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
1075 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1076 builder.
getAttr<spirv::FunctionControlAttr>(control));
1077 state.attributes.append(attrs.begin(), attrs.end());
1125 Type type, StringRef name,
1126 unsigned descriptorSet,
unsigned binding) {
1129 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1132 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1137 Type type, StringRef name,
1138 spirv::BuiltIn builtin) {
1141 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1148 StringAttr nameAttr;
1149 StringRef initializerAttrName =
1150 spirv::GlobalVariableOp::getInitializerAttrName(result.
name);
1171 StringRef typeAttrName =
1172 spirv::GlobalVariableOp::getTypeAttrName(result.
name);
1177 if (!llvm::isa<spirv::PointerType>(type)) {
1178 return parser.
emitError(loc,
"expected spirv.ptr type");
1187 spirv::attributeName<spirv::StorageClass>()};
1194 StringRef initializerAttrName = this->getInitializerAttrName();
1196 if (
auto initializer = this->getInitializer()) {
1197 printer <<
" " << initializerAttrName <<
'(';
1200 elidedAttrs.push_back(initializerAttrName);
1203 StringRef typeAttrName = this->getTypeAttrName();
1204 elidedAttrs.push_back(typeAttrName);
1206 printer <<
" : " <<
getType();
1210 if (!llvm::isa<spirv::PointerType>(
getType()))
1211 return emitOpError(
"result must be of a !spv.ptr type");
1217 auto storageClass = this->storageClass();
1218 if (storageClass == spirv::StorageClass::Generic ||
1219 storageClass == spirv::StorageClass::Function) {
1220 return emitOpError(
"storage class cannot be '")
1221 << stringifyStorageClass(storageClass) <<
"'";
1225 this->getInitializerAttrName())) {
1227 (*this)->getParentOp(), init.getAttr());
1231 if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1232 spirv::SpecConstantCompositeOp>(initOp)) {
1233 return emitOpError(
"initializer must be result of a "
1234 "spirv.SpecConstant or spirv.GlobalVariable or "
1235 "spirv.SpecConstantCompositeOp op");
1260 spirv::StorageClass storageClass;
1271 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1282 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1283 << getValue().getType();
1374 std::optional<StringRef> name) {
1384 spirv::AddressingModel addressingModel,
1385 spirv::MemoryModel memoryModel,
1386 std::optional<VerCapExtAttr> vceTriple,
1387 std::optional<StringRef> name) {
1390 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1391 state.addAttribute(
"memory_model",
1392 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1396 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1407 StringAttr nameAttr;
1412 spirv::AddressingModel addrModel;
1413 spirv::MemoryModel memoryModel;
1414 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1416 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1423 spirv::ModuleOp::getVCETripleAttrName(),
1440 if (std::optional<StringRef> name = getName()) {
1449 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1450 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1451 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1454 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1455 printer <<
" requires " << *triple;
1456 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1464 LogicalResult spirv::ModuleOp::verifyRegions() {
1465 Dialect *dialect = (*this)->getDialect();
1470 for (
auto &op : *getBody()) {
1472 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1477 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1478 auto funcOp =
table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1480 return entryPointOp.emitError(
"function '")
1481 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1483 if (
auto interface = entryPointOp.getInterface()) {
1485 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1487 return entryPointOp.emitError(
1488 "expected symbol reference for interface "
1489 "specification instead of '")
1493 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1495 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1496 "symbol reference instead of'")
1497 << varSymRef <<
"'";
1502 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1503 funcOp, entryPointOp.getExecutionModel());
1504 if (!entryPoints.try_emplace(key, entryPointOp).second)
1505 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1506 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1510 auto linkageAttr = funcOp.getLinkageAttributes();
1511 auto hasImportLinkage =
1512 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1513 spirv::LinkageType::Import);
1514 if (funcOp.isExternal() && !hasImportLinkage)
1516 "'spirv.module' cannot contain external functions "
1517 "without 'Import' linkage_attributes (LinkageAttributes)");
1520 for (
auto &block : funcOp)
1521 for (
auto &op : block) {
1524 "functions in 'spirv.module' can only contain spirv.* ops");
1538 (*this)->getParentOp(), getSpecConstAttr());
1541 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1543 constType = specConstOp.getDefaultValue().getType();
1545 auto specConstCompositeOp =
1546 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1547 if (specConstCompositeOp)
1548 constType = specConstCompositeOp.getType();
1550 if (!specConstOp && !specConstCompositeOp)
1552 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1554 if (getReference().
getType() != constType)
1555 return emitOpError(
"result type mismatch with the referenced "
1556 "specialization constant's type");
1567 StringAttr nameAttr;
1569 StringRef defaultValueAttrName =
1570 spirv::SpecConstantOp::getDefaultValueAttrName(result.
name);
1578 IntegerAttr specIdAttr;
1595 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1597 printer <<
" = " << getDefaultValue();
1601 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1602 if (specID.getValue().isNegative())
1603 return emitOpError(
"SpecId cannot be negative");
1605 auto value = getDefaultValue();
1606 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1608 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1609 return emitOpError(
"default value bitwidth disallowed");
1613 "default value can only be a bool, integer, or float scalar");
1621 VectorType resultType = llvm::cast<VectorType>(
getType());
1623 size_t numResultElements = resultType.getNumElements();
1624 if (numResultElements != getComponents().size())
1625 return emitOpError(
"result type element count (")
1626 << numResultElements
1627 <<
") mismatch with the number of component selectors ("
1628 << getComponents().size() <<
")";
1630 size_t totalSrcElements =
1631 llvm::cast<VectorType>(getVector1().
getType()).getNumElements() +
1632 llvm::cast<VectorType>(getVector2().
getType()).getNumElements();
1634 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1635 uint32_t index = selector.getZExtValue();
1636 if (index >= totalSrcElements &&
1637 index != std::numeric_limits<uint32_t>().
max())
1638 return emitOpError(
"component selector ")
1639 << index <<
" out of range: expected to be in [0, "
1640 << totalSrcElements <<
") or 0xffffffff";
1653 [](
auto matrixType) {
return matrixType.getElementType(); })
1654 .Default([](
Type) {
return nullptr; });
1656 assert(elementType &&
"Unhandled type");
1659 if (getScalar().
getType() != elementType)
1660 return emitOpError(
"input matrix components' type and scaling value must "
1661 "have the same type");
1671 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1672 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1675 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1676 return emitError(
"input matrix rows count must be equal to "
1677 "output matrix columns count");
1679 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1680 return emitError(
"input matrix columns count must be equal to "
1681 "output matrix rows count");
1684 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1685 return emitError(
"input and output matrices must have the same "
1696 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1697 auto vectorType = llvm::cast<VectorType>(getVector().
getType());
1698 auto resultType = llvm::cast<VectorType>(
getType());
1700 if (matrixType.getNumColumns() != vectorType.getNumElements())
1701 return emitOpError(
"matrix columns (")
1702 << matrixType.getNumColumns() <<
") must match vector operand size ("
1703 << vectorType.getNumElements() <<
")";
1705 if (resultType.getNumElements() != matrixType.getNumRows())
1706 return emitOpError(
"result size (")
1707 << resultType.getNumElements() <<
") must match the matrix rows ("
1708 << matrixType.getNumRows() <<
")";
1710 if (matrixType.getElementType() != resultType.getElementType())
1711 return emitOpError(
"matrix and result element types must match");
1721 auto vectorType = llvm::cast<VectorType>(getVector().
getType());
1722 auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().
getType());
1723 auto resultType = llvm::cast<VectorType>(
getType());
1725 if (matrixType.getNumRows() != vectorType.getNumElements())
1726 return emitOpError(
"number of components in vector must equal the number "
1727 "of components in each column in matrix");
1729 if (resultType.getNumElements() != matrixType.getNumColumns())
1730 return emitOpError(
"number of columns in matrix must equal the number of "
1731 "components in result");
1733 if (matrixType.getElementType() != resultType.getElementType())
1734 return emitOpError(
"matrix must be a matrix with the same component type "
1735 "as the component type in result");
1745 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().
getType());
1746 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().
getType());
1747 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().
getType());
1750 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1751 return emitError(
"left matrix columns' count must be equal to "
1752 "the right matrix rows' count");
1755 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1757 "right and result matrices must have equal columns' count");
1760 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1761 return emitError(
"right and result matrices' component type must"
1765 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1766 return emitError(
"left and result matrices' component type"
1767 " must be the same");
1770 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1771 return emitError(
"left and result matrices must have equal rows' count");
1783 StringAttr compositeName;
1795 const char *attrName =
"spec_const";
1802 constituents.push_back(specConstRef);
1808 StringAttr compositeSpecConstituentsName =
1809 spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.
name);
1817 StringAttr typeAttrName =
1818 spirv::SpecConstantCompositeOp::getTypeAttrName(result.
name);
1828 auto constituents = this->getConstituents().getValue();
1830 if (!constituents.empty())
1831 llvm::interleaveComma(constituents, printer);
1833 printer <<
") : " <<
getType();
1837 auto cType = llvm::dyn_cast<spirv::CompositeType>(
getType());
1838 auto constituents = this->getConstituents().getValue();
1841 return emitError(
"result type must be a composite type, but provided ")
1844 if (llvm::isa<spirv::CooperativeMatrixType>(cType))
1845 return emitError(
"unsupported composite type ") << cType;
1846 if (constituents.size() != cType.getNumElements())
1847 return emitError(
"has incorrect number of operands: expected ")
1848 << cType.getNumElements() <<
", but provided "
1849 << constituents.size();
1851 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1852 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1854 auto constituentSpecConstOp =
1856 (*this)->getParentOp(), constituent.getAttr()));
1858 if (constituentSpecConstOp.getDefaultValue().getType() !=
1859 cType.getElementType(index))
1860 return emitError(
"has incorrect types of operands: expected ")
1861 << cType.getElementType(index) <<
", but provided "
1862 << constituentSpecConstOp.getDefaultValue().getType();
1900 printer <<
" wraps ";
1904 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1905 Block &block = getRegion().getBlocks().
front();
1908 return emitOpError(
"expected exactly 2 nested ops");
1913 return emitOpError(
"invalid enclosed op");
1916 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1917 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1919 "invalid operand, must be defined by a constant operation");
1930 llvm::dyn_cast<spirv::StructType>(getResult().
getType());
1933 return emitError(
"result type must be a struct type with two memebers");
1937 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1938 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1940 Type operandTy = getOperand().getType();
1941 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1942 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1944 if (significandTy != operandTy)
1945 return emitError(
"member zero of the resulting struct type must be the "
1946 "same type as the operand");
1948 if (exponentVecTy) {
1949 IntegerType componentIntTy =
1950 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1951 if (!componentIntTy || componentIntTy.getWidth() != 32)
1952 return emitError(
"member one of the resulting struct type must"
1953 "be a scalar or vector of 32 bit integer type");
1954 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1955 return emitError(
"member one of the resulting struct type "
1956 "must be a scalar or vector of 32 bit integer type");
1960 if (operandVecTy && exponentVecTy &&
1961 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1964 if (operandFTy && exponentIntTy)
1967 return emitError(
"member one of the resulting struct type must have the same "
1968 "number of components as the operand type");
1976 Type significandType = getX().getType();
1977 Type exponentType = getExp().getType();
1979 if (llvm::isa<FloatType>(significandType) !=
1980 llvm::isa<IntegerType>(exponentType))
1981 return emitOpError(
"operands must both be scalars or vectors");
1984 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
1985 return vectorType.getNumElements();
1990 return emitOpError(
"operands must have the same number of elements");
2025 return emitOpError(
"vector operand and result type mismatch");
2026 auto scalarType = llvm::cast<VectorType>(
getType()).getElementType();
2027 if (getScalar().
getType() != scalarType)
2028 return emitOpError(
"scalar operand and result element type match");
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.