28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/TypeSwitch.h"
35 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
44 return llvm::any_of(region, [](
Block &block) {
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
57 bool wouldBeCloned)
const final {
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
85 if (isa<spirv::KillOp>(op))
94 if (
auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
96 spirv::BranchOp::create(builder, op->getLoc(), newDest);
98 }
else if (
auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
100 spirv::BranchOp::create(builder, retValOp->getLoc(), newDest,
101 retValOp->getOperands());
110 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
115 assert(valuesToRepl.size() == 1 &&
116 "spirv.ReturnValue expected to only handle one result");
117 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
126 void SPIRVDialect::initialize() {
127 registerAttributes();
133 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
136 addInterfaces<SPIRVInlinerInterface>();
139 allowUnknownOperations();
140 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
143 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
144 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
152 template <
typename ValTy>
153 static std::optional<ValTy>
parseAndVerify(SPIRVDialect
const &dialect,
175 if (
auto t = llvm::dyn_cast<FloatType>(type)) {
177 }
else if (
auto t = llvm::dyn_cast<IntegerType>(type)) {
180 "only 1/8/16/32/64-bit integer type allowed but found ")
184 }
else if (
auto t = llvm::dyn_cast<VectorType>(type)) {
185 if (t.getRank() != 1) {
186 parser.
emitError(typeLoc,
"only 1-D vector allowed but found ") << t;
189 if (t.getNumElements() > 4) {
191 typeLoc,
"vector length has to be less than or equal to 4 but found ")
192 << t.getNumElements();
195 }
else if (
auto t = dyn_cast<TensorArmType>(type)) {
196 if (!isa<ScalarType>(t.getElementType())) {
198 typeLoc,
"only scalar element type allowed in tensor type but found ")
199 << t.getElementType();
204 << type <<
" to compose SPIR-V types";
218 if (
auto t = llvm::dyn_cast<VectorType>(type)) {
219 if (t.getRank() != 1) {
220 parser.
emitError(typeLoc,
"only 1-D vector allowed but found ") << t;
223 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
225 "matrix columns size has to be less than or equal "
226 "to 4 and greater than or equal 2, but found ")
227 << t.getNumElements();
231 if (!llvm::isa<FloatType>(t.getElementType())) {
232 parser.
emitError(typeLoc,
"matrix columns' elements must be of "
234 << t.getElementType();
238 parser.
emitError(typeLoc,
"matrix must be composed using vector "
254 if (!llvm::isa<ImageType>(type)) {
256 "sampled image must be composed using image type, got ")
283 if (!(stride = *optStride)) {
284 parser.
emitError(strideLoc,
"ArrayStride must be greater than zero");
306 if (countDims.size() != 1) {
308 "expected single integer for array element count");
314 int64_t count = countDims[0];
316 parser.
emitError(countLoc,
"expected array length greater than 0");
346 if (dims.size() != 2) {
347 parser.
emitError(countLoc,
"expected row and column count");
360 CooperativeMatrixUseKHR use;
378 bool unranked =
false;
390 if (!unranked && dims.empty()) {
391 parser.
emitError(countLoc,
"arm.tensors do not support rank zero");
395 if (llvm::is_contained(dims, 0)) {
396 parser.
emitError(countLoc,
"arm.tensors do not support zero dimensions");
400 if (llvm::any_of(dims, [](int64_t dim) {
return dim < 0; }) &&
401 llvm::any_of(dims, [](int64_t dim) {
return dim > 0; })) {
402 parser.
emitError(countLoc,
"arm.tensor shape dimensions must be either "
403 "fully dynamic or completed shaped");
435 StringRef storageClassSpec;
440 auto storageClass = symbolizeStorageClass(storageClassSpec);
442 parser.
emitError(storageClassLoc,
"unknown storage class: ")
481 if (countDims.size() != 1) {
482 parser.
emitError(countLoc,
"expected single unsigned "
483 "integer for number of columns");
487 int64_t columnCount = countDims[0];
489 if (columnCount < 2 || columnCount > 4) {
490 parser.
emitError(countLoc,
"matrix is expected to have 2, 3, or 4 "
507 template <
typename ValTy>
516 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
518 parser.
emitError(enumLoc,
"unknown attribute: '") << enumSpec <<
"'";
532 template <
typename IntTy>
544 return parseAndVerifyInteger<unsigned>(dialect, parser);
552 template <
typename ParseType,
typename... Args>
553 struct ParseCommaSeparatedList {
554 std::optional<std::tuple<ParseType, Args...>>
556 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
560 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
563 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
564 if (!remainingValues)
566 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
567 remainingValues.value());
573 template <
typename ParseType>
574 struct ParseCommaSeparatedList<ParseType> {
575 std::optional<std::tuple<ParseType>>
577 if (
auto value = parseAndVerify<ParseType>(dialect, parser))
578 return std::tuple<ParseType>(*value);
605 ParseCommaSeparatedList<
Type, Dim, ImageDepthInfo, ImageArrayedInfo,
606 ImageSamplingInfo, ImageSamplerUseInfo,
607 ImageFormat>{}(dialect, parser);
643 if (
failed(*offsetParseResult))
646 if (offsetInfo.size() != memberTypes.size() - 1) {
648 "offset specification must be given for "
651 offsetInfo.push_back(offset);
663 auto parseDecorations = [&]() {
664 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
665 if (!memberDecoration)
674 memberDecorationInfo.emplace_back(
675 static_cast<uint32_t
>(memberTypes.size() - 1),
676 memberDecoration.value(), memberDecorationValue);
678 memberDecorationInfo.emplace_back(
679 static_cast<uint32_t
>(memberTypes.size() - 1),
680 memberDecoration.value(),
UnitAttr::get(dialect.getContext()));
706 StringRef identifier;
707 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
716 if (succeeded(cyclicParse)) {
719 "recursive struct reference not nested in struct definition");
730 if (
failed(cyclicParse)) {
732 "identifier already used for an enclosing struct");
747 if (!identifier.empty())
758 memberTypes.push_back(memberType);
762 memberDecorationInfo))
766 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
768 "offset specification must be given for all members");
777 auto parseStructDecoration = [&]() {
778 std::optional<spirv::Decoration> decoration =
779 parseAndVerify<spirv::Decoration>(dialect, parser);
789 structDecorationInfo.emplace_back(decoration.value(), decorationValue);
791 structDecorationInfo.emplace_back(decoration.value(),
798 if (
failed(parseStructDecoration()))
804 if (!identifier.empty()) {
806 memberDecorationInfo,
807 structDecorationInfo)))
813 structDecorationInfo);
828 if (keyword ==
"array")
830 if (keyword ==
"coopmatrix")
832 if (keyword ==
"image")
834 if (keyword ==
"ptr")
836 if (keyword ==
"rtarray")
838 if (keyword ==
"sampled_image")
840 if (keyword ==
"struct")
842 if (keyword ==
"matrix")
844 if (keyword ==
"arm.tensor")
857 os <<
", stride=" << stride;
864 os <<
", stride=" << stride;
875 <<
", " << stringifyImageDepthInfo(type.
getDepthInfo()) <<
", "
887 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
895 if (
failed(cyclicPrint)) {
905 auto printMember = [&](
unsigned i) {
909 if (type.
hasOffset() || !decorations.empty()) {
913 if (!decorations.empty())
917 os << stringifyDecoration(decoration.decoration);
918 if (decoration.hasValue()) {
923 llvm::interleaveComma(decorations, os, eachFn);
927 llvm::interleaveComma(llvm::seq<unsigned>(0, type.
getNumElements()), os,
933 if (!decorations.empty()) {
936 os << stringifyDecoration(decoration.decoration);
937 if (decoration.hasValue()) {
942 llvm::interleaveComma(decorations, os, eachFn);
965 if (ShapedType::isDynamic(dim))
981 [&](
auto type) {
print(type, os); })
982 .Default([](
Type) { llvm_unreachable(
"unhandled SPIR-V type"); });
992 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
993 return ub::PoisonOp::create(builder, loc, type, poison);
995 if (!spirv::ConstantOp::isBuildableWith(type))
998 return spirv::ConstantOp::create(builder, loc, type, value);
1005 LogicalResult SPIRVDialect::verifyOperationAttribute(
Operation *op,
1007 StringRef symbol = attribute.
getName().strref();
1011 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
1013 << symbol <<
"' attribute must be an entry point ABI attribute";
1016 if (!llvm::isa<spirv::TargetEnvAttr>(attr))
1017 return op->
emitError(
"'") << symbol <<
"' must be a spirv::TargetEnvAttr";
1019 return op->
emitError(
"found unsupported '")
1020 << symbol <<
"' attribute on operation";
1030 StringRef symbol = attribute.
getName().strref();
1034 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
1037 << symbol <<
"' must be a spirv::InterfaceVarABIAttr";
1041 <<
"' attribute cannot specify storage class "
1042 "when attaching to a non-scalar value";
1045 if (symbol == spirv::DecorationAttr::name) {
1046 if (!isa<spirv::DecorationAttr>(attr))
1048 << symbol <<
"' must be a spirv::DecorationAttr";
1052 return emitError(loc,
"found unsupported '")
1053 << symbol <<
"' attribute on region argument";
1056 LogicalResult SPIRVDialect::verifyRegionArgAttribute(
Operation *op,
1057 unsigned regionIndex,
1060 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1063 Type argType = funcOp.getArgumentTypes()[argIndex];
1068 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1071 return op->
emitError(
"cannot attach SPIR-V attributes to region result");
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
std::optional< unsigned > parseAndVerify< unsigned >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, DialectAsmParser &parser, unsigned &stride)
Parses an optional , stride = N assembly segment.
static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute)
Verifies the given SPIR-V attribute attached to a value of the given valueType is valid.
static Type parseTensorArmType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static void print(ArrayType type, DialectAsmPrinter &os)
static Type parseSampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static ParseResult parseStructMemberDecorations(SPIRVDialect const &dialect, DialectAsmParser &parser, ArrayRef< Type > memberTypes, SmallVectorImpl< StructType::OffsetInfo > &offsetInfo, SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorationInfo)
static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
std::optional< Type > parseAndVerify< Type >(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static bool containsReturn(Region ®ion)
Returns true if the given region contains spirv.Return or spirv.ReturnValue ops.
static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< IntTy > parseAndVerifyInteger(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseMatrixType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parseImageType(SPIRVDialect const &dialect, DialectAsmParser &parser)
static std::optional< ValTy > parseAndVerify(SPIRVDialect const &dialect, DialectAsmParser &parser)
static Type parsePointerType(SPIRVDialect const &dialect, DialectAsmParser &parser)
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
FailureOr< CyclicParseReset > tryStartCyclicParse(AttrOrTypeT attrOrType)
Attempts to start a cyclic parsing region for attrOrType.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalStar()=0
Parse a '*' token if present.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseXInDimensionList()=0
Parse an 'x' token in a dimension list, handling the case where the x is juxtaposed with an element t...
FailureOr< CyclicPrintReset > tryStartCyclicPrint(AttrOrTypeT attrOrType)
Attempts to start a cyclic printing region for attrOrType.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
static ArrayType get(Type elementType, unsigned elementCount)
Scope getScope() const
Returns the scope of the matrix.
uint32_t getRows() const
Returns the number of rows of the matrix.
uint32_t getColumns() const
Returns the number of columns of the matrix.
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Type getElementType() const
CooperativeMatrixUseKHR getUse() const
Returns the use parameter of the cooperative matrix.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
ImageDepthInfo getDepthInfo() const
ImageArrayedInfo getArrayedInfo() const
ImageFormat getImageFormat() const
ImageSamplerUseInfo getSamplerUseInfo() const
Type getElementType() const
ImageSamplingInfo getSamplingInfo() const
static MatrixType get(Type columnType, uint32_t columnCount)
Type getColumnType() const
unsigned getNumColumns() const
Returns the number of columns.
Type getPointeeType() const
StorageClass getStorageClass() const
static PointerType get(Type pointeeType, StorageClass storageClass)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
static RuntimeArrayType get(Type elementType)
Type getImageType() const
static SampledImageType get(Type imageType)
static bool isValid(FloatType)
Returns true if the given integer type is valid for the SPIR-V dialect.
void getStructDecorations(SmallVectorImpl< StructType::StructDecorationInfo > &structDecorations) const
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
unsigned getNumElements() const
Type getElementType(unsigned) const
LogicalResult trySetBody(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Sets the contents of an incomplete identified StructType.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
uint64_t getMemberOffset(unsigned) const
Type getElementType() const
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
ArrayRef< int64_t > getShape() const
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
StringRef getInterfaceVarABIAttrName()
Returns the attribute name for specifying argument ABI information.
ParseResult parseEnumKeywordAttr(EnumClass &value, ParserType &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next keyword in parser as an enumerant of the given EnumClass.
StringRef getTargetEnvAttrName()
Returns the attribute name for specifying SPIR-V target environment.
StringRef getEntryPointABIAttrName()
Returns the attribute name for specifying entry point information.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.