14#include "llvm/ADT/TypeSwitch.h"
15#include "llvm/Support/InterleavedRange.h"
26#include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"
37 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
62 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
87 std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;
104 std::get<2>(key), std::get<3>(key),
105 std::get<4>(key), std::get<5>(key));
125 std::optional<spirv::StorageClass> storageClass,
128 auto descriptorSetAttr =
b.getI32IntegerAttr(descriptorSet);
129 auto bindingAttr =
b.getI32IntegerAttr(binding);
130 auto storageClassAttr =
131 storageClass ?
b.getI32IntegerAttr(
static_cast<uint32_t
>(*storageClass))
133 return get(descriptorSetAttr, bindingAttr, storageClassAttr);
136spirv::InterfaceVarABIAttr
138 IntegerAttr storageClass) {
139 assert(descriptorSet && binding);
141 return Base::get(context, descriptorSet, binding, storageClass);
145 return "interface_var_abi";
149 return cast<IntegerAttr>(
getImpl()->binding).getInt();
153 return cast<IntegerAttr>(
getImpl()->descriptorSet).getInt();
156std::optional<spirv::StorageClass>
159 return static_cast<spirv::StorageClass
>(
160 cast<IntegerAttr>(
getImpl()->storageClass).getValue().getZExtValue());
166 IntegerAttr binding, IntegerAttr storageClass) {
167 if (!descriptorSet.getType().isSignlessInteger(32))
168 return emitError() <<
"expected 32-bit integer for descriptor set";
170 if (!binding.getType().isSignlessInteger(32))
171 return emitError() <<
"expected 32-bit integer for binding";
174 if (
auto storageClassAttr = cast<IntegerAttr>(storageClass)) {
175 auto storageClassValue =
176 spirv::symbolizeStorageClass(storageClassAttr.getInt());
177 if (!storageClassValue)
178 return emitError() <<
"unknown storage class";
180 return emitError() <<
"expected valid storage class";
196 auto versionAttr =
b.getI32IntegerAttr(
static_cast<uint32_t
>(version));
199 capAttrs.reserve(capabilities.size());
200 for (spirv::Capability cap : capabilities)
201 capAttrs.push_back(
b.getI32IntegerAttr(
static_cast<uint32_t
>(cap)));
204 extAttrs.reserve(extensions.size());
205 for (spirv::Extension ext : extensions)
206 extAttrs.push_back(
b.getStringAttr(spirv::stringifyExtension(ext)));
208 return get(versionAttr,
b.getArrayAttr(capAttrs),
b.getArrayAttr(extAttrs));
212 ArrayAttr capabilities,
213 ArrayAttr extensions) {
214 assert(version && capabilities && extensions);
216 return Base::get(context, version, capabilities, extensions);
222 return static_cast<spirv::Version
>(
223 cast<IntegerAttr>(
getImpl()->version).getValue().getZExtValue());
227 :
llvm::mapped_iterator<ArrayAttr::iterator,
230 return *symbolizeExtension(cast<StringAttr>(attr).getValue());
239 return cast<ArrayAttr>(
getImpl()->extensions);
243 :
llvm::mapped_iterator<ArrayAttr::iterator,
246 return *symbolizeCapability(
247 cast<IntegerAttr>(attr).getValue().getZExtValue());
256 return cast<ArrayAttr>(
getImpl()->capabilities);
261 ArrayAttr capabilities, ArrayAttr extensions) {
262 if (!version.getType().isSignlessInteger(32))
263 return emitError() <<
"expected 32-bit integer for version";
265 if (!llvm::all_of(capabilities.getValue(), [](
Attribute attr) {
266 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
267 if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
271 return emitError() <<
"unknown capability in capability list";
273 if (!llvm::all_of(extensions.getValue(), [](
Attribute attr) {
274 if (auto strAttr = dyn_cast<StringAttr>(attr))
275 if (spirv::symbolizeExtension(strAttr.getValue()))
279 return emitError() <<
"unknown extension in extension list";
290 Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {
291 assert(triple && limits &&
"expected valid triple and limits");
293 return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,
300 return cast<spirv::VerCapExtAttr>(
getImpl()->triple);
340 return cast<spirv::ResourceLimitsAttr>(
getImpl()->limits);
347#define GET_ATTRDEF_CLASSES
348#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
358 function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) {
369 auto loc = parser.getCurrentLocation();
371 if (parser.parseKeyword(&keyword) ||
372 failed(processKeyword(loc, keyword)))
390 IntegerAttr descriptorSetAttr;
393 uint32_t descriptorSet = 0;
396 if (!descriptorSetParseResult.has_value() ||
397 failed(*descriptorSetParseResult)) {
398 parser.
emitError(loc,
"missing descriptor set");
407 IntegerAttr bindingAttr;
410 uint32_t binding = 0;
413 if (!bindingParseResult.has_value() || failed(*bindingParseResult)) {
414 parser.
emitError(loc,
"missing binding");
423 IntegerAttr storageClassAttr;
427 StringRef storageClass;
431 if (
auto storageClassSymbol =
432 spirv::symbolizeStorageClass(storageClass)) {
434 static_cast<uint32_t
>(*storageClassSymbol));
436 parser.
emitError(loc,
"unknown storage class: ") << storageClass;
455 IntegerAttr versionAttr;
462 if (
auto versionSymbol = spirv::symbolizeVersion(version)) {
466 parser.
emitError(loc,
"unknown version: ") << version;
475 StringRef errorKeyword;
477 auto processCapability = [&](SMLoc loc, StringRef capability) {
478 if (
auto capSymbol = spirv::symbolizeCapability(capability)) {
479 capabilities.push_back(
483 return errorloc = loc, errorKeyword = capability, failure();
486 if (!errorKeyword.empty())
487 parser.
emitError(errorloc,
"unknown capability: ") << errorKeyword;
498 StringRef errorKeyword;
500 auto processExtension = [&](SMLoc loc, StringRef extension) {
501 if (spirv::symbolizeExtension(extension)) {
505 return errorloc = loc, errorKeyword = extension, failure();
508 if (!errorKeyword.empty())
509 parser.
emitError(errorloc,
"unknown extension: ") << errorKeyword;
532 auto clientAPI = spirv::ClientAPI::Unknown;
540 if (
auto apiSymbol = spirv::symbolizeClientAPI(apiStr))
541 clientAPI = *apiSymbol;
543 parser.
emitError(loc,
"unknown client API: ") << apiStr;
549 Vendor vendorID = Vendor::Unknown;
550 DeviceType deviceType = DeviceType::Unknown;
556 if (
auto vendorSymbol = spirv::symbolizeVendor(vendorStr))
557 vendorID = *vendorSymbol;
559 parser.
emitError(loc,
"unknown vendor: ") << vendorStr;
563 StringRef deviceTypeStr;
566 if (
auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))
567 deviceType = *deviceTypeSymbol;
569 parser.
emitError(loc,
"unknown device type: ") << deviceTypeStr;
582 ResourceLimitsAttr limitsAttr;
587 deviceType, deviceID);
590Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
601 OptionalParseResult
result =
602 generatedAttributeParser(parser, &attrKind, type, attr);
624 << spirv::stringifyVersion(triple.
getVersion()) <<
", "
625 << llvm::interleaved_array(llvm::map_range(
628 << llvm::interleaved_array(
637 if (clientAPI != spirv::ClientAPI::Unknown)
638 printer <<
", api=" << clientAPI;
642 if (vendorID != spirv::Vendor::Unknown) {
643 printer <<
", " << spirv::stringifyVendor(vendorID);
644 if (deviceType != spirv::DeviceType::Unknown) {
645 printer <<
":" << spirv::stringifyDeviceType(deviceType);
647 printer <<
":" << deviceID;
660 printer <<
", " << spirv::stringifyStorageClass(*storageClass);
664void SPIRVDialect::printAttribute(Attribute attr,
665 DialectAsmPrinter &printer)
const {
666 if (succeeded(generatedAttributePrinter(attr, printer)))
669 if (
auto targetEnv = dyn_cast<TargetEnvAttr>(attr))
670 print(targetEnv, printer);
671 else if (
auto vceAttr = dyn_cast<VerCapExtAttr>(attr))
672 print(vceAttr, printer);
673 else if (
auto interfaceVarABIAttr = dyn_cast<InterfaceVarABIAttr>(attr))
674 print(interfaceVarABIAttr, printer);
676 llvm_unreachable(
"unhandled SPIR-V attribute kind");
683void spirv::SPIRVDialect::registerAttributes() {
684 addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
686#define GET_ATTRDEF_LIST
687#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
static Attribute parseTargetEnvAttr(DialectAsmParser &parser)
Parses a spirv::TargetEnvAttr.
static ParseResult parseKeywordList(DialectAsmParser &parser, function_ref< LogicalResult(SMLoc, StringRef)> processKeyword)
Parses a comma-separated list of keywords, invokes processKeyword on each of the parsed keyword,...
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Attribute parseVerCapExtAttr(DialectAsmParser &parser)
static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser)
Parses a spirv::InterfaceVarABIAttr.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
T * allocate()
Allocate an instance of the provided type.
ImplType * getImpl() const
An attribute that specifies the information regarding the interface variable: descriptor set,...
uint32_t getBinding()
Returns binding.
static StringRef getKindName()
Returns the attribute kind's name (without the 'spirv.' prefix).
uint32_t getDescriptorSet()
Returns descriptor set.
static InterfaceVarABIAttr get(uint32_t descriptorSet, uint32_t binding, std::optional< StorageClass > storageClass, MLIRContext *context)
Gets a InterfaceVarABIAttr.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, IntegerAttr descriptorSet, IntegerAttr binding, IntegerAttr storageClass)
std::optional< StorageClass > getStorageClass()
Returns spirv::StorageClass.
An attribute that specifies the target version, allowed extensions and capabilities,...
Version getVersion() const
Returns the target version.
VerCapExtAttr::cap_range getCapabilities()
Returns the target capabilities.
ResourceLimitsAttr getResourceLimits() const
Returns the target resource limits.
static StringRef getKindName()
Returns the attribute kind's name (without the 'spirv.' prefix).
VerCapExtAttr getTripleAttr() const
Returns the (version, capabilities, extensions) triple attribute.
ArrayAttr getCapabilitiesAttr()
Returns the target capabilities as an integer array attribute.
VerCapExtAttr::ext_range getExtensions()
Returns the target extensions.
Vendor getVendorID() const
Returns the vendor ID.
DeviceType getDeviceType() const
Returns the device type.
ClientAPI getClientAPI() const
Returns the client API.
ArrayAttr getExtensionsAttr()
Returns the target extensions as a string array attribute.
uint32_t getDeviceID() const
Returns the device ID.
static constexpr uint32_t kUnknownDeviceID
ID for unknown devices.
static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI=ClientAPI::Unknown, Vendor vendorID=Vendor::Unknown, DeviceType deviceType=DeviceType::Unknown, uint32_t deviceId=kUnknownDeviceID)
Gets a TargetEnvAttr instance.
An attribute that specifies the SPIR-V (version, capabilities, extensions) triple.
cap_range getCapabilities()
Returns the capabilities.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, IntegerAttr version, ArrayAttr capabilities, ArrayAttr extensions)
Version getVersion()
Returns the version.
static StringRef getKindName()
Returns the attribute kind's name (without the 'spirv.' prefix).
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
ArrayAttr getCapabilitiesAttr()
Returns the capabilities as an integer array attribute.
llvm::iterator_range< ext_iterator > ext_range
ext_range getExtensions()
Returns the extensions.
ArrayAttr getExtensionsAttr()
Returns the extensions as a string array attribute.
llvm::iterator_range< cap_iterator > cap_range
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
StorageUniquer::StorageAllocator AttributeStorageAllocator
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
cap_iterator(ArrayAttr::iterator it)
ext_iterator(ArrayAttr::iterator it)
InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding, Attribute storageClass)
std::tuple< Attribute, Attribute, Attribute > KeyTy
bool operator==(const KeyTy &key) const
static InterfaceVarABIAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute > KeyTy
TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI, Vendor vendorID, DeviceType deviceType, uint32_t deviceID, Attribute limits)
bool operator==(const KeyTy &key) const
static TargetEnvAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
std::tuple< Attribute, Attribute, Attribute > KeyTy
static VerCapExtAttributeStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
VerCapExtAttributeStorage(Attribute version, Attribute capabilities, Attribute extensions)
bool operator==(const KeyTy &key) const