28#include "llvm/ADT/Sequence.h"
29#include "llvm/ADT/StringExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
35#include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
44 return llvm::any_of(region, [](
Block &block) {
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
52struct SPIRVInlinerInterface :
public DialectInlinerInterface {
53 using DialectInlinerInterface::DialectInlinerInterface;
57 bool wouldBeCloned)
const final {
64 IRMapping &)
const final {
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
75 IRMapping &)
const final {
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
85 if (isa<spirv::KillOp>(op))
93 void handleTerminator(Operation *op,
Block *newDest)
const final {
94 if (
auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95 auto builder = OpBuilder(op);
96 spirv::BranchOp::create(builder, op->getLoc(), newDest);
98 }
else if (
auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
99 auto builder = OpBuilder(op);
100 spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
101 retValOp->getOperands());
108 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
110 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
115 assert(valuesToRepl.size() == 1 &&
116 "spirv.ReturnValue expected to only handle one result");
117 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
126void SPIRVDialect::initialize() {
127 registerAttributes();
133#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
136 addInterfaces<SPIRVInlinerInterface>();
139 allowUnknownOperations();
140 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
143std::string SPIRVDialect::getAttributeName(Decoration decoration) {
152template <
typename ValTy>
153static std::optional<ValTy>
parseAndVerify(SPIRVDialect
const &dialect,
175 if (
auto t = dyn_cast<FloatType>(type)) {
178 "only 8/16/32/64-bit float type allowed but found ")
182 }
else if (
auto t = dyn_cast<IntegerType>(type)) {
185 "only 1/8/16/32/64-bit integer type allowed but found ")
189 }
else if (
auto t = dyn_cast<VectorType>(type)) {
190 if (t.getRank() != 1) {
191 parser.
emitError(typeLoc,
"only 1-D vector allowed but found ") << t;
194 if (t.getNumElements() < 2) {
195 parser.
emitError(typeLoc,
"SPIR-V does not allow one-element vectors");
198 if (t.getNumElements() > 4) {
200 typeLoc,
"vector length has to be less than or equal to 4 but found ")
201 << t.getNumElements();
204 if (!isa<ScalarType>(t.getElementType())) {
207 "vector element type must be a SPIR-V scalar type but found ")
208 << t.getElementType();
211 }
else if (
auto t = dyn_cast<TensorArmType>(type)) {
212 if (!isa<ScalarType>(t.getElementType())) {
214 typeLoc,
"only scalar element type allowed in tensor type but found ")
215 << t.getElementType();
220 << type <<
" to compose SPIR-V types";
234 if (
auto t = dyn_cast<VectorType>(type)) {
235 if (t.getRank() != 1) {
236 parser.
emitError(typeLoc,
"only 1-D vector allowed but found ") << t;
239 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
241 "matrix columns size has to be less than or equal "
242 "to 4 and greater than or equal 2, but found ")
243 << t.getNumElements();
247 if (!isa<FloatType>(t.getElementType())) {
248 parser.
emitError(typeLoc,
"matrix columns' elements must be of "
250 << t.getElementType();
254 parser.
emitError(typeLoc,
"matrix must be composed using vector "
270 auto imageType = dyn_cast<ImageType>(type);
273 "sampled image must be composed using image type, got ")
278 if (llvm::is_contained({Dim::SubpassData, Dim::Buffer}, imageType.getDim())) {
280 typeLoc,
"sampled image Dim must not be SubpassData or Buffer, got ")
281 << stringifyDim(imageType.getDim());
307 if (!(stride = *optStride)) {
308 parser.
emitError(strideLoc,
"ArrayStride must be greater than zero");
330 if (countDims.size() != 1) {
332 "expected single integer for array element count");
340 parser.
emitError(countLoc,
"expected array length greater than 0");
370 if (dims.size() != 2) {
371 parser.
emitError(countLoc,
"expected row and column count");
384 CooperativeMatrixUseKHR use;
402 bool unranked =
false;
414 if (!unranked && dims.empty()) {
415 parser.
emitError(countLoc,
"arm.tensors do not support rank zero");
419 if (llvm::is_contained(dims, 0)) {
420 parser.
emitError(countLoc,
"arm.tensors do not support zero dimensions");
424 if (llvm::any_of(dims, [](
int64_t dim) {
return dim < 0; }) &&
425 llvm::any_of(dims, [](
int64_t dim) {
return dim > 0; })) {
426 parser.
emitError(countLoc,
"arm.tensor shape dimensions must be either "
427 "fully dynamic or completed shaped");
459 StringRef storageClassSpec;
464 auto storageClass = symbolizeStorageClass(storageClassSpec);
466 parser.
emitError(storageClassLoc,
"unknown storage class: ")
505 if (countDims.size() != 1) {
506 parser.
emitError(countLoc,
"expected single unsigned "
507 "integer for number of columns");
511 int64_t columnCount = countDims[0];
513 if (columnCount < 2 || columnCount > 4) {
514 parser.
emitError(countLoc,
"matrix is expected to have 2, 3, or 4 "
531template <
typename ValTy>
540 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
542 parser.
emitError(enumLoc,
"unknown attribute: '") << enumSpec <<
"'";
556template <
typename IntTy>
559 IntTy offsetVal = std::numeric_limits<IntTy>::max();
576template <
typename ParseType,
typename... Args>
577struct ParseCommaSeparatedList {
578 std::optional<std::tuple<ParseType, Args...>>
579 operator()(SPIRVDialect
const &dialect, DialectAsmParser &parser)
const {
584 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
587 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
588 if (!remainingValues)
590 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
591 remainingValues.value());
597template <
typename ParseType>
598struct ParseCommaSeparatedList<ParseType> {
599 std::optional<std::tuple<ParseType>>
600 operator()(SPIRVDialect
const &dialect, DialectAsmParser &parser)
const {
602 return std::tuple<ParseType>(*value);
629 ParseCommaSeparatedList<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
630 ImageSamplingInfo, ImageSamplerUseInfo,
631 ImageFormat>{}(dialect, parser);
667 if (failed(*offsetParseResult))
670 if (offsetInfo.size() != memberTypes.size() - 1) {
672 "offset specification must be given for "
675 offsetInfo.push_back(offset);
687 auto parseDecorations = [&]() {
689 if (!memberDecoration)
698 memberDecorationInfo.emplace_back(
699 static_cast<uint32_t
>(memberTypes.size() - 1),
700 memberDecoration.value(), memberDecorationValue);
702 memberDecorationInfo.emplace_back(
703 static_cast<uint32_t
>(memberTypes.size() - 1),
704 memberDecoration.value(), UnitAttr::get(dialect.getContext()));
730 StringRef identifier;
731 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
740 if (succeeded(cyclicParse)) {
743 "recursive struct reference not nested in struct definition");
754 if (failed(cyclicParse)) {
756 "identifier already used for an enclosing struct");
771 if (!identifier.empty())
782 if (!isa<SPIRVType>(memberType)) {
784 "member type must be a valid SPIR-V type");
787 memberTypes.push_back(memberType);
791 memberDecorationInfo))
795 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
797 "offset specification must be given for all members");
806 auto parseStructDecoration = [&]() {
807 std::optional<spirv::Decoration> decoration =
818 structDecorationInfo.emplace_back(decoration.value(), decorationValue);
820 structDecorationInfo.emplace_back(decoration.value(),
821 UnitAttr::get(dialect.getContext()));
827 if (failed(parseStructDecoration()))
833 if (!identifier.empty()) {
834 if (failed(idStructTy.
trySetBody(memberTypes, offsetInfo,
835 memberDecorationInfo,
836 structDecorationInfo)))
842 structDecorationInfo);
857 if (keyword ==
"array")
859 if (keyword ==
"coopmatrix")
861 if (keyword ==
"image")
863 if (keyword ==
"ptr")
865 if (keyword ==
"rtarray")
867 if (keyword ==
"sampled_image")
869 if (keyword ==
"sampler")
871 if (keyword ==
"named_barrier")
873 if (keyword ==
"struct")
875 if (keyword ==
"matrix")
877 if (keyword ==
"arm.tensor")
890 os <<
", stride=" << stride;
897 os <<
", stride=" << stride;
908 <<
", " << stringifyImageDepthInfo(type.
getDepthInfo()) <<
", "
922 os <<
"named_barrier";
926 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
934 if (failed(cyclicPrint)) {
944 auto printMember = [&](
unsigned i) {
948 if (type.
hasOffset() || !decorations.empty()) {
952 if (!decorations.empty())
956 os << stringifyDecoration(decoration.decoration);
957 if (decoration.hasValue()) {
962 llvm::interleaveComma(decorations, os, eachFn);
966 llvm::interleaveComma(llvm::seq<unsigned>(0, type.
getNumElements()), os,
972 if (!decorations.empty()) {
975 os << stringifyDecoration(decoration.decoration);
976 if (decoration.hasValue()) {
981 llvm::interleaveComma(decorations, os, eachFn);
1004 if (ShapedType::isDynamic(dim))
1021 [&](
auto type) {
print(type, os); })
1022 .DefaultUnreachable(
"Unhandled SPIR-V type");
1032 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
1033 return ub::PoisonOp::create(builder, loc, type, poison);
1035 if (!spirv::ConstantOp::isBuildableWith(type))
1038 return spirv::ConstantOp::create(builder, loc, type, value);
1045LogicalResult SPIRVDialect::verifyOperationAttribute(
Operation *op,
1047 StringRef symbol = attribute.
getName().strref();
1051 if (!isa<spirv::EntryPointABIAttr>(attr)) {
1053 << symbol <<
"' attribute must be an entry point ABI attribute";
1056 if (!isa<spirv::TargetEnvAttr>(attr))
1057 return op->
emitError(
"'") << symbol <<
"' must be a spirv::TargetEnvAttr";
1059 if (!isa<spirv::LoopControlAttr>(attr))
1061 << symbol <<
"' must be a spirv::LoopControlAttr";
1063 if (!isa<spirv::SelectionControlAttr>(attr))
1065 << symbol <<
"' must be a spirv::SelectionControlAttr";
1067 return op->
emitError(
"found unsupported '")
1068 << symbol <<
"' attribute on operation";
1078 StringRef symbol = attribute.
getName().strref();
1082 auto varABIAttr = dyn_cast<spirv::InterfaceVarABIAttr>(attr);
1085 << symbol <<
"' must be a spirv::InterfaceVarABIAttr";
1089 <<
"' attribute cannot specify storage class "
1090 "when attaching to a non-scalar value";
1093 if (symbol == spirv::DecorationAttr::name) {
1094 if (!isa<spirv::DecorationAttr>(attr))
1096 << symbol <<
"' must be a spirv::DecorationAttr";
1100 return emitError(loc,
"found unsupported '")
1101 << symbol <<
"' attribute on region argument";
1104LogicalResult SPIRVDialect::verifyRegionArgAttribute(
Operation *op,
1105 unsigned regionIndex,
1108 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1111 Type argType = funcOp.getArgumentTypes()[argIndex];
1116LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1117 Operation *op,
unsigned ,
unsigned resultIndex,
1119 if (
auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
1121 op->
getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
1123 "cannot attach SPIR-V attributes to region result which is "
1124 "not part of a spirv::GraphARMOp type");
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseTensorArmType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region ®ion)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalStar()=0
Parse a '*' token if present.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseXInDimensionList()=0
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
static NamedBarrierType get(MLIRContext *context)
Type getPointeeType() const
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
Type getImageType() const
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static bool isValid(FloatType)
Returns true if the given float type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
StringRef getLoopControlAttrName()
Returns the attribute name for specifying loop control.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
StringRef getSelectionControlAttrName()
Returns the attribute name for specifying selection control.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch