16 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/MathExtras.h"
24 #include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
26 #define GET_ATTRDEF_CLASSES
27 #include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
29 #define DEBUG_TYPE "dlti"
38 using KeyTy = std::pair<DataLayoutEntryKey, Attribute>;
60 return Base::get(key.getContext(), key, value);
68 return getImpl()->entryKey;
71 Attribute DataLayoutEntryAttr::getValue()
const {
return getImpl()->value; }
80 std::string identifier;
88 parser.
emitError(idLoc) <<
"expected a type or a quoted string";
98 return type ?
get(type, value)
99 :
get(parser.getBuilder().getStringAttr(identifier), value);
104 if (
auto type = llvm::dyn_cast_if_present<Type>(getKey()))
107 os <<
"\"" << getKey().get<StringAttr>().strref() <<
"\"";
108 os <<
", " << getValue() <<
">";
119 for (DataLayoutEntryInterface entry : entries) {
120 if (
auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
121 if (!types.insert(type).second)
122 return emitError() <<
"repeated layout entry key: " << type;
124 auto id = entry.getKey().get<StringAttr>();
125 if (!ids.insert(
id).second)
126 return emitError() <<
"repeated layout entry key: " <<
id.getValue();
153 unsigned oldEntriesSize = oldEntries.size();
154 for (DataLayoutEntryInterface entry : newEntries) {
159 bool replaced =
false;
160 for (
unsigned i = 0; i < oldEntriesSize; ++i) {
161 if (oldEntries[i].getKey() == entry.getKey()) {
162 oldEntries[i] = entry;
168 oldEntries.push_back(entry);
185 spec.bucketEntriesByType(newEntriesForType, newEntriesForID);
188 for (
auto &kvp : newEntriesForType) {
189 if (!entriesForType.count(kvp.first)) {
190 entriesForType[kvp.first] = std::move(kvp.second);
194 Type typeSample = kvp.second.front().getKey().get<
Type>();
197 "unexpected data layout entry for built-in type");
199 auto interface = llvm::cast<DataLayoutTypeInterface>(typeSample);
200 if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second))
206 for (
const auto &kvp : newEntriesForID) {
207 StringAttr
id = kvp.second.getKey().get<StringAttr>();
208 Dialect *dialect =
id.getReferencedDialect();
209 if (!entriesForID.count(
id)) {
210 entriesForID[id] = kvp.second;
218 dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
219 entriesForID[
id], kvp.second)
222 if (!entriesForID[
id])
233 if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) {
234 return !llvm::isa<DataLayoutSpecAttr>(spec);
241 for (DataLayoutSpecInterface spec : specs)
249 llvm::append_range(entries, llvm::make_second_range(entriesForID));
250 for (
const auto &kvp : entriesForType)
251 llvm::append_range(entries, kvp.getSecond());
257 DataLayoutSpecAttr::getEndiannessIdentifier(
MLIRContext *context)
const {
262 DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(
MLIRContext *context)
const {
264 DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
267 StringAttr DataLayoutSpecAttr::getProgramMemorySpaceIdentifier(
270 DLTIDialect::kDataLayoutProgramMemorySpaceKey);
274 DataLayoutSpecAttr::getGlobalMemorySpaceIdentifier(
MLIRContext *context)
const {
276 DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
280 DataLayoutSpecAttr::getStackAlignmentIdentifier(
MLIRContext *context)
const {
282 DLTIDialect::kDataLayoutStackAlignmentKey);
299 [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
309 llvm::interleaveComma(getEntries(), os);
323 std::string deviceID;
327 <<
"DeviceID is missing, or is not of string type";
336 auto target_device_spec =
338 if (failed(target_device_spec)) {
340 <<
"Error in parsing target device spec";
345 *target_device_spec);
351 return printer << param.first <<
" : " << param.second;
363 for (DataLayoutEntryInterface entry : entries) {
364 if (
auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
366 <<
"dlti.target_device_spec does not allow type as a key: "
370 auto id = entry.getKey().get<StringAttr>();
371 if (!ids.insert(
id).second)
372 return emitError() <<
"repeated layout entry key: " <<
id.getValue();
388 for (
const auto &entry : entries) {
389 TargetDeviceSpecInterface target_device_spec = entry.second;
393 target_device_spec.getEntries())))
397 TargetSystemSpecInterface::DeviceID device_id = entry.first;
398 if (!device_ids.insert(device_id).second) {
399 return emitError() <<
"repeated Device ID in dlti.target_system_spec: "
413 static std::pair<DLTIQueryInterface, Operation *>
415 DLTIQueryInterface queryable = {};
420 if ((queryable = llvm::dyn_cast<DLTIQueryInterface>(attr.getValue())))
424 return std::pair(queryable, op);
431 auto diag = op->
emitError() <<
"target op of failed DLTI query";
432 diag.attachNote(op->
getLoc()) <<
"no keys provided to attempt query with";
438 Operation *reportOp = (queryOp ? queryOp : op);
442 auto diag = op->
emitError() <<
"target op of failed DLTI query";
444 <<
"no DLTI-queryable attrs on target op or any of its ancestors";
452 .Case<StringAttr,
Type>(
453 [&](
auto key) { llvm::raw_string_ostream(buf) << key; })
454 .Default([](
auto) { llvm_unreachable(
"unexpected entry key kind"); });
460 if (
auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
461 auto maybeAttr = map.query(key);
462 if (failed(maybeAttr)) {
464 auto diag = op->
emitError() <<
"target op of failed DLTI query";
466 <<
"key " << keyToStr(key)
467 <<
" has no DLTI-mapping per attr: " << map;
471 currentAttr = *maybeAttr;
474 std::string commaSeparatedKeys;
476 keys.take_front(idx),
477 [&](
auto key) { commaSeparatedKeys += keyToStr(key); },
478 [&]() { commaSeparatedKeys +=
","; });
480 auto diag = op->
emitError() <<
"target op of failed DLTI query";
482 <<
"got non-DLTI-queryable attribute upon looking up keys ["
483 << commaSeparatedKeys <<
"] at op";
492 constexpr
const StringLiteral mlir::DLTIDialect::kDataLayoutAttrName;
493 constexpr
const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessKey;
494 constexpr
const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessBig;
495 constexpr
const StringLiteral mlir::DLTIDialect::kDataLayoutEndiannessLittle;
502 LogicalResult verifyEntry(DataLayoutEntryInterface entry,
504 StringRef entryName = entry.getKey().get<StringAttr>().strref();
505 if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
506 auto value = llvm::dyn_cast<StringAttr>(entry.getValue());
508 (value.getValue() == DLTIDialect::kDataLayoutEndiannessBig ||
509 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle))
511 return emitError(loc) <<
"'" << entryName
512 <<
"' data layout entry is expected to be either '"
513 << DLTIDialect::kDataLayoutEndiannessBig <<
"' or '"
514 << DLTIDialect::kDataLayoutEndiannessLittle <<
"'";
516 if (entryName == DLTIDialect::kDataLayoutAllocaMemorySpaceKey ||
517 entryName == DLTIDialect::kDataLayoutProgramMemorySpaceKey ||
518 entryName == DLTIDialect::kDataLayoutGlobalMemorySpaceKey ||
519 entryName == DLTIDialect::kDataLayoutStackAlignmentKey)
521 return emitError(loc) <<
"unknown data layout entry name: " << entryName;
526 void DLTIDialect::initialize() {
528 #define GET_ATTRDEF_LIST
529 #include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
531 addInterfaces<TargetDataLayoutInterface>();
534 LogicalResult DLTIDialect::verifyOperationAttribute(
Operation *op,
536 if (attr.
getName() == DLTIDialect::kDataLayoutAttrName) {
537 if (!llvm::isa<DataLayoutSpecAttr>(attr.
getValue())) {
538 return op->
emitError() <<
"'" << DLTIDialect::kDataLayoutAttrName
539 <<
"' is expected to be a #dlti.dl_spec attribute";
541 if (isa<ModuleOp>(op))
546 if (attr.
getName() == DLTIDialect::kTargetSystemDescAttrName) {
547 if (!llvm::isa<TargetSystemSpecAttr>(attr.
getValue())) {
549 <<
"'" << DLTIDialect::kTargetSystemDescAttrName
550 <<
"' is expected to be a #dlti.target_system_spec attribute";
555 if (attr.
getName() == DLTIDialect::kMapAttrName) {
556 if (!llvm::isa<MapAttr>(attr.
getValue())) {
557 return op->
emitError() <<
"'" << DLTIDialect::kMapAttrName
558 <<
"' is expected to be a #dlti.map attribute";
564 <<
"' not supported by dialect";
static std::pair< DLTIQueryInterface, Operation * > getClosestQueryable(Operation *op)
Retrieve the first DLTIQueryInterface-implementing attribute that is attached to op or such an attr o...
static LogicalResult combineOneSpec(DataLayoutSpecInterface spec, DenseMap< TypeID, DataLayoutEntryList > &entriesForType, DenseMap< StringAttr, DataLayoutEntryInterface > &entriesForID)
Combines a data layout spec into the given lists of entries organized by type class and identifier,...
static void overwriteDuplicateEntries(SmallVectorImpl< DataLayoutEntryInterface > &oldEntries, ArrayRef< DataLayoutEntryInterface > newEntries)
Given a list of old and a list of new entries, overwrites old entries with new ones if they have matc...
static LogicalResult verifyEntries(function_ref< InFlightDiagnostic()> emitError, ArrayRef< DataLayoutEntryInterface > entries)
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual OptionalParseResult parseOptionalType(Type &result)=0
Parse an optional type.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
StringAttr getStringAttr(const Twine &bytes)
An interface to be implemented by dialects that can have identifiers in the data layout specification...
DataLayoutDialectInterface(Dialect *dialect)
static DataLayoutEntryInterface defaultCombine(DataLayoutEntryInterface outer, DataLayoutEntryInterface inner)
Default implementation of entry combination that combines identical entries and returns null otherwis...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This is a utility allocator used to allocate memory for instances of derived types.
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Dialect & getDialect() const
Get the dialect this type is registered to.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
DataLayoutEntryKey entryKey
DataLayoutEntryAttrStorage(DataLayoutEntryKey entryKey, Attribute value)
bool operator==(const KeyTy &other) const
std::pair< DataLayoutEntryKey, Attribute > KeyTy
static DataLayoutEntryAttrStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult verifyDataLayoutOp(Operation *op)
Verifies that the operation implementing the data layout interface, or a module operation,...
FailureOr< Attribute > query(Operation *op, ArrayRef< DataLayoutEntryKey > keys, bool emitError=false)
Perform a DLTI-query at op, recursively querying each key of keys on query interface-implementing att...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< StringAttr, TargetDeviceSpecInterface > DeviceIDTargetDeviceSpecPair
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
static FailureOr< DeviceIDTargetDeviceSpecPair > parse(AsmParser &parser)
Provide a template class that can be specialized by users to dispatch to parsers.