10 #include "llvm/IR/Constants.h"
18 struct LoopMetadataConversion {
19 LoopMetadataConversion(
const llvm::MDNode *node,
Location loc,
21 : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
24 LoopAnnotationAttr convert();
27 LogicalResult initConversionState();
30 const llvm::MDNode *lookupAndEraseProperty(StringRef name);
35 FailureOr<BoolAttr> lookupUnitNode(StringRef name);
36 FailureOr<BoolAttr> lookupBoolNode(StringRef name,
bool negated =
false);
37 FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
38 FailureOr<IntegerAttr> lookupIntNode(StringRef name);
39 FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
40 FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
41 FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name);
42 FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName,
43 StringRef disableName,
44 bool negated =
false);
47 FailureOr<LoopVectorizeAttr> convertVectorizeAttr();
48 FailureOr<LoopInterleaveAttr> convertInterleaveAttr();
49 FailureOr<LoopUnrollAttr> convertUnrollAttr();
50 FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr();
51 FailureOr<LoopLICMAttr> convertLICMAttr();
52 FailureOr<LoopDistributeAttr> convertDistributeAttr();
53 FailureOr<LoopPipelineAttr> convertPipelineAttr();
54 FailureOr<LoopPeeledAttr> convertPeeledAttr();
55 FailureOr<LoopUnswitchAttr> convertUnswitchAttr();
56 FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses();
58 FailureOr<FusedLoc> convertEndLoc();
61 llvm::StringMap<const llvm::MDNode *> propertyMap;
62 const llvm::MDNode *node;
69 LogicalResult LoopMetadataConversion::initConversionState() {
71 if (node->getNumOperands() == 0 ||
72 dyn_cast<llvm::MDNode>(node->getOperand(0)) != node)
75 for (
const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) {
76 if (
auto *diLoc = dyn_cast<llvm::DILocation>(operand)) {
77 locations.push_back(diLoc);
81 auto *
property = dyn_cast<llvm::MDNode>(operand);
83 return emitWarning(loc) <<
"expected all loop properties to be either "
84 "debug locations or metadata nodes";
86 if (property->getNumOperands() == 0)
87 return emitWarning(loc) <<
"cannot import empty loop property";
89 auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0));
91 return emitWarning(loc) <<
"cannot import loop property without a name";
92 StringRef name = nameNode->getString();
94 bool succ = propertyMap.try_emplace(name, property).second;
97 <<
"cannot import loop properties with duplicated names " << name;
104 LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
105 auto it = propertyMap.find(name);
106 if (it == propertyMap.end())
108 const llvm::MDNode *
property = it->getValue();
109 propertyMap.erase(it);
113 FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) {
114 const llvm::MDNode *
property = lookupAndEraseProperty(name);
118 if (property->getNumOperands() != 1)
120 <<
"expected metadata node " << name <<
" to hold no value";
125 FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode(
126 StringRef enableName, StringRef disableName,
bool negated) {
127 auto enable = lookupUnitNode(enableName);
128 auto disable = lookupUnitNode(disableName);
129 if (failed(enable) || failed(disable))
132 if (*enable && *disable)
134 <<
"expected metadata nodes " << enableName <<
" and " << disableName
135 <<
" to be mutually exclusive.";
145 FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
147 const llvm::MDNode *
property = lookupAndEraseProperty(name);
151 auto emitNodeWarning = [&]() {
153 <<
"expected metadata node " << name <<
" to hold a boolean value";
156 if (property->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt *val =
159 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
160 if (!val || val->getBitWidth() != 1)
161 return emitNodeWarning();
163 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
167 LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
168 const llvm::MDNode *
property = lookupAndEraseProperty(name);
172 auto emitNodeWarning = [&]() {
174 <<
"expected metadata node " << name <<
" to hold an integer value";
177 if (property->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt *val =
180 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
181 if (!val || val->getBitWidth() != 32)
182 return emitNodeWarning();
184 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
187 FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
188 const llvm::MDNode *
property = lookupAndEraseProperty(name);
190 return IntegerAttr(
nullptr);
192 auto emitNodeWarning = [&]() {
194 <<
"expected metadata node " << name <<
" to hold an i32 value";
197 if (property->getNumOperands() != 2)
198 return emitNodeWarning();
200 llvm::ConstantInt *val =
201 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
202 if (!val || val->getBitWidth() != 32)
203 return emitNodeWarning();
206 val->getValue().getLimitedValue());
209 FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) {
210 const llvm::MDNode *
property = lookupAndEraseProperty(name);
214 auto emitNodeWarning = [&]() {
216 <<
"expected metadata node " << name <<
" to hold an MDNode";
219 if (property->getNumOperands() != 2)
220 return emitNodeWarning();
222 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1));
224 return emitNodeWarning();
229 FailureOr<SmallVector<llvm::MDNode *>>
230 LoopMetadataConversion::lookupMDNodes(StringRef name) {
231 const llvm::MDNode *
property = lookupAndEraseProperty(name);
236 auto emitNodeWarning = [&]() {
237 return emitWarning(loc) <<
"expected metadata node " << name
238 <<
" to hold one or multiple MDNodes";
241 if (property->getNumOperands() < 2)
242 return emitNodeWarning();
244 for (
unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
245 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i));
247 return emitNodeWarning();
254 FailureOr<LoopAnnotationAttr>
255 LoopMetadataConversion::lookupFollowupNode(StringRef name) {
256 auto node = lookupMDNode(name);
259 if (*node ==
nullptr)
260 return LoopAnnotationAttr(
nullptr);
262 return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
267 template <
typename T>
274 template <
typename T,
typename... P>
276 bool anyFailed = (failed(args) || ...);
284 return T::get(ctx, *args...);
287 FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() {
288 FailureOr<BoolAttr> enable =
289 lookupBoolNode(
"llvm.loop.vectorize.enable",
true);
290 FailureOr<BoolAttr> predicateEnable =
291 lookupBoolNode(
"llvm.loop.vectorize.predicate.enable");
292 FailureOr<BoolAttr> scalableEnable =
293 lookupBoolNode(
"llvm.loop.vectorize.scalable.enable");
294 FailureOr<IntegerAttr> width = lookupIntNode(
"llvm.loop.vectorize.width");
295 FailureOr<LoopAnnotationAttr> followupVec =
296 lookupFollowupNode(
"llvm.loop.vectorize.followup_vectorized");
297 FailureOr<LoopAnnotationAttr> followupEpi =
298 lookupFollowupNode(
"llvm.loop.vectorize.followup_epilogue");
299 FailureOr<LoopAnnotationAttr> followupAll =
300 lookupFollowupNode(
"llvm.loop.vectorize.followup_all");
302 return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable,
303 scalableEnable, width, followupVec,
304 followupEpi, followupAll);
307 FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() {
308 FailureOr<IntegerAttr> count = lookupIntNode(
"llvm.loop.interleave.count");
309 return createIfNonNull<LoopInterleaveAttr>(ctx, count);
312 FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
313 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
314 "llvm.loop.unroll.enable",
"llvm.loop.unroll.disable",
true);
315 FailureOr<IntegerAttr> count = lookupIntNode(
"llvm.loop.unroll.count");
316 FailureOr<BoolAttr> runtimeDisable =
317 lookupUnitNode(
"llvm.loop.unroll.runtime.disable");
318 FailureOr<BoolAttr> full = lookupUnitNode(
"llvm.loop.unroll.full");
319 FailureOr<LoopAnnotationAttr> followupUnrolled =
320 lookupFollowupNode(
"llvm.loop.unroll.followup_unrolled");
321 FailureOr<LoopAnnotationAttr> followupRemainder =
322 lookupFollowupNode(
"llvm.loop.unroll.followup_remainder");
323 FailureOr<LoopAnnotationAttr> followupAll =
324 lookupFollowupNode(
"llvm.loop.unroll.followup_all");
326 return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
327 full, followupUnrolled,
328 followupRemainder, followupAll);
331 FailureOr<LoopUnrollAndJamAttr>
332 LoopMetadataConversion::convertUnrollAndJamAttr() {
333 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
334 "llvm.loop.unroll_and_jam.enable",
"llvm.loop.unroll_and_jam.disable",
336 FailureOr<IntegerAttr> count =
337 lookupIntNode(
"llvm.loop.unroll_and_jam.count");
338 FailureOr<LoopAnnotationAttr> followupOuter =
339 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_outer");
340 FailureOr<LoopAnnotationAttr> followupInner =
341 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_inner");
342 FailureOr<LoopAnnotationAttr> followupRemainderOuter =
343 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_outer");
344 FailureOr<LoopAnnotationAttr> followupRemainderInner =
345 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_remainder_inner");
346 FailureOr<LoopAnnotationAttr> followupAll =
347 lookupFollowupNode(
"llvm.loop.unroll_and_jam.followup_all");
348 return createIfNonNull<LoopUnrollAndJamAttr>(
349 ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
350 followupRemainderInner, followupAll);
353 FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() {
354 FailureOr<BoolAttr> disable = lookupUnitNode(
"llvm.licm.disable");
355 FailureOr<BoolAttr> versioningDisable =
356 lookupUnitNode(
"llvm.loop.licm_versioning.disable");
357 return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable);
360 FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() {
361 FailureOr<BoolAttr> disable =
362 lookupBoolNode(
"llvm.loop.distribute.enable",
true);
363 FailureOr<LoopAnnotationAttr> followupCoincident =
364 lookupFollowupNode(
"llvm.loop.distribute.followup_coincident");
365 FailureOr<LoopAnnotationAttr> followupSequential =
366 lookupFollowupNode(
"llvm.loop.distribute.followup_sequential");
367 FailureOr<LoopAnnotationAttr> followupFallback =
368 lookupFollowupNode(
"llvm.loop.distribute.followup_fallback");
369 FailureOr<LoopAnnotationAttr> followupAll =
370 lookupFollowupNode(
"llvm.loop.distribute.followup_all");
371 return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident,
373 followupFallback, followupAll);
376 FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() {
377 FailureOr<BoolAttr> disable = lookupBoolNode(
"llvm.loop.pipeline.disable");
378 FailureOr<IntegerAttr> initiationinterval =
379 lookupIntNode(
"llvm.loop.pipeline.initiationinterval");
380 return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval);
383 FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() {
384 FailureOr<IntegerAttr> count = lookupIntNode(
"llvm.loop.peeled.count");
385 return createIfNonNull<LoopPeeledAttr>(ctx, count);
388 FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() {
389 FailureOr<BoolAttr> partialDisable =
390 lookupUnitNode(
"llvm.loop.unswitch.partial.disable");
391 return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable);
394 FailureOr<SmallVector<AccessGroupAttr>>
395 LoopMetadataConversion::convertParallelAccesses() {
396 FailureOr<SmallVector<llvm::MDNode *>> nodes =
397 lookupMDNodes(
"llvm.loop.parallel_accesses");
401 for (llvm::MDNode *node : *nodes) {
402 FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
403 loopAnnotationImporter.lookupAccessGroupAttrs(node);
404 if (failed(accessGroups)) {
405 emitWarning(loc) <<
"could not lookup access group";
408 llvm::append_range(refs, *accessGroups);
413 FusedLoc LoopMetadataConversion::convertStartLoc() {
414 if (locations.empty())
416 return dyn_cast<FusedLoc>(
417 loopAnnotationImporter.moduleImport.translateLoc(locations[0]));
420 FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() {
421 if (locations.size() < 2)
423 if (locations.size() > 2)
425 <<
"expected loop metadata to have at most two DILocations";
426 return dyn_cast<FusedLoc>(
427 loopAnnotationImporter.moduleImport.translateLoc(locations[1]));
430 LoopAnnotationAttr LoopMetadataConversion::convert() {
431 if (failed(initConversionState()))
434 FailureOr<BoolAttr> disableNonForced =
435 lookupUnitNode(
"llvm.loop.disable_nonforced");
436 FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr();
437 FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr();
438 FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr();
439 FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr();
440 FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr();
441 FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
442 FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
443 FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr();
444 FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr();
445 FailureOr<BoolAttr> mustProgress = lookupUnitNode(
"llvm.loop.mustprogress");
446 FailureOr<BoolAttr> isVectorized =
447 lookupIntNodeAsBoolAttr(
"llvm.loop.isvectorized");
448 FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses =
449 convertParallelAccesses();
452 if (!propertyMap.empty()) {
453 for (
auto name : propertyMap.keys())
454 emitWarning(loc) <<
"unknown loop annotation " << name;
458 FailureOr<FusedLoc> startLoc = convertStartLoc();
459 FailureOr<FusedLoc> endLoc = convertEndLoc();
461 return createIfNonNull<LoopAnnotationAttr>(
462 ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
463 unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
464 unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
476 auto it = loopMetadataMapping.find(node);
477 if (it != loopMetadataMapping.end())
478 return it->getSecond();
480 LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *
this).convert();
482 mapLoopMetadata(node, attr);
490 if (!node->getNumOperands())
491 accessGroups.push_back(node);
492 for (
const llvm::MDOperand &operand : node->operands()) {
493 auto *childNode = dyn_cast<llvm::MDNode>(operand);
496 accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
500 for (
const llvm::MDNode *accessGroup : accessGroups) {
501 if (accessGroupMapping.count(accessGroup))
504 if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
506 <<
"expected an access group node to be empty and distinct";
509 accessGroupMapping[accessGroup] = builder.
getAttr<AccessGroupAttr>();
514 FailureOr<SmallVector<AccessGroupAttr>>
519 if (!node->getNumOperands())
520 accessGroups.push_back(accessGroupMapping.lookup(node));
521 for (
const llvm::MDOperand &operand : node->operands()) {
522 auto *node = cast<llvm::MDNode>(operand.get());
523 accessGroups.push_back(accessGroupMapping.lookup(node));
526 if (llvm::is_contained(accessGroups,
nullptr))
static MLIRContext * getContext(OpFoldResult val)
static T createIfNonNull(MLIRContext *ctx, const P &...args)
Helper function that only creates and attribute of type T if all argument conversion were successfull...
static bool isEmptyOrNull(const Attribute attr)
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A helper class that converts llvm.loop metadata nodes into corresponding LoopAnnotationAttrs and llvm...
LoopAnnotationAttr translateLoopAnnotation(const llvm::MDNode *node, Location loc)
LogicalResult translateAccessGroup(const llvm::MDNode *node, Location loc)
Converts all LLVM access groups starting from node to MLIR access group attributes.
FailureOr< SmallVector< AccessGroupAttr > > lookupAccessGroupAttrs(const llvm::MDNode *node) const
Returns the access group attribute that map to the access group nodes starting from the access group ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...