28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
48 return llvm::any_of(region, [](
Block &block) {
50 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
61 bool wouldBeCloned)
const final {
72 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
81 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
94 if (
auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
97 }
else if (
auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
99 retValOp->getOperands());
108 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
113 assert(valuesToRepl.size() == 1 &&
114 "spirv.ReturnValue expected to only handle one result");
115 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
124 void SPIRVDialect::initialize() {
125 registerAttributes();
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
134 addInterfaces<SPIRVInlinerInterface>();
137 allowUnknownOperations();
138 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
141 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
150 template <
typename ValTy>
151 static std::optional<ValTy>
parseAndVerify(SPIRVDialect
const &dialect,
173 if (
auto t = llvm::dyn_cast<FloatType>(type)) {
175 parser.
emitError(typeLoc,
"cannot use 'bf16' to compose SPIR-V types");
178 }
else if (
auto t = llvm::dyn_cast<IntegerType>(type)) {
181 "only 1/8/16/32/64-bit integer type allowed but found ")
185 }
else if (
auto t = llvm::dyn_cast<VectorType>(type)) {
186 if (t.getRank() != 1) {
187 parser.
emitError(typeLoc,
"only 1-D vector allowed but found ") << t;
190 if (t.getNumElements() > 4) {
192 typeLoc,
"vector length has to be less than or equal to 4 but found ")
193 << t.getNumElements();
198 << type <<
" to compose SPIR-V types";
212 if (
auto t = llvm::dyn_cast<VectorType>(type)) {
213 if (t.getRank() != 1) {
214 parser.
emitError(typeLoc,
"only 1-D vector allowed but found ") << t;
217 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
219 "matrix columns size has to be less than or equal "
220 "to 4 and greater than or equal 2, but found ")
221 << t.getNumElements();
225 if (!llvm::isa<FloatType>(t.getElementType())) {
226 parser.
emitError(typeLoc,
"matrix columns' elements must be of "
228 << t.getElementType();
232 parser.
emitError(typeLoc,
"matrix must be composed using vector "
248 if (!llvm::isa<ImageType>(type)) {
250 "sampled image must be composed using image type, got ")
277 if (!(stride = *optStride)) {
278 parser.
emitError(strideLoc,
"ArrayStride must be greater than zero");
300 if (countDims.size() != 1) {
302 "expected single integer for array element count");
308 int64_t count = countDims[0];
310 parser.
emitError(countLoc,
"expected array length greater than 0");
340 if (dims.size() != 2) {
341 parser.
emitError(countLoc,
"expected row and column count");
354 CooperativeMatrixUseKHR use;
378 if (dims.size() != 2) {
379 parser.
emitError(countLoc,
"expected rows and columns size");
386 MatrixLayout matrixLayout;
418 StringRef storageClassSpec;
423 auto storageClass = symbolizeStorageClass(storageClassSpec);
425 parser.
emitError(storageClassLoc,
"unknown storage class: ")
464 if (countDims.size() != 1) {
465 parser.
emitError(countLoc,
"expected single unsigned "
466 "integer for number of columns");
470 int64_t columnCount = countDims[0];
472 if (columnCount < 2 || columnCount > 4) {
473 parser.
emitError(countLoc,
"matrix is expected to have 2, 3, or 4 "
490 template <
typename ValTy>
499 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
501 parser.
emitError(enumLoc,
"unknown attribute: '") << enumSpec <<
"'";
515 template <
typename IntTy>
527 return parseAndVerifyInteger<unsigned>(dialect, parser);
535 template <
typename ParseType,
typename... Args>
536 struct ParseCommaSeparatedList {
537 std::optional<std::tuple<ParseType, Args...>>
539 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
543 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
546 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
547 if (!remainingValues)
549 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
550 remainingValues.value());
556 template <
typename ParseType>
557 struct ParseCommaSeparatedList<ParseType> {
558 std::optional<std::tuple<ParseType>>
560 if (
auto value = parseAndVerify<ParseType>(dialect, parser))
561 return std::tuple<ParseType>(*value);
588 ParseCommaSeparatedList<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
589 ImageSamplingInfo, ImageSamplerUseInfo,
590 ImageFormat>{}(dialect, parser);
626 if (
failed(*offsetParseResult))
629 if (offsetInfo.size() != memberTypes.size() - 1) {
631 "offset specification must be given for "
634 offsetInfo.push_back(offset);
646 auto parseDecorations = [&]() {
647 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
648 if (!memberDecoration)
653 auto memberDecorationValue =
654 parseAndVerifyInteger<uint32_t>(dialect, parser);
656 if (!memberDecorationValue)
659 memberDecorationInfo.emplace_back(
660 static_cast<uint32_t
>(memberTypes.size() - 1), 1,
661 memberDecoration.value(), memberDecorationValue.value());
663 memberDecorationInfo.emplace_back(
664 static_cast<uint32_t
>(memberTypes.size() - 1), 0,
665 memberDecoration.value(), 0);
689 StringRef identifier;
702 "recursive struct reference not nested in struct definition");
713 if (
failed(cyclicParse)) {
715 "identifier already used for an enclosing struct");
730 if (!identifier.empty())
741 memberTypes.push_back(memberType);
745 memberDecorationInfo))
749 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
751 "offset specification must be given for all members");
758 if (!identifier.empty()) {
760 memberDecorationInfo)))
780 if (keyword ==
"array")
782 if (keyword ==
"coopmatrix")
784 if (keyword ==
"jointmatrix")
786 if (keyword ==
"image")
788 if (keyword ==
"ptr")
790 if (keyword ==
"rtarray")
792 if (keyword ==
"sampled_image")
794 if (keyword ==
"struct")
796 if (keyword ==
"matrix")
809 os <<
", stride=" << stride;
816 os <<
", stride=" << stride;
827 <<
", " << stringifyImageDepthInfo(type.
getDepthInfo()) <<
", "
847 if (
failed(cyclicPrint)) {
857 auto printMember = [&](
unsigned i) {
861 if (type.
hasOffset() || !decorations.empty()) {
865 if (!decorations.empty())
869 os << stringifyDecoration(decoration.decoration);
870 if (decoration.hasValue) {
871 os <<
"=" << decoration.decorationValue;
874 llvm::interleaveComma(decorations, os, eachFn);
878 llvm::interleaveComma(llvm::seq<unsigned>(0, type.
getNumElements()), os,
893 os <<
", " << stringifyScope(type.
getScope()) <<
">";
906 .Default([](
Type) { llvm_unreachable(
"unhandled SPIR-V type"); });
916 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
917 return builder.
create<ub::PoisonOp>(loc, type, poison);
919 if (!spirv::ConstantOp::isBuildableWith(type))
922 return builder.
create<spirv::ConstantOp>(loc, type, value);
931 StringRef symbol = attribute.
getName().strref();
935 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
937 << symbol <<
"' attribute must be an entry point ABI attribute";
940 if (!llvm::isa<spirv::TargetEnvAttr>(attr))
941 return op->
emitError(
"'") << symbol <<
"' must be a spirv::TargetEnvAttr";
943 return op->
emitError(
"found unsupported '")
944 << symbol <<
"' attribute on operation";
954 StringRef symbol = attribute.
getName().strref();
958 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
961 << symbol <<
"' must be a spirv::InterfaceVarABIAttr";
965 <<
"' attribute cannot specify storage class "
966 "when attaching to a non-scalar value";
969 if (symbol == spirv::DecorationAttr::name) {
970 if (!isa<spirv::DecorationAttr>(attr))
972 << symbol <<
"' must be a spirv::DecorationAttr";
976 return emitError(loc,
"found unsupported '")
977 << symbol <<
"' attribute on region argument";
981 unsigned regionIndex,
984 auto funcOp = dyn_cast<FunctionOpInterface>(op);
987 Type argType = funcOp.getArgumentTypes()[argIndex];
995 return op->
emitError(
"cannot attach SPIR-V attributes to region result");
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseJointMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region ®ion)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void erase()
Remove this operation from its parent block and delete it.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
Scope getScope() const
Return the scope of the joint matrix.
unsigned getColumns() const
return the number of columns of the matrix.
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
unsigned getRows() const
return the number of rows of the matrix.
MatrixLayout getMatrixLayout() const
return the layout of the matrix
Type getElementType() const
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
Type getImageType() const
static SampledImageType get(Type imageType)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.