31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/APInt.h"
33#include "llvm/ADT/ArrayRef.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/StringExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/InterleavedRange.h"
50 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
54 auto valueAttr = constOp.getValue();
55 auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
56 if (!integerValueAttr) {
60 if (integerValueAttr.getType().isSignlessInteger())
61 value = integerValueAttr.getInt();
63 value = integerValueAttr.getSInt();
70 spirv::MemorySemantics memorySemantics) {
77 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
78 spirv::MemorySemantics::Release |
79 spirv::MemorySemantics::AcquireRelease |
80 spirv::MemorySemantics::SequentiallyConsistent;
83 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
86 "expected at most one of these four memory constraints "
87 "to be set: `Acquire`, `Release`,"
88 "`AcquireRelease` or `SequentiallyConsistent`");
99 auto pointeePtrType = dyn_cast<spirv::PointerType>(pointeeType);
100 if (!pointeePtrType) {
101 if (
auto pointeeArrayType = dyn_cast<spirv::ArrayType>(pointeeType)) {
103 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
107 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
108 spirv::StorageClass::PhysicalStorageBuffer)
111 auto getDecorationAttr = [op](spirv::Decoration decoration) {
116 getDecorationAttr(spirv::Decoration::AliasedPointer) !=
nullptr;
117 bool hasRestrictPtr =
118 getDecorationAttr(spirv::Decoration::RestrictPointer) !=
nullptr;
120 if (!hasAliasedPtr && !hasRestrictPtr)
122 <<
" with physical buffer pointer must be decorated "
123 "either 'AliasedPointer' or 'RestrictPointer'";
125 if (hasAliasedPtr && hasRestrictPtr)
127 <<
" with physical buffer pointer must have exactly one "
128 "aliasing decoration";
137 stringifyDecoration(spirv::Decoration::DescriptorSet));
138 auto bindingName = llvm::convertToSnakeFromCamelCase(
139 stringifyDecoration(spirv::Decoration::Binding));
142 if (descriptorSet && binding) {
145 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
150 auto builtInName = llvm::convertToSnakeFromCamelCase(
151 stringifyDecoration(spirv::Decoration::BuiltIn));
153 printer <<
" " << builtInName <<
"(\"" <<
builtin.getValue() <<
"\")";
154 elidedAttrs.push_back(builtInName);
172 auto fnType = dyn_cast<FunctionType>(type);
174 parser.
emitError(loc,
"expected function type");
179 result.addTypes(fnType.getResults());
190 assert(op->
getNumResults() == 1 &&
"op should have one result");
196 [&](
Type type) { return type != resultType; })) {
205 p <<
" : " << resultType;
208template <
typename BlockReadWriteOpTy>
212 if (
auto valVecTy = dyn_cast<VectorType>(valType))
213 valType = valVecTy.getElementType();
215 if (valType != cast<spirv::PointerType>(
ptr.getType()).getPointeeType()) {
216 return op.emitOpError(
"mismatch in result type and pointer type");
228 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
233 if (
auto cType = dyn_cast<spirv::CompositeType>(type)) {
234 if (cType.hasCompileTimeKnownNumElements() &&
236 static_cast<uint64_t
>(
index) >= cType.getNumElements())) {
237 emitErrorFn(
"index ") <<
index <<
" out of bounds for " << type;
240 type = cType.getElementType(
index);
242 emitErrorFn(
"cannot extract from non-composite type ")
243 << type <<
" with index " <<
index;
253 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
indices);
254 if (!indicesArrayAttr) {
255 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
258 if (indicesArrayAttr.empty()) {
259 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
264 for (
auto indexAttr : indicesArrayAttr) {
265 auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
267 emitErrorFn(
"expected an 32-bit integer for index, but found '")
271 indexVals.push_back(indexIntAttr.getInt());
278 return ::mlir::emitError(loc, err);
291template <
typename ExtendedBinaryOp>
293 auto resultType = cast<spirv::StructType>(op.getType());
294 if (resultType.getNumElements() != 2)
295 return op.emitOpError(
"expected result struct type containing two members");
297 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
298 resultType.getElementType(0),
299 resultType.getElementType(1)}))
300 return op.emitOpError(
301 "expected all operand types and struct member types are the same");
318 auto structType = dyn_cast<spirv::StructType>(resultType);
319 if (!structType || structType.getNumElements() != 2)
320 return parser.
emitError(loc,
"expected spirv.struct type with two members");
326 result.addTypes(resultType);
340 return op->
emitError(
"expected the same type for the first operand and "
341 "result, but provided ")
353 spirv::GlobalVariableOp var) {
354 build(builder, state, var.getType(), SymbolRefAttr::get(var));
357LogicalResult spirv::AddressOfOp::verify() {
358 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
362 return emitOpError(
"expected spirv.GlobalVariable symbol");
364 if (getPointer().
getType() != varOp.getType()) {
366 "result type mismatch with the referenced global variable's type");
375LogicalResult spirv::CompositeConstructOp::verify() {
376 operand_range constituents = this->getConstituents();
391 if (coopElementType) {
392 if (constituents.size() != 1)
393 return emitOpError(
"has incorrect number of operands: expected ")
394 <<
"1, but provided " << constituents.size();
395 if (coopElementType != constituents.front().getType())
396 return emitOpError(
"operand type mismatch: expected operand type ")
397 << coopElementType <<
", but provided "
398 << constituents.front().getType();
403 auto cType = cast<spirv::CompositeType>(
getType());
404 if (constituents.size() == cType.getNumElements()) {
405 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
407 return emitOpError(
"operand type mismatch: expected operand type ")
408 << cType.getElementType(
index) <<
", but provided "
409 << constituents[
index].getType();
416 auto resultType = dyn_cast<VectorType>(cType);
419 "expected to return a vector or cooperative matrix when the number of "
420 "constituents is less than what the result needs");
423 for (
Value component : constituents) {
424 if (!isa<VectorType>(component.getType()) &&
425 !component.getType().isIntOrFloat())
426 return emitOpError(
"operand type mismatch: expected operand to have "
427 "a scalar or vector type, but provided ")
428 << component.getType();
430 Type elementType = component.getType();
431 if (
auto vectorType = dyn_cast<VectorType>(component.getType())) {
432 sizes.push_back(vectorType.getNumElements());
433 elementType = vectorType.getElementType();
438 if (elementType != resultType.getElementType())
439 return emitOpError(
"operand element type mismatch: expected to be ")
440 << resultType.getElementType() <<
", but provided " << elementType;
442 unsigned totalCount = llvm::sum_of(sizes);
443 if (totalCount != cType.getNumElements())
444 return emitOpError(
"has incorrect number of operands: expected ")
445 << cType.getNumElements() <<
", but provided " << totalCount;
462 build(builder, state, elementType, composite, indexAttr);
465ParseResult spirv::CompositeExtractOp::parse(
OpAsmParser &parser,
469 StringRef indicesAttrName =
470 spirv::CompositeExtractOp::getIndicesAttrName(
result.name);
487 result.addTypes(resultType);
491void spirv::CompositeExtractOp::print(
OpAsmPrinter &printer) {
492 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
496LogicalResult spirv::CompositeExtractOp::verify() {
497 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
getIndices());
504 return emitOpError(
"invalid result type: expected ")
505 << resultType <<
" but provided " <<
getType();
519 build(builder, state, composite.
getType(),
object, composite, indexAttr);
522ParseResult spirv::CompositeInsertOp::parse(
OpAsmParser &parser,
525 Type objectType, compositeType;
527 StringRef indicesAttrName =
528 spirv::CompositeInsertOp::getIndicesAttrName(
result.name);
541LogicalResult spirv::CompositeInsertOp::verify() {
542 auto indicesArrayAttr = dyn_cast<ArrayAttr>(
getIndices());
548 if (objectType != getObject().
getType()) {
549 return emitOpError(
"object operand type should be ")
550 << objectType <<
", but found " << getObject().getType();
554 return emitOpError(
"result type should be the same as "
555 "the composite type, but found ")
556 << getComposite().getType() <<
" vs " <<
getType();
562void spirv::CompositeInsertOp::print(
OpAsmPrinter &printer) {
563 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
564 <<
" : " << getObject().
getType() <<
" into "
565 << getComposite().getType();
572ParseResult spirv::ConstantOp::parse(
OpAsmParser &parser,
575 StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(
result.name);
580 if (
auto typedAttr = dyn_cast<TypedAttr>(value))
581 type = typedAttr.getType();
582 if (isa<NoneType, TensorType>(type)) {
587 if (isa<TensorArmType>(type)) {
597 printer <<
' ' << getValue();
598 if (isa<spirv::ArrayType, spirv::StructType>(
getType()))
604 if (isa<spirv::CooperativeMatrixType>(opType)) {
605 auto denseAttr = dyn_cast<DenseElementsAttr>(value);
606 if (!denseAttr || !denseAttr.isSplat())
607 return op.emitOpError(
"expected a splat dense attribute for cooperative "
608 "matrix constant, but found ")
611 if (isa<IntegerAttr, FloatAttr>(value)) {
612 auto valueType = cast<TypedAttr>(value).getType();
613 if (valueType != opType)
614 return op.emitOpError(
"result type (")
615 << opType <<
") does not match value type (" << valueType <<
")";
618 if (isa<DenseTypedElementsAttr, SparseElementsAttr>(value)) {
619 auto valueType = cast<TypedAttr>(value).getType();
620 if (valueType == opType)
622 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
623 auto shapedType = dyn_cast<ShapedType>(valueType);
625 return op.emitOpError(
"result or element type (")
626 << opType <<
") does not match value type (" << valueType
627 <<
"), must be the same or spirv.array";
629 int numElements = arrayType.getNumElements();
630 auto opElemType = arrayType.getElementType();
631 while (
auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
632 numElements *= t.getNumElements();
633 opElemType = t.getElementType();
635 if (!opElemType.isIntOrFloat())
636 return op.emitOpError(
"only support nested array result type");
638 auto valueElemType = shapedType.getElementType();
639 if (valueElemType != opElemType) {
640 return op.emitOpError(
"result element type (")
641 << opElemType <<
") does not match value element type ("
642 << valueElemType <<
")";
645 if (numElements != shapedType.getNumElements()) {
646 return op.emitOpError(
"result number of elements (")
647 << numElements <<
") does not match value number of elements ("
648 << shapedType.getNumElements() <<
")";
652 if (
auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
653 if (
auto structType = dyn_cast<spirv::StructType>(opType)) {
655 if (structType.isIdentified())
656 return op.emitOpError(
657 "cannot have an identified struct as a constant type");
658 if (arrayAttr.size() != structType.getNumElements())
659 return op.emitOpError(
"number of constituents (")
661 <<
") does not match number of struct members ("
662 << structType.getNumElements() <<
")";
663 for (
auto [idx, element] : llvm::enumerate(arrayAttr.getValue())) {
665 structType.getElementType(idx))))
670 auto arrayType = dyn_cast<spirv::ArrayType>(opType);
672 return op.emitOpError(
673 "must have spirv.array or spirv.struct result type for array value");
674 Type elemType = arrayType.getElementType();
675 for (
Attribute element : arrayAttr.getValue()) {
682 return op.emitOpError(
"cannot have attribute: ") << value;
685LogicalResult spirv::ConstantOp::verify() {
692bool spirv::ConstantOp::isBuildableWith(
Type type) {
694 if (!isa<spirv::SPIRVType>(type))
698 if (
auto structType = dyn_cast<spirv::StructType>(type))
699 return !structType.isIdentified();
700 return isa<spirv::ArrayType>(type);
706spirv::ConstantOp spirv::ConstantOp::getZero(
Type type,
Location loc,
708 if (
auto intType = dyn_cast<IntegerType>(type)) {
709 unsigned width = intType.getWidth();
711 return spirv::ConstantOp::create(builder, loc, type,
713 return spirv::ConstantOp::create(
714 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 0)));
716 if (
auto floatType = dyn_cast<FloatType>(type)) {
717 return spirv::ConstantOp::create(builder, loc, type,
720 if (
auto vectorType = dyn_cast<VectorType>(type)) {
721 Type elemType = vectorType.getElementType();
722 if (isa<IntegerType>(elemType)) {
723 return spirv::ConstantOp::create(
726 IntegerAttr::get(elemType, 0).getValue()));
728 if (isa<FloatType>(elemType)) {
729 return spirv::ConstantOp::create(
732 FloatAttr::get(elemType, 0.0).getValue()));
736 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
739spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
741 if (
auto intType = dyn_cast<IntegerType>(type)) {
742 unsigned width = intType.getWidth();
744 return spirv::ConstantOp::create(builder, loc, type,
746 return spirv::ConstantOp::create(
747 builder, loc, type, builder.
getIntegerAttr(type, APInt(width, 1)));
749 if (
auto floatType = dyn_cast<FloatType>(type)) {
750 return spirv::ConstantOp::create(builder, loc, type,
753 if (
auto vectorType = dyn_cast<VectorType>(type)) {
754 Type elemType = vectorType.getElementType();
755 if (isa<IntegerType>(elemType)) {
756 return spirv::ConstantOp::create(
759 IntegerAttr::get(elemType, 1).getValue()));
761 if (isa<FloatType>(elemType)) {
762 return spirv::ConstantOp::create(
765 FloatAttr::get(elemType, 1.0).getValue()));
769 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
772void mlir::spirv::ConstantOp::getAsmResultNames(
777 llvm::raw_svector_ostream specialName(specialNameBuffer);
778 specialName <<
"cst";
780 IntegerType intTy = dyn_cast<IntegerType>(type);
782 if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
785 if (intTy.getWidth() == 1) {
786 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
789 if (intTy.isSignless()) {
790 specialName << intCst.getInt();
791 }
else if (intTy.isUnsigned()) {
792 specialName << intCst.getUInt();
794 specialName << intCst.getSInt();
798 if (intTy || isa<FloatType>(type)) {
799 specialName <<
'_' << type;
802 if (
auto vecType = dyn_cast<VectorType>(type)) {
803 specialName <<
"_vec_";
804 specialName << vecType.getDimSize(0);
806 Type elementType = vecType.getElementType();
808 if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
809 specialName <<
"x" << elementType;
813 setNameFn(getResult(), specialName.str());
816void mlir::spirv::AddressOfOp::getAsmResultNames(
819 llvm::raw_svector_ostream specialName(specialNameBuffer);
820 specialName << getVariable() <<
"_addr";
821 setNameFn(getResult(), specialName.str());
832 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
833 return typedAttr.getType();
836 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
843LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
846 return emitError(
"unknown value attribute type");
848 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
850 return emitError(
"result type is not a composite type");
852 Type compositeElementType = compositeType.getElementType(0);
855 while (
auto type = dyn_cast<spirv::CompositeType>(compositeElementType)) {
856 compositeElementType = type.getElementType(0);
857 possibleTypes.push_back(compositeElementType);
860 if (!is_contained(possibleTypes, valueType)) {
861 return emitError(
"expected value attribute type ")
862 << interleaved(possibleTypes,
" or ") <<
", but got: " << valueType;
872LogicalResult spirv::ControlBarrierOp::verify() {
881 spirv::ExecutionModel executionModel,
882 spirv::FuncOp function,
884 build(builder, state,
885 spirv::ExecutionModelAttr::get(builder.
getContext(), executionModel),
886 SymbolRefAttr::get(function), builder.
getArrayAttr(interfaceVars));
889ParseResult spirv::EntryPointOp::parse(
OpAsmParser &parser,
891 spirv::ExecutionModel execModel;
904 FlatSymbolRefAttr var;
906 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
908 interfaceVars.push_back(var);
913 result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(
result.name),
921 auto interfaceVars = getInterface().getValue();
922 if (!interfaceVars.empty())
923 printer <<
", " << llvm::interleaved(interfaceVars);
926LogicalResult spirv::EntryPointOp::verify() {
937 spirv::FuncOp function,
938 spirv::ExecutionMode executionMode,
940 build(builder, state, SymbolRefAttr::get(function),
941 spirv::ExecutionModeAttr::get(builder.
getContext(), executionMode),
945ParseResult spirv::ExecutionModeOp::parse(
OpAsmParser &parser,
947 spirv::ExecutionMode execMode;
962 values.push_back(cast<IntegerAttr>(value).getInt());
964 StringRef valuesAttrName =
965 spirv::ExecutionModeOp::getValuesAttrName(
result.name);
966 result.addAttribute(valuesAttrName,
971void spirv::ExecutionModeOp::print(
OpAsmPrinter &printer) {
974 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
977 printer <<
", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
984ParseResult spirv::ExecutionModeIdOp::parse(
OpAsmParser &parser,
986 ExecutionMode execMode;
995 FlatSymbolRefAttr attr;
996 if (parser.parseAttribute(attr))
998 values.push_back(attr);
1004 StringRef valuesAttrName = getValuesAttrName(
result.name);
1006 result.addAttribute(valuesAttrName, valuesAttr);
1010void spirv::ExecutionModeIdOp::print(
OpAsmPrinter &printer) {
1013 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\" ";
1015 llvm::interleaveComma(
1016 getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
1020LogicalResult spirv::ExecutionModeIdOp::verify() {
1022 switch (getExecutionMode()) {
1023 case ExecutionMode::SubgroupsPerWorkgroupId:
1024 case ExecutionMode::LocalSizeId:
1025 case ExecutionMode::LocalSizeHintId:
1028 return emitOpError(
"expected ExecutionMode that takes extra operands that "
1029 "are <id> operands, got: ")
1030 << stringifyExecutionMode(getExecutionMode());
1033 if (getValues().empty())
1034 return emitOpError(
"expected at least one value operand");
1037 auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
1039 return emitOpError(
"expected value operands to be symbol reference");
1041 (*this)->getParentOp(), valueSymbol);
1043 return emitOpError(
"cannot find symbol referenced by value operand: ")
1044 << valueSymbol.getValue();
1061 StringAttr nameAttr;
1067 bool isVariadic =
false;
1069 parser,
false, entryArgs, isVariadic, resultTypes,
1074 for (
auto &arg : entryArgs)
1075 argTypes.push_back(arg.type);
1077 result.addAttribute(getFunctionTypeAttrName(
result.name),
1078 TypeAttr::get(fnType));
1081 spirv::FunctionControl fnControl;
1090 assert(resultAttrs.size() == resultTypes.size());
1092 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1093 getResAttrsAttrName(
result.name));
1096 auto *body =
result.addRegion();
1106 auto fnType = getFunctionType();
1108 printer, *
this, fnType.getInputs(),
1109 false, fnType.getResults());
1110 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
1114 {spirv::attributeName<spirv::FunctionControl>(),
1115 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
1116 getFunctionControlAttrName()});
1119 Region &body = this->getBody();
1120 if (!body.empty()) {
1127LogicalResult spirv::FuncOp::verifyType() {
1128 FunctionType fnType = getFunctionType();
1129 if (fnType.getNumResults() > 1)
1130 return emitOpError(
"cannot have more than one result");
1132 auto hasDecorationAttr = [&](spirv::Decoration decoration,
1133 unsigned argIndex) {
1134 auto func = cast<FunctionOpInterface>(getOperation());
1135 for (
auto argAttr : cast<FunctionOpInterface>(
func).
getArgAttrs(argIndex)) {
1136 if (argAttr.getName() != spirv::DecorationAttr::name)
1138 if (
auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
1139 return decAttr.getValue() == decoration;
1144 for (
unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
1145 Type param = fnType.getInputs()[i];
1146 auto inputPtrType = dyn_cast<spirv::PointerType>(param);
1150 auto pointeePtrType =
1151 dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
1152 if (pointeePtrType) {
1158 if (pointeePtrType.getStorageClass() !=
1159 spirv::StorageClass::PhysicalStorageBuffer)
1162 bool hasAliasedPtr =
1163 hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
1164 bool hasRestrictPtr =
1165 hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
1166 if (!hasAliasedPtr && !hasRestrictPtr)
1168 <<
"with a pointer points to a physical buffer pointer must "
1169 "be decorated either 'AliasedPointer' or 'RestrictPointer'";
1176 if (
auto pointeeArrayType =
1177 dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
1179 dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
1181 pointeePtrType = inputPtrType;
1184 if (!pointeePtrType || pointeePtrType.getStorageClass() !=
1185 spirv::StorageClass::PhysicalStorageBuffer)
1188 bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
1189 bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
1190 if (!hasAliased && !hasRestrict)
1191 return emitOpError() <<
"with physical buffer pointer must be decorated "
1192 "either 'Aliased' or 'Restrict'";
1198LogicalResult spirv::FuncOp::verifyBody() {
1199 FunctionType fnType = getFunctionType();
1200 if (!isExternal()) {
1201 Block &entryBlock = front();
1203 unsigned numArguments = this->getNumArguments();
1206 << numArguments <<
" arguments to match function signature";
1208 for (
auto [
index, fnArgType, blockArgType] :
1210 if (blockArgType != fnArgType) {
1211 return emitOpError(
"type of entry block argument #")
1212 <<
index <<
'(' << blockArgType
1213 <<
") must match the type of the corresponding argument in "
1214 <<
"function signature(" << fnArgType <<
')';
1220 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1221 if (fnType.getNumResults() != 0)
1222 return retOp.emitOpError(
"cannot be used in functions returning value");
1223 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1224 if (fnType.getNumResults() != 1)
1225 return retOp.emitOpError(
1226 "returns 1 value but enclosing function requires ")
1227 << fnType.getNumResults() <<
" results";
1229 auto retOperandType = retOp.getValue().getType();
1230 auto fnResultType = fnType.getResult(0);
1231 if (retOperandType != fnResultType)
1232 return retOp.emitOpError(
" return value's type (")
1233 << retOperandType <<
") mismatch with function's result type ("
1234 << fnResultType <<
")";
1241 return failure(walkResult.wasInterrupted());
1245 StringRef name, FunctionType type,
1246 spirv::FunctionControl control,
1250 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
1251 state.
addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1252 builder.
getAttr<spirv::FunctionControlAttr>(control));
1261ParseResult spirv::GLFClampOp::parse(
OpAsmParser &parser,
1271ParseResult spirv::GLUClampOp::parse(
OpAsmParser &parser,
1281ParseResult spirv::GLSClampOp::parse(
OpAsmParser &parser,
1301 Type type, StringRef name,
1302 unsigned descriptorSet,
unsigned binding) {
1303 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
1305 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1308 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1313 Type type, StringRef name,
1315 build(builder, state, TypeAttr::get(type), builder.
getStringAttr(name));
1317 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1321ParseResult spirv::GlobalVariableOp::parse(
OpAsmParser &parser,
1324 StringAttr nameAttr;
1325 StringRef initializerAttrName =
1326 spirv::GlobalVariableOp::getInitializerAttrName(
result.name);
1347 StringRef typeAttrName =
1348 spirv::GlobalVariableOp::getTypeAttrName(
result.name);
1353 if (!isa<spirv::PointerType>(type)) {
1354 return parser.
emitError(loc,
"expected spirv.ptr type");
1356 result.addAttribute(typeAttrName, TypeAttr::get(type));
1361void spirv::GlobalVariableOp::print(
OpAsmPrinter &printer) {
1363 spirv::attributeName<spirv::StorageClass>()};
1370 StringRef initializerAttrName = this->getInitializerAttrName();
1372 if (
auto initializer = this->getInitializer()) {
1373 printer <<
" " << initializerAttrName <<
'(';
1376 elidedAttrs.push_back(initializerAttrName);
1379 StringRef typeAttrName = this->getTypeAttrName();
1380 elidedAttrs.push_back(typeAttrName);
1382 printer <<
" : " <<
getType();
1385LogicalResult spirv::GlobalVariableOp::verify() {
1386 if (!isa<spirv::PointerType>(
getType()))
1387 return emitOpError(
"result must be of a !spv.ptr type");
1393 auto storageClass = this->storageClass();
1394 if (storageClass == spirv::StorageClass::Generic ||
1395 storageClass == spirv::StorageClass::Function) {
1397 << stringifyStorageClass(storageClass) <<
"'";
1402 if (std::optional<spirv::LinkageAttributesAttr> linkage =
1403 getLinkageAttributes()) {
1404 if (linkage->getLinkageType().getValue() == spirv::LinkageType::Import &&
1407 "with Import linkage type must not have an initializer");
1412 this->getInitializerAttrName())) {
1414 (*this)->getParentOp(), init.getAttr());
1426 !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
1427 return emitOpError(
"initializer must be result of a "
1428 "spirv.SpecConstant or "
1429 "spirv.SpecConstantCompositeOp op");
1433 Type pointeeType = cast<spirv::PointerType>(
getType()).getPointeeType();
1445LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
1456ParseResult spirv::INTELSubgroupBlockWriteOp::parse(
OpAsmParser &parser,
1459 spirv::StorageClass storageClass;
1470 if (
auto valVecTy = dyn_cast<VectorType>(elementType))
1480void spirv::INTELSubgroupBlockWriteOp::print(
OpAsmPrinter &printer) {
1481 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1482 << getValue().getType();
1485LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
1496LogicalResult spirv::IAddCarryOp::verify() {
1497 return ::verifyArithmeticExtendedBinaryOp(*
this);
1500ParseResult spirv::IAddCarryOp::parse(
OpAsmParser &parser,
1502 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1513LogicalResult spirv::ISubBorrowOp::verify() {
1514 return ::verifyArithmeticExtendedBinaryOp(*
this);
1517ParseResult spirv::ISubBorrowOp::parse(
OpAsmParser &parser,
1519 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1522void spirv::ISubBorrowOp::print(
OpAsmPrinter &printer) {
1530LogicalResult spirv::SMulExtendedOp::verify() {
1531 return ::verifyArithmeticExtendedBinaryOp(*
this);
1534ParseResult spirv::SMulExtendedOp::parse(
OpAsmParser &parser,
1536 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1539void spirv::SMulExtendedOp::print(
OpAsmPrinter &printer) {
1547LogicalResult spirv::UMulExtendedOp::verify() {
1548 return ::verifyArithmeticExtendedBinaryOp(*
this);
1551ParseResult spirv::UMulExtendedOp::parse(
OpAsmParser &parser,
1553 return ::parseArithmeticExtendedBinaryOp(parser,
result);
1556void spirv::UMulExtendedOp::print(
OpAsmPrinter &printer) {
1564LogicalResult spirv::MemoryBarrierOp::verify() {
1572LogicalResult spirv::MemoryNamedBarrierOp::verify() {
1581 std::optional<StringRef> name) {
1591 spirv::AddressingModel addressingModel,
1592 spirv::MemoryModel memoryModel,
1593 std::optional<VerCapExtAttr> vceTriple,
1594 std::optional<StringRef> name) {
1597 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1599 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1603 state.
addAttribute(getVCETripleAttrName(), *vceTriple);
1609ParseResult spirv::ModuleOp::parse(
OpAsmParser &parser,
1614 StringAttr nameAttr;
1619 spirv::AddressingModel addrModel;
1620 spirv::MemoryModel memoryModel;
1630 spirv::ModuleOp::getVCETripleAttrName(),
1647 if (std::optional<StringRef> name = getName()) {
1656 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1657 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1658 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1661 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1662 printer <<
" requires " << *triple;
1663 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1671LogicalResult spirv::ModuleOp::verifyRegions() {
1672 Dialect *dialect = (*this)->getDialect();
1677 for (
auto &op : *getBody()) {
1679 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1684 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1685 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1687 return entryPointOp.emitError(
"function '")
1688 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1690 if (
auto interface = entryPointOp.getInterface()) {
1692 auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
1694 return entryPointOp.emitError(
1695 "expected symbol reference for interface "
1696 "specification instead of '")
1700 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1702 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1703 "symbol reference instead of'")
1704 << varSymRef <<
"'";
1709 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1710 funcOp, entryPointOp.getExecutionModel());
1711 if (!entryPoints.try_emplace(key, entryPointOp).second)
1712 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1713 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1717 auto linkageAttr = funcOp.getLinkageAttributes();
1718 auto hasImportLinkage =
1719 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1720 spirv::LinkageType::Import);
1721 if (funcOp.isExternal() && !hasImportLinkage)
1723 "'spirv.module' cannot contain external functions "
1724 "without 'Import' linkage_attributes (LinkageAttributes)");
1727 for (
auto &block : funcOp)
1728 for (
auto &op : block) {
1731 "functions in 'spirv.module' can only contain spirv.* ops");
1743LogicalResult spirv::ReferenceOfOp::verify() {
1745 (*this)->getParentOp(), getSpecConstAttr());
1748 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1750 constType = specConstOp.getDefaultValue().getType();
1752 auto specConstCompositeOp =
1753 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1754 if (specConstCompositeOp)
1755 constType = specConstCompositeOp.getType();
1757 if (!specConstOp && !specConstCompositeOp)
1759 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1761 if (getReference().
getType() != constType)
1762 return emitOpError(
"result type mismatch with the referenced "
1763 "specialization constant's type");
1772ParseResult spirv::SpecConstantOp::parse(
OpAsmParser &parser,
1774 StringAttr nameAttr;
1776 StringRef defaultValueAttrName =
1777 spirv::SpecConstantOp::getDefaultValueAttrName(
result.name);
1785 IntegerAttr specIdAttr;
1799void spirv::SpecConstantOp::print(
OpAsmPrinter &printer) {
1802 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1804 printer <<
" = " << getDefaultValue();
1807LogicalResult spirv::SpecConstantOp::verify() {
1808 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1809 if (specID.getValue().isNegative())
1812 auto value = getDefaultValue();
1813 if (isa<IntegerAttr, FloatAttr>(value)) {
1815 if (!isa<spirv::SPIRVType>(value.getType()))
1816 return emitOpError(
"default value bitwidth disallowed");
1820 "default value can only be a bool, integer, or float scalar");
1827LogicalResult spirv::VectorShuffleOp::verify() {
1828 VectorType resultType = cast<VectorType>(
getType());
1830 size_t numResultElements = resultType.getNumElements();
1831 if (numResultElements != getComponents().size())
1833 << numResultElements
1834 <<
") mismatch with the number of component selectors ("
1835 << getComponents().size() <<
")";
1837 size_t totalSrcElements =
1838 cast<VectorType>(getVector1().
getType()).getNumElements() +
1839 cast<VectorType>(getVector2().
getType()).getNumElements();
1841 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1842 uint32_t
index = selector.getZExtValue();
1843 if (
index >= totalSrcElements &&
1844 index != std::numeric_limits<uint32_t>().
max())
1846 <<
index <<
" out of range: expected to be in [0, "
1847 << totalSrcElements <<
") or 0xffffffff";
1856ParseResult spirv::SpecConstantCompositeOp::parse(
OpAsmParser &parser,
1859 StringAttr compositeName;
1871 const char *attrName =
"spec_const";
1878 constituents.push_back(specConstRef);
1884 StringAttr compositeSpecConstituentsName =
1885 spirv::SpecConstantCompositeOp::getConstituentsAttrName(
result.name);
1886 result.addAttribute(compositeSpecConstituentsName,
1893 StringAttr typeAttrName =
1894 spirv::SpecConstantCompositeOp::getTypeAttrName(
result.name);
1895 result.addAttribute(typeAttrName, TypeAttr::get(type));
1900void spirv::SpecConstantCompositeOp::print(
OpAsmPrinter &printer) {
1903 printer <<
" (" << llvm::interleaved(this->getConstituents().getValue())
1907LogicalResult spirv::SpecConstantCompositeOp::verify() {
1908 auto cType = dyn_cast<spirv::CompositeType>(
getType());
1909 auto constituents = this->getConstituents().getValue();
1912 return emitError(
"result type must be a composite type, but provided ")
1915 if (isa<spirv::CooperativeMatrixType>(cType))
1916 return emitError(
"unsupported composite type ") << cType;
1917 if (constituents.size() != cType.getNumElements())
1918 return emitError(
"has incorrect number of operands: expected ")
1919 << cType.getNumElements() <<
", but provided "
1920 << constituents.size();
1922 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1923 auto constituent = cast<FlatSymbolRefAttr>(constituents[
index]);
1926 (*this)->getParentOp(), constituent.getAttr());
1929 return emitError(
"unknown constituent symbol ") << constituent.getAttr();
1931 Type constituentType;
1932 if (
auto specConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp)) {
1933 constituentType = specConstOp.getDefaultValue().getType();
1934 }
else if (
auto specConstCompositeOp =
1935 dyn_cast<spirv::SpecConstantCompositeOp>(constituentOp)) {
1936 constituentType = specConstCompositeOp.getType();
1938 return emitError(
"unsupported constituent ")
1939 << constituent.getAttr()
1940 <<
": must reference a spirv.SpecConstant or "
1941 "spirv.SpecConstantComposite";
1944 if (constituentType != cType.getElementType(
index))
1945 return emitError(
"has incorrect types of operands: expected ")
1946 << cType.getElementType(
index) <<
", but provided "
1958spirv::EXTSpecConstantCompositeReplicateOp::parse(
OpAsmParser &parser,
1960 StringAttr compositeName;
1962 const char *attrName =
"spec_const";
1973 StringAttr compositeSpecConstituentName =
1974 spirv::EXTSpecConstantCompositeReplicateOp::getConstituentAttrName(
1976 result.addAttribute(compositeSpecConstituentName, specConstRef);
1978 StringAttr typeAttrName =
1979 spirv::EXTSpecConstantCompositeReplicateOp::getTypeAttrName(
result.name);
1980 result.addAttribute(typeAttrName, TypeAttr::get(type));
1985void spirv::EXTSpecConstantCompositeReplicateOp::print(
OpAsmPrinter &printer) {
1988 printer <<
" (" << this->getConstituent() <<
") : " <<
getType();
1991LogicalResult spirv::EXTSpecConstantCompositeReplicateOp::verify() {
1992 auto compositeType = dyn_cast<spirv::CompositeType>(
getType());
1994 return emitError(
"result type must be a composite type, but provided ")
1998 (*this)->getParentOp(), this->getConstituent());
2001 "splat spec constant reference defining constituent not found");
2003 auto constituentSpecConstOp = dyn_cast<spirv::SpecConstantOp>(constituentOp);
2004 if (!constituentSpecConstOp)
2005 return emitError(
"constituent is not a spec constant");
2007 Type constituentType = constituentSpecConstOp.getDefaultValue().getType();
2008 Type compositeElementType = compositeType.getElementType(0);
2009 if (constituentType != compositeElementType)
2010 return emitError(
"constituent has incorrect type: expected ")
2011 << compositeElementType <<
", but provided " << constituentType;
2020ParseResult spirv::SpecConstantOperationOp::parse(
OpAsmParser &parser,
2036 spirv::YieldOp::create(builder, wrappedOp->
getLoc(), wrappedOp->
getResult(0));
2047void spirv::SpecConstantOperationOp::print(
OpAsmPrinter &printer) {
2048 printer <<
" wraps ";
2052LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
2053 Block &block = getRegion().getBlocks().
front();
2056 return emitOpError(
"expected exactly 2 nested ops");
2064 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
2065 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
2067 "invalid operand, must be defined by a constant operation");
2076LogicalResult spirv::GLFrexpStructOp::verify() {
2078 dyn_cast<spirv::StructType>(getResult().
getType());
2081 return emitError(
"result type must be a struct type with two memebers");
2085 VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
2086 IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
2088 Type operandTy = getOperand().getType();
2089 VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
2090 FloatType operandFTy = dyn_cast<FloatType>(operandTy);
2092 if (significandTy != operandTy)
2093 return emitError(
"member zero of the resulting struct type must be the "
2094 "same type as the operand");
2096 if (exponentVecTy) {
2097 IntegerType componentIntTy =
2098 dyn_cast<IntegerType>(exponentVecTy.getElementType());
2099 if (!componentIntTy || componentIntTy.getWidth() != 32)
2100 return emitError(
"member one of the resulting struct type must"
2101 "be a scalar or vector of 32 bit integer type");
2102 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
2103 return emitError(
"member one of the resulting struct type "
2104 "must be a scalar or vector of 32 bit integer type");
2108 if (operandVecTy && exponentVecTy &&
2109 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
2112 if (operandFTy && exponentIntTy)
2115 return emitError(
"member one of the resulting struct type must have the same "
2116 "number of components as the operand type");
2125 if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
2126 return op->
emitOpError(
"operands must both be scalars or vectors");
2129 if (
auto vectorType = dyn_cast<VectorType>(type))
2130 return vectorType.getNumElements();
2135 return op->
emitOpError(
"operands must have the same number of elements");
2140LogicalResult spirv::GLLdexpOp::verify() {
2149LogicalResult spirv::CLLdexpOp::verify() {
2158LogicalResult spirv::CLPownOp::verify() {
2167LogicalResult spirv::CLRootnOp::verify() {
2176LogicalResult spirv::ShiftLeftLogicalOp::verify() {
2184LogicalResult spirv::ShiftRightArithmeticOp::verify() {
2192LogicalResult spirv::ShiftRightLogicalOp::verify() {
2200LogicalResult spirv::VectorTimesScalarOp::verify() {
2202 return emitOpError(
"vector operand and result type mismatch");
2203 auto scalarType = cast<VectorType>(
getType()).getElementType();
2204 if (getScalar().
getType() != scalarType)
2205 return emitOpError(
"scalar operand and result element type match");
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static Type getValueType(Attribute attr)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType, Type integerType)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
Type getElementType() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
uint64_t getN(LevelType lt)
constexpr char kFnNameAttrName[]
constexpr char kSpecIdAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
LogicalResult verifyPhysicalStorageBufferDecorations(Operation *op, Type pointeeType)
Verifies the SPV_KHR_physical_storage_buffer rule that a variable whose pointee is a pointer (or arra...
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.