33 #include "llvm/ADT/APFloat.h"
34 #include "llvm/ADT/APInt.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/TypeSwitch.h"
42 #include <type_traits>
52 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
56 auto valueAttr = constOp.getValue();
57 auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
58 if (!integerValueAttr) {
62 if (integerValueAttr.getType().isSignlessInteger())
63 value = integerValueAttr.getInt();
65 value = integerValueAttr.getSInt();
72 spirv::MemorySemantics memorySemantics) {
79 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
80 spirv::MemorySemantics::Release |
81 spirv::MemorySemantics::AcquireRelease |
82 spirv::MemorySemantics::SequentiallyConsistent;
85 llvm::popcount(
static_cast<uint32_t
>(memorySemantics & atMostOneInSet));
88 "expected at most one of these four memory constraints "
89 "to be set: `Acquire`, `Release`,"
90 "`AcquireRelease` or `SequentiallyConsistent`");
99 stringifyDecoration(spirv::Decoration::DescriptorSet));
100 auto bindingName = llvm::convertToSnakeFromCamelCase(
101 stringifyDecoration(spirv::Decoration::Binding));
104 if (descriptorSet && binding) {
107 printer <<
" bind(" << descriptorSet.getInt() <<
", " << binding.getInt()
112 auto builtInName = llvm::convertToSnakeFromCamelCase(
113 stringifyDecoration(spirv::Decoration::BuiltIn));
114 if (
auto builtin = op->
getAttrOfType<StringAttr>(builtInName)) {
115 printer <<
" " << builtInName <<
"(\"" << builtin.getValue() <<
"\")";
116 elidedAttrs.push_back(builtInName);
134 auto fnType = llvm::dyn_cast<FunctionType>(type);
136 parser.
emitError(loc,
"expected function type");
141 result.
addTypes(fnType.getResults());
152 assert(op->
getNumResults() == 1 &&
"op should have one result");
158 [&](
Type type) { return type != resultType; })) {
167 p <<
" : " << resultType;
170 template <
typename Op>
172 spirv::ImageOperandsAttr attr,
175 if (operands.empty())
178 return imageOp.
emitError(
"the Image Operands should encode what operands "
179 "follow, as per Image Operands");
183 spirv::ImageOperands noSupportOperands =
184 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
185 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
186 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
187 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
188 spirv::ImageOperands::MakeTexelAvailable |
189 spirv::ImageOperands::MakeTexelVisible |
190 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
192 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
193 llvm_unreachable(
"unimplemented operands of Image Operands");
198 template <
typename BlockReadWriteOpTy>
202 if (
auto valVecTy = llvm::dyn_cast<VectorType>(valType))
203 valType = valVecTy.getElementType();
206 llvm::cast<spirv::PointerType>(ptr.
getType()).getPointeeType()) {
207 return op.
emitOpError(
"mismatch in result type and pointer type");
218 if (indices.empty()) {
219 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
223 for (
auto index : indices) {
224 if (
auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
225 if (cType.hasCompileTimeKnownNumElements() &&
227 static_cast<uint64_t
>(index) >= cType.getNumElements())) {
228 emitErrorFn(
"index ") << index <<
" out of bounds for " << type;
231 type = cType.getElementType(index);
233 emitErrorFn(
"cannot extract from non-composite type ")
234 << type <<
" with index " << index;
244 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
245 if (!indicesArrayAttr) {
246 emitErrorFn(
"expected a 32-bit integer array attribute for 'indices'");
249 if (indicesArrayAttr.empty()) {
250 emitErrorFn(
"expected at least one index for spirv.CompositeExtract");
255 for (
auto indexAttr : indicesArrayAttr) {
256 auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
258 emitErrorFn(
"expected an 32-bit integer for index, but found '")
262 indexVals.push_back(indexIntAttr.getInt());
282 template <
typename ExtendedBinaryOp>
284 auto resultType = llvm::cast<spirv::StructType>(op.getType());
285 if (resultType.getNumElements() != 2)
286 return op.
emitOpError(
"expected result struct type containing two members");
288 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
289 resultType.getElementType(0),
290 resultType.getElementType(1)}))
292 "expected all operand types and struct member types are the same");
309 auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
310 if (!structType || structType.getNumElements() != 2)
311 return parser.
emitError(loc,
"expected spirv.struct type with two members");
331 return op->
emitError(
"expected the same type for the first operand and "
332 "result, but provided ")
344 spirv::GlobalVariableOp var) {
349 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
353 return emitOpError(
"expected spirv.GlobalVariable symbol");
355 if (getPointer().getType() != varOp.getType()) {
357 "result type mismatch with the referenced global variable's type");
367 operand_range constituents = this->getConstituents();
375 auto coopElementType =
379 [](
auto coopType) {
return coopType.getElementType(); })
380 .Default([](
Type) {
return nullptr; });
383 if (coopElementType) {
384 if (constituents.size() != 1)
385 return emitOpError(
"has incorrect number of operands: expected ")
386 <<
"1, but provided " << constituents.size();
387 if (coopElementType != constituents.front().getType())
388 return emitOpError(
"operand type mismatch: expected operand type ")
389 << coopElementType <<
", but provided "
390 << constituents.front().getType();
395 auto cType = llvm::cast<spirv::CompositeType>(getType());
396 if (constituents.size() == cType.getNumElements()) {
397 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
398 if (constituents[index].getType() != cType.getElementType(index)) {
399 return emitOpError(
"operand type mismatch: expected operand type ")
400 << cType.getElementType(index) <<
", but provided "
401 << constituents[index].getType();
408 auto resultType = llvm::dyn_cast<VectorType>(cType);
411 "expected to return a vector or cooperative matrix when the number of "
412 "constituents is less than what the result needs");
415 for (
Value component : constituents) {
416 if (!llvm::isa<VectorType>(component.getType()) &&
417 !component.getType().isIntOrFloat())
418 return emitOpError(
"operand type mismatch: expected operand to have "
419 "a scalar or vector type, but provided ")
420 << component.getType();
422 Type elementType = component.getType();
423 if (
auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
424 sizes.push_back(vectorType.getNumElements());
425 elementType = vectorType.getElementType();
430 if (elementType != resultType.getElementType())
431 return emitOpError(
"operand element type mismatch: expected to be ")
432 << resultType.getElementType() <<
", but provided " << elementType;
434 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
435 if (totalCount != cType.getNumElements())
436 return emitOpError(
"has incorrect number of operands: expected ")
437 << cType.getNumElements() <<
", but provided " << totalCount;
454 build(builder, state, elementType, composite, indexAttr);
482 printer <<
' ' << getComposite() <<
getIndices() <<
" : "
487 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
489 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
493 if (resultType != getType()) {
494 return emitOpError(
"invalid result type: expected ")
495 << resultType <<
" but provided " << getType();
509 build(builder, state, composite.
getType(),
object, composite, indexAttr);
515 Type objectType, compositeType;
530 auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(
getIndices());
532 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
536 if (objectType != getObject().getType()) {
537 return emitOpError(
"object operand type should be ")
538 << objectType <<
", but found " << getObject().getType();
541 if (getComposite().getType() != getType()) {
542 return emitOpError(
"result type should be the same as "
543 "the composite type, but found ")
544 << getComposite().getType() <<
" vs " << getType();
551 printer <<
" " << getObject() <<
", " << getComposite() <<
getIndices()
552 <<
" : " << getObject().
getType() <<
" into "
553 << getComposite().getType();
567 if (
auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
568 type = typedAttr.getType();
569 if (llvm::isa<NoneType, TensorType>(type)) {
578 printer <<
' ' << getValue();
579 if (llvm::isa<spirv::ArrayType>(getType()))
580 printer <<
" : " << getType();
585 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
586 auto valueType = llvm::cast<TypedAttr>(value).getType();
587 if (valueType != opType)
589 << opType <<
") does not match value type (" << valueType <<
")";
592 if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
593 auto valueType = llvm::cast<TypedAttr>(value).getType();
594 if (valueType == opType)
596 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
597 auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
600 << opType <<
") does not match value type (" << valueType
601 <<
"), must be the same or spirv.array";
603 int numElements = arrayType.getNumElements();
604 auto opElemType = arrayType.getElementType();
605 while (
auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
606 numElements *= t.getNumElements();
607 opElemType = t.getElementType();
609 if (!opElemType.isIntOrFloat())
610 return op.
emitOpError(
"only support nested array result type");
612 auto valueElemType = shapedType.getElementType();
613 if (valueElemType != opElemType) {
615 << opElemType <<
") does not match value element type ("
616 << valueElemType <<
")";
619 if (numElements != shapedType.getNumElements()) {
620 return op.
emitOpError(
"result number of elements (")
621 << numElements <<
") does not match value number of elements ("
622 << shapedType.getNumElements() <<
")";
626 if (
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
627 auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
630 "must have spirv.array result type for array value");
631 Type elemType = arrayType.getElementType();
632 for (
Attribute element : arrayAttr.getValue()) {
639 return op.
emitOpError(
"cannot have attribute: ") << value;
649 bool spirv::ConstantOp::isBuildableWith(
Type type) {
651 if (!llvm::isa<spirv::SPIRVType>(type))
656 return llvm::isa<spirv::ArrayType>(type);
664 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
665 unsigned width = intType.getWidth();
667 return builder.
create<spirv::ConstantOp>(loc, type,
669 return builder.
create<spirv::ConstantOp>(
672 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
673 return builder.
create<spirv::ConstantOp>(
676 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
677 Type elemType = vectorType.getElementType();
678 if (llvm::isa<IntegerType>(elemType)) {
679 return builder.
create<spirv::ConstantOp>(
684 if (llvm::isa<FloatType>(elemType)) {
685 return builder.
create<spirv::ConstantOp>(
692 llvm_unreachable(
"unimplemented types for ConstantOp::getZero()");
695 spirv::ConstantOp spirv::ConstantOp::getOne(
Type type,
Location loc,
697 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
698 unsigned width = intType.getWidth();
700 return builder.
create<spirv::ConstantOp>(loc, type,
702 return builder.
create<spirv::ConstantOp>(
705 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
706 return builder.
create<spirv::ConstantOp>(
709 if (
auto vectorType = llvm::dyn_cast<VectorType>(type)) {
710 Type elemType = vectorType.getElementType();
711 if (llvm::isa<IntegerType>(elemType)) {
712 return builder.
create<spirv::ConstantOp>(
717 if (llvm::isa<FloatType>(elemType)) {
718 return builder.
create<spirv::ConstantOp>(
725 llvm_unreachable(
"unimplemented types for ConstantOp::getOne()");
728 void mlir::spirv::ConstantOp::getAsmResultNames(
730 Type type = getType();
733 llvm::raw_svector_ostream specialName(specialNameBuffer);
734 specialName <<
"cst";
736 IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
738 if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
739 if (intTy && intTy.getWidth() == 1) {
740 return setNameFn(getResult(), (intCst.getInt() ?
"true" :
"false"));
743 if (intTy.isSignless()) {
744 specialName << intCst.getInt();
745 }
else if (intTy.isUnsigned()) {
746 specialName << intCst.getUInt();
748 specialName << intCst.getSInt();
752 if (intTy || llvm::isa<FloatType>(type)) {
753 specialName <<
'_' << type;
756 if (
auto vecType = llvm::dyn_cast<VectorType>(type)) {
757 specialName <<
"_vec_";
758 specialName << vecType.getDimSize(0);
760 Type elementType = vecType.getElementType();
762 if (llvm::isa<IntegerType>(elementType) ||
763 llvm::isa<FloatType>(elementType)) {
764 specialName <<
"x" << elementType;
768 setNameFn(getResult(), specialName.str());
771 void mlir::spirv::AddressOfOp::getAsmResultNames(
774 llvm::raw_svector_ostream specialName(specialNameBuffer);
775 specialName << getVariable() <<
"_addr";
776 setNameFn(getResult(), specialName.str());
792 spirv::ExecutionModel executionModel,
793 spirv::FuncOp
function,
795 build(builder, state,
802 spirv::ExecutionModel execModel;
808 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
817 FlatSymbolRefAttr var;
819 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
821 interfaceVars.push_back(var);
834 auto interfaceVars = getInterface().getValue();
835 if (!interfaceVars.empty()) {
837 llvm::interleaveComma(interfaceVars, printer);
852 spirv::FuncOp
function,
853 spirv::ExecutionMode executionMode,
862 spirv::ExecutionMode execMode;
865 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
877 values.push_back(llvm::cast<IntegerAttr>(value).getInt());
887 printer <<
" \"" << stringifyExecutionMode(getExecutionMode()) <<
"\"";
888 auto values = this->getValues();
892 llvm::interleaveComma(values, printer, [&](
Attribute a) {
893 printer << llvm::cast<IntegerAttr>(a).getInt();
914 bool isVariadic =
false;
916 parser,
false, entryArgs, isVariadic, resultTypes,
921 for (
auto &arg : entryArgs)
922 argTypes.push_back(arg.type);
928 spirv::FunctionControl fnControl;
929 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
937 assert(resultAttrs.size() == resultTypes.size());
939 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
940 getResAttrsAttrName(result.
name));
953 auto fnType = getFunctionType();
955 printer, *
this, fnType.getInputs(),
956 false, fnType.getResults());
957 printer <<
" \"" << spirv::stringifyFunctionControl(getFunctionControl())
961 {spirv::attributeName<spirv::FunctionControl>(),
962 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
963 getFunctionControlAttrName()});
966 Region &body = this->getBody();
975 if (getFunctionType().getNumResults() > 1)
976 return emitOpError(
"cannot have more than one result");
981 FunctionType fnType = getFunctionType();
984 if (
auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
985 if (fnType.getNumResults() != 0)
986 return retOp.emitOpError(
"cannot be used in functions returning value");
987 }
else if (
auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
988 if (fnType.getNumResults() != 1)
989 return retOp.emitOpError(
990 "returns 1 value but enclosing function requires ")
991 << fnType.getNumResults() <<
" results";
993 auto retOperandType = retOp.getValue().getType();
994 auto fnResultType = fnType.getResult(0);
995 if (retOperandType != fnResultType)
996 return retOp.emitOpError(
" return value's type (")
997 << retOperandType <<
") mismatch with function's result type ("
998 << fnResultType <<
")";
1005 return failure(walkResult.wasInterrupted());
1009 StringRef name, FunctionType type,
1010 spirv::FunctionControl control,
1014 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
1015 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1016 builder.
getAttr<spirv::FunctionControlAttr>(control));
1017 state.attributes.append(attrs.begin(), attrs.end());
1065 Type type, StringRef name,
1066 unsigned descriptorSet,
unsigned binding) {
1069 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1072 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1077 Type type, StringRef name,
1078 spirv::BuiltIn builtin) {
1081 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1088 StringAttr nameAttr;
1113 if (!llvm::isa<spirv::PointerType>(type)) {
1114 return parser.
emitError(loc,
"expected spirv.ptr type");
1123 spirv::attributeName<spirv::StorageClass>()};
1131 if (
auto initializer = this->getInitializer()) {
1140 printer <<
" : " << getType();
1144 if (!llvm::isa<spirv::PointerType>(getType()))
1145 return emitOpError(
"result must be of a !spv.ptr type");
1151 auto storageClass = this->storageClass();
1152 if (storageClass == spirv::StorageClass::Generic ||
1153 storageClass == spirv::StorageClass::Function) {
1154 return emitOpError(
"storage class cannot be '")
1155 << stringifyStorageClass(storageClass) <<
"'";
1161 (*this)->getParentOp(), init.getAttr());
1166 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
1167 return emitOpError(
"initializer must be result of a "
1168 "spirv.SpecConstant or spirv.GlobalVariable op");
1182 spirv::StorageClass storageClass;
1191 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1203 printer <<
" " << getPtr() <<
" : " << getType();
1220 spirv::StorageClass storageClass;
1231 if (
auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
1242 printer <<
" " << getPtr() <<
", " << getValue() <<
" : "
1243 << getValue().getType();
1334 std::optional<StringRef> name) {
1344 spirv::AddressingModel addressingModel,
1345 spirv::MemoryModel memoryModel,
1346 std::optional<VerCapExtAttr> vceTriple,
1347 std::optional<StringRef> name) {
1350 builder.
getAttr<spirv::AddressingModelAttr>(addressingModel));
1351 state.addAttribute(
"memory_model",
1352 builder.
getAttr<spirv::MemoryModelAttr>(memoryModel));
1356 state.addAttribute(getVCETripleAttrName(), *vceTriple);
1367 StringAttr nameAttr;
1372 spirv::AddressingModel addrModel;
1373 spirv::MemoryModel memoryModel;
1374 if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
1376 spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
1383 spirv::ModuleOp::getVCETripleAttrName(),
1400 if (std::optional<StringRef> name = getName()) {
1409 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
1410 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
1411 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
1414 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
1415 printer <<
" requires " << *triple;
1416 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
1425 Dialect *dialect = (*this)->getDialect();
1430 for (
auto &op : *getBody()) {
1432 return op.
emitError(
"'spirv.module' can only contain spirv.* ops");
1437 if (
auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
1438 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
1440 return entryPointOp.emitError(
"function '")
1441 << entryPointOp.getFn() <<
"' not found in 'spirv.module'";
1443 if (
auto interface = entryPointOp.getInterface()) {
1445 auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
1447 return entryPointOp.emitError(
1448 "expected symbol reference for interface "
1449 "specification instead of '")
1453 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
1455 return entryPointOp.emitError(
"expected spirv.GlobalVariable "
1456 "symbol reference instead of'")
1457 << varSymRef <<
"'";
1462 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
1463 funcOp, entryPointOp.getExecutionModel());
1464 auto entryPtIt = entryPoints.find(key);
1465 if (entryPtIt != entryPoints.end()) {
1466 return entryPointOp.emitError(
"duplicate of a previous EntryPointOp");
1468 entryPoints[key] = entryPointOp;
1469 }
else if (
auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
1473 auto linkageAttr = funcOp.getLinkageAttributes();
1474 auto hasImportLinkage =
1475 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
1476 spirv::LinkageType::Import);
1477 if (funcOp.isExternal() && !hasImportLinkage)
1479 "'spirv.module' cannot contain external functions "
1480 "without 'Import' linkage_attributes (LinkageAttributes)");
1483 for (
auto &block : funcOp)
1484 for (
auto &op : block) {
1487 "functions in 'spirv.module' can only contain spirv.* ops");
1501 (*this)->getParentOp(), getSpecConstAttr());
1504 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
1506 constType = specConstOp.getDefaultValue().getType();
1508 auto specConstCompositeOp =
1509 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
1510 if (specConstCompositeOp)
1511 constType = specConstCompositeOp.getType();
1513 if (!specConstOp && !specConstCompositeOp)
1515 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
1517 if (getReference().getType() != constType)
1518 return emitOpError(
"result type mismatch with the referenced "
1519 "specialization constant's type");
1530 StringAttr nameAttr;
1539 IntegerAttr specIdAttr;
1557 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1559 printer <<
" = " << getDefaultValue();
1563 if (
auto specID = (*this)->getAttrOfType<IntegerAttr>(
kSpecIdAttrName))
1564 if (specID.getValue().isNegative())
1565 return emitOpError(
"SpecId cannot be negative");
1567 auto value = getDefaultValue();
1568 if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
1570 if (!llvm::isa<spirv::SPIRVType>(value.getType()))
1571 return emitOpError(
"default value bitwidth disallowed");
1575 "default value can only be a bool, integer, or float scalar");
1583 VectorType resultType = llvm::cast<VectorType>(getType());
1585 size_t numResultElements = resultType.getNumElements();
1586 if (numResultElements != getComponents().size())
1587 return emitOpError(
"result type element count (")
1588 << numResultElements
1589 <<
") mismatch with the number of component selectors ("
1590 << getComponents().size() <<
")";
1592 size_t totalSrcElements =
1593 llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
1594 llvm::cast<VectorType>(getVector2().getType()).getNumElements();
1596 for (
const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
1597 uint32_t index = selector.getZExtValue();
1598 if (index >= totalSrcElements &&
1599 index != std::numeric_limits<uint32_t>().
max())
1600 return emitOpError(
"component selector ")
1601 << index <<
" out of range: expected to be in [0, "
1602 << totalSrcElements <<
") or 0xffffffff";
1616 [](
auto matrixType) {
return matrixType.getElementType(); })
1617 .Default([](
Type) {
return nullptr; });
1619 assert(elementType &&
"Unhandled type");
1622 if (getScalar().getType() != elementType)
1623 return emitOpError(
"input matrix components' type and scaling value must "
1624 "have the same type");
1634 auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1635 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1638 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
1639 return emitError(
"input matrix rows count must be equal to "
1640 "output matrix columns count");
1642 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
1643 return emitError(
"input matrix columns count must be equal to "
1644 "output matrix rows count");
1647 if (inputMatrix.getElementType() != resultMatrix.getElementType())
1648 return emitError(
"input and output matrices must have the same "
1659 auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
1660 auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
1661 auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
1664 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
1665 return emitError(
"left matrix columns' count must be equal to "
1666 "the right matrix rows' count");
1669 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
1671 "right and result matrices must have equal columns' count");
1674 if (rightMatrix.getElementType() != resultMatrix.getElementType())
1675 return emitError(
"right and result matrices' component type must"
1679 if (leftMatrix.getElementType() != resultMatrix.getElementType())
1680 return emitError(
"left and result matrices' component type"
1681 " must be the same");
1684 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
1685 return emitError(
"left and result matrices must have equal rows' count");
1697 StringAttr compositeName;
1709 const char *attrName =
"spec_const";
1716 constituents.push_back(specConstRef);
1738 auto constituents = this->getConstituents().getValue();
1740 if (!constituents.empty())
1741 llvm::interleaveComma(constituents, printer);
1743 printer <<
") : " << getType();
1747 auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
1748 auto constituents = this->getConstituents().getValue();
1751 return emitError(
"result type must be a composite type, but provided ")
1754 if (llvm::isa<spirv::CooperativeMatrixNVType>(cType))
1755 return emitError(
"unsupported composite type ") << cType;
1756 if (llvm::isa<spirv::JointMatrixINTELType>(cType))
1757 return emitError(
"unsupported composite type ") << cType;
1758 if (constituents.size() != cType.getNumElements())
1759 return emitError(
"has incorrect number of operands: expected ")
1760 << cType.getNumElements() <<
", but provided "
1761 << constituents.size();
1763 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1764 auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
1766 auto constituentSpecConstOp =
1768 (*this)->getParentOp(), constituent.getAttr()));
1770 if (constituentSpecConstOp.getDefaultValue().getType() !=
1771 cType.getElementType(index))
1772 return emitError(
"has incorrect types of operands: expected ")
1773 << cType.getElementType(index) <<
", but provided "
1774 << constituentSpecConstOp.getDefaultValue().getType();
1812 printer <<
" wraps ";
1816 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
1817 Block &block = getRegion().getBlocks().
front();
1820 return emitOpError(
"expected exactly 2 nested ops");
1825 return emitOpError(
"invalid enclosed op");
1828 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
1829 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
1831 "invalid operand, must be defined by a constant operation");
1842 llvm::dyn_cast<spirv::StructType>(getResult().getType());
1845 return emitError(
"result type must be a struct type with two memebers");
1849 VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
1850 IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
1852 Type operandTy = getOperand().getType();
1853 VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
1854 FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
1856 if (significandTy != operandTy)
1857 return emitError(
"member zero of the resulting struct type must be the "
1858 "same type as the operand");
1860 if (exponentVecTy) {
1861 IntegerType componentIntTy =
1862 llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
1863 if (!componentIntTy || componentIntTy.getWidth() != 32)
1864 return emitError(
"member one of the resulting struct type must"
1865 "be a scalar or vector of 32 bit integer type");
1866 }
else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
1867 return emitError(
"member one of the resulting struct type "
1868 "must be a scalar or vector of 32 bit integer type");
1872 if (operandVecTy && exponentVecTy &&
1873 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
1876 if (operandFTy && exponentIntTy)
1879 return emitError(
"member one of the resulting struct type must have the same "
1880 "number of components as the operand type");
1888 Type significandType = getX().getType();
1889 Type exponentType = getExp().getType();
1891 if (llvm::isa<FloatType>(significandType) !=
1892 llvm::isa<IntegerType>(exponentType))
1893 return emitOpError(
"operands must both be scalars or vectors");
1896 if (
auto vectorType = llvm::dyn_cast<VectorType>(type))
1897 return vectorType.getNumElements();
1902 return emitOpError(
"operands must have the same number of elements");
1912 VectorType resultType = llvm::cast<VectorType>(getResult().getType());
1913 auto sampledImageType =
1914 llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
1916 llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
1918 if (resultType.getNumElements() != 4)
1919 return emitOpError(
"result type must be a vector of four components");
1921 Type elementType = resultType.getElementType();
1922 Type sampledElementType = imageType.getElementType();
1923 if (!llvm::isa<NoneType>(sampledElementType) &&
1924 elementType != sampledElementType)
1926 "the component type of result must be the same as sampled type of the "
1927 "underlying image type");
1929 spirv::Dim imageDim = imageType.getDim();
1930 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
1932 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
1933 imageDim != spirv::Dim::Rect)
1935 "the Dim operand of the underlying image type must be 2D, Cube, or "
1938 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
1939 return emitOpError(
"the MS operand of the underlying image type must be 0");
1941 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
1942 auto operandArguments = getOperandArguments();
1976 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1982 if (rhsMask.isZero())
1983 return getOperand2();
1986 if (rhsMask.isAllOnes())
1987 return getOperand1();
1990 if (
auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1993 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1994 return getOperand1();
2004 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
2010 if (rhsMask.isZero())
2011 return getOperand1();
2014 if (rhsMask.isAllOnes())
2015 return getOperand2();
2026 llvm::cast<spirv::ImageType>(getImage().getType());
2027 Type resultType = getResult().getType();
2029 spirv::Dim dim = imageType.
getDim();
2033 case spirv::Dim::Dim1D:
2034 case spirv::Dim::Dim2D:
2035 case spirv::Dim::Dim3D:
2036 case spirv::Dim::Cube:
2037 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
2038 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
2039 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
2041 "if Dim is 1D, 2D, 3D, or Cube, "
2042 "it must also have either an MS of 1 or a Sampled of 0 or 2");
2044 case spirv::Dim::Buffer:
2045 case spirv::Dim::Rect:
2048 return emitError(
"the Dim operand of the image type must "
2049 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
2052 unsigned componentNumber = 0;
2054 case spirv::Dim::Dim1D:
2055 case spirv::Dim::Buffer:
2056 componentNumber = 1;
2058 case spirv::Dim::Dim2D:
2059 case spirv::Dim::Cube:
2060 case spirv::Dim::Rect:
2061 componentNumber = 2;
2063 case spirv::Dim::Dim3D:
2064 componentNumber = 3;
2070 if (imageType.
getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
2071 componentNumber += 1;
2073 unsigned resultComponentNumber = 1;
2074 if (
auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
2075 resultComponentNumber = resultVectorType.getNumElements();
2077 if (componentNumber != resultComponentNumber)
2078 return emitError(
"expected the result to have ")
2079 << componentNumber <<
" component(s), but found "
2080 << resultComponentNumber <<
" component(s)";
2090 if (getVector().getType() != getType())
2091 return emitOpError(
"vector operand and result type mismatch");
2092 auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
2093 if (getScalar().getType() != scalarType)
2094 return emitOpError(
"scalar operand and result element type match");
static std::string bindingName()
Returns the string name of the Binding decoration.
static std::string descriptorSetName()
Returns the string name of the DescriptorSet decoration.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType)
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op)
static LogicalResult verifyShiftOp(Operation *op)
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void printOneResultOp(Operation *op, OpAsmPrinter &p)
static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands)
static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer)
static int64_t getNumElements(ShapedType type)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalSymbolName(StringAttr &result)=0
Parse an optional -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType & getOperations()
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
FloatAttr getFloatAttr(Type type, double value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual Operation * parseGenericOperation(Block *insertBlock, Block::iterator insertPt)=0
Parse an operation in its generic form.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printGenericOp(Operation *op, bool printOpName=true)=0
Print the entire operation with the default generic assembly form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
A trait to mark ops that can be enclosed/wrapped in a SpecConstantOperation op.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
AttrClass getAttrOfType(StringAttr name)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
ImageArrayedInfo getArrayedInfo() const
ImageSamplerUseInfo getSamplerUseInfo() const
ImageSamplingInfo getSamplingInfo() const
static PointerType get(Type pointeeType, StorageClass storageClass)
unsigned getNumElements() const
Type getElementType(unsigned) const
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
constexpr char kInitializerAttrName[]
constexpr char kIndicesAttrName[]
constexpr char kSpecIdAttrName[]
constexpr char kValuesAttrName[]
constexpr char kValueAttrName[]
constexpr char kCompositeSpecConstituentsName[]
constexpr char kInterfaceAttrName[]
constexpr char kTypeAttrName[]
constexpr char kDefaultValueAttrName[]
LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
AddressingModel getAddressingModel(TargetEnvAttr targetAttr, bool use64bitAddress)
Returns addressing model selected based on target environment.
FailureOr< ExecutionModel > getExecutionModel(TargetEnvAttr targetAttr)
Returns execution model selected based on target environment.
FailureOr< MemoryModel > getMemoryModel(TargetEnvAttr targetAttr)
Returns memory model selected based on target environment.
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.