10 #include "llvm/IR/Constants.h"
18 struct LoopMetadataConversion {
19 LoopMetadataConversion(
const llvm::MDNode *node,
Location loc,
21 : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
24 LoopAnnotationAttr convert();
30 const llvm::MDNode *lookupAndEraseProperty(StringRef name);
43 StringRef disableName,
44 bool negated =
false);
61 llvm::StringMap<const llvm::MDNode *> propertyMap;
62 const llvm::MDNode *node;
71 if (node->getNumOperands() == 0 ||
72 dyn_cast<llvm::MDNode>(node->getOperand(0)) != node)
75 for (
const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) {
76 if (
auto *diLoc = dyn_cast<llvm::DILocation>(operand)) {
77 locations.push_back(diLoc);
81 auto *
property = dyn_cast<llvm::MDNode>(operand);
83 return emitWarning(loc) <<
"expected all loop properties to be either "
84 "debug locations or metadata nodes";
86 if (property->getNumOperands() == 0)
87 return emitWarning(loc) <<
"cannot import empty loop property";
89 auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0));
91 return emitWarning(loc) <<
"cannot import loop property without a name";
92 StringRef name = nameNode->getString();
94 bool succ = propertyMap.try_emplace(name, property).second;
97 <<
"cannot import loop properties with duplicated names " << name;
104 LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
105 auto it = propertyMap.find(name);
106 if (it == propertyMap.end())
108 const llvm::MDNode *
property = it->getValue();
109 propertyMap.erase(it);
114 const llvm::MDNode *
property = lookupAndEraseProperty(name);
118 if (property->getNumOperands() != 1)
120 <<
"expected metadata node " << name <<
" to hold no value";
126 StringRef enableName, StringRef disableName,
bool negated) {
127 auto enable = lookupUnitNode(enableName);
128 auto disable = lookupUnitNode(disableName);
132 if (*enable && *disable)
134 <<
"expected metadata nodes " << enableName <<
" and " << disableName
135 <<
" to be mutually exclusive.";
147 const llvm::MDNode *
property = lookupAndEraseProperty(name);
151 auto emitNodeWarning = [&]() {
153 <<
"expected metadata node " << name <<
" to hold a boolean value";
156 if (property->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt *val =
159 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
160 if (!val || val->getBitWidth() != 1)
161 return emitNodeWarning();
163 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
167 LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
168 const llvm::MDNode *
property = lookupAndEraseProperty(name);
172 auto emitNodeWarning = [&]() {
174 <<
"expected metadata node " << name <<
" to hold an integer value";
177 if (property->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt *val =
180 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
181 if (!val || val->getBitWidth() != 32)
182 return emitNodeWarning();
184 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
188 const llvm::MDNode *
property = lookupAndEraseProperty(name);
190 return IntegerAttr(
nullptr);
192 auto emitNodeWarning = [&]() {
194 <<
"expected metadata node " << name <<
" to hold an i32 value";
197 if (property->getNumOperands() != 2)
198 return emitNodeWarning();
200 llvm::ConstantInt *val =
201 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
202 if (!val || val->getBitWidth() != 32)
203 return emitNodeWarning();
206 val->getValue().getLimitedValue());
210 const llvm::MDNode *
property = lookupAndEraseProperty(name);
214 auto emitNodeWarning = [&]() {
216 <<
"expected metadata node " << name <<
" to hold an MDNode";
219 if (property->getNumOperands() != 2)
220 return emitNodeWarning();
222 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1));
224 return emitNodeWarning();
230 LoopMetadataConversion::lookupMDNodes(StringRef name) {
231 const llvm::MDNode *
property = lookupAndEraseProperty(name);
236 auto emitNodeWarning = [&]() {
237 return emitWarning(loc) <<
"expected metadata node " << name
238 <<
" to hold one or multiple MDNodes";
241 if (property->getNumOperands() < 2)
242 return emitNodeWarning();
244 for (
unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
245 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i));
247 return emitNodeWarning();
255 LoopMetadataConversion::lookupFollowupNode(StringRef name) {
256 auto node = lookupMDNode(name);
259 if (*node ==
nullptr)
260 return LoopAnnotationAttr(
nullptr);
262 return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
267 template <
typename T>
274 template <
typename T,
typename... P>
276 bool anyFailed = (
failed(args) || ...);
284 return T::get(ctx, *args...);
289 lookupBoolNode(
"llvm.loop.vectorize.enable",
true);
291 lookupBoolNode(
"llvm.loop.vectorize.predicate.enable");
293 lookupBoolNode(
"llvm.loop.vectorize.scalable.enable");
296 lookupFollowupNode(
"llvm.loop.vectorize.followup_vectorized");
298 lookupFollowupNode(
"llvm.loop.vectorize.followup_epilogue");
300 lookupFollowupNode(
"llvm.loop.vectorize.followup_all");
302 return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable,
303 scalableEnable, width, followupVec,
304 followupEpi, followupAll);
309 return createIfNonNull<LoopInterleaveAttr>(ctx, count);
314 "llvm.loop.unroll.enable",
"llvm.loop.unroll.disable",
true);
317 lookupUnitNode(
"llvm.loop.unroll.runtime.disable");
320 lookupFollowupNode(
"llvm.loop.unroll.followup_unrolled");
322 lookupFollowupNode(
"llvm.loop.unroll.followup_remainder");
324 lookupFollowupNode(
"llvm.loop.unroll.followup_all");
326 return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
327 full, followupUnrolled,
328 followupRemainder, followupAll);
332 LoopMetadataConversion::convertUnrollAndJamAttr() {
334 "llvm.loop.unroll_and_jam.enable",
"llvm.loop.unroll_and_jam.disable",
337 lookupIntNode(
"llvm.loop.unroll_and_jam.count");
339 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_outer");
341 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_inner");
343 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_outer");
345 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_inner");
347 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_all");
348 return createIfNonNull<LoopUnrollAndJamAttr>(
349 ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
350 followupRemainderInner, followupAll);
356 lookupUnitNode(
"llvm.loop.licm_versioning.disable");
357 return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable);
362 lookupBoolNode(
"llvm.loop.distribute.enable",
true);
364 lookupFollowupNode(
"llvm.loop.distribute.followup_coincident");
366 lookupFollowupNode(
"llvm.loop.distribute.followup_sequential");
368 lookupFollowupNode(
"llvm.loop.distribute.followup_fallback");
370 lookupFollowupNode(
"llvm.loop.distribute.followup_all");
371 return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident,
373 followupFallback, followupAll);
379 lookupIntNode(
"llvm.loop.pipeline.initiationinterval");
380 return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval);
385 return createIfNonNull<LoopPeeledAttr>(ctx, count);
390 lookupUnitNode(
"llvm.loop.unswitch.partial.disable");
391 return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable);
395 LoopMetadataConversion::convertParallelAccesses() {
397 lookupMDNodes(
"llvm.loop.parallel_accesses");
401 for (llvm::MDNode *node : *nodes) {
403 loopAnnotationImporter.lookupAccessGroupAttrs(node);
404 if (
failed(accessGroups)) {
405 emitWarning(loc) <<
"could not lookup access group";
408 llvm::append_range(refs, *accessGroups);
413 FusedLoc LoopMetadataConversion::convertStartLoc() {
414 if (locations.empty())
416 return dyn_cast<FusedLoc>(
417 loopAnnotationImporter.moduleImport.translateLoc(locations[0]));
421 if (locations.size() < 2)
423 if (locations.size() > 2)
425 <<
"expected loop metadata to have at most two DILocations";
426 return dyn_cast<FusedLoc>(
427 loopAnnotationImporter.moduleImport.translateLoc(locations[1]));
430 LoopAnnotationAttr LoopMetadataConversion::convert() {
431 if (
failed(initConversionState()))
435 lookupUnitNode(
"llvm.loop.disable_nonforced");
447 lookupIntNodeAsBoolAttr(
"llvm.loop.isvectorized");
449 convertParallelAccesses();
452 if (!propertyMap.empty()) {
453 for (
auto name : propertyMap.keys())
454 emitWarning(loc) <<
"unknown loop annotation " << name;
461 return createIfNonNull<LoopAnnotationAttr>(
462 ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
463 unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
464 unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
476 auto it = loopMetadataMapping.find(node);
477 if (it != loopMetadataMapping.end())
478 return it->getSecond();
480 LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *
this).convert();
482 mapLoopMetadata(node, attr);
490 if (!node->getNumOperands())
491 accessGroups.push_back(node);
492 for (
const llvm::MDOperand &operand : node->operands()) {
493 auto *childNode = dyn_cast<llvm::MDNode>(operand);
496 accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
500 for (
const llvm::MDNode *accessGroup : accessGroups) {
501 if (accessGroupMapping.count(accessGroup))
504 if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
506 <<
"expected an access group node to be empty and distinct";
509 accessGroupMapping[accessGroup] = builder.
getAttr<AccessGroupAttr>();
519 if (!node->getNumOperands())
520 accessGroups.push_back(accessGroupMapping.lookup(node));
521 for (
const llvm::MDOperand &operand : node->operands()) {
522 auto *node = cast<llvm::MDNode>(operand.get());
523 accessGroups.push_back(accessGroupMapping.lookup(node));
526 if (llvm::is_contained(accessGroups,
nullptr))
static MLIRContext * getContext(OpFoldResult val)
static T createIfNonNull(MLIRContext *ctx, const P &...args)
Helper function that only creates and attribute of type T if all argument conversion were successfull...
static bool isEmptyOrNull(const Attribute attr)
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class provides support for representing a failure result, or a valid value of type T.
A helper class that converts llvm.loop metadata nodes into corresponding LoopAnnotationAttrs and llvm...
LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, Location loc)
LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc)
Converts all LLVM access groups starting from node to MLIR access group attributes.
FailureOr< SmallVector< AccessGroupAttr > > lookupAccessGroupAttrs(const llvm::MDNode *node) const
Returns the access group attribute that map to the access group nodes starting from the access group ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.