10#include "llvm/IR/Constants.h"
18struct LoopMetadataConversion {
19 LoopMetadataConversion(
const llvm::MDNode *node, Location loc,
20 LoopAnnotationImporter &loopAnnotationImporter)
21 : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
24 LoopAnnotationAttr convert();
27 LogicalResult initConversionState();
30 const llvm::MDNode *lookupAndEraseProperty(StringRef name);
35 FailureOr<BoolAttr> lookupUnitNode(StringRef name);
36 FailureOr<BoolAttr> lookupBoolNode(StringRef name,
bool negated =
false);
37 FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
38 FailureOr<IntegerAttr> lookupIntNode(StringRef name);
39 FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
40 FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
41 FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name);
42 FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName,
43 StringRef disableName,
44 bool negated =
false);
47 FailureOr<LoopVectorizeAttr> convertVectorizeAttr();
48 FailureOr<LoopInterleaveAttr> convertInterleaveAttr();
49 FailureOr<LoopUnrollAttr> convertUnrollAttr();
50 FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr();
51 FailureOr<LoopLICMAttr> convertLICMAttr();
52 FailureOr<LoopDistributeAttr> convertDistributeAttr();
53 FailureOr<LoopPipelineAttr> convertPipelineAttr();
54 FailureOr<LoopPeeledAttr> convertPeeledAttr();
55 FailureOr<LoopUnswitchAttr> convertUnswitchAttr();
56 FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses();
57 FusedLoc convertStartLoc();
58 FailureOr<FusedLoc> convertEndLoc();
60 llvm::SmallVector<llvm::DILocation *, 2> locations;
61 llvm::StringMap<const llvm::MDNode *> propertyMap;
62 const llvm::MDNode *node;
64 LoopAnnotationImporter &loopAnnotationImporter;
69LogicalResult LoopMetadataConversion::initConversionState() {
71 if (node->getNumOperands() == 0 ||
72 dyn_cast<llvm::MDNode>(node->getOperand(0)) != node)
75 for (
const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) {
76 if (
auto *diLoc = dyn_cast<llvm::DILocation>(operand)) {
77 locations.push_back(diLoc);
81 auto *
property = dyn_cast<llvm::MDNode>(operand);
83 return emitWarning(loc) <<
"expected all loop properties to be either "
84 "debug locations or metadata nodes";
86 if (property->getNumOperands() == 0)
87 return emitWarning(loc) <<
"cannot import empty loop property";
89 auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0));
91 return emitWarning(loc) <<
"cannot import loop property without a name";
92 StringRef name = nameNode->getString();
94 bool succ = propertyMap.try_emplace(name, property).second;
97 <<
"cannot import loop properties with duplicated names " << name;
104LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
105 auto it = propertyMap.find(name);
106 if (it == propertyMap.end())
108 const llvm::MDNode *
property = it->getValue();
109 propertyMap.erase(it);
113FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) {
114 const llvm::MDNode *
property = lookupAndEraseProperty(name);
116 return BoolAttr(
nullptr);
118 if (property->getNumOperands() != 1)
120 <<
"expected metadata node " << name <<
" to hold no value";
125FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode(
126 StringRef enableName, StringRef disableName,
bool negated) {
127 auto enable = lookupUnitNode(enableName);
128 auto disable = lookupUnitNode(disableName);
132 if (*enable && *disable)
134 <<
"expected metadata nodes " << enableName <<
" and " << disableName
135 <<
" to be mutually exclusive.";
142 return BoolAttr(
nullptr);
145FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
147 const llvm::MDNode *
property = lookupAndEraseProperty(name);
149 return BoolAttr(
nullptr);
151 auto emitNodeWarning = [&]() {
153 <<
"expected metadata node " << name <<
" to hold a boolean value";
156 if (property->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt *val =
159 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
160 if (!val || val->getBitWidth() != 1)
161 return emitNodeWarning();
163 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
167LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
168 const llvm::MDNode *
property = lookupAndEraseProperty(name);
170 return BoolAttr(
nullptr);
172 auto emitNodeWarning = [&]() {
174 <<
"expected metadata node " << name <<
" to hold an integer value";
177 if (property->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt *val =
180 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
181 if (!val || val->getBitWidth() != 32)
182 return emitNodeWarning();
184 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
187FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
188 const llvm::MDNode *
property = lookupAndEraseProperty(name);
190 return IntegerAttr(
nullptr);
192 auto emitNodeWarning = [&]() {
194 <<
"expected metadata node " << name <<
" to hold an i32 value";
197 if (property->getNumOperands() != 2)
198 return emitNodeWarning();
200 llvm::ConstantInt *val =
201 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
202 if (!val || val->getBitWidth() != 32)
203 return emitNodeWarning();
205 return IntegerAttr::get(IntegerType::get(ctx, 32),
206 val->getValue().getLimitedValue());
209FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) {
210 const llvm::MDNode *
property = lookupAndEraseProperty(name);
214 auto emitNodeWarning = [&]() {
216 <<
"expected metadata node " << name <<
" to hold an MDNode";
219 if (property->getNumOperands() != 2)
220 return emitNodeWarning();
222 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1));
224 return emitNodeWarning();
229FailureOr<SmallVector<llvm::MDNode *>>
230LoopMetadataConversion::lookupMDNodes(StringRef name) {
231 const llvm::MDNode *
property = lookupAndEraseProperty(name);
232 SmallVector<llvm::MDNode *> res;
236 auto emitNodeWarning = [&]() {
237 return emitWarning(loc) <<
"expected metadata node " << name
238 <<
" to hold one or multiple MDNodes";
241 if (property->getNumOperands() < 2)
242 return emitNodeWarning();
244 for (
unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
245 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i));
247 return emitNodeWarning();
254FailureOr<LoopAnnotationAttr>
255LoopMetadataConversion::lookupFollowupNode(StringRef name) {
256 auto node = lookupMDNode(name);
259 if (*node ==
nullptr)
260 return LoopAnnotationAttr(
nullptr);
274template <
typename T,
typename... P>
276 bool anyFailed = (failed(args) || ...);
284 return T::get(ctx, *args...);
287FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() {
288 FailureOr<BoolAttr> enable =
289 lookupBoolNode(
"llvm.loop.vectorize.enable",
true);
290 FailureOr<BoolAttr> predicateEnable =
291 lookupBoolNode(
"llvm.loop.vectorize.predicate.enable");
292 FailureOr<BoolAttr> scalableEnable =
293 lookupBoolNode(
"llvm.loop.vectorize.scalable.enable");
294 FailureOr<IntegerAttr> width = lookupIntNode(
"llvm.loop.vectorize.width");
295 FailureOr<LoopAnnotationAttr> followupVec =
296 lookupFollowupNode(
"llvm.loop.vectorize.followup_vectorized");
297 FailureOr<LoopAnnotationAttr> followupEpi =
298 lookupFollowupNode(
"llvm.loop.vectorize.followup_epilogue");
299 FailureOr<LoopAnnotationAttr> followupAll =
300 lookupFollowupNode(
"llvm.loop.vectorize.followup_all");
303 scalableEnable, width, followupVec,
304 followupEpi, followupAll);
307FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() {
308 FailureOr<IntegerAttr> count = lookupIntNode(
"llvm.loop.interleave.count");
312FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
313 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
314 "llvm.loop.unroll.enable",
"llvm.loop.unroll.disable",
true);
315 FailureOr<IntegerAttr> count = lookupIntNode(
"llvm.loop.unroll.count");
316 FailureOr<BoolAttr> runtimeDisable =
317 lookupUnitNode(
"llvm.loop.unroll.runtime.disable");
318 FailureOr<BoolAttr> full = lookupUnitNode(
"llvm.loop.unroll.full");
319 FailureOr<LoopAnnotationAttr> followupUnrolled =
320 lookupFollowupNode(
"llvm.loop.unroll.followup_unrolled");
321 FailureOr<LoopAnnotationAttr> followupRemainder =
322 lookupFollowupNode(
"llvm.loop.unroll.followup_remainder");
323 FailureOr<LoopAnnotationAttr> followupAll =
324 lookupFollowupNode(
"llvm.loop.unroll.followup_all");
327 full, followupUnrolled,
328 followupRemainder, followupAll);
331FailureOr<LoopUnrollAndJamAttr>
332LoopMetadataConversion::convertUnrollAndJamAttr() {
333 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
334 "llvm.loop.unroll_and_jam.enable",
"llvm.loop.unroll_and_jam.disable",
336 FailureOr<IntegerAttr> count =
337 lookupIntNode(
"llvm.loop.unroll_and_jam.count");
338 FailureOr<LoopAnnotationAttr> followupOuter =
339 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_outer");
340 FailureOr<LoopAnnotationAttr> followupInner =
341 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_inner");
342 FailureOr<LoopAnnotationAttr> followupRemainderOuter =
343 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_outer");
344 FailureOr<LoopAnnotationAttr> followupRemainderInner =
345 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_inner");
346 FailureOr<LoopAnnotationAttr> followupAll =
347 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_all");
349 ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
350 followupRemainderInner, followupAll);
353FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() {
354 FailureOr<BoolAttr> disable = lookupUnitNode(
"llvm.licm.disable");
355 FailureOr<BoolAttr> versioningDisable =
356 lookupUnitNode(
"llvm.loop.licm_versioning.disable");
360FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() {
361 FailureOr<BoolAttr> disable =
362 lookupBoolNode(
"llvm.loop.distribute.enable",
true);
363 FailureOr<LoopAnnotationAttr> followupCoincident =
364 lookupFollowupNode(
"llvm.loop.distribute.followup_coincident");
365 FailureOr<LoopAnnotationAttr> followupSequential =
366 lookupFollowupNode(
"llvm.loop.distribute.followup_sequential");
367 FailureOr<LoopAnnotationAttr> followupFallback =
368 lookupFollowupNode(
"llvm.loop.distribute.followup_fallback");
369 FailureOr<LoopAnnotationAttr> followupAll =
370 lookupFollowupNode(
"llvm.loop.distribute.followup_all");
373 followupFallback, followupAll);
376FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() {
377 FailureOr<BoolAttr> disable = lookupBoolNode(
"llvm.loop.pipeline.disable");
378 FailureOr<IntegerAttr> initiationinterval =
379 lookupIntNode(
"llvm.loop.pipeline.initiationinterval");
383FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() {
384 FailureOr<IntegerAttr> count = lookupIntNode(
"llvm.loop.peeled.count");
388FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() {
389 FailureOr<BoolAttr> partialDisable =
390 lookupUnitNode(
"llvm.loop.unswitch.partial.disable");
394FailureOr<SmallVector<AccessGroupAttr>>
395LoopMetadataConversion::convertParallelAccesses() {
396 FailureOr<SmallVector<llvm::MDNode *>> nodes =
397 lookupMDNodes(
"llvm.loop.parallel_accesses");
400 SmallVector<AccessGroupAttr> refs;
401 for (llvm::MDNode *node : *nodes) {
402 FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
404 if (
failed(accessGroups)) {
405 emitWarning(loc) <<
"could not lookup access group";
408 llvm::append_range(refs, *accessGroups);
413FusedLoc LoopMetadataConversion::convertStartLoc() {
414 if (locations.empty())
416 return dyn_cast<FusedLoc>(
420FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() {
421 if (locations.size() < 2)
423 if (locations.size() > 2)
425 <<
"expected loop metadata to have at most two DILocations";
426 return dyn_cast<FusedLoc>(
430LoopAnnotationAttr LoopMetadataConversion::convert() {
431 if (
failed(initConversionState()))
434 FailureOr<BoolAttr> disableNonForced =
435 lookupUnitNode(
"llvm.loop.disable_nonforced");
436 FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr();
437 FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr();
438 FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr();
439 FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr();
440 FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr();
441 FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
442 FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
443 FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr();
444 FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr();
445 FailureOr<BoolAttr> mustProgress = lookupUnitNode(
"llvm.loop.mustprogress");
446 FailureOr<BoolAttr> isVectorized =
447 lookupIntNodeAsBoolAttr(
"llvm.loop.isvectorized");
448 FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses =
449 convertParallelAccesses();
452 if (!propertyMap.empty()) {
453 for (
auto name : propertyMap.keys())
454 emitWarning(loc) <<
"unknown loop annotation " << name;
458 FailureOr<FusedLoc> startLoc = convertStartLoc();
459 FailureOr<FusedLoc> endLoc = convertEndLoc();
462 ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
463 unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
464 unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
476 auto it = loopMetadataMapping.find(node);
477 if (it != loopMetadataMapping.end())
478 return it->getSecond();
480 LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *
this).convert();
482 mapLoopMetadata(node, attr);
490 if (!node->getNumOperands())
491 accessGroups.push_back(node);
492 for (
const llvm::MDOperand &operand : node->operands()) {
493 auto *childNode = dyn_cast<llvm::MDNode>(operand);
496 accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
500 for (
const llvm::MDNode *accessGroup : accessGroups) {
501 if (accessGroupMapping.count(accessGroup))
504 if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
506 <<
"expected an access group node to be empty and distinct";
509 accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>();
514FailureOr<SmallVector<AccessGroupAttr>>
519 if (!node->getNumOperands())
520 accessGroups.push_back(accessGroupMapping.lookup(node));
521 for (
const llvm::MDOperand &operand : node->operands()) {
522 auto *node = cast<llvm::MDNode>(operand.get());
523 accessGroups.push_back(accessGroupMapping.lookup(node));
526 if (llvm::is_contained(accessGroups,
nullptr))
static T createIfNonNull(MLIRContext *ctx, const P &...args)
Helper function that only creates and attribute of type T if all argument conversion were successfull...
static bool isEmptyOrNull(const Attribute attr)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
Location translateLoc(llvm::DILocation *loc)
Translates the debug location.
LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, Location loc)
LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc)
Converts all LLVM access groups starting from node to MLIR access group attributes.
ModuleImport & moduleImport
The ModuleImport owning this instance.
FailureOr< SmallVector< AccessGroupAttr > > lookupAccessGroupAttrs(const llvm::MDNode *node) const
Returns the access group attribute that map to the access group nodes starting from the access group ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.