14 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"
36 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
40 : descriptorSet(descriptorSet), binding(binding),
41 storageClass(storageClass) {}
44 return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
45 std::get<2>(key) == storageClass;
61 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
65 : version(version), capabilities(capabilities), extensions(extensions) {}
68 return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
69 std::get<2>(key) == extensions;
86 std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;
89 Vendor vendorID, DeviceType deviceType,
91 : triple(triple), limits(limits), clientAPI(clientAPI),
92 vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {}
95 return key == std::make_tuple(triple, clientAPI, vendorID, deviceType,
103 std::get<2>(key), std::get<3>(key),
104 std::get<4>(key), std::get<5>(key));
124 std::optional<spirv::StorageClass> storageClass,
129 auto storageClassAttr =
132 return get(descriptorSetAttr, bindingAttr, storageClassAttr);
137 IntegerAttr storageClass) {
138 assert(descriptorSet && binding);
140 return Base::get(context, descriptorSet, binding, storageClass);
144 return "interface_var_abi";
148 return llvm::cast<IntegerAttr>(getImpl()->binding).getInt();
152 return llvm::cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
155 std::optional<spirv::StorageClass>
157 if (getImpl()->storageClass)
158 return static_cast<spirv::StorageClass
>(
159 llvm::cast<IntegerAttr>(getImpl()->storageClass)
167 IntegerAttr binding, IntegerAttr storageClass) {
168 if (!descriptorSet.getType().isSignlessInteger(32))
169 return emitError() <<
"expected 32-bit integer for descriptor set";
171 if (!binding.getType().isSignlessInteger(32))
172 return emitError() <<
"expected 32-bit integer for binding";
175 if (
auto storageClassAttr = llvm::cast<IntegerAttr>(storageClass)) {
176 auto storageClassValue =
177 spirv::symbolizeStorageClass(storageClassAttr.getInt());
178 if (!storageClassValue)
179 return emitError() <<
"unknown storage class";
181 return emitError() <<
"expected valid storage class";
197 auto versionAttr = b.getI32IntegerAttr(
static_cast<uint32_t
>(version));
200 capAttrs.reserve(capabilities.size());
201 for (spirv::Capability cap : capabilities)
202 capAttrs.push_back(b.getI32IntegerAttr(
static_cast<uint32_t
>(cap)));
205 extAttrs.reserve(extensions.size());
206 for (spirv::Extension ext : extensions)
207 extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
209 return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
213 ArrayAttr capabilities,
214 ArrayAttr extensions) {
215 assert(version && capabilities && extensions);
217 return Base::get(context, version, capabilities, extensions);
223 return static_cast<spirv::Version
>(
224 llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
228 :
llvm::mapped_iterator<ArrayAttr::iterator,
231 return *symbolizeExtension(llvm::cast<StringAttr>(attr).getValue());
240 return llvm::cast<ArrayAttr>(
getImpl()->extensions);
244 :
llvm::mapped_iterator<ArrayAttr::iterator,
247 return *symbolizeCapability(
248 llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
257 return llvm::cast<ArrayAttr>(
getImpl()->capabilities);
262 ArrayAttr capabilities, ArrayAttr extensions) {
263 if (!version.getType().isSignlessInteger(32))
264 return emitError() <<
"expected 32-bit integer for version";
266 if (!llvm::all_of(capabilities.getValue(), [](
Attribute attr) {
267 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
268 if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
272 return emitError() <<
"unknown capability in capability list";
274 if (!llvm::all_of(extensions.getValue(), [](
Attribute attr) {
275 if (auto strAttr = llvm::dyn_cast<StringAttr>(attr))
276 if (spirv::symbolizeExtension(strAttr.getValue()))
280 return emitError() <<
"unknown extension in extension list";
291 Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {
292 assert(triple && limits &&
"expected valid triple and limits");
294 return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,
301 return llvm::cast<spirv::VerCapExtAttr>(
getImpl()->triple);
305 return getTripleAttr().getVersion();
309 return getTripleAttr().getExtensions();
313 return getTripleAttr().getExtensionsAttr();
317 return getTripleAttr().getCapabilities();
321 return getTripleAttr().getCapabilitiesAttr();
341 return llvm::cast<spirv::ResourceLimitsAttr>(
getImpl()->limits);
348 #define GET_ATTRDEF_CLASSES
349 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
359 function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) {
370 auto loc = parser.getCurrentLocation();
372 if (parser.parseKeyword(&keyword) ||
373 failed(processKeyword(loc, keyword)))
391 IntegerAttr descriptorSetAttr;
394 uint32_t descriptorSet = 0;
397 if (!descriptorSetParseResult.has_value() ||
398 failed(*descriptorSetParseResult)) {
399 parser.
emitError(loc,
"missing descriptor set");
408 IntegerAttr bindingAttr;
411 uint32_t binding = 0;
414 if (!bindingParseResult.has_value() || failed(*bindingParseResult)) {
415 parser.
emitError(loc,
"missing binding");
424 IntegerAttr storageClassAttr;
428 StringRef storageClass;
432 if (
auto storageClassSymbol =
433 spirv::symbolizeStorageClass(storageClass)) {
435 static_cast<uint32_t
>(*storageClassSymbol));
437 parser.
emitError(loc,
"unknown storage class: ") << storageClass;
456 IntegerAttr versionAttr;
463 if (
auto versionSymbol = spirv::symbolizeVersion(version)) {
467 parser.
emitError(loc,
"unknown version: ") << version;
472 ArrayAttr capabilitiesAttr;
476 StringRef errorKeyword;
478 auto processCapability = [&](SMLoc loc, StringRef capability) {
479 if (
auto capSymbol = spirv::symbolizeCapability(capability)) {
480 capabilities.push_back(
484 return errorloc = loc, errorKeyword = capability, failure();
487 if (!errorKeyword.empty())
488 parser.
emitError(errorloc,
"unknown capability: ") << errorKeyword;
495 ArrayAttr extensionsAttr;
499 StringRef errorKeyword;
501 auto processExtension = [&](SMLoc loc, StringRef extension) {
502 if (spirv::symbolizeExtension(extension)) {
506 return errorloc = loc, errorKeyword = extension, failure();
509 if (!errorKeyword.empty())
510 parser.
emitError(errorloc,
"unknown extension: ") << errorKeyword;
533 auto clientAPI = spirv::ClientAPI::Unknown;
541 if (
auto apiSymbol = spirv::symbolizeClientAPI(apiStr))
542 clientAPI = *apiSymbol;
544 parser.
emitError(loc,
"unknown client API: ") << apiStr;
550 Vendor vendorID = Vendor::Unknown;
551 DeviceType deviceType = DeviceType::Unknown;
557 if (
auto vendorSymbol = spirv::symbolizeVendor(vendorStr))
558 vendorID = *vendorSymbol;
560 parser.
emitError(loc,
"unknown vendor: ") << vendorStr;
564 StringRef deviceTypeStr;
567 if (
auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))
568 deviceType = *deviceTypeSymbol;
570 parser.
emitError(loc,
"unknown device type: ") << deviceTypeStr;
583 ResourceLimitsAttr limitsAttr;
588 deviceType, deviceID);
603 generatedAttributeParser(parser, &attrKind, type, attr);
626 << spirv::stringifyVersion(triple.
getVersion()) <<
", [";
627 llvm::interleaveComma(
629 [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
632 os << llvm::cast<StringAttr>(attr).getValue();
641 if (clientAPI != spirv::ClientAPI::Unknown)
642 printer <<
", api=" << clientAPI;
646 if (vendorID != spirv::Vendor::Unknown) {
647 printer <<
", " << spirv::stringifyVendor(vendorID);
648 if (deviceType != spirv::DeviceType::Unknown) {
649 printer <<
":" << spirv::stringifyDeviceType(deviceType);
651 printer <<
":" << deviceID;
664 printer <<
", " << spirv::stringifyStorageClass(*storageClass);
668 void SPIRVDialect::printAttribute(
Attribute attr,
670 if (succeeded(generatedAttributePrinter(attr, printer)))
673 if (
auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr))
674 print(targetEnv, printer);
675 else if (
auto vceAttr = llvm::dyn_cast<VerCapExtAttr>(attr))
676 print(vceAttr, printer);
677 else if (
auto interfaceVarABIAttr = llvm::dyn_cast<InterfaceVarABIAttr>(attr))
678 print(interfaceVarABIAttr, printer);
680 llvm_unreachable(
"unhandled SPIR-V attribute kind");
687 void spirv::SPIRVDialect::registerAttributes() {
688 addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
690 #define GET_ATTRDEF_LIST
691 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
static Attribute parseTargetEnvAttr(DialectAsmParser &parser)
Parses a spirv::TargetEnvAttr.
static ParseResult parseKeywordList(DialectAsmParser &parser, function_ref< LogicalResult(SMLoc, StringRef)> processKeyword)
Parses a comma-separated list of keywords, invokes processKeyword on each of the parsed keyword,...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Attribute parseVerCapExtAttr(DialectAsmParser &parser)
static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser)
Parses a spirv::InterfaceVarABIAttr.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ImplType * getImpl() const
Utility for easy access to the storage instance.
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
static StringRef getKindName()
Returns the attribute kind's name (without the 'spirv.' prefix).
uint32_t getDescriptorSet()
Returns descriptor set.
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, IntegerAttr descriptorSet, IntegerAttr binding, IntegerAttr storageClass)
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
VerCapExtAttr::cap_range getCapabilities()
Returns the target capabilities.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
static StringRef getKindName()
Returns the attribute kind's name (without the 'spirv.' prefix).
VerCapExtAttr getTripleAttr() const
Returns the (version, capabilities, extensions) triple attribute.
ArrayAttr getCapabilitiesAttr()
Returns the target capabilities as an integer array attribute.
VerCapExtAttr::ext_range getExtensions()
Returns the target extensions.
Vendor getVendorID() const
Returns the vendor ID.
DeviceType getDeviceType() const
Returns the device type.
ClientAPI getClientAPI() const
Returns the client API.
ArrayAttr getExtensionsAttr()
Returns the target extensions as a string array attribute.
uint32_t getDeviceID() const
Returns the device ID.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
cap_range getCapabilities()
Returns the capabilities.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, IntegerAttr version, ArrayAttr capabilities, ArrayAttr extensions)
Version getVersion()
Returns the version.
static StringRef getKindName()
Returns the attribute kind's name (without the 'spirv.' prefix).
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
ArrayAttr getCapabilitiesAttr()
Returns the capabilities as an integer array attribute.
ext_range getExtensions()
Returns the extensions.
ArrayAttr getExtensionsAttr()
Returns the extensions as a string array attribute.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
cap_iterator(ArrayAttr::iterator it)
ext_iterator(ArrayAttr::iterator it)
InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding, Attribute storageClass)
std::tuple< Attribute, Attribute, Attribute > KeyTy
bool operator==(const KeyTy &key) const
static InterfaceVarABIAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
static TargetEnvAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI, Vendor vendorID, DeviceType deviceType, uint32_t deviceID, Attribute limits)
bool operator==(const KeyTy &key) const
std::tuple< Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute > KeyTy
static VerCapExtAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Attribute, Attribute, Attribute > KeyTy
VerCapExtAttributeStorage(Attribute version, Attribute capabilities, Attribute extensions)
bool operator==(const KeyTy &key) const